tensorflow中训练后的模型是一个pb文件,proto 文件如下:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
通过解析pb文件即可以拿到训练后的的权重信息。
with open(output_graph_path,"rb") as f: output_graph.ParseFromString(f.read())
graph是有node节点组成,遍历所有的node 节点可以获取到 训练的权重信息以及shape大小:
for node in output_graph.node: print 'name:{}'.format(node.name) print 'shape:{},dtype;{}'.format(node.attr['value'].tensor.tensor_shape,node.attr['value'].tensor.dtype) if node.attr['value'].tensor.dtype != 1: continue print tensor_util.MakeNdarray(node.attr['value'].tensor)
graph是一个有向图,由节点和有向边组成,假设有如下计算表达式:t1=MatMul(input, W1)。
图计算表达式包含三个节点,两条边,描述为文字形式如下:
TF 调用protobuf 解析方法,将graph 的字符串描述解析并生成grapdef实例,下一节查看graphdef的输入和输出,并尝试将模型转换为caffe模型
代码在git地址:https://github.com/wudafucode/machine_learning/blob/master/showpb.py
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow 学习笔记(1)—-解析pb文件,打印node的权重信息 - Python技术站