pytorch加载自己的数据集源码分享

下面是关于pytorch加载自己的数据集的完整攻略。

1. 准备数据集

在使用pytorch训练模型需要一个自己的数据集,这里以图像分类任务为例,准备一个包含训练集和测试集的数据集,其中每个图像都分好了类别并放在对应的文件夹中,例如:

dataset
├── train
│   ├── cat
│   │   ├── cat1.jpg
│   │   ├── cat2.jpg
│   │   └── ...
│   ├── dog
│   │   ├── dog1.jpg
│   │   ├── dog2.jpg
│   │   └── ...
│   └── ...
└── test
    ├── cat
    │   ├── cat1.jpg
    │   ├── cat2.jpg
    │   └── ...
    ├── dog
    │   ├── dog1.jpg
    │   ├── dog2.jpg
    │   └── ...
    └── ...

2. 定义Dataset类

接下来需要定义一个torch.utils.data.Dataset的子类,在其中实现数据集的加载、预处理等操作。以下是一个基本的示例:

import torch
from PIL import Image
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.labels = sorted(os.listdir(root_dir))

    def __len__(self):
        return sum([len(files) for _, _, files in os.walk(self.root_dir)])

    def __getitem__(self, index):
        label = self.labels[index // len(self)]
        img_path = glob.glob(f"{self.root_dir}/{label}/*")[index % len(self)]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

在上面的代码中,我们定义了一个名为MyDataset的子类,实现了4个方法:

  • __init__:初始化方法,接收两个参数,一个是数据集的根目录,另一个是数据集上的转换操作(transform),这里使用了PIL.Image库来加载图像。
  • __len__:返回数据集的长度,这里是遍历数据集中所有图片的数量。
  • __getitem__:根据索引返回对应的图像和标签,并且进行预处理,这里直接返回了图像的张量和标签字符串。
  • labels:这个属性用于保存数据集中所有类别的名称,使用了Python内置的os.listdir方法。

3. 定义DataLoader

接下来可以使用torch.utils.data.DataLoader类来加载数据集,并使用pytorch进行训练。以下是一个基本的示例:

from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = MyDataset("dataset/train", transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

test_set = MyDataset("dataset/test", transform=transform)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

在上面的代码中,我们使用torchvision.transforms库来定义了一组转换操作,包括将图像大小转换为224x224、转换为张量和标准化操作,并将其传递给MyDataset类的实例中。

然后我们分别创建了训练集和测试集的DataLoader,其中MyDataset是传入数据集的实例,batch_size表示每个batch的大小,shuffle=True表示在每个epoch开始时打乱数据。

至此,我们已经完成了pytorch加载自己的数据集的完整攻略。

4. 示例说明

示例一

假设我们要训练一个resnet18网络来对上述示例中的图片分类,可以按照以下步骤定义并训练这个模型:

import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18

device = "cuda" if torch.cuda.is_available() else "cpu"
model = resnet18(pretrained=True)
model.fc = nn.Linear(512, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(10):
    running_loss = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch: {epoch+1} Loss: {running_loss/len(train_loader):.4f}")

在上述代码中,我们使用了torchvision内置的resnet18网络,将其输出层改为2个节点的全连接层,用于分类2个类别的图像。

我们使用了交叉熵损失函数和随机梯度下降优化器,每个epoch训练完成后输出当前loss值。

示例二

接下来我们可以使用训练好的模型对测试集中的图像进行预测,示例代码如下:

correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Accuracy: {100*correct/total:.2f}%")

在上述代码中,我们使用torch.no_grad()上下文管理器来关闭梯度计算,防止内存溢出。然后遍历测试集中的所有图像,进行前向预测,并计算准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch加载自己的数据集源码分享 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Go 内存分配管理

    Go 内存分配管理的完整攻略 Go语言内存管理继承了C语言的双层结构:堆和栈。栈是自动管理的,而程序员需要负责管理堆上的内存。Go语言采用一个称为垃圾回收器的进程来管理堆上的内存。 内存分配 Go语言的内存分配是通过new()和make()进行的。 new() new()函数会为类型分配内存,并返回指向该类型零值的指针。它的语法为: p := new(Typ…

    人工智能概览 2023年5月25日
    00
  • go语言入门环境搭建及GoLand安装教程详解

    Go语言入门环境搭建及GoLand安装教程详解 概述 Go语言是Google公司推出的一种新型编程语言,具有并发,高性能等特性,因此备受开发者青睐。本文将详细讲解如何搭建Go语言的开发环境和安装GoLand等开发工具。 步骤一:安装Go语言环境 下载Go语言环境安装包 在官网(https://golang.org/dl/)下载对应操作系统的安装包,推荐下载稳…

    人工智能概论 2023年5月25日
    00
  • 在Linux系统中将Redmine和SVN整合入Nginx的方法

    将Redmine和SVN整合入Nginx的方法,可以通过以下步骤完成: 1. 安装和配置Redmine 1.1 安装Ruby和Rails 首先需要安装Ruby和Rails。在命令行输入以下命令: sudo apt-get update sudo apt-get install ruby rails 1.2 下载和解压Redmine 到Redmine官网下载安…

    人工智能概览 2023年5月25日
    00
  • 详解Python如何实现惰性导入-lazy import

    如何实现Python的惰性导入?我们可以通过使用Python的 __import__() 函数和自定义模块加载器实现这一功能。下面是详细的攻略: 1. 了解Python的模块加载顺序 在了解如何实现惰性导入之前,我们先简要介绍一下Python的模块加载顺序。当Python通过 import 或 from 语句加载模块时,会按照以下顺序搜索模块: 当前目录 环…

    人工智能概论 2023年5月25日
    00
  • Django def clean()函数对表单中的数据进行验证操作

    Django中的表单验证是在视图函数中使用的,在视图函数中,使用表单的is_valid()方法进行验证,但是有时候我们需要在表单类中对用户提交的数据进行进一步的自定义验证操作,这时候就需要使用到clean()函数。 clean()函数介绍 clean()函数是在django中的表单验证过程中定义的一个函数,可以对用户提交的数据进行自定义验证操作。clean(…

    人工智能概论 2023年5月25日
    00
  • Vue兼容ie9的问题全面解决方案

    下面是关于“Vue兼容IE9的问题全面解决方案”的攻略: 1. 问题描述 Vue版本从2.x开始,不再支持IE8以及更早的版本,而IE9在Vue项目中的兼容性问题也比较突出,容易导致项目运行出错或数据无法正确展示。 2. 解决方案 2.1 使用babel-polyfill兼容ES6的语法 IE浏览器不支持ES6的语法,我们需要使用babel将ES6转为ES5…

    人工智能概览 2023年5月25日
    00
  • Docker如何部署Python项目的实现详解

    下面我将为你详细讲解“Docker如何部署Python项目的实现详解”。 Docker部署Python项目 1. 什么是Docker? Docker是一种开源软件平台,可以帮助开发者将应用程序与其依赖项打包到一个可移植的容器中,然后发布到任何支持Docker的机器上。 2. Docker如何部署Python项目? Docker部署Python项目的实现方法如…

    人工智能概览 2023年5月25日
    00
  • 在 Ubuntu 12.04 Server 上安装部署 Ruby on Rails 应用

    下面我们详细讲解“在 Ubuntu 12.04 Server 上安装部署 Ruby on Rails 应用”的完整攻略。 1. 前置条件 在安装和部署 Ruby on Rails 应用之前,需要先完成以下几个前置条件: 安装 Ubuntu Server 12.04。 更新操作系统并安装必要的依赖。 安装 Ruby 2.0 或更高版本。 安装 Rails 5 …

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部