这是torchvision.utils模块里面的两个方法,因为比较常用,所以pytorch直接封装好了。

制作网格

网络图像一般用于训练数据或测试数据的可视化。

torchvision.utils.make_grid(tensor, nrow, padding) → torch.Tensor

  • 描述

将多张tensor格式的图像以网格的方式封装到一起。

  • 参数

tensor (tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列表。

nrow (int):网格每行图片的个数,默认是8;千万不要理解为图片的行数。

padding (int):四周填充的宽度,默认是2,你可以理解为网格中图片之间的间距。默认填充值是0,也就是黑色。

注:这是三个比较常用的参数,其它参数请参考官方文档

  • 示例
# 以mnist数据集为例,train_loader的batch_size设置为9
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size())  # torch.Size([3, 84, 84])
  • 绘图
    【pytorch】制作网格图像,直接将tensor格式的图像保存到本地

保存本地

tensor数据类型保存时不用再转为PIL.Imagenumpy.ndarraypytorch直接给我们写好了一个方法。

torchvision.utils.save_image(tensor, fp) → None

  • 描述

直接将tensor数据保存为图像。

  • 参数

tensor (Tensor or list):待保存的tensor数据。如果给以一个四维的mini-batchtensor,将调用网格方法,然后再保存到本地。

fp (string or file object)):图像的保存路径。

注:这是两个比较常用的参数,其它参数请参考官方文档

  • 示例
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size())  # torch.Size([3, 84, 84])
torchvision.utils.save_image(images, 'test.jpg')

完整代码

#%% 导入模块
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
images = make_grid(images, 3, 0)
print(images.size())  # torch.Size([3, 84, 84])
save_image(images, 'test.jpg')

引用参考

https://pytorch.org/docs/stable/torchvision/utils.html