当前位置:   article > 正文

yolov8添加注意力机制-MHSA_在yolov8的骨干网络池化金字塔前加入注意力机制

在yolov8的骨干网络池化金字塔前加入注意力机制

一 在ultralytics/nn中新建MHSA.py

添加代码

  1. import torch
  2. import torch.nn as nn
  3. class MHSA(nn.Module):
  4. def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
  5. super(MHSA, self).__init__()
  6. self.heads = heads
  7. self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  8. self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  9. self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  10. self.pos = pos_emb
  11. if self.pos:
  12. self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
  13. requires_grad=True)
  14. self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
  15. requires_grad=True)
  16. self.softmax = nn.Softmax(dim=-1)
  17. def forward(self, x):
  18. n_batch, C, width, height = x.size()
  19. q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
  20. k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
  21. v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
  22. content_content = torch.matmul(q.permute(0, 1, 3, 2), k) # 1,C,h*w,h*w
  23. c1, c2, c3, c4 = content_content.size()
  24. if self.pos:
  25. content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
  26. 0, 1, 3, 2) # 1,4,1024,64
  27. content_position = torch.matmul(content_position, q) # ([1, 4, 1024, 256])
  28. content_position = content_position if (
  29. content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
  30. assert (content_content.shape == content_position.shape)
  31. energy = content_content + content_position
  32. else:
  33. energy = content_content
  34. attention = self.softmax(energy)
  35. out = torch.matmul(v, attention.permute(0, 1, 3, 2)) # 1,4,256,64
  36. out = out.view(n_batch, C, width, height)
  37. return out
  38. if __name__ == '__main__':
  39. input = torch.randn(50, 512, 7, 7)
  40. mhsa = MHSA(n_dims=512)
  41. output = mhsa(input)
  42. print(output.shape)

二 在 ultralytics/nn/tasks.py中导入

from ultralytics.nn.MHSA import MHSA

找到parse_model函数,加入

  1. elif m in {MHSA}:
  2. args=[ch[f],*args]

 

三 在配置文件中修改

在ultralytics/models/v8中添加配置文件yolov8n_MHSA.yaml

  1. # Ultralytics YOLO
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/379060
    推荐阅读
    相关标签