下面是tensorflow 保存模型和取出中间权重的完整攻略,包含两条示例说明。
标准流程
TensorFlow中训练好的模型需要保存下来,以便在需要时进行加载和使用。保存模型需要进行两步,第一步是定义saver,第二步是运行saver实例的save方法。加载模型需要进行两步,第一步是定义saver,第二步是运行saver实例的restore方法。
保存模型
定义saver
import tensorflow as tf
# 定义网络结构
...
# 创建Saver
saver = tf.train.Saver()
运行saver实例的save方法
with tf.Session() as sess:
# 执行训练过程
...
# 保存训练好的模型
saver.save(sess, 'model.ckpt')
其中,model.ckpt
为保存的模型文件的名称。
加载模型
定义saver
import tensorflow as tf
# 定义网络结构
...
# 创建Saver
saver = tf.train.Saver()
运行saver实例的restore方法
with tf.Session() as sess:
# 加载训练好的模型
saver.restore(sess, 'model.ckpt')
# 进行预测或测试
...
其中,model.ckpt
为保存的模型文件的名称。
示例一:保存和加载全部变量
下面我们来看一个保存和加载全部变量的示例。
定义网络结构
import tensorflow as tf
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
创建Saver实例
saver = tf.train.Saver()
执行训练过程和保存模型
with tf.Session() as sess:
tf.global_variables_initializer().run()
# Train
for i in range(1000):
...
# Save Model
saver.save(sess, 'model.ckpt')
加载模型
with tf.Session() as sess:
# Load Model
saver.restore(sess, 'model.ckpt')
# Test
...
示例二:保存和加载部分变量
下面我们来看一个保存和加载部分变量的示例。
定义网络结构
import tensorflow as tf
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]), name="W")
b = tf.Variable(tf.zeros([10]), name="b")
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
创建Saver实例
saver = tf.train.Saver({'W': W, 'b': b})
执行训练过程和保存模型
with tf.Session() as sess:
tf.global_variables_initializer().run()
# Train
for i in range(1000):
...
# Save Model
saver.save(sess, 'model.ckpt')
加载模型
with tf.Session() as sess:
# Load Model
saver.restore(sess, 'model.ckpt')
# Test
...
在创建Saver实例时传递一个字典,其中键是要保存的变量的名称,值是对应的变量。在加载模型时,只需要传递和保存时相同的变量名即可。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow 保存模型和取出中间权重例子 - Python技术站