当前位置:   article > 正文

YOLOv8改进 更换旋转目标检测的主干网络LSKNet_yolov8修改主干网络代码

yolov8修改主干网络代码

一、旋转目标检测主干网络LSKNet论文

论文地址:2303.09030.pdf (arxiv.org)

二、Large Selective Kernel Network 的结构

    LSK明确地产生了具有各种大感受野的多个特征,使后来的内核选择更加容易且顺序分解比简单地应用一个较大的核更有效更高效。为了提高网络关注检测目标的最相关的空间背景区域的能力,LSK使用了一种空间选择机制,从不同尺度的大卷积核中空间选择特征图。

三、代码实现

1、在ultralytics\ultralytics\nn路径下新建一个文件夹命名为backbone,用于存放网络结构修改的代码。

并在该 backbone文件夹路径下新建py文件lsknet.py,并在该文件里添加lsknet网络结构的代码:

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.modules.utils import _pair as to_2tuple
  4. from timm.layers import DropPath, to_2tuple
  5. from functools import partial
  6. import numpy as np
  7. __all__ = 'lsknet_t', 'lsknet_s'
  8. class Mlp(nn.Module):
  9. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  10. super().__init__()
  11. out_features = out_features or in_features
  12. hidden_features = hidden_features or in_features
  13. self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  14. self.dwconv = DWConv(hidden_features)
  15. self.act = act_layer()
  16. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  17. self.drop = nn.Dropout(drop)
  18. def forward(self, x):
  19. x = self.fc1(x)
  20. x = self.dwconv(x)
  21. x = self.act(x)
  22. x = self.drop(x)
  23. x = self.fc2(x)
  24. x = self.drop(x)
  25. return x
  26. class LSKblock(nn.Module):
  27. def __init__(self, dim):
  28. super().__init__()
  29. self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  30. self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
  31. self.conv1 = nn.Conv2d(dim, dim//2, 1)
  32. self.conv2 = nn.Conv2d(dim, dim//2, 1)
  33. self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
  34. self.conv = nn.Conv2d(dim//2, dim, 1)
  35. def forward(self, x):
  36. attn1 = self.conv0(x)
  37. attn2 = self.conv_spatial(attn1)
  38. attn1 = self.conv1(attn1)
  39. attn2 = self.conv2(attn2)
  40. attn = torch.cat([attn1, attn2], dim=1)
  41. avg_attn = torch.mean(attn, dim=1, keepdim=True)
  42. max_attn, _ = torch.max(attn, dim=1, keepdim=True)
  43. agg = torch.cat([avg_attn, max_attn], dim=1)
  44. sig = self.conv_squeeze(agg).sigmoid()
  45. attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
  46. attn = self.conv(attn)
  47. return x * attn
  48. class Attention(nn.Module):
  49. def __init__(self, d_model):
  50. super().__init__()
  51. self.proj_1 = nn.Conv2d(d_model, d_model, 1)
  52. self.activation = nn.GELU()
  53. self.spatial_gating_unit = LSKblock(d_model)
  54. self.proj_2 = nn.Conv2d(d_model, d_model, 1)
  55. def forward(self, x):
  56. shorcut = x.clone()
  57. x = self.proj_1(x)
  58. x = self.activation(x)
  59. x = self.spatial_gating_unit(x)
  60. x = self.proj_2(x)
  61. x = x + shorcut
  62. return x
  63. class Block(nn.Module):
  64. def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU, norm_cfg=None):
  65. super().__init__()
  66. self.norm1 = nn.BatchNorm2d(dim)
  67. self.norm2 = nn.BatchNorm2d(dim)
  68. self.attn = Attention(dim)
  69. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  70. mlp_hidden_dim = int(dim * mlp_ratio)
  71. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  72. layer_scale_init_value = 1e-2
  73. self.layer_scale_1 = nn.Parameter(
  74. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  75. self.layer_scale_2 = nn.Parameter(
  76. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  77. def forward(self, x):
  78. x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
  79. x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
  80. return x
  81. class OverlapPatchEmbed(nn.Module):
  82. """ Image to Patch Embedding
  83. """
  84. def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
  85. super().__init__()
  86. patch_size = to_2tuple(patch_size)
  87. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
  88. padding=(patch_size[0] // 2, patch_size[1] // 2))
  89. self.norm = nn.BatchNorm2d(embed_dim)
  90. def forward(self, x):
  91. x = self.proj(x)
  92. _, _, H, W = x.shape
  93. x = self.norm(x)
  94. return x, H, W
  95. class LSKNet(nn.Module):
  96. def __init__(self, img_size=224, in_chans=3, embed_dims=[64, 128, 256, 512],
  97. mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
  98. depths=[3, 4, 6, 3], num_stages=4,
  99. norm_cfg=None):
  100. super().__init__()
  101. self.depths = depths
  102. self.num_stages = num_stages
  103. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  104. cur = 0
  105. for i in range(num_stages):
  106. patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
  107. patch_size=7 if i == 0 else 3,
  108. stride=4 if i == 0 else 2,
  109. in_chans=in_chans if i == 0 else embed_dims[i - 1],
  110. embed_dim=embed_dims[i], norm_cfg=norm_cfg)
  111. block = nn.ModuleList([Block(
  112. dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j],norm_cfg=norm_cfg)
  113. for j in range(depths[i])])
  114. norm = norm_layer(embed_dims[i])
  115. cur += depths[i]
  116. setattr(self, f"patch_embed{i + 1}", patch_embed)
  117. setattr(self, f"block{i + 1}", block)
  118. setattr(self, f"norm{i + 1}", norm)
  119. self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
  120. def forward(self, x):
  121. B = x.shape[0]
  122. outs = []
  123. for i in range(self.num_stages):
  124. patch_embed = getattr(self, f"patch_embed{i + 1}")
  125. block = getattr(self, f"block{i + 1}")
  126. norm = getattr(self, f"norm{i + 1}")
  127. x, H, W = patch_embed(x)
  128. for blk in block:
  129. x = blk(x)
  130. x = x.flatten(2).transpose(1, 2)
  131. x = norm(x)
  132. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  133. outs.append(x)
  134. return outs
  135. class DWConv(nn.Module):
  136. def __init__(self, dim=768):
  137. super(DWConv, self).__init__()
  138. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  139. def forward(self, x):
  140. x = self.dwconv(x)
  141. return x
  142. def update_weight(model_dict, weight_dict):
  143. idx, temp_dict = 0, {}
  144. for k, v in weight_dict.items():
  145. if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
  146. temp_dict[k] = v
  147. idx += 1
  148. model_dict.update(temp_dict)
  149. print(f'loading weights... {idx}/{len(model_dict)} items')
  150. return model_dict
  151. def lsknet_t(weights=''):
  152. model = LSKNet(embed_dims=[32, 64, 160, 256], depths=[3, 3, 5, 2], drop_rate=0.1, drop_path_rate=0.1)
  153. if weights:
  154. model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['state_dict']))
  155. return model
  156. def lsknet_s(weights=''):
  157. model = LSKNet(embed_dims=[64, 128, 256, 512], depths=[2, 2, 4, 2], drop_rate=0.1, drop_path_rate=0.1)
  158. if weights:
  159. model.load_state_dict(update_weight(model.state_dict(), torch.load(weights)['state_dict']))
  160. return model
  161. if __name__ == '__main__':
  162. model = lsknet_t('lsk_t_backbone-2ef8a593.pth')
  163. inputs = torch.randn((1, 3, 640, 640))
  164. for i in model(inputs):
  165. print(i.size())

2、在ultralytics\ultralytics\nn\tasks.py文件中加入lsknet模块

开头先从新建的文件夹引入lsknet的包:

from ultralytics.nn.backbone.lsknet import *

并且文件的def _predict_once函数模块要替换为更换网络结构后的预测模块:

替换为:

  1. def _predict_once(self, x, profile=False, visualize=False):
  2. """
  3. Perform a forward pass through the network.
  4. Args:
  5. x (torch.Tensor): The input tensor to the model.
  6. profile (bool): Print the computation time of each layer if True, defaults to False.
  7. visualize (bool): Save the feature maps of the model if True, defaults to False.
  8. Returns:
  9. (torch.Tensor): The last output of the model.
  10. """
  11. y, dt = [], [] # outputs
  12. for m in self.model:
  13. if m.f != -1: # if not from previous layer
  14. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  15. if profile:
  16. self._profile_one_layer(m, x, dt)
  17. if hasattr(m, 'backbone'):
  18. x = m(x)
  19. for _ in range(5 - len(x)):
  20. x.insert(0, None)
  21. for i_idx, i in enumerate(x):
  22. if i_idx in self.save:
  23. y.append(i)
  24. else:
  25. y.append(None)
  26. # for i in x:
  27. # if i is not None:
  28. # print(i.size())
  29. x = x[-1]
  30. else:
  31. x = m(x) # run
  32. y.append(x if m.i in self.save else None) # save output
  33. if visualize:
  34. feature_visualization(x, m.type, m.i, save_dir=visualize)
  35. return x

然后在def parse_model函数模块中进行修改:

由于是更换yolov8原始的网路结构,所以需要在该parse_model函数模块中加入更改网络模块的代码,更改后完整的def parse_model模块代码为:

  1. def parse_model(d, ch, verbose=True, warehouse_manager=None): # model_dict, input_channels(3)
  2. """Parse a YOLO model.yaml dictionary into a PyTorch model."""
  3. import ast
  4. # Args
  5. max_channels = float('inf')
  6. nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
  7. depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
  8. if scales:
  9. scale = d.get('scale')
  10. if not scale:
  11. scale = tuple(scales.keys())[0]
  12. LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
  13. depth, width, max_channels = scales[scale]
  14. if act:
  15. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  16. if verbose:
  17. LOGGER.info(f"{colorstr('activation:')} {act}") # print
  18. if verbose:
  19. LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
  20. ch = [ch]
  21. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  22. is_backbone = False
  23. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  24. try:
  25. if m == 'node_mode':
  26. m = d[m]
  27. if len(args) > 0:
  28. if args[0] == 'head_channel':
  29. args[0] = int(d[args[0]])
  30. t = m
  31. m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
  32. except:
  33. pass
  34. for j, a in enumerate(args):
  35. if isinstance(a, str):
  36. with contextlib.suppress(ValueError):
  37. try:
  38. args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
  39. except:
  40. args[j] = a
  41. n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
  42. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
  43. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
  44. if args[0] == 'head_channel':
  45. args[0] = d[args[0]]
  46. c1, c2 = ch[f], args[0]
  47. if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
  48. c2 = make_divisible(min(c2, max_channels) * width, 8)
  49. args = [c1, c2, *args[1:]]
  50. if m in (
  51. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
  52. args.insert(2, n) # number of repeats
  53. n = 1
  54. elif m is AIFI:
  55. args = [ch[f], *args]
  56. ##### 更换网络lsknet ####
  57. elif m in {lsknet_s, lsknet_t}:
  58. m = m(*args)
  59. c2 = m.channel
  60. elif m in (HGStem, HGBlock):
  61. c1, cm, c2 = ch[f], args[0], args[1]
  62. args = [c1, cm, c2, *args[2:]]
  63. if m is HGBlock:
  64. args.insert(4, n) # number of repeats
  65. n = 1
  66. elif m in (
  67. Detect, Pose):
  68. args.append([ch[x] for x in f])
  69. elif m is nn.BatchNorm2d:
  70. args = [ch[f]]
  71. elif m is Concat:
  72. c2 = sum(ch[x] for x in f)
  73. elif m in (Detect, Segment, Pose):
  74. args.append([ch[x] for x in f])
  75. if m is Segment:
  76. args[2] = make_divisible(min(args[2], max_channels) * width, 8)
  77. elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
  78. args.insert(1, [ch[x] for x in f])
  79. else:
  80. c2 = ch[f]
  81. if isinstance(c2, list):
  82. is_backbone = True
  83. m_ = m
  84. m_.backbone = True
  85. else:
  86. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  87. t = str(m)[8:-2].replace('__main__.', '') # module type
  88. m.np = sum(x.numel() for x in m_.parameters()) # number params
  89. m_.i, m_.f, m_.type = i + 4 if is_backbone else i, f, t # attach index, 'from' index, type
  90. if verbose:
  91. LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
  92. save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if
  93. x != -1) # append to savelist
  94. layers.append(m_)
  95. if i == 0:
  96. ch = []
  97. if isinstance(c2, list):
  98. ch.extend(c2)
  99. for _ in range(5 - len(ch)):
  100. ch.insert(0, 0)
  101. else:
  102. ch.append(c2)
  103. return nn.Sequential(*layers), sorted(save)

3、创建yolov8+LSKNet.yaml文件: