PyTorch模型的保存与加载方法实例

以下是PyTorch模型的保存与加载方法实例的详细攻略:

PyTorch提供了多种方法来保存和加载模型,包括使用pickle、torch.save和torch.load等方法。以下是使用torch.save和torch.load方法保存和加载模型的详细步骤:

  1. 定义模型并训练模型。

```python
import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = x.view(-1, 16 * 5 * 5)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x

# 定义数据集和数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 定义模型、损失函数和优化器
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

print('Finished Training')
```

  1. 保存模型。

python
# 保存模型
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

这个代码会将模型的参数保存到指定的文件中。

  1. 加载模型。

python
# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

这个代码会从指定的文件中加载模型的参数,并将其应用到新的模型中。

以下是两个示例说明:

示例1:使用保存的模型进行预测

以下是一个使用保存的模型进行预测的示例代码:

import torch
import torchvision.transforms as transforms
from PIL import Image

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

# 加载图片并进行预测
image = Image.open('test.jpg')
image = transform(image)
image = image.unsqueeze(0)
output = net(image)
_, predicted = torch.max(output, 1)
print(predicted)

在这个示例中,我们首先定义了数据预处理方式,然后使用Net类加载模型,并使用load_state_dict方法从文件中加载模型的参数。接着,我们使用PIL库加载图片,并进行数据处理。最后,我们使用训练好的模型对图片进行预测,并输出预测结果。

示例2:使用保存的模型进行微调

以下是一个使用保存的模型进行微调的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 定义数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 加载模型
net = Net()
net.load_state_dict(torch.load(PATH))

# 将模型的最后一层替换为新的全连接层
num_ftrs = net.fc3.in_features
net.fc3 = nn.Linear(num_ftrs, 2)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

在这个示例中,我们首先定义了数据预处理方式,然后使用CIFAR10类加载数据集。接着,我们使用Net类加载模型,并使用load_state_dict方法从文件中加载模型的参数。我们将模型的最后一层替换为新的全连接层,并使用交叉熵损失函数和随机梯度下降优化器来微调模型。最后,我们使用微调后的模型对测试集进行预测,并输出预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch模型的保存与加载方法实例 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • linux中anaconda环境下pytorch的安装(conda安装本地包)

    跑代码的时候遇到和这位博主几乎一模一样的问题,安装的也是同一版本。目前清华源已经停止服务,如果要自己下载pytorch包的话估计只能在官网下载了。 原文:https://blog.csdn.net/summer2day/article/details/88652934 pytorch的安装(1)版本查看查看cuda版本cat /usr/local/cuda/…

    PyTorch 2023年4月8日
    00
  • 运行pytorch代码遇到的error解决办法

    1.no CUDA-capable device is detected 首先考虑的是cuda的驱动问题,查看gpu显示是否正常,然后更新最新的cuda驱动; 第二个考虑的是cuda设备的默认参数是否修改,平常一块显卡的设置是0,多块可能会修改此参数: CUDA_VISIBLE_DEVICES=”3″  ,把它修改为0即可解决。 2.out of gpu m…

    PyTorch 2023年4月7日
    00
  • 深度学习环境搭建常用网址、conda/pip命令行整理(pytorch、paddlepaddle等环境搭建)

    前言:最近研究深度学习,安装了好多环境,记录一下,方便后续查阅。 1. Anaconda软件安装 1.1 Anaconda Anaconda是一个用于科学计算的Python发行版,支持Linux、Mac、Windows,包含了众多流行的科学计算、数据分析的Python包。请自行到官网下载安装,下载速度太慢的话可移步清华源。 官网:https://repo.a…

    2023年4月8日
    00
  • [pytorch]动态调整学习率

    问题描述 在深度学习的过程中,会需要有调节学习率的需求,一种方式是直接通过手动的方式进行调节,即每次都保存一个checkpoint,但这种方式的缺点是需要盯着训练过程,会很浪费时间。因此需要设定自动更新学习率的方法,让模型自适应地调整学习率。 解决思路 通过epoch来动态调整,比如每10次学习率为原来的0.1 实现示例: def adjust_learni…

    PyTorch 2023年4月8日
    00
  • 教你一分钟在win10终端成功安装Pytorch的方法步骤

    PyTorch安装教程 PyTorch是一个基于Python的科学计算库,它支持GPU加速,提供了丰富的神经网络模块,可以用于自然语言处理、计算机视觉、强化学习等领域。本文将提供详细的PyTorch安装教程,以帮助您在Windows 10上成功安装PyTorch。 步骤一:安装Anaconda 在开始安装PyTorch之前,您需要先安装Anaconda。An…

    PyTorch 2023年5月16日
    00
  • pytorch tensor计算三通道均值方式

    以下是PyTorch计算三通道均值的两个示例说明。 示例1:计算图像三通道均值 在这个示例中,我们将使用PyTorch计算图像三通道均值。 首先,我们需要准备数据。我们将使用torchvision库来加载图像数据集。您可以使用以下代码来加载数据集: import torchvision.datasets as datasets import torchvis…

    PyTorch 2023年5月15日
    00
  • PyTorch环境配置及安装过程

    以下是PyTorch环境配置及安装过程的完整攻略,包括Windows、macOS和Linux三个平台的安装步骤。同时,还提供了两个示例说明。 Windows平台 1. 安装Anaconda 在Windows平台上,我们可以使用Anaconda来安装PyTorch。首先,我们需要下载并安装Anaconda。可以在官网上下载对应的安装包,然后按照提示进行安装。 …

    PyTorch 2023年5月16日
    00
  • Pytorch 加载保存模型,进行模型推断【直播】2019 年县域农业大脑AI挑战赛—(三)保存结果

    在模型训练结束,结束后,通常是一个分割模型,输入 1024×1024 输出 4x1024x1024。 一种方法就是将整个图切块,然后每张预测,但是有个不好处就是可能在边界处断续。   由于这种切块再预测很ugly,所以直接遍历整个图预测(这就是相当于卷积啊),防止边界断续,还有一个问题就是防止图过大不能超过20M。 很有意思解决上边的问题。话也不多说了。直接…

    2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部