主要记录在Tensorflow2中使用Keras API接口,有关模型保存、加载的内容;
0. 加载数据、构建网络
首先,为了方便后续有关模型保存、加载相关代码的正常执行,这里加载mnist数据集、构建一个简单的网络结构。
import tensorflow as tf
from libs.load_keras_dataset import load_mnist
注意:下面引入mnist数据集的方式,仅为了方便作者从本地加载、使用;
mnist_path = \'/home/chenz/data/mnist/mnist.npz\'
(x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
x_train.shape, y_train.shape, x_test.shape, y_test.shape
))
train_labels = y_train[:1000]
test_labels = y_test[:1000]
train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
test_images = x_test[:1000].reshape(-1, 28*28) / 255.0
print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
))
[INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
[INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)
定义一个方法,用于构建网络结构,并定义网络编译方式,方便后续使用;
# Build Model
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(512, activation=\'relu\', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
1. model.save() & model.save_weights()
在TensorFlow的Keras API中提供了两种保存模型的方式,分别为model.save()
、model.save_weights()
,从字面上可以简单理解,后者仅保存网络结构权重,前者能够保存整个模型结构;
进一步,从源码文档中可以理清两者的区别:
1.1 model.save()
该方法能够将整个模型进行保存,以两种方式存储,Tensorflow SavedModel
、HDF file
,保存的文件包括:
- 模型结构,能够重新实例化模型;
- 模型权重;
- 优化器的状态,在上次中断的地方继续训练;
可以通过tf.keras.models.load_model
重新实例化保存的模型,通过该方法返回的模型是已经编译过的模型,除非在之前保存模型的时候就没有被编译;
利用Sequential
和Functional
两种形式构建的网络都能够保存成HDF5和SavedModel格式,但是Subclasses
形式的模型仅能够保存成SavedModel格式;
# HDF5格式
model_name.h5
# Tensorflow SavedModel格式
./saved_model
assets/
saved_model.pb
variables/
使用参数说明:
def save(self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None):
-
filepath
表示模型存储的路径; -
save_format
表示以tf
或者h5
形式进行存储,在TF2中默认tf
,TF1中默认h5
; -
overwrite
表示是否覆盖在目标目录下的已有文件; -
include_optimizer
表示是否保存优化器的状态; -
signatures
仅用于tf
形式,具体使用见tf.saved_model.save
;
filepath
和save_format
结合在一起使用,有如下组合方式:
-
filepath
以.h5
为结尾的文件名,则不论save_format
是tf
或者h5
,则模型将保存成filename.h5
形式;(上级目录需要存在) -
filepath
仅指定文件名,save_format=\'h5\'
,则模型将保存成filename
的HDF形式; -
filepath
指定路径(需存在),save_format=\'tf\'
,则模型将以Tensorflow SavedModel
形式保存到指定路径下;
注意:filepath
不包含后缀时,注意区分是文件目录还是文件名,以tf
形式保存,则需要存在指定路径,以h5
形式保存,则不能存在相同名称路径;
1.2 model.save_weights()
该方法仅保存网络中所有层的权重,
# HDF5格式
weights_2 or weights_3.h5
# Tensorflow 格式
checkpoint
weiths_1.data-00000-of-00001
weigths_1.index
使用参数说明:
def save_weights(self,
filepath,
overwrite=True,
save_format=None,
options=None):
-
filepath
表示存储的模型文件名或路径; -
save_format
用于表示存储格式,HDF5
或者Tensorflow
格式;
filepath
与save_format
结合使用:
-
filepath
以后缀.h5
或者.keras
结尾,设置save_format=None
或者save_format=None
,模型将保存成filename.h5
或filename.keras
格式; -
filepath
不含后缀,如果save_format=\'h5\'
,则模型保存成filename
; -
filepath
不含后缀,如果save_format=\'tf\'
或者save_format=None
,则模型保存成Tensorflow
格式;
2. tf.keras.callbacks.ModelCheckpoint
该方法以回调函数的形式,在模型训练过程中保存模型。
def __init__(self,
filepath,
monitor=\'val_loss\',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode=\'auto\',
save_freq=\'epoch\',
options=None,
**kwargs):
这里仅提及一点,就是在使用参数save_weights_only
时:
- 设置
True
,则调用model.save_weights()
; - 设置
False
,则调用model.save()
;
使用方式:
checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=False,
verbose=1)
model.fit(train_images, train_labels,
epochs=3,
validation_data=(test_images, test_labels),
callbacks=[cp_callback])
3. tf.keras.models.load_model、model.load_weights
上面简单说明了模型保存的两种方式,一种是保存整个模型,另一种则是仅保存模型权重;
完整的模型可以使用tf.keras.models.load_model
加载,只包含权重的模型则使用model.load_weights
加载;
3.1 tf.keras.models.load_model
加载完整模型
model_path = \'./saved_model/save_and_load/save_test/test_5/\'
model = tf.keras.models.load_model(model_path)
model.summary()
- 其中,
model_path
可以为.h5
文件的路径,或者Tensorflow SavedModel
的路径
3.2 model.load_weights
在重新构建网络的基础上,加载模型权重;
model = create_model()
model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
model.summary()
4. 总结
- 官方API是推荐Tensorflow格式进行保存模型,不论是保存整个模型,或是仅保存权重;
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFlow2中Keras模型保存与加载 - Python技术站