介绍:在PyTorch中,PyTorch提供了两个函数:torch.flatten和torch.nn.Flatten用于将多维张量转换为一维张量。然而它们之间的实现方式和特点略有不同。
Torch.flatten()
torch.flatten(input, start_dim=0, end_dim=-1)函数用于将一个输入的多维形状张量展平成形状为“1D”的张量。函数有以下参数:
input
:要展平的张量。start_dim
:展平的起始维度。默认值为0。end_dim
: 展平的结束维度。默认为最后一个维度。
示例1:
import torch
# 定义二维张量 2x3
x = torch.randn(2, 3)
print(x)
#Flatten start_dim=0(default)
out1 = torch.flatten(x)
print(out1)
#Flatten start_dim=1 (This is what we are looking for)
out2 = torch.flatten(x, start_dim=1)
print(out2)
输出:
tensor([[-0.2140, 0.0203, -0.6107],
[ 0.5391, -0.4602, -0.5470]])
tensor([-0.2140, 0.0203, -0.6107, 0.5391, -0.4602, -0.5470])
tensor([-0.2140, 0.0203, -0.6107, 0.5391, -0.4602, -0.5470])
在上面的示例中,使用torch.randn函数创建了一个2x3列的多维形状张量,并通过torch.flatten函数将其展平。这里我们通过两个不同的start_dim参数值讨论了这个函数。
Torch.nn.Flatten()
nn.Module.flatten()函数是PyTorch工具包中包含的类。nn.Module.flatten()和torch.flatten()类似,不过在函数、参数调用和输出方面存在一些差异。
import torch
import torch.nn as nn
# 定义二维张量 2x3
x = torch.randn(2, 3)
print(x)
#Flatten
flatten = nn.Flatten()
out1 = flatten(x)
print(out1)
输出:
tensor([[ 0.7535, -0.2421, -2.3212],
[ 1.0820, -2.2361, -1.2694]])
tensor([ 0.7535, -0.2421, -2.3212, 1.0820, -2.2361, -1.2694])
在上面的示例中,我们创建了一个2x3列的多维Tensor张量,创建一个nn.Flatten()类,并将其应用于输入Tensor张量。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch中torch.flatten()和torch.nn.Flatten()实例详解 - Python技术站