PyTorch保存模型断点以及加载断点继续训练

yizhihongxing

 

 

 

在训练神经网络时,用到的数据量可能很大,训练周期较长,如果半途中断了训练,下次从头训练就会很费时间,这时我们就想断点续训。

一、神经网络模型的保存,基本两种方式:
1. 保存完整模型model, torch.save(model, save_path) 

2. 只保存模型的参数, torch.save(model.state_dict(), save_path) ,多卡训练的话,在保存参数时,使用 model.module.state_dict( ) 。

二、保存模型的断点checkpoint

断点dictionary中一般保存训练的网络的权重参数、优化器的状态、学习率 lr_scheduler 的状态以及epoch 。

checkpoint = {'parameter': model.module.state_dict(),
              'optimizer': optimizer.state_dict(),
              'scheduler': scheduler.state_dict(),
              'epoch': epoch}
 torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))

三、加载断点继续训练

if resume: # True
load_ckpt = torch.load(ckpt_dir, map_location=device)                                 # 从断点路径加载断点,指定加载到CPU内存或GPU
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
                                      if model.state_dict()[k].numel() == v.numel()}  # 简单验证
model.load_state_dict(load_weights_dict, strict=False) 

# 如果是多卡训练,加载weights后要设置DDP模式,然后先定义一下optimizer和scheduler,之后再加载断点中保存的optimizer和scheduler以及设置epoch,
optimizer.load_state_dict(load_ckpt[
'optimizer']) # 加载优化器状态 scheduler.load_state_dict(load_ckpt['scheduler']) # 加载scheduler状态
start_epoch
= load_ckpt['epoch']+1 # 设定继续训练的epoch起点 iter_epochs = range(start_epoch, args.epochs) # arg.epochs指出训练的总epoch数,包括断点前的训练次数

 

 

 

 

 

Enjoy it!

原文链接:https://www.cnblogs.com/booturbo/p/17358917.html

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch保存模型断点以及加载断点继续训练 - Python技术站

(0)
上一篇 2023年4月27日
下一篇 2023年4月27日

相关文章

  • 使用 Estimator 构建卷积神经网络

    1,tf.layers基础函数 conv2d(). Constructs a two-dimensional convolutional layer. Takes number of filters, filter kernel size, padding, and activation function as arguments. max_pooling2…

    卷积神经网络 2023年4月6日
    00
  • tensor搭建–windows 10 64bit下安装Tensorflow+Keras+VS2015+CUDA8.0 GPU加速

    原文见于:http://www.jianshu.com/p/c245d46d43f0   作者 xushiluo 关注 2016.12.21 20:32* 字数 3096 阅读 12108评论 18喜欢 19 写在前面的话 2016年11月29日,Google Brain 工程师团队宣布在 TensorFlow 0.12 中加入初步的 Windows 支持。…

    2023年4月8日
    00
  • TensorFlow入门——bazel编译(带GPU)

    这一系列基本上是属于我自己进行到了那个步骤就做到那个步骤的 由于新装了GPU (GTX750ti)和CUDA9.0、CUDNN7.1版本的软件,所以希望TensorFlow能在GPU上运行,也算上补上之前的承诺 说了下初衷,由于现在新的CUDA版本对TensorFlow的支持不好,只能采取编译源码的方式进行 所以大概分为以下几个步骤 1.安装依赖库(这部分我…

    tensorflow 2023年4月8日
    00
  • 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在《手写数字识别——手动搭建全连接层》一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配、梯度计算、准确度的统计等问题,但是这样的实践对机器学习的理解是大有裨益的。在大多数情况下,我们还是希望能多简单就多简单地去搭建网络模型,这同时也算对得起TensorFlow这个强大的工具了。本节,还是以手…

    Keras 2023年4月6日
    00
  • keras写模型时遇到的典型问题,也是最基础的类与对象问题

    自己定义了一个卷积类,现在需要把卷积加入model中,我的操作是这样的: model.add(Convolution1dLayer) 这样就会报错: 正确的写法是: model.add(Convolution1dLayer()) 原因是Convolution1dLayer仅仅是一个类,但model需要添加的层必须是实例(对象),必须把类实例化后才能添加。 实…

    Keras 2023年4月6日
    00
  • keras系列︱图像多分类训练与利用bottleneck features进行微调(三)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72861152 中文文档:http://keras-cn.readthedocs.io/en/latest/  官方文档:https://keras.io/  文档主要是以keras2.0。 训练、训练主要就”练“嘛,所以堆几个案例就知道怎么做了。…

    2023年4月8日
    00
  • caffe中train过程的train数据集、val数据集、test时候的test数据集区别

    val是validation的简称。training dataset 和 validation dataset都是在训练的时候起作用。而因为validation的数据集和training没有交集,所以这部分数据对最终训练出的模型没有贡献。validation的主要作用是来验证是否过拟合、以及用来调节训练参数等。 比如你训练0-10000次迭代过程中,trai…

    Caffe 2023年4月5日
    00
  • TensorFlow实战6——TensorFlow实现VGGNet_16_D

    1 #coding = utf-8 2 from datetime import datetime 3 import tensorflow as tf 4 import time 5 import math 6 7 def conv_op(input_op, name, kh, kw, n_out, dh, dw, p): 8 n_in = input_op…

    tensorflow 2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部