如果你想查看在PyTorch中定义的可训练参数(Trainable Parameters),可以使用PyTorch中的nn.Module
类提供的parameters()
方法,该方法返回一个生成器对象,可以遍历模型中的所有可训练参数。
下面是一个示例代码,展示了如何使用parameters()
方法查看可训练参数。
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(32 * 8 * 8, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
model = MyModel()
print(model)
# 打印模型中的可训练参数
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.shape)
上面的代码创建了一个包含两个卷积和池化层以及一个全连接层的简单CNN模型。我们使用named_parameters()
方法打印了模型中所有可训练参数的名称和形状。运行上述代码,会输出以下内容:
MyModel(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc): Linear(in_features=2048, out_features=10, bias=True)
)
conv1.weight torch.Size([16, 3, 3, 3])
conv1.bias torch.Size([16])
conv2.weight torch.Size([32, 16, 3, 3])
conv2.bias torch.Size([32])
fc.weight torch.Size([10, 2048])
fc.bias torch.Size([10])
如上所示,参数名称由模型中每个层的名称和类型组成,以及参数的类型(例如权重和偏置)。
另外一个查看可训练参数的方式是使用state_dict()
方法,该方法将可训练参数保存到一个字典中。下面是一个示例代码:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
model = MyModel()
print(model)
# 打印模型中的可训练参数
state_dict = model.state_dict()
for key in state_dict:
print(key, state_dict[key].shape)
该代码定义了一个包含两个全连接层的简单神经网络模型,并使用state_dict()
方法打印了模型中的全部可训练参数名称和形状。
运行上述代码,会输出以下内容:
MyModel(
(fc1): Linear(in_features=10, out_features=20, bias=True)
(fc2): Linear(in_features=20, out_features=5, bias=True)
)
fc1.weight torch.Size([20, 10])
fc1.bias torch.Size([20])
fc2.weight torch.Size([5, 20])
fc2.bias torch.Size([5])
如上所示,使用state_dict()
方法可以得到键值对形式的可训练参数名称和形状,其中参数名称与模型中每个层的名称相对应。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:在pytorch中查看可训练参数的例子 - Python技术站