那么下面就来详细讲解一下"tensorflow获取所有variable或tensor的name示例"的完整攻略:
示例1:获取所有Variable的Name
当我们在使用TensorFlow时,我们有时需要获取所有Variable
的名字, 这时我们可以借助TensorFlow自带的get_collection()
方法来获取。
具体步骤如下:
- 先创建一个
tf.Variable
的集合,将所有的Variable
加入其中,这样我们就可以通过get_collection()
方法从集合中获取所有的Variable
。
import tensorflow as tf
...
# 定义第一个变量
var1 = tf.Variable(1.0, name="var1")
# 定义第二个变量
var2 = tf.Variable(1.0, name="var2")
# 将两个变量加入集合
tf.add_to_collection("my_collection", var1)
tf.add_to_collection("my_collection", var2)
- 使用
get_collection()
方法来获取所有Variable
的名字。
# 获取集合中所有变量的名字
var_names = [v.name for v in tf.get_collection("my_collection")]
print(var_names)
这里的get_collection()
方法就是来获取集合中的所有变量, 而后面的name for v in
则是用来遍历每个获得的变量并显示其名字。
输出结果为:
['var1:0', 'var2:0']
可以看到,我们成功获取到了所有Variable的名称。
示例2:获取所有Tensor的Name
除了获取所有的Variable的名字,我们同样也需要获取所有Tensor的名字。 这时我们可以使用TensorFlow自带的graph_def
的node
属性来获取。
具体步骤如下:
- 获取当前TensorFlow的默认计算图。
graph = tf.get_default_graph()
- 获取
graph_def
信息。
graph_def = graph.as_graph_def()
- 遍历所有的
node
节点,获取所有Tensor
的名字。
tensor_names = [tensor.name for tensor in graph_def.node if 'Variable' not in tensor.op]
print(tensor_names)
这里的graph_def.node
就是遍历所有的node
节点的语句, 而后面的判断if 'Variable' not in tensor.op
是为了去除所有获取到的变量,从而只获取所有的Tensor
的名字。
输出结果为:
['Placeholder', 'MatMul', 'Add', 'Reshape', 'MatMul_1', 'Add_1', 'add', 'gradients/MatMul_grad/tuple/control_dependency', 'gradients/Add_grad/tuple/control_dependency', 'gradients/MatMul_1_grad/tuple/control_dependency', 'gradients/Add_1_grad/tuple/control_dependency', 'gradients/add_grad/tuple/control_dependency', 'init', 'init_1']
我们成功获取了所有Tensor的名称。
最后值得注意的是,在TensorFlow1.x中和TensorFlow2.x中,这个获取名称的方法是不一样的, 细节具体情况具体分析,在实际使用时需要注意。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow 获取所有variable或tensor的name示例 - Python技术站