BN层在实际中应用广泛。

上一次总结了使得训练变得简单的方法,比如SGD+momentum RMSProp Adam,BN是另外的方法。

cell 1 依旧是初始化设置

cell 2 读取cifar-10数据

cell 3 BN的前传

 # Check the training-time forward pass by checking means and variances
# of features both before and after batch normalization # Simulate the forward pass for a two-layer network
N, D1, D2, D3 = 200, 50, 60, 3
X = np.random.randn(N, D1)
W1 = np.random.randn(D1, D2)
W2 = np.random.randn(D2, D3)
a = np.maximum(0, X.dot(W1)).dot(W2) print 'Before batch normalization:'
print ' means: ', a.mean(axis=0)
print ' stds: ', a.std(axis=0) # Means should be close to zero and stds close to one
print 'After batch normalization (gamma=1, beta=0)'
a_norm, _ = batchnorm_forward(a, np.ones(D3), np.zeros(D3), {'mode': 'train'})
print ' mean: ', a_norm.mean(axis=0)
print ' std: ', a_norm.std(axis=0) # Now means should be close to beta and stds close to gamma
gamma = np.asarray([1.0, 2.0, 3.0])
beta = np.asarray([11.0, 12.0, 13.0])
a_norm, _ = batchnorm_forward(a, gamma, beta, {'mode': 'train'})
print 'After batch normalization (nontrivial gamma, beta)'
print ' means: ', a_norm.mean(axis=0)
print ' stds: ', a_norm.std(axis=0)

  相应的核心代码:

     buf_mean = np.mean(x, axis=0)
buf_var = np.var(x, axis=0)
x_hat = x - buf_mean
x_hat = x_hat / (np.sqrt(buf_var + eps)) out = gamma * x_hat + beta
#running_mean = momentum * running_mean + (1 - momentum) * sample_mean
#running_var = momentum * running_var + (1 - momentum) * sample_var
running_mean = momentum * running_mean + (1- momentum) * buf_mean
running_var = momentum * running_var + (1 - momentum) * buf_var

  running_mean  running_var 是在test时使用的,test时不再另外计算均值和方差。

  test 时的前传核心代码:

 x_hat = x - running_mean
x_hat = x_hat / (np.sqrt(running_var + eps))
out = gamma * x_hat + beta

cell 5 BN反向传播

  通过反向传播,计算beta gamma等参数。

  核心代码:

   dx_hat = dout * cache['gamma']
dgamma = np.sum(dout * cache['x_hat'], axis=0)
dbeta = np.sum(dout, axis=0)
#x_hat = x - buf_mean
#x_hat = x_hat / (np.sqrt(buf_var + eps))
t1 = cache['x'] - cache['mean']
t2 = (-0.5)*((cache['var'] + cache['eps'])**(-1.5))
t1 = t1 * t2
d_var = np.sum(dx_hat * t1, axis=0) tmean1 = (-1)*((cache['var'] + cache['eps'])**(-0.5))
d_mean = np.sum(dx_hat * tmean1, axis=0) tmean1 = (-1)*tmean1
tx1 = dx_hat * tmean1
tx2 = d_mean * (1.0 / float(N))
tx3 = d_var * (2 * (cache['x'] - cache['mean']) / N)
dx = tx1 + tx2 + tx3

cell 9 BN与其他层结合

  形成的结构:   {affine - [batch norm] - relu - [dropout]} x (L - 1) - affine - softmax

  原理依旧。

之后是对cell 9 的模型,对cifar-10数据训练。

值得注意的是:

  使用BN后,正则项与dropout层的需求降低。可以使用较高的学习率加快模型收敛。

附:通关CS231n企鹅群:578975100 validation:DL-CS231n

最新文章

  1. C#:结构
  2. html table表头斜线
  3. 【BZOJ3732】 Network Kruskal+倍增lca
  4. Python3 字符串
  5. JDE报表开发笔记(数据选择及继承)
  6. C#.Net中的转义字符
  7. ZOJ 2745 01-K Code(DP)(转)
  8. scalajs_初体验
  9. bzoj:1656 [Usaco2006 Jan] The Grove 树木
  10. esb和eai的区别
  11. 初识 go 语言:数据类型
  12. Convert List<Entity> to Json String.
  13. Python-类的组合与重用
  14. 【Core】.NET Core 部署( Docker + CentOS)
  15. python自动化测试入门篇-jemter
  16. Hadoop HBase概念学习系列之META表和ROOT表(六)
  17. 第一Sprint阶段对各组提出的意见
  18. 试着用React写项目-利用react-router解决跳转路由等问题(三)
  19. linux内核分析 第八周读书笔记
  20. springmvc web.xml配置之 -- ContextLoaderListener

热门文章

  1. 江湖问题研究-- intent传递有没有限制大小,是多少?
  2. ffplay 播放m3u8 hls Failed to open segment of playlist 0
  3. wdcp新开站点或绑定域名打不开或无法访问的问题
  4. 简单的ftp服务器
  5. 记使用WaitGroup时的一个错误
  6. Elasticsearch + Logstash + Kibana 搭建教程
  7. 九度OJ 1090:路径打印 (树、DFS)
  8. splittability A SequenceFile can be split by Hadoop and distributed across map jobs whereas a GZIP file cannot be.
  9. 对ShortCut和TWMKey的研究
  10. Learning Scrapy 中文版翻译 第二章