为了实现从checkpoint中获取TensorFlow的Graph信息,可以使用TensorFlow提供的tf.train.import_meta_graph()和tf.train.Saver()两个函数结合起来。具体步骤如下:
- 加载checkpoint模型
import tensorflow as tf
checkpoint_path = "model.ckpt"
saver = tf.train.import_meta_graph(checkpoint_path + ".meta")
saver.restore(sess, checkpoint_path)
- 获取Graph
graph = tf.get_default_graph()
- 获取Graph中的所有操作
for op in graph.get_operations():
print(op.name)
- 获取指定操作的张量
tensor = graph.get_tensor_by_name("input_tensors:0")
- 获取指定Tensor的shape信息
shape = tensor.get_shape().as_list()
print(shape)
下面是两个示例:
示例1:获取保存在checkpoint中的VGG16模型,并获取该模型中某个卷积层的权重张量。
import tensorflow as tf
import numpy as np
checkpoint_path = "vgg16.ckpt"
saver = tf.train.import_meta_graph(checkpoint_path + ".meta")
sess = tf.Session()
saver.restore(sess, checkpoint_path)
graph = tf.get_default_graph()
# 打印所有操作的name
for op in graph.get_operations():
print(op.name)
# 获取指定操作的张量
w_conv1_1 = graph.get_tensor_by_name("conv1_1/weights:0")
# 打印该张量的shape信息
print(w_conv1_1.get_shape().as_list())
# 获取该张量的值
w_conv1_1_value = sess.run(w_conv1_1)
# 打印该张量的值
print(w_conv1_1_value)
示例2:获取保存在checkpoint中的BERT模型,并获取该模型中某个embedding层的输入张量。
import tensorflow as tf
checkpoint_path = "bert_model.ckpt"
saver = tf.train.import_meta_graph(checkpoint_path + ".meta")
sess = tf.Session()
saver.restore(sess, checkpoint_path)
graph = tf.get_default_graph()
# 打印所有操作的name
for op in graph.get_operations():
print(op.name)
# 获取指定操作的张量
input_tensor = graph.get_tensor_by_name("input_ids:0")
# 打印该张量的shape信息
print(input_tensor.get_shape().as_list())
可以根据自己的需求,调整以上示例中的代码来实现自己需要的功能。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow 实现从checkpoint中获取graph信息 - Python技术站