问题的出现 Question

这个问题是我基于TensorFlow使用CNN训练MNIST数据集的时候遇到的。关键的相关代码是以下这部分:

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

学习速率是\((1e-4)\)的时候是没有问题,但是当我把学习速率调到\(0.01/0.5\)的时候,很快就会报错。

tensorflow.python.framework.errors.InvalidArgumentError: ReluGrad input is not finite. : Tensor had NaN values

分析 Analysis

学习速率 Learning Rate

于是我尝试加上几行代码,希望能把y_conv和cross_entropy的状态反映出来。

y_conv=tf.Print(y_conv,[y_conv],"y_conv: ")
cross_entropy =tf.Print(cross_entropy,[cross_entropy],"cross_entropy: ")

当learning rate \(=0.01\)时,程序会报错:

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.0374929e-06 0.0059775524 0.980205...]
step 0, training accuracy 0.04
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [9.2028862e-10 1.4812358e-05 0.044873074...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [648.49146]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.024463326 1.4828938e-31 0...]
step 1, training accuracy 0.2
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [2.4634053e-11 3.3087209e-34 0...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [nan]
step 2, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [nan nan nan...]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7ff51d92a940 Compute status: Invalid argument: ReluGrad input is not finite. : Tensor had NaN values

当learning rate \(=1e-4\)时,程序不会报错。

I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.00056920078 8.4922984e-09 0.00033719366...]
step 0, training accuracy 0.14
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [7.0613837e-10 9.28294e-09 0.00016230672...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [439.95135]
step 1, training accuracy 0.16
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.031509314 3.6221365e-05 0.015359053...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [3.7112056e-07 1.8543299e-09 8.9234991e-06...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [436.37653]
step 2, training accuracy 0.12
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [0.015578311 0.0026688741 0.44736364...]
I tensorflow/core/kernels/logging_ops.cc:64] y_conv: [6.0428465e-07 0.0001744287 0.026451336...]
I tensorflow/core/kernels/logging_ops.cc:64] cross_entropy: [385.33765]

至此,我们可以看到,学习速率太大是产生error其中一个原因。

参考斯坦福CS 224D的Lecture Note,在训练深度神经网络的时候,出现NaN比较大的可能是因为学习速率过大,梯度值过大,产生梯度爆炸。

Refer to the lecture note of Stanford CS 224D, a precise definition of Gradient Explosion is:

During experimentation, once the gradient value grows extremely large, it causes an overflow (i.e. NaN) which is easily detectable at runtime; this issue is called the Gradient Explosion Problem.

解决方法 Solutions

  1. 适当减小学习速率 Try to decrease the learning rate.
  2. 加入Gradient clipping的方法。 Gradient clipping的方法最早是由Thomas Mikolov提出的。每当梯度达到一定的阈值,就把他们设置回一个小一些的数字。

    Refer to the lecture note of Stanford CS 224D, use gradient clipping.

To solve the problem of exploding gradients, Thomas Mikolov first introduced a simple heuristic solution that clips gradients to a small number whenever they explode. That is, whenever they reach a certain threshold, they are set back to a small number as shown in Algorithm 1.

Algorithm 1:

\(\frac{\partial E}{\partial W}\to g\)

if $ \Vert g\Vert\ge threshold$ then

\(\frac {threshold}{\Vert g\Vert} g\to g\)

end if

最新文章

  1. HDU5556 Land of Farms(二分图 2015 合肥区域赛)
  2. zabbix使用sendEmail发送邮件报警
  3. [转载]提高rails new时bundle install运行速度
  4. openswitch db files
  5. win7 64位搭建scrapy
  6. 读取excel文件内容代码
  7. 如何在Centos上安装python3.4
  8. lightoj 1243 - Guardian Knights 最小费用流
  9. YouTube视频插入Markdown
  10. 更具体的描述JNI
  11. Java 实现下载
  12. PHP标准库(SPL)- SplDoublyLinkedList类(双向链表)
  13. 利用PowerShell 得到 进程总共占用的内存
  14. CCF-201412-2-Z字形扫描
  15. AngualrJS之服务器端通信
  16. 转:StarUML3.0的破解方法
  17. [No0000B8]WPF或Winform调用系统Console控制台显示信息
  18. 什么是AOP面向切面编程
  19. hdoj-1503 (LCS解的输出)
  20. [UE4]点积、余弦和急停

热门文章

  1. tomcat端口被占用如何解决
  2. 搭建 Redis 的主从
  3. URL参数获取/转码
  4. Delphi7 GDI+学习
  5. 中国软件大会上大快搜索入选中国数字化转型TOP100服务商
  6. ecshop跨站漏洞详情及修补网站漏洞
  7. Zabbix 3.4.11版本 自定义监控项
  8. JZ2440开发板:UART(串口)使用(学习笔记)
  9. 第五节 Go数据结构之队列
  10. 常用前端UI框架