赞
踩
一 在ultralytics/nn中新建MHSA.py
添加代码
- import torch
- import torch.nn as nn
-
- class MHSA(nn.Module):
- def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
- super(MHSA, self).__init__()
-
- self.heads = heads
- self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
- self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
- self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
- self.pos = pos_emb
- if self.pos:
- self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
- requires_grad=True)
- self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
- requires_grad=True)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x):
- n_batch, C, width, height = x.size()
- q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
- k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
- v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
- content_content = torch.matmul(q.permute(0, 1, 3, 2), k) # 1,C,h*w,h*w
- c1, c2, c3, c4 = content_content.size()
- if self.pos:
- content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
- 0, 1, 3, 2) # 1,4,1024,64
-
- content_position = torch.matmul(content_position, q) # ([1, 4, 1024, 256])
- content_position = content_position if (
- content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
- assert (content_content.shape == content_position.shape)
- energy = content_content + content_position
- else:
- energy = content_content
- attention = self.softmax(energy)
- out = torch.matmul(v, attention.permute(0, 1, 3, 2)) # 1,4,256,64
- out = out.view(n_batch, C, width, height)
- return out
-
- if __name__ == '__main__':
- input = torch.randn(50, 512, 7, 7)
- mhsa = MHSA(n_dims=512)
- output = mhsa(input)
- print(output.shape)

二 在 ultralytics/nn/tasks.py中导入
from ultralytics.nn.MHSA import MHSA
找到parse_model函数,加入
- elif m in {MHSA}:
- args=[ch[f],*args]
三 在配置文件中修改
在ultralytics/models/v8中添加配置文件yolov8n_MHSA.yaml
- # Ultralytics YOLO 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/379060推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。