PyTorch的数据集DataLoader
是十分常用的数据加载和预处理工具,通过将数据传输到GPU并在深度学习过程中进行抽样,而它的shuffle
参数可以打乱数据集的顺序,使损失函数更加随机。但同时,我们也可能需要控制随机的行为,以获得可再现的实验结果。下面是两种设置shuffle
随机数种子的方法:
方法一:使用torch.utils.data.DataLoader
类的WorkerInitFn
参数
我们可以使用WorkerInitFn
来传递一个函数,来控制数据集加载器的每个工作进程的初始化过程。以下是一个示例的代码段:
import random
import torch
from torch.utils.data import DataLoader
class MyDataset(Dataset):
def __init__(self):
super().__init__()
self.data = list(range(10))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 设置随机数种子,获得可再现的实验结果
def worker_init_fn(worker_id):
random.seed(worker_id)
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
num_workers=2, worker_init_fn=worker_init_fn)
for i, batch in enumerate(dataloader):
print(batch)
在这个例子中,我们将worker_init_fn
设置为一个函数,该函数会在每个工作进程初始化时调用,并使用其工作进程ID作为随机数种子,以控制每个进程数据加载顺序的随机性。这里,使用random.seed
来设置随机种子。
当shuffle
参数设置为True
时,DataLoader
会在每个工作进程中打乱数据,并将其放回主进程。 在每个工作进程初始化时,随机数种子被设置成与工作进程ID有关的值。这样,每个进程在打乱数据时使用不同的随机数种子,以确保打乱后的顺序是独立的,而不是互相关联的。
方法二:使用torch.Generator
类
我们也可以使用PyTorch的Random模块来设置DataLoader
类中的随机数种子。具体做法是将shuffle
设置为True
,然后使用PyTorch的工具包生成随机数种子。以下是一个示例的代码段:
import torch
import torch.utils.data as data_utils
torch.manual_seed(42) # 设置随机数种子
# 创建数据集
data = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
target = torch.Tensor([1, 1, 0, 0])
dataset = data_utils.TensorDataset(data, target)
# 创建DataLoader类
batch_size = 2
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42))
# 打印出来
for batch_idx, (data, target) in enumerate(dataloader):
print("Batch index {}, data shape {}, target shape {}".format(batch_idx, data.shape, target.shape))
此例中,我们将DataLoader
类的generator
参数设置为为torch.Generator().manual_seed(42)
,shuffle
参数设置为True
,并使用torch.manual_seed(42)
方法设置随机数种子来控制打乱数据的顺序。在这个例子中,generator
是torch.Generator
对象,我们设置它的随机数种子为42。这样每一次使用DataLoader
类,我们都能得到相同的打乱数据顺序。
这两种设置shuffle
随机数种子的方式,在控制随机性方面有其各自的优点和适用场景,读者可以根据情况选择更加适合自身需求的方法。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch在dataloader类中设置shuffle的随机数种子方式 - Python技术站