Yolov4网络代码

from collections import OrderedDict
import torch
import torch.nn as nn
from Darknet_53 import darknet53 def conv(in_channels, out_channels, kernel_size, stride=1):
pad = (kernel_size-1)//2 if kernel_size else 0
return nn.Sequential(OrderedDict(
[
("conv", nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=pad)),
("bn", nn.BatchNorm2d(out_channels)),
("relu", nn.LeakyReLU(0.1))
]
))
class SPP(nn.Module):
def __init__(self, pool_sizes=[5, 9, 13]):
super(SPP, self).__init__()
self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes])
def forward(self, x):
features = [maxpool(x) for maxpool in self.maxpools[::-1]]
features = torch.cat(features + [x], dim=1)
return features
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upsample = nn.Sequential(
conv(in_channels=in_channels, out_channels=out_channels,kernel_size=1),
nn.Upsample(scale_factor=2, mode="nearest")
)
def forward(self, x):
x = self.upsample(x)
return x
def conv_three(channels_list, in_channels):
m = nn.Sequential(
conv(in_channels=in_channels, out_channels=channels_list[0], kernel_size=1),
conv(in_channels=channels_list[0], out_channels=channels_list[1], kernel_size=3),
conv(in_channels=channels_list[1], out_channels=channels_list[0], kernel_size=1)
)
return m
def conv_five(channels_list, in_channels):
m = nn.Sequential(
conv(in_channels=in_channels, out_channels=channels_list[0], kernel_size=1),
conv(in_channels=channels_list[0], out_channels=channels_list[1], kernel_size=3),
conv(in_channels=channels_list[1], out_channels=channels_list[0], kernel_size=1),
conv(in_channels=channels_list[0], out_channels=channels_list[1], kernel_size=3),
conv(in_channels=channels_list[1], out_channels=channels_list[0], kernel_size=1)
)
return m
def Yolov4_head(channels_list, in_channels):
m = nn.Sequential(
conv(in_channels=in_channels, out_channels=channels_list[0], kernel_size=3),
conv(in_channels=channels_list[0], out_channels=channels_list[1], kernel_size=1)
)
return m
class YoloBody(nn.Module):
def __init__(self, anchors_mask, num_classes, pretrained = False):
super(YoloBody, self).__init__()
self.backbone = darknet53(pretrained) self.conv1=conv_three(channels_list=[512, 1024], in_channels=1024)
self.spp = SPP()
self.conv2=conv_three(channels_list=[512, 1024], in_channels=2048) self.upsample1 = Upsample(512, 256)
self.conv_for_p4 = conv(in_channels=512, out_channels=256, kernel_size=1)
self.make_five_conv1=conv_five(channels_list=[256, 512], in_channels=512) self.upsample2 = Upsample(in_channels=256, out_channels=128)
self.conv_for_p3=conv(in_channels=256, out_channels=128, kernel_size=1)
self.make_five_conv2=conv_five(channels_list=[128, 256], in_channels=256) # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
self.yolo_head3=Yolov4_head(channels_list= [256, len(anchors_mask[0]) * (5 + num_classes)], in_channels=128) self.down_sample1 = conv(in_channels=128, out_channels=256, kernel_size=3, stride=2)
self.make_five_conv3 = conv_five(channels_list=[256, 512], in_channels=512) # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
self.yolo_head2 = Yolov4_head(channels_list=[512, len(anchors_mask[1]) * (5 + num_classes)], in_channels=256) self.down_sample2 = conv(in_channels=256, out_channels=512, kernel_size=3, stride=2)
self.make_five_conv4 = conv_five(channels_list=[512, 1024], in_channels=1024) # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
self.yolo_head1 = Yolov4_head(channels_list=[1024, len(anchors_mask[2]) * (5 + num_classes)], in_channels=512) def forward(self, x):
x2, x1, x0 = self.backbone(x) # 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
p5 = self.conv1(x0)
p5 = self.spp(p5)
# 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
p5 = self.conv2(p5) # 13,13,512 -> 13,13,256 -> 26,26,256
p5_upsample = self.upsample1(p5)
# 26,26,512 -> 26,26,256
p4 = self.conv_for_p4(x1)
# 26,26,256 + 26,26,256 -> 26,26,512
p4 = torch.cat([p4, p5_upsample], axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
p4 = self.make_five_conv1(p4) # 26,26,256 -> 26,26,128 -> 52,52,128
p4_upsample = self.upsample2(p4)
# 52,52,256 -> 52,52,128
p3 = self.conv_for_p3(x2)
p3=torch.cat([p3, p4_upsample], axis=1)
p3=self.make_five_conv2(p3) p3_downsample=self.down_sample1(p3)
p4=torch.cat([p3_downsample, p4], axis=1)
p4=self.make_five_conv3(p4) p4_downsample=self.down_sample2(p4)
p5=torch.cat([p4_downsample, p5], axis=1)
p5=self.make_five_conv4(p5) out2=self.yolo_head3(p3)
out1=self.yolo_head2(p4)
out0=self.yolo_head1(p5) return out0, out1, out2 # from torchsummary import summary
# yoloyolo=YoloBody(anchors_mask=["0","0","0"], num_classes=5, pretrained = False)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# summary(yoloyolo, input_size=(3, 416, 416))
# print(yoloyolo)

代码没有注释,欢迎留言共同讨论,顺便给个关注,感谢。

最新文章

  1. (转)spring boot注解 --@EnableAsync 异步调用
  2. Monte Carlo方法简介(转载)
  3. 通过printf设置Linux终端输出的颜色和显示方式
  4. 关于手机微网站ICP备案
  5. rand()和srand()GetTickCount函数用法
  6. 虚拟机VMware里 windows server 2003 扩充C盘方法
  7. Free Editor
  8. 关于我的FGC的OAuth2.0认证。
  9. Wireless Network(POJ 2236)
  10. java Date 和 javascript Date
  11. 与众不同 windows phone (20) - Device(设备)之位置服务(GPS 定位), FM 收音机, 麦克风, 震动器
  12. hdu_1007_Quoit Design(最近点对)
  13. eclipse F3可以查询某个方法的具体定义
  14. ASP.NET Core Web API 索引 (更新Identity Server 4 视频教程)
  15. Cocos Creator EditBox(编辑框/输入框)添加事件的两种方法
  16. oracle 索引的(创建、简介、技巧、怎样查看)
  17. 深入理解java虚拟机,类加载
  18. 20155322 《Java程序设计》课堂实践项目 数据库-3-4
  19. python近期遇到的一些面试问题(三)
  20. 关于如何爬虫妹子图网的源码分析 c#实现

热门文章

  1. JZOJ 1075. 【GDKOI2006】新红黑树
  2. JZOJ 1082. 【GDOI2005】选址
  3. RocketMQ - 消费者启动机制
  4. WHAT IS PPM Encoder ?
  5. IIS部署WGCMS
  6. 在github上如何克隆带子模块的项目?
  7. elasticsearch 内存分配设置
  8. C#导出Excel设置单元格样式
  9. python win32 microsoft excel 类range的copyPictrue方法无效
  10. Linux 部署apache2.4