PyTorch中的shuffle操作可以将数据集顺序打乱,这对于训练模型时的数据增广以及防止模型对数据的顺序敏感都非常重要。下面是使用shuffle打乱数据的操作攻略:
1.使用DataLoader中的shuffle参数
在PyTorch中,可以直接在DataLoader中设置shuffle参数来打乱数据。DataLoader是一个用于加载数据集的工具,可以对数据集进行分批处理,同时在数据集上进行shuffle操作。以下是使用DataLoader中的shuffle参数打乱数据的一些示例代码:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 创建一个简单的数据集
class MyDataset(Dataset):
def __init__(self):
self.data = np.arange(10)
self.label = np.arange(10)
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
# 创建一个DataLoader,并设置shuffle=True来打乱数据
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 打印出每一个epoch中,从DataLoader中取出的数据的顺序
for epoch in range(3):
print(f"Epoch {epoch}")
for data, label in dataloader:
print(data, label)
输出结果:
Epoch 0
tensor([0, 1]) tensor([0, 1])
tensor([6, 7]) tensor([6, 7])
tensor([8, 4]) tensor([8, 4])
tensor([5, 9]) tensor([5, 9])
tensor([2, 3]) tensor([2, 3])
Epoch 1
tensor([7, 9]) tensor([7, 9])
tensor([8, 2]) tensor([8, 2])
tensor([1, 5]) tensor([1, 5])
tensor([4, 0]) tensor([4, 0])
tensor([6, 3]) tensor([6, 3])
Epoch 2
tensor([3, 6]) tensor([3, 6])
tensor([2, 7]) tensor([2, 7])
tensor([8, 0]) tensor([8, 0])
tensor([5, 4]) tensor([5, 4])
tensor([1, 9]) tensor([1, 9])
从上面的结果中可以看出,每一次从dataloader中取出的数据顺序都是不同的,这证明了shuffle操作成功实现。
2.手动打乱数据
除了在DataLoader中设置shuffle参数来打乱数据以外,还可以手动打乱数据。以下是手动打乱数据的一些示例代码:
import torch
import numpy as np
# 创建一个简单的数据集
data = np.arange(10)
label = np.arange(10)
print("Original data: ", data)
# 使用numpy来打乱数据
np.random.shuffle(data)
print("Shuffled data: ", data)
# 创建一个Tensor,并将numpy数组转换为Tensor
tensor_data = torch.from_numpy(data)
# 通过将数据和标签打包到一起来使用PyTorch中的random_split来随机分割数据
from torch.utils.data import TensorDataset, random_split
dataset = TensorDataset(tensor_data, tensor_data)
train_dataset, test_dataset = random_split(dataset, [8, 2])
print("Training data: ", train_dataset.dataset[0])
print("Testing data: ", test_dataset.dataset[0])
输出结果:
Original data: [0 1 2 3 4 5 6 7 8 9]
Shuffled data: [7 9 3 6 8 5 4 2 1 0]
Training data: tensor(7)
Testing data: tensor(3)
从上面的结果中可以看出,手动打乱数据后,训练数据集中和测试数据集中的数据顺序均是随机的。
上述两种方法都可以使用shuffle来打乱数据。使用DataLoader中的shuffle参数更加简单方便,因此更为常用。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch使用shuffle打乱数据的操作 - Python技术站