终于来到了最终的大BOSS,卷积神经网络~

这里我想还是主要关注代码的实现,具体的CNN的知识点想以后在好好写一写,CNN的代码关键就是要加上卷积层和池话层.

一、卷积层

  卷积层的前向传播还是比较容易的,我们主要关注的是反向传播,看下图就知道了:

  

def conv_forward_naive(x, w, b, conv_param):
stride, pad = conv_param['stride'], conv_param['pad']
N, C, H, W = x.shape
F, C, HH, WW = w.shape
x_padded = np.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode='constant') #补零
H_new = 1 + (H + 2 * pad - HH) / stride
W_new = 1 + (W + 2 * pad - WW) / stride
s = stride
out = np.zeros((N, F, H_new, W_new)) for i in xrange(N): # ith image
for f in xrange(F): # fth filter
for j in xrange(H_new):
for k in xrange(W_new):
out[i, f, j, k] = np.sum(x_padded[i, :, j*s:HH+j*s, k*s:WW+k*s] * w[f]) + b[f]#对应位相乘 cache = (x, w, b, conv_param) return out, cache def conv_backward_naive(dout, cache):
x, w, b, conv_param = cache
pad = conv_param['pad']
stride = conv_param['stride']
F, C, HH, WW = w.shape
N, C, H, W = x.shape
H_new = 1 + (H + 2 * pad - HH) / stride
W_new = 1 + (W + 2 * pad - WW) / stride dx = np.zeros_like(x)
dw = np.zeros_like(w)
db = np.zeros_like(b) s = stride
x_padded = np.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant')
dx_padded = np.pad(dx, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant') for i in xrange(N): # ith image
for f in xrange(F): # fth filter
for j in xrange(H_new):
for k in xrange(W_new):
window = x_padded[i, :, j*s:HH+j*s, k*s:WW+k*s]
db[f] += dout[i, f, j, k]
dw[f] += window * dout[i, f, j, k]
dx_padded[i, :, j*s:HH+j*s, k*s:WW+k*s] += w[f] * dout[i, f, j, k]#上面的式子,关键就在于+号 # Unpad
dx = dx_padded[:, :, pad:pad+H, pad:pad+W] return dx, dw, db

  和http://www.cnblogs.com/tornadomeet/p/3468450.html中提到的一样,卷积层的BP算法就是这么计算的,也就是一个正统的卷积操作

二、pooling层

  

def max_pool_forward_naive(x, pool_param):
HH, WW = pool_param['pool_height'], pool_param['pool_width']
s = pool_param['stride']
N, C, H, W = x.shape
H_new = 1 + (H - HH) / s
W_new = 1 + (W - WW) / s
out = np.zeros((N, C, H_new, W_new))
for i in xrange(N):
for j in xrange(C):
for k in xrange(H_new):
for l in xrange(W_new):
window = x[i, j, k*s:HH+k*s, l*s:WW+l*s]
out[i, j, k, l] = np.max(window) cache = (x, pool_param) return out, cache def max_pool_backward_naive(dout, cache):
x, pool_param = cache
HH, WW = pool_param['pool_height'], pool_param['pool_width']
s = pool_param['stride']
N, C, H, W = x.shape
H_new = 1 + (H - HH) / s
W_new = 1 + (W - WW) / s
dx = np.zeros_like(x)
for i in xrange(N):
for j in xrange(C):
for k in xrange(H_new):
for l in xrange(W_new):
window = x[i, j, k*s:HH+k*s, l*s:WW+l*s]
m = np.max(window) #获得之前的那个值,这样下面只要windows==m就能得到相应的位置
dx[i, j, k*s:HH+k*s, l*s:WW+l*s] = (window == m) * dout[i, j, k, l] return dx

三、与之前的区别

  这里BN算法与之前是不太一样的,因为网络的输入变成了saptail的

  

def spatial_batchnorm_forward(x, gamma, beta, bn_param):
N, C, H, W = x.shape
x_new = x.transpose(0, 2, 3, 1).reshape(N*H*W, C)#分成不同的channel来算,所以可以直接用之前的代码
out, cache = batchnorm_forward(x_new, gamma, beta, bn_param)
out = out.reshape(N, H, W, C).transpose(0, 3, 1, 2) return out, cache def spatial_batchnorm_backward(dout, cache):
N, C, H, W = dout.shape
dout_new = dout.transpose(0, 2, 3, 1).reshape(N*H*W, C)
dx, dgamma, dbeta = batchnorm_backward(dout_new, cache)
dx = dx.reshape(N, H, W, C).transpose(0, 3, 1, 2) return dx, dgamma, dbeta

四、总结

  assignment2终于弄完了,总的来说..numpy还是要多熟悉,具体的操作也要熟悉。卷积层的前向传播很好理解,反向传播和之前的区别不大,只不过需要做一个卷积的操作。

最新文章

  1. Oracle 11g RAC 应用补丁简明版
  2. linux c++应用程序内存高或者占用CPU高的解决方案_20161213
  3. node04-buffer
  4. 关于android4.3 Intel X86 Atom System Image的下载
  5. php 练习
  6. 使用checked关键字处理“溢出”错误
  7. [Swift系列]002-基础语法
  8. linux/windows系统oracle数据库简单冷备同步
  9. linux软与硬接线连接
  10. sqlserver删除重复的数据
  11. HTML5本地存储应用sessionStorage和localStorage
  12. R语言读取JSON数据
  13. Java XML JSON 数据解析
  14. uvalive 11865 Stream My Contest
  15. (GoRails) 使用ActiveStorage给user添加上传头像功能。
  16. dom响应事件
  17. Oracle EBS OPM 查询现有量
  18. Ubuntu中安装配置和卸载FTP(转)
  19. Mac安装homebrew安装到指定目录
  20. C++17中那些值得关注的特性(上)

热门文章

  1. webpack最小化lodash
  2. layer 点击弹出图片
  3. php - empty() is_null() isset()的区别
  4. PHP提取奇数或偶数下标元素
  5. python之获取微信好友列表并保存文档中
  6. java.lang.NoSuchMethodError: cn.makangning.test.dao.Users.getUserBirthday()Ljava/sql/Date;
  7. 15.4,redis不重启,切换RDB备份到AOF备份
  8. 为什么工具类App,都要做一个社区?
  9. thinkphp3.1.3验证码优化
  10. IDEA调试快捷键