Pytorch数据读取之Dataset和DataLoader知识总结

当使用PyTorch进行深度学习时,我们需要将数据转化为张量并通过模型传递,但如何将原始数据转化为张量呢?这就涉及到PyTorch数据读取中的Dataset和DataLoader两个重要的概念。

Dataset

PyTorch中的Dataset是一个抽象类,代表数据集,它可以定义自己的数据形式、读取数据的方式、增加额外的预处理步骤等。我们只需继承该类,并实现__getitem__和__len__两个魔法方法即可。

示例1

以下是一个简单的示例,展示如何创建一个自定义的Dataset:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.data[index]  # 假设data是一个存储数据和标签的列表
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.data)

这个示例中,我们创建了一个名为MyDataset的子类,它包含2个参数:data和transform。MyDataset的__getitem__方法返回数据和标签,并对数据应用了一个可选的图像变换(transform)。__len__方法返回数据集的大小。

示例2

接下来是一个更完整的示例,展示如何将PyTorch自带的CIFAR-10数据集转换成自己的Dataset:

import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset

class MyCIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        self.data = datasets.CIFAR10(root, train=train, transform=transform, target_transform=target_transform)
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.data[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.data)

这个示例中,我们从torchvision中导入CIFAR-10数据集,并使用transforms模块定义一个可选的数据预处理操作。然后我们创建一个名为MyCIFAR10的子类,它继承了Dataset并包含4个参数:root、train、transform和target_transform。我们在__init__方法中创建了一个CIFAR10对象,它将使用传递的参数初始化。继承自Dataset的__getitem__和__len__方法分别返回数据和标签以及数据集大小。

DataLoader

DataLoader是PyTorch提供的一个数据读取器,它可以将Dataset中的数据转化为迭代器,并提供一些有用的功能,如自动批量读取、多进程数据加载、随机打乱数据等。下面是一些常用的DataLoader参数:

  • dataset:数据集。
  • batch_size:每次返回的数据批量大小。
  • shuffle:是否随机打乱数据。
  • num_workers:读取数据的线程数。
  • drop_last:当数据集大小不能整除batch_size时,是否丢弃最后一批数据(默认为False,即不丢弃)。

示例1

以下是一个简单的示例,展示如何使用DataLoader读取数据:

import torch
from torch.utils.data import DataLoader

dataset = MyDataset(data)  # 创建数据集
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)  # 创建数据读取器

for i, batch in enumerate(dataloader):
    x, y = batch
    print('批次', i, ':', x, y)

这个示例中,我们首先创建了一个名为dataset的MyDataset对象。然后我们使用DataLoader创建了一个名为dataloader的数据读取器,它将读取dataset中的数据,并返回batch_size大小的数据批次。我们使用for循环迭代dataloader,并逐批次获取数据。由于我们设置了shuffle参数为True,每个批次的数据都将是随机的。

示例2

下面是一个更完整的示例,展示如何将PyTorch自带的CIFAR-10数据集转换成可用于训练的DataLoader对象:

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

