pytorch 数据处理:定义自己的数据集合实例

yizhihongxing

请看下面的详细讲解。

PyTorch数据处理:定义自己的数据集合实例

在进行深度学习任务时,数据预处理是非常重要的一步,而 PyTorch 中,数据预处理也是必不可少的一环。在大多数情况下,我们需要使用已有的数据集,如官方提供的 MNIST、CIFAR10 等数据集;但有时我们也需要自己定义数据集,例如从图片数据集中自定义一个猫狗二分类的数据集。自定义数据集的过程其实并不困难,下面我们就来详细讲解一下如何定义 PyTorch 自己的数据集合实例。

定义自己的数据集合实例包含以下步骤:

  1. 构建包含数据和标签的数据集类。

数据集类需要继承 torch.utils.data.Dataset,并实现以下两个方法:

- `__len__(self)` 返回数据集的长度。
- `__getitem__(self, index)` 给定一个索引 index,返回对应的数据和标签。
  1. 对数据进行预处理。

  2. 创建数据加载器 DataLoader。

接下来我们将对上面三个步骤进行详细讲解。

步骤一:构建数据集类

让我们从构建一个二分类数据集为例。我们有一些猫和狗的图片,我们需要将他们分别标记为1和0,并存储在一个字典中。我们使用 PyTorch 中的 ImageFolder 类来加载数据。

import os
import torch.utils.data as data
from torchvision import datasets, transforms

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir):
    images = []
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in sorted(fnames):
            if is_image_file(fname):
                path = os.path.join(root, fname)
                item = (path, int(fname.split('.')[0] == 'cat'))
                images.append(item)
    return images

class CatDogDataset(data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = make_dataset(self.root)

    def __getitem__(self, index):
        path, target = self.images[index]
        img = default_loader(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.images)

在上述代码中,我们首先定义了一个 IMG_EXTENSIONS 列表,其中包含了我们所支持的图片格式的后缀名。接着定义了一个函数 is_image_file 用于判断某个文件是否是图片;再定义了 make_dataset 函数用于生成一个元组列表,其中包含了图片的路径和标签,这个标签是根据文件名来判断的,如果文件名是 'cat.XXX.jpg' 的格式,则标签为1,否则标签为0。最后,我们定义了 CatDogDataset 类,继承了 torch.utils.data.Dataset;定义了初始化函数 __init__,用于初始化数据集路径和数据增强方法;定义了 __getitem__ 方法,用于返回指定索引的数据和标签;以及 __len__ 方法返回数据集长度。

步骤二:对数据进行预处理

为了让模型能够更好地利用数据,我们通常需要对数据进行预处理。在 PyTorch 中,我们可以使用 torchvision.transforms 来进行数据预处理,其中包含了很多有用的函数,例如对图像进行裁剪、旋转、缩放等增强操作。下面是一个简单的数据预处理示例:

transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

在上述代码中,我们通过 transforms.Compose 将多个数据预处理操作串联起来,最后将它们应用在我们定义的数据集上。在这个例子中,我们首先使用 transforms.CenterCrop 将图片中心裁剪为 224 × 224,再使用 transforms.ToTensor 将图片(0, 255)范围内的像素转换到(0, 1)范围内的张量,最后使用 transforms.Normalize 对图片进行归一化操作。对于这个归一化操作,我们需要使用图像数据集上的均值和方差。在这里,我们使用 ImageNet 数据集上的均值和方差,这是一个广泛使用的标准值。

步骤三:创建数据加载器 DataLoader

在数据预处理之后,我们需要将定义好的数据集类加载到 PyTorch 的 DataLoader 类当中,并且设定好每次取多少张图片进行训练或测试。DataLoader 可以自动对数据集进行批次切分,这样可以提高模型训练时的效率。

from torch.utils.data import DataLoader

batch_size = 64

trainset = CatDogDataset('path/to/train/data', transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

在上述代码中,我们使用了 torch.utils.data.DataLoader 类来生成数据加载器,其中包含了我们自己定义的 CatDogDataset 类。我们设定了 batch size 为 64,shuffle 参数为 True 表示每个 epoch 随机对数据集进行洗牌,num_workers 表示使用几个进程来加载数据。该数据加载器可以用于训练过程中,例如:

for epoch in range(10): # 循环10次
    for i, (inputs, labels) in enumerate(trainloader, 0): # 每次取一个批次
        print(inputs.shape) # 显示当前批次数据的形状

在上述代码中,我们循环了 10 次,并且每次从训练数据中取出一个 batch size 大小的数据来训练模型。每次取出的数据使用 inputslabels 两个变量进行存储。

示例一:在CIFAR10数据集上创建训练集

以「CIFAR10数据集」为例,首先我们需要安装「torchvision」包,并引入以下依赖:

import torch
import torchvision
import torchvision.transforms as transforms

接着,我们定义数据预处理操作,这里需要注意的一点是,我们不对「CIFAR10数据集」进行归一化,因为这个数据集本身已经做过归一化处理。

transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor()])

transform_val = transforms.Compose([transforms.ToTensor()])

最后,我们以 4 作为 batch size,创建数据加载器。并使用 trainset.train_datatrainset.train_labels,这是 CIFAR10 自带的数据集。训练数据集我们随机抽取 80% 作为训练集,20% 作为测试集。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)

n_train = len(trainset)
train_idx, val_idx = torch.utils.data.random_split(range(n_train), [int(0.8*n_train), int(0.2*n_train)])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
valloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(val_idx))
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

示例二:通过加密文件夹生成训练数据集

