pytorch 数据处理:定义自己的数据集合实例

请看下面的详细讲解。

PyTorch数据处理:定义自己的数据集合实例

在进行深度学习任务时,数据预处理是非常重要的一步,而 PyTorch 中,数据预处理也是必不可少的一环。在大多数情况下,我们需要使用已有的数据集,如官方提供的 MNIST、CIFAR10 等数据集;但有时我们也需要自己定义数据集,例如从图片数据集中自定义一个猫狗二分类的数据集。自定义数据集的过程其实并不困难,下面我们就来详细讲解一下如何定义 PyTorch 自己的数据集合实例。

定义自己的数据集合实例包含以下步骤:

  1. 构建包含数据和标签的数据集类。

数据集类需要继承 torch.utils.data.Dataset,并实现以下两个方法:

- `__len__(self)` 返回数据集的长度。
- `__getitem__(self, index)` 给定一个索引 index,返回对应的数据和标签。
  1. 对数据进行预处理。

  2. 创建数据加载器 DataLoader。

接下来我们将对上面三个步骤进行详细讲解。

步骤一:构建数据集类

让我们从构建一个二分类数据集为例。我们有一些猫和狗的图片,我们需要将他们分别标记为1和0,并存储在一个字典中。我们使用 PyTorch 中的 ImageFolder 类来加载数据。

import os
import torch.utils.data as data
from torchvision import datasets, transforms

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir):
    images = []
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in sorted(fnames):
            if is_image_file(fname):
                path = os.path.join(root, fname)
                item = (path, int(fname.split('.')[0] == 'cat'))
                images.append(item)
    return images

class CatDogDataset(data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = make_dataset(self.root)

    def __getitem__(self, index):
        path, target = self.images[index]
        img = default_loader(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.images)

在上述代码中,我们首先定义了一个 IMG_EXTENSIONS 列表,其中包含了我们所支持的图片格式的后缀名。接着定义了一个函数 is_image_file 用于判断某个文件是否是图片;再定义了 make_dataset 函数用于生成一个元组列表,其中包含了图片的路径和标签,这个标签是根据文件名来判断的,如果文件名是 'cat.XXX.jpg' 的格式,则标签为1,否则标签为0。最后,我们定义了 CatDogDataset 类,继承了 torch.utils.data.Dataset;定义了初始化函数 __init__,用于初始化数据集路径和数据增强方法;定义了 __getitem__ 方法,用于返回指定索引的数据和标签;以及 __len__ 方法返回数据集长度。

步骤二:对数据进行预处理

为了让模型能够更好地利用数据,我们通常需要对数据进行预处理。在 PyTorch 中,我们可以使用 torchvision.transforms 来进行数据预处理,其中包含了很多有用的函数,例如对图像进行裁剪、旋转、缩放等增强操作。下面是一个简单的数据预处理示例:

transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

在上述代码中,我们通过 transforms.Compose 将多个数据预处理操作串联起来,最后将它们应用在我们定义的数据集上。在这个例子中,我们首先使用 transforms.CenterCrop 将图片中心裁剪为 224 × 224,再使用 transforms.ToTensor 将图片(0, 255)范围内的像素转换到(0, 1)范围内的张量,最后使用 transforms.Normalize 对图片进行归一化操作。对于这个归一化操作,我们需要使用图像数据集上的均值和方差。在这里,我们使用 ImageNet 数据集上的均值和方差,这是一个广泛使用的标准值。

步骤三:创建数据加载器 DataLoader

在数据预处理之后,我们需要将定义好的数据集类加载到 PyTorch 的 DataLoader 类当中,并且设定好每次取多少张图片进行训练或测试。DataLoader 可以自动对数据集进行批次切分,这样可以提高模型训练时的效率。

from torch.utils.data import DataLoader

batch_size = 64

trainset = CatDogDataset('path/to/train/data', transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

在上述代码中,我们使用了 torch.utils.data.DataLoader 类来生成数据加载器,其中包含了我们自己定义的 CatDogDataset 类。我们设定了 batch size 为 64,shuffle 参数为 True 表示每个 epoch 随机对数据集进行洗牌,num_workers 表示使用几个进程来加载数据。该数据加载器可以用于训练过程中,例如:

for epoch in range(10): # 循环10次
    for i, (inputs, labels) in enumerate(trainloader, 0): # 每次取一个批次
        print(inputs.shape) # 显示当前批次数据的形状

在上述代码中,我们循环了 10 次,并且每次从训练数据中取出一个 batch size 大小的数据来训练模型。每次取出的数据使用 inputslabels 两个变量进行存储。

示例一:在CIFAR10数据集上创建训练集

以「CIFAR10数据集」为例,首先我们需要安装「torchvision」包,并引入以下依赖:

import torch
import torchvision
import torchvision.transforms as transforms

接着,我们定义数据预处理操作,这里需要注意的一点是,我们不对「CIFAR10数据集」进行归一化,因为这个数据集本身已经做过归一化处理。

transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor()])

transform_val = transforms.Compose([transforms.ToTensor()])

最后,我们以 4 作为 batch size,创建数据加载器。并使用 trainset.train_datatrainset.train_labels,这是 CIFAR10 自带的数据集。训练数据集我们随机抽取 80% 作为训练集,20% 作为测试集。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)

n_train = len(trainset)
train_idx, val_idx = torch.utils.data.random_split(range(n_train), [int(0.8*n_train), int(0.2*n_train)])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
valloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(val_idx))
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

示例二:通过加密文件夹生成训练数据集

