在计算loss的时候,最常见的一句话就是 tf.nn.softmax_cross_entropy_with_logits ,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

除去name参数用以指定该操作的name,与方法有关的一共两个参数:

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上

具体的执行流程大概分为两步:

第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:

\[H_{y'}(y)=-\sum_i{y'_ilog(y_i)}
\]

其中\(y'_i\)指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

\(y_i\)就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

理论讲完了,上代码

import tensorflow as tf

#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!! with tf.Session() as sess:
softmax=sess.run(y)
c_e = sess.run(cross_entropy)
c_e2 = sess.run(cross_entropy2)
print("step1:softmax result=")
print(softmax)
print("step2:cross_entropy result=")
print(c_e)
print("Function(softmax_cross_entropy_with_logits) result=")
print(c_e2)

输出结果是:

step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228

最后大家可以试试e1/(e1+e2+e3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的

原文链接:【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法

MARSGGBO♥原创







2018-7-30

最新文章

  1. 基于jQuery的email suggest插件
  2. AngularJS开发指南3:Angular主要组成部分以及如何协同工作
  3. Objective之ARC
  4. C# "error CS1729: 'XXClass' does not contain a constructor that takes 0 arguments"的解决方案
  5. rJava配置
  6. BloomFilter——读数学之美札记
  7. 逆波兰表达式 java
  8. 基于表单的身份验证(FBA)
  9. javascript7
  10. hdu1356&hdu1944 博弈论的SG值(王道)
  11. [AI开发]Python+Tensorflow打造自己的计算机视觉API服务
  12. 浅谈 CSS 预处理器: 为什么要使用预处理器?
  13. c++中为什么可以通过指针或引用实现多态,而不可以通过对象呢?
  14. java 值传递 数组传递
  15. SQL类型转换和数学函数
  16. JMS规范概览
  17. saltstack常用模块
  18. python网络编程之线程
  19. 开源 SHOPNC B2B2C结算营运版 wap IM客服 API 手机app 短信通知
  20. 串行 RapidIO

热门文章

  1. Code First 重复外键(简单方法)
  2. Leading and Trailing LightOJ - 1282 (取数的前三位和后三位)
  3. Domino 邮箱服务器接收不存在的邮箱账号的邮件
  4. 【转】cJSON 源码阅读笔记
  5. Linux中配置Aria2 RPC Server
  6. 【loj3056】【hnoi2019】多边形
  7. 51nod1237 最大公约数之和 V3
  8. 【codevs2189】数字三角形+
  9. Flash:使用FileReference上传在Firefox上遇到的问题终于解决了
  10. HDU3613 Manacher//EXKMP//KMP