在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能。
DataLoader的函数定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)
dataset:加载的数据集(Dataset对象)
batch_size:batch size
shuffle::是否将数据打乱
sampler: 样本抽样,后续会详细介绍
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
def main(): import visdom import time viz = visdom.Visdom() db = Pokemon('pokeman', 224, 'train') x,y = next(iter(db)) ## print('sample:',x.shape,y.shape,y) viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x')) loader = DataLoader(db,batch_size=32,shuffle=True) for x,y in loader: #为了得一个一个的数据集形式的数据每一组32张 viz.images(db.denormalize(x),nrow=8,win='batch',opts = dict(title = 'batch')) viz.text(str(y.numpy()),win = 'label',opts=dict(title='batch-y')) time.sleep(10)
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在_ getitem _函数中将出现异常,此时最好的解决方案即是将出错的样本剔除
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch之DataLoader()函数 - Python技术站