当前位置:   article > 正文

YOLOV8增加head部分模块-BasicRFB

basicrfb

01模型介绍

        1介绍

        本文提出RFB,将RFs的尺度、离心率纳入考虑范围,使用轻量级主干网也能提取到高判别性特征,使得检测器速度快、精度高;具体地,RFB基于RFs的不同尺度,使用不同的卷积核,设计了多分支的conv、pooling操作(makes use of multi-branch pooling with varying kernels),并通过虫洞卷积(dilated conv)来控制感受野的离心率,最后一步reshape操作后,形成生成的特征。

        2使用方法

RFB模块是一个多分支的卷积模块,它的内部结构被划分为两部分:

        1.多分支卷积层:根据RF的定义,使用多种尺寸的卷积核来实现比固定尺寸更好。具体设计:1.瓶颈结构,1x1-s2的卷积减少通道特征,然后加上一个nxn卷积。2.用5x5卷积替换为2个3x3的卷积去减少参数,这样可得到非线性结构更好的层。3.为了输出,卷积经常有stride=2或者是减少通道,所有直连层为了匹配维度用一个不带激活函数的1x1卷积层。
        2.dilated 卷积层:在保持参数量可扩大感受野,用来获取更高分辨率的特征。下图展示了两种RFB结构:RFB和RFB-s。每个分支都是一个正常卷积后面加一个dilated卷积,主要尺寸和dilated因子不同。(a)RFB整体上借鉴了Inception的思想,主要不同点在于引入了3个dilated卷积层。(b)RFB-s和RFB相比主要有两个改进,一方面用3x3的卷积层代替5x5卷积层,另一方面用1x3和3x1的卷积来代替3x3卷积,主要目的是为了减少计算量,类似Inception后期版本对Inception结构的改进。

02模型改进方法

1修改modules.py-模型使用的模块

        在ultralytics/nn/modules路径下增加一个BasicRFB.py,将以下代码复制

  1. import torch
  2. import torch.nn as nn
  3. class BasicConv(nn.Module):
  4. def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True):
  5. super(BasicConv, self).__init__()
  6. self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
  7. padding=padding, dilation=dilation, groups=groups, bias=False)
  8. self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
  9. self.relu = nn.ReLU(inplace=True) if relu else nn.Identity()
  10. def forward(self, x):
  11. x = self.conv(x)
  12. x = self.bn(x)
  13. x = self.relu(x)
  14. return x
  15. class BasicRFB(nn.Module):
  16. def __init__(self, in_planes, out_planes, stride=1, scale=0.1, map_reduce=8, vision=1, groups=1):
  17. super(BasicRFB, self).__init__()
  18. self.scale = scale
  19. self.out_channels = out_planes
  20. inter_planes = in_planes // map_reduce
  21. self.branch0 = nn.Sequential(
  22. BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
  23. BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
  24. BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 1,
  25. dilation=vision + 1, relu=False, groups=groups)
  26. )
  27. self.branch1 = nn.Sequential(
  28. BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
  29. BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
  30. BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 2,
  31. dilation=vision + 2, relu=False, groups=groups)
  32. )
  33. self.branch2 = nn.Sequential(
  34. BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
  35. BasicConv(inter_planes, (inter_planes // 2) * 3, kernel_size=3, stride=1, padding=1, groups=groups),
  36. BasicConv((inter_planes // 2) * 3, 2 * inter_planes, kernel_size=3, stride=stride, padding=1,
  37. groups=groups),
  38. BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 4,
  39. dilation=vision + 4, relu=False, groups=groups)
  40. )
  41. self.ConvLinear = BasicConv(6 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
  42. self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
  43. self.relu = nn.ReLU(inplace=False)
  44. def forward(self, x):
  45. x0 = self.branch0(x)
  46. x1 = self.branch1(x)
  47. x2 = self.branch2(x)
  48. out = torch.cat((x0, x1, x2), 1)
  49. out = self.ConvLinear(out)
  50. short = self.shortcut(x)
  51. out = out * self.scale + short
  52. out = self.relu(out)
  53. return out

2修改tasks.py-增加模型中使用的模块

       在ultralytics/nn/tasks.py中,保证能够导入类,在最后面加入class 的BasicRFB名称,

  1. from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
  2. Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
  3. GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
  4. Segment,CBAM,BasicRFB)

3修改def parse_model-解析参数

        还是在ultralytics/nn/tasks.py中,在最后面加入BasicRFB,保证能读入名称

  1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
  2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,BasicRFB)

4修改yolov8_BasicRFB.yaml-模型的配置文件

        在ultralytics/models/v8/.yaml文件中,复制增加一个名称yolov8_BasicRFB.yaml文件,

把以下复制在其中。

  1. # Ultralytics YOLO
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/42180
    推荐阅读