当使用PyTorch进行深度学习时,我们需要将数据转化为张量并通过模型传递,但如何将原始数据转化为张量呢?这就涉及到PyTorch数据读取中的Dataset和DataLoader两个重要的概念。
Dataset
PyTorch中的Dataset是一个抽象类,代表数据集,它可以定义自己的数据形式、读取数据的方式、增加额外的预处理步骤等。我们只需继承该类,并实现__getitem__和__len__两个魔法方法即可。
示例1
以下是一个简单的示例,展示如何创建一个自定义的Dataset:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __getitem__(self, index):
x, y = self.data[index] # 假设data是一个存储数据和标签的列表
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
这个示例中,我们创建了一个名为MyDataset的子类,它包含2个参数:data和transform。MyDataset的__getitem__方法返回数据和标签,并对数据应用了一个可选的图像变换(transform)。__len__方法返回数据集的大小。
示例2
接下来是一个更完整的示例,展示如何将PyTorch自带的CIFAR-10数据集转换成自己的Dataset:
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset
class MyCIFAR10(Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None):
self.data = datasets.CIFAR10(root, train=train, transform=transform, target_transform=target_transform)
self.transform = transform
def __getitem__(self, index):
x, y = self.data[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
这个示例中,我们从torchvision中导入CIFAR-10数据集,并使用transforms模块定义一个可选的数据预处理操作。然后我们创建一个名为MyCIFAR10的子类,它继承了Dataset并包含4个参数:root、train、transform和target_transform。我们在__init__方法中创建了一个CIFAR10对象,它将使用传递的参数初始化。继承自Dataset的__getitem__和__len__方法分别返回数据和标签以及数据集大小。
DataLoader
DataLoader是PyTorch提供的一个数据读取器,它可以将Dataset中的数据转化为迭代器,并提供一些有用的功能,如自动批量读取、多进程数据加载、随机打乱数据等。下面是一些常用的DataLoader参数:
- dataset:数据集。
- batch_size:每次返回的数据批量大小。
- shuffle:是否随机打乱数据。
- num_workers:读取数据的线程数。
- drop_last:当数据集大小不能整除batch_size时,是否丢弃最后一批数据(默认为False,即不丢弃)。
示例1
以下是一个简单的示例,展示如何使用DataLoader读取数据:
import torch
from torch.utils.data import DataLoader
dataset = MyDataset(data) # 创建数据集
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # 创建数据读取器
for i, batch in enumerate(dataloader):
x, y = batch
print('批次', i, ':', x, y)
这个示例中,我们首先创建了一个名为dataset的MyDataset对象。然后我们使用DataLoader创建了一个名为dataloader的数据读取器,它将读取dataset中的数据,并返回batch_size大小的数据批次。我们使用for循环迭代dataloader,并逐批次获取数据。由于我们设置了shuffle参数为True,每个批次的数据都将是随机的。
示例2
下面是一个更完整的示例,展示如何将PyTorch自带的CIFAR-10数据集转换成可用于训练的DataLoader对象:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
transform_train = transforms.Compose([ # 定义数据预处理操作
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = MyCIFAR10(root='./data', train=True, transform=transform_train) # 创建训练集
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) # 创建训练集读取器
for i, batch in enumerate(trainloader):
x, y = batch
# 训练代码...
这个示例中,我们使用了transforms模块定义了一系列的数据预处理操作。然后我们创建了一个名为trainset的MyCIFAR10对象,它将使用这些预处理操作,并将CIFAR-10训练数据集中的数据和标签作为参数进行初始化。接着我们使用DataLoader创建了一个名为trainloader的训练集读取器,它将随机读取128张图片作为一个批次,并使用2个线程并行读取数据。
总之,Dataset和DataLoader是PyTorch中非常重要的数据读取相关的类,通过它们我们可以有效地读取、批次化和预处理数据,是深度学习中必不可少的组件。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch数据读取之Dataset和DataLoader知识总结 - Python技术站