Pytorch框架之one_hot编码函数解读
一、什么是one_hot编码?
在机器学习中,one_hot编码是将一个分类变量转换成一系列二进制变量的过程,其中只有一个变量包含 1,其他变量都是 0。例如:有一个分类变量"颜色",它有三个类别:"红色"、"黄色"、"绿色",那么对 "颜色" 进行 one_hot 编码会得到如下的结果:
红色 -> [1,0,0]
黄色 -> [0,1,0]
绿色 -> [0,0,1]
二、Pytorch框架中的one_hot编码函数
在Pytorch框架中,使用torch.eye()
函数可以很方便的进行one_hot编码。torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
的用法如下:
n
:int, 行数;m
:int,列数,默认为 n;out
:Tensor,结果Tensor;dtype
:数据类型,默认不填,与输入Tensor一致;layout
:布局;device
:设备,默认为 CPU;requires_grad
:是否记录梯度,False 为不记录,True 为记录。默认为 False。
例如,对于红色、黄色、绿色三个颜色进行one_hot编码的示例代码如下:
import torch
color = torch.tensor([0, 1, 2]) # 颜色
num_classes = 3 # 颜色的类别数
one_hot = torch.eye(num_classes)[color]
print(one_hot)
打印结果如下:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
可以看到,输出结果正好是 红色 -> [1,0,0]
、黄色 -> [0,1,0]
和 绿色 -> [0,0,1]
。
另外,我们也可以使用torch.nn.functional.one_hot()
来进行one_hot编码。torch.nn.functional.one_hot()
的用法如下:
torch.nn.functional.one_hot(tensor, num_classes=None)
其中,
tensor
:要进行one_hot编码的Tensor对象;num_classes
:one_hot编码后的结果向量的类别数。
当num_classes
为None
时,则自动根据输入 tensor
中的最大值推断出 num_classes
。
例如,对于红色、黄色、绿色三个颜色进行one_hot编码的示例代码如下:
import torch.nn.functional as F
import torch
color = torch.tensor([0, 1, 2]) # 颜色
num_classes = 3 # 颜色的类别数
one_hot = F.one_hot(color, num_classes=num_classes)
print(one_hot)
输出结果与使用torch.eye()
函数得到的结果相同:
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
三、总结
通过本次攻略,我们了解了one_hot编码的概念和在Pytorch框架中的实现方式,包括使用torch.eye()
和torch.nn.functional.one_hot()
函数。例如,使用torch.eye()
实现:
color = torch.tensor([0, 1, 2])
num_classes = 3
one_hot = torch.eye(num_classes)[color]
使用torch.nn.functional.one_hot()
实现:
import torch.nn.functional as F
import torch
color = torch.tensor([0, 1, 2])
num_classes = 3
one_hot = F.one_hot(color, num_classes=num_classes)
当然,使用不同的函数得到的one_hot编码结果是相同的。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch框架之one_hot编码函数解读 - Python技术站