mobilenet v1

论文解读

论文地址:https://arxiv.org/abs/1704.04861

核心思想就是通过depthwise conv替代普通conv.



有关depthwise conv可以参考https://www.cnblogs.com/sdu20112013/p/11759928.html

模型结构:



类似于vgg这种堆叠的结构.

每一层的运算量



可以看到,运算量并不是与参数数量绝对成正比,当然整体趋势而言,参数量更少的模型会运算更快.

代码实现

https://github.com/marvis/pytorch-mobilenet

网络结构:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__() def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
) def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
) self.model = nn.Sequential(
conv_bn( 3, 32, 2),
conv_dw( 32, 64, 1),
conv_dw( 64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000) def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x

参考论文中的结构,第一层是普通的卷积层,后面接的都是可分离卷积.

这里注意groups参数的用法. 当groups=输入channel数目时,即对每个channel分别做卷积.默认groups=1,此时即为普通卷积.

训练伪代码

# create model
model = Net() # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay) # load data
train_loader = torch.utils.data.DataLoader() # train
for every epoch:
input,target=get_from_data #前向传播得到预测值
output = model(input_var) #计算loss
loss = criterion(output, target_var) #反向传播更新网络参数
optimizer.zero_grad()
loss.backward()
optimizer.step()

最新文章

  1. gulp压缩css文件跟js文件
  2. Jetty嵌入式Web容器攻略
  3. 統計分析dbms_stats包与analyze 的区别
  4. 20150511---Timer计时器(备忘)
  5. Mybatis拦截器介绍
  6. 关于 mkimage
  7. 【Python】不定期更新学习小问题整理
  8. Mobile Computing-天平难题-Uva1354(回溯枚举二叉树)
  9. freemarker的非空判断
  10. CI框架学习——检查用户名与密码是否合法(二)
  11. 高通android开发摘要
  12. 【转】Setting up SDL 2 on Visual Studio 2010 Ultimate
  13. 一个box四周边框阴影
  14. C语言内存四区的学习总结(二)---- 堆区
  15. [SHOI2001]化工厂装箱员(dp?暴力:暴力)
  16. 一个 react 小的 demo
  17. [01] Why Spring
  18. 转:安装PHP出现make: *** [sapi/cli/php] Error 1 解决办法
  19. 实现一个简易的vue的mvvm(defineProperty)
  20. Nodejs编写复制文件及文件夹命令

热门文章

  1. Java线程池Executor&ThreadPool
  2. spring与logstash整合,并将数据传输到Elasticsearch
  3. 使用最新版Mybatis逆向工程生成属性不全的问题
  4. 美化H标签
  5. Robot Framework自定义测试库的作用域的理解
  6. Python IAQ中文版 - Python中少有人回答的问题
  7. @DateTimeFormat注解
  8. Java 多线程爬虫及分布式爬虫架构探索
  9. Vue入门教程 第四篇 (属性与事件)
  10. linux netlink通信机制简介