import torch from torch.utils.data import Dataset,DataLoader class SmsDataset(Dataset): def __init__(self): self.file_path = "./SMSSpamCollection" self.lines = open(self.file_path,encoding="utf-8").readlines() def __getitem__(self, index): line = self.lines[index].strip() label = line.split("\t")[0] sent = line.split("\t")[1] return label,sent def __len__(self): return len(self.lines) sms_dataset = SmsDataset() dataloader = DataLoader(sms_dataset,batch_size=2,shuffle=True) if __name__ == '__main__': for idx,(label,sent) in enumerate(dataloader): print(idx) print(label) print(sent) break print(len(sms_dataset)) print(len(dataloader))
0 ('ham', 'spam') ('And popping <#> ibuprofens was no help.', 'This is the 2nd time we have tried 2 contact u. U have won the 750 Pound prize. 2 claim is easy, call 08712101358 NOW! Only 10p per min. BT-national-rate') 5574 2787
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch Dataset数据集和Dataloader迭代数据集 - Python技术站