Tensorflow 利用tf.contrib.learn建立输入函数的方法

yizhihongxing

TensorFlow 是目前广泛应用在人工智能领域的深度学习框架之一。在 TensorFlow 中,一般利用 tf.contrib.learn 模块建立模型,并利用输入函数(Input Function)将数据输入到模型中训练和预测。下面,我将详细讲解 TensorFlow 利用 tf.contrib.learn 建立输入函数的方法,包含两个示例。

示例一

首先我们需要导入需要的库:

import tensorflow as tf
from tensorflow.contrib import learn

输入函数需要返回数据(features)和标签(labels)两个对象,所以我们需要定义一个 input_fn 函数,该函数的返回值是一个函数对象。

def input_fn(data_set):
    feature_cols = {feature_name: tf.constant(data_set[feature_name].values) for feature_name in data_set.columns[:-1]}
    labels = tf.constant(data_set['label'].values)
    return feature_cols, labels

在该函数中,我们使用 pd.DataFrame.values 将特征列数据取出并构造成一个字典对象 feature_cols,每个特征列名称作为字典的键。另外构造一个常量对象 labels 即为标签数据。最终函数返回一个二元组,包含 feature_cols 和 labels。

接下来我们需要定义模型估计器。这里我们使用 DNNClassifier 进行分类器训练。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, feature_columns=feature_columns, model_dir="./tmp/iris_model")

这里使用的是 iris 数据集,因为难度适中,才便于理解其中的逻辑。

在上面的代码中,定义了三层 DNN 模型,其中 [10, 20, 10] 表示每层神经元的个数,n_classes 表示分类器需要预测的类别数目,feature_columns 是一个存放特征列对象的列表。最后一个参数 model_dir 是模型保存路径,保存在当前目录下的 tmp 文件夹下。

利用上述定义的输入函数和模型估计器开始进行模型训练。这里我们训练 1000 次。

import pandas as pd

IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

training_set = pd.read_csv(IRIS_TRAINING, header=0)
test_set = pd.read_csv(IRIS_TEST, header=0)

classifier.fit(input_fn=lambda: input_fn(training_set), steps=1000)

使用 pd.read_csv 函数从文件中读取训练数据集和测试数据集。然后,我们使用 classifier.fit 函数来训练模型,输入数据使用 input_fn(training_set),表示输入数据的方式使用刚才编写的输入函数。训练完成之后,模型会保存到指定的路径中。

在模型训练完成之后,我们可以使用 classifier.predict 函数用新输入数据进行预测。如下所示:

predictions = list(classifier.predict(input_fn=lambda: input_fn(test_set)))
print(predictions)

这里我们将测试集输入进去进行测试,接收到返回值为预测结果,其中每个元素代表一个样本的预测值。

示例二

这里让我们来看一个更加复杂的输入函数例子,使用的是 mnist 数据集,我们将数据集先进行预处理,然后使用 dataset.from_tensor_slices 函数转换数据格式,实现更为复杂的输入函数。

数据预处理:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("tmp/MNIST_data/", one_hot=False)

def input_fn_train():
    data = mnist.train.images.reshape((-1, 28, 28, 1))  # 训练数据集
    labels = mnist.train.labels.astype("int32")
    dataset = tf.data.Dataset.from_tensor_slices((data, labels)) \
        .map(lambda x, y: ({'image': x}, y)) \
        .batch(32) \
        .repeat()
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

def input_fn_test():
    data = mnist.test.images.reshape((-1, 28, 28, 1))  # 训练数据集
    labels = mnist.test.labels.astype("int32")
    dataset = tf.data.Dataset.from_tensor_slices((data, labels)) \
        .map(lambda x, y: ({'image': x}, y)) \
        .batch(32) \
        .repeat()
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

这里将数据 reshape 成了 [batch_size, 28, 28, 1] 的形式,因为 tf.contrib.learn 的 DNNClassifier 需要输入 [batch_size, features] 的输入数据。在返回的函数中,我们通过 tf.data.Dataset.from_tensor_slices 来构造数据集对象,由于训练样本是无限的,这里我们使用 .repeat() 函数来构造无限的数据流,并且按照设定的 batch_size 进行分割。

训练模型:

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=784)]
classifier = learn.DNNClassifier(feature_columns=feature_columns,
                                 hidden_units=[256, 32],
                                 optimizer=tf.train.AdamOptimizer(learning_rate=0.001),
                                 n_classes=10,
                                 model_dir="./tmp/mnist_model")

