tensorflow的模型保存有两种形式:
1. ckpt:可以恢复图和变量,继续做训练
2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练
Demo
1 def freeze_graph(input_checkpoint,output_graph): 2 3 ''' 4 :param input_checkpoint: 5 :param output_graph: PB模型保存路径 6 :return 7 void 8 ''' 9 10 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 11 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 12 13 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 14 output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开 15 16 ############################ Step1: 从ckpt中恢复图: ############################################# 17 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 18 graph = tf.get_default_graph() # 获得默认的图, 可以省略 19 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图,可以省略 20 21 with tf.Session() as sess: # 会使用默认的图 作为当前的图 22 saver.restore(sess, input_checkpoint) #恢复图并得到数据 23 24 ######################## Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息 ############## 25 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 26 sess=sess, 27 input_graph_def=input_graph_def,# 等于:sess.graph_def 28 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开 29 ######################### Step3: 模型持久化 ####################################################### 30 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 31 f.write(output_graph_def.SerializeToString()) #序列化输出 32 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点 33 # for op in graph.get_operations(): 34 35 # print(op.name, op.values()) 36 37 38 ########################### 调用方式 ################################ 39 # 输入ckpt模型路径 40 input_checkpoint='models/model.ckpt-10000' 41 # 输出pb模型的路径 42 out_pb_path="models/pb/frozen_model.pb" 43 # 调用freeze_graph将ckpt转为pb 44 freeze_graph(input_checkpoint,out_pb_path)
解析
函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。
freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。
在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称;
tensor name 和 node name 的区别
node name 是 图 的节点,里面包含了很多操作和tensor
tensor 是 node 里面的一个组成部分;
以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称
PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Tensorflow Learning1 模型的保存和恢复 - Python技术站