pytorch中的transforms.ToTensor和transforms.Normalize的实现

PyTorch是目前非常流行的深度学习框架之一,它提供了transforms模块来进行图像的预处理。其中,transforms.ToTensortransforms.Normalize是常用的图像预处理方法,下面将详细讲解它们的实现。

一. transforms.ToTensor实现

transforms.ToTensor用于将PIL图像或numpy.array数组类型的图像转换成PyTorch中的tensor类型,并将图像像素值归一化到[0,1]之间。其实现可分为以下几个步骤:

  1. 将PIL图像或numpy.array数组类型的图像转换成tensor类型;
  2. 将图像像素值转换为[0,1]之间的小数;
  3. 调整图像的维度,将通道维度放在第二个位置。即变为(C,H,W)的形式。

以下是使用样例一:

from PIL import Image
from torchvision import transforms
import torch

# 加载图像
img = Image.open('test.jpg')

# 定义变换
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

# 进行图像变换
img_t = transform(img)
print(img_t.shape)
print(torch.min(img_t), torch.max(img_t))

输出结果为:

torch.Size([3, 224, 224])
tensor(0.) tensor(1.)

解释一下输出的结果:torch.Size([3, 224, 224])表示变换后图像的维度为(C,H,W),其中C表示图像的通道数,为3;tensor(0.) tensor(1.)表示变换后的图像像素值位于[0,1]之间。

以下是使用样例二:

import numpy as np

# 生成numpy数组
img_np = np.array(img)

# 定义变换
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 进行图像变换
img_t = transform(img_np)
print(img_t.shape)
print(torch.min(img_t), torch.max(img_t))

输出结果为:

torch.Size([3, 300, 400])
tensor(0.) tensor(1.)

解释一下输出的结果:torch.Size([3, 300, 400])表示变换后图像的维度为(C,H,W),其中C表示图像的通道数,为3;tensor(0.) tensor(1.)表示变换后的图像像素值位于[0,1]之间。

需要注意的是,在使用transforms.ToTensor变换时,PIL图像或numpy.array数组类型的图像需要保证通道维度在最后一个位置。

二. transforms.Normalize实现

transforms.Normalize用于对输入图像进行标准化处理,即按照给定的均值和标准差对输入图像进行标准化。其实现可分为以下几个步骤:

  1. 根据给定的均值和标准差对输入图像进行标准化;
  2. 将图像像素值再次归一化到[-1,1]之间。

以下是使用样例一:

from PIL import Image
from torchvision import transforms
import torch

# 加载图像
img = Image.open('test.jpg')

# 定义变换
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 进行图像变换
img_t = transform(img)
print(img_t.shape)
print(torch.min(img_t), torch.max(img_t))

输出结果为:

torch.Size([3, 224, 224])
tensor(-2.1179) tensor(2.2489)

解释一下输出的结果:torch.Size([3, 224, 224])表示变换后图像的维度为(C,H,W),其中C表示图像的通道数,为3;tensor(-2.1179) tensor(2.2489)表示变换后的图像像素值位于[-1,1]之间。

以下是使用样例二:

import numpy as np

# 生成numpy数组
img_np = np.array(img)

# 定义变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 进行图像变换
img_t = transform(img_np)
print(img_t.shape)
print(torch.min(img_t), torch.max(img_t))

输出结果为:

torch.Size([3, 300, 400])
tensor(-2.1179) tensor(2.2489)

解释一下输出的结果:torch.Size([3, 300, 400])表示变换后图像的维度为(C,H,W),其中C表示图像的通道数,为3;tensor(-2.1179) tensor(2.2489)表示变换后的图像像素值位于[-1,1]之间。

