PyTorch保存模型断点以及加载断点继续训练

 

 

 

在训练神经网络时,用到的数据量可能很大,训练周期较长,如果半途中断了训练,下次从头训练就会很费时间,这时我们就想断点续训。

一、神经网络模型的保存,基本两种方式:
1. 保存完整模型model, torch.save(model, save_path) 

2. 只保存模型的参数, torch.save(model.state_dict(), save_path) ,多卡训练的话,在保存参数时,使用 model.module.state_dict( ) 。

二、保存模型的断点checkpoint

断点dictionary中一般保存训练的网络的权重参数、优化器的状态、学习率 lr_scheduler 的状态以及epoch 。

checkpoint = {'parameter': model.module.state_dict(),
              'optimizer': optimizer.state_dict(),
              'scheduler': scheduler.state_dict(),
              'epoch': epoch}
 torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))

三、加载断点继续训练

if resume: # True
load_ckpt = torch.load(ckpt_dir, map_location=device)                                 # 从断点路径加载断点,指定加载到CPU内存或GPU
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
                                      if model.state_dict()[k].numel() == v.numel()}  # 简单验证
model.load_state_dict(load_weights_dict, strict=False) 

# 如果是多卡训练,加载weights后要设置DDP模式,然后先定义一下optimizer和scheduler,之后再加载断点中保存的optimizer和scheduler以及设置epoch,
optimizer.load_state_dict(load_ckpt[
'optimizer']) # 加载优化器状态 scheduler.load_state_dict(load_ckpt['scheduler']) # 加载scheduler状态
start_epoch
= load_ckpt['epoch']+1 # 设定继续训练的epoch起点 iter_epochs = range(start_epoch, args.epochs) # arg.epochs指出训练的总epoch数,包括断点前的训练次数

 

 

 

 

 

Enjoy it!

原文链接:https://www.cnblogs.com/booturbo/p/17358917.html

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch保存模型断点以及加载断点继续训练 - Python技术站

(0)
上一篇 2023年4月27日
下一篇 2023年4月27日

相关文章

  • Tensorflow安装以及RuntimeError: The Session graph is empty. Add operations to the graph before calling run().解决方法

    之前装过pytorch,但是很多老的机器学习代码都是tensorflow,所以没办法,还要装个tensorflow。 在安装之前还要安装nvidia驱动还有cudn之类的,这些我已经在之前的篇章介绍过,就不在这细说了,可以直接传送过去看。那么前面这些搞完,直接运行下面的命令:pip install –upgrade tensorflow-gpu 上面这行命…

    tensorflow 2023年4月8日
    00
  • 深度学习之卷积和池化

    转载:http://www.cnblogs.com/zf-blog/p/6075286.html 卷积神经网络(CNN)由输入层、卷积层、激活函数、池化层、全连接层组成,即INPUT-CONV-RELU-POOL-FC (1)卷积层:用它来进行特征提取,如下: 输入图像是32*32*3,3是它的深度(即R、G、B),卷积层是一个5*5*3的filter(感受…

    2023年4月8日
    00
  • 机器学习-数据可视化神器matplotlib学习之路(一)

    直接上代码吧,说明写在备注就好了,这次主要学习一下基本的画图方法和常用的图例图标等 from matplotlib import pyplot as plt import numpy as np #这里是最最基本的代码了 #x轴-2到2均分50个点 x = np.linspace(-2, 2, 50) y = x**2 plt.plot(x, y) plt.…

    机器学习 2023年4月13日
    00
  • keras中的Flatten和Reshape

    最近在看SSD源码的时候,就一直不理解,在模型构建的时候如果使用Flatten或者是Merge层,那么整个数据的shape就发生了变化,那么还可以对应起来么(可能你不知道我在说什么)?后来不知怎么的,就想明白了,只要先前按照同样的方式进行操作,那么就可以对应起来。同样的,只要按照之前操作的逆操作,就可以将数据的shape进行还原。 最后在说一句,在追看Ten…

    Keras 2023年4月6日
    00
  • 循环神经网络中Dropout的应用(转)

    https://blog.csdn.net/wangli0519/article/details/75208155 循环神经网络(RNNs)是基于序列的模型,对自然语言理解、语言生成、视频处理和其他许多任务至关重要。模型的输入是一个符号序列,在每个时间点一个简单的神经网络(RNN单元)应用于一个符号,以及此前时间点的网络输出。RNNs是强大的模型,在许多任务…

    2023年4月8日
    00
  • Keras猫狗大战十:输出Resnet50分类热力图

    图像分类识别中,可以根据热力图来观察模型根据图片的哪部分决定图片属于一个分类。 以前面的Resnet50模型为例:https://www.cnblogs.com/zhengbiqing/p/11964301.html 输出模型结构为: model.summary() ______________________________________________…

    Keras 2023年4月7日
    00
  • Caffe Python MemoryDataLayer Segmentation Fault

    http://home.cnblogs.com/louyihang-loves-baiyan/ 因为利用Pyhon来做数据的预处理比较方便,因此在data_layer选择上,采用了MemoryDataLayer,可以比较方便的直接用Python 根据set_input_array进行feed数据,然后再调用solver进行step就可以了。说一下我碰到的问题…

    Caffe 2023年4月8日
    00
  • 使用Python+TensorFlow2构建基于卷积神经网络(CNN)的ECG心电信号识别分类(一)

    本篇博客以及之后的一个系列,我将记录下我是如何从一个没学过信号处理,不懂什么是深度学习,没接触过心电信号的小白,一步步做出基于CNN的心电信号识别分类的过程。网络上关于ECG方面的相关博客内容不多,可以直接运行的相关代码也寥寥无几,这给初学者造成了很大的困难。希望通过自己的总结和整理能够帮助自己更好的理解这些知识和技术,也能够为同为新接触这方面研究的小伙伴们…

    2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部