Tensorflow 利用tf.contrib.learn建立输入函数的方法

TensorFlow 是目前广泛应用在人工智能领域的深度学习框架之一。在 TensorFlow 中,一般利用 tf.contrib.learn 模块建立模型,并利用输入函数(Input Function)将数据输入到模型中训练和预测。下面,我将详细讲解 TensorFlow 利用 tf.contrib.learn 建立输入函数的方法,包含两个示例。

示例一

首先我们需要导入需要的库:

import tensorflow as tf
from tensorflow.contrib import learn

输入函数需要返回数据(features)和标签(labels)两个对象,所以我们需要定义一个 input_fn 函数,该函数的返回值是一个函数对象。

def input_fn(data_set):
    feature_cols = {feature_name: tf.constant(data_set[feature_name].values) for feature_name in data_set.columns[:-1]}
    labels = tf.constant(data_set['label'].values)
    return feature_cols, labels

在该函数中,我们使用 pd.DataFrame.values 将特征列数据取出并构造成一个字典对象 feature_cols,每个特征列名称作为字典的键。另外构造一个常量对象 labels 即为标签数据。最终函数返回一个二元组,包含 feature_cols 和 labels。

接下来我们需要定义模型估计器。这里我们使用 DNNClassifier 进行分类器训练。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, feature_columns=feature_columns, model_dir="./tmp/iris_model")

这里使用的是 iris 数据集,因为难度适中,才便于理解其中的逻辑。

在上面的代码中,定义了三层 DNN 模型,其中 [10, 20, 10] 表示每层神经元的个数,n_classes 表示分类器需要预测的类别数目,feature_columns 是一个存放特征列对象的列表。最后一个参数 model_dir 是模型保存路径,保存在当前目录下的 tmp 文件夹下。

利用上述定义的输入函数和模型估计器开始进行模型训练。这里我们训练 1000 次。

import pandas as pd

IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

training_set = pd.read_csv(IRIS_TRAINING, header=0)
test_set = pd.read_csv(IRIS_TEST, header=0)

classifier.fit(input_fn=lambda: input_fn(training_set), steps=1000)

使用 pd.read_csv 函数从文件中读取训练数据集和测试数据集。然后,我们使用 classifier.fit 函数来训练模型,输入数据使用 input_fn(training_set),表示输入数据的方式使用刚才编写的输入函数。训练完成之后,模型会保存到指定的路径中。

在模型训练完成之后,我们可以使用 classifier.predict 函数用新输入数据进行预测。如下所示:

predictions = list(classifier.predict(input_fn=lambda: input_fn(test_set)))
print(predictions)

这里我们将测试集输入进去进行测试,接收到返回值为预测结果,其中每个元素代表一个样本的预测值。

示例二

这里让我们来看一个更加复杂的输入函数例子,使用的是 mnist 数据集,我们将数据集先进行预处理,然后使用 dataset.from_tensor_slices 函数转换数据格式,实现更为复杂的输入函数。

数据预处理:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("tmp/MNIST_data/", one_hot=False)

def input_fn_train():
    data = mnist.train.images.reshape((-1, 28, 28, 1))  # 训练数据集
    labels = mnist.train.labels.astype("int32")
    dataset = tf.data.Dataset.from_tensor_slices((data, labels)) \
        .map(lambda x, y: ({'image': x}, y)) \
        .batch(32) \
        .repeat()
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

def input_fn_test():
    data = mnist.test.images.reshape((-1, 28, 28, 1))  # 训练数据集
    labels = mnist.test.labels.astype("int32")
    dataset = tf.data.Dataset.from_tensor_slices((data, labels)) \
        .map(lambda x, y: ({'image': x}, y)) \
        .batch(32) \
        .repeat()
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

这里将数据 reshape 成了 [batch_size, 28, 28, 1] 的形式,因为 tf.contrib.learn 的 DNNClassifier 需要输入 [batch_size, features] 的输入数据。在返回的函数中,我们通过 tf.data.Dataset.from_tensor_slices 来构造数据集对象,由于训练样本是无限的,这里我们使用 .repeat() 函数来构造无限的数据流,并且按照设定的 batch_size 进行分割。

训练模型:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=784)]
classifier = learn.DNNClassifier(feature_columns=feature_columns,
                                 hidden_units=[256, 32],
                                 optimizer=tf.train.AdamOptimizer(learning_rate=0.001),
                                 n_classes=10,
                                 model_dir="./tmp/mnist_model")

classifier.fit(input_fn=input_fn_train, steps=10000)

