一个例子
import torch
from torch.utils import data
class MyDataset(data.Dataset):
def __init__(self):
super(MyDataset, self).__init__()
self.data = torch.randn(8,2)
def __getitem__(self, index):
return self.data[index], index
def __len__(self):
return self.data.size()[0]
data_set = MyDataset()
print(data_set.data)
输出
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320],
[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])
data_loader = data.DataLoader(data_set,
batch_size=4,
shuffle=False)
print(len(data_set))
for i, (number, labels) in enumerate(data_loader):
print(number)
输出
8
tensor([[-1.3907, -0.0916],
[-0.4626, -1.3323],
[ 1.4242, -2.1718],
[ 1.5850, 0.3320]])
tensor([0, 1, 2, 3])
tensor([[-1.0804, 0.3884],
[ 0.6567, -0.1234],
[ 1.6721, -0.7327],
[-1.9595, -0.3512]])
tensor([4, 5, 6, 7])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch自定义dataset - Python技术站