指数加权平均 (exponentially weighted averges)

先说一下指数加权平均, 公式如下:

\[v_{t}=\beta v_{t-1}+(1-\beta) \theta_{t}
\]
  • \(\theta_t\) 是第t天的观测值
  • \(v_t\) 是用来替代\(\theta_t\)的估计值,也就是加权平均值
  • \(\beta\) 超参数

设 \(\beta = 0.9\) , 那么公式可以化简为:

\[v_{100} = 0.1 * \theta_t + 0.1 * 0.9 * \theta_{99} + 0.1 * 0.9^{2} \theta_{98}+\ldots+0.1 * 0.9^{99} \theta_{1}
\]

它考虑到了之前所有观测值,但是事件越靠近的观测值权重越大,时间越久远的观测值权重就很小了。

在 \(\beta = 0.9\)时,很多资料认为\(0.9^{10} \approx 0.35 \approx 1 / e\), 把这个数当成一个分界点,权重降低到这个分界点之下就可以忽略不计,而 \(\beta^{\frac{1}{1-\beta}} \approx 1 / e\) , 所以把上面两个公式合到一起就可以认为指数加权平均就是最近 \(N=\frac{1}{1-\beta}\)天的加权平均值

所以

  • \(\beta\) 越小, 加权平均的数据越少,就容易出现震荡
  • \(\beta\) 越大, 加权平均考虑的数据就越多,当出现震荡的时候会由于历史数据的权重导致震荡的幅度减小

Batch Gradient Descent (BGD)

BGD使用整个数据集来计算梯度,这里的损失函数是所有输入的样本数据的loss的和,单个样本的loss可以用交叉熵或者均方误差来计算。

\[\theta=\theta-\eta \cdot \nabla_{\theta} J(\theta)
\]

缺点是每次更新数据都需要计算整个数据集,速度很慢,不能实时的投入数据更新模型。对于凸函数可以收敛到全局最小值,对于非凸函数只能收敛到局部最小值。这是最朴素的优化器了

Stochastic Gradient Descent(SGD)

由于BGD计算梯度太过费时,SGD每次只计算一个样本的loss,然后更新参数。计算时可以先打乱数据,然后一条一条的将数据输入到模型中

\[\theta=\theta-\eta \cdot \nabla_{\theta} J\left(\theta ; x^{(i)} ; y^{(i)}\right)
\]

他的缺点是更新比较频繁,会有严重的震荡。

当我们稍微减小learning rate, SGD和BGD的收敛性是一样的

Mini-Batch Gradient Descent (MBGD)

每次接收batch个样本,然后计算它们的loss的和。

\[\theta=\theta-\eta \cdot \nabla_{\theta} J\left(\theta ; x^{(i: i+n)} ; y^{(i: i+n)}\right)
\]

对于鞍点, BGD会在鞍点附近停止更新,而MSGD会在鞍点周围来回震荡。

Monentum SGD

加入了v的概念,起到一个类似惯性的作用。在更新梯度的时候会照顾到之前已有的梯度。这里的\(v_t\)就是梯度的加权平均

\[\begin{array}{l}
v_{t}=\gamma v_{t-1}+\eta \nabla_{\theta} J(\theta) \\
\theta=\theta-v_{t}
\end{array}
\]

它可以在梯度方向不变的维度上使速度变快,在梯度方向有所改变的维度上更新速度更慢,可以抵消某些维度的摆动,加快收敛并减小震荡。\(\gamma\)一般取值为0.9

Nesterov Accelerated Gradient

它用 \(\theta-\gamma v_{t-1}\)来近似估计下一步 \(\theta\)会到达的位置

\[\begin{array}{l}
v_{t}=\gamma v_{t-1}+\eta \nabla_{\theta} J\left(\theta-\gamma v_{t-1}\right) \\
\theta=\theta-v_{t}
\end{array}
\]

能够让算法提前看到前方的地形梯度,如果前面的梯度比当前位置的梯度大,那我就可以把步子迈得比原来大一些,如果前面的梯度比现在的梯度小,那我就可以把步子迈得小一些

这个算法的公式竟然可以转化为下面的等价的公式:

\[\begin{array}{l}
d_{i}=\beta d_{i-1}+g\left(\theta_{i-1}\right)+\beta\left[g\left(\theta_{i-1}\right)-g\left(\theta_{i-2}\right)\right] \\
\theta_{i}=\theta_{i-1}-\alpha d_{i}
\end{array}
\]

