概述
在TensorFlow中,tf.Session()
函数用于执行图中的操作。单个图可以拥有多个会话,但是会话不共享状态,由此可以更好地控制实现的方案。会话将操作运行在设备上,并执行同步和异步计算。对于CPU、GPU或TPU等不同类型的设备可以使用不同的会话。
基本语法
在使用tf.Session()
函数前,需要先构建一个表示计算的数据流图。使用tf.Session()
函数,创建一个会话实例sess,变量初始化的操作运行在sess中,并返回目标tensor。构建数据流图后,只要在会话中运行相关的操作,就可以得到执行结果。
import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
c = a + b
with tf.Session() as sess:
print(sess.run(c))
这里我们定义了两个常数a
和b
,定义了一个加法操作c=a+b
。之后创建新的会话,使用sess.run()
函数来获取最终结果,注意到sess.run()
函数的参数是一个tensor,代表要获取该tensor的值。
会话模式
当我们创建会话时,可以传入不同的会话类型参数,不同参数的使用场景如下:
InteractiveSession
InteractiveSession和普通Session不同的是,它会把生成的会话注册为默认会话,以后的操作不需要指定使用的会话了。
import tensorflow as tf
sess = tf.InteractiveSession()
a = tf.constant(1)
b = tf.constant(2)
c = a + b
print(c.eval()) # 要使用InteractiveSession,就得调用eval()函数
sess.close()
Session
Session()
是最常用的会话类型,由with来控制上下文,with内是计算图(Graph)的构建代码,with外是数据的获取代码。在with内调用sess.run()
即可计算计算图的节点tensor的值。
import tensorflow as tf
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = a + b
with tf.Session() as sess:
print(sess.run(c, feed_dict={a: 3, b: 4.5}))
这里我们定义了两个占位符,使用feed_dict
字典,将实际的数值传入计算图中。
Session config
使用Session config可以进行配置,如以下设置:
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
with tf.Session(config=config) as sess:
...
allow_and_set_gpu_memory_fraction()
:允许gpu显存分配控制;allow_and_set_gpu_memory_growth()
: 允许随着计算进行自动申请显存,不是一次性申请完所有的显存;log_device_placement=True
: 可以打印各个张量是分布在哪个设备上的。
实例演示
向量加法
import tensorflow as tf
# 创建计算图
a = tf.constant([1, 2], name='a')
b = tf.constant([3, 4], name='b')
result = tf.add(a, b, name='result')
# 创建会话
with tf.Session() as sess:
print(sess.run(result))
矩阵乘法
import tensorflow as tf
# 通过占位符传入数据,占位符前面使用dtype指定数据类型
x = tf.placeholder(tf.float32, shape=[None, 2])
y = tf.placeholder(tf.float32, shape=[None, 1])
# 定义变量
w = tf.Variable(tf.random_normal([2, 1], mean=0, stddev=1), name='w')
b = tf.Variable(tf.zeros([1,]), name='b')
# 计算图
result = tf.matmul(x, w) + b
# 创建会话
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 设置数据并进行计算
data_x = [[1, 2], [3, 4]]
data_y = [[5], [6]]
feed = {x: data_x, y: data_y}
print(sess.run(result, feed_dict=feed))
结论
本文详细介绍了TensorFlow的tf.Session()
函数的作用和使用方法。只要掌握数据流图的构建和Session的使用即可运行数据操作和获取结果。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:详解TensorFlow的 tf.Session 函数:创建一个会话 - Python技术站