官网默认定义如下:
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
该函数的功能主要是转换成one_hot类型的张量输出。
参数功能如下:
1)indices中的元素指示on_value的位置,不指示的地方都为off_value。indices可以是向量、矩阵。
2)depth表示输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]
3)on_value默认为1
4)off_value默认为0
5)dtype默认为tf.float32
下面用几个例子说明一下:
1. indices是向量
1 import tensorflow as tf 2 3 indices = [0,2,3,5] 4 depth1 = 6 # indices没有元素超过(depth-1) 5 depth2 = 4 # indices有元素超过(depth-1) 6 a = tf.one_hot(indices,depth1) 7 b = tf.one_hot(indices,depth2) 8 9 with tf.Session() as sess: 10 print('a = \n',sess.run(a)) 11 print('b = \n',sess.run(b))
运行结果:
# 输入是一维的,则输出是一个二维的
a = [[1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1.]] # shape=(4,6) b = [[1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.] [0. 0. 0. 0.]] # shape=(4,4)
2. indices是矩阵
1 import tensorflow as tf 2 3 indices = [[2,3],[1,4]] 4 depth1 = 9 # indices没有元素超过(depth-1) 5 depth2 = 4 # indices有元素超过(depth-1) 6 a = tf.one_hot(indices,depth1) 7 b = tf.one_hot(indices,depth2) 8 9 with tf.Session() as sess: 10 print('a = \n',sess.run(a)) 11 print('b = \n',sess.run(b))
运行结果:
# 输入是二维的,则输出是三维的
a = [[[0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0.]] [[0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0.]]] # shape=(2,2,9) b = [[[0. 0. 1. 0.] [0. 0. 0. 1.]] [[0. 1. 0. 0.] [0. 0. 0. 0.]]] # shape=(2,2,4)
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Tensorflow中one_hot() 函数用法 - Python技术站