pytorch查看网络参数显存占用量等操作

yizhihongxing

下面是针对pytorch查看网络参数显存占用量等操作的完整攻略。

1. 查看网络参数总量

为了查看神经网络的参数总量,我们可以使用 torchsummary 库中的 summary 函数。该函数可以打印出我们定义的模型结构及其参数量等相关信息。

首先,我们需要在命令行中使用 pip 安装 torchsummary 库:

pip install torchsummary

然后,我们可以在pytorch代码中导入该库,使用以下代码查看神经网络的参数总量:

import torch
from torchsummary import summary

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例并查看参数总量
model = Net()
summary(model, (3, 32, 32))

其中,我们首先定义了一个神经网络模型 Net,然后在 main 函数中创建了一个 Net 的实例 model。最后,使用 summary 函数查看神经网络模型的参数总量。

在这个示例中,我们的神经网络有 61,706 个可训练的参数。

2. 查看网络的显存占用量

在训练神经网络时,我们需要时刻关注神经网络的显存占用量,以免显存溢出或训练速度过慢。下面我们介绍两种查看神经网络显存占用量的方法。

方法一:pytorch自带工具

pytorch自带诊断工具 torch.cuda.memory_summary 可以用来查看系统显存状态和分配情况,使用方法如下:

import torch

# 定义张量
a = torch.randn(1024, 1024).cuda()

# 打印显存分配情况
print(torch.cuda.memory_summary())

运行上述代码后,我们可以得到类似下面的输出:

|    GPU     |    累计使用 |
|:---------:|:-------------:|
|  Quadro P4000  |  3554 MB      |

这表示当前显存使用了 3554 MB,其中包括了我们之前创建的 1024*1024 的张量占用的显存。

方法二:torch.autograd.profiler

torch.autograd.profiler 是 pytorch 自带的性能评测工具,可以用于查看神经网络的显存使用情况以及计算时间、函数调用顺序等。

import torch

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = Net()

# 创建输入张量,将张量移动到GPU上
inputs = torch.randn(1, 3, 32, 32).cuda()

# 使用Profiler查看网络的显存占用情况
with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as profile:
    model(inputs)

print(profile)

运行上述代码后,我们可以得到类似下面的输出:

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                     CPU time        CUDA time        Calls           CPU total       CUDA total      Input size     
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
_convolution                            0.200us          1.757ms          2               0.000us          3.513ms          [1, 6, 28, 28]  
_max_pool2d                             0.025us          0.120us          2               0.050us          0.240us          [1, 6, 28, 28]  
_convolution                            0.032us          0.272ms          2               0.000us          0.544ms          [1, 16, 10, 10] 
_max_pool2d                             0.004us          0.052us          2               0.008us          0.104us          [1, 16, 10, 10] 
_view                                   0.065us          0.120us          2               0.130us          0.240us          [1, 400]        
_linear                                 0.064us          0.073ms          3               0.000us          0.219ms          [1, 120]        
_relu                                   0.025us          0.002ms          3               0.000us          0.005ms          [1, 120]        
_linear                                 0.020us          0.022ms          3               0.000us          0.067ms          [1, 84]         
_relu                                   0.011us          0.001ms          3               0.000us          0.004ms          [1, 84]         
_linear                                 0.012us          0.002ms          1               0.012us          0.002ms          [1, 10]         
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 0.518ms
Self CUDA time total: 4.059ms

其中,CUDA time 就是每个操作在 GPU 上使用的显存总量,可以通过求和得到模型整体的显存占用情况。在这个示例中,我们的模型总共占用了 4.059 MB 的显存。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch查看网络参数显存占用量等操作 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • Python 实现Numpy中找出array中最大值所对应的行和列

    在Python中,可以使用NumPy库来进行数组操作。本文将详细讲解如何使用NumPy库找出数组中最大值所对应的行和列的完整攻略,包括两个例。 方法一:使用argmax函数 Py库中的argmax函数可以返回数组中最大值所在的索引。可以使用该函数找数组中大值所对应的行和列。下面是一个示例代码: import numpy as np # 创建一个二维数组 ar…

    python 2023年5月14日
    00
  • python中np.multiply()、np.dot()和星号(*)三种乘法运算的区别详解

    以下是关于“Python中np.multiply()、np.dot()和星号(*)三种乘法运算的区别详解”的完整攻略。 背景 在Python中,有三种常用的乘法运算分别是np.multiply()、np.dot()和星号(*)。这三乘法运算在使用时需要其区别。本攻略将详细介这三种乘法运算的区别。 np.multiply()函数 np.multiply()函数…

    python 2023年5月14日
    00
  • np.concatenate()函数数组序列参数的实现

    np.concatenate()函数是NumPy库中的一个函数,用于将两个或多个数组沿指定轴连接在一起。在使用np.concatenate()函数时,可以将多个数组作为一个序列参数传递给函数。本文将介绍np.concatenate()函数序列参数的实现,并提供两个示例。 数组序列参数的实现 在np.concatenate()函数中,可以将多个数组作为一个序列…

    python 2023年5月14日
    00
  • numpy.insert用法及内插插0的方法

    当您需要在NumPy数组中插入值时,可以使用numpy.insert()函数。该函数可以在指定的轴上插入值,并返回一个新的数组。以下是numpy.insert()的语法: numpy.insert(arr, obj, values, axis=None) 其中,参数的含义如: arr:要插入的输入数组。 obj:插入值的索引或者索引数组。 values:要插…

    python 2023年5月14日
    00
  • NumPy数组的高级索引

    NumPy中的高级索引指的是使用整数数组或布尔数组来索引数组的方式,相较于基本索引(使用切片或整数索引),高级索引提供了更加灵活的方式来获取数组中的元素。下面我们来详细介绍一下NumPy中的高级索引。 整数数组索引 整数数组索引是指使用整数数组来作为索引的方式。例如,有一个二维数组a: import numpy as np a = np.array([[0,…

    2023年3月3日
    00
  • pytorch中可视化之hook钩子

    PyTorch中可视化之hook钩子 在PyTorch中,我们可以使用hook钩子来获取模型中间层的输出,以便进行可视化或其他操作。本攻略将详细讲解PyTorch中可视化之hook钩子,包括如何使用hook钩子获取中间层的输出和如何使用hook钩子可视化中间层的输出。 使用hook钩子获取中间层的输出 在PyTorch中,我们可以使用register_for…

    python 2023年5月14日
    00
  • TensorFlow损失函数专题详解

    TensorFlow损失函数专题详解 TensorFlow是一个流行的深度学习框架,可以用于各种任务,例如分类、回归和聚类。在进行这些任务时,损失函数是非常关键的一个部分。本文将详细讲解TensorFlow中一些常用的损失函数。 什么是损失函数? 损失函数是一个衡量模型预测结果与真实结果之间的差异的函数。在训练模型时,我们尝试最小化损失函数的值。在深度学习中…

    python 2023年5月14日
    00
  • NumPy遍历数组最常用的4种方法

    NumPy提供了多种遍历数组的方法,主要有以下几种: 迭代器遍历 使用NumPy的nditer函数可以返回一个用于迭代数组元素的迭代器对象。可以通过设置order参数来指定迭代的顺序,例如order=’C’表示按照C语言的行优先顺序进行迭代,order=’F’表示按照Fortran语言的列优先顺序进行迭代。示例代码如下: import numpy as np…

    Numpy 2023年3月3日
    00
合作推广
合作推广
分享本页
返回顶部