当前位置:   article > 正文

(内含即插即用代码)涨点神器!反向残差注意力模块 iRMB,《ICCV 2023 最新论文》

(内含即插即用代码)涨点神器!反向残差注意力模块 iRMB,《ICCV 2023 最新论文》

论文地址:https://arxiv.org/abs/2301.01146

代码链接:GitHub - zhangzjn/EMO: [ICCV 2023] Official PyTorch implementation of "Rethinking Mobile Block for Efficient Attention-based Models"下面直接给大家附上即插即用的全模块代码以及使用方法:

模块代码:其中涉及到几个库函数,直接用镜像源pip下载即可

  1. import math
  2. from functools import partial
  3. from einops import rearrange
  4. from timm.models.layers.activations import *
  5. from timm.models.layers import DropPath
  6. from timm.models.efficientnet_builder import _parse_ksize
  7. from timm.models.efficientnet_blocks import num_groups, SqueezeExcite as SE
  8. # ========== For Common ==========
  9. class LayerNorm2d(nn.Module):
  10. def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
  11. super().__init__()
  12. self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)
  13. def forward(self, x):
  14. x = rearrange(x, 'b c h w -> b h w c').contiguous()
  15. x = self.norm(x)
  16. x = rearrange(x, 'b h w c -> b c h w').contiguous()
  17. return x
  18. def get_norm(norm_layer='in_1d'):
  19. eps = 1e-6
  20. norm_dict = {
  21. 'none': nn.Identity,
  22. 'in_1d': partial(nn.InstanceNorm1d, eps=eps),
  23. 'in_2d': partial(nn.InstanceNorm2d, eps=eps),
  24. 'in_3d': partial(nn.InstanceNorm3d, eps=eps),
  25. 'bn_1d': partial(nn.BatchNorm1d, eps=eps),
  26. 'bn_2d': partial(nn.BatchNorm2d, eps=eps),
  27. # 'bn_2d': partial(nn.SyncBatchNorm, eps=eps),
  28. 'bn_3d': partial(nn.BatchNorm3d, eps=eps),
  29. 'gn': partial(nn.GroupNorm, eps=eps),
  30. 'ln_1d': partial(nn.LayerNorm, eps=eps),
  31. 'ln_2d': partial(LayerNorm2d, eps=eps),
  32. }
  33. return norm_dict[norm_layer]
  34. def get_act(act_layer='relu'):
  35. act_dict = {
  36. 'none': nn.Identity,
  37. 'sigmoid': Sigmoid,
  38. 'swish': Swish,
  39. 'mish': Mish,
  40. 'hsigmoid': HardSigmoid,
  41. 'hswish': HardSwish,
  42. 'hmish': HardMish,
  43. 'tanh': Tanh,
  44. 'relu': nn.ReLU,
  45. 'relu6': nn.ReLU6,
  46. 'prelu': PReLU,
  47. 'gelu': GELU,
  48. 'silu': nn.SiLU
  49. }
  50. return act_dict[act_layer]
  51. class ConvNormAct(nn.Module):
  52. def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
  53. skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
  54. super(ConvNormAct, self).__init__()
  55. self.has_skip = skip and dim_in == dim_out
  56. padding = math.ceil((kernel_size - stride) / 2)
  57. self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
  58. self.norm = get_norm(norm_layer)(dim_out)
  59. self.act = get_act(act_layer)(inplace=inplace)
  60. self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
  61. def forward(self, x):
  62. shortcut = x
  63. x = self.conv(x)
  64. x = self.norm(x)
  65. x = self.act(x)
  66. if self.has_skip:
  67. x = self.drop_path(x) + shortcut
  68. return x
  69. class iRMB(nn.Module):
  70. def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
  71. act_layer='relu', v_proj=True, dw_ks=3, stride=1, dilation=1, se_ratio=0.0, dim_head=64, window_size=7,
  72. attn_s=True, qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False,inplace=True):
  73. super().__init__()
  74. self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
  75. dim_mid = int(dim_in * exp_ratio)
  76. self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
  77. self.attn_s = attn_s
  78. if self.attn_s:
  79. assert dim_in % dim_head == 0, 'dim should be divisible by num_heads'
  80. self.dim_head = dim_head
  81. self.window_size = window_size
  82. self.num_head = dim_in // dim_head
  83. self.scale = self.dim_head ** -0.5
  84. self.attn_pre = attn_pre
  85. self.qk = ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias, norm_layer='none',
  86. act_layer='none')
  87. self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, groups=self.num_head if v_group else 1, bias=qkv_bias,
  88. norm_layer='none', act_layer=act_layer, inplace=inplace)
  89. self.attn_drop = nn.Dropout(attn_drop)
  90. else:
  91. if v_proj:
  92. self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias, norm_layer='none',
  93. act_layer=act_layer, inplace=inplace)
  94. else:
  95. self.v = nn.Identity()
  96. self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation,
  97. groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace)
  98. self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
  99. self.proj_drop = nn.Dropout(drop)
  100. self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
  101. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  102. def forward(self, x):
  103. shortcut = x
  104. x = self.norm(x)
  105. B, C, H, W = x.shape
  106. if self.attn_s:
  107. # padding
  108. if self.window_size <= 0:
  109. window_size_W, window_size_H = W, H
  110. else:
  111. window_size_W, window_size_H = self.window_size, self.window_size
  112. pad_l, pad_t = 0, 0
  113. pad_r = (window_size_W - W % window_size_W) % window_size_W
  114. pad_b = (window_size_H - H % window_size_H) % window_size_H
  115. x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,))
  116. n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
  117. x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous()
  118. # attention
  119. b, c, h, w = x.shape
  120. qk = self.qk(x)
  121. qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head', qk=2, heads=self.num_head,
  122. dim_head=self.dim_head).contiguous()
  123. q, k = qk[0], qk[1]
  124. attn_spa = (q @ k.transpose(-2, -1)) * self.scale
  125. attn_spa = attn_spa.softmax(dim=-1)
  126. attn_spa = self.attn_drop(attn_spa)
  127. if self.attn_pre:
  128. x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
  129. x_spa = attn_spa @ x
  130. x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
  131. w=w).contiguous()
  132. x_spa = self.v(x_spa)
  133. else:
  134. v = self.v(x)
  135. v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
  136. x_spa = attn_spa @ v
  137. x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
  138. w=w).contiguous()
  139. # unpadding
  140. x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous()
  141. if pad_r > 0 or pad_b > 0:
  142. x = x[:, :, :H, :W].contiguous()
  143. else:
  144. x = self.v(x)
  145. x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
  146. x = self.proj_drop(x)
  147. x = self.proj(x)
  148. x = (shortcut + self.drop_path(x)) if self.has_skip else x
  149. return x

当我们使用的时候如何做:

  1. a = torch.ones(1,64,20,20)#设置输入,此时要注意的是模块中有一默认参数dim_head=64,因此输入通道##数需要64的倍数,但是这个默认参数是可以改的
  2. b = iRMB(64,20)#实例化,设置输入输出参数
  3. c = b(a)
  4. print(c.size())#输出尺寸为1,20,20,20

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/380952
推荐阅读
相关标签
  

闽ICP备14008679号