pytorch
训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。那么如何显示dataloader
里面带batch
的tensor
类型的图像呢?
显示图像
绘图最常用的库就是matplotlib
:
pip install matplotlib
显示图像会用到matplotlib.pyplot.imshow
方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:
数据加载器中数据的维度是[B, C, H, W]
,我们每次只拿一个数据出来就是[C, H, W]
,而matplotlib.pyplot.imshow
要求的输入维度是[H, W, C]
,所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch
里面的permute
方法(transpose
方法也行,不过要交换两次,没这个方便,numpy
中的transpose
方法倒是可以一次交换完成),用法示例如下:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])
代码示例
#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.title(labels[i].item())
plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()
这里以mnist
数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))
。所以,如果你想查看训练集的原始图像,还得反标准化。
- 标准化:
image = (image-mean)/std
- 反标准化:
image = image*std+mean
我拿imagenet
中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:
最终效果
引用参考
https://pytorch.org/docs/stable/tensors.html
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:【pytorch】带batch的tensor类型图像显示 - Python技术站