PyTorch中torch.utils.data.DataLoader实例详解

PyTorch中torch.utils.data.DataLoader实例详解

介绍

在深度学习中,使用大量的数据进行模型的训练是必需的,但是对于包含大量数据集的任务来说,常规的数据输入(如读取整个数据集,并将其存储在内存中)通常会耗费大量的时间和空间。因此,数据加载的高效性至关重要。PyTorch提供了一个名为DataLoader的工具,可以快速且高效地处理数据。

DataLoader在PyTorch中是数据加载的一种方式,它可以通过提供一个数据集dataset和一个批大小batch_size,自动地对数据进行迭代和批量处理。我们可以使用DataLoader从硬盘或者内存中加载数据,并且可以在数据批次之间轻松地对数据进行处理。

基本使用

步骤一:创建数据集

在使用DataLoader之前,我们需要先创建一个数据集。数据集可以是一个文件夹,也可以是一个csv文件或其他类型文件。以下代码展示如何创建一个来自MNIST数据集的数据集:

import torchvision.datasets as dset
dataset = dset.MNIST(root='data/', download=True, transform=None)

这个数据集含有60000个训练图片和10000个测试图片,每张图片都是28x28的灰度图片。dataset对象可以通过getitem()方法访问每个样本。

步骤二:创建数据加载器DataLoader

在创建数据集之后,我们需要将它传递给DataLoader,以便对数据进行批处理和迭代。以下是创建数据加载器的基本语法:

from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

这里的dataset是我们在第一步创建的数据集对象,batch_size是指每批数据的大小,shuffle参数表示在每个时期结束时是否对数据进行重洗。

步骤三: 遍历数据集

现在,我们可以使用DataLoader来遍历数据集,并可以使用for循环语句按批迭代数据集,如下所示:

for x_train, y_train in dataloader:
    # do something...

这里的x_trainy_train分别是一个从数据集中获取的批次中的数据和标签。

示例应用一:图像分类

以下代码展示了如何使用DataLoader从CIFAR10数据集中加载图像数据,然后进行标准化处理,并将其拟合到一个简单的卷积神经网络中进行分类:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 加载CIFAR10数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = dset.CIFAR10(root='data/', train=True, download=True, transform=transform)
test_data = dset.CIFAR10(root='data/', train=False, download=True, transform=transform)

# 使用数据加载器迭代数据集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义卷积神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5,padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5,padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 定义优化器和损失函数
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the model on the test images: %d %%' % (
    100 * correct / total))

示例应用二:自定义数据集

步骤一:准备数据

首先,我们需要准备一组自己的数据集,我们可以将所有数据放在一个文件夹中,或者使用csv文件导入数据。

步骤二:自定义数据集类

我们需要创建一个能够读取我们的数据的类。为此,我们需要继承torch.utils.data.Dataset类,并实现两个函数__getitem____len__

from torch.utils.data.dataset import Dataset

class CustomDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.transform = transform
        self.images = pd.read_csv(csv_path)
        self.img_dir = img_dir

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images.iloc[index, 0])
        img = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = self.images.iloc[index, 1]
        return img, label

    def __len__(self):
        return len(self.images)

这个类接受三个参数:csv_path指向我们的csv文件,img_dir指向我们的图像文件夹,transform是一个可选的图像变换操作。

步骤三:创建数据加载器

现在,我们已经定义了用于读取我们的自定义数据集的类。我们可以使用这个类创建DataLoader对象,并将其传递给迭代器。

from torch.utils.data import DataLoader
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor()])

dataset = CustomDataset('train.csv', './train/', transform)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

总结

