1.文章原文地址

ImageNet Classification with Deep Convolutional Neural Networks

2.文章摘要

我们训练了一个大型的深度卷积神经网络用于在ImageNet LSVRC-2010竞赛中,将120万(12百万)的高分辨率图像进行1000个类别的分类。在测试集上,网络的top-1和top-5误差分别为37.5%和17.0%,这结果极大的优于先前的最好结果。这个拥有6千万(60百万)参数和65万神经元的神经网络包括了五个卷积层,其中一些卷积层后面会跟着最大池化层,以及三个全连接层,其中全连接层是以1000维的softmax激活函数结尾的。为了可以训练的更快,我们使用了非饱和神经元(如Relu,激活函数输出没有将其限定在特定范围)和一个非常高效的GPU来完成卷积运算,为了减少过拟合,我们在全连接层中使用了近期发展起来的一种正则化方式,即dropout,它被证明是非常有效的。我们也使用了该模型的一个变体用于ILSVRC-2012竞赛中,并且以top-5的测试误差为15.3赢得比赛,该比赛中第二名的top-5测试误差为26.2%。

3.网络结构

4.Pytorch实现

 import torch.nn as nn
from torchsummary import summary try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
} class AlexNet(nn.Module):
def __init__(self,num_classes=1000):
super(AlexNet,self).__init__()
self.features=nn.Sequential(
nn.Conv2d(3,96,kernel_size=11,stride=4,padding=2), #(224+2*2-11)/4+1=55
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2), #(55-3)/2+1=27
nn.Conv2d(96,256,kernel_size=5,stride=1,padding=2), #(27+2*2-5)/1+1=27
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2), #(27-3)/2+1=13
nn.Conv2d(256,384,kernel_size=3,stride=1,padding=1), #(13+1*2-3)/1+1=13
nn.ReLU(inplace=True),
nn.Conv2d(384,384,kernel_size=3,stride=1,padding=1), #(13+1*2-3)/1+1=13
nn.ReLU(inplace=True),
nn.Conv2d(384,256,kernel_size=3,stride=1,padding=1), #13+1*2-3)/1+1=13
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2), #(13-3)/2+1=6
) #6*6*256=9126 self.avgpool=nn.AdaptiveAvgPool2d((6,6))
self.classifier=nn.Sequential(
nn.Dropout(),
nn.Linear(256*6*6,4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096,4096),
nn.ReLU(inplace=True),
nn.Linear(4096,num_classes),
) def forward(self,x):
x=self.features(x)
x=self.avgpool(x)
x=x.view(x.size(0),-1)
x=self.classifier(x)
return x def alexnet(pretrain=False,progress=True,**kwargs):
r"""
Args:
pretrained(bool):If True, retures a model pre-trained on IMageNet
progress(bool):If True, displays a progress bar of the download to stderr
"""
model=AlexNet(**kwargs)
if pretrain:
state_dict=load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model if __name__=="__main__":
model=alexnet()
print(summary(model,(3,224,224)))
 Output:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 96, 55, 55] 34,944
ReLU-2 [-1, 96, 55, 55] 0
MaxPool2d-3 [-1, 96, 27, 27] 0
Conv2d-4 [-1, 256, 27, 27] 614,656
ReLU-5 [-1, 256, 27, 27] 0
MaxPool2d-6 [-1, 256, 13, 13] 0
Conv2d-7 [-1, 384, 13, 13] 885,120
ReLU-8 [-1, 384, 13, 13] 0
Conv2d-9 [-1, 384, 13, 13] 1,327,488
ReLU-10 [-1, 384, 13, 13] 0
Conv2d-11 [-1, 256, 13, 13] 884,992
ReLU-12 [-1, 256, 13, 13] 0
MaxPool2d-13 [-1, 256, 6, 6] 0
AdaptiveAvgPool2d-14 [-1, 256, 6, 6] 0
Dropout-15 [-1, 9216] 0
Linear-16 [-1, 4096] 37,752,832
ReLU-17 [-1, 4096] 0
Dropout-18 [-1, 4096] 0
Linear-19 [-1, 4096] 16,781,312
ReLU-20 [-1, 4096] 0
Linear-21 [-1, 1000] 4,097,000
================================================================
Total params: 62,378,344
Trainable params: 62,378,344
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 11.16
Params size (MB): 237.95
Estimated Total Size (MB): 249.69
----------------------------------------------------------------

参考

https://github.com/pytorch/vision/tree/master/torchvision/models

最新文章

  1. 关于在VS 上发布网站
  2. Windows Server 2008R2服务器安装及设置教程
  3. c& c++ enum
  4. 关于git的文件内容冲突解决
  5. 检测android的网络链接状态
  6. 使用 stvd 编译STM8S 时能看到使用RAM ROM大小的方法
  7. appium+python环境搭建
  8. 织梦5.7sp1最新问题:后台不显示编辑器
  9. 从零开始学安全(十九)●PHP数组函数
  10. python中使用for循环,while循环,一条命令打印99乘法表
  11. BZOJ4832[Lydsy1704月赛]抵制克苏恩——期望DP
  12. 微信出现BUG,发送“ 两位数字+15个句号 ”,双方系统会卡崩……
  13. [转载]Frontend Knowledge Structure
  14. Window下Latex加速编译方法以及西农毕设论文模板推荐
  15. js 获取高度
  16. JS在if中的强制类型转换
  17. linux中的 tar命令的 -C 参数,以及其它一些参数
  18. 【原创】CRM 2015/2016,SSRS 生成PDF文件,幷以附件的形式发送邮件
  19. 2017百度春招<有趣的排序>
  20. 获取ios设备系统信息的方法 之 [UIDevice currentDevice]

热门文章

  1. Github克隆代码慢问题解决办法
  2. 【ARTS】01_46_左耳听风-201900923~201900929
  3. centos下通过yum安装redis-cli
  4. 微信小程序 与后台交互----传递和回传时间
  5. [bzoj3829][Poi2014]FarmCraft_树形dp
  6. myeclipse 相关问题
  7. Spring Boot 入门(九):集成Quartz定时任务
  8. MATLAB 单元数组 cell 和结构体 struct 的用法以及区别
  9. vim 常用命令总结(排版精良,内容优质)
  10. golang 切片使用注意事项