下面是pytorch中with torch.no_grad()的用法实例的攻略:
1. 什么是torch.no_grad()
在深度学习模型训练过程中,模型的前向传播和反向传播计算中都需要计算梯度,以便于更新参数。但在模型预测时,我们并不需要计算梯度,因此使用torch.no_grad()可以临时关闭该计算图的梯度计算操作。这可以减小模型权重对显存的占用,同时也加快了计算速度。
2. 示例说明
下面我们通过两个示例来说明怎样使用torch.no_grad()。
示例1:运行一个训练好的模型,生成预测结果
我们先构建一个简单的线性模型,在MNIST数据集上进行训练。当模型训练好之后,我们也许会想利用该模型在测试集上生成预测值。
import torch
import torch.nn as nn
# 构建线性模型
class LinearModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(784, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.linear(x)
return x
model = LinearModel()
# 加载训练好的模型参数
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
# 加载测试集数据
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('data/', train=False, download=True),
batch_size=128, shuffle=False)
# 生成预测值
model.eval() # 将模型切换到评估模式,关闭Dropout和BN的计算
predictions = []
with torch.no_grad(): # 关闭梯度计算
for x, y in test_loader:
x = x.cuda()
y_hat = model(x)
predictions.append(y_hat.argmax(dim=1).cpu())
predictions = torch.cat(predictions)
上面这个例子中,我们首先定义了 LinearModel ,并加载了model.pth
中训练好的模型参数。然后,我们将模型切换到评估模式(即关闭了Dropout和BN的计算),并使用 with torch.no_grad()
进行包裹,来关闭自动求导功能。在这个模式下,代码所做的一切操作,都不会影响模型的权重和偏移的更新。最后,我们遍历了测试集,并生成了预测值。
示例2:计算模型的评估指标
我们来看一个实际的计算模型评估指标的例子,比如准确率。
def evaluate(model, data_loader):
correct, total = 0, 0
model.eval()
with torch.no_grad():
for x, y in data_loader:
x, y = x.cuda(), y.cuda()
y_hat = model(x)
label = y_hat.argmax(dim=1)
correct += (label == y).sum().item()
total += y.size(0)
acc = correct / total
return acc
# 计算模型在验证集上的准确率
val_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('data/', train=False, download=True),
batch_size=128, shuffle=True)
val_acc = evaluate(model, val_loader)
print('Model accuracy on validation set: {:.2f}%'.format(val_acc*100))
在这个例子中,我们定义了一个用于计算准确率的函数,函数的输入是模型和数据集的DataLoader。在函数执行中,我们遍历了data_loader中的数据,计算出正确预测的样本数和总测试样本数,然后计算准确率。由于我们仍然处于评估状态,所以我们再次使用了with torch.no_grad()
。
这两个示例说明了在不需要进行梯度计算或更新模型参数的情况下,使用 torch.no_grad()
可以加快模型运行速度,同时也可以释放GPU显存。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch中with torch.no_grad():的用法实例 - Python技术站