pytorch 数据集图片显示方法

yizhihongxing

在PyTorch中,我们可以使用torchvision库来加载和处理图像数据集。本文将详细讲解如何使用PyTorch加载和显示图像数据集,并提供两个示例说明。

1. 加载图像数据集

在PyTorch中,我们可以使用torchvision.datasets模块中的ImageFolder类来加载图像数据集。ImageFolder类会自动将数据集中的图像按照文件夹名称进行分类,并将每个图像的标签设置为文件夹的名称。以下是一个加载图像数据集的示例代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder(root='./data/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./data/test', transform=transform)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的缩放、裁剪、转换为张量和归一化等操作。然后,我们使用ImageFolder类加载了训练数据集和测试数据集,并将数据预处理方法作为参数传入。

2. 显示图像数据集

在PyTorch中,我们可以使用matplotlib库来显示图像数据集。以下是一个显示图像数据集的示例代码:

import matplotlib.pyplot as plt
import numpy as np

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个imshow函数,该函数用于显示图像。然后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例1:显示CIFAR10数据集

以下是一个显示CIFAR10数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用CIFAR10类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例2:显示MNIST数据集

以下是一个显示MNIST数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用MNIST类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据集图片显示方法 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch GPU显存充足却显示out of memory的解决方式

    当我们在使用PyTorch进行深度学习训练时,经常会遇到GPU显存充足却显示out of memory的问题。这个问题的原因是PyTorch默认会占用所有可用的GPU显存,而在训练过程中,显存的使用可能会超出我们的预期。本文将提供一个详细的攻略,介绍如何解决PyTorch GPU显存充足却显示out of memory的问题,并提供两个示例说明。 1. 使用…

    PyTorch 2023年5月15日
    00
  • PyTorch Geometric Temporal 介绍 —— 数据结构和RGCN的概念

    Introduction PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric. PyTorch Geometric Temporal 是基于PyTorch Geometric的对时间序列图数据的扩展。 Dat…

    PyTorch 2023年4月8日
    00
  • pytorch的Backward过程用时太长问题及解决

    在PyTorch中,当我们使用反向传播算法进行模型训练时,有时会遇到Backward过程用时太长的问题。这个问题可能会导致训练时间过长,甚至无法完成训练。本文将提供一个完整的攻略,介绍如何解决这个问题。我们将提供两个示例,分别是使用梯度累积和使用半精度训练。 示例1:使用梯度累积 梯度累积是一种解决Backward过程用时太长问题的方法。它的基本思想是将一个…

    PyTorch 2023年5月15日
    00
  • Pytorch自动求导函数详解流程以及与TensorFlow搭建网络的对比

    以下是“PyTorch自动求导函数详解流程以及与TensorFlow搭建网络的对比”的完整攻略,包含两个示例说明。 PyTorch自动求导函数详解流程 PyTorch是一个基于Python的科学计算库,它提供了强大的GPU加速支持和自动求导机制。在PyTorch中,我们可以使用自动求导函数来计算梯度,从而实现反向传播算法。下面是PyTorch自动求导函数的详…

    PyTorch 2023年5月15日
    00
  • Pytorch快速入门及在线体验

    本文搭配了Pytorch在线环境,可以直接在线体验。 Pytorch是Facebook 的 AI 研究团队发布了一个基于 Python的科学计算包,旨在服务两类场合: 1.替代numpy发挥GPU潜能 ;2. 一个提供了高度灵活性和效率的深度学习实验性平台。 1.Pytorch简介 Pytorch是Facebook 的 AI 研究团队发布了一个基于 Pyth…

    2023年4月8日
    00
  • pytorch conditional GAN 调试笔记

    推荐的几个开源实现 znxlwm 使用InfoGAN的结构,卷积反卷积 eriklindernoren 把mnist转成1维,label用了embedding wiseodd 直接从tensorflow代码转换过来的,数据集居然还用tf的数据集。。 Yangyangii 转1维向量,全连接 FangYang970206 提供了多标签作为条件的实现思路 znx…

    2023年4月8日
    00
  • pytorch实现分类

    完整代码 #实现分类 import torch import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt import torch.optim as optim #生成数据 n_data = torch.ones(10…

    PyTorch 2023年4月7日
    00
  • Focal Loss 的Pytorch 实现以及实验

      Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。 如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部