cifar10主要是由32x32的三通道彩色图, 总共10个类别,这里我们使用残差网络构造网络结构

网络结构:

第一层:首先经过一个卷积,归一化,激活 32x32x16 -> 32x32x16

第二层:  通过一多个残差模型

残差模块的网络构造:

如果stride != 1 or in_channel != out_channel, 就构造downsample网络结构进行降采样操作

利用残差模块进行第一次残差卷积, 将downsample传入

连续进行多次的残差卷积

from torchvision import transforms
from torch import nn
# 首先对图片进行数据转换 train_transform = transforms.Compose([
transforms.Scale(40), # 相当于是resize操作,
transforms.RandomHorizontalFlip(), # 表示进行左右的翻转
transforms.RandomCrop(32), #表示进行随机的裁剪
transforms.ToTensor(), # 将数据转换为tensor格式
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 进行-均值 / 标准差, 将数据转换为-1, 1 之间 ]) test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]) def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False) class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(True)
self.conv2 = conv3x3(out_channels, out_channels, stride=1)
self.bn = nn.BatchNorm2d(out_channels)
self.downsample = downsample def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn(x)
out = self.relu(x)
out = self.conv2(x)
out = self.bn(x)
if self.downsample:
residual = self.downsample(x)
out += residual
return self.relu(out) class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(True)
self.layers1 = self.make_block(block, 16, layers[0])
self.layers2 = self.make_block(block, 32, layers[0])
self.layers3 = self.make_block(block, 64, layers[1])
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes) def make_block(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or out_channels != self.in_channels:
downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, stride=stride, downsample = downsample))
for i in blocks:
layers.append(block(self.out_channels, out_channels, stride=stride, downsample=downsample)) return nn.Sequential(*layers) def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layers1(out)
out= self.layers2(out)
out = self.layers3(out)
out = self.avg_pool(out)
out = self.fc(out) return out

最新文章

  1. EXTJS中grid的数据特殊显示,不同窗口的数据传递
  2. Divide and Conquer:Cable Master(POJ 1064)
  3. WPF中的VisualTreeHelper
  4. XAML 概述二
  5. hdoj 2097 Sky数
  6. 4.MySQL连接并选择数据库(SQL & C)
  7. 解决Mysql的主从数据库没有同步的两种方法
  8. itunes备份文件解析入门
  9. python Eve RESTFul 尝试笔记
  10. Linux学习之/etc/init.d/functions详解
  11. [转]Publishing and Running ASP.NET Core Applications with IIS
  12. PHP基础入门(三)---PHP函数基础
  13. 云游戏学习与实践(二)——安装GamingAnywhere
  14. js中多维数组转一维
  15. spring IOC 分析及实现
  16. ADB抓取内存命令
  17. cobbler学习
  18. Swift5 语言指南(九) 闭包
  19. mysql pdo设置显示报错
  20. js取float型小数点后x位数的方法

热门文章

  1. MySQL存储引擎MyISAM和InnoDB,索引结构优缺点
  2. Linux ppp 数据收发流程
  3. C#DataGrid列值出现E形式的小数,将DataGrid表格上的数据保存至数据库表时会因格式转换不正确导致报错
  4. 团队第三次作业:Alpha版本发布
  5. es的相关知识二(检索文档)
  6. 使用expect登录批量拷贝本地文件到多个目标主机
  7. JLOI2016 侦查守卫
  8. CF1037H Security——SAM+线段树合并
  9. Qt disconnect函数
  10. HDU 6070 - Dirt Ratio | 2017 Multi-University Training Contest 4