主要记录在Tensorflow2中使用Keras API接口,有关模型保存、加载的内容;

0. 加载数据、构建网络

首先,为了方便后续有关模型保存、加载相关代码的正常执行,这里加载mnist数据集、构建一个简单的网络结构。

import tensorflow as tf
from libs.load_keras_dataset import load_mnist

注意:下面引入mnist数据集的方式,仅为了方便作者从本地加载、使用;

mnist_path = \'/home/chenz/data/mnist/mnist.npz\'
(x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
    x_train.shape, y_train.shape, x_test.shape, y_test.shape
))
train_labels = y_train[:1000]
test_labels = y_test[:1000]
train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
test_images = x_test[:1000].reshape(-1, 28*28) / 255.0

print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
    train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
))
[INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
[INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)

定义一个方法,用于构建网络结构,并定义网络编译方式,方便后续使用;

# Build Model
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(512, activation=\'relu\', input_shape=(784,)),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10)
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
                  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

1. model.save() & model.save_weights()

在TensorFlow的Keras API中提供了两种保存模型的方式,分别为model.save()model.save_weights(),从字面上可以简单理解,后者仅保存网络结构权重,前者能够保存整个模型结构

进一步,从源码文档中可以理清两者的区别:

1.1 model.save()

该方法能够将整个模型进行保存,以两种方式存储,Tensorflow SavedModelHDF file,保存的文件包括:

  • 模型结构,能够重新实例化模型;
  • 模型权重;
  • 优化器的状态,在上次中断的地方继续训练;

可以通过tf.keras.models.load_model重新实例化保存的模型,通过该方法返回的模型是已经编译过的模型,除非在之前保存模型的时候就没有被编译;

利用SequentialFunctional两种形式构建的网络都能够保存成HDF5和SavedModel格式,但是Subclasses形式的模型仅能够保存成SavedModel格式;

# HDF5格式
model_name.h5

# Tensorflow SavedModel格式
./saved_model
	assets/
	saved_model.pb
	variables/

使用参数说明:

def save(self,
           filepath,
           overwrite=True,
           include_optimizer=True,
           save_format=None,
           signatures=None,
           options=None):
  • filepath表示模型存储的路径;

  • save_format表示以tf或者h5形式进行存储,在TF2中默认tf,TF1中默认h5

  • overwrite表示是否覆盖在目标目录下的已有文件;

  • include_optimizer表示是否保存优化器的状态;

  • signatures仅用于tf形式,具体使用见tf.saved_model.save

filepathsave_format结合在一起使用,有如下组合方式:

  • filepath.h5为结尾的文件名,则不论save_formattf或者h5,则模型将保存成filename.h5形式;(上级目录需要存在)
  • filepath仅指定文件名,save_format=\'h5\',则模型将保存成filename的HDF形式;
  • filepath指定路径(需存在),save_format=\'tf\',则模型将以Tensorflow SavedModel形式保存到指定路径下;

注意:filepath不包含后缀时,注意区分是文件目录还是文件名,以tf形式保存,则需要存在指定路径,以h5形式保存,则不能存在相同名称路径;

1.2 model.save_weights()

该方法仅保存网络中所有层的权重,

# HDF5格式
weights_2 or weights_3.h5

# Tensorflow 格式
checkpoint 
weiths_1.data-00000-of-00001
weigths_1.index

使用参数说明:

def save_weights(self,
                 filepath,
                 overwrite=True,
                 save_format=None,
                 options=None):
  • filepath表示存储的模型文件名或路径;
  • save_format用于表示存储格式,HDF5或者Tensorflow格式;

filepathsave_format结合使用:

  • filepath以后缀.h5或者.keras结尾,设置save_format=None或者save_format=None,模型将保存成filename.h5filename.keras格式;
  • filepath不含后缀,如果save_format=\'h5\',则模型保存成filename
  • filepath不含后缀,如果save_format=\'tf\'或者save_format=None,则模型保存成Tensorflow格式;

2. tf.keras.callbacks.ModelCheckpoint

该方法以回调函数的形式,在模型训练过程中保存模型。

def __init__(self,
             filepath,
             monitor=\'val_loss\',
             verbose=0,
             save_best_only=False,
             save_weights_only=False,
             mode=\'auto\',
             save_freq=\'epoch\',
             options=None,
             **kwargs):

这里仅提及一点,就是在使用参数save_weights_only时:

  • 设置True,则调用model.save_weights()
  • 设置False,则调用model.save()

使用方式:

checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=False,
                                                 verbose=1)
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])

3. tf.keras.models.load_model、model.load_weights

上面简单说明了模型保存的两种方式,一种是保存整个模型,另一种则是仅保存模型权重;

完整的模型可以使用tf.keras.models.load_model加载,只包含权重的模型则使用model.load_weights加载;

3.1 tf.keras.models.load_model

加载完整模型

model_path = \'./saved_model/save_and_load/save_test/test_5/\'

model = tf.keras.models.load_model(model_path)
model.summary()
  • 其中,model_path可以为.h5文件的路径,或者Tensorflow SavedModel的路径

3.2 model.load_weights

在重新构建网络的基础上,加载模型权重;

model = create_model()
model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
model.summary()

4. 总结

  • 官方API是推荐Tensorflow格式进行保存模型,不论是保存整个模型,或是仅保存权重;