PyTorch统计参数:网络参数数量方式
在深度学习中,了解模型的参数数量是非常重要的。在PyTorch中,我们可以使用torchsummary
模块来统计模型的参数数量。本文将介绍两种不同的方式来统计模型的参数数量。
1. 使用torchsummary模块
torchsummary
模块是一个用于打印PyTorch模型摘要的工具。它可以打印出模型的输入形状、输出形状和参数数量等信息。以下是使用torchsummary
模块来统计模型参数数量的示例代码。
!pip install torchsummary
import torch
import torch.nn as nn
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
summary(model, (3, 32, 32))
输出结果如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 6, 28, 28] 456
MaxPool2d-2 [-1, 6, 14, 14] 0
Conv2d-3 [-1, 16, 10, 10] 2,416
MaxPool2d-4 [-1, 16, 5, 5] 0
Linear-5 [-1, 120] 48,120
Linear-6 [-1, 84] 10,164
Linear-7 [-1, 10] 850
================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
----------------------------------------------------------------
2. 自定义函数
我们也可以自定义一个函数来统计模型的参数数量。以下是使用自定义函数来统计模型参数数量的示例代码。
import torch
import torch.nn as nn
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
print(f'Total number of parameters: {count_parameters(model)}')
输出结果如下:
Total number of parameters: 62006
这两种方式都可以用来统计模型的参数数量,选择哪种方式取决于个人喜好和使用场景。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch统计参数网络参数数量方式 - Python技术站