PyTorch中torch.utils.data.DataLoader实例详解

PyTorch中torch.utils.data.DataLoader实例详解

介绍

在深度学习中,使用大量的数据进行模型的训练是必需的,但是对于包含大量数据集的任务来说,常规的数据输入(如读取整个数据集,并将其存储在内存中)通常会耗费大量的时间和空间。因此,数据加载的高效性至关重要。PyTorch提供了一个名为DataLoader的工具,可以快速且高效地处理数据。

DataLoader在PyTorch中是数据加载的一种方式,它可以通过提供一个数据集dataset和一个批大小batch_size,自动地对数据进行迭代和批量处理。我们可以使用DataLoader从硬盘或者内存中加载数据,并且可以在数据批次之间轻松地对数据进行处理。

基本使用

步骤一:创建数据集

在使用DataLoader之前,我们需要先创建一个数据集。数据集可以是一个文件夹,也可以是一个csv文件或其他类型文件。以下代码展示如何创建一个来自MNIST数据集的数据集:

import torchvision.datasets as dset
dataset = dset.MNIST(root='data/', download=True, transform=None)

这个数据集含有60000个训练图片和10000个测试图片,每张图片都是28x28的灰度图片。dataset对象可以通过getitem()方法访问每个样本。

步骤二:创建数据加载器DataLoader

在创建数据集之后,我们需要将它传递给DataLoader,以便对数据进行批处理和迭代。以下是创建数据加载器的基本语法:

from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

这里的dataset是我们在第一步创建的数据集对象,batch_size是指每批数据的大小,shuffle参数表示在每个时期结束时是否对数据进行重洗。

步骤三: 遍历数据集

现在,我们可以使用DataLoader来遍历数据集,并可以使用for循环语句按批迭代数据集,如下所示:

for x_train, y_train in dataloader:
    # do something...

这里的x_trainy_train分别是一个从数据集中获取的批次中的数据和标签。

示例应用一:图像分类

以下代码展示了如何使用DataLoader从CIFAR10数据集中加载图像数据,然后进行标准化处理,并将其拟合到一个简单的卷积神经网络中进行分类:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 加载CIFAR10数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = dset.CIFAR10(root='data/', train=True, download=True, transform=transform)
test_data = dset.CIFAR10(root='data/', train=False, download=True, transform=transform)

# 使用数据加载器迭代数据集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义卷积神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5,padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5,padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 定义优化器和损失函数
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the model on the test images: %d %%' % (
    100 * correct / total))

示例应用二:自定义数据集

步骤一:准备数据

首先,我们需要准备一组自己的数据集,我们可以将所有数据放在一个文件夹中,或者使用csv文件导入数据。

步骤二:自定义数据集类

我们需要创建一个能够读取我们的数据的类。为此,我们需要继承torch.utils.data.Dataset类,并实现两个函数__getitem____len__

from torch.utils.data.dataset import Dataset

class CustomDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.transform = transform
        self.images = pd.read_csv(csv_path)
        self.img_dir = img_dir

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images.iloc[index, 0])
        img = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = self.images.iloc[index, 1]
        return img, label

    def __len__(self):
        return len(self.images)

这个类接受三个参数:csv_path指向我们的csv文件,img_dir指向我们的图像文件夹,transform是一个可选的图像变换操作。

步骤三:创建数据加载器

现在,我们已经定义了用于读取我们的自定义数据集的类。我们可以使用这个类创建DataLoader对象,并将其传递给迭代器。

from torch.utils.data import DataLoader
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor()])

dataset = CustomDataset('train.csv', './train/', transform)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

总结

