下面我就来详细讲解tensorflow模型保存、加载之变量重命名实例的完整攻略。
一、tensorflow模型保存和加载
在tensorflow中,我们通常使用saver对象来保存和加载模型,saver对象是一个tensorflow中的类,用来保存变量,模型,图的实例(saver可以将变量数值作为numpy数组或TensorFlow张量对待,不用在 session 中取回张量)。saver有三个基本方法:
save()
: 将模型保存到磁盘restore()
: 从磁盘恢复模型export_meta_graph()
: 将模型导出到.meta文件
下面分别介绍一下这3个方法的用法。
1. save()方法
通常在tensorflow的训练过程中,我们需要保存一些中间结果(例如训练后的模型),这时候我们就可以使用save()方法将模型保存到磁盘上,它的用法如下:
saver.save(sess, 'model/model.ckpt')
其中,sess是tensorflow的Session对象,‘model/model.ckpt’是模型的保存路径,.ckpt是tensorflow默认的模型文件扩展名。如果开启了tensorboard,则.saver文件将会被写入到上述目录下
2. restore()方法
当你需要从保存的模型中重载参数时,你可以使用restore()方法,它的用法如下:
saver.restore(sess, 'model/model.ckpt')
其中,sess是tensorflow的Session对象,‘model/model.ckpt’是模型的保存路径。在这个过程中,所有的变量和张量被重新拉入到你的本地环境中,这时候你可以使用它们做其它的事情了。
3. export_meta_graph()方法
export_meta_graph()方法可以将TensorFlow图导出为.meta文件,.meta也是tensorflow默认的模型文件扩展名。使用方法如下:
tf.train.export_meta_graph(filename)
其中filename为.meta文件路径,这样做的好处是在训练和测试模型时可以调用相同的模型,只需加载.meta文件。.meta文件中定义了计算图中的所有变量、op、图结构,保存了操作的类型、输入/输出张量的形状和类型,和节点名称。
二、变量重命名实例
我们知道,在节点定义的计算图中,tensor和operation节点的名称是非常重要的,因为它们通常是与其他节点链接的主要方式。对于小型计算图,我们可以直接手动为每个变量制定一个唯一的变量名,但是对于复杂的计算图和具有数百个变量的神经网络,这项任务变得更加困难。
为了解决这个问题,我们可以使用变量作用域(variable scope)来指定名称空间,为变量命名。在这个过程中,我们通常会使用tf.variable_scope()来定义变量作用域。
假设我们有一个用来训练MNIST数据集的简单的线性模型,为了更好的演示变量重命名实例,我们将模型分为两部分:第一部分是输入和权重的初始化,第二部分是模型计算和损失函数。
import tensorflow as tf
def linear_model(inputs, scope='linear_model'):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
# Linear model
w = tf.get_variable('w', (784, 10), initializer=tf.random_normal_initializer())
b = tf.get_variable('b', (1, 10), initializer=tf.zeros_initializer())
logits = tf.matmul(inputs, w) + b
# Loss
targets = tf.placeholder(tf.float32, (None, 10))
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
# Optimizer
optimizer = tf.train.GradientDescentOptimizer(0.5)
train_step = optimizer.minimize(cross_entropy)
# Output
output = {'logits': logits, 'targets': targets, 'train_step': train_step}
return output
上面的代码实现了一个线性模型,我们使用tf.variable_scope()来定义变量作用域,然后使用tf.get_variable()来获取变量,这里的tf.get_variable()函数具有自动重用的功能,使得我们在执行restore training checkpoints的时候非常方便。
在这个模型中,我们使用tf.AUTO_REUSE设置tf.variable_scope()来重用变量,使用tf.get_variable()获取变量名和该变量的维度和初始化器,然后再对模型进行计算。我们还在模型中定义了损失函数cross_entropy和优化器train_step,最后将它们保存在一个字典中并返回。
下面,我们来演示一个变量重命名的实例:
import tensorflow as tf
def linear_model(inputs, scope='linear_model'):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
# Linear model
w = tf.get_variable('weights', (784, 10), initializer=tf.random_normal_initializer())
b = tf.get_variable('biases', (1, 10), initializer=tf.zeros_initializer())
logits = tf.matmul(inputs, w) + b
# Loss
targets = tf.placeholder(tf.float32, (None, 10), name='targets')
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets), name='cross_entropy')
# Optimizer
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5)
train_step = optimizer.minimize(cross_entropy, name='train_step')
# Output
output = {'logits': logits, 'targets': targets, 'train_step': train_step}
return output
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "model/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("weights:", sess.run('weights:0'))
print("biases:", sess.run('biases:0'))
在上述代码中,我们为w和b变量重新命名为‘weights’和‘biases’,并使用tf.nn.softmax_cross_entropy_with_logits()和 tf.train.GradientDescentOptimizer()的参数中的name参数给损失函数和优化器命名。
在模型训练结束后,保存下模型,并作以下测试:
with tf.Session() as sess:
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
sess.run(init)
# Train the model
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
_, loss_val = sess.run([model['train_step'], model['cross_entropy']], feed_dict={model['inputs']: batch_xs, model['targets']: batch_ys})
if i % 50 == 0:
print('Step: %s, Loss: %s' % (i, loss_val))
# Save the model
save_path = saver.save(sess, "model/model.ckpt")
print("Model saved in file: %s" % save_path)
在最后,我们使用了tf.Session()来打开一个会话(Session),读取了保存的模型文件,并检查了其参数值(‘weights’和‘biases’)。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow模型保存、加载之变量重命名实例 - Python技术站