PyTorch 中的 torch.cat()
函数是用来将张量按照给定的维度进行拼接的函数。在这里,我们将详细讲解该函数的使用。本攻略将包含以下内容:
torch.cat()
函数的基本格式及参数说明;- 两个具体的示例,分别说明如何进行张量拼接。
1. torch.cat()
函数的基本格式及参数说明
torch.cat()
函数的基本格式如下:
torch.cat(tensors, dim=0, out=None) -> Tensor
其中,各个参数的含义如下:
tensors
:需要拼接的张量序列,可以是元组、列表或者其他包含张量的可迭代类型。dim
:拼接维度,指定在哪个维度上进行拼接操作。默认为0,即在第0维进行拼接操作。out
:可选参数,表示输出张量,如果未指定,则该函数会自动创建一个新的张量。
2. 两个具体的示例,分别说明如何进行张量拼接
在本节中,我们将给出两个具体的示例,分别说明如何使用 torch.cat()
函数进行张量拼接。
示例一
在这个示例中,我们先定义两个大小相同的张量,并将它们按照第0维拼接。
import torch
# 定义两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 按照第0维进行拼接
c = torch.cat([a, b], dim=0)
print(c)
输出结果如下:
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
在这个示例中,我们首先创建了两个大小相同的 $2\times 2$ 张量 a
和 b
,然后使用 torch.cat()
函数将它们按照第0维进行拼接,得到了一个大小为 $4\times 2$ 的张量 c
。
示例二
在这个示例中,我们先定义两个大小不同的张量,并将它们按照第1维拼接。此外,我们还将使用 out
参数指定输出张量。
import torch
# 定义两个张量,大小不同
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([[7], [8], [9]])
# 按照第1维进行拼接
c = torch.cat([a, b], dim=1, out=torch.zeros(3, 3))
print(c)
输出结果如下:
tensor([[1., 2., 7.],
[3., 4., 8.],
[5., 6., 9.]])
在这个示例中,我们定义了两个张量,其中 a
的大小为 $3\times 2$,而 b
的大小为 $3\times 1$。我们使用 torch.cat()
函数在第1维进行拼接,并将输出张量指定为一个 $3\times 3$ 的全零张量。在这个示例中,我们可以看到,由于 a
和 b
的第一维大小不同,因此在拼接时需要将 b
这个张量在第1维进行扩展,使其大小与 a
相同。最终,我们得到了一个 $3\times 3$ 的张量 c
。
至此,关于 PyTorch 中的 torch.cat()
函数的介绍到此结束。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch中torch.cat()函数举例解析 - Python技术站