需要注意的是,在使用transforms.Normalize时,输入的图像必须是tensor类型。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch中的transforms.ToTensor和transforms.Normalize的实现 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 指针操作数组的两种方法(总结)

    下面我就来详细讲解“指针操作数组的两种方法(总结)”的完整攻略。 什么是指针操作数组? 指针操作数组是指通过指针变量对数组进行操作的一种方式。指针变量存储的是一个地址,该地址指向数组的第一个元素,通过指针变量可以对数组进行遍历、访问、修改等操作。 方法1:指针通过数组名操作数组 指针通过数组名操作数组是指定义一个指向数组的指针变量,然后通过该指针变量对数组进…

    人工智能概览 2023年5月25日
    00
  • Pycharm及python安装详细教程(图解)

    下面是Pycharm及Python安装详细教程的完整攻略: Pycharm及Python安装详细教程(图解) 1.下载Python安装包 在Python官网下载对应系统的安装包,建议选择最新的稳定版本进行下载。 2.安装Python 双击下载的安装包,按照步骤进行安装。安装过程中注意勾选“Add Python to PATH”选项,这样可以方便后面在命令行中…

    人工智能概览 2023年5月25日
    00
  • Django之无名分组和有名分组的实现

    Django之无名分组和有名分组的实现 在Django的url路由中,我们可以通过使用正则表达式来匹配不同的url地址,并且通过分组的方式将匹配到的信息提取出来,这就是Django的分组功能,分组的方式可以分为无名分组和有名分组。 无名分组 无名分组即为不特别指定分组名称的分组方式,使用()来进行分组,$1、$2等都是分组的引用,这种引用方式不直观,难以辨别…

    人工智能概论 2023年5月25日
    00
  • MongoDB C 驱动程序安装(libmongoc) 和 BSON 库(libbson)方法

    安装MongoDB C驱动程序(libmongoc)和BSON库(libbson)方法如下: 安装依赖项 在安装MongoDB C驱动程序和BSON库之前,需要先安装一些依赖项。以下是在Ubuntu系统中安装这些依赖项的命令: sudo apt-get update sudo apt-get install -y autoconf automake libt…

    人工智能概论 2023年5月25日
    00
  • Python的Django框架中的Context使用

    下面是Python的Django框架中的Context使用的完整攻略: 什么是Context? Context是Django框架中一个非常重要的部分,它负责传递模板中需要的变量以及函数等信息。在Django框架中,Context通常是一个字典对象,其中键为变量名,值为对应变量的值。 如何定义Context? 在Django框架中,可以通过定义一个字典来创建C…

    人工智能概览 2023年5月25日
    00
  • django中ORM模型常用的字段的使用方法

    下面是“Django中ORM模型常用字段的使用方法”的攻略。 简介 Django中的ORM(对象关系映射)是一个强大的工具,它使开发人员能够更轻松地与数据库交互。Django中ORM提供了许多内置字段,这些字段可以将Python对象映射为数据库中的列。本攻略将会介绍Django中ORM模型常用的字段和它们的基本使用方法。 CharField CharFiel…

    人工智能概论 2023年5月25日
    00
  • Dubbo本地开发技巧分享

    Dubbo本地开发技巧分享 Dubbo是一个高性能、轻量级的Java RPC框架,被广泛应用于微服务架构中。在进行Dubbo应用开发时,本地开发是必不可少的环节,因此掌握一些Dubbo本地开发技巧是非常有必要的。 本文将会分享几个Dubbo本地开发技巧,包括Dubbo本地开发环境的配置、Dubbo服务的本地调用等。 环境配置 在进行本地开发前,需要首先配置本…

    人工智能概览 2023年5月25日
    00
  • python控制windows剪贴板,向剪贴板中写入图片的实例

    Python控制Windows剪贴板,向剪贴板中写入图片,可以通过下面几个步骤完成。 1. 安装必要的库 首先需要安装pywin32和Pillow两个Python库: pip install pywin32 pip install Pillow 2. 代码实现 以下是一个演示如何将一张图片复制到剪贴板的Python脚本示例: import win32clip…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部