将tensorflow模型打包成PB文件及PB文件读取方式

yizhihongxing

将TensorFlow模型打包成PB文件及PB文件读取方式

在TensorFlow中,可以将训练好的模型打包成PB文件,以便在其他环境中使用。本文将详细讲解如何将TensorFlow模型打包成PB文件以及如何读取PB文件,并提供两个示例说明。

步骤1:将模型保存为PB文件

在TensorFlow中,可以使用tf.saved_model.simple_save()方法将模型保存为PB文件。可以使用以下代码将模型保存为PB文件:

import tensorflow as tf

# 创建模型
inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
outputs = tf.placeholder(tf.float32, shape=[None, 10], name='outputs')
hidden = tf.layers.dense(inputs, 256, activation=tf.nn.relu)
logits = tf.layers.dense(hidden, 10, activation=None)
predictions = tf.nn.softmax(logits, name='predictions')

# 保存模型为PB文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.saved_model.simple_save(sess, 'model', inputs={'inputs': inputs}, outputs={'predictions': predictions})

在这个代码中,我们首先创建了一个简单的模型,然后使用tf.saved_model.simple_save()方法将模型保存为PB文件。在保存模型时,我们需要指定模型的输入和输出张量,并将其保存在指定的文件夹中。

步骤2:读取PB文件

在TensorFlow中,可以使用tf.saved_model.loader.load()方法读取PB文件。可以使用以下代码读取PB文件:

import tensorflow as tf

# 读取PB文件
with tf.Session() as sess:
    tf.saved_model.loader.load(sess, ['serve'], 'model')
    graph = tf.get_default_graph()
    inputs = graph.get_tensor_by_name('inputs:0')
    predictions = graph.get_tensor_by_name('predictions:0')

在这个代码中,我们使用tf.saved_model.loader.load()方法读取PB文件,并使用tf.get_default_graph()方法获取默认图。然后,我们使用graph.get_tensor_by_name()方法获取输入和输出张量。

示例1:将TensorFlow模型打包成PB文件

以下是将TensorFlow模型打包成PB文件的示例代码:

import tensorflow as tf

# 创建模型
inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
outputs = tf.placeholder(tf.float32, shape=[None, 10], name='outputs')
hidden = tf.layers.dense(inputs, 256, activation=tf.nn.relu)
logits = tf.layers.dense(hidden, 10, activation=None)
predictions = tf.nn.softmax(logits, name='predictions')

# 保存模型为PB文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.saved_model.simple_save(sess, 'model', inputs={'inputs': inputs}, outputs={'predictions': predictions})

在这个示例中,我们创建了一个简单的模型,并使用tf.saved_model.simple_save()方法将模型保存为PB文件。

示例2:读取TensorFlow PB文件

以下是读取TensorFlow PB文件的示例代码:

import tensorflow as tf

# 读取PB文件
with tf.Session() as sess:
    tf.saved_model.loader.load(sess, ['serve'], 'model')
    graph = tf.get_default_graph()
    inputs = graph.get_tensor_by_name('inputs:0')
    predictions = graph.get_tensor_by_name('predictions:0')

在这个示例中,我们使用tf.saved_model.loader.load()方法读取PB文件,并使用graph.get_tensor_by_name()方法获取输入和输出张量。

结语

以上是将TensorFlow模型打包成PB文件及PB文件读取方式的详细攻略,包括将模型保存为PB文件和读取PB文件等步骤,并提供了两个示例。在实际应用中,我们可以根据具体情况来选择合适的方法来保存和读取模型。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:将tensorflow模型打包成PB文件及PB文件读取方式 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • Tensorflow object detection API 搭建物体识别模型(一)

    一、开发环境  1)python3.5  2)tensorflow1.12.0  3)Tensorflow object detection API :https://github.com/tensorflow/models下载到本地,解压   我们需要的目标检测代码在models-research文件中:     其中object_detection中的R…

    tensorflow 2023年4月7日
    00
  • tensorflow学习一

    1.用图(graph)来表示计算任务 2.用op(opreation)来表示图中的计算节点,图有默认的计算节点,构建图的过程就是在其基础上加节点。 3.用tensor表示每个op的输入输出数据,可以使用feed,fetch可以为任意操作设置输入和获取输出。 4.通过Variable来维护状态。 5.整个计算任务放入session的上下文来执行。     te…

    tensorflow 2023年4月8日
    00
  • tensorflow TFRecords文件的生成和读取的方法

    TensorFlow提供了TFRecords文件格式,它是一种二进制文件格式,用于有效地处理大量数据。TFRecords文件包含一系列大小固定的记录。每条记录包含一个二进制数据字符串(实际上是一个字节数组)和它所代表的任何数据以及它的长度。在此过程中,我们将重点介绍如何生成和读取TensorFlow中的TFRecords文件。 生成TFRecords文件 以…

    tensorflow 2023年5月18日
    00
  • ubuntu16.04 使用tensorflow object detection训练自己的模型

    一、构建自己的数据集 1、格式必须为jpg、jpeg或png。 2、在models/research/object_detection文件夹下创建images文件夹,在images文件夹下创建train和val两个文件夹,分别存放训练集图片和测试集图片。 3、下载labelImg目标检测标注工具 (1)下载地址:https://github.com/tzut…

    tensorflow 2023年4月8日
    00
  • 关于tensorflow版本报错问题的解决办法

    #原 config = tf.ConfigProto(allow_soft_placement=True) config = tf.compat.v1.ConfigProto(allow_soft_placement=True) #原 sess = tf.Session(config=config) sess =tf.compat.v1.Session(co…

    tensorflow 2023年4月6日
    00
  • 12 tensorflow实战:修改三维tensor矩阵的某个剖面

    # -*- coding: utf-8 -*- “”” Created on Mon Apr 22 21:02:02 2019 @author: a “”” # -*- coding: utf-8 -*- “”” Created on Sat Dec 1 16:53:26 2018 @author: a “”” import tensorflow as tf…

    tensorflow 2023年4月8日
    00
  • Tensorflow暑期实践——基于多隐层神经网络的手写数字识别

    版权说明:浙江财经大学专业实践深度学习tensorflow——齐峰 目录 1  基于多隐层神经网络的手写数字识别 2  本章内容介绍 3  Tensorflow实现基于单个神经元的手写数字识别 4  Tensorflow实现基于单隐层神经网络的手写数字识别 5.1  载入数据 5.2.1  构建输入层 5.2.2  构建隐藏层h15.2.3  构建隐藏层h2…

    2023年4月8日
    00
  • windows tensorflow无法下载Fashion-mnist的解决办法

    使用下面的语句下载数据集会报错连接超时等 import tensorflow as tf from tensorflow import keras fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fa…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部