在PyTorch中,torch.cat
是一个非常有用的函数,它可以将多个张量沿着指定的维度拼接在一起。本文将介绍torch.cat
的用法和示例。
用法
torch.cat
的用法如下:
torch.cat(tensors, dim=0, out=None) -> Tensor
其中,tensors
是要拼接的张量序列,dim
是要沿着的维度,out
是输出张量。如果out
未提供,则会创建一个新的张量来存储结果。
示例一:沿着行拼接两个张量
我们可以使用torch.cat
函数沿着行拼接两个张量。示例代码如下:
import torch
# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6]])
# 沿着行拼接两个张量
c = torch.cat((a, b), dim=0)
print(c)
在上述代码中,我们首先创建了两个张量a
和b
,其中a
的形状为(2, 2)
,b
的形状为(1, 2)
。接着,我们使用torch.cat
函数沿着行拼接了这两个张量,得到了一个形状为(3, 2)
的新张量c
。
示例二:沿着列拼接两个张量
除了沿着行拼接,我们还可以使用torch.cat
函数沿着列拼接两个张量。示例代码如下:
import torch
# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5], [6]])
# 沿着列拼接两个张量
c = torch.cat((a, b), dim=1)
print(c)
在上述代码中,我们首先创建了两个张量a
和b
,其中a
的形状为(2, 2)
,b
的形状为(2, 1)
。接着,我们使用torch.cat
函数沿着列拼接了这两个张量,得到了一个形状为(2, 3)
的新张量c
。
总结
本文介绍了torch.cat
函数的用法和示例。torch.cat
函数可以将多个张量沿着指定的维度拼接在一起,非常方便。我们可以使用torch.cat
函数沿着行或列拼接两个张量,得到一个新的张量。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中的torch.cat简单介绍 - Python技术站