pytorch加载自己的数据集源码分享

yizhihongxing

下面是关于pytorch加载自己的数据集的完整攻略。

1. 准备数据集

在使用pytorch训练模型需要一个自己的数据集,这里以图像分类任务为例,准备一个包含训练集和测试集的数据集,其中每个图像都分好了类别并放在对应的文件夹中,例如:

dataset
├── train
│   ├── cat
│   │   ├── cat1.jpg
│   │   ├── cat2.jpg
│   │   └── ...
│   ├── dog
│   │   ├── dog1.jpg
│   │   ├── dog2.jpg
│   │   └── ...
│   └── ...
└── test
    ├── cat
    │   ├── cat1.jpg
    │   ├── cat2.jpg
    │   └── ...
    ├── dog
    │   ├── dog1.jpg
    │   ├── dog2.jpg
    │   └── ...
    └── ...

2. 定义Dataset类

接下来需要定义一个torch.utils.data.Dataset的子类,在其中实现数据集的加载、预处理等操作。以下是一个基本的示例:

import torch
from PIL import Image
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.labels = sorted(os.listdir(root_dir))

    def __len__(self):
        return sum([len(files) for _, _, files in os.walk(self.root_dir)])

    def __getitem__(self, index):
        label = self.labels[index // len(self)]
        img_path = glob.glob(f"{self.root_dir}/{label}/*")[index % len(self)]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

在上面的代码中,我们定义了一个名为MyDataset的子类,实现了4个方法:

  • __init__:初始化方法,接收两个参数,一个是数据集的根目录,另一个是数据集上的转换操作(transform),这里使用了PIL.Image库来加载图像。
  • __len__:返回数据集的长度,这里是遍历数据集中所有图片的数量。
  • __getitem__:根据索引返回对应的图像和标签,并且进行预处理,这里直接返回了图像的张量和标签字符串。
  • labels:这个属性用于保存数据集中所有类别的名称,使用了Python内置的os.listdir方法。

3. 定义DataLoader

接下来可以使用torch.utils.data.DataLoader类来加载数据集,并使用pytorch进行训练。以下是一个基本的示例:

from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = MyDataset("dataset/train", transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

test_set = MyDataset("dataset/test", transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

在上面的代码中,我们使用torchvision.transforms库来定义了一组转换操作,包括将图像大小转换为224x224、转换为张量和标准化操作,并将其传递给MyDataset类的实例中。

然后我们分别创建了训练集和测试集的DataLoader,其中MyDataset是传入数据集的实例,batch_size表示每个batch的大小,shuffle=True表示在每个epoch开始时打乱数据。

至此,我们已经完成了pytorch加载自己的数据集的完整攻略。

4. 示例说明

示例一

假设我们要训练一个resnet18网络来对上述示例中的图片分类,可以按照以下步骤定义并训练这个模型:

import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18

device = "cuda" if torch.cuda.is_available() else "cpu"
model = resnet18(pretrained=True)
model.fc = nn.Linear(512, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(10):
    running_loss = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch: {epoch+1} Loss: {running_loss/len(train_loader):.4f}")

在上述代码中,我们使用了torchvision内置的resnet18网络,将其输出层改为2个节点的全连接层,用于分类2个类别的图像。

我们使用了交叉熵损失函数和随机梯度下降优化器,每个epoch训练完成后输出当前loss值。

示例二

接下来我们可以使用训练好的模型对测试集中的图像进行预测,示例代码如下:

correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Accuracy: {100*correct/total:.2f}%")

在上述代码中,我们使用torch.no_grad()上下文管理器来关闭梯度计算,防止内存溢出。然后遍历测试集中的所有图像,进行前向预测,并计算准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch加载自己的数据集源码分享 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • SpringBoot2 整合Nacos组件及环境搭建和入门案例解析

    下面是关于“SpringBoot2 整合Nacos组件及环境搭建和入门案例解析”的完整攻略。 SpringBoot2 整合Nacos组件及环境搭建和入门案例解析 1. 环境搭建 Nacos简介 Nacos是阿里巴巴开源的分布式服务发现、配置管理和服务治理平台。Nacos支持几乎所有主流类型的服务,包括Kubernetes、Mesos、Docker等。 下载N…

    人工智能概览 2023年5月25日
    00
  • win10预览版10074再次更新:OCR中文语言包

    Win10预览版10074再次更新:OCR中文语言包攻略 Win10预览版10074在2015年5月1日再次更新了OCR中文语言包。接下来我们将详细讲解安装和使用该语言包的方法。 1. 下载安装语言包 首先需要下载OCR中文语言包。可以前往微软官网下载安装。具体步骤如下: 访问微软官网; 在搜索框中搜索“OCR中文语言包”; 找到“Win10预览版10074…

    人工智能概览 2023年5月25日
    00
  • 独立部署小程序基于nodejs的服务器过程详解

    下面我来详细解释一下“独立部署小程序基于nodejs的服务器过程详解”的完整攻略,包含以下几个部分: 前提条件 安装Node.js和MongoDB 使用Express框架和Mongoose模块创建基于Node.js的服务端 部署服务端到云服务器上(以阿里云为例) 1. 前提条件 在开始独立部署小程序的服务器之前,需要具备以下技能: 熟悉Node.js和Exp…

    人工智能概论 2023年5月25日
    00
  • 使用PHPWord生成word文档的方法详解

    “使用PHPWord生成word文档的方法详解”是一篇介绍在PHP中使用PHPWord库生成word文档的教程。这里将为大家提供一份完整的攻略,包含了从安装库到使用代码生成word文档的详细步骤。 安装PHPWord 在使用PHPWord之前,需要先将PHPWord库安装到本地。下面是安装步骤: 下载PHPWord库 可以通过在PHPWord的官方GitHu…

    人工智能概论 2023年5月25日
    00
  • Java进程间通信之消息队列

    接下来我将详细讲解Java进程间通信之消息队列的完整攻略。 什么是消息队列 消息队列是一种通过在应用程序之间异步地传输数据来解决耦合问题的技术。它允许发送者,通常是独立的应用程序,将消息发送到队列中而不需要实时处理它。相反,接收者从队列中接收消息并在合适的时候进行处理。 消息队列的作用 使用消息队列可以将应用程序之间的通信和解耦,提高了系统的可靠性、可扩展性…

    人工智能概览 2023年5月25日
    00
  • Go实现分布式系统高可用限流器实战

    Go实现分布式系统高可用限流器实战攻略 什么是限流器? 限流器是用来控制流量的一种重要工具。在分布式系统中,限流器可以帮助我们控制流量并且保证系统的稳定运行。 Go实现分布式系统高可用限流器的步骤 以下是Go实现分布式系统高可用限流器的步骤: 1. 定义限流器的数据结构 我们需要定义一个结构体来表示限流器。这个结构体包含以下字段: 每秒钟可以处理的请求数 r…

    人工智能概览 2023年5月25日
    00
  • Python CategoricalDtype自定义排序实现原理解析

    下面我会详细讲解如何使用Python的CategoricalDtype自定义排序。本文将按照以下步骤进行: 了解CategoricalDtype数据类型的基本概念 自定义排序方法的实现原理 示例演示 1. CategoricalDtype数据类型的基本概念 在Python中,CategoricalDtype是一种广泛使用的数据类型,其主要功能是对分类数据进行…

    人工智能概论 2023年5月25日
    00
  • python 实现任务管理清单案例

    下面是Python实现任务管理清单案例的完整攻略。 1. 准备工作 首先需要安装Python环境。推荐使用Python 3.x版本,可以在Python官网下载可执行程序并安装。 2. 确定需求和功能 本案例实现的功能需求如下: 添加任务 删除任务 修改任务 查看任务列表 3. 编写代码 首先,创建一个名为todolist.py的Python文件。在文件中添加…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部