transform_train = transforms.Compose([  # 定义数据预处理操作
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = MyCIFAR10(root='./data', train=True, transform=transform_train)  # 创建训练集
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)  # 创建训练集读取器

for i, batch in enumerate(trainloader):
    x, y = batch
    # 训练代码...

这个示例中,我们使用了transforms模块定义了一系列的数据预处理操作。然后我们创建了一个名为trainset的MyCIFAR10对象,它将使用这些预处理操作,并将CIFAR-10训练数据集中的数据和标签作为参数进行初始化。接着我们使用DataLoader创建了一个名为trainloader的训练集读取器,它将随机读取128张图片作为一个批次,并使用2个线程并行读取数据。

总之,Dataset和DataLoader是PyTorch中非常重要的数据读取相关的类,通过它们我们可以有效地读取、批次化和预处理数据,是深度学习中必不可少的组件。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch数据读取之Dataset和DataLoader知识总结 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • 如何在Pandas中执行SUMIF函数

    在Pandas中执行SUMIF函数,需要使用groupby方法结合agg方法,具体步骤如下: 使用groupby方法按指定列分组 使用agg方法,指定要进行聚合的函数,如sum、count、mean等。 对于需要进行条件筛选的列,使用lambda表达式指定条件 以下是一个示例代码,假设我们有一个sales表,其中包含商品名称、销售数量和销售价格三列数据: i…

    python-answer 2023年3月27日
    00
  • 在Pandas中执行交叉连接的Python程序

    交叉连接在Pandas中的一般称呼是笛卡尔积。笛卡尔积是指将两个数据集的每个元素组合成一个新的数据集。Pandas提供了一个函数,可以快速且简单地进行笛卡尔积操作:pandas.DataFrame.merge()。 下面演示一下如何在Pandas中执行交叉连接的Python程序: 首先,我们需要导入 Pandas 包。接着,我们需要创建两个数据集 df1 和…

    python-answer 2023年3月27日
    00
  • 简单介绍Python中的JSON模块

    当我们想将数据以一种易于读取和存储的方式进行传输时,我们通常会使用JSON数据格式。Python中的JSON模块为我们提供了便捷的方法来操纵JSON数据。 什么是JSON模块 JSON模块是提供了编码和解码JSON数据的Python标准库。该模块提供了四个方法:dump(), dumps(), load()和loads()。 dump(obj, fp, *,…

    python 2023年5月14日
    00
  • Pandas的Apply函数具体使用

    关于Pandas的Apply函数的具体使用,我将为您提供一份完整攻略。下面将会分为以下几个部分: 什么是Pandas的Apply函数? Apply函数的基础用法 Apply函数的高级用法 两条示例说明 1.什么是Pandas的Apply函数? Pandas的apply函数是一种能够作用于Pandas数据的灵活且高性能的函数。此函数可以用于许多相似的目的。比如…

    python 2023年5月14日
    00
  • 如何从Pandas数据框架中创建Boxplot

    当我们想比较不同分组或分类之间的数据分布时,Boxplot是一个非常有效的数据可视化方式。在Python中,我们可以使用Pandas数据框架和Matplotlib库来轻松创建Boxplot图表。 下面是如何从Pandas数据框架中创建Boxplot的步骤: 1. 导入相关库并读取数据 首先,我们需要导入所需的Python库——Pandas和Matplotli…

    python-answer 2023年3月27日
    00
  • Python pandas DataFrame操作的实现代码

    Python pandas DataFrame 操作的实现代码攻略 为了进行Python pandas DataFrame操作,首先需要导入pandas模块。常用的pandas模块操作有以下几种: 创建DataFrame:在pandas模块中,可以通过list、dict和CSV文件创建DataFrame。 读取CSV文件并创建DataFrame:pandas…

    python 2023年5月14日
    00
  • 详细介绍pandas的DataFrame的append方法使用

    当我们在使用 pandas 来处理数据时,DataFrame 是我们使用最频繁的数据结构之一。DataFrame 中的数据以二维表格的形式出现,其中每行代表一个数据样本,每列代表一个特征或变量。 在 pandas 的 DataFrame 中,我们可以使用 append 方法来合并两个 DataFrame。这个方法返回的是一个新的 DataFrame,原始的两…

    python 2023年5月14日
    00
  • 如何在Pandas中创建一个空的DataFrame并向其添加行和列

    在 Pandas 中创建一个空的 DataFrame 并向其添加行和列涉及以下步骤: 导入 Pandas 模块: import pandas as pd 创建空的 DataFrame: df = pd.DataFrame() 添加列到 DataFrame,使用以下语法: df[‘column_name’] = None 其中,column_name 是你想要…

    python-answer 2023年3月27日
    00
合作推广
合作推广
分享本页
返回顶部