pytorch 数据集图片显示方法

在PyTorch中,我们可以使用torchvision库来加载和处理图像数据集。本文将详细讲解如何使用PyTorch加载和显示图像数据集,并提供两个示例说明。

1. 加载图像数据集

在PyTorch中,我们可以使用torchvision.datasets模块中的ImageFolder类来加载图像数据集。ImageFolder类会自动将数据集中的图像按照文件夹名称进行分类,并将每个图像的标签设置为文件夹的名称。以下是一个加载图像数据集的示例代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder(root='./data/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./data/test', transform=transform)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的缩放、裁剪、转换为张量和归一化等操作。然后,我们使用ImageFolder类加载了训练数据集和测试数据集,并将数据预处理方法作为参数传入。

2. 显示图像数据集

在PyTorch中,我们可以使用matplotlib库来显示图像数据集。以下是一个显示图像数据集的示例代码:

import matplotlib.pyplot as plt
import numpy as np

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个imshow函数,该函数用于显示图像。然后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例1:显示CIFAR10数据集

以下是一个显示CIFAR10数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用CIFAR10类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例2:显示MNIST数据集

以下是一个显示MNIST数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用MNIST类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据集图片显示方法 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • 对pytorch中Tensor的剖析

    不是python层面Tensor的剖析,是C层面的剖析。   看pytorch下lib库中的TH好一阵子了,TH也是torch7下面的一个重要的库。 可以在torch的github上看到相关文档。看了半天才发现pytorch借鉴了很多torch7的东西。 pytorch大量借鉴了torch7下面lua写的东西并且做了更好的设计和优化。 https://git…

    PyTorch 2023年4月8日
    00
  • pytorch 中改变tensor维度(transpose)、拼接(cat)、压缩(squeeze)详解

    具体示例如下,注意观察维度的变化 1.改变tensor维度的操作:transpose、view、permute、t()、expand、repeat #coding=utf-8 import torch def change_tensor_shape(): x=torch.randn(2,4,3) s=x.transpose(1,2) #shape=[2,3,…

    PyTorch 2023年4月7日
    00
  • pytorch自定义dataset

    参考 一个例子 import torch from torch.utils import data class MyDataset(data.Dataset): def __init__(self): super(MyDataset, self).__init__() self.data = torch.randn(8,2) def __getitem__(…

    PyTorch 2023年4月8日
    00
  • 解决pytorch GPU 计算过程中出现内存耗尽的问题

    在PyTorch中,当进行GPU计算时,可能会出现内存耗尽的问题。本文将介绍如何解决PyTorch GPU计算过程中出现内存耗尽的问题,并提供两个示例说明。 1. 解决内存耗尽的问题 当进行GPU计算时,可能会出现内存耗尽的问题。为了解决这个问题,可以采取以下几种方法: 1.1 减少批量大小 减少批量大小是解决内存耗尽问题的最简单方法。可以通过减少批量大小来…

    PyTorch 2023年5月15日
    00
  • 深度学习训练过程中的学习率衰减策略及pytorch实现

    学习率是深度学习中的一个重要超参数,选择合适的学习率能够帮助模型更好地收敛。 本文主要介绍深度学习训练过程中的6种学习率衰减策略以及相应的Pytorch实现。 1. StepLR 按固定的训练epoch数进行学习率衰减。 举例说明: # lr = 0.05 if epoch < 30 # lr = 0.005 if 30 <= epoch &lt…

    2023年4月8日
    00
  • pytorch seq2seq闲聊机器人beam search返回结果

    decoder.py “”” 实现解码器 “”” import heapq import torch.nn as nn import config import torch import torch.nn.functional as F import numpy as np import random from chatbot.attention impor…

    PyTorch 2023年4月8日
    00
  • pytorch 文本情感分类和命名实体识别NER中LSTM输出的区别

      文本情感分类: 文本情感分类采用LSTM的最后一层输出 比如双层的LSTM,使用正向的最后一层和反向的最后一层进行拼接 def forward(self,input): ”’ :param input: :return: ”’ input_embeded = self.embedding(input) #[batch_size,seq_len,200…

    PyTorch 2023年4月8日
    00
  • 转:pytorch优化器传入多个模型的参数

    pytorch 优化器(optim)不同参数组,不同学习率设置  

    PyTorch 2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部