PyTorch中torch.utils.data.DataLoader实例详解

PyTorch中torch.utils.data.DataLoader实例详解

介绍

在深度学习中,使用大量的数据进行模型的训练是必需的,但是对于包含大量数据集的任务来说,常规的数据输入(如读取整个数据集,并将其存储在内存中)通常会耗费大量的时间和空间。因此,数据加载的高效性至关重要。PyTorch提供了一个名为DataLoader的工具,可以快速且高效地处理数据。

DataLoader在PyTorch中是数据加载的一种方式,它可以通过提供一个数据集dataset和一个批大小batch_size,自动地对数据进行迭代和批量处理。我们可以使用DataLoader从硬盘或者内存中加载数据,并且可以在数据批次之间轻松地对数据进行处理。

基本使用

步骤一:创建数据集

在使用DataLoader之前,我们需要先创建一个数据集。数据集可以是一个文件夹,也可以是一个csv文件或其他类型文件。以下代码展示如何创建一个来自MNIST数据集的数据集:

import torchvision.datasets as dset
dataset = dset.MNIST(root='data/', download=True, transform=None)

这个数据集含有60000个训练图片和10000个测试图片,每张图片都是28x28的灰度图片。dataset对象可以通过getitem()方法访问每个样本。

步骤二:创建数据加载器DataLoader

在创建数据集之后,我们需要将它传递给DataLoader,以便对数据进行批处理和迭代。以下是创建数据加载器的基本语法:

from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

这里的dataset是我们在第一步创建的数据集对象,batch_size是指每批数据的大小,shuffle参数表示在每个时期结束时是否对数据进行重洗。

步骤三: 遍历数据集

现在,我们可以使用DataLoader来遍历数据集,并可以使用for循环语句按批迭代数据集,如下所示:

for x_train, y_train in dataloader:
    # do something...

这里的x_trainy_train分别是一个从数据集中获取的批次中的数据和标签。

示例应用一:图像分类

以下代码展示了如何使用DataLoader从CIFAR10数据集中加载图像数据,然后进行标准化处理,并将其拟合到一个简单的卷积神经网络中进行分类:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 加载CIFAR10数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = dset.CIFAR10(root='data/', train=True, download=True, transform=transform)
test_data = dset.CIFAR10(root='data/', train=False, download=True, transform=transform)

# 使用数据加载器迭代数据集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义卷积神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5,padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5,padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 定义优化器和损失函数
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the model on the test images: %d %%' % (
    100 * correct / total))

示例应用二:自定义数据集

步骤一:准备数据

首先,我们需要准备一组自己的数据集,我们可以将所有数据放在一个文件夹中,或者使用csv文件导入数据。

步骤二:自定义数据集类

我们需要创建一个能够读取我们的数据的类。为此,我们需要继承torch.utils.data.Dataset类,并实现两个函数__getitem____len__

from torch.utils.data.dataset import Dataset

class CustomDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.transform = transform
        self.images = pd.read_csv(csv_path)
        self.img_dir = img_dir

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images.iloc[index, 0])
        img = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = self.images.iloc[index, 1]
        return img, label

    def __len__(self):
        return len(self.images)

这个类接受三个参数:csv_path指向我们的csv文件,img_dir指向我们的图像文件夹,transform是一个可选的图像变换操作。

步骤三:创建数据加载器

现在,我们已经定义了用于读取我们的自定义数据集的类。我们可以使用这个类创建DataLoader对象,并将其传递给迭代器。

from torch.utils.data import DataLoader
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor()])

dataset = CustomDataset('train.csv', './train/', transform)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

总结