在本文中,我们已经学习了如何使用PyTorch的DataLoader对象来加载和处理数据集。在深度学习中,数据集的加载和处理是非常重要的,并且它们可以显着影响模型的性能。对于大型数据集,DataLoader是一种自动将数据加载到GPU上并从中批处理数据的理想工具。在本文中,我们学习了DataLoader的基本使用方法,并提供了两个常见的示例应用程序。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中torch.utils.data.DataLoader实例详解 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • Python实现压缩和解压缩ZIP文件的方法分析

    当需要将多个文件合并成一个文件传输或存储时,压缩文件是一个非常有效的方式。ZIP是一种被广泛使用的文件格式,可以减小文件大小,并可以方便地打包和解压文件。 Python实现压缩ZIP文件 Python内置的zipfile模块提供了一种简单的方法来创建和压缩ZIP文件。下面是使用zipfile实现压缩ZIP文件的步骤。 步骤一:导入zipfile模块 使用Py…

    云计算 2023年5月18日
    00
  • 在Node.js中使用HTTP上传文件的方法

    下面是关于“在Node.js中使用HTTP上传文件的方法”的完整攻略,包含两个示例说明。 简介 在Node.js中,我们可以使用HTTP模块来上传文件。本攻略中,我们将介绍如何使用HTTP模块来上传文件,并提供一些最佳实践。 步骤 在使用HTTP模块上传文件时,我们可以通过以下步骤来实现: 创建一个HTTP请求。 将文件添加到请求中。 发送请求。 示例 示例…

    云计算 2023年5月16日
    00
  • 如何用云盾保障全球1500万用户愉快地“嘎嘎”

    如何用云盾保障全球1500万用户愉快地“嘎嘎” 什么是云盾 云盾是阿里云提供的一个针对安全业务的解决方案,在这个方案中可以提供多重安全防护措施,包括但不限于DDoS攻击防护、网站风险防护等。使用云盾可以帮助网站保障用户的安全,防御恶意攻击,同时提高网站的可用性和稳定性。 云盾如何保障全球1500万用户 1. DDoS攻击防护 DDoS攻击是一种常见的网络攻击…

    云计算 2023年5月17日
    00
  • .NET 6更新使.NET生态系统蜕变

    .NET 6更新使.NET生态系统蜕变 .NET 6是微软推出的最新版本的.NET框架,它带来了许多新的功能和改进,使.NET生态系统发生了蜕变。本文将详细讲解.NET 6更新使.NET生态系统蜕变的完整攻略,包括以下内容: .NET 6的新功能和改进 .NET生态系统的蜕变 示例说明 1. .NET 6的新功能和改进 .NET 6带来了许多新的功能和改进,…

    云计算 2023年5月16日
    00
  • 程序打包软件InstallShield 2018最新破解版安装激活教程(附下载)

    程序打包软件InstallShield 2018最新破解版安装激活教程 在本文中,我们将介绍最新破解版的程序打包软件InstallShield 2018的安装、激活和基本使用方法。 下载安装文件 首先,我们需要下载最新版的InstallShield 2018破解版安装文件。可以通过第三方网站进行下载,例如:https://www.crackdll.com/i…

    云计算 2023年5月17日
    00
  • Python操作Access数据库基本步骤分析

    下面是详细讲解“Python操作Access数据库基本步骤分析”的完整攻略。 一、前置准备 安装Python的pyodbc库; 安装Microsoft Access驱动程序。一般情况下,Windows系统自带了Microsoft Access驱动程序,可以通过控制面板中的“ODBC数据源管理器”来查看和配置。 二、连接Access数据库 连接Access数据…

    云计算 2023年5月18日
    00
  • 云计算市场价值爆发,两马之战日趋激烈

        2016年7月5日、8月12日,腾讯云+未来峰会在深圳站、苏州接连召开,深圳峰会中马化腾表示腾讯云的未来战略将升级为探索云上生态,实现全面开放。采用腾讯云的解决方案,就是接受和得到了整个腾讯平台,这是一个最强大的场景,没有之一。倡导与合作伙伴们共建云计算生态,共享云端生态发展的成果。苏州峰会也对外宣布将推动苏州各行各业“互联网+”“云化转型”落地,将…

    云计算 2023年4月10日
    00
  • 华为云新一代iPaaS全域融合集成平台全新升级

    摘要:基于华为十多年的数字化转型实践,华为云通过组装式交付、数智驱动、DevOps、服务化架构、安全可信、韧性6大关键技术助力客户实现应用现代化和高质量增长,华为云新一代iPaaS全域融合集成平台ROMA Connect也应运而生。 本文分享自华为云社区《华为云新一代iPaaS全域融合集成平台全新升级!》,作者:华为云头条。 数字化浪潮席卷,未来每一家企业都…

    云计算 2023年4月18日
    00
合作推广
合作推广
分享本页
返回顶部