后面的梯度相减可以认为是梯度的导数,也就是loss的二阶导数。也就是用二阶导数判断了一下曲线的趋势。其中 \(\gamma\)一般取值为0.9

Adagrad (Adaptive gradient algorithm)

可以对低频的参数做较大的更新,对高频的参数做较小的更新。

\[\theta_{t+1, i}=\theta_{t, i}-\frac{\eta}{\sqrt{G_{t, i i}+\epsilon}} \cdot g_{t, i}
\]

这个算法很有意思,G是在某个维度上,t从0开始到现在的所有梯度的平方和。所以对于经常更新的参数,学习率会越来越小,而对于不怎么更新的参数,他的学习率会变得相对更高。

\(\theta\)一般设置为0.01,他的缺点是分母会不断累计,最终学习率会变得非常小。如果初始梯度很大,会导致学习率变得很小。它适合用于稀疏数据。

Adadelta

对Adagrad的改进,对某个维度的历史维度进行平方、相加、开方

\[E\left[g^{2}\right]_{t}=\rho * E\left[g^{2}\right]_{t-1}+(1-\rho) * g_{t}^{2}
\]
\[x_{t+1}=x_{t}-\frac{\eta}{\sqrt{E\left[g^{2}\right]_{t}+\epsilon}} * g_{t}
\]
\[R M S\left(g_{t}\right)=\sqrt{E\left[g^{2}\right]_{t}+\epsilon}
\]

解决了历史梯度一直累加导致的学习率下降问题, \(\epsilon\) 是为了方式分母为0加上的极小值, \(rho\)一般取值为0.9

Adaptive Moment Estimation (Adam)

同时考虑了梯度的平方和梯度的指数衰减。建议\(\beta_1\)=0.9, \(\beta_2\)=0.999, \(\eta\)=10e-8

\[m_{t}=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t}
\]
\[v_{t}=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2}
\]
\[\begin{array}{l}
\hat{m}{t}=\frac{m{t}}{1-\beta_{1}^{t}},
\hat{v}{t}=\frac{v{t}}{1-\beta_{2}^{t}}
\end{array}
\]
\[\theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t}
\]

Adam取得了比其他方法更好的效果

总结

如果数据是稀疏的,就用自适用方法,即 Adagrad, Adadelta, RMSprop, Adam。

参考资料:

https://www.cnblogs.com/guoyaohua/p/8542554.html

https://arxiv.org/pdf/1609.04747.pdf

最新文章

  1. struts-标签
  2. SQL常用语句整理
  3. 百度地图Api进阶教程-点击生成和拖动标注4.html
  4. 编写高质量代码改善C#程序的157个建议[用抛异常替代返回错误、不要在不恰当的场合下引发异常、重新引发异常时使用inner Exception]
  5. window.open() 被拦截后的分析
  6. Matlab位运算笔记
  7. lintcode:形状工厂
  8. LeetCode: Next Permutation & Permutations1,2
  9. 原生javascript难点总结(1)---面向对象分析以及带来的思考
  10. date日期比较和格式化方法
  11. 【JS】学习18天Jquery Moblie的总结笔记。
  12. java复习(3)---字符串、数组
  13. CodeForces 816B Karen and Coffee(前缀和,大量查询)
  14. Hive:子查询
  15. 【easy-】437. Path Sum III 二叉树任意起始区间和
  16. DOM-基本概念及使用
  17. 【blog】SpringBoot事务
  18. eclipse开发Java web工程时,jsp第一行报错,如何解决?
  19. android-------开发常用框架汇总
  20. 去除eclipse的validating

热门文章

  1. 【原创】JDK 9-17新功能30分钟详解-语法篇-var
  2. Bert不完全手册7. 为Bert注入知识的力量 Baidu-ERNIE & THU-ERNIE & KBert
  3. ORA-01950: no privileges on tablespace 'USERS'-- 解决办法
  4. MySQL源码分析之SQL函数执行
  5. 【mido】python的midi处理库
  6. Linux上安装jdk 1.8
  7. 《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(15)-Fiddler弱网测试,知否知否,应是必知必会
  8. 牛客小白月赛51-C-E
  9. 运用Filebeat module分析nginx日志
  10. 8. 使用Fluentd+MongoDB采集Apache日志