请看下面的详细讲解。
PyTorch数据处理:定义自己的数据集合实例
在进行深度学习任务时,数据预处理是非常重要的一步,而 PyTorch 中,数据预处理也是必不可少的一环。在大多数情况下,我们需要使用已有的数据集,如官方提供的 MNIST、CIFAR10 等数据集;但有时我们也需要自己定义数据集,例如从图片数据集中自定义一个猫狗二分类的数据集。自定义数据集的过程其实并不困难,下面我们就来详细讲解一下如何定义 PyTorch 自己的数据集合实例。
定义自己的数据集合实例包含以下步骤:
- 构建包含数据和标签的数据集类。
数据集类需要继承 torch.utils.data.Dataset
,并实现以下两个方法:
- `__len__(self)` 返回数据集的长度。
- `__getitem__(self, index)` 给定一个索引 index,返回对应的数据和标签。
-
对数据进行预处理。
-
创建数据加载器 DataLoader。
接下来我们将对上面三个步骤进行详细讲解。
步骤一:构建数据集类
让我们从构建一个二分类数据集为例。我们有一些猫和狗的图片,我们需要将他们分别标记为1和0,并存储在一个字典中。我们使用 PyTorch 中的 ImageFolder
类来加载数据。
import os
import torch.utils.data as data
from torchvision import datasets, transforms
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)
item = (path, int(fname.split('.')[0] == 'cat'))
images.append(item)
return images
class CatDogDataset(data.Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.images = make_dataset(self.root)
def __getitem__(self, index):
path, target = self.images[index]
img = default_loader(path)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.images)
在上述代码中,我们首先定义了一个 IMG_EXTENSIONS
列表,其中包含了我们所支持的图片格式的后缀名。接着定义了一个函数 is_image_file
用于判断某个文件是否是图片;再定义了 make_dataset
函数用于生成一个元组列表,其中包含了图片的路径和标签,这个标签是根据文件名来判断的,如果文件名是 'cat.XXX.jpg' 的格式,则标签为1,否则标签为0。最后,我们定义了 CatDogDataset
类,继承了 torch.utils.data.Dataset
;定义了初始化函数 __init__
,用于初始化数据集路径和数据增强方法;定义了 __getitem__
方法,用于返回指定索引的数据和标签;以及 __len__
方法返回数据集长度。
步骤二:对数据进行预处理
为了让模型能够更好地利用数据,我们通常需要对数据进行预处理。在 PyTorch 中,我们可以使用 torchvision.transforms 来进行数据预处理,其中包含了很多有用的函数,例如对图像进行裁剪、旋转、缩放等增强操作。下面是一个简单的数据预处理示例:
transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
在上述代码中,我们通过 transforms.Compose
将多个数据预处理操作串联起来,最后将它们应用在我们定义的数据集上。在这个例子中,我们首先使用 transforms.CenterCrop
将图片中心裁剪为 224 × 224,再使用 transforms.ToTensor
将图片(0, 255)范围内的像素转换到(0, 1)范围内的张量,最后使用 transforms.Normalize
对图片进行归一化操作。对于这个归一化操作,我们需要使用图像数据集上的均值和方差。在这里,我们使用 ImageNet 数据集上的均值和方差,这是一个广泛使用的标准值。
步骤三:创建数据加载器 DataLoader
在数据预处理之后,我们需要将定义好的数据集类加载到 PyTorch 的 DataLoader
类当中,并且设定好每次取多少张图片进行训练或测试。DataLoader
可以自动对数据集进行批次切分,这样可以提高模型训练时的效率。
from torch.utils.data import DataLoader
batch_size = 64
trainset = CatDogDataset('path/to/train/data', transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
在上述代码中,我们使用了 torch.utils.data.DataLoader
类来生成数据加载器,其中包含了我们自己定义的 CatDogDataset
类。我们设定了 batch size 为 64,shuffle 参数为 True 表示每个 epoch 随机对数据集进行洗牌,num_workers 表示使用几个进程来加载数据。该数据加载器可以用于训练过程中,例如:
for epoch in range(10): # 循环10次
for i, (inputs, labels) in enumerate(trainloader, 0): # 每次取一个批次
print(inputs.shape) # 显示当前批次数据的形状
在上述代码中,我们循环了 10 次,并且每次从训练数据中取出一个 batch size 大小的数据来训练模型。每次取出的数据使用 inputs
和 labels
两个变量进行存储。
示例一:在CIFAR10数据集上创建训练集
以「CIFAR10数据集」为例,首先我们需要安装「torchvision」包,并引入以下依赖:
import torch
import torchvision
import torchvision.transforms as transforms
接着,我们定义数据预处理操作,这里需要注意的一点是,我们不对「CIFAR10数据集」进行归一化,因为这个数据集本身已经做过归一化处理。
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
transform_val = transforms.Compose([transforms.ToTensor()])
最后,我们以 4 作为 batch size,创建数据加载器。并使用 trainset.train_data
和 trainset.train_labels
,这是 CIFAR10 自带的数据集。训练数据集我们随机抽取 80% 作为训练集,20% 作为测试集。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)
n_train = len(trainset)
train_idx, val_idx = torch.utils.data.random_split(range(n_train), [int(0.8*n_train), int(0.2*n_train)])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(train_idx))
valloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(val_idx))
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
示例二:通过加密文件夹生成训练数据集
在一些实际场景中,我们的数据集可能需要加密以保护数据隐私。我们可以先将数据集文件加密,在使用 PyTorch 加载数据集时进行解密。我们以放置有已加密文件的文件夹作为例子。
import torch
import torch.utils.data as data
from PIL import Image
from Crypto.Cipher import AES
class EncryptedDataset(data.Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
def __getitem__(self, index):
path = self.images[index]
img = Image.open(path)
img = self.decrypt(img)
if self.transform is not None:
img = self.transform(img)
return img, torch.tensor(0) # 返回第二个值用于符合 __getitem__ 的规则
def __len__(self):
return len(self.images)
def decrypt(self, image):
block_size = 16
aes = AES.new('your_secret_key', AES.MODE_CBC, 'your_secret_iv')
encrypted_data = image.tobytes()
decrypted_data = aes.decrypt(encrypted_data)
if b'\0' in decrypted_data:
decrypted_data = decrypted_data.rstrip(b'\0')
img = Image.frombytes('RGB', image.size, decrypted_data)
return img
在上述代码中,我们定义了一个 EncryptedDataset
类,用于加载加密的图片文件,这个类需要实现 __init__
、__getitem__
和 __len__
方法。在 __getitem__
方法中,我们首先使用 Pillow 库的 Image.open
方法加载加密的图片数据,然后使用我们的密钥和向量对图像数据进行解密处理,然后再进行如上述步骤一和步骤二中的数据预处理操作。
在实例化该类时,我们需要传入带有加密数据的图片文件夹路径。
root = 'path/to/your/encrypted/data'
dataset = EncryptedDataset(root, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
最后我们使用 DataLoader
载入我们自定义的数据集,在训练时直接使用该数据加载器即可。
总之,根据需要,我们可以为我们的深度学习任务定义自己的数据集合实例,在这个自定义的数据集上使用批量解压、大小重置、颜色处理等方法来预处理数据。随后我们可以将预处理好的数据集与 PyTorch 的预定义数据集传入到数据加载器中,实现模型训练。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据处理:定义自己的数据集合实例 - Python技术站