classifier.fit(input_fn=input_fn_train, steps=10000)

在这里我们使用了与上个示例类似的方法来定义 DNNClassifier。需要注意的是,这里因为特征数为 784,直接使用 tf.contrib.layers.real_valued_column 定义即可,如果不确定特征数大小,可以使用更加通用的语句 tf.feature_column.numeric_column。在定义 DNNClassifier 的同时,我们可以指定优化器的类型和 learning_rate,同时也可以设置模型检查点的存放目录。

测试模型:

accuracy_score = classifier.evaluate(input_fn=input_fn_test, steps=40)['accuracy']
print('\nTest Accuracy: {0:f}%\n'.format(accuracy_score*100))

这里我们利用 DNNClassifier.evaluate 函数来计算模型在测试集上的精度。

总结

这篇文章我们主要探讨了 TensorFlow 如何利用 tf.contrib.learn 模块建立输入函数的方法,同时通过两个示例分别展示了如何输入预处理过的数据和复杂的输入函数情况。总的来说,在输入函数的编写中,我们需要注意数据的封装和格式的转换,将数据转换成适合模型输入的样式。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Tensorflow 利用tf.contrib.learn建立输入函数的方法 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • 对tensorflow 的模型保存和调用实例讲解

    在TensorFlow中,我们可以使用tf.train.Saver()方法保存模型,并使用tf.train.import_meta_graph()方法调用模型。本文将详细讲解如何对TensorFlow的模型进行保存和调用,并提供两个示例说明。 示例1:保存和调用模型 以下是保存和调用模型的示例代码: import tensorflow as tf # 定义模…

    tensorflow 2023年5月16日
    00
  • tensorflow 和cuda对应关系

    Version Python version Compiler Build tools tensorflow-1.11.0 2.7, 3.3-3.6 GCC 4.8 Bazel 0.15.0 tensorflow-1.10.0 2.7, 3.3-3.6 GCC 4.8 Bazel 0.15.0 tensorflow-1.9.0 2.7, 3.3-3.6 GC…

    tensorflow 2023年4月6日
    00
  • win7上tensorflow2.2.0安装成功 引用DLL load failed时找不到指定模块 tensorflow has no attribute xxx 解决方法

    win7上tensorflow2.2.0安装成功 引用DLL load failed时找不到指定模块 tensorflow has no attribute xxx 解决方法 在Windows 7上安装TensorFlow 2.2.0时,有时会遇到引用DLL load failed时找不到指定模块或者tensorflow has no attribute x…

    tensorflow 2023年5月16日
    00
  • 解决Jupyter notebook[import tensorflow as tf]报错

     参考: https://blog.csdn.net/caicai_zju/article/details/70245099

    tensorflow 2023年4月6日
    00
  • 1.0Tensorflow中出现编译问题的解决方案

    跑简单tf例程的时候遇到这个 sess = tf.Session(),I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 S…

    2023年4月8日
    00
  • tensorflow更改变量的值实例

    在TensorFlow中,我们可以使用tf.Variable.assign()方法更改变量的值。本文将详细讲解TensorFlow更改变量的值的方法,并提供两个示例说明。 示例1:更改变量的值 以下是更改变量的值的示例代码: import tensorflow as tf # 定义变量 x = tf.Variable(1.0) # 打印变量的值 print(…

    tensorflow 2023年5月16日
    00
  • 构建基于深度学习神经网络协同过滤模型(NCF)的视频推荐系统(Python3.10/Tensorflow2.11)

    毋庸讳言,和传统架构(BS开发/CS开发)相比,人工智能技术确实有一定的基础门槛,它注定不是大众化,普适化的东西。但也不能否认,人工智能技术也具备像传统架构一样“套路化”的流程,也就是说,我们大可不必自己手动构建基于神经网络的机器学习系统,直接使用深度学习框架反而更加简单,深度学习可以帮助我们自动地从原始数据中提取特征,不需要手动选择和提取特征。 毋庸讳言,…

    2023年4月5日
    00
  • tensorflow安装问题:ImportError:DLL load failed找不到指定模块

      初步接触图像识别,通过pip下载了需要用到的包,tensorflow有CPU版和GPU版的,因为GPU版的需要配置cuda和cudnn,比较麻烦,所以先拿CPU版的开刀,但是在安装后进行测试时,出现了找不到指定模块的错误,我下载的是tensorflow2.2版本,网上给的教程有调低版本这一方法,如使用tensorflow1.15。但我down下来的测试用…

    2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部