PyTorch保存模型断点以及加载断点继续训练

 

 

 

在训练神经网络时,用到的数据量可能很大,训练周期较长,如果半途中断了训练,下次从头训练就会很费时间,这时我们就想断点续训。

一、神经网络模型的保存,基本两种方式:
1. 保存完整模型model, torch.save(model, save_path) 

2. 只保存模型的参数, torch.save(model.state_dict(), save_path) ,多卡训练的话,在保存参数时,使用 model.module.state_dict( ) 。

二、保存模型的断点checkpoint

断点dictionary中一般保存训练的网络的权重参数、优化器的状态、学习率 lr_scheduler 的状态以及epoch 。

checkpoint = {'parameter': model.module.state_dict(),
              'optimizer': optimizer.state_dict(),
              'scheduler': scheduler.state_dict(),
              'epoch': epoch}
 torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))

三、加载断点继续训练

if resume: # True
load_ckpt = torch.load(ckpt_dir, map_location=device)                                 # 从断点路径加载断点,指定加载到CPU内存或GPU
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
                                      if model.state_dict()[k].numel() == v.numel()}  # 简单验证
model.load_state_dict(load_weights_dict, strict=False) 

# 如果是多卡训练,加载weights后要设置DDP模式,然后先定义一下optimizer和scheduler,之后再加载断点中保存的optimizer和scheduler以及设置epoch,
optimizer.load_state_dict(load_ckpt[
'optimizer']) # 加载优化器状态 scheduler.load_state_dict(load_ckpt['scheduler']) # 加载scheduler状态
start_epoch
= load_ckpt['epoch']+1 # 设定继续训练的epoch起点 iter_epochs = range(start_epoch, args.epochs) # arg.epochs指出训练的总epoch数,包括断点前的训练次数

 

 

 

 

 

Enjoy it!

原文链接:https://www.cnblogs.com/booturbo/p/17358917.html

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch保存模型断点以及加载断点继续训练 - Python技术站

(0)
上一篇 2023年4月27日
下一篇 2023年4月27日

相关文章

  • 目标检测算法-YOLO算法纵向对比理解

    目标检测算法-YOLO算法纵向对比理解 DeepLearning的目标检测任务主要有两大类:一段式,两段式 其中两段式主要包括RCNN、FastRCNN、FasterRCNN为代表, 一段式主要包括YOLO,SSD等算法 由于一段式直接在最后进行分类(判断所属类别)和回归(标记物体的位置框框),所以现在一段式大有发展。 YOLO v1 论文地址:You On…

    目标检测 2023年4月7日
    00
  • Tensorflow遇到的问题

    问题1、自定义loss function,y_true shape多一个维度 def nce_loss(y_true, y_pred): y_true = tf.reshape(y_true, [-1]) y_true = tf.linalg.diag(y_true) ret = tf.keras.metrics.categorical_crossentro…

    tensorflow 2023年4月8日
    00
  • AMD 处理器 Ubuntu 16.04 LTS 配置 opencv、caffe 小结

    上个随笔讲了在windows 上安装 caffe,并且 跑mnist 这个例程的过程,说真的,就像奶妈一样,每一步都得给奶才干活。最近配置了一台台式机,可以作为以后自己配置学习机的参考。 配置如下:补图。   电脑概览 电脑型号 兼容机操作系统 Ubuntu 16.04 LTSCPU AMD Ryzen 7 1700X Eight-Core Processo…

    Caffe 2023年4月5日
    00
  • [傅里叶变换及其应用学习笔记] 九. 继续卷积的讨论

    这份是本人的学习笔记,课程为网易公开课上的斯坦福大学公开课:傅里叶变换及其应用。   浑浊度(Turbidity)研究是关于测量水的清澈度的研究。大致方法是把光传感器放置到深水区域,然后测量光线的昏暗程度,测量出来的值将随时间变化。 (由于没有真实数据,下面用mathematica比较粗糙地模拟水域的浑浊度数据)         能看到信号主要集中在低频,我…

    2023年4月7日
    00
  • pytorch保存模型和导入模型以及预训练模型

    参考 model.state_dict()中保存了{参数名:参数值}的字典 import torchvision.models as models resnet34 = models.resnet34(pretrained=True) resnet34.state_dict().keys() for param in resnet34.parameters(…

    PyTorch 2023年4月8日
    00
  • python和tensorflow安装

    一、Python安装       python采用anaconda安装,简单方便,下载python3.6的anaconda  linux64的sh安装文件.       1、bash Anaconda-2.1.0-Linux-x86_64.sh       2、python,用于测试     二、Tensorflow安装   1、首先安装 pip (或 Py…

    tensorflow 2023年4月8日
    00
  • CNN之经典卷积网络框架原理

    一、GoogleNet 1、原理介绍        inception 结构   如下图所示,输入数据经过一分四,然后做一些大小不同的卷积,之后再堆叠feature map             inception结构可以理解为把一个输入数据先通过一个1*1的卷积核进行降维然后再通过四个卷积核(分别为1*1,3*3,5*5,maxpooling)进行升维运…

    2023年4月8日
    00
  • [caffe笔记]:杀死caffe多个进程中的某个(发生 leveldb lock 解决方法)

    1.leveldb lock 当运行caffe发生意外停止时,再重新运行训练会发生如下错误: Check failed: status.ok() Failed to open leveldb dish_train_leveldb IO error: lock dish_train_leveldb/LOCK: Resource temporarily unav…

    2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部