PyTorch是一个非常流行的深度学习框架,可以用于训练和部署神经网络模型。在训练好一个模型后,我们需要将其保存下来以便后续使用。PyTorch提供了.pth格式来保存模型的参数,本文将详细讲解如何加载.pth格式的模型实例。
- 加载.pth格式的模型实例
在PyTorch中,可以使用torch.load函数来加载.pth格式的模型实例。以下是加载.pth格式的模型实例的示例代码:
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return nn.functional.log_softmax(x, dim=1)
# 加载模型
model = Net()
model.load_state_dict(torch.load('model.pth'))
model.eval()
这个示例中,我们定义了一个名为Net的神经网络模型,并使用torch.load函数来加载.pth格式的模型实例。我们使用torch.nn.Module类来定义模型,并使用torch.nn.functional类来定义模型的前向传播函数。在加载模型后,我们使用model.eval()函数来将模型设置为评估模式。
- 加载.pth格式的模型实例并进行推理
在加载.pth格式的模型实例后,我们可以使用它来进行推理。以下是加载.pth格式的模型实例并进行推理的示例代码:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return nn.functional.log_softmax(x, dim=1)
# 加载模型
model = Net()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 加载数据
test_dataset = MNIST(root='./data', train=False, download=True, transform=ToTensor())
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
# 进行推理
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
这个示例中,我们加载了一个名为model.pth的.pth格式的模型实例,并使用它来进行推理。我们使用torchvision.datasets.MNIST类来加载MNIST数据集,并使用torchvision.transforms.ToTensor类来将数据转换为张量。在进行推理时,我们使用torch.no_grad()上下文管理器来禁用梯度计算,并使用torch.max函数来获取输出张量中每行的最大值及其索引。我们还使用predicted == target来计算正确分类的数量,并使用total来计算总的样本数量。最后,我们计算模型的准确率并输出结果。
总结
本文详细讲解了如何加载.pth格式的模型实例,并提供了两个示例代码来说明如何加载.pth格式的模型实例并进行推理。在实际应用中,我们可以根据自己的需求来修改示例代码,并将其应用到自己的项目中。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 加载(.pth)格式的模型实例 - Python技术站