在本文中,我们已经学习了如何使用PyTorch的DataLoader对象来加载和处理数据集。在深度学习中,数据集的加载和处理是非常重要的,并且它们可以显着影响模型的性能。对于大型数据集,DataLoader是一种自动将数据加载到GPU上并从中批处理数据的理想工具。在本文中,我们学习了DataLoader的基本使用方法,并提供了两个常见的示例应用程序。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中torch.utils.data.DataLoader实例详解 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • python实现比较类的两个instance(对象)是否相等的方法分析

    针对“python实现比较类的两个instance(对象)是否相等的方法分析”的问题,我用下面的几个方面进行了详细的讲解和分析。 方式一:使用“==”运算符 在python中,可以使用“==”运算符来比较两个instance对象是否相等。默认情况下,比较的是两个对象的内存地址是否相同。但是,对于许多类来说,相等意味着对象的属性值是相同的,因此我们需要覆盖Py…

    云计算 2023年5月18日
    00
  • ZEALER创始人王自如谈ZEALER网站的云计算应用

    以下是关于ZEALER创始人王自如谈ZEALER网站的云计算应用的攻略和示例,供您参考: 1. 什么是云计算 云计算是一种基于互联网的计算方式,它将计算资源(如服务器、存储、数据库等)通过互联网提供给用户使用。云计算可以帮助用户节省成本、提高效率、提高可靠性等。 2. ZEALER网站的云计算应用 ZEALER网站是一个科技媒体网站,它使用云计算技术来提供服…

    云计算 2023年5月16日
    00
  • 一文看懂云计算、虚拟化和容器

    “云计算”这个词,相信大家都非常熟悉。 作为信息科技发展的主流趋势,它频繁地出现在我们的眼前。伴随它一起出现的,还有这些概念名词——OpenStack、Hypervisor、KVM、Docker、K8S… 这些名词概念,全部都属于云计算技术领域的范畴。 对于初学者来说,理解这些概念的具体含义并不是一件容易的事情。 所以,今天这篇文章,将给大家做一个通俗易…

    云计算 2023年4月17日
    00
  • 国际国内云计算发展现状及未来前景

    一、“云计算”概述         云计算(Cloud Computing)是分布式处理(Distributed Computing)、并行处理(Parallel Computing)和网格计算(Grid Computing)的发展。        (一)云计算的基本原理。通过使计算分布在大量的分布式计算机上,而非本地计算机或远程服务器中,企业数据中心的运行…

    云计算 2023年4月12日
    00
  • 云计算与 Cloud Native | 数人云CEO王璞@KVM分享实录

    今天小数又给大家带来一篇干货满满的分享——来自KVM社区线上群分享的实录,分享嘉宾是数人云CEO王璞,题目是《云计算与 Cloud Native》。这是数人云在KVM社区群分享的第一弹,之后还有数人云CTO肖德时、COO谢乐冰的Docker与Mesos的应用实战经验分享,敬请期待! 嘉宾介绍 王璞,数人云创始人兼CEO美国 George Mason 大学计算…

    云计算 2023年4月12日
    00
  • ASP.NET Core中的wwwroot文件夹

    下面是关于“ASP.NET Core中的wwwroot文件夹”的完整攻略,包含两个示例说明。 简介 在ASP.NET Core应用程序中,wwwroot文件夹是一个特殊的文件夹,用于存储静态文件,如HTML、CSS、JavaScript、图像等。在本攻略中,我们将介绍如何在ASP.NET Core应用程序中使用wwwroot文件夹。 步骤 在ASP.NET …

    云计算 2023年5月16日
    00
  • Python手拉手教你爬取贝壳房源数据的实战教程

    “Python手拉手教你爬取贝壳房源数据的实战教程”是一篇教程,详细介绍了使用Python爬虫爬取贝壳网房源数据的全过程。以下是该教程的完整攻略: 一、准备工作 在开始爬虫之前,需要准备相应的工具和环境:1. 安装Python环境和相关库:本教程使用Python3编写,需要安装相关库,如requests、BeautifulSoup等;2. 首先需要了解网站的…

    云计算 2023年5月18日
    00
  • 全部满分!阿里云函数计算通过可信云21项测试

    简介: 在未来,无论是一方云服务,还是三方应用,所有事件都可被函数计算等服务可靠地处理。 今日,“2020 可信云线上峰会”正式召开。会上,中国信通院公布了混合云安全、云组网、函数即服务、消息队列、云计算安全运营中心等首次评估结果。阿里云函数计算通过了基础能力要求、平台可观测能力、服务性能、服务安全和服务计量准确性等 21 项测试,最终以满分成绩通过可信云函…

    云计算 2023年4月12日
    00
合作推广
合作推广
分享本页
返回顶部