一、DataLoader、DataSet、Sampler
Pytorch是一个开源的机器学习、深度学习框架,其中DataLoader、DataSet、Sampler是数据处理的核心组件。
1.1 DataLoader
DataLoader是一个数据迭代器,它可以将数据集封装成可迭代的对象,方便我们对数据集进行批量读取,并且可以通过设置参数来实现多线程和数据预处理等功能。
比如我们可以通过设置batch_size、shuffle来实现分批读取随机乱序的数据。
1.2 DataSet
DataSet是一个抽象类,需要自定义数据集的读取和处理方式。需要继承它,并重写__getitem__和__len__两个方法。
重写__getitem__方法,定义如何从数据集中获取一条数据,重写__len__方法,定义数据集的长度。
1.3 Sampler
Sampler是数据集的采样器,可以用来控制数据的采样方式。
比如我们可以通过SequentialSampler来实现顺序采样,RandomSampler来实现随机采样。
二、DataLoader、DataSet、Sampler之间的关系
2.1 DataLoader和DataSet之间的关系
DataLoader是从DataSet中读取数据的工具,DataSet中存储了我们的数据,而DataLoader按照DataSet的要求读取数据。
2.2 DataLoader和Sampler之间的关系
DataLoader中的sampler参数可以控制对数据的采样方式,即可以通过设置sampler参数使用自定义的Sampler来控制数据的采样方式。
三、示例
接下来,我们通过两个示例来进一步说明DataLoader、DataSet、Sampler之间的关系。
3.1 示例1
比如我们有一个数据集,我们想要按照一定的顺序读取数据,这时我们可以使用SequentialSampler来进行采样。
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler
# 创建一个自定义的数据集DataSet,并实现__getitem__和__len__方法。
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建SequentialSampler
sampler = SequentialSampler(dataset)
# 创建DataLoader,并传入dataset和sampler参数
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=4)
# 读取数据
for data in dataloader:
print(data)
上述代码中,我们首先定义了一个自定义的数据集CustomDataset,并实现了__getitem__和__len__方法。
然后,我们创建了一个SequentialSampler,并传入了我们的数据集dataset。
最后,我们创建了一个DataLoader,通过传入dataset和sampler参数,来读取数据。
在循环中,我们依次读取了每个batch_size大小的数据。
3.2 示例2
我们还可以通过自定义Sampler来控制数据的采样方式。比如我们有一个数据集,我们想要跳过其中的一些数据,这时我们可以自定义一个Sampler。
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
# 创建一个自定义的数据集DataSet,并实现__getitem__和__len__方法。
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建一个自定义的Sampler
class CustomSampler(Sampler):
def __init__(self, data, skip_idx):
self.data = data
self.skip_idx = skip_idx
def __iter__(self):
return iter([i for i in range(len(self.data)) if i not in self.skip_idx])
def __len__(self):
return len(self.data) - len(self.skip_idx)
# 创建自定义的数据集
dataset = CustomDataset(list(range(10)))
# 创建自定义的Sampler
sampler = CustomSampler(dataset, [1, 2, 3])
# 创建DataLoader,传入dataset和sampler参数
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=4)
# 读取数据
for data in dataloader:
print(data)
上述代码中,我们首先定义了一个自定义的数据集CustomDataset,并实现了__getitem__和__len__方法。
然后,我们定义了一个自定义的Sampler CustomSampler,实现了__iter__和__len__方法。
其中,__iter__方法返回一个迭代器,控制数据的顺序。
最后,我们创建了一个DataLoader,通过传入dataset和sampler参数,来读取数据。
在循环中,我们依次读取了除了索引为1,2,3的数据之外的数据。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 - Python技术站