pytorch加载自己的图像数据集实例

下面是 "PyTorch加载自己的图像数据集实例" 的完整攻略:

准备工作

  1. 数据集准备:准备自己的图像数据集,并将其组织为相应的目录结构。例如,我们假设有一份猫狗分类的数据集,其中包含两个类别:狗和猫。则我们可以将其组织为如下目录结构:
dataset
├── train
│   ├── cat
│   │   ├── cat.1.png
│   │   ├── cat.2.png
│   │   ├── ……
│   ├── dog
│   │   ├── dog.1.png
│   │   ├── dog.2.png
│   │   ├── ……
├── val
│   ├── cat
│   │   ├── cat.10.png
│   │   ├── cat.11.png
│   │   ├── ……
│   ├── dog
│   │   ├── dog.10.png
│   │   ├── dog.11.png
│   │   ├── ……

其中,train 目录下是训练集,val 目录下是验证集。每个子目录表示一个类别。每个类别中包含若干张图片,文件名以类别名开头,并编号。

  1. 安装所需依赖包:PyTorch、torchvision

代码实现

import torch
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader

# 定义数据路径和变换
train_path = '/path/to/train'
val_path = '/path/to/val'
transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.classes = os.listdir(self.path)

    def __len__(self):
        num = 0
        for c in self.classes:
            num += len(os.listdir(os.path.join(self.path, c)))
        return num

    def __getitem__(self, index):
        for i, c in enumerate(self.classes):
            images = os.listdir(os.path.join(self.path, c))
            if index < len(images):
                img_path = os.path.join(self.path, c, images[index])
                img = Image.open(img_path).convert('RGB')
                img = self.transform(img)
                label = i
                return img, label
            else:
                index -= len(images)

