PyTorch模型的保存与加载方法实例

yizhihongxing

以下是PyTorch模型的保存与加载方法实例的详细攻略:

PyTorch提供了多种方法来保存和加载模型,包括使用pickle、torch.save和torch.load等方法。以下是使用torch.save和torch.load方法保存和加载模型的详细步骤:

  1. 定义模型并训练模型。

```python
import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = x.view(-1, 16 * 5 * 5)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x

# 定义数据集和数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 定义模型、损失函数和优化器
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

print('Finished Training')
```

  1. 保存模型。

python
# 保存模型
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

这个代码会将模型的参数保存到指定的文件中。

  1. 加载模型。

python
# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

这个代码会从指定的文件中加载模型的参数,并将其应用到新的模型中。

以下是两个示例说明:

示例1:使用保存的模型进行预测

以下是一个使用保存的模型进行预测的示例代码:

import torch
import torchvision.transforms as transforms
from PIL import Image

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

# 加载图片并进行预测
image = Image.open('test.jpg')
image = transform(image)
image = image.unsqueeze(0)
output = net(image)
_, predicted = torch.max(output, 1)
print(predicted)

在这个示例中,我们首先定义了数据预处理方式,然后使用Net类加载模型,并使用load_state_dict方法从文件中加载模型的参数。接着,我们使用PIL库加载图片,并进行数据处理。最后,我们使用训练好的模型对图片进行预测,并输出预测结果。

示例2:使用保存的模型进行微调

以下是一个使用保存的模型进行微调的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 定义数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

# 将模型的最后一层替换为新的全连接层
num_ftrs = net.fc3.in_features
net.fc3 = nn.Linear(num_ftrs, 2)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

在这个示例中,我们首先定义了数据预处理方式,然后使用CIFAR10类加载数据集。接着,我们使用Net类加载模型,并使用load_state_dict方法从文件中加载模型的参数。我们将模型的最后一层替换为新的全连接层,并使用交叉熵损失函数和随机梯度下降优化器来微调模型。最后,我们使用微调后的模型对测试集进行预测,并输出预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch模型的保存与加载方法实例 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • pytorch(二十一):交叉验证

    一、K折交叉验证 将训练集分成K份,一份做验证集,其他做测试集。这K份都有机会做验证集             二、代码 1 import torch 2 import torch.nn as nn 3 import torchvision 4 from torchvision import datasets,transforms 5 from torch.…

    PyTorch 2023年4月7日
    00
  • 贝叶斯个性化排序(BPR)pytorch实现

    一、BPR算法的原理: 1、贝叶斯个性化排序(BPR)算法小结https://www.cnblogs.com/pinard/p/9128682.html2、Bayesian Personalized Ranking 算法解析及Python实现https://www.cnblogs.com/wkang/p/10217172.html3、推荐系统中的排序学习ht…

    2023年4月8日
    00
  • pytorch网络参数初始化

    在定义网络时,pythorch会自己初始化参数,但也可以自己初始化,详见官方实现 for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode=’fan_out’, nonlinearity=’relu’) elif isinstanc…

    PyTorch 2023年4月8日
    00
  • 教你两步解决conda安装pytorch时下载速度慢or超时的问题

    当我们使用conda安装PyTorch时,有时会遇到下载速度慢或超时的问题。本文将介绍两个解决方案,帮助您快速解决这些问题。 解决方案一:更换清华源 清华源是国内比较稳定的镜像源之一,我们可以将conda的镜像源更换为清华源,以加速下载速度。具体步骤如下: 打开Anaconda Prompt或终端,输入以下命令: conda config –add cha…

    PyTorch 2023年5月15日
    00
  • PyTorch 之 强大的 hub 模块和搭建神经网络进行气温预测

    PyTorch之强大的hub模块和搭建神经网络进行气温预测 在PyTorch中,我们可以使用hub模块来加载预训练的模型,也可以使用它来分享和重用模型组件。在本文中,我们将介绍如何使用hub模块来加载预训练的模型,并使用它来搭建神经网络进行气温预测,并提供两个示例说明。 示例1:使用hub模块加载预训练的模型 以下是一个使用hub模块加载预训练的模型的示例代…

    PyTorch 2023年5月16日
    00
  • pytorch实现手动线性回归

    import torch import matplotlib.pyplot as plt learning_rate = 0.1 #准备数据 #y = 3x +0.8 x = torch.randn([500,1]) y_true = 3*x + 0.8 #计算预测值 w = torch.rand([],requires_grad=True) b = tor…

    2023年4月8日
    00
  • PyTorch的自适应池化Adaptive Pooling实例

    PyTorch的自适应池化Adaptive Pooling实例 在 PyTorch 中,自适应池化(Adaptive Pooling)是一种常见的池化操作,它可以根据输入的大小自动调整池化的大小。本文将详细讲解 PyTorch 中自适应池化的实现方法,并提供两个示例说明。 1. 二维自适应池化 在 PyTorch 中,我们可以使用 nn.AdaptiveAv…

    PyTorch 2023年5月16日
    00
  • pytorch程序异常后删除占用的显存操作

    在本攻略中,我们将介绍如何在PyTorch程序异常后删除占用的显存操作。我们将使用try-except语句和torch.cuda.empty_cache()函数来实现这个功能。 删除占用的显存操作 在PyTorch程序中,如果出现异常,可能会导致一些变量或模型占用显存。如果不及时清理这些占用的显存,可能会导致显存不足,从而导致程序崩溃。为了避免这种情况,我们…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部