PyTorch笔记之scatter()函数的使用

yizhihongxing

PyTorch笔记之scatter()函数的使用

在PyTorch中,scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中。本文将介绍scatter()函数的用法,并提供两个示例说明。

1. scatter()函数的用法

scatter()函数的语法如下:

torch.scatter(input, dim, index, src)

其中,input表示目标张量,dim表示要分散的维度,index表示要分散的索引,src表示要分散的源张量。

2. 示例1:使用scatter()函数将数据分散到目标张量中

以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中。

import torch

# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)

# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])

# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])

# 使用scatter()函数将数据分散到目标张量中
torch.scatter(target, 1, index.unsqueeze(1), source)

# 打印目标张量
print(target)

在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter()函数将源张量source中的数据按照索引张量index分散到目标张量target中。最后,我们打印目标张量target,发现它的值为:

tensor([[1., 0., 2.],
        [0., 0., 4.]])

3. 示例2:使用scatter()函数将数据分散到目标张量中,并进行累加

以下是一个示例,展示如何使用scatter()函数将数据分散到目标张量中,并进行累加。

import torch

# 创建一个形状为(2, 3)的目标张量
target = torch.zeros(2, 3)

# 创建一个形状为(2, 2)的源张量
source = torch.tensor([[1, 2], [3, 4]])

# 创建一个形状为(2,)的索引张量
index = torch.tensor([0, 2])

# 使用scatter()函数将数据分散到目标张量中,并进行累加
torch.scatter_add(target, 1, index.unsqueeze(1), source)

# 打印目标张量
print(target)

在上面的示例中,我们首先创建了一个形状为(2, 3)的目标张量target,一个形状为(2, 2)的源张量source,以及一个形状为(2,)的索引张量index。然后,我们使用scatter_add()函数将源张量source中的数据按照索引张量index分散到目标张量target中,并进行累加。最后,我们打印目标张量target,发现它的值为:

tensor([[1., 0., 2.],
        [0., 0., 4.]])

4. 总结

本文介绍了PyTorch中scatter()函数的用法,并提供了两个示例说明。scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中,非常实用。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch笔记之scatter()函数的使用 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch模型的保存/复用/迁移实现代码

    PyTorch是一个流行的深度学习框架,它提供了许多内置的模型保存、复用和迁移方法。在本攻略中,我们将介绍如何使用PyTorch实现模型的保存、复用和迁移。 模型的保存 在PyTorch中,我们可以使用torch.save()函数将模型保存到磁盘上。以下是一个示例代码,演示了如何保存模型: import torch import torch.nn as nn…

    PyTorch 2023年5月15日
    00
  • pytorch转onnx常见问题

    一、Type Error: Type ‘tensor(bool)’ of input parameter (121) of operator (ScatterND) in node (ScatterND_128) is invalid 问题模型转出成功后,用onnxruntime加载,出现不支持参数问题, 这里出现tensor(bool)是因为代码中使用了b…

    2023年4月8日
    00
  • Pytorch 之激活函数

    1. Sigmod 函数    Sigmoid 函数是应用最广泛的非线性激活函数之一,它可以将值转换为 $0$ 和 $1$ 之间,如果原来的输出具有这样的特点:值越大,归为某类的可能性越大,    那么经过 Sigmod 函数处理的输出就可以代表属于某一类别的概率。其数学表达式为: $$y = frac{1}{1 + e^{-x}} = frac{e^{x}…

    2023年4月6日
    00
  • pytorch 中HWC转CHW

    import torch import numpy as np from torchvision.transforms import ToTensor t = torch.tensor(np.arange(24).reshape(2,4,3)) print(t) #HWC 转CHW print(t.transpose(0,2).transpose(1,2))…

    PyTorch 2023年4月8日
    00
  • pytorch __init__、forward与__call__的用法小结

    在PyTorch中,我们通常使用nn.Module类来定义神经网络模型。在定义模型时,我们需要实现__init__()、forward()和__call__()方法。这些方法分别用于初始化模型参数、定义前向传播过程和调用模型。 init()方法 init()方法用于初始化模型参数。在该方法中,我们通常定义模型的各个层,并初始化它们的参数。以下是一个示例代码,…

    PyTorch 2023年5月15日
    00
  • 使用pytorch测试单张图片(test single image with pytorch)

    以下代码实现使用pytorch测试一张图片 引用文章: https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/ from __future__ import print_function, division from PI…

    PyTorch 2023年4月7日
    00
  • pytorch简单框架

    网络搭建: mynn.py: import torchfrom torch import nnclass mynn(nn.Module): def __init__(self): super(mynn, self).__init__() self.layer1 = nn.Sequential( nn.Linear(3520, 4096), nn.BatchN…

    PyTorch 2023年4月8日
    00
  • Pytorch模型量化

    在深度学习中,量化指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算。这么做的好处主要有如下几点: 更少的模型体积,接近4倍的减少; 可以更快的计算,由于更少的内存访问和更快的int8计算,可以快2~4倍。 一个量化后的模型,其部分或者全部的tensor操作会使用int类型来计算,而不是使用量化之前的…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部