Pytorch统计参数网络参数数量方式

PyTorch统计参数:网络参数数量方式

在深度学习中,了解模型的参数数量是非常重要的。在PyTorch中,我们可以使用torchsummary模块来统计模型的参数数量。本文将介绍两种不同的方式来统计模型的参数数量。

1. 使用torchsummary模块

torchsummary模块是一个用于打印PyTorch模型摘要的工具。它可以打印出模型的输入形状、输出形状和参数数量等信息。以下是使用torchsummary模块来统计模型参数数量的示例代码。

!pip install torchsummary

import torch
import torch.nn as nn
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
summary(model, (3, 32, 32))

输出结果如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
            Linear-5                  [-1, 120]          48,120
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
----------------------------------------------------------------

2. 自定义函数

我们也可以自定义一个函数来统计模型的参数数量。以下是使用自定义函数来统计模型参数数量的示例代码。

import torch
import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
print(f'Total number of parameters: {count_parameters(model)}')

输出结果如下:

Total number of parameters: 62006

这两种方式都可以用来统计模型的参数数量,选择哪种方式取决于个人喜好和使用场景。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch统计参数网络参数数量方式 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Linux下PyTorch安装的方法是什么

    这篇文章主要讲解了“Linux下PyTorch安装的方法是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Linux下PyTorch安装的方法是什么”吧! 一、PyTorch简介 PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook…

    2023年4月5日
    00
  • pyTorch——(1)基本数据类型

    @ 目录 torch.tensor() torch.FloatTensor() torch.empty() torch.zeros() torch.ones() torch.eye() torch.randn() torch.rand() torch.randint() torch.full() torch.normal() torch.arange() t…

    2023年4月8日
    00
  • pytorch自定义初始化权重的方法

    PyTorch是一个流行的深度学习框架,它提供了许多内置的初始化权重方法。但是,有时候我们需要自定义初始化权重方法来更好地适应我们的模型。在本攻略中,我们将介绍如何自定义初始化权重方法。 方法1:使用nn.Module的apply()函数 我们可以使用nn.Module的apply()函数来自定义初始化权重方法。apply()函数可以递归地遍历整个模型,并对…

    PyTorch 2023年5月15日
    00
  • PyTorch中torch.utils.data.Dataset的介绍与实战

    在PyTorch中,torch.utils.data.Dataset是一个抽象类,用于表示数据集。本文将介绍torch.utils.data.Dataset的基本用法,并提供两个示例说明。 基本用法 要使用torch.utils.data.Dataset,您需要创建一个自定义数据集类,并实现以下两个方法: len():返回数据集的大小。 getitem():…

    PyTorch 2023年5月15日
    00
  • pytorch中使用tensorboard

    完整代码见我的githubpytorch handbook官方介绍tensorboard官方turtorial 显示图片 cat_img = Image.open(‘cat.jpg’) transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), tr…

    PyTorch 2023年4月8日
    00
  • Pytorch学习:CIFAR-10分类

    最近在学习Pytorch,先照着别人的代码过一遍,加油!!!   加载数据集 # 加载数据集及预处理 import torchvision as tv import torchvision.transforms as transforms from torchvision.transforms import ToPILImage import torch a…

    PyTorch 2023年4月6日
    00
  • python PyTorch参数初始化和Finetune

    PyTorch参数初始化和Finetune攻略 在深度学习中,参数初始化和Finetune是非常重要的步骤,它们可以影响模型的收敛速度和性能。本文将详细介绍PyTorch中参数初始化和Finetune的实现方法,并提供两个示例说明。 1. 参数初始化方法 在PyTorch中,可以使用torch.nn.init模块中的函数来初始化模型的参数。以下是一些常用的初…

    PyTorch 2023年5月15日
    00
  • 教你如何在Pytorch中使用TensorBoard

    在PyTorch中,我们可以使用TensorBoard来可视化模型的训练过程和结果。TensorBoard是TensorFlow的一个可视化工具,但是它也可以与PyTorch一起使用。下面是一个简单的示例,演示如何在PyTorch中使用TensorBoard。 示例一:使用TensorBoard可视化损失函数 在这个示例中,我们将使用TensorBoard来…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部