DARTS代码分析(Pytorch)
2024-09-05 11:44:53
最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。
OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}
首先定义10个操作,依次解释:
class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去
class Identity(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
return x这个表示skip conncet
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out这个表示将特征图大小变为原来的一半
class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
) def forward(self, x):
return self.net(x)深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。
class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
) def forward(self, x):
return self.net(x)深度可分离卷积,由两个上面的深度分组卷积组成
class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
) def forward(self, x):
return self.net(x)这个表示长方形的卷积,增加了一点特征图的长和宽
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride def forward(self, x):
if self.stride == 1:
return x * 0. # re-sizing by stride
return x[:, :, ::self.stride, ::self.stride] * 0.这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变
最新文章
- MFC-01-Chapter01:Hello,MFC---1.3 第一个MFC程序(04)
- CS0234: 命名空间“System.Web.Mvc”中不存在类型或命名空间名称“Html、Ajax”(是否缺少程序集引用?)
- Ubuntu上部署一个简单的Java项目
- bitmag
- [leetcode]_Count and Say
- 编译器的未来——我们还需要C++么?
- mybatis 应用参考
- Android DrawerLayout 点击事情穿透
- action接收到来自jsp页面的请求时出现中文乱码问题处理方法
- 用CSS3实现饼状loading效果
- jq 点击复制div里面的内容 如果粘贴到富文本中,会将样式,里面所有的标签,文字一并粘贴进去
- Hadoop学习------Hadoop安装方式之(三):分布式部署
- java依赖的斗争:依赖倒置、控制反转和依赖注入
- Codeforces963C Cutting Rectangle 【数学】
- 洛谷P1774 最接近神的人_NOI导刊2010提高(02)(求逆序对)
- undo与redo
- MySQL MTS复制: hitting slave_pending_jobs_size_max
- ASP.NET真假分页—真分页
- PHP 调用ffmpeg
- 转-spring boot web相关配置
热门文章
- router-link to 动态赋值
- 记录微信小程序里自带 时间格式 工具
- electron-vue 引入OpenLayer 报错 Unexpected token export
- android的ant编译打包
- permutation 2(递推 + 思维)
- TCP最大报文段MSS源码分析
- SpringBoot2.X&;Prometheus使用
- TP-四种url访问的方式
- C++ STL 中 map 容器
- add_header 'Cache-Control' 'no-store, no-cache, must-revalidate, proxy-revalidate, max-age=0'