pytorch人工智能之torch.gather算子用法示例

PyTorch人工智能之torch.gather算子用法示例

torch.gather是PyTorch中的一个重要算子,用于在指定维度上收集输入张量中指定索引处的值。在本文中,我们将介绍torch.gather的用法,并提供两个示例说明。

torch.gather的用法

torch.gather的语法如下:

torch.gather(input, dim, index, out=None)

其中,参数含义如下:

  • input:输入张量。
  • dim:指定维度。
  • index:索引张量。
  • out:输出张量。

torch.gather的作用是在input张量的dim维度上,根据index张量中的索引,收集input张量中对应位置的值,并将结果存储在输出张量out中。

示例1:使用torch.gather实现分类任务

下面是一个示例,演示了如何使用torch.gather实现分类任务:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        return x

# 加载数据集,并使用DataLoader创建数据加载器
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # 打印训练日志
        print('Epoch: %d, Batch: %d, Loss: %.4f' % (epoch+1, i+1, loss.item()))

    # 在测试集上测试模型
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    # 打印测试日志
    print('Epoch: %d, Test Accuracy: %.2f%%' % (epoch+1, 100 * correct / total))

在这个示例中,我们首先定义了一个包含卷积层、池化层、全连接层等的卷积神经网络。然后,我们加载了CIFAR10数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个交叉熵损失函数和一个SGD优化器。最后,我们进行了模型训练,并在测试集上测试了模型的泛化能力。

示例2:使用torch.gather实现双线性插值

下面是一个示例,演示了如何使用torch.gather实现双线性插值:

import torch

def bilinear_interpolate(im, x, y):
    # 获取图像大小
    _, _, h, w = im.size()

    # 计算四个相邻像素的坐标
    x0 = torch.floor(x).long()
    x1 = x0 + 1
    y0 = torch.floor(y).long()
    y1 = y0 + 1

    # 确保坐标不超出图像范围
    x0 = torch.clamp(x0, 0, w-1)
    x1 = torch.clamp(x1, 0, w-1)
    y0 = torch.clamp(y0, 0, h-1)
    y1 = torch.clamp(y1, 0, h-1)

    # 计算四个相邻像素的权重
    wa = (x1.float() - x) * (y1.float() - y)
    wb = (x1.float() - x) * (y - y0.float())
    wc = (x - x0.float()) * (y1.float() - y)
    wd = (x - x0.float()) * (y - y0.float())

    # 收集四个相邻像素的值,并进行双线性插值
    Ia = im[:, :, y0, x0]
    Ib = im[:, :, y1, x0]
    Ic = im[:, :, y0, x1]
    Id = im[:, :, y1, x1]
    out = wa.unsqueeze(1) * Ia + wb.unsqueeze(1) * Ib + wc.unsqueeze(1) * Ic + wd.unsqueeze(1) * Id

    return out

# 创建输入张量
im = torch.randn(1, 3, 4, 4)

# 创建坐标张量
x = torch.tensor([0.5, 1.5, 2.5, 3.5])
y = torch.tensor([0.5, 1.5, 2.5, 3.5])

# 进行双线性插值
out = bilinear_interpolate(im, x, y)

# 打印结果
print(out)

在这个示例中,我们首先定义了一个bilinear_interpolate函数,用于实现双线性插值。然后,我们创建了一个输入张量im和一个坐标张量xy。最后,我们使用torch.gather实现了双线性插值,并打印了结果。

总结

本文介绍了torch.gather算子的用法,并提供了两个示例说明。在实现过程中,我们使用torch.gather实现了分类任务和双线性插值,展示了torch.gather的强大功能。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch人工智能之torch.gather算子用法示例 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • pytorch seq2seq闲聊机器人beam search返回结果

    decoder.py “”” 实现解码器 “”” import heapq import torch.nn as nn import config import torch import torch.nn.functional as F import numpy as np import random from chatbot.attention impor…

    PyTorch 2023年4月8日
    00
  • pytorch提取神经网络模型层结构和参数初始化

    torch.nn.Module()类有一些重要属性,我们可用其下面几个属性来实现对神经网络层结构的提取: torch.nn.Module.children() torch.nn.Module.modules() torch.nn.Module.named_children() torch.nn.Module.named_moduless() 为方面说明,我们…

    2023年4月8日
    00
  • pytorch和tensorflow的爱恨情仇之张量

    pytorch和tensorflow的爱恨情仇之基本数据类型:https://www.cnblogs.com/xiximayou/p/13759451.html pytorch版本:1.6.0 tensorflow版本:1.15.0 基本概念:标量、一维向量、二维矩阵、多维张量。 1、pytorch中的张量 (1)通过torch.Tensor()来建立常量 …

    2023年4月8日
    00
  • pytorch中的卷积和池化计算方式详解

    PyTorch中的卷积和池化计算方式 在PyTorch中,卷积和池化是深度学习中非常重要的一部分。在本文中,我们将详细介绍PyTorch中的卷积和池化计算方式,并提供两个示例。 示例1:使用PyTorch中的卷积计算方式 以下是一个使用PyTorch中的卷积计算方式的示例代码: import torch import torch.nn as nn # Def…

    PyTorch 2023年5月16日
    00
  • pytorch中nn.RNN()总结

    nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False) 参数说明 input_size输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就…

    PyTorch 2023年4月6日
    00
  • pytorch中.pth文件转成.bin的二进制文件

    model_dict = torch.load(save_path) fp = open(‘model_parameter.bin’, ‘wb’) weight_count = 0 num=1 for k, v in model_dict.items(): print(k,num) num=num+1 if ‘num_batches_tracked’ in …

    PyTorch 2023年4月7日
    00
  • pytorch中如何使用DataLoader对数据集进行批处理的方法

    PyTorch中使用DataLoader对数据集进行批处理的方法 在PyTorch中,DataLoader是一个非常有用的工具,它可以用来对数据集进行批处理。本文将详细介绍如何使用DataLoader对数据集进行批处理,并提供两个示例来说明其用法。 1. 创建数据集 在使用DataLoader对数据集进行批处理之前,我们需要先创建一个数据集。以下是一个示例,…

    PyTorch 2023年5月15日
    00
  • Lubuntu安装Pytorch

    PyTorch官方对于PyTorch的定位为: 一个使用GPU加速的numpy替换库 一个深度学习研究平台,提高最大灵活度和速度 具体点来讲, PyTorch是一个Python包,是Torch在Python上的衍生,原先的Torch是用Lua语言写的,虽然效率高,但是普及度不够,社区不够大,改成Python后,受众范围广泛了许多。并且有FaceBook这样的…

    2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部