在一些实际场景中,我们的数据集可能需要加密以保护数据隐私。我们可以先将数据集文件加密,在使用 PyTorch 加载数据集时进行解密。我们以放置有已加密文件的文件夹作为例子。

import torch
import torch.utils.data as data
from PIL import Image
from Crypto.Cipher import AES

class EncryptedDataset(data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

    def __getitem__(self, index):
        path = self.images[index]
        img = Image.open(path)
        img = self.decrypt(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.tensor(0) # 返回第二个值用于符合 __getitem__ 的规则

    def __len__(self):
        return len(self.images)

    def decrypt(self, image):
        block_size = 16
        aes = AES.new('your_secret_key', AES.MODE_CBC, 'your_secret_iv')
        encrypted_data = image.tobytes()
        decrypted_data = aes.decrypt(encrypted_data)
        if b'\0' in decrypted_data:
            decrypted_data = decrypted_data.rstrip(b'\0')
        img = Image.frombytes('RGB', image.size, decrypted_data)
        return img

在上述代码中,我们定义了一个 EncryptedDataset 类,用于加载加密的图片文件,这个类需要实现 __init____getitem____len__ 方法。在 __getitem__ 方法中,我们首先使用 Pillow 库的 Image.open 方法加载加密的图片数据,然后使用我们的密钥和向量对图像数据进行解密处理,然后再进行如上述步骤一和步骤二中的数据预处理操作。

在实例化该类时,我们需要传入带有加密数据的图片文件夹路径。

root = 'path/to/your/encrypted/data'
dataset = EncryptedDataset(root, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

最后我们使用 DataLoader 载入我们自定义的数据集,在训练时直接使用该数据加载器即可。

总之,根据需要,我们可以为我们的深度学习任务定义自己的数据集合实例,在这个自定义的数据集上使用批量解压、大小重置、颜色处理等方法来预处理数据。随后我们可以将预处理好的数据集与 PyTorch 的预定义数据集传入到数据加载器中,实现模型训练。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据处理:定义自己的数据集合实例 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • python判断windows系统是32位还是64位的方法

    确定当前运行的操作系统位数可以使用以下两种方法。 1. 使用platform模块 Python中的platform模块提供了许多关于运行Python解释器的平台信息的查询。我们可以使用platform.machine()函数查询当前计算机的处理器类型和操作系统的位数。 以下是一个示例程序: import platform if platform.machin…

    python 2023年5月30日
    00
  • python 提高开发效率的5个小技巧

    Python 提高开发效率的 5 个小技巧 Python 提供了许多方法来提高开发效率。在本文中,我们将介绍一些有用的技巧,可以帮助您更快、更高效地编写 Python 代码。 1. 列表解析式 列表解析式是一种简洁、优美的语法,可用于快速创建、转换或过滤列表。它可以代替大部分for循环,使代码更简单易懂。 例如,以下代码用列表解析式来创建一个由 1 到 10…

    python 2023年5月18日
    00
  • Python中Tkinter组件Listbox的具体使用

    Python中Tkinter组件Listbox的具体使用 在Python的Tkinter库中,Listbox是一种用于显示列表的组件。它可以用于显示一组项,用户可以从中一个或多个选项。本文将详细介绍如何在Python中使用Tkinter库中的Listbox组件,括如何创建Listbox、如何向Listbox中添加选项、如何获取选中的选项等。 创建Listbo…

    python 2023年5月13日
    00
  • python随机模块random使用方法详解

    Python随机模块random使用方法详解 在Python中,random模块是一个非常常用的模块,它可以帮助我们生成随机数、随机字符串、随机选择等。本文详细介绍如何Python的random模块,包括如何生成随数、如何生成随机字符串、如何进行随机选择。 生成随机数 在Python中,我们可以使用random模块的randint()函数、uniform()…

    python 2023年5月14日
    00
  • Python使用pyh生成HTML文档的方法示例

    Python使用pyh生成HTML文档的方法示例 pyh是Python的一个HTML生成库,可以用于生成HTML文档。本文将介绍如何使用pyh生成HTML文档,并提供两个示例。 步骤1:安装pyh库 在使用pyh库之前,我们需要安装它。您可以使用以下命令安装pyh库: pip install pyh 步骤2:生成HTML文档 以下是生成HTML文档的示例代码…

    python 2023年5月15日
    00
  • 浅谈python str.format与制表符\t关于中文对齐的细节问题

    浅谈python str.format与制表符\t关于中文对齐的细节问题 介绍 在Python中,字符串的格式化是经常用到的一个功能。而str.format方法则是目前Python默认推荐的格式化方法之一,因为它可以处理各种数据类型,并且使用起来非常方便。 同时,在输出数据时,经常需要使用到制表符\t来进行表格对齐的操作,而中文对齐的问题则是我们在使用中容易…

    python 2023年5月20日
    00
  • Python 以及如何从 Selenium 元素 WebElement 对象中获取文本?

    【问题标题】:Python and how to get text from Selenium element WebElement object?Python 以及如何从 Selenium 元素 WebElement 对象中获取文本? 【发布时间】:2023-04-03 10:25:01 【问题描述】: 我正在尝试使用 Selenium 方法获取 html…

    Python开发 2023年4月8日
    00
  • 详解Python中time()方法的使用的教程

    详解Python中time()方法的使用的教程 time()方法是Python标准库time模块中的一个函数,它的主要作用是获取当前时间的时间戳(即秒数)。本文将详细讲解Python中time()方法的使用。 time() 方法的语法 time()方法的语法如下: time.time() time() 方法的返回值 time()方法的返回值是从1970年1月…

    python 2023年6月3日
    00
合作推广
合作推广
分享本页
返回顶部