详解PyTorch预定义数据集类datasets.ImageFolder使用方法
简述
datasets.ImageFolder
是PyTorch中预定义的用于处理图像分类任务的数据集类,并且可以轻松地进行自定义。
其中ImageFolder
的基础类是torch.utils.data.Dataset
,这个类是用于构建数据集的基类,我们可以在这个类中实现自定义数据集。
使用方法
首先,我们需要在代码中导入相关的库
import torch
from torchvision import datasets, transforms
在导入库以后,我们需要对数据进行预处理。可以通过transforms库来实现。比如我们需要对图像进行数据增强、缩放,同时将数据转换为tensor类型。
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
上述代码中,我们使用了transforms.Resize
将图像大小改为(224,224)
,使用transforms.RandomCrop
在图像中随机裁剪(224,224)
大小的图像,使用transforms.RandomHorizontalFlip
对图像进行随机水平翻转,并使用transforms.ToTensor
将图像转换为tensor类型。
接下来,我们可以使用datasets.ImageFolder
类按照给定的路径构建数据集,并进行预处理,同时使用torch.utils.data.DataLoader
构建数据迭代器。
train_dataset = datasets.ImageFolder('data/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
上述代码中,我们使用datasets.ImageFolder
类构建了训练数据集,并传入预处理的参数transform。之后,我们使用torch.utils.data.DataLoader
构建了数据迭代器,其中batch_size
为批大小,shuffle
表示是否对数据进行随机排序。
最后,我们就可以使用数据迭代器来获取数据进行训练。
for i, (input, label) in enumerate(train_loader):
# 进行训练操作
pass
示例说明
示例一
我们可以通过以下方式来修改datasets.ImageFolder
类的默认标签名称和类名对应的文件夹名称。
class ImageFolderWithPaths(datasets.ImageFolder):
# 重载 __getitem__ 函数来包含文件路径
def __getitem__(self, index):
original_tuple = super().__getitem__(index)
# 文件路径
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 加载数据集
data_dir = './data'
dataset = ImageFolderWithPaths(data_dir, transform)
# 获取数据并显示文件路径
for inputs, labels, paths in dataset:
print(paths)
上述代码中,我们实现了一个重载__getitem__
函数的自定义ImageFolderWithPaths
类,使得该类在获取数据时可以返回文件路径。接着,我们实例化了这个类并传入数据集目录和预处理参数。最后我们使用for
循环方式来遍历数据集,并输出每一张图片对应的文件路径。
示例二
下面的示例代码展示了如何在训练过程中使用ImageFolder数据集读取顺序打乱的CSV数据。
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import random
class CSVImageDataset(Dataset):
def __init__(self, csv_file_path, transform=None):
self.df = pd.read_csv(csv_file_path)
self.transform = transform
self.dataset_len = len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
img_path = row['img_path']
label = row['label']
image = Image.open(img_path).convert("RGB")
if self.transform is not None:
image = self.transform(image)
return (image, label)
def __len__(self):
return self.dataset_len
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 加载CSV文件并初始化数据集
csv_file = './data/train.csv'
dataset = CSVImageDataset(csv_file, transform)
# 初始化数据迭代器,并打乱数据顺序
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
# 遍历数据集并进行训练
for inputs, labels in train_loader:
# 进行训练操作
pass
上述代码中,我们使用了Pandas库读取CSV文件记录的文件路径和标签,并使用pil库将图像读取为RGB格式的PIL Image类型。
接着,我们定义了一个自定义的图片数据集类CSVImageDataset
,并重载了__getitem__
和__len__
函数对数据进行操作。
最后,我们创建了一个CSVImageDataset
的实例并传入CSV文件路径和预处理参数,然后使用DataLoader
构建了数据迭代器,并使用for
循环遍历每个批次的数据并进行训练。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:详解PyTorch预定义数据集类datasets.ImageFolder使用方法 - Python技术站