当前位置:   article > 正文

【论文笔记】MCANet: Medical Image Segmentation withMulti-Scale Cross-Axis Attention

mcanet

        医疗图像分割任务中,捕获多尺度信息、构建长期依赖对分割结果有非常大的影响。该论文提出了 Multi-scale Cross-axis Attention(MCA)模块,融合了多尺度特征,并使用Attention提取全局上下文信息。

论文地址:MCANet: Medical Image Segmentation with Multi-Scale Cross-Axis Attention

代码地址:https://github.com/haoshao-nku/medical_seg

一、MCA(Multi-scale Cross-axis Attention)

MCA的结构如下,将E2/3/4通过concat连接起来(concat前先插值到同样分辨率),经过1x1的卷积后(压缩通道数来降低计算量),得到了包含多尺度信息的特征图F,然后在X和Y方向使用不同大小的卷积核进行卷积运算(比如1x11的卷积是x方向,11x1的是y方向,这里可以对着代码看,容易理解),将Q在X和Y方向交换后(这就是Cross-Axis),经过注意力模块后,将多个特征图相加,并融合E1,经过卷积后得到输出。该模块有以下特点:

1、注意力机制作用在多个不同尺度的特征图;

2、Multi-Scale x-Axis Convolution和Multi-Scale y-Axis Convolution分别关注不同轴的特征,在计算注意力时交叉计算,使得不同方向的特征都能被关注到。

MCA细节如下图,输入特征图进入x和y方向的路径,经过不同大小的卷积后进行融合,然后跨轴(x和y轴的Q交换)计算Attention,最后得到输出特征图。

二、代码

