PyTorch笔记之scatter()函数的使用
在PyTorch中,scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中。本文将介绍scatter()函数的用法,并提供两个示例说明。
1. scatter()函数的用法
scatter()函数的语法如下:
torch.scatter(input, dim, index, src)
其中,input表示目标张量,dim表示要分散的维度,index表示要分散的索引,src表示要分散的源张量。
2. 示例1:使用scatter()函数将数据分散到目标张量中
以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中。
import torch
# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)
# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])
# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])
# 使用scatter()函数将数据分散到目标张量中
torch.scatter(target, 1, index.unsqueeze(1), source)
# 打印目标张量
print(target)
在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter()函数将源张量source中的数据按照索引张量index分散到目标张量target中。最后,我们打印目标张量target,发现它的值为:
tensor([[1., 0., 2.],
[0., 0., 4.]])
3. 示例2:使用scatter()函数将数据分散到目标张量中,并进行累加
以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中,并进行累加。
import torch
# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)
# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])
# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])
# 使用scatter()函数将数据分散到目标张量中,并进行累加
torch.scatter_add(target, 1, index.unsqueeze(1), source)
# 打印目标张量
print(target)
在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter_add()函数将源张量source中的数据按照索引张量index分散到目标张量target中,并进行累加。最后,我们打印目标张量target,发现它的值为:
tensor([[1., 0., 2.],
[0., 0., 4.]])
4. 总结
本文介绍了PyTorch中scatter()函数的用法,并提供了两个示例说明。scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中,非常实用。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch笔记之scatter()函数的使用 - Python技术站