三步理解--门控循环单元(GRU),TensorFlow实现
1. 什么是GRU
在循环神经⽹络中的梯度计算⽅法中,我们发现,当时间步数较⼤或者时间步较小时,循环神经⽹络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但⽆法解决梯度衰减的问题。通常由于这个原因,循环神经⽹络在实际中较难捕捉时间序列中时间步距离较⼤的依赖关系。
门控循环神经⽹络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较⼤的依赖关系。它通过可以学习的⻔来控制信息的流动。其中,门控循环单元(gatedrecurrent unit,GRU)是⼀种常⽤的门控循环神经⽹络。
2. ⻔控循环单元
2.1 重置门和更新门
GRU它引⼊了重置⻔(reset gate)和更新⻔(update gate)的概念,从而修改了循环神经⽹络中隐藏状态的计算⽅式。
门控循环单元中的重置⻔和更新⻔的输⼊均为当前时间步输⼊ \(X_t\) 与上⼀时间步隐藏状态\(H_{t-1}\),输出由激活函数为sigmoid函数的全连接层计算得到。 如下图所示:
具体来说,假设隐藏单元个数为 h,给定时间步 t 的小批量输⼊ \(X_t\in_{}\mathbb{R}^{n*d}\)(样本数为n,输⼊个数为d)和上⼀时间步隐藏状态 \(H_{t-1}\in_{}\mathbb{R}^{n*h}\)。重置⻔ \(H_t\in_{}\mathbb{R}^{n*h}\) 和更新⻔ \(Z_t\in_{}\mathbb{R}^{n*h}\) 的计算如下:
\[R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\]
\[Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\]
sigmoid函数可以将元素的值变换到0和1之间。因此,重置⻔ \(R_t\) 和更新⻔ \(Z_t\) 中每个元素的值域都是[0, 1]。
2.2 候选隐藏状态
接下来,⻔控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。我们将当前时间步重置⻔的输出与上⼀时间步隐藏状态做按元素乘法(符号为⊙)。如果重置⻔中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上⼀时间步的隐藏状态。如果元素值接近1,那么表⽰保留上⼀时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输⼊连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为[-1,1]。
具体来说,时间步 t 的候选隐藏状态 \(\tilde{H}\in_{}\mathbb{R}^{n*h}\) 的计算为:
\[\tilde{H}_t=tanh(X_tW_{xh}+(R_t⊙H_{t-1})W_{hh}+b_h)\]
从上⾯这个公式可以看出,重置⻔控制了上⼀时间步的隐藏状态如何流⼊当前时间步的候选隐藏状态。而上⼀时间步的隐藏状态可能包含了时间序列截⾄上⼀时间步的全部历史信息。因此,重置⻔可以⽤来丢弃与预测⽆关的历史信息。
2.3 隐藏状态
最后,时间步t的隐藏状态 \(H_t\in_{}\mathbb{R}^{n*h}\) 的计算使⽤当前时间步的更新⻔\(Z_t\)来对上⼀时间步的隐藏状态 \(H_{t-1}\) 和当前时间步的候选隐藏状态 \(\tilde{H}_t\) 做组合:
值得注意的是,更新⻔可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如上图所⽰。假设更新⻔在时间步 \(t^{′}到t(t^{′}<t)\) 之间⼀直近似1。那么,在时间步 \(t^{′}到t\) 间的输⼊信息⼏乎没有流⼊时间步 t 的隐藏状态\(H_t\)实际上,这可以看作是较早时刻的隐藏状态 \(H_{t^{′}-1}\) 直通过时间保存并传递⾄当前时间步 t。这个设计可以应对循环神经⽹络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较⼤的依赖关系。
我们对⻔控循环单元的设计稍作总结:
- 重置⻔有助于捕捉时间序列⾥短期的依赖关系;
- 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。
3. 代码实现GRU
4. 参考文献
作者:@mantchs
GitHub:https://github.com/NLP-LOVE/ML-NLP
欢迎大家加入讨论!共同完善此项目!群号:【541954936】
最新文章
- AC 设置DMZ口上网
- struts2 CVE-2013-2251 S2-016 action、redirect code injection remote command execution
- Java提高篇——equals()方法和“==”运算符
- JavaScript input file上传前获取文件名、文件类型、文件大小等信息
- IDisposable接口
- 如何搭建Struts2环境
- 【BZOJ】【2005】【NOI2010】能量采集
- Uva 315 Network 判断割点
- Java基础知识强化之集合框架笔记02:集合的继承体系图解
- 【剑指offer】面试题37:两个链表的第一个公共结点
- SQl 判断 表 视图 临时表等 是否存在
- HDU 4391 Paint The Wall 段树(水
- JVM基础01-内存分配
- ubuntu下svn的命令使用
- Jmeter性能测试之进阶BeanShell的使用
- 在IDEA中配置Spring的XML装配
- node.js中express框架的基本使用
- Django REST framework基础:版本、认证、权限、限制
- Android -- taskAffinity
- linux的浅谈io操作
热门文章
- C语言学习推荐《C语言参考手册(原书第5版)》下载
- Adobe全系软件下载安装工具 CCMaker 1.3.6
- 基于SpringBoot的Web API快速开发基础框架
- MongoDB基础教程[菜鸟教程整理]
- Learning the Depths of Moving People by Watching Frozen
- ctrl shift o失效
- 【原】深度学习的一些经验总结和建议 | To do v.s Not To Do
- C#2.0新增功能07 getter/setter 单独可访问性
- SQLServer 问题(一)
- IO-Java实现文件的复制