PyTorch中nn.Flatten()函数详解及示例
1. 简介
nn.Flatten()
是PyTorch中的一个函数,它用来将输入张量展平为一维张量。它可以被用来将二维卷积层的输出偏扁为一维传到全连接层里,或者张量reshape的一种更简单的方式。
2. 使用方法
nn.Flatten()可以接受任何形式的输入,但在输入之前必须将通道数(C)和图像大小(dx,dy)确定好,这些尺寸通过计算输入张量的numel()函数来计算得到。
下面是nn.Flatten()函数的基本使用方法和实例:
2.1 常规使用
import torch.nn as nn
flatten = nn.Flatten()
input = torch.randn(3, 4, 5, 6)
output = flatten(input)
print(output.size()) # 输出torch.Size([360])
在这个示例中,我们首先导入了nn.Flatten()
函数。我们随机生成了一个大小为[3,4,5,6]
的张量,其中第一个维度是批大小,后三个维度是图像的大小和通道数量。
然后我们把这个张量input
传入nn.Flatten()中,输出的张量形状变为了[360],所有batch里面的图像都被展平了。
2.2 和nn.Sequential()一起使用
nn.Sequential()
是一个PyTorch中用于封装序列的函数,通常用于将多个层串联起来形成一个神经网络。
下面是nn.Flatten()函数和nn.Sequential()函数一起使用的示例:
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 10, 3, stride=2, padding=1),
nn.Flatten()
)
input = torch.randn(1, 1, 28, 28)
output = model(input)
print(output.size()) # 输出torch.Size([1, 10])
在这个示例中,我们定义了一个包含三个卷积层和一个自动展平的序列模型。序列从一个1通道的28×28像素的图像开始,然后通过三个卷积层,最后通过nn.Flatten()
输送前向传播的一维向量输出。
2.3 和nn.Linear()一起使用
nn.Linear()
是PyTorch中的一个线性变换层,常用来在神经网络中构建一个线性分类器。
下面是nn.Flatten()函数和nn.Linear()函数一起使用的示例:
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 10, 3, stride=2, padding=1),
nn.Flatten(),
nn.Linear(90, 10)
)
input = torch.randn(1, 1, 28, 28)
output = model(input)
print(output.size()) # 输出torch.Size([1, 10])
在这个示例中,我们定义了一个包含三个卷积层、一个自动展平的序列模型,和一个具有10个输出节点的线性分类器。序列从一个1通道的28×28像素的图像开始,然后通过三个卷积层和一个自动展平的序列,最后通过nn.Linear()
输出前向传播结果。
3. 结论
在本文中,我们介绍了PyTorch中的nn.Flatten()函数,并提供了两个使用示例,详细介绍了如何在卷积神经网络中使用这个函数,以帮助巩固您对PyTorch的掌握。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch中nn.Flatten()函数详解及示例 - Python技术站