PyTorch笔记之scatter()函数的使用

PyTorch笔记之scatter()函数的使用

在PyTorch中,scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中。本文将介绍scatter()函数的用法,并提供两个示例说明。

1. scatter()函数的用法

scatter()函数的语法如下:

torch.scatter(input, dim, index, src)

其中,input表示目标张量,dim表示要分散的维度,index表示要分散的索引,src表示要分散的源张量。

2. 示例1:使用scatter()函数将数据分散到目标张量中

以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中。

import torch

# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)

# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])

# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])

# 使用scatter()函数将数据分散到目标张量中
torch.scatter(target, 1, index.unsqueeze(1), source)

# 打印目标张量
print(target)

在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter()函数将源张量source中的数据按照索引张量index分散到目标张量target中。最后,我们打印目标张量target,发现它的值为:

tensor([[1., 0., 2.],
        [0., 0., 4.]])

3. 示例2:使用scatter()函数将数据分散到目标张量中,并进行累加

以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中,并进行累加。

import torch

# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)

# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])

# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])

# 使用scatter()函数将数据分散到目标张量中,并进行累加
torch.scatter_add(target, 1, index.unsqueeze(1), source)

# 打印目标张量
print(target)

在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter_add()函数将源张量source中的数据按照索引张量index分散到目标张量target中,并进行累加。最后,我们打印目标张量target,发现它的值为:

tensor([[1., 0., 2.],
        [0., 0., 4.]])

4. 总结

本文介绍了PyTorch中scatter()函数的用法,并提供了两个示例说明。scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中,非常实用。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch笔记之scatter()函数的使用 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • pytorch中Parameter函数用法示例

    PyTorch中Parameter函数用法示例 在PyTorch中,Parameter函数是一个特殊的张量,它被自动注册为模型的可训练参数。本文将介绍Parameter函数的用法,并演示两个示例。 示例一:使用Parameter函数定义可训练参数 import torch import torch.nn as nn class MyModel(nn.Modu…

    PyTorch 2023年5月15日
    00
  • 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图           在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率。sigmod函数曾经是比较流行的,它可以想象成一个神经元的放电率,在中间斜率比较大的地方是神经元的敏感区,在两边斜率很平缓的地方是神经元的抑制区。 当然,流行也是曾…

    2023年4月8日
    00
  • pytorch将部分参数进行加载

    参考:https://blog.csdn.net/LXX516/article/details/80124768 示例代码: 加载相同名称的模块 pretrained_dict=torch.load(model_weight) model_dict=myNet.state_dict() # 1. filter out unnecessary keys pre…

    PyTorch 2023年4月6日
    00
  • Pytorch 张量维度

      Tensor类的成员函数dim()可以返回张量的维度,shape属性与成员函数size()返回张量的具体维度分量,如下代码定义了一个两行三列的张量:   f = torch.randn(2, 3)   print(f.dim())   print(f.size())   print(f.shape)   输出结果:   2   torch.Size([2…

    PyTorch 2023年4月8日
    00
  • Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现。   1. 稀疏编码 首先介绍一下“稀疏编码”这一概念。        早期学者在黑白风景照片中可以提取到许多16*16像素的图像碎片。而这些图像碎片几乎都可由64种正交的边组合得到。而且组合出一张碎…

    2023年4月8日
    00
  • pytorch提取神经网络模型层结构和参数初始化

    torch.nn.Module()类有一些重要属性,我们可用其下面几个属性来实现对神经网络层结构的提取: torch.nn.Module.children() torch.nn.Module.modules() torch.nn.Module.named_children() torch.nn.Module.named_moduless() 为方面说明,我们…

    2023年4月8日
    00
  • pytorch函数之nn.Linear

    class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b   参数: in_features – 每个输入样本的大小 out_features – 每个输出样本的大小 bias – 如果设置为False,则图层不会学习附加偏差。默认值:Tru…

    PyTorch 2023年4月7日
    00
  • pytorch torchversion自带的数据集

        from torchvision.datasets import MNIST # import torchvision # torchvision.datasets. #准备数据集 mnist = MNIST(root=”./mnist”,train=True,download=True) print(mnist) mnist[0][0].show(…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部