首先,expand_dims()
函数是 TensorFlow 中用于增加张量维度的函数,可传入三个参数:
input
: 要增加维度的张量axis
: 新维度所在的位置,取值范围为 $[-(R+1), R]$,其中 R 为原张量的秩,当axis
为负数时表示新维度在倒数第 $|axis|$ 个位置(比如-1
表示最后一个位置)name
: 可选参数,表示操作的名称
以下是使用示例并附有详细解释:
示例一
import tensorflow as tf
# 定义一个张量
x = tf.constant([
[1, 2],
[3, 4]
])
# 增加维度
y = tf.expand_dims(x, axis=0)
print(y)
输出结果为:
Tensor("ExpandDims:0", shape=(1, 2, 2), dtype=int32)
解释如下:
- 此处将形状为
(2, 2)
的张量x
在第 0 个位置增加了一个维度,故输出张量的形状发生了变化,变为(1, 2, 2)
。 - 可以看到,
expand_dims()
函数返回的是一个张量,而不是具体的数值。 - 输出的 tensor 对象名称为
ExpandDims:0
,这是 TensorFlow 在图中自动为该节点命名的。 - 输出 tensor 对象的数据类型为
int32
。
示例二
import tensorflow as tf
# 定义一个张量
x = tf.constant([
[[1, 2], [3, 4]],
[[5, 6], [7, 8]]
])
# 增加维度,在最后一个位置增加一个维度
y = tf.expand_dims(x, axis=-1)
print(y)
输出结果为:
Tensor("ExpandDims_1:0", shape=(2, 2, 2, 1), dtype=int32)
解释如下:
- 此处将 shape 为
(2, 2, 2)
的张量在最后一个位置增加了一个维度,输出张量的形状发生了变化,变为(2, 2, 2, 1)
。 axis=-1
表示最后一个位置,因此新维度被增加到最后。- 输出 tensor 对象名称为
ExpandDims_1:0
,是新声明的节点名称。 - 输出 tensor 对象的数据类型为
int32
。
以上是 expand_dims()
函数的使用方法和示例。需要注意的是,增加维度后的张量形状应该与实际需要的计算一致。进一步地,当张量的秩大于等于3时,需要在传入 axis
参数时确保符号相同。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFlow用expand_dim()来增加维度的方法 - Python技术站