CS231n 2016 通关 第五、六章 Batch Normalization 作业
2024-10-08 02:46:07
BN层在实际中应用广泛。
上一次总结了使得训练变得简单的方法,比如SGD+momentum RMSProp Adam,BN是另外的方法。
cell 1 依旧是初始化设置
cell 2 读取cifar-10数据
cell 3 BN的前传
# Check the training-time forward pass by checking means and variances
# of features both before and after batch normalization # Simulate the forward pass for a two-layer network
N, D1, D2, D3 = 200, 50, 60, 3
X = np.random.randn(N, D1)
W1 = np.random.randn(D1, D2)
W2 = np.random.randn(D2, D3)
a = np.maximum(0, X.dot(W1)).dot(W2) print 'Before batch normalization:'
print ' means: ', a.mean(axis=0)
print ' stds: ', a.std(axis=0) # Means should be close to zero and stds close to one
print 'After batch normalization (gamma=1, beta=0)'
a_norm, _ = batchnorm_forward(a, np.ones(D3), np.zeros(D3), {'mode': 'train'})
print ' mean: ', a_norm.mean(axis=0)
print ' std: ', a_norm.std(axis=0) # Now means should be close to beta and stds close to gamma
gamma = np.asarray([1.0, 2.0, 3.0])
beta = np.asarray([11.0, 12.0, 13.0])
a_norm, _ = batchnorm_forward(a, gamma, beta, {'mode': 'train'})
print 'After batch normalization (nontrivial gamma, beta)'
print ' means: ', a_norm.mean(axis=0)
print ' stds: ', a_norm.std(axis=0)
相应的核心代码:
buf_mean = np.mean(x, axis=0)
buf_var = np.var(x, axis=0)
x_hat = x - buf_mean
x_hat = x_hat / (np.sqrt(buf_var + eps)) out = gamma * x_hat + beta
#running_mean = momentum * running_mean + (1 - momentum) * sample_mean
#running_var = momentum * running_var + (1 - momentum) * sample_var
running_mean = momentum * running_mean + (1- momentum) * buf_mean
running_var = momentum * running_var + (1 - momentum) * buf_var
running_mean running_var 是在test时使用的,test时不再另外计算均值和方差。
test 时的前传核心代码:
x_hat = x - running_mean
x_hat = x_hat / (np.sqrt(running_var + eps))
out = gamma * x_hat + beta
cell 5 BN反向传播
通过反向传播,计算beta gamma等参数。
核心代码:
dx_hat = dout * cache['gamma']
dgamma = np.sum(dout * cache['x_hat'], axis=0)
dbeta = np.sum(dout, axis=0)
#x_hat = x - buf_mean
#x_hat = x_hat / (np.sqrt(buf_var + eps))
t1 = cache['x'] - cache['mean']
t2 = (-0.5)*((cache['var'] + cache['eps'])**(-1.5))
t1 = t1 * t2
d_var = np.sum(dx_hat * t1, axis=0) tmean1 = (-1)*((cache['var'] + cache['eps'])**(-0.5))
d_mean = np.sum(dx_hat * tmean1, axis=0) tmean1 = (-1)*tmean1
tx1 = dx_hat * tmean1
tx2 = d_mean * (1.0 / float(N))
tx3 = d_var * (2 * (cache['x'] - cache['mean']) / N)
dx = tx1 + tx2 + tx3
cell 9 BN与其他层结合
形成的结构: {affine - [batch norm] - relu - [dropout]} x (L - 1) - affine - softmax
原理依旧。
之后是对cell 9 的模型,对cifar-10数据训练。
值得注意的是:
使用BN后,正则项与dropout层的需求降低。可以使用较高的学习率加快模型收敛。
附:通关CS231n企鹅群:578975100 validation:DL-CS231n
最新文章
- C#:结构
- html table表头斜线
- 【BZOJ3732】 Network Kruskal+倍增lca
- Python3 字符串
- JDE报表开发笔记(数据选择及继承)
- C#.Net中的转义字符
- ZOJ 2745 01-K Code(DP)(转)
- scalajs_初体验
- bzoj:1656 [Usaco2006 Jan] The Grove 树木
- esb和eai的区别
- 初识 go 语言:数据类型
- Convert List<;Entity>; to Json String.
- Python-类的组合与重用
- 【Core】.NET Core 部署( Docker + CentOS)
- python自动化测试入门篇-jemter
- Hadoop HBase概念学习系列之META表和ROOT表(六)
- 第一Sprint阶段对各组提出的意见
- 试着用React写项目-利用react-router解决跳转路由等问题(三)
- linux内核分析 第八周读书笔记
- springmvc web.xml配置之 -- ContextLoaderListener
热门文章
- 江湖问题研究-- intent传递有没有限制大小,是多少?
- ffplay 播放m3u8 hls Failed to open segment of playlist 0
- wdcp新开站点或绑定域名打不开或无法访问的问题
- 简单的ftp服务器
- 记使用WaitGroup时的一个错误
- Elasticsearch + Logstash + Kibana 搭建教程
- 九度OJ 1090:路径打印 (树、DFS)
- splittability A SequenceFile can be split by Hadoop and distributed across map jobs whereas a GZIP file cannot be.
- 对ShortCut和TWMKey的研究
- Learning Scrapy 中文版翻译 第二章