pytorch 数据集图片显示方法

在PyTorch中,我们可以使用torchvision库来加载和处理图像数据集。本文将详细讲解如何使用PyTorch加载和显示图像数据集,并提供两个示例说明。

1. 加载图像数据集

在PyTorch中,我们可以使用torchvision.datasets模块中的ImageFolder类来加载图像数据集。ImageFolder类会自动将数据集中的图像按照文件夹名称进行分类,并将每个图像的标签设置为文件夹的名称。以下是一个加载图像数据集的示例代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder(root='./data/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./data/test', transform=transform)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的缩放、裁剪、转换为张量和归一化等操作。然后,我们使用ImageFolder类加载了训练数据集和测试数据集,并将数据预处理方法作为参数传入。

2. 显示图像数据集

在PyTorch中,我们可以使用matplotlib库来显示图像数据集。以下是一个显示图像数据集的示例代码:

import matplotlib.pyplot as plt
import numpy as np

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个imshow函数,该函数用于显示图像。然后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例1:显示CIFAR10数据集

以下是一个显示CIFAR10数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用CIFAR10类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

示例2:显示MNIST数据集

以下是一个显示MNIST数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义数据预处理方法
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 显示图像函数
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

# 随机选择一张图像并显示
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(labels)

在上面的代码中,我们首先定义了一个数据预处理方法,该方法包括了图像的转换为张量和归一化等操作。然后,我们使用MNIST类加载了训练数据集,并使用DataLoader类将数据集转换为可迭代的数据加载器。接下来,我们定义了一个imshow函数,该函数用于显示图像。最后,我们使用iter函数和next函数从训练数据集中随机选择一批图像,并使用make_grid函数将这批图像拼接成一个网格。最后,我们调用imshow函数显示这个网格。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据集图片显示方法 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch【直播】2019 年县域农业大脑AI挑战赛—初级准备(一)切图

    比赛地址:https://tianchi.aliyun.com/competition/entrance/231717/introduction 这次比赛给的图非常大5万x5万,在训练之前必须要进行数据的切割。通常切割后的大小为512×512,或者1024×1024. 按照512×512切完后的结果如下: 切图时需要注意的几点是: gdal的二进制安装包wh…

    2023年4月6日
    00
  • pytorch模型预测结果与ndarray互转方式

    PyTorch是一个流行的深度学习框架,它提供了许多工具和函数来构建、训练和测试神经网络模型。在实际应用中,我们通常需要将PyTorch模型的预测结果转换为NumPy数组或将NumPy数组转换为PyTorch张量。在本文中,我们将介绍如何使用PyTorch和NumPy进行模型预测结果和数组之间的转换。 示例1:PyTorch模型预测结果转换为NumPy数组 …

    PyTorch 2023年5月15日
    00
  • pytorch 预训练模型读取修改相关参数的填坑问题

    PyTorch预训练模型读取修改相关参数的填坑问题 在使用PyTorch预训练模型时,有时需要读取模型的参数并进行修改。然而,这个过程中可能会遇到一些填坑问题。本文将提供一个完整的攻略,帮助您解决这些问题。 步骤1:下载预训练模型 首先,您需要下载预训练模型。您可以从PyTorch官方网站或其他来源下载预训练模型。在本文中,我们将使用ResNet18作为示例…

    PyTorch 2023年5月15日
    00
  • pytorch的topk()函数

    pytorch.topk()用于返回Tensor中的前k个元素以及元素对应的索引值。例: import torch item=torch.IntTensor([1,2,4,7,3,2]) value,indices=torch.topk(item,3) print(“value:”,value) print(“indices:”,indices) 输出结果为…

    2023年4月8日
    00
  • Pytorch之如何dropout避免过拟合

    PyTorch之如何使用dropout避免过拟合 在深度学习中,过拟合是一个常见的问题。为了避免过拟合,我们可以使用dropout技术。本文将提供一个完整的攻略,介绍如何使用PyTorch中的dropout技术来避免过拟合,并提供两个示例,分别是使用dropout进行图像分类和使用dropout进行文本分类。 dropout技术 dropout是一种常用的正…

    PyTorch 2023年5月15日
    00
  • pytorch SENet实现案例

    SENet是一种用于图像分类的深度神经网络,它通过引入Squeeze-and-Excitation模块来增强模型的表达能力。本文将深入浅析PyTorch中SENet的实现方法,并提供两个示例说明。 1. PyTorch中SENet的实现方法 PyTorch中SENet的实现方法如下: import torch.nn as nn import torch.nn…

    PyTorch 2023年5月15日
    00
  • pytorch分类模型绘制混淆矩阵以及可视化详解

    以下是关于“pytorch分类模型绘制混淆矩阵以及可视化详解”的完整攻略,其中包含两个示例说明。 示例1:绘制混淆矩阵 步骤1:导入必要的库 在绘制混淆矩阵之前,我们需要导入一些必要的库,包括numpy、matplotlib和sklearn。 import numpy as np import matplotlib.pyplot as plt from sk…

    PyTorch 2023年5月16日
    00
  • pytorch 中pad函数toch.nn.functional.pad()的用法

    torch.nn.functional.pad()是PyTorch中的一个函数,用于在张量的边缘填充值。它的语法如下: torch.nn.functional.pad(input, pad, mode=’constant’, value=0) 其中,input是要填充的张量,pad是填充的数量,mode是填充模式,value是填充的值。 pad参数可以是一个…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部