PyTorch人工智能之torch.gather算子用法示例
torch.gather
是PyTorch中的一个重要算子,用于在指定维度上收集输入张量中指定索引处的值。在本文中,我们将介绍torch.gather
的用法,并提供两个示例说明。
torch.gather
的用法
torch.gather
的语法如下:
torch.gather(input, dim, index, out=None)
其中,参数含义如下:
input
:输入张量。dim
:指定维度。index
:索引张量。out
:输出张量。
torch.gather
的作用是在input
张量的dim
维度上,根据index
张量中的索引,收集input
张量中对应位置的值,并将结果存储在输出张量out
中。
示例1:使用torch.gather
实现分类任务
下面是一个示例,演示了如何使用torch.gather
实现分类任务:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = torch.relu(x)
x = self.pool(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
# 加载数据集,并使用DataLoader创建数据加载器
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 进行模型训练
for epoch in range(10):
for i, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# 打印训练日志
print('Epoch: %d, Batch: %d, Loss: %.4f' % (epoch+1, i+1, loss.item()))
# 在测试集上测试模型
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
# 打印测试日志
print('Epoch: %d, Test Accuracy: %.2f%%' % (epoch+1, 100 * correct / total))
在这个示例中,我们首先定义了一个包含卷积层、池化层、全连接层等的卷积神经网络。然后,我们加载了CIFAR10数据集,并使用DataLoader
创建了数据加载器。然后,我们定义了一个交叉熵损失函数和一个SGD优化器。最后,我们进行了模型训练,并在测试集上测试了模型的泛化能力。
示例2:使用torch.gather
实现双线性插值
下面是一个示例,演示了如何使用torch.gather
实现双线性插值:
import torch
def bilinear_interpolate(im, x, y):
# 获取图像大小
_, _, h, w = im.size()
# 计算四个相邻像素的坐标
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
# 确保坐标不超出图像范围
x0 = torch.clamp(x0, 0, w-1)
x1 = torch.clamp(x1, 0, w-1)
y0 = torch.clamp(y0, 0, h-1)
y1 = torch.clamp(y1, 0, h-1)
# 计算四个相邻像素的权重
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
# 收集四个相邻像素的值,并进行双线性插值
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]
out = wa.unsqueeze(1) * Ia + wb.unsqueeze(1) * Ib + wc.unsqueeze(1) * Ic + wd.unsqueeze(1) * Id
return out
# 创建输入张量
im = torch.randn(1, 3, 4, 4)
# 创建坐标张量
x = torch.tensor([0.5, 1.5, 2.5, 3.5])
y = torch.tensor([0.5, 1.5, 2.5, 3.5])
# 进行双线性插值
out = bilinear_interpolate(im, x, y)
# 打印结果
print(out)
在这个示例中,我们首先定义了一个bilinear_interpolate
函数,用于实现双线性插值。然后,我们创建了一个输入张量im
和一个坐标张量x
和y
。最后,我们使用torch.gather
实现了双线性插值,并打印了结果。
总结
本文介绍了torch.gather
算子的用法,并提供了两个示例说明。在实现过程中,我们使用torch.gather
实现了分类任务和双线性插值,展示了torch.gather
的强大功能。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch人工智能之torch.gather算子用法示例 - Python技术站