在这里我们使用了与上个示例类似的方法来定义 DNNClassifier。需要注意的是,这里因为特征数为 784,直接使用 tf.contrib.layers.real_valued_column 定义即可,如果不确定特征数大小,可以使用更加通用的语句 tf.feature_column.numeric_column。在定义 DNNClassifier 的同时,我们可以指定优化器的类型和 learning_rate,同时也可以设置模型检查点的存放目录。

测试模型:

accuracy_score = classifier.evaluate(input_fn=input_fn_test, steps=40)['accuracy']
print('\nTest Accuracy: {0:f}%\n'.format(accuracy_score*100))

这里我们利用 DNNClassifier.evaluate 函数来计算模型在测试集上的精度。

总结

这篇文章我们主要探讨了 TensorFlow 如何利用 tf.contrib.learn 模块建立输入函数的方法,同时通过两个示例分别展示了如何输入预处理过的数据和复杂的输入函数情况。总的来说,在输入函数的编写中,我们需要注意数据的封装和格式的转换,将数据转换成适合模型输入的样式。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Tensorflow 利用tf.contrib.learn建立输入函数的方法 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • Tensorflow 老版本的安装 – 兵者

    Tensorflow 老版本的安装 Tensorflow 的版本,已经从1.0 进展到2.0 安装比较旧的版本时,有可能发现再pypi镜像中不存在,并没有对应的版本,而是只有2.*; 报错信息可能: ERROR: Could not find a version that satisfies the requirement tensorflow-gpu==1…

    2023年4月8日
    00
  • TensorFlow实战3——TensorFlow实现CNN

    1 from tensorflow.examples.tutorials.mnist import input_data 2 import tensorflow as tf 3 4 mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True) 5 sess = tf.InteractiveSes…

    tensorflow 2023年4月8日
    00
  • TensorFlow的图像NCHW与NHWC

        import tensorflow as tf x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] with tf.Session() as sess: a = tf.reshape(x, [2, 2, 3]) a = sess.run(a) print(a) print(“——————–…

    2023年4月8日
    00
  • tensorflow遇到ImportError: Could not find ‘cudart64_100.dll’错误解决

      在安装tensorflow的时候,使用import tensorflow出现了找不到dll文件的错误,参考了很多博客和stackflow的解决方案,发现其中只说了版本号不匹配,但是没有具体说明什么样的版本才是适配正确的,因此手写此避坑指南。再次感谢Function兄的指导帮助。   笔者环境:   python 版本3.6   tensorflow版本1…

    tensorflow 2023年4月7日
    00
  • python和tensorflow安装

    一、Python安装       python采用anaconda安装,简单方便,下载python3.6的anaconda  linux64的sh安装文件.       1、bash Anaconda-2.1.0-Linux-x86_64.sh       2、python,用于测试     二、Tensorflow安装   1、首先安装 pip (或 Py…

    tensorflow 2023年4月8日
    00
  • 深度学习之TensorFlow安装与初体验

    学习前 搞懂一些关系和概念首先,搞清楚一个关系:深度学习的前身是人工神经网络,深度学习只是人工智能的一种,深层次的神经网络结构就是深度学习的模型,浅层次的神经网络结构是浅度学习的模型。 浅度学习:层数少于3层,使用全连接的一般被认为是浅度神经网络,也就是浅度学习的模型,全连接的可能性过于繁多,如果层数超过三层,计算量呈现指数级增长,计算机无法计算到结果,所以…

    2023年4月5日
    00
  • Tensorflow矩阵运算实例(矩阵相乘,点乘,行/列累加)

    下面是Tensorflow矩阵运算实例(矩阵相乘,点乘,行/列累加)的完整攻略,本攻略包括两条示例说明。 示例1:矩阵相乘 背景 如何使用Tensorflow进行矩阵相乘运算? 实现步骤 首先,需要导入Tensorflow库。 import tensorflow as tf 创建两个矩阵。 a = tf.constant([[2, 3], [4, 5]]) …

    tensorflow 2023年5月17日
    00
  • tensorflow 输出权重到csv或txt的实例

    TensorFlow之如何输出权重到CSV或TXT的实例 在使用TensorFlow进行深度学习模型训练时,我们可能需要将模型的权重输出到CSV或TXT文件中,以便后续分析或使用。本文将提供一个完整的攻略,详细讲解如何输出TensorFlow模型的权重到CSV或TXT文件,并提供两个示例说明。 如何输出TensorFlow模型的权重到CSV或TXT文件 在输…

    tensorflow 2023年5月16日
    00
合作推广
合作推广
分享本页
返回顶部