本文主要是介绍tensorflow统计graph中的trainable_variables,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
最简单的做法: 转自: https://blog.csdn.net/feynman233/article/details/79187304, 版权归原作者所有。
print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
另有篇博客讲解的很详细:原文地址https://blog.csdn.net/shwan_ma/article/details/78879620,版权归原作者所有。
原博主写的很好,将常用的方法记载下来供以后学习参考。
sess.run(tf.global_varibales_initializer())
variable_name = [v.name for v in tf.trainable_variables()]
print(variable_names)
variable_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variable_names)
for k,v in zip(variable_names, values):
print("Variable: ", k)
print("Shape: ", v.shape)
print(v)
for variable in tf.trainable_variables():
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
这篇关于tensorflow统计graph中的trainable_variables的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!