以下是针对“pytorch dataset实战案例之读取数据集的代码”的完整攻略。
1. 确定数据集
在实现读取数据集的代码之前,首先要确定需要使用的数据集。PyTorch支持的数据集种类很多,例如MNIST手写数字数据集、CIFAR-10图像分类数据集、ImageNet图像分类数据集等。根据不同的场景选择不同的数据集。
2. 继承Dataset类
在PyTorch中,需要继承Dataset类来定义自己的数据集。继承Dataset类后,需要实现__len__()和__getitem__()两个方法。len()方法返回数据集的长度,getitem()方法根据索引返回数据集中对应的数据。
以下是一个示例,用于加载MNIST数据集的代码。其中,MNISTDataset类继承了Dataset类,并且在__init__()方法中读取MNIST数据集的图片和标签。
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
class MNISTDataset(Dataset):
def __init__(self, root):
self.dataset = datasets.MNIST(root=root, download=True, transform=transforms.ToTensor())
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label = self.dataset[idx]
return img, label
3. 加载数据集
定义好自己的数据集之后,需要使用DataLoader类来加载数据集。DataLoader类可以将数据集分成小批次进行训练。在使用DataLoader类之前,需要确定小批次的大小(batch_size)和是否打乱数据集(shuffle)。
以下是一个示例,用于加载MNIST数据集的代码。其中,mnist_train_loader是训练数据的DataLoader对象。
mnist_train = MNISTDataset('./datasets/mnist')
mnist_train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)
4. 示例演示
以下是一个简单的示例,用于演示如何读取MNIST数据集。在这个示例中,我们加载MNIST数据集并显示一张手写数字图片。
import torch
from torchvision import datasets, transforms
# 定义数据集
mnist = datasets.MNIST('./datasets/mnist', download=True, transform=transforms.ToTensor())
# 打印数据集大小
print(len(mnist))
# 加载数据集
mnist_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
# 显示一张图片
data, label = next(iter(mnist_loader))
img = data[0].numpy().squeeze()
print(label[0])
plt.imshow(img, cmap='gray')
plt.show()
另一个示例展示了如何读取自定义数据集。这里我们使用了一个名为“my_dataset”的文件夹,其中包含了10张猫咪图片和10张狗狗图片。我们将文件夹中的图片打包在一个名为my_dataset.zip的压缩包中,然后使用以下代码读取数据集。
import torch
from torchvision import datasets, transforms
# 定义自定义数据集
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, path):
self.path = path
self.images = []
self.labels = []
img_paths = glob.glob(os.path.join(path, '*.jpg'))
random.shuffle(img_paths)
for img_path in img_paths:
img = Image.open(img_path).convert('RGB')
img = transforms.ToTensor()(img)
label = 1 if 'cat' in img_path else 0
self.images.append(img)
self.labels.append(label)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
# 加载数据集
my_dataset = CustomDataset('./my_dataset')
my_loader = torch.utils.data.DataLoader(my_dataset, batch_size=10, shuffle=True)
# 显示一批图片
data, label = next(iter(my_loader))
for i in range(data.shape[0]):
img = data[i].permute(1, 2, 0).numpy()
plt.imshow(img, cmap='gray')
plt.title('cat' if label[i]==1 else 'dog')
plt.show()
在这个示例中,我们定义了自己的CustomDataset类,并在__init__()方法中读取所有的图片文件,将图片转换为张量并保存在self.images和self.labels列表中。在__getitem__()方法中,我们可以根据传入的索引idx返回对应的图片张量和标签。最后通过DataLoader类加载数据集,并使用next(iter(my_loader))方法获取一批数据进行展示。
希望这份攻略能对您有所帮助。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch dataset实战案例之读取数据集的代码 - Python技术站