在一些实际场景中,我们的数据集可能需要加密以保护数据隐私。我们可以先将数据集文件加密,在使用 PyTorch 加载数据集时进行解密。我们以放置有已加密文件的文件夹作为例子。

import torch
import torch.utils.data as data
from PIL import Image
from Crypto.Cipher import AES

class EncryptedDataset(data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

    def __getitem__(self, index):
        path = self.images[index]
        img = Image.open(path)
        img = self.decrypt(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.tensor(0) # 返回第二个值用于符合 __getitem__ 的规则

    def __len__(self):
        return len(self.images)

    def decrypt(self, image):
        block_size = 16
        aes = AES.new('your_secret_key', AES.MODE_CBC, 'your_secret_iv')
        encrypted_data = image.tobytes()
        decrypted_data = aes.decrypt(encrypted_data)
        if b'\0' in decrypted_data:
            decrypted_data = decrypted_data.rstrip(b'\0')
        img = Image.frombytes('RGB', image.size, decrypted_data)
        return img

在上述代码中,我们定义了一个 EncryptedDataset 类,用于加载加密的图片文件,这个类需要实现 __init____getitem____len__ 方法。在 __getitem__ 方法中,我们首先使用 Pillow 库的 Image.open 方法加载加密的图片数据,然后使用我们的密钥和向量对图像数据进行解密处理,然后再进行如上述步骤一和步骤二中的数据预处理操作。

在实例化该类时,我们需要传入带有加密数据的图片文件夹路径。

root = 'path/to/your/encrypted/data'
dataset = EncryptedDataset(root, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

最后我们使用 DataLoader 载入我们自定义的数据集,在训练时直接使用该数据加载器即可。

总之,根据需要,我们可以为我们的深度学习任务定义自己的数据集合实例,在这个自定义的数据集上使用批量解压、大小重置、颜色处理等方法来预处理数据。随后我们可以将预处理好的数据集与 PyTorch 的预定义数据集传入到数据加载器中,实现模型训练。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据处理:定义自己的数据集合实例 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • Python之日期与时间处理模块(date和datetime)

    Python之日期与时间处理模块(date和datetime) 在Python中日期和时间处理非常方便,Python标准库提供了两个重要的模块date和datetime。本篇文章将详细介绍如何使用这两个模块,并通过示例展示具体的使用方法。 date模块 date模块提供了一个date类,该类表示一个简单的日期对象,包含年月日的信息。 创建日期对象 使用dat…

    python 2023年5月14日
    00
  • Vs Code中8个好用的python 扩展插件

    标题:Vs Code中8个好用的Python扩展插件 首先,为了更好的使用Vs Code编写Python代码,可以安装以下8个好用的Python扩展插件。 1. Python Python是一款由Microsoft官方提供的Vs Code扩展插件,可使Vs Code更好地解析Python代码,并可做到代码智能提示、语法高亮、代码补全、代码格式化等。安装方法为…

    python 2023年5月19日
    00
  • 详解Python WSGI处理抛出异常

    Python WSGI是Python Web Server Gateway Interface的缩写,它定义了应用程序和Web服务器之间的通信接口。WSGI应用程序运行在Web服务器和Python解释器之间,通过环境变量来传递请求和响应数据。在WSGI应用程序的开发中,处理抛出异常是非常重要的一步,因为它可以有效地保证应用程序的稳定性和安全性。 以下是Pyt…

    python-answer 2023年3月25日
    00
  • Python numpy.transpose使用详解

    非常感谢您对于Python numpy.transpose使用的关注。下面是详细讲解的攻略。 Python numpy.transpose使用详解 概述 numpy.transpose() 函数用于对换数组的维度。对于一维数组,它就是将原数组翻转。对于二维数组,就是执行矩阵转置的操作。更高维度的数组操作,是基于这两个维度的操作,多次使用transpose()…

    python 2023年5月18日
    00
  • Python urlopen()和urlretrieve()用法解析

    Python urlopen() 和 urlretrieve() 用法解析 在Python中,我们可以使用urllib库中的urlopen()和urlretrieve()函数来处理URL。这两个函数都可以用于打开URL并读取其内容,但它们的用法略有不同。本文将详细介绍这两个函数的用法,并提供两个示例。 urlopen()函数 urlopen()函数是Pyth…

    python 2023年5月15日
    00
  • python3中rsa加密算法详情

    下面就来详细讲解 Python3 中 RSA 加密算法的完整攻略。 什么是 RSA 加密算法? RSA 是一种非对称加密算法,即加密与解密使用的是不同的密钥。 RSA 加密算法的原理是:使用两个大素数 p 和 q 计算出 N = p * q,然后选取两个数 e 和 d,使得 e * d ≡ 1 (mod (p-1) * (q-1)),e 称为公钥,d 称为私…

    python 2023年5月20日
    00
  • Python学习之文件的创建与写入详解

    Python学习之文件的创建与写入详解 在Python中,文件是信息存储的一种重要方式。Python中的文件操作非常简单,可以轻松地创建、读取和修改文件。本文介绍如何在Python中创建和写入文件。 文件的创建 要在Python中创建一个新文件,可以使用内置的open()函数。 open()函数的语法如下: file = open(filename, mod…

    python 2023年6月2日
    00
  • 如何在 Python 配置的 atom 中修复 linter-Flake8

    【问题标题】:How to Fix linter-Flake8 in atom for Python Configuration如何在 Python 配置的 atom 中修复 linter-Flake8 【发布时间】:2023-04-07 12:23:01 【问题描述】: 简介 在我将atom 安装到我的debian-ParrotOS 中用于编码python…

    Python开发 2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部