# 创建数据集和数据加载器
train_dataset = MyDataset(train_path, transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

val_dataset = MyDataset(val_path, transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# 使用数据集和数据加载器进行训练和验证
for epoch in range(num_epochs):
    # 训练
    for images, labels in train_loader:
        # 训练操作
        pass

    # 验证
    with torch.no_grad():
        for images, labels in val_loader:
            # 计算模型预测结果,并进行验证操作
            pass

以上代码实现了一个简单的PyTorch数据加载器。其中,我们使用了torchvision.transforms模块定义了图像变换,包括将图像缩放到256x256,并中心裁剪为224x224大小,将图像转换为Tensor类型,并进行归一化操作。

然后,我们定义了一个自定义数据集类MyDataset,该类继承torch.utils.data.Dataset类。其中,__init__方法初始化数据路径和变换,__len__方法返回数据集样本数,__getitem__方法根据索引返回图像和标签。

最后,我们创建了两个数据集实例train_datasetval_dataset,并使用torch.utils.data.DataLoader创建了相应的数据加载器train_loaderval_loader。这样,我们就可以使用数据加载器对模型进行训练和验证。

另外,这里提供两条示例:

示例1:使用自定义数据集训练分类模型

import torch
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader

# 定义数据路径和变换
train_path = '/path/to/train'
val_path = '/path/to/val'
transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.classes = os.listdir(self.path)

    def __len__(self):
        num = 0
        for c in self.classes:
            num += len(os.listdir(os.path.join(self.path, c)))
        return num

    def __getitem__(self, index):
        for i, c in enumerate(self.classes):
            images = os.listdir(os.path.join(self.path, c))
            if index < len(images):
                img_path = os.path.join(self.path, c, images[index])
                img = Image.open(img_path).convert('RGB')
                img = self.transform(img)
                label = i
                return img, label
            else:
                index -= len(images)

# 创建数据集和数据加载器
train_dataset = MyDataset(train_path, transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

val_dataset = MyDataset(val_path, transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# 定义模型和优化器
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)  # 二分类问题
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # 计算验证集上的准确率
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100.0 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {acc:.2f}%')

上述示例使用了自定义数据集train_datasetval_dataset,分别表示训练集和验证集。其中,我们使用了一个预训练的ResNet18模型,并替换了其最后一层全连接层以适应二分类问题。然后,我们定义了交叉熵损失和随机梯度下降优化器,使用数据加载器进行训练,并在每个epoch结束后在验证集上计算了模型的准确率。

示例2:使用自定义数据集微调分类模型

import torch
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader

# 定义数据路径和变换
train_path = '/path/to/train'
val_path = '/path/to/val'
transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.classes = os.listdir(self.path)

    def __len__(self):
        num = 0
        for c in self.classes:
            num += len(os.listdir(os.path.join(self.path, c)))
        return num

    def __getitem__(self, index):
        for i, c in enumerate(self.classes):
            images = os.listdir(os.path.join(self.path, c))
            if index < len(images):
                img_path = os.path.join(self.path, c, images[index])
                img = Image.open(img_path).convert('RGB')
                img = self.transform(img)
                label = i
                return img, label
            else:
                index -= len(images)

# 加载预训练模型,并替换其最后一层
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)  # 二分类问题

# 冻结前面若干层参数
for i, param in enumerate(model.parameters()):
    if i < 40:
        param.requires_grad = False

# 创建数据集和数据加载器
train_dataset = MyDataset(train_path, transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

val_dataset = MyDataset(val_path, transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 微调模型
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # 计算验证集上的准确率
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100.0 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {acc:.2f}%')

上述示例同样使用了自定义数据集train_datasetval_dataset,分别表示训练集和验证集。不同的是,我们使用了一个预训练的ResNet18模型,并将其最后一层替换为适用于二分类问题的全连接层。然后,我们冻结前面若干层的参数,只训练后面的几层,以加速模型收敛。最后,我们使用数据加载器进行微调,并在每个epoch结束后在验证集上计算了模型的准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch加载自己的图像数据集实例 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • numpy取反操作符和Boolean类型与0-1表示方式

    当使用numpy进行数据处理时,经常需要使用取反操作符(~)和Boolean类型与0-1表示方式。本文将详细介绍这些概念,并提供一些示例来说明它们之间的关系。 取反操作符(~) 在numpy中,取反操作符(~)用于对数组中的元素进行逐位反。它的语法如下: numpy.invert(x, /, out=None, *, where=True, casting=…

    python 2023年5月14日
    00
  • 浅析关于Keras的安装(pycharm)和初步理解

    1. PyTorch中Tensor的数据类型 在PyTorch中,Tensor是最基本的数据类型,它是一个多维数组。Tensor可以是标量、向量、矩阵或任意维度的数组。在PyTorch中,Tensor有多种数据类型,包括: torch.FloatTensor:32位浮点数 torch.DoubleTensor:64位浮点数 torch.HalfTensor:…

    python 2023年5月14日
    00
  • Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算

    Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算 NumPy是Python中一个重要的科学计算库,提供了高效的多维数组和各种派生对象以于计各种函数。其中,方差、标准方差、样本标准方差和协方差是用的统计量,本文将讲解如使用NumPy计算这些统计量。 方差的计算 方差是一组数据其平均数之差的平方和的平均,用于衡量数据的离散程度。在Num…

    python 2023年5月13日
    00
  • Python numpy中的ndarray介绍

    Python Numpy中的ndarray介绍 ndarray是Numpy中一个重要的数据结构,它是一个多维数组,可以用于存储和处理大量的数据。本攻略将详细介绍Python Numpy中的ndarray。 导入Numpy模块 在使用Numpy模块之前,需要先导入它。可以以下命令在Python脚本中导入Numpy模块: import numpy as np 在…

    python 2023年5月13日
    00
  • python中numpy包使用教程之数组和相关操作详解

    Python中NumPy包使用教程之数组和相关操作详解 NumPy是Python中一个重要的科学计算库,提供了高效的多维数组和各派生对象以于计算各种函数。本文将入讲解Py中的NumPy包使用教之数组和相关操作详解,包括数组的创建、数组的索引和切片、数组的形状操作、数组的拼接和分裂、数组的复制和视图等。 数组的创建 在NumPy中,可以使用array()函数来…

    python 2023年5月13日
    00
  • python 使用cx-freeze打包程序的实现

    Python使用cx-Freeze打包程序的实现 在Python中,我们可以使用cx-Freeze将Python程序打包成可执行文件。在本攻略中,我们将介绍如何使用cx-Freeze打包程序,并提供两个示例说明。 问题描述 在Python中,我们通常需要将Python程序打包成可执行文件,以便在没有Python环境的计算机上运行。如何使用cx-Freeze打…

    python 2023年5月14日
    00
  • Python如何遍历numpy数组

    Python如何遍历NumPy数组 在Python中,遍历NumPy数组有多种方法,包括使用for循环、使用nditer()函数、使用flat属性等。下面将详细讲解这些方法。 使用for循环遍历NumPy数组 使用循环遍历NumPy数组是最简单的方法。下面是一个示例: import numpy as np # 创建NumPy a = np.array([[1…

    python 2023年5月14日
    00
  • Windows下Python3.6安装第三方模块的方法

    在Windows下,安装Python3.6后,可以使用pip来安装第三方模块。以下是安装第三方模块的步骤: 安装pip 在安装第三方模块之前,需要先安装pip。可以从官方网站下载get-pip.py文件。下载完成后,可以使用以下命令安装pip: python get-pip.py 安装第三方模块 安装pip后,可以使用以下命令安装第三方模块: pip ins…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部