本文主要是介绍关于tensorflow中的softmax_cross_entropy_with_logits_v2函数的区别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
tf.nn.softmax_cross_entropy_with_logits(记为f1) 和
tf.nn.sparse_softmax_cross_entropy_with_logits(记为f3),以及
tf.nn.softmax_cross_entropy_with_logits_v2(记为f2)
之间的区别。
f1和f3对于参数logits的要求都是一样的,即未经处理的,直接由神经网络输出的数值, 比如 [3.5,2.1,7.89,4.4]
。两个函数不一样的地方在于labels格式的要求,f1的要求labels的格式和logits类似,比如[0,0,1,0]
。而f3的要求labels是一个数值,这个数值记录着ground truth所在的索引。以[0,0,1,0]
为例,这里真值1的索引为2。所以f3要求labels的输入为数字2(tensor)。一般可以用tf.argmax()
来从[0,0,1,0]
中取得真值的索引。
f1和f2之间很像,实际上官方文档已经标记出f1已经是deprecated 状态,推荐使用f2。两者唯一的区别在于f1在进行反向传播的时候,只对logits进行反向传播,labels保持不变。而f2在进行反向传播的时候,同时对logits和labels都进行反向传播,如果将labels传入的tensor设置为stop_gradients,就和f1一样了。
那么问题来了,一般我们在进行监督学习的时候,labels都是标记好的真值,什么时候会需要改变label?f2存在的意义是什么?实际上在应用中labels并不一定都是人工手动标注的,有的时候还可能是神经网络生成的,一个实际的例子就是对抗生成网络(GAN)。
测试用代码:
import tensorflow as tf
import numpy as npTruth = np.array([0,0,1,0])
Pred_logits = np.array([3.5,2.1,7.89,4.4])loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)with tf.Session() as sess:print(sess.run(loss))print(sess.run(loss2))print(sess.run(loss3))
参考:
https://www.tensorflow.org/api_docs/
https://stats.stackexchange.com/questions/327348/how-is-softmax-cross-entropy-with-logits-different-from-softmax-cross-entropy-wi
这篇关于关于tensorflow中的softmax_cross_entropy_with_logits_v2函数的区别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!