详解TensorFlow的 tf.Session 函数:创建一个会话

yizhihongxing

概述

在TensorFlow中,tf.Session()函数用于执行图中的操作。单个图可以拥有多个会话,但是会话不共享状态,由此可以更好地控制实现的方案。会话将操作运行在设备上,并执行同步和异步计算。对于CPU、GPU或TPU等不同类型的设备可以使用不同的会话。

基本语法

在使用tf.Session()函数前,需要先构建一个表示计算的数据流图。使用tf.Session()函数,创建一个会话实例sess,变量初始化的操作运行在sess中,并返回目标tensor。构建数据流图后,只要在会话中运行相关的操作,就可以得到执行结果。

import tensorflow as tf

a = tf.constant(2)
b = tf.constant(3)
c = a + b

with tf.Session() as sess:
    print(sess.run(c))

这里我们定义了两个常数ab,定义了一个加法操作c=a+b。之后创建新的会话,使用sess.run()函数来获取最终结果,注意到sess.run()函数的参数是一个tensor,代表要获取该tensor的值。

会话模式

当我们创建会话时,可以传入不同的会话类型参数,不同参数的使用场景如下:

InteractiveSession

InteractiveSession和普通Session不同的是,它会把生成的会话注册为默认会话,以后的操作不需要指定使用的会话了。

import tensorflow as tf

sess = tf.InteractiveSession()
a = tf.constant(1)
b = tf.constant(2)
c = a + b
print(c.eval())  # 要使用InteractiveSession,就得调用eval()函数
sess.close()

Session

Session()是最常用的会话类型,由with来控制上下文,with内是计算图(Graph)的构建代码,with外是数据的获取代码。在with内调用sess.run()即可计算计算图的节点tensor的值。

import tensorflow as tf

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = a + b

with tf.Session() as sess:
    print(sess.run(c, feed_dict={a: 3, b: 4.5}))

这里我们定义了两个占位符,使用feed_dict字典,将实际的数值传入计算图中。

Session config

使用Session config可以进行配置,如以下设置:

config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
with tf.Session(config=config) as sess:
    ...
  • allow_and_set_gpu_memory_fraction():允许gpu显存分配控制;
  • allow_and_set_gpu_memory_growth(): 允许随着计算进行自动申请显存,不是一次性申请完所有的显存;
  • log_device_placement=True: 可以打印各个张量是分布在哪个设备上的。

实例演示

向量加法

import tensorflow as tf

# 创建计算图
a = tf.constant([1, 2], name='a')
b = tf.constant([3, 4], name='b')
result = tf.add(a, b, name='result')
# 创建会话
with tf.Session() as sess:
    print(sess.run(result))

矩阵乘法

import tensorflow as tf

# 通过占位符传入数据,占位符前面使用dtype指定数据类型
x = tf.placeholder(tf.float32, shape=[None, 2])
y = tf.placeholder(tf.float32, shape=[None, 1])
# 定义变量
w = tf.Variable(tf.random_normal([2, 1], mean=0, stddev=1), name='w')
b = tf.Variable(tf.zeros([1,]), name='b')
# 计算图
result = tf.matmul(x, w) + b
# 创建会话
with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    # 设置数据并进行计算
    data_x = [[1, 2], [3, 4]]
    data_y = [[5], [6]]
    feed = {x: data_x, y: data_y}
    print(sess.run(result, feed_dict=feed))

结论

本文详细介绍了TensorFlow的tf.Session()函数的作用和使用方法。只要掌握数据流图的构建和Session的使用即可运行数据操作和获取结果。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:详解TensorFlow的 tf.Session 函数:创建一个会话 - Python技术站

(0)
上一篇 2023年3月23日
下一篇 2023年3月30日

相关文章

合作推广
合作推广
分享本页
返回顶部