在pytorch中查看可训练参数的例子

如果你想查看在PyTorch中定义的可训练参数(Trainable Parameters),可以使用PyTorch中的nn.Module类提供的parameters()方法,该方法返回一个生成器对象,可以遍历模型中的所有可训练参数。

下面是一个示例代码,展示了如何使用parameters()方法查看可训练参数。

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = MyModel()
print(model)

# 打印模型中的可训练参数
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

上面的代码创建了一个包含两个卷积和池化层以及一个全连接层的简单CNN模型。我们使用named_parameters()方法打印了模型中所有可训练参数的名称和形状。运行上述代码,会输出以下内容:

MyModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=2048, out_features=10, bias=True)
)
conv1.weight torch.Size([16, 3, 3, 3])
conv1.bias torch.Size([16])
conv2.weight torch.Size([32, 16, 3, 3])
conv2.bias torch.Size([32])
fc.weight torch.Size([10, 2048])
fc.bias torch.Size([10])

如上所示,参数名称由模型中每个层的名称和类型组成,以及参数的类型(例如权重和偏置)。

另外一个查看可训练参数的方式是使用state_dict()方法,该方法将可训练参数保存到一个字典中。下面是一个示例代码:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

model = MyModel()
print(model)

# 打印模型中的可训练参数
state_dict = model.state_dict()
for key in state_dict:
    print(key, state_dict[key].shape)

该代码定义了一个包含两个全连接层的简单神经网络模型,并使用state_dict()方法打印了模型中的全部可训练参数名称和形状。

运行上述代码,会输出以下内容:

MyModel(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=5, bias=True)
)
fc1.weight torch.Size([20, 10])
fc1.bias torch.Size([20])
fc2.weight torch.Size([5, 20])
fc2.bias torch.Size([5])

如上所示,使用state_dict()方法可以得到键值对形式的可训练参数名称和形状,其中参数名称与模型中每个层的名称相对应。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:在pytorch中查看可训练参数的例子 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Python函数实现学员管理系统

    Python函数实现学员管理系统的完整攻略包括以下步骤: 设计数据结构在实现学员管理系统之前,我们要先设计好存储学员信息的数据结构。比较常用的数据结构有列表(list)、元组(tuple)、字典(dict)等。在这里,我们选择使用字典存储学员信息,例如:{‘id’: 1001, ‘name’: ‘Tom’, ‘age’: 18, ‘gender’: ‘mal…

    人工智能概览 2023年5月25日
    00
  • vivo X Note值得入手吗 vivo X Note体验评测

    vivo X Note值得入手吗 – vivo X Note体验评测 介绍 vivo X Note是vivo推出的一款中高端手机。以下是对该手机的详细评测分析,希望能够帮到想要购买该手机的用户。 外观 vivo X Note采用了前后双玻璃+金属中框的设计,整体感觉非常高端。屏幕正面采用了2.5D曲面玻璃,机身背面也有着类似的设计。该机整体颜色采用亮黑色,非…

    人工智能概览 2023年5月25日
    00
  • spring cloud 使用Zuul 实现API网关服务问题

    下面是关于“Spring Cloud 使用Zuul 实现API网关服务”的完整攻略: 一、什么是API网关服务 API网关服务是一个在客户端和服务器端之间的中间层,用于处理请求、转发流量、筛选和管理API。与其他架构设计不同,API网关服务提供了单一入口点,使得请求能够通过一个位置路由到不同的服务。 二、为什么使用API网关服务 简化了客户端和后端服务的交互…

    人工智能概览 2023年5月25日
    00
  • 详解django中url路由配置及渲染方式

    我们来详细讲解“详解django中url路由配置及渲染方式”的攻略。 1. 什么是URL路由 URL路由(也叫网址路由、URL映射)是指将URL请求映射到相应的处理器上,从而在Web服务器和应用程序之间建立一一对应关系。 在Django中,URL路由是实现模块化开发的核心,通过定义URL映射规则,将请求分发到对应的处理器方法中,并返回响应数据。URL路由是D…

    人工智能概览 2023年5月25日
    00
  • flask和vue前后端分离项目部署的示例代码

    下面我将为你详细讲解Flask和Vue前后端分离项目部署的攻略,分为以下几个步骤: 1. 开发前的准备工作 在开始开发前,我们需要准备好以下工具和环境: Python环境。推荐安装Python 3.6以上的版本。 Node.js环境。推荐安装8.11以上的版本。 Vue CLI。可使用npm install -g @vue/cli命令进行安装。 MySQL数…

    人工智能概论 2023年5月25日
    00
  • win7平台快速安装、启动mongodb的方法

    以下是“win7平台快速安装、启动mongodb的方法”的完整攻略: 安装 MongoDB 访问 MongoDB 官网(https://www.mongodb.com/download-center/community)下载 64 位 Windows 版本的 MSI 文件。 运行 MSI 文件,按照提示进行安装。在安装目标目录选择时,建议选择一个简单的目录,…

    人工智能概论 2023年5月25日
    00
  • Django框架基础模板标签与filter使用方法详解

    我将为你详细讲解“Django框架基础模板标签与filter使用方法详解”的完整攻略。 模板标签 Django框架中的模板标签是创建模板时使用的一种方便的方式,它们可以扩展模板语言的功能。以下是在Django模板中使用常见的标签: if标签 判断条件是否成立,并执行相应操作。示例代码如下: {% if name == ‘john’ %} Hi John! {…

    人工智能概论 2023年5月25日
    00
  • 使用Docker-compose离线部署Django应用的方法

    下面是使用Docker-compose离线部署Django应用的完整攻略: 1. 安装Docker和Docker-compose Docker是一种容器化技术,可以轻松创建、部署和运行应用程序。Docker-compose则可以用来管理多个Docker容器的部署。 在开始部署之前,需要先安装Docker和Docker-compose。安装方法可以参考Dock…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部