当前位置:   article > 正文

注意力机制讲解与代码解析_注意力机制代码

注意力机制代码

一、SEBlock(通道注意力机制)

先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
(B, C, H, W)---- (B, C, 1, 1)

利用各channel维度的相关性计算权重
(B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid

与原特征相乘得到加权后的。

  1. import torch
  2. import torch.nn as nn
  3. class SELayer(nn.Module):
  4. def __init__(self, channel, reduction = 4):
  5. super(SELayer, self).__init__()
  6. self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小
  7. self.fc1 = nn.Sequential(
  8. nn.Conv2d(channel, channel//reduction, 1, bias = False),
  9. nn.ReLu(implace = True),
  10. nn.Conv2d(channel//reduction, channel, 1, bias = False),
  11. nn.sigmoid()
  12. )
  13. def forward(self, x):
  14. y = self.avg_pool(x)
  15. y_out = self.fc1(y)
  16. return x * y_out

二、CBAM(通道注意力+空间注意力机制)

CBAM里面既有通道注意力机制,也有空间注意力机制。
通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。

空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。 

  1. import torch
  2. import torch.nn as nn
  3. class ChannelAttention(nn.Module):
  4. def __init__(self, channel, rate = 4):
  5. super(ChannelAttention, self).__init__()
  6. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  7. self.max_pool = nn.AdaptiveMaxPool2d(1)
  8. self.fc1 = nn.Sequential(
  9. nn.Conv2d(channel, channel//rate, 1, bias = False)
  10. nn.ReLu(implace = True)
  11. nn.Conv2d(channel//rate, channel, 1, bias = False)
  12. )
  13. self.sig = nn.sigmoid()
  14. def forward(self, x):
  15. avg = sefl.avg_pool(x)
  16. avg_feature = self.fc1(avg)
  17. max = self.max_pool(x)
  18. max_feature = self.fc1(max)
  19. out = max_feature + avg_feature
  20. out = self.sig(out)
  21. return x * out

  1. import torch
  2. import torch.nn as nn
  3. class SpatialAttention(nn.Module):
  4. def __init__(self):
  5. super(SpatialAttention, self).__init__()
  6. //(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)
  7. self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)
  8. self.sigmoid = nn.sigmoid()
  9. def forward(self, x):
  10. mean_f = torch.mean(x, dim = 1, keepdim = True)
  11. max_f = torch.max(x, dim = 1, keepdim = True).values
  12. cat = torch.cat([mean_f, max_f], dim = 1)
  13. out = self.conv1(cat)
  14. return x*self.sigmod(out)

三、transformer里的注意力机制 

Scaled Dot-Product Attention

该注意力机制的输入是QKV。

1.先Q,K相乘。

2.scale

3.softmax

4.求output

  1. import torch
  2. import torch.nn as nn
  3. class ScaledDotProductAttention(nn.Module):
  4. def __init__(self, scale):
  5. super(ScaledDotProductAttention, self)
  6. self.scale = scale
  7. self.softmax = nn.softmax(dim = 2)
  8. def forward(self, q, k, v):
  9. u = torch.bmm(q, k.transpose(1, 2))
  10. u = u / scale
  11. attn = self.softmax(u)
  12. output = torch.bmm(attn, v)
  13. return output
  14. scale = np.power(d_k, 0.5) //缩放系数为K维度的根号。
  15. //Q (B, n_q, d_q) , K (B, n_k, d_k) V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。

 MultiHeadAttention

将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。

1.QKV单头变多头  channel ----- n * new_channel通过linear变换,然后把head和batch先合并

2.求单头注意力机制输出

3.维度拆分   将最终的head和channel合并。

4.linear得到最终输出维度

  1. import torch
  2. import torch.nn as nn
  3. class MultiHeadAttention(nn.Module):
  4. def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):
  5. super(MultiHeadAttention, self)
  6. self.n_head = n_head
  7. self.d_k = d_k
  8. self.d_v = d_v
  9. self.fc_k = nn.Linear(d_k_, n_head * d_k)
  10. self.fc_v = nn.Linear(d_v_, n_head * d_v)
  11. self.fc_q = nn.Linear(d_k_, n_head * d_k)
  12. self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
  13. self.fc_o = nn.Linear(n_head * d_v, d_0)
  14. def forward(self, q, k, v):
  15. batch, n_q, d_q_ = q.size()
  16. batch, n_k, d_k_ = k.size()
  17. batch, n_v, d_v_ = v.size()
  18. q = self.fc_q(q)
  19. k = self.fc_k(k)
  20. v = self.fc_v(v)
  21. q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
  22. k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
  23. v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v)
  24. output = self.attention(q, k, v)
  25. output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
  26. output = self.fc_0(output)
  27. return output

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/290988?site
推荐阅读
相关标签
  

闽ICP备14008679号