基本使用

  • 使用图(graph)来表示计算任务
  • 激活会话(Session)执行图
  • 使用张量(tensor)表示数据
  • 定义变量(Variable)
  • 使用feed可以任意赋值或者从中获取数据,通常与占位符一起使用

1、综述

  Tensorflow是一个开源框架,使用图来表示计算任务,图中的节点被称作op(operation),一个op获得0个或者多个Tensor,执行计算,产生0个或者多个Tensor,每个tensor是一个类型的多维数组。例如,你可以将一小组图像集表示为一个四维浮点数数字,这四个维度分别是[batch,height,width,channels]。一个TensorFlow图描述了计算的过程,为了进行计算,图必须在会话里启动,会话将图的op分发到诸如CPU或者GPU的设备上,同时提供执行op的方法,这些方法执行后,将产生的tensor返回,在python语言中,返回的tensor是numpy array对象,在C或者C++语言中,返回的tensor是tensorflow:Tensor实例。

2、构建图

import tensorflow as tf
c = tf.constant(0.0)  # 自动默认图建立
g = tf.Graph() # 手动建立图
with g.as_default():
    c1 = tf.constant(0.0)
    print(c1.graph)
    print(g)
    print(c.graph)
g2 = tf.get_default_graph()
print(g2)

tf.reset_default_graph()
g3 = tf.get_default_graph()
print(g3)

 第四节:tensorflow图的基本操作

c是在刚开始的默认图中建立的,所以图的打印值就是原始的默认图的打印值。

然后使用tf.Graph函数建立一个图,并在新建的图里添加变量,可以通过变量的.graph获得所在图。

在新图的作用域外,使用tf.get_default_graph函数又获得了原始的默认值。

使用tf.reset_default_graph函数重新建了一张图代替原来的默认图

3、获取张量tensor

print(c1.name)
t = g.get_tensor_by_name(name = "Const:0")
print(t)

4、变量(Variables):启动图后,必须先初始化

#创建一个变量,初始化为标量0
state = tf.Variable(0,name = 'counter')

#创建一个op,其作用是使state增1
one = tf.constant(1)
new_value = tf.add(state,one)

'''
通常会将一个统计模型中的参数表示为一组变量. 例如, 你可以将一个神经网络的权重作为某个变量存储在一个 tensor 中. 在训练过程中, 通过重复运行训练图, 更新这个 tensor.
'''
update = tf.assign(state,new_value)

#启动图后,变量必须先经过'初始化' op 
#首先必须增加一个 '初始化' op 到图中
init_op = tf.global_variables_initializer()

#启动图,运行op
with tf.Session() as sess:
    sess.run(init_op)
    #打印state初始值
    print(sess.run(state))             #0
    #运行op,更新state,并打印
    for _ in range(3):
        sess.run(update)
        print(sess.run(state))  # 1 2 3

 

5、Feed机制:可以使用一个tensor临时替换一个操作的输出结果,提供feed数据作为run()调用的参数,feed只在调用它的方法内有效,方法结束,feed就会消失,常与占位符tf.placeholder()一起使用。

input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
output = tf.multiply(input1, input2)

#使用7替代input1,2替代input2,feed操作相当于设置一个占位符
with tf.Session() as sess:
  print(sess.run([output], feed_dict={input1:[7.], input2:[2.]}))    #[array([ 14.], dtype=float32)]