Pytorch关于Dataset 的数据处理

PyTorch关于Dataset的数据处理

在PyTorch中,Dataset是一个抽象类,用于表示数据集。它提供了一种统一的方式来处理数据,使得我们可以轻松地加载和处理数据。在本文中,我们将详细介绍如何使用PyTorch中的Dataset类来处理数据,并提供两个示例来说明其用法。

1. 创建自定义Dataset

要创建自定义Dataset,需要继承PyTorch中的Dataset类,并实现以下两个方法:

  • __len__:返回数据集的大小。
  • __getitem__:返回给定索引的数据样本。

以下是一个示例,展示如何创建一个自定义Dataset:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index][0]
        y = self.data[index][1]
        return torch.tensor(x), torch.tensor(y)

在上面的示例中,我们创建了一个名为CustomDataset的自定义Dataset。它接受一个名为data的参数,该参数是一个列表,其中每个元素都是一个包含输入和输出的元组。在__len__方法中,我们返回数据集的大小。在__getitem__方法中,我们返回给定索引的数据样本,其中输入和输出都被转换为PyTorch张量。

2. 使用自定义Dataset

要使用自定义Dataset,需要将其传递给PyTorch中的DataLoader类。DataLoader类可以自动将数据集分成小批量,并在训练期间加载数据。以下是一个示例,展示如何使用自定义Dataset:

from torch.utils.data import DataLoader

# 创建自定义数据集
data = [(1, 2), (3, 4), (5, 6), (7, 8)]
dataset = CustomDataset(data)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历数据加载器
for batch in dataloader:
    x, y = batch
    print(x, y)

在上面的示例中,我们首先创建了一个自定义数据集dataset,它包含四个元素,每个元素都是一个包含输入和输出的元组。然后,我们使用DataLoader类创建了一个数据加载器dataloader,它将数据集分成大小为2的小批量,并在训练期间加载数据。最后,我们遍历数据加载器,并打印每个小批量的输入和输出。

3. 示例1:使用PyTorch中的Dataset类加载MNIST数据集

以下是一个示例,展示如何使用PyTorch中的Dataset类加载MNIST数据集:

import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# 加载MNIST数据集
train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)

# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 遍历数据加载器
for batch in train_dataloader:
    x, y = batch
    print(x.shape, y.shape)
    break

在上面的示例中,我们首先使用MNIST类加载MNIST数据集,并将其转换为PyTorch张量。然后,我们使用DataLoader类创建了两个数据加载器,一个用于训练数据,另一个用于测试数据。最后,我们遍历训练数据加载器,并打印第一个小批量的输入和输出。

4. 示例2:使用PyTorch中的Dataset类加载自定义图像数据集

以下是一个示例,展示如何使用PyTorch中的Dataset类加载自定义图像数据集:

import os
import torch
from PIL import Image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.images[index])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(0)

# 创建自定义图像数据集
dataset = CustomImageDataset('data/', transform=transforms.ToTensor())

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 遍历数据加载器
for batch in dataloader:
    x, y = batch
    print(x.shape, y.shape)
    break

在上面的示例中,我们创建了一个名为CustomImageDataset的自定义图像数据集。它接受一个名为root_dir的参数,该参数是包含图像文件的目录。在__len__方法中,我们返回数据集的大小。在__getitem__方法中,我们加载给定索引的图像,并将其转换为PyTorch张量。最后,我们使用DataLoader类创建了一个数据加载器,并遍历它以打印第一个小批量的输入和输出。

5. 总结

在PyTorch中,Dataset是一个抽象类,用于表示数据集。它提供了一种统一的方式来处理数据,使得我们可以轻松地加载和处理数据。在本文中,我们详细介绍了如何使用PyTorch中的Dataset类来处理数据,并提供了两个示例来说明其用法。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch关于Dataset 的数据处理 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • weight_decay in Pytorch

    在训练人脸属性网络时,发现在优化器里增加weight_decay=1e-4反而使准确率下降 pytorch论坛里说是因为pytorch对BN层的系数也进行了weight_decay,导致BN层的系数趋近于0,使得BN的结果毫无意义甚至错误 当然也有办法不对BN层进行weight_decay, 详见pytorch forums讨论1pytorch forums…

    PyTorch 2023年4月8日
    00
  • python调用pytorch实现deeplabv3+图像语义分割——以分割动漫人物为例

    图像语义分割就是把图像分成若干个特定的、具有独特性质的区域并提出感兴趣目标的技术和过程。本文提供了一个可进行自定义数据集训练基于pytorch的deeplabv3+图像分割模型的方法,训练了一个动漫人物分割模型,不过数据集较小,仅供学习使用 程序输入:动漫图片 程序输出:分割好的动漫人物图片 目录 程序简介 程序/数据集下载 数据集准备 训练步骤 预测演示步…

    2023年4月8日
    00
  • PytorchMNIST(使用Pytorch进行MNIST字符集识别任务)

      都说MNIST相当于机器学习界的Hello World。最近加入实验室,导师给我们安排了一个任务,但是我才刚刚入门呐!!没办法,只能从最基本的学起。   Pytorch是一套开源的深度学习张量库。或者我倾向于把它当成一个独立的深度学习框架。为了写这么一个”Hello World”。查阅了不少资料,也踩了不少坑。不过同时也学习了不少东西,下面我把我的代码记…

    2023年4月7日
    00
  • Pytorch 扩展Tensor维度、压缩Tensor维度

        相信刚接触Pytorch的宝宝们,会遇到这样一个问题,输入的数据维度和实验需要维度不一致,输入的可能是2维数据或3维数据,实验需要用到3维或4维数据,那么我们需要扩展这个维度。其实特别简单,只要对数据加一个扩展维度方法就可以了。 1.1 torch.unsqueeze(self: Tensor, dim: _int)   torch.unsqueez…

    2023年4月8日
    00
  • pytorch1.0实现RNN-LSTM for Classification

    import torch from torch import nn import torchvision.datasets as dsets import torchvision.transforms as transforms import matplotlib.pyplot as plt # 超参数 # Hyper Parameters # 训练整批数据…

    PyTorch 2023年4月6日
    00
  • pytorch 设置种子

    目的: 固定住训练的顺序等变量,使实验可复现 def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = Tr…

    PyTorch 2023年4月6日
    00
  • pytorch 读取和保存模型参数

    只保存参数信息 加载 checkpoint = torch.load(opt.resume) model.load_state_dict(checkpoint) 保存 torch.save(self.state_dict(),file_path) 这而只保存了参数信息,读取时也只有参数信息,模型结构需要手动编写 保存整个模型 保存torch.save(the…

    PyTorch 2023年4月8日
    00
  • Yolov5训练意外中断后如何接续训练详解

    当YOLOv5的训练意外中断时,我们可以通过接续训练来恢复训练过程,以便继续训练模型。下面是接续训练的详细步骤: 首先,我们需要保存当前训练的状态。我们可以使用PyTorch提供的torch.save()函数将模型的参数和优化器的状态保存到文件中。例如,我们可以使用以下代码将模型的参数和优化器的状态保存到文件checkpoint.pth中: torch.sa…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部