显示pb模型的所有网络节点可以通过TensorFlow提供的工具tf.GraphDef().返回一个TensorFlow计算图的protocol buffer定义。可以通过以下步骤在Python API中使用tf.GraphDef():
1.导入TensorFlow模块
import tensorflow as tf
2.定义待加载的pb模型文件路径。其中with open()打开的文件流,读取二进制文件,'rb'代表读取二进制文件。
pb_file_path = "./example.pb"
with open(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
3.将pb模型文件解析成一个GraphDef。
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
4.再使用tf.import_graph_def将GraphDef加载到当前计算图中。并通过graph.get_operations()获取graph中的所有操作。
with tf.compat.v1.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def)
all_nodes = sess.graph.get_operations()
for node in all_nodes:
print(node.name)
该方法可以非常方便地列出给定的pb文件中的全部节点名称。示例如下:
import tensorflow as tf
pb_file_path = "./example.pb"
with open(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.compat.v1.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def)
all_nodes = sess.graph.get_operations()
for node in all_nodes:
print(node.name)
除此之外,也可以通过TensorBoard来可视化展示网络结点,也就是通过TensorFlow的内置工具GraphDef visualizer来将pb文件转化成网络结构图。首先,需要在代码中构造一张计算图,以便导出和可视化。示例如下:
import tensorflow as tf
pb_file_path = "./example.pb"
with tf.Graph().as_default():
graph_def = tf.compat.v1.GraphDef()
with open(pb_file_path, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# 在with语句块内,利用 tf.summary.FileWriter 将当前的计算图写入日志文件,亦即 events file.
with tf.compat.v1.Session() as sess:
writer = tf.compat.v1.summary.FileWriter('./log/', graph=sess.graph)
writer.close()
此操作将在 './log/' 目录下生成事件文件。在终端中输入命令:tensorboard --logdir="./log/",在浏览器中打开"http://localhost:6006/"即可见到可视化网络结点。例如,使用TensorBoard展示MNIST模型结构:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)
with tf.Graph().as_default() as graph:
x = tf.placeholder(tf.float32, [None, 784], name="Input_Data")
y = tf.placeholder(tf.float32, [None, 10], name="Label_Data")
with tf.name_scope("Model"):
W = tf.Variable(tf.zeros([784, 10]), name="Weight")
b = tf.Variable(tf.zeros([10]), name="Bias")
pred = tf.nn.softmax(tf.matmul(x, W) + b)
with tf.name_scope("LossFunction"):
loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=[1]))
with tf.name_scope("Training"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss_function)
with tf.name_scope("Accuracy"):
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init = tf.global_variables_initializer()
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
sess.run(init)
writer = tf.compat.v1.summary.FileWriter('./log/', graph=sess.graph)
writer.close()
在终端中,将以上代码保存在mnist.py文件中,输入:tensorboard --logdir="./log/"即可看到MNIST模型的可视化网络结构。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:使用tensorflow显示pb模型的所有网络结点方式 - Python技术站