详解PyTorch中squeeze()和unsqueeze()函数介绍
在PyTorch中,squeeze()
和unsqueeze()
函数是用于改变张量形状的常用函数。本文将详细介绍这两个函数的用法和示例。
1. unsqueeze()
函数
unsqueeze()
函数用于在指定维度上增加一个维度。以下是unsqueeze()
函数的语法:
torch.unsqueeze(input, dim)
其中,input
是要增加维度的张量,dim
是要增加的维度的索引。例如,如果要在第二个维度上增加一个维度,可以使用以下代码:
import torch
x = torch.randn(3, 4)
y = torch.unsqueeze(x, dim=1)
print(x.shape) # 输出: torch.Size([3, 4])
print(y.shape) # 输出: torch.Size([3, 1, 4])
在上面的示例中,我们创建了一个形状为(3, 4)
的张量x
,然后使用unsqueeze()
函数在第二个维度上增加了一个维度,得到了一个形状为(3, 1, 4)
的张量y
。
2. squeeze()
函数
squeeze()
函数用于删除张量中的所有大小为1的维度。以下是squeeze()
函数的语法:
torch.squeeze(input, dim=None)
其中,input
是要删除维度的张量,dim
是要删除的维度的索引。如果不指定dim
参数,则删除所有大小为1的维度。例如,如果要删除第二个维度上的大小为1的维度,可以使用以下代码:
import torch
x = torch.randn(3, 1, 4)
y = torch.squeeze(x, dim=1)
print(x.shape) # 输出: torch.Size([3, 1, 4])
print(y.shape) # 输出: torch.Size([3, 4])
在上面的示例中,我们创建了一个形状为(3, 1, 4)
的张量x
,然后使用squeeze()
函数删除了第二个维度上的大小为1的维度,得到了一个形状为(3, 4)
的张量y
。
3. 示例
以下是一个使用unsqueeze()
和squeeze()
函数的示例,用于将一个形状为(3, 4)
的张量转换为一个形状为(1, 3, 2, 2)
的张量,然后再将其转换回原始形状。
import torch
# 创建一个形状为(3, 4)的张量
x = torch.randn(3, 4)
print(x.shape) # 输出: torch.Size([3, 4])
# 将张量转换为形状为(1, 3, 2, 2)的张量
y = torch.unsqueeze(x.view(1, 3, 2, 2), dim=0)
print(y.shape) # 输出: torch.Size([1, 3, 2, 2])
# 将张量转换回原始形状
z = torch.squeeze(y, dim=0).view(3, 4)
print(z.shape) # 输出: torch.Size([3, 4])
在上面的示例中,我们首先创建了一个形状为(3, 4)
的张量x
,然后使用view()
函数将其转换为一个形状为(1, 3, 2, 2)
的张量,并使用unsqueeze()
函数在第一个维度上增加了一个维度,得到了一个形状为(1, 3, 2, 2)
的张量y
。最后,我们使用squeeze()
函数删除了第一个维度上的大小为1的维度,并使用view()
函数将其转换回原始形状(3, 4)
,得到了一个形状为(3, 4)
的张量z
。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:详解pytorch中squeeze()和unsqueeze()函数介绍 - Python技术站