Pytorch在dataloader类中设置shuffle的随机数种子方式

PyTorch的数据集DataLoader是十分常用的数据加载和预处理工具,通过将数据传输到GPU并在深度学习过程中进行抽样,而它的shuffle参数可以打乱数据集的顺序,使损失函数更加随机。但同时,我们也可能需要控制随机的行为,以获得可再现的实验结果。下面是两种设置shuffle随机数种子的方法:

方法一:使用torch.utils.data.DataLoader类的WorkerInitFn参数

我们可以使用WorkerInitFn来传递一个函数,来控制数据集加载器的每个工作进程的初始化过程。以下是一个示例的代码段:

import random
import torch
from torch.utils.data import DataLoader

class MyDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# 设置随机数种子,获得可再现的实验结果
def worker_init_fn(worker_id):
    random.seed(worker_id)

dataset = MyDataset()

dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
                        num_workers=2, worker_init_fn=worker_init_fn)

for i, batch in enumerate(dataloader):
    print(batch)

在这个例子中,我们将worker_init_fn设置为一个函数,该函数会在每个工作进程初始化时调用,并使用其工作进程ID作为随机数种子,以控制每个进程数据加载顺序的随机性。这里,使用random.seed来设置随机种子。

shuffle参数设置为True时,DataLoader会在每个工作进程中打乱数据,并将其放回主进程。 在每个工作进程初始化时,随机数种子被设置成与工作进程ID有关的值。这样,每个进程在打乱数据时使用不同的随机数种子,以确保打乱后的顺序是独立的,而不是互相关联的。

方法二:使用torch.Generator

我们也可以使用PyTorch的Random模块来设置DataLoader类中的随机数种子。具体做法是将shuffle设置为True,然后使用PyTorch的工具包生成随机数种子。以下是一个示例的代码段:

import torch
import torch.utils.data as data_utils

torch.manual_seed(42)  # 设置随机数种子

# 创建数据集
data = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
target = torch.Tensor([1, 1, 0, 0])
dataset = data_utils.TensorDataset(data, target)

# 创建DataLoader类
batch_size = 2
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42))

# 打印出来
for batch_idx, (data, target) in enumerate(dataloader):
    print("Batch index {}, data shape {}, target shape {}".format(batch_idx, data.shape, target.shape))

此例中,我们将DataLoader类的generator参数设置为为torch.Generator().manual_seed(42)shuffle参数设置为True,并使用torch.manual_seed(42)方法设置随机数种子来控制打乱数据的顺序。在这个例子中,generatortorch.Generator对象,我们设置它的随机数种子为42。这样每一次使用DataLoader类,我们都能得到相同的打乱数据顺序。

这两种设置shuffle随机数种子的方式,在控制随机性方面有其各自的优点和适用场景,读者可以根据情况选择更加适合自身需求的方法。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch在dataloader类中设置shuffle的随机数种子方式 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • 使用python接入微信聊天机器人

    下面是使用Python接入微信聊天机器人的完整攻略。 1. 准备工作 在使用Python接入微信聊天机器人前,我们需要确保以下几点: 已安装Python,建议使用Python 3.x版本。 已安装itchat模块,itchat是一个开源的微信个人号接口,使用pip安装即可,命令如下: pip install itchat 已准备好微信个人号,可以在手机上登录…

    python 2023年5月23日
    00
  • Python爬虫之Selenium实现窗口截图

    下面是“Python爬虫之Selenium实现窗口截图”的攻略: 1. 安装Selenium 首先需要安装Selenium,可使用pip包管理器,输入以下命令: pip install selenium 2. 下载Chromedriver 使用Selenium需要下载浏览器驱动,这里以Chrome浏览器为例,下载对应版本的Chromedriver,在http…

    python 2023年5月14日
    00
  • Python for 循环语句的使用

    下面是Python for循环语句的使用完整攻略。 什么是Python for循环语句? 在Python中,for循环语句可以遍历任何序列的项目,例如一个列表或一个字符串。for循环的一般形式如下: for var in sequence: statements 其中,var 指的是变量,在 for 循环中会被赋值为序列 sequence 中的每个项,一次循…

    python 2023年6月5日
    00
  • OpenCV找到彩色圆圈和位置值Python

    【问题标题】:OpenCV find coloured in circle and position value PythonOpenCV找到彩色圆圈和位置值Python 【发布时间】:2023-04-03 18:39:01 【问题描述】: 我要做的是处理下面的考勤表,告诉我谁在场,谁不在 我目前正在使用 matchTemplate,它使用一个奇异的黑点来查…

    Python开发 2023年4月8日
    00
  • python+pytest接口自动化之日志管理模块loguru简介

    欢迎来到本篇文章,本文主要介绍Python+pytest接口自动化测试中的一个强大的日志管理模块——loguru。 什么是loguru? loguru是一款Python的日志管理模块,具有以下特点: 易于使用,方便快捷地记录Python日志; 提供多种配置方式,满足不同用户的需求; 具有强大的过滤和格式化功能; 支持多进程、多线程、异步I/O等场景下的日志记…

    python 2023年6月6日
    00
  • python实现共轭梯度法

    这里为大家介绍下 Python 实现共轭梯度法的完整攻略。 共轭梯度法概述 共轭梯度法是一种求解线性方程组的迭代方法,它的优点是收敛速度较快,特别是对于大规模稀疏矩阵的求解。共轭梯度法的原理是基于最小化二次型的思想,通过不断迭代改进搜索方向,以达到快速收敛的目的。 在实现共轭梯度法之前,需要先定义一下模型和目标函数。 定义模型 定义模型时,需要定义一个二次型…

    python 2023年6月5日
    00
  • 学习Python爬虫前必掌握知识点

    学习Python爬虫前必掌握知识点,包括以下几个方面: 1. Python基础知识 Python是一门高级编程语言,支持多种编程范式。在学习Python爬虫前,需要掌握Python的基础语法,包括但不限于: 变量的定义与使用 数据类型(数字、字符串、列表、字典、元组等) 条件语句与控制结构(if-else、for、while等) 函数的定义与调用 模块的导入…

    python 2023年5月14日
    00
  • python多线程实现同时执行两个while循环的操作

    实现同时执行两个while循环的操作可以使用python的多线程来实现。需要创建两个线程分别执行两个while循环。 下面是实现多线程的示例代码: import threading def thread_1(): while True: # 线程1的循环内容 print("Thread 1 is running") def thread_…

    python 2023年5月18日
    00
合作推广
合作推广
分享本页
返回顶部