Pytorch在dataloader类中设置shuffle的随机数种子方式

yizhihongxing

PyTorch的数据集DataLoader是十分常用的数据加载和预处理工具,通过将数据传输到GPU并在深度学习过程中进行抽样,而它的shuffle参数可以打乱数据集的顺序,使损失函数更加随机。但同时,我们也可能需要控制随机的行为,以获得可再现的实验结果。下面是两种设置shuffle随机数种子的方法:

方法一:使用torch.utils.data.DataLoader类的WorkerInitFn参数

我们可以使用WorkerInitFn来传递一个函数,来控制数据集加载器的每个工作进程的初始化过程。以下是一个示例的代码段:

import random
import torch
from torch.utils.data import DataLoader

class MyDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.data = list(range(10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# 设置随机数种子,获得可再现的实验结果
def worker_init_fn(worker_id):
    random.seed(worker_id)

dataset = MyDataset()

dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
                        num_workers=2, worker_init_fn=worker_init_fn)

for i, batch in enumerate(dataloader):
    print(batch)

在这个例子中,我们将worker_init_fn设置为一个函数,该函数会在每个工作进程初始化时调用,并使用其工作进程ID作为随机数种子,以控制每个进程数据加载顺序的随机性。这里,使用random.seed来设置随机种子。

shuffle参数设置为True时,DataLoader会在每个工作进程中打乱数据,并将其放回主进程。 在每个工作进程初始化时,随机数种子被设置成与工作进程ID有关的值。这样,每个进程在打乱数据时使用不同的随机数种子,以确保打乱后的顺序是独立的,而不是互相关联的。

方法二:使用torch.Generator

我们也可以使用PyTorch的Random模块来设置DataLoader类中的随机数种子。具体做法是将shuffle设置为True,然后使用PyTorch的工具包生成随机数种子。以下是一个示例的代码段:

import torch
import torch.utils.data as data_utils

torch.manual_seed(42)  # 设置随机数种子

# 创建数据集
data = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
target = torch.Tensor([1, 1, 0, 0])
dataset = data_utils.TensorDataset(data, target)

# 创建DataLoader类
batch_size = 2
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42))

# 打印出来
for batch_idx, (data, target) in enumerate(dataloader):
    print("Batch index {}, data shape {}, target shape {}".format(batch_idx, data.shape, target.shape))

此例中,我们将DataLoader类的generator参数设置为为torch.Generator().manual_seed(42)shuffle参数设置为True,并使用torch.manual_seed(42)方法设置随机数种子来控制打乱数据的顺序。在这个例子中,generatortorch.Generator对象,我们设置它的随机数种子为42。这样每一次使用DataLoader类,我们都能得到相同的打乱数据顺序。

这两种设置shuffle随机数种子的方式,在控制随机性方面有其各自的优点和适用场景,读者可以根据情况选择更加适合自身需求的方法。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch在dataloader类中设置shuffle的随机数种子方式 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • python中tkinter复选框使用操作

    接下来我将为你详细讲解“Python中Tkinter复选框使用操作”的完整攻略,以及两个示例说明。 什么是Tkinter复选框 复选框(Checkbox)是一种用户界面控件,通常用于表示可以选择或取消选择的选项。在Tkinter中,复选框使用Checkbutton控件实现。 如何创建复选框 使用Tkinter创建复选框非常简单,只需要调用Checkbutto…

    python 2023年6月13日
    00
  • python实现的登录和操作开心网脚本分享

    开心网是一个中国社交网络平台,本文将详细讲解如何使用Python实现登录和操作开心网的完整攻略,包括使用requests库发送HTTP请求和处理HTTP响应、使用BeautifulSoup库解析HTML文档、使用selenium库模拟浏览器操作等。 登录开心网 在Python中,我们可以使用requests库发送HTTP POST请求模拟登录开心网。以下是一…

    python 2023年5月15日
    00
  • python集合能干吗

    Python集合是一种无序、不重复的数据类型,可以用于存储各种类型的值,例如数字、字符串和元组等。集合非常适合用于数据去重、判断成员关系、求交集和并集等场景。 数据去重 集合最常用的功能之一就是去重。我们可以将一组数据放到一个集合中,自动去除重复的元素。使用方法如下: # 创建一个列表,包含重复元素 nums = [1, 2, 3, 2, 4, 5, 1] …

    python 2023年5月13日
    00
  • Python科学画图代码分享

    Python科学画图代码分享 前言 Python是一门优秀的编程语言,尤其在科学计算领域拥有广泛的应用。Python科学画图模块也越来越受到关注。通过本篇文章,我们将学习如何用Python科学画图模块来进行数据可视化,并分享一些常用的代码。 本篇文章将重点介绍以下三个主要的Python科学画图模块: Matplotlib:Python中最常用的科学画图模块之…

    python 2023年5月19日
    00
  • Python正则表达式re模块详解(建议收藏!)

    Python正则表达式re模块详解 正则表达式是一种用于描述字符串模式的语言,可以用于匹配、查找、替换和割字符串。Python中的re模块提供了正则表达式支持,方便进行字符串的处理。本文将详细讲解Python正则表达式的使用,包括正则表达式语法、re模块的常用函数以及两个常用匹配实例。 正则表达式语法 正则表达式由一些特殊字符和普通字符组成,用于字符串模式匹…

    python 2023年5月14日
    00
  • 使用python scrapy爬取天气并导出csv文件

    下面是使用Python Scrapy爬取天气数据并导出CSV文件的完整攻略,包括以下步骤: 第一步:安装Scrapy Scrapy是一个Python爬虫框架,可以大大简化爬取网页的过程。安装Scrapy的方法是打开命令行窗口(或者终端),输入以下命令: pip install scrapy 第二步:创建一个Scrapy项目 在命令行窗口中,输入以下命令: s…

    python 2023年6月3日
    00
  • Python中super()函数简介及用法分享

    Python中super()函数简介及用法分享 简介 在Python中,如果需要在子类中调用父类的方法或属性,可以使用super()函数。super()函数返回父类实例的对象,通过它可以调用父类的方法和属性。 super()函数有两个参数,第一个参数是子类类型,第二个参数是对象(self),可以省略。 用法 下面是super()函数的一些常用用法: 1. 调…

    python 2023年6月5日
    00
  • 一文详解Python中哈希表的使用

    一文详解Python中哈希表的使用 什么是哈希表 哈希表也称为散列表,是一种用于存储键值对的数据结构。在哈希表中,每个键都与一个特定的值相关联。哈希表使用哈希函数将键映射到存储桶中,以便快速访问键对应的值。 Python中的哈希表实现在内部使用了散列表。Python的“字典”数据类型就是基于哈希表实现的,也称为dict。字典的键必须是不可变类型,例如数字、字…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部