下面我来为你详细讲解“PyTorch 自定义数据集加载方法”的完整攻略。
1. 前置条件
在开始介绍如何自定义数据集加载方法之前,需要先了解以下几个前置条件:
- 了解PyTorch库,包括张量(Tensor)、数据集(Dataset)、变换(Transforms)、数据读取器(DataLoader)等基本概念;
- 数据集文件按要求格式存储,例如:每张图片的地址和标签组成一条样本,按照csv或txt文件格式存储。
2. 编写自定义数据集类
在 PyTorch 中,我们可以通过自定义数据集类来加载自有的数据集。此类需继承 PyTorch 的 Dataset 类并重载以下两个方法:
__len__()
: 返回数据集中样本的数量__getitem__()
: 根据索引index返回对应的一条数据记录(包括图像和标签等)
以下是一个简单的自定义数据集类的示例:
import torch.utils.data as data
class CustomDataset(data.Dataset):
def __init__(self, data_file, transform=None):
super(CustomDataset, self).__init__()
self.data, self.label = [], []
with open(data_file, 'r') as f:
for line in f:
content = line.strip().split(',')
self.data.append(content[0]) # 图像路径
self.label.append(int(content[1])) # 图像标签
self.transform = transform
def __len__(self):
return len(self.label)
def __getitem__(self, index):
img, label = self.data[index], self.label[index]
img = Image.open(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
在 CustomDataset
中,我们首先读取数据文件中的样本信息,然后在 __getitem__()
方法中返回到每条样本数据。在返回之前,还可以实现一些图像预处理操作(例如,将图像转换为张量或对图像进行归一化等)。
3. 数据读取及预处理
完成自定义数据集类后,还需要进行数据的读取以及预处理等操作。这里通常采用 PyTorch 提供的 DataLoader 类。
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
# 设置 transforms 变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载自定义数据集
train_set = CustomDataset(data_file='train.csv', transform=transform)
# 创建 DataLoader
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
其中,transforms
变换用于对图像进行预处理操作,CustomDataset
类用于载入数据集,最后通过 DataLoader
创建数据读取器。
4. 示例说明
以下是两个在 PyTorch 中如何自定义数据集加载方法的示例说明。
示例1:加载CIFAR10数据集
CIFAR10 是一个包含十个类别、共计 6 万张 32x32 像素彩色图片的数据集,其中有 5 万张图片用于训练集,1 万张图片用于测试集。
首先,下载 CIFAR10 数据集,然后定义一个名为 CIFAR10Dataset
的类,实现 __init__()
,__len__()
和 __getitem__()
方法:
import torch.utils.data as data
import torchvision.datasets as datasets
class CIFAR10Dataset(data.Dataset):
def __init__(self, root, train=True, transform=None):
super(CIFAR10Dataset, self).__init__()
self.cifar10 = datasets.CIFAR10(root=root, train=train, transform=transform, download=True)
def __len__(self):
return len(self.cifar10)
def __getitem__(self, index):
img, label = self.cifar10[index]
return img, label
然后,定义一个名为 get_dataloader()
的函数,从而得到训练集和测试集的数据批次:
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
# transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def get_dataloader():
train_set = CIFAR10Dataset(root='./data', train=True, transform=transform)
test_set = CIFAR10Dataset(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)
return train_loader, test_loader
最后,我们可以调用 get_dataloader()
函数获取数据集的批次:
train_loader, test_loader = get_dataloader()
示例2:加载一张多标签图像
假设我们现在有一张图片 A ,它是一个多标签图像,即这张图片有多个类别标签(例如,这张图片既包含狗的标签,又包含树的标签)。此时我们就可以采用自定义数据集加载的方法,通过重载 __getitem__()
方法来实现。
首先,我们读取包含该图像路径和标签的文件:
with open('data.txt', 'r') as f:
lines = f.readlines()
file_list, label_list = [], []
for line in lines:
arr = line.strip().split('\t')
file_list.append(arr[0])
label_list.append(list(map(int, arr[1:])))
然后,定义一个名为 CustomMultiLabelDataset
的类,实现 __init__()
,__len__()
和 __getitem__()
方法。
from PIL import Image
class CustomMultiLabelDataset(data.Dataset):
def __init__(self, root, file_list, label_list, transform=None):
super(CustomMultiLabelDataset, self).__init__()
self.root = root
self.file_list = file_list
self.label_list = label_list
self.transform = transform
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img = Image.open(os.path.join(self.root, self.file_list[index]))
if self.transform is not None:
img = self.transform(img)
labels = torch.FloatTensor(self.label_list[index])
return img, labels
最后,我们可以按照自己的需求,对图像进行预处理以及通过自定义数据集来获取图像和标签:
import torch.utils.data as data
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
root = '/path/to/image/folder'
file_list = ['dog_and_tree.jpg']
label_list = [[1, 0, 1]]
dataset = CustomMultiLabelDataset(root, file_list, label_list, transform=transform)
img, labels = dataset[0]
以上就是 PyTorch 自定义数据集加载方法的完整攻略,以及包含 CIFAR10 数据集和多标签图像数据集的两个示例。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 自定义数据集加载方法 - Python技术站