转载请注明出处:

https://www.cnblogs.com/darkknightzh/p/9410540.html

论文:

MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

网址:

https://arxiv.org/abs/1704.04861?context=cs

非官方的pytorch代码:

https://github.com/marvis/pytorch-mobilenet

1. 深度可分离卷积

mobilenetV1使用的是深度可分离卷积(Depthwise Separable Convolution,DSC),DSC包含两部分:depthwise convolution(DWC)+ pointwise convolution(PWC)。DWC对输入的通道进行滤波,其不增加通道的数量,PWC用于将PWC不同的通道进行连接,其可以增加通道的数量。通过这种分解的方式,可以明显的减少计算量。

如下图所示,传统的卷积(a),卷积核参数为${{D}_{K}}\centerdot {{D}_{K}}\centerdot M\centerdot N$,其中${{D}_{K}}$为卷积核大小,M为输入的通道数,N为输出的通道数。DWC(b)中卷积核参数为${{D}_{K}}\centerdot {{D}_{K}}\centerdot 1\centerdot M$,其中M个${{D}_{K}}\centerdot {{D}_{K}}$的核和输入特征的对应通道进行卷积,如下式所示。PWC(c)中卷积核参数为$1\centerdot 1\centerdot M\centerdot N$,每个卷积核在特征维度上分别对输入的M个特征进行加权,最终得到N个特征(M≠N时,完成了升维或者降维)。

${{\mathbf{\hat{G}}}_{k,l,m}}=\sum\limits_{i,j}{{{{\mathbf{\hat{K}}}}_{k,l,m}}\centerdot {{\mathbf{F}}_{k+i-1,l+j-1,m}}}$

传统卷积的计算量为:

${{D}_{K}}\centerdot {{D}_{K}}\centerdot M\centerdot N\centerdot {{D}_{F}}\centerdot {{D}_{F}}$

DSC总共的计算量为:

${{D}_{K}}\centerdot {{D}_{K}}\centerdot M\centerdot {{D}_{F}}\centerdot {{D}_{F}}+M\centerdot N\centerdot {{D}_{F}}\centerdot {{D}_{F}}$

当使用3*3的卷积核时,DSC可将计算量降低为原来的1/8到1/9。

需要说明的是,DWC,PWC后面均有BN和ReLU。如下图所示,传统的卷积层为3*3conv+BN+ReLU,Depthwise Separable convolutions为3*3DWC+BN+ReLU+1*1conv+BN+ReLU。

2. 网络结构

mobileNetV1的网络结构如下图所示。其中第一个卷积层为传统的卷积;前面的卷积层均有bn和relu,最后一个全连接层只有BN,无ReLU。

mobileNetV1使用RMSprop训练;由于参数很少,DWC使用比较小的或者不使用weight decay(l2 regularization)。

3. 宽度缩放因子(width multiplier)

文中引入了$\alpha $作为宽度缩放因子,其作用是在整体上对网络的每一层维度(特征数量)进行瘦身。$\alpha $影响模型的参数数量及前向计算时的乘加次数。此时网络每一层的输入为$\alpha M$维,输出为$\alpha N$维。此时DSC的计算量变为:

${{D}_{K}}\centerdot {{D}_{K}}\centerdot \alpha M\centerdot {{D}_{F}}\centerdot {{D}_{F}}+\alpha M\centerdot \alpha N\centerdot {{D}_{F}}\centerdot {{D}_{F}}$

$\alpha \in (0,1]$,典型值为1,0.75,0.5,0.25。

4. 分辨率缩放因子(resolution multiplier)

该因子即为$\rho $,用于降低输入图像的分辨率(如将224*224降低到192*192,160*160,128*128)。

此时DSC的计算量变为:

${{D}_{K}}\centerdot {{D}_{K}}\centerdot \alpha M\centerdot \rho {{D}_{F}}\centerdot \rho {{D}_{F}}+\alpha M\centerdot \alpha N\centerdot \rho {{D}_{F}}\centerdot \rho {{D}_{F}}$

5. pytorch代码

pytorch代码见参考网址中benchmark.py

 class MobileNet(nn.Module):
def __init__(self):
super(MobileNet, self).__init__() def conv_bn(inp, oup, stride): # 第一层传统的卷积:conv3*3+BN+ReLU
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
) def conv_dw(inp, oup, stride): # 其它层的depthwise convolution:conv3*3+BN+ReLU+conv1*1+BN+ReLU
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
) self.model = nn.Sequential(
conv_bn( 3, 32, 2), # 第一层传统的卷积
conv_dw( 32, 64, 1), # 其它层depthwise convolution
conv_dw( 64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000) # 全连接层 def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x

最新文章

  1. RandHelper
  2. [综]聚类Clustering
  3. 打造 html5 文件上传组件,实现进度显示及拖拽上传,支持秒传+分片上传+断点续传,兼容IE6+及其它标准浏览器
  4. nginx查看安装了哪些模块
  5. Mybatis Physical Pagination
  6. Qt的QLineEdit显示密码
  7. Relativelayout属性
  8. Excel日期格式单元格写成yyyy.MM.dd格式将无法读取到DataTable
  9. BZOJ 1059 矩阵游戏
  10. 1645: [Usaco2007 Open]City Horizon 城市地平线
  11. 登录模块的进化史,带大家回顾java学习历程(二)
  12. BloomFilter——大规模数据处理利器
  13. Python加密保护-对可执行的exe进行保护
  14. Linux 解压/压缩xxx.zip格式(unZip Zip的安装和使用)
  15. Java 并发编程(二)对象的不变性和安全的公布对象
  16. MySQL自动设置create_time和update_time
  17. torando-ioloop生命周期
  18. HexDump.java解析,android 16进制转换
  19. python-day63--前端
  20. nginx反向代理proxy_set_header自定义header头无效

热门文章

  1. brew装snappy
  2. JVM启动过程
  3. python全栈开发day42-固定定位等
  4. post请求测试代码
  5. P1590 失踪的7
  6. ogg - 从oracle到mysql的同步
  7. 【python学习-6】异常处理
  8. [漏洞分析]thinkphp 5.1.25 insert、insetAll、update方法注入
  9. Alpha测试
  10. Postman使用记录