pytorch查看网络参数显存占用量等操作

下面是针对pytorch查看网络参数显存占用量等操作的完整攻略。

1. 查看网络参数总量

为了查看神经网络的参数总量,我们可以使用 torchsummary 库中的 summary 函数。该函数可以打印出我们定义的模型结构及其参数量等相关信息。

首先,我们需要在命令行中使用 pip 安装 torchsummary 库:

pip install torchsummary

然后,我们可以在pytorch代码中导入该库,使用以下代码查看神经网络的参数总量:

import torch
from torchsummary import summary

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例并查看参数总量
model = Net()
summary(model, (3, 32, 32))

其中,我们首先定义了一个神经网络模型 Net,然后在 main 函数中创建了一个 Net 的实例 model。最后,使用 summary 函数查看神经网络模型的参数总量。

在这个示例中,我们的神经网络有 61,706 个可训练的参数。

2. 查看网络的显存占用量

在训练神经网络时,我们需要时刻关注神经网络的显存占用量,以免显存溢出或训练速度过慢。下面我们介绍两种查看神经网络显存占用量的方法。

方法一:pytorch自带工具

pytorch自带诊断工具 torch.cuda.memory_summary 可以用来查看系统显存状态和分配情况,使用方法如下:

import torch

# 定义张量
a = torch.randn(1024, 1024).cuda()

# 打印显存分配情况
print(torch.cuda.memory_summary())

运行上述代码后,我们可以得到类似下面的输出:

|    GPU     |    累计使用 |
|:---------:|:-------------:|
|  Quadro P4000  |  3554 MB      |

这表示当前显存使用了 3554 MB,其中包括了我们之前创建的 1024*1024 的张量占用的显存。

方法二:torch.autograd.profiler

torch.autograd.profiler 是 pytorch 自带的性能评测工具,可以用于查看神经网络的显存使用情况以及计算时间、函数调用顺序等。

import torch

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = Net()

# 创建输入张量,将张量移动到GPU上
inputs = torch.randn(1, 3, 32, 32).cuda()

# 使用Profiler查看网络的显存占用情况
with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as profile:
    model(inputs)

print(profile)

运行上述代码后,我们可以得到类似下面的输出:

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                     CPU time        CUDA time        Calls           CPU total       CUDA total      Input size     
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
_convolution                            0.200us          1.757ms          2               0.000us          3.513ms          [1, 6, 28, 28]  
_max_pool2d                             0.025us          0.120us          2               0.050us          0.240us          [1, 6, 28, 28]  
_convolution                            0.032us          0.272ms          2               0.000us          0.544ms          [1, 16, 10, 10] 
_max_pool2d                             0.004us          0.052us          2               0.008us          0.104us          [1, 16, 10, 10] 
_view                                   0.065us          0.120us          2               0.130us          0.240us          [1, 400]        
_linear                                 0.064us          0.073ms          3               0.000us          0.219ms          [1, 120]        
_relu                                   0.025us          0.002ms          3               0.000us          0.005ms          [1, 120]        
_linear                                 0.020us          0.022ms          3               0.000us          0.067ms          [1, 84]         
_relu                                   0.011us          0.001ms          3               0.000us          0.004ms          [1, 84]         
_linear                                 0.012us          0.002ms          1               0.012us          0.002ms          [1, 10]         
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 0.518ms
Self CUDA time total: 4.059ms

其中,CUDA time 就是每个操作在 GPU 上使用的显存总量,可以通过求和得到模型整体的显存占用情况。在这个示例中,我们的模型总共占用了 4.059 MB 的显存。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch查看网络参数显存占用量等操作 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • Pytorch数据类型与转换(torch.tensor,torch.FloatTensor)

    PyTorch是一个开源的机器学习框架,提供了丰富的数据类型和转换方式。在使用PyTorch时,我们常常需要将数据转换成特定的数据类型,例如张量类型torch.tensor或浮点类型torch.FloatTensor等。本文将详细讲解PyTorch数据类型与转换的攻略。 PyTorch数据类型介绍 PyTorch提供了多种数据类型,包括整数类型、浮点类型、布…

    python 2023年5月13日
    00
  • 用tensorflow实现弹性网络回归算法

    用TensorFlow实现弹性网络回归算法 弹性网络回归是一种常用的线性回归算法,它可以在保持模型简单性的同时,克服最小二乘法(OLS)的一些缺点,例如对多重共线性的敏感性。本攻略将详细讲解如何使用TensorFlow实现弹性网络回归算法,并提供两个示例。 步骤一:导入库 在使用TensorFlow实现弹性回归算法之前,我们需要先导入相关的库。下面是一个简单…

    python 2023年5月14日
    00
  • python的dataframe和matrix的互换方法

    以下是Python中DataFrame和Matrix互换的方法的完整攻略,包括两个示例。 DataFrame和Matrix互换的方法 在Python中,可以使用NumPy和Pandas库将DataFrame和Matrix互换。以下是DataFrame和Matrix换的基本步骤: 将DataFrame转换为Matrix 使用Pandas的values属性将Da…

    python 2023年5月14日
    00
  • python怎么判断模块安装完成

    Python怎么判断模块安装完成 在Python中,可以使用pip命令安装第三方模块。但是,如何判断模块是否安装完成呢?本文将详细介绍Python如何判断模块安装完成。 方法1:使用import语句 可以使用import语句来判断模块是否安装完成。如果模块已经安装,import语句将不会报错。可以使用以下代码来判断模块是否安装完成: try: import …

    python 2023年5月14日
    00
  • Python数据分析应用之Matplotlib数据可视化详情

    Python数据分析应用之Matplotlib数据可视化详情 在本攻略中,我们将介绍如何使用Matplotlib进行数据可视化。以下是完整的攻略,含两个示例说明。 示例1:绘制折线图 以下是使用Matplotlib绘制折线图的步骤: 导入Matplotlib库。可以使用以下命令导入Matplotlib库: import matplotlib.pyplot a…

    python 2023年5月14日
    00
  • python numpy–数组的组合和分割实例

    Python NumPy – 数组的组合和分割实例 NumPy是Python中用于科学计算的一个重要的库,它提供了高效的多维数组array和与之相关的量。本文将详细讲解NumPy中的数组的组合和割实例,包括水组合、垂直组合、数组割等方法。 水平组合 使用NumPy中的hstack()函数可以将个数组水平组在一起,即将两个数组按列方向拼接。下面是一些示例: i…

    python 2023年5月14日
    00
  • Pytorch实现逻辑回归分类

    下面是关于“Pytorch实现逻辑回归分类”的完整攻略。 1. 逻辑回归分类 逻辑回归是一种二分类算法,用于将输入数据分为两个类别。在逻辑回归中,我们使用sigmoid函数将输入数据映射到0和1之间,然后将其作为概率输出。如果输出概率大于0.5,则将输入数据分类为1,否则分类为0。 2. Pytorch实现逻辑回归分类 在Pytorch中,可以使用torch…

    python 2023年5月14日
    00
  • 在MAC上搭建python数据分析开发环境

    以下是关于“在MAC上搭建Python数据分析开发环境”的完整攻略。 背景 在MAC上搭建Python数据分析开发环境,可以让我们更加高效地进行数据析和开发工作。本攻略将详细介绍在MAC上搭建Python数据分析开发环境的方法。 步骤一:安Python 在MAC上搭建Python数据分析开发环境的第一步是安装Python。可以从Python官网下载最新版本的…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部