在本文中,我们已经学习了如何使用PyTorch的DataLoader对象来加载和处理数据集。在深度学习中,数据集的加载和处理是非常重要的,并且它们可以显着影响模型的性能。对于大型数据集,DataLoader是一种自动将数据加载到GPU上并从中批处理数据的理想工具。在本文中,我们学习了DataLoader的基本使用方法,并提供了两个常见的示例应用程序。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中torch.utils.data.DataLoader实例详解 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • c#代码生成URL地址的示例

    对于“c#代码生成URL地址的示例”,我可以提供如下完整攻略: 1. 使用System.Net.Http.HttpClient生成URL地址示例 下面提供一个使用System.Net.Http.HttpClient生成URL地址的示例,具体步骤如下: 创建System.Net.Http.HttpClient实例: using System.Net.Http;…

    云计算 2023年5月17日
    00
  • C#实现滑动开关效果

    C#实现滑动开关效果 滑动开关是一种常见的用户界面控件,它通常用于开关某些功能或选项。在C#中,我们可以使用Windows Forms或WPF来实现滑动开关效果。本文将提供一个完整攻略,包括如何在Windows Forms和WPF中实现滑动开关效果,并提供两个示例说明。 Windows Forms 在Windows Forms中,我们可以使用TrackBar…

    云计算 2023年5月16日
    00
  • python网络编程调用recv函数完整接收数据的三种方法

    Python 的网络编程是一门非常重要的技能,在网络编程中,我们通常使用 recv() 函数来接收数据。但是由于网络不稳定等原因,可能出现一次 recv() 无法接收完整数据的情况。下面我们介绍几种处理这种情况的方法。 方法一:自定义数据长度 使用 recv() 函数时,可以给定一个长度参数,用于判断是否已经接收完整数据。示例代码如下: import soc…

    云计算 2023年5月18日
    00
  • openstack已经成为云计算的事实标准,其依赖的一个重要的核心就是虚拟化技术

    (1)虚拟化的概念   所谓虚拟化就是在物理设备上同时运行多台虚拟机,这些虚拟机共享物理设备的CPU,内存和网络,但是这些虚拟机之间是相互隔离的。  物理机被称为host(宿主机),虚拟机被称为guest。 (2)虚拟化分类   虚拟机的调度管理依赖于hypervisor软件,根据hypervisor所处的位置,可以分为2大类:  1、直接在硬件上安装hyp…

    云计算 2023年4月10日
    00
  • 区块链去中心化是什么意思?详解去中心化的含义

    以下是“区块链去中心化是什么意思?详解去中心化的含义”的完整攻略: 1. 区块链去中心化的含义 区块链去中心化是指在区块链网络中,没有中心化的控制机构或单一的权威机构,而是由网络中的所有节点共同维护和管理。这种去中心化的特点使得区块链网络具有高度的安全性和透明度,同时也能够避免单点故障和数据篡改等问题。 2. 去中心化的含义 2.1. 去中心化的优势 去中心…

    云计算 2023年5月16日
    00
  • 如何用Matlab和Python读取Netcdf文件

    读取NetCDF文件的步骤如下: 1. 安装需要的工具包 在Matlab中使用ncread函数读取NetCDF文件前,需要安装MATLAB NetCDF工具包。安装方法可参考官方文档。 在Python中,需要安装netCDF4库,可通过pip命令安装: pip install netCDF4 2. 导入读取器 在Matlab中,需要导入ncread函数来读取…

    云计算 2023年5月18日
    00
  • 初学python数学建模之数据导入(小白篇)

    当我们进行Python数学建模时,常需要导入数据,而数据导入是我们进行数学建模的首要步骤。下面将会介绍Python中常用的几种数据导入方法及其详细使用步骤。 1. 通过CSV文件导入数据 CSV文件是指逗号分隔值文件,通过Python中内置的csv模块可以轻松读取和导入CSV文件。 CSV文件中的每列都代表一个特征,每行代表一个数据点。以下是使用Python…

    云计算 2023年5月18日
    00
  • Python 分析Nginx访问日志并保存到MySQL数据库实例

    以下是详细的Python分析Nginx访问日志并保存到MySQL数据库实例的攻略: 1. 了解Nginx访问日志格式 在保存Nginx访问日志之前,我们需要了解Nginx日志格式的设置。默认情况下,Nginx日志格式的设置会输出一行类似以下的记录: 10.0.10.153 – – [17/Jan/2022:14:57:24 +0800] "GET …

    云计算 2023年5月18日
    00
合作推广
合作推广
分享本页
返回顶部