在PyTorch中,我们可以使用torchvision库来加载和处理图像数据集。本文将详细讲解如何使用PyTorch加载和显示图像数据集,并提供两个示例说明。
1. 加载图像数据集
在PyTorch中,我们可以使用torchvision.datasets模块中的ImageFolder类来加载图像数据集。ImageFolder类会自动将数据集中的图像按照文件夹名称进行分类,并将每个图像的标签设置为文件夹的名称。以下是一个加载图像数据集的示例代码:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义数据预处理方法
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder(root='./data/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./data/test', transform=transform)
在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的缩放、裁剪、转换为张量和归一化等操作。然后,我们使用ImageFolder类加载了训练数据集和测试数据集,并将数据预处理方法作为参数传入。
2. 显示图像数据集
在PyTorch中,我们可以使用matplotlib库来显示图像数据集。以下是一个显示图像数据集的示例代码:
import matplotlib.pyplot as plt
import numpy as np
# 显示图像函数
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 随机选择一张图像并显示
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)
在上面的代码中,我们首先定义了一个imshow函数,该函数用于显示图像。然后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。
示例1:显示CIFAR10数据集
以下是一个显示CIFAR10数据集的示例代码:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 定义数据预处理方法
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 显示图像函数
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)
在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用CIFAR10类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。
示例2:显示MNIST数据集
以下是一个显示MNIST数据集的示例代码:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 定义数据预处理方法
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 显示图像函数
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
plt.show()
# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)
在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用MNIST类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据集图片显示方法 - Python技术站