MCA的代码如下所示,总体来说比较简单:

  1. from audioop import bias
  2. from pip import main
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import numbers
  7. from mmseg.registry import MODELS
  8. from einops import rearrange
  9. from ..utils import resize
  10. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  11. from mmseg.models.decode_heads.decode_head import BaseDecodeHead
  12. def to_3d(x):
  13. return rearrange(x, 'b c h w -> b (h w) c')
  14. def to_4d(x,h,w):
  15. return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
  16. class BiasFree_LayerNorm(nn.Module):
  17. def __init__(self, normalized_shape):
  18. super(BiasFree_LayerNorm, self).__init__()
  19. if isinstance(normalized_shape, numbers.Integral):
  20. normalized_shape = (normalized_shape,)
  21. normalized_shape = torch.Size(normalized_shape)
  22. assert len(normalized_shape) == 1
  23. self.weight = nn.Parameter(torch.ones(normalized_shape))
  24. self.normalized_shape = normalized_shape
  25. def forward(self, x):
  26. sigma = x.var(-1, keepdim=True, unbiased=False)
  27. return x / torch.sqrt(sigma+1e-5) * self.weight
  28. class WithBias_LayerNorm(nn.Module):
  29. def __init__(self, normalized_shape):
  30. super(WithBias_LayerNorm, self).__init__()
  31. if isinstance(normalized_shape, numbers.Integral):
  32. normalized_shape = (normalized_shape,)
  33. normalized_shape = torch.Size(normalized_shape)
  34. assert len(normalized_shape) == 1
  35. self.weight = nn.Parameter(torch.ones(normalized_shape))
  36. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  37. self.normalized_shape = normalized_shape
  38. def forward(self, x):
  39. mu = x.mean(-1, keepdim=True)
  40. sigma = x.var(-1, keepdim=True, unbiased=False)
  41. return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
  42. class LayerNorm(nn.Module):
  43. def __init__(self, dim, LayerNorm_type):
  44. super(LayerNorm, self).__init__()
  45. if LayerNorm_type =='BiasFree':
  46. self.body = BiasFree_LayerNorm(dim)
  47. else:
  48. self.body = WithBias_LayerNorm(dim)
  49. def forward(self, x):
  50. h, w = x.shape[-2:]
  51. return to_4d(self.body(to_3d(x)), h, w)
  52. class Attention(nn.Module):
  53. def __init__(self, dim, num_heads,LayerNorm_type,):
  54. super(Attention, self).__init__()
  55. self.num_heads = num_heads
  56. self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
  57. self.norm1 = LayerNorm(dim, LayerNorm_type)
  58. self.project_out = nn.Conv2d(dim, dim, kernel_size=1)
  59. self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
  60. self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
  61. self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
  62. self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
  63. self.conv2_1 = nn.Conv2d(
  64. dim, dim, (1, 21), padding=(0, 10), groups=dim)
  65. self.conv2_2 = nn.Conv2d(
  66. dim, dim, (21, 1), padding=(10, 0), groups=dim)
  67. def forward(self, x):
  68. b,c,h,w = x.shape
  69. x1 = self.norm1(x)
  70. attn_00 = self.conv0_1(x1)
  71. attn_01= self.conv0_2(x1)
  72. attn_10 = self.conv1_1(x1)
  73. attn_11 = self.conv1_2(x1)
  74. attn_20 = self.conv2_1(x1)
  75. attn_21 = self.conv2_2(x1)
  76. out1 = attn_00+attn_10+attn_20
  77. out2 = attn_01+attn_11+attn_21
  78. out1 = self.project_out(out1)
  79. out2 = self.project_out(out2)
  80. k1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)
  81. v1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)
  82. k2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)
  83. v2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)
  84. q2 = rearrange(out1, 'b (head c) h w -> b head w (h c)', head=self.num_heads)
  85. q1 = rearrange(out2, 'b (head c) h w -> b head h (w c)', head=self.num_heads)
  86. q1 = torch.nn.functional.normalize(q1, dim=-1)
  87. q2 = torch.nn.functional.normalize(q2, dim=-1)
  88. k1 = torch.nn.functional.normalize(k1, dim=-1)
  89. k2 = torch.nn.functional.normalize(k2, dim=-1)
  90. attn1 = (q1 @ k1.transpose(-2, -1))
  91. attn1 = attn1.softmax(dim=-1)
  92. out3 = (attn1 @ v1) + q1
  93. attn2 = (q2 @ k2.transpose(-2, -1))
  94. attn2 = attn2.softmax(dim=-1)
  95. out4 = (attn2 @ v2) + q2
  96. out3 = rearrange(out3, 'b head h (w c) -> b (head c) h w', head=self.num_heads, h=h, w=w)
  97. out4 = rearrange(out4, 'b head w (h c) -> b (head c) h w', head=self.num_heads, h=h, w=w)
  98. out = self.project_out(out3) + self.project_out(out4) + x
  99. return out
  100. @MODELS.register_module()
  101. class MCAHead(BaseDecodeHead):
  102. def __init__(self,in_channels,image_size,heads,c1_channels,
  103. **kwargs):
  104. super(MCAHead, self).__init__(in_channels,input_transform = 'multiple_select',**kwargs)
  105. self.image_size = image_size
  106. self.decoder_level = Attention(in_channels[1],heads,LayerNorm_type = 'WithBias')
  107. self.align = ConvModule(
  108. in_channels[3],
  109. in_channels[0],
  110. 1,
  111. conv_cfg=self.conv_cfg,
  112. norm_cfg=self.norm_cfg,
  113. act_cfg=self.act_cfg)
  114. self.squeeze = ConvModule(
  115. sum((in_channels[1],in_channels[2],in_channels[3])),
  116. in_channels[1],
  117. 1,
  118. conv_cfg=self.conv_cfg,
  119. norm_cfg=self.norm_cfg,
  120. act_cfg=self.act_cfg)
  121. self.sep_bottleneck = nn.Sequential(
  122. DepthwiseSeparableConvModule(
  123. in_channels[1] + in_channels[0],
  124. in_channels[3],
  125. 3,
  126. padding=1,
  127. norm_cfg=self.norm_cfg,
  128. act_cfg=self.act_cfg),
  129. DepthwiseSeparableConvModule(
  130. in_channels[3],
  131. in_channels[3],
  132. 3,
  133. padding=1,
  134. norm_cfg=self.norm_cfg,
  135. act_cfg=self.act_cfg))
  136. def forward(self, inputs):
  137. """Forward function."""
  138. inputs = self._transform_inputs(inputs)
  139. inputs = [resize(
  140. level,
  141. size=self.image_size,
  142. mode='bilinear',
  143. align_corners=self.align_corners
  144. ) for level in inputs]
  145. y1 = torch.cat([inputs[1],inputs[2],inputs[3]], dim=1)
  146. x = self.squeeze(y1)
  147. x = self.decoder_level(x)
  148. x = torch.cat([x,inputs[0]], dim=1)
  149. x = self.sep_bottleneck(x)
  150. output = self.align(x)
  151. output = self.cls_seg(output)
  152. return output
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/209694
推荐阅读
  

闽ICP备14008679号