pytorch查看网络参数显存占用量等操作

下面是针对pytorch查看网络参数显存占用量等操作的完整攻略。

1. 查看网络参数总量

为了查看神经网络的参数总量,我们可以使用 torchsummary 库中的 summary 函数。该函数可以打印出我们定义的模型结构及其参数量等相关信息。

首先,我们需要在命令行中使用 pip 安装 torchsummary 库:

pip install torchsummary

然后,我们可以在pytorch代码中导入该库,使用以下代码查看神经网络的参数总量:

import torch
from torchsummary import summary

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例并查看参数总量
model = Net()
summary(model, (3, 32, 32))

其中,我们首先定义了一个神经网络模型 Net,然后在 main 函数中创建了一个 Net 的实例 model。最后,使用 summary 函数查看神经网络模型的参数总量。

在这个示例中,我们的神经网络有 61,706 个可训练的参数。

2. 查看网络的显存占用量

在训练神经网络时,我们需要时刻关注神经网络的显存占用量,以免显存溢出或训练速度过慢。下面我们介绍两种查看神经网络显存占用量的方法。

方法一:pytorch自带工具

pytorch自带诊断工具 torch.cuda.memory_summary 可以用来查看系统显存状态和分配情况,使用方法如下:

import torch

# 定义张量
a = torch.randn(1024, 1024).cuda()

# 打印显存分配情况
print(torch.cuda.memory_summary())

运行上述代码后,我们可以得到类似下面的输出:

|    GPU     |    累计使用 |
|:---------:|:-------------:|
|  Quadro P4000  |  3554 MB      |

这表示当前显存使用了 3554 MB,其中包括了我们之前创建的 1024*1024 的张量占用的显存。

方法二:torch.autograd.profiler

torch.autograd.profiler 是 pytorch 自带的性能评测工具,可以用于查看神经网络的显存使用情况以及计算时间、函数调用顺序等。

import torch

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = Net()

# 创建输入张量,将张量移动到GPU上
inputs = torch.randn(1, 3, 32, 32).cuda()

# 使用Profiler查看网络的显存占用情况
with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as profile:
    model(inputs)

print(profile)

运行上述代码后,我们可以得到类似下面的输出:

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                     CPU time        CUDA time        Calls           CPU total       CUDA total      Input size     
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
_convolution                            0.200us          1.757ms          2               0.000us          3.513ms          [1, 6, 28, 28]  
_max_pool2d                             0.025us          0.120us          2               0.050us          0.240us          [1, 6, 28, 28]  
_convolution                            0.032us          0.272ms          2               0.000us          0.544ms          [1, 16, 10, 10] 
_max_pool2d                             0.004us          0.052us          2               0.008us          0.104us          [1, 16, 10, 10] 
_view                                   0.065us          0.120us          2               0.130us          0.240us          [1, 400]        
_linear                                 0.064us          0.073ms          3               0.000us          0.219ms          [1, 120]        
_relu                                   0.025us          0.002ms          3               0.000us          0.005ms          [1, 120]        
_linear                                 0.020us          0.022ms          3               0.000us          0.067ms          [1, 84]         
_relu                                   0.011us          0.001ms          3               0.000us          0.004ms          [1, 84]         
_linear                                 0.012us          0.002ms          1               0.012us          0.002ms          [1, 10]         
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 0.518ms
Self CUDA time total: 4.059ms

其中,CUDA time 就是每个操作在 GPU 上使用的显存总量,可以通过求和得到模型整体的显存占用情况。在这个示例中,我们的模型总共占用了 4.059 MB 的显存。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch查看网络参数显存占用量等操作 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • python中numpy.zeros(np.zeros)的使用方法

    以下是关于“Python中Numpy.zeros(np.zeros)的使用方法”的完整攻略。 背景 在Python中,Numpy是一个常用的科学计算库,提供了许多方便的函数和工具。其中,numpy.zeros函数用来创建指定形状的全0数组。本攻略将详细介绍numpy.zeros函数的使用方法。 numpy.zeros函数的基本概念 numpy.zeros函数…

    python 2023年5月14日
    00
  • Python实现分段线性插值

    Python实现分段线性插值 分段线性插值是一种常见的插值方法,可以用于在给定的数据点之间估计未知的函数值。在本攻略中,我们将介绍如何使用Python实现分段线性插值,并提供两个示例说明。 问题描述 在某些情况下,我们需要在给定的数据点之间估计未知的函数值。分段线性插值是一种常见的插值方法,可以用于实现这个目标。如何使用Python实现分段线性插值呢?在本攻…

    python 2023年5月14日
    00
  • python numpy查询定位赋值数值所在行列

    在Python中,使用NumPy库可以方便地对数组进行各种操作,包括查询、定位和赋值数值所在行列。下面是查询、位和赋值数值在行列的详细攻略。 查询数值所行列 在NumPy中,可以使用where函数来查询数组中某个数值的位置。面是一个使用where函数查询一个二维数组中某数值的位置的示例代码: import numpy as np # 创建一个3×4的二维数组…

    python 2023年5月14日
    00
  • Python数据分析numpy数组的3种创建方式

    Python数据分析numpy数组的3种创建方式 NumPy是Python中一个非常流行的科学计算库,它提供了许多常用的数学函数和工具。在数据分析,经常需要使用NumPy来存储和处理数据。本攻略将介绍NumPy数组的3种创建方式,包括使用列表、使用NumPy使用文件读取。 列表创建NumPy数组 我们可以使用Python中的列表来创建NumPy数组。下面是一…

    python 2023年5月13日
    00
  • 在Python3 numpy中mean和average的区别详解

    在Python3的numpy中,mean和average都是用于计算数组中元素的平均值的函数,但它们之间有一些区别。 mean函数 mean函数是numpy中的一个函数,用于计算中素的平均值。它的法如下: .mean(a, axis=None, dtype=None, out=None, keepdims=<no value>) ,参数是要计算平…

    python 2023年5月14日
    00
  • MacOS(M1芯片 arm架构)下安装tensorflow的详细过程

    MacOS(M1芯片 arm架构)下安装TensorFlow的详细过程 在MacOS(M1芯片 arm架构)下安装TensorFlow需要一些额外的步骤。本文将详细介绍如何在MacOS(M1芯片 arm架构)下安装TensorFlow。 步骤1:安装Homebrew Homebrew是MacOS下的一个包管理器,可以方便地安装和管理软件包。可以使用以下命令安…

    python 2023年5月14日
    00
  • 变长双向rnn的正确使用姿势教学

    变长双向RNN的正确使用姿势教学 变长双向RNN是一种强大的神经网络模型,它可以处理变长序列数据,例如自然语言文本、音频信号等。在本攻略中,我们将介绍变长双向RNN的正确使用姿势,并提供两个示例说明。 什么是变长双向RNN? 变长双向RNN是一种神经网络模型,它由两个方向的RNN组成,一个从前往后处理输入序列,另一个从后往前处理输入序列。这种结构可以捕捉输入…

    python 2023年5月14日
    00
  • Python强化练习之PyTorch opp算法实现月球登陆器

    PyTorch是一个常用的深度学习框架,提供了许多常用的深度学习算法和工具。在本次强化练习中,我们将使用PyTorch实现月球登陆器的控制算法。以下是Python强化练习之PyTorchopp算法实现月球登陆器的完整攻略,包括算法实现的步骤和示例说明: PyTorchopp算法介绍 PyTorchopp算法是一种常用的强化学习算法,用于解决连续动作空间的问题…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部