Pytorch框架之one_hot编码函数解读

yizhihongxing

Pytorch框架之one_hot编码函数解读

一、什么是one_hot编码?

在机器学习中,one_hot编码是将一个分类变量转换成一系列二进制变量的过程,其中只有一个变量包含 1,其他变量都是 0。例如:有一个分类变量"颜色",它有三个类别:"红色"、"黄色"、"绿色",那么对 "颜色" 进行 one_hot 编码会得到如下的结果:

红色 -> [1,0,0]
黄色 -> [0,1,0]
绿色 -> [0,0,1]

二、Pytorch框架中的one_hot编码函数

在Pytorch框架中,使用torch.eye()函数可以很方便的进行one_hot编码。torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)的用法如下:

  • n:int, 行数;
  • m:int,列数,默认为 n;
  • out:Tensor,结果Tensor;
  • dtype:数据类型,默认不填,与输入Tensor一致;
  • layout:布局;
  • device:设备,默认为 CPU;
  • requires_grad:是否记录梯度,False 为不记录,True 为记录。默认为 False。

例如,对于红色、黄色、绿色三个颜色进行one_hot编码的示例代码如下:

import torch

color = torch.tensor([0, 1, 2]) # 颜色
num_classes = 3                 # 颜色的类别数
one_hot = torch.eye(num_classes)[color]
print(one_hot)

打印结果如下:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

可以看到,输出结果正好是 红色 -> [1,0,0]黄色 -> [0,1,0]绿色 -> [0,0,1]

另外,我们也可以使用torch.nn.functional.one_hot()来进行one_hot编码。torch.nn.functional.one_hot()的用法如下:

torch.nn.functional.one_hot(tensor, num_classes=None)

其中,

  • tensor:要进行one_hot编码的Tensor对象;
  • num_classes:one_hot编码后的结果向量的类别数。

num_classesNone时,则自动根据输入 tensor 中的最大值推断出 num_classes

例如,对于红色、黄色、绿色三个颜色进行one_hot编码的示例代码如下:

import torch.nn.functional as F
import torch

color = torch.tensor([0, 1, 2]) # 颜色
num_classes = 3                 # 颜色的类别数
one_hot = F.one_hot(color, num_classes=num_classes)
print(one_hot)

输出结果与使用torch.eye()函数得到的结果相同:

tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]])

三、总结

通过本次攻略,我们了解了one_hot编码的概念和在Pytorch框架中的实现方式,包括使用torch.eye()torch.nn.functional.one_hot()函数。例如,使用torch.eye()实现:

color = torch.tensor([0, 1, 2])
num_classes = 3
one_hot = torch.eye(num_classes)[color]

使用torch.nn.functional.one_hot()实现:

import torch.nn.functional as F
import torch

color = torch.tensor([0, 1, 2])
num_classes = 3
one_hot = F.one_hot(color, num_classes=num_classes)

当然,使用不同的函数得到的one_hot编码结果是相同的。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch框架之one_hot编码函数解读 - Python技术站

(0)
上一篇 2023年5月20日
下一篇 2023年5月20日

相关文章

  • 预签名 URL:发布图像错误:签名不匹配:Python

    【问题标题】:presigned URL : Post image error: Signature does not match: Python预签名 URL:发布图像错误:签名不匹配:Python 【发布时间】:2023-04-01 02:58:02 【问题描述】: 我将在 lambda 中执行以下命令以生成预签名 URL ”’ import boto…

    Python开发 2023年4月8日
    00
  • Python 可迭代对象 iterable的具体使用

    针对 Python 可迭代对象 iterable 的具体使用,我为您整理了以下完整攻略: 1. 什么是可迭代对象 iterable 可迭代对象 iterable 是指能够提供一个迭代器 iterator 的对象,迭代器是一个带有 next() 方法并且返回一个迭代值的对象。通常,可迭代对象 iterable 包括 list、set、tuple、dict、st…

    python 2023年6月3日
    00
  • Python的形参和实参使用方式

    当我们在Python中定义函数时,可以为函数指定形参,形参是在函数定义时用于接受传递给函数的数据的变量。函数被调用时,需要传递对应个数的实际参数给函数,这些实参的值将被传递给函数内的形参,并在函数内部使用。 Python中形参和实参的使用方式需要注意以下几点: 1. 形参和实参的基本使用 当使用函数时,形参是在函数定义时预先定义好的参数,用于拦截传递给函数的…

    python 2023年5月14日
    00
  • Pandas数据分析之批量拆分/合并Excel

    下面是《Pandas数据分析之批量拆分/合并Excel》的完整实例教程。 1. 教程背景 在实际的工作中,我们经常需要对Excel表格进行批量拆分或合并操作。这些操作如果手动完成往往比较繁琐,而使用Pandas库可以方便地实现这些操作。本篇教程将介绍如何使用Pandas库对Excel表格进行批量拆分和合并。 2. 批量拆分Excel 假设我们有一个包含多个工…

    python 2023年5月13日
    00
  • Python while true实现爬虫定时任务

    实现爬虫的定时任务需要用到while True循环和time.sleep()方法。当然在循环内部还需要完成实际的爬虫任务。下面是具体的步骤: 1. 导入相关模块 首先要导入的模块是requests和beautifulsoup4,用于进行网络请求和网页解析。另外还需要time模块用于设置间隔时间。 import requests from bs4 import…

    python 2023年6月3日
    00
  • python使用requests库爬取拉勾网招聘信息的实现

    Python 使用 requests 库爬取拉勾网招聘信息的实现 环境准备 首先,我们需要确保 Python 安装了 requests 库。如果没有安装,可以使用以下命令进行安装: pip install requests 分析网页结构 在使用 requests 爬取拉勾网招聘信息前,我们需要先分析网页的结构,以便于编写代码。以下是拉勾网的招聘页面的网址: …

    python 2023年5月14日
    00
  • python 教程实现 turtle海龟绘图

    接下来我将为您详细讲解“Python 教程实现 turtle 海龟绘图”的完整攻略,同时会给出两个示例说明。 1. 准备工作 在学习本教程之前,需要提前安装好 Python 环境和 turtle 库。如果您还没有安装 Python 环境和 turtle 库,请先按照官方安装教程进行安装。 2. 创建绘图窗口 在 Python 中,使用 turtle 库进行绘…

    python 2023年5月19日
    00
  • Python seaborn barplot画图案例

    接下来我将向您介绍如何使用Python Seaborn库来创建barplot(条形图)的完整攻略。 步骤一:导入必要的库和数据 我们需要先导入必要的Python库,包括Seaborn、Matplotlib和Pandas。同时,我们还需要加载我们想要绘制的数据集。在这个示例中,我们将使用Seaborn自带的数据集”tips”。 import seaborn a…

    python 2023年5月18日
    00
合作推广
合作推广
分享本页
返回顶部