在pytorch中查看可训练参数的例子

如果你想查看在PyTorch中定义的可训练参数(Trainable Parameters),可以使用PyTorch中的nn.Module类提供的parameters()方法,该方法返回一个生成器对象,可以遍历模型中的所有可训练参数。

下面是一个示例代码,展示了如何使用parameters()方法查看可训练参数。

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = MyModel()
print(model)

# 打印模型中的可训练参数
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

上面的代码创建了一个包含两个卷积和池化层以及一个全连接层的简单CNN模型。我们使用named_parameters()方法打印了模型中所有可训练参数的名称和形状。运行上述代码,会输出以下内容:

MyModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=2048, out_features=10, bias=True)
)
conv1.weight torch.Size([16, 3, 3, 3])
conv1.bias torch.Size([16])
conv2.weight torch.Size([32, 16, 3, 3])
conv2.bias torch.Size([32])
fc.weight torch.Size([10, 2048])
fc.bias torch.Size([10])

如上所示,参数名称由模型中每个层的名称和类型组成,以及参数的类型(例如权重和偏置)。

另外一个查看可训练参数的方式是使用state_dict()方法,该方法将可训练参数保存到一个字典中。下面是一个示例代码:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

model = MyModel()
print(model)

# 打印模型中的可训练参数
state_dict = model.state_dict()
for key in state_dict:
    print(key, state_dict[key].shape)

该代码定义了一个包含两个全连接层的简单神经网络模型,并使用state_dict()方法打印了模型中的全部可训练参数名称和形状。

运行上述代码,会输出以下内容:

MyModel(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=5, bias=True)
)
fc1.weight torch.Size([20, 10])
fc1.bias torch.Size([20])
fc2.weight torch.Size([5, 20])
fc2.bias torch.Size([5])

如上所示,使用state_dict()方法可以得到键值对形式的可训练参数名称和形状,其中参数名称与模型中每个层的名称相对应。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:在pytorch中查看可训练参数的例子 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • mysql-8.0.15-winx64 解压版安装教程及退出的三种方式

    以下是“mysql-8.0.15-winx64解压版安装教程及退出的三种方式”的完整攻略: 安装前的准备 下载mysql-8.0.15-winx64解压版,下载地址:https://dev.mysql.com/downloads/mysql/。 解压下载好的zip文件,将解压出的文件夹移动到目标安装位置。 安装步骤 确认文件夹的路径,如 D:\mysql-8…

    人工智能概览 2023年5月25日
    00
  • 在tensorflow中设置保存checkpoint的最大数量实例

    在TensorFlow中,保存Checkpoint是非常重要的一项功能,这能帮助我们在训练模型时保存模型的参数,以便在需要时恢复参数。但是,我们不想保存无限多的Checkpoint文件,因为不仅浪费存储空间,还会降低性能。因此,我们需要设置保存最大数量的Checkpoint文件,当超过设定的数量时,则自动删除最旧的Checkpoint文件。本攻略详细讲解在T…

    人工智能概论 2023年5月24日
    00
  • django创建最简单HTML页面跳转方法

    下面是详细的攻略: 确认Django环境已经搭建 在使用Django创建HTML页面跳转之前,需要确保Django环境已经搭建成功。 第一步:创建Django项目 创建Django项目,使用命令行工具,执行以下命令: django-admin startproject projectname 其中,projectname为你的项目名称。 第二步: 创建Dja…

    人工智能概论 2023年5月25日
    00
  • 树莓派(python)与arduino串口通信的详细步骤

    下面是树莓派和Arduino串口通信的详细步骤。 准备工作 首先,需要准备以下材料和工具: 树莓派和Arduino Uno开发板 USB数据线 Arduino IDE软件 Python编程环境 确定通信端口 将Arduino连接到树莓派,打开终端输入以下命令,查看Arduino的串口号: ls /dev/ttyACM* 如果连了多个串口设备,可能会显示多个串…

    人工智能概览 2023年5月25日
    00
  • Mongodb聚合函数count、distinct、group如何实现数据聚合操作

    MongoDB是目前流行的非关系型数据库之一,在数据聚合操作中,使用其提供的聚合函数可以轻松实现各种聚合操作。本文将详细讲解 MongoDB 聚合函数 count、distinct、group 的使用方法,包括语法和示例。 count函数 count函数用于统计集合中满足条件的文档数量。语法如下: db.collection.count(query, opt…

    人工智能概论 2023年5月25日
    00
  • 浅谈使用java实现阿里云消息队列简单封装

    使用Java实现阿里云消息队列简单封装,需要注意以下几个步骤: 第一步:引入依赖 在pom.xml文件中添加如下依赖: <dependency> <groupId>com.aliyun.openservices</groupId> <artifactId>ons-client</artifactId&gt…

    人工智能概览 2023年5月25日
    00
  • MongoDB设计方法以及技巧示例详解

    MongoDB设计方法以及技巧示例详解 在使用 MongoDB 设计数据库时,需要考虑如何设置数据结构和索引,以及如何查询和优化查询。下面将介绍一些 MongoDB 的设计方法和技巧,并且提供两个示例帮助理解。 MongoDB 数据结构设计 MongoDB 是一种文档型数据库,数据以 BSON 格式存储。设计数据结构时,需要考虑如何组织数据和关联数据。 设计…

    人工智能概览 2023年5月25日
    00
  • java使用tess4j进行图片文字识别功能

    以下是使用tess4j进行图片文字识别功能的完整攻略: 简介 Tess4J是基于Tesseract OCR引擎的Java OCR API。它支持OCR引擎的多种语言,并提供了易于使用的API。使用Tess4J可以方便地实现图片文字识别的功能。 步骤 步骤一:引入Tess4J的Jar包 在项目中引入Tess4J的Jar包,可以去官网(https://sourc…

    人工智能概论 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部