Pytorch 中 torch.unsqueeze() 与 torch.squeeze() 函数详细解析
1. 简介
torch.unsqueeze()
和 torch.squeeze()
是 pytorch 中的两个常用函数,用于调整张量的形状。
-
torch.unsqueeze(input, dim=None, *, out=None)
:
在指定维度上增加一个维度。返回的张量与input
张量共享数据存储空间,即input.storage()
与返回的张量的存储空间相同。input.dim()
的值加 1,如果dim=None
则在第一维增加一维,否则在dim
维度增加一维。 -
torch.squeeze(input, dim=None, *, out=None)
:
在指定维度上移除一个维度。返回的张量与input
张量共享数据存储空间,即input.storage()
与返回的张量的存储空间相同。input.dim()
的值减 1,如果dim=None
则移除所有大小为 1 的维度,否则移除指定维度,如果在指定维度上的大小不为 1,则返回的张量与input
张量相同。
2. 用法举例
2.1 使用 torch.unsqueeze()
对于一个张量 A,如果要在第二个维度上增加一个维度,则可以使用以下代码:
import torch
A = torch.randn(3, 4)
B = torch.unsqueeze(A, dim=1)
print(A.shape) # (3, 4)
print(B.shape) # (3, 1, 4)
在代码中,我们使用 torch.randn()
生成一个形状为 (3, 4)
的随机张量 A,然后在第二个维度上增加一个维度得到张量 B,B 的形状为 (3, 1, 4)
。
2.2 使用 torch.squeeze()
对于一个张量 C,如果要将第二个维度上的大小为 1 的维度移除,则可以使用以下代码:
import torch
C = torch.randn(3, 1, 4)
D = torch.squeeze(C, dim=1)
print(C.shape) # (3, 1, 4)
print(D.shape) # (3, 4)
在代码中,我们使用 torch.randn()
生成一个形状为 (3,1,4)
的随机张量 C,然后在第二个维度上移除大小为 1 的维度得到张量 D,D 的形状为 (3, 4)
。如果在第二维度上的大小不为 1,那么返回的张量与输入的张量 C 形状相同。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch中torch.unsqueeze()与torch.squeeze()函数详细解析 - Python技术站