当前位置:   article > 正文

注意力机制(一)SE模块(Squeeze-and-Excitation Networks)论文总结和代码实现

se模块

Squeeze-and-Excitation Networks(压缩和激励网络)

论文地址:Squeeze-and-Excitation Networks

论文中文版:Squeeze-and-Excitation Networks_中文版

代码地址:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks

 

一、论文出发点

为了提高网络的表示能力,许多现有的工作已经显示出增强空间编码的好处。而作者专注于通道,希望能够提出了一种新的架构单元,通过显式地建模出卷积特征通道之间的相互依赖性来提高网络的表示能力。

这里引用“博文:Squeeze-and-Excitation Networks解读”中的总结:核心思想是不同通道的权重应该自适应分配,由网络自己学习出来的,而不是像Inception net一样留下过多人工干预的痕迹。

 

 

二、论文的主要工作

1.提出了一种新的架构单元Squeeze-and-Excitation模块,该模块可以显式地建模卷积特征通道之间的相互依赖性来提高网络的表示能力。

2.提出了一种机制,使网络能够执行特征重新校准,通过这种机制可以学习使用全局信息来选择性地强调信息特征并抑制不太有用的特征。

 

 

三、Squeeze-and-Excitation模块

3.1 Transformation(Ftr)(转型)

F_{tr}X\rightarrow U,经过F_{tr}特征图X变为特征图U。F_{tr}可以看作一个标准的卷积算子。公式定义如下:

                                                       U_{c}=V_{c}*X=\sum_{s=1}^{C'}V_{c}^{s}*X^{s}

其中,X∈R^(H′×W′×C′)为输入特征图U∈R^(H×W×C):输出特征图,V:表示学习到的一组滤波器核,Vc:指的是第c个滤波器的参数,​V_{c}^{s}表示一个2D的空间核,*表示卷积操作。

该卷积算子公式表示,输入特征图X的每一层都经过一个2D空间核的卷积最终得到C个输出的feature map,组成特征图U。

 

3.2 Squeeze(全局信息嵌入)

Fsq就是使用通道的全局平均池化。

原文中为了解决利用通道依赖性的问题,选择将全局空间信息压缩到一个信道描述符中,即使用通道的全局平均池化,将包含全局信息的W×H×C 的特征图直接压缩成一个1×1×C的特征向量Z,C个feature map的通道特征都被压缩成了一个数值,这样使得生成的通道级统计数据Z就包含了上下文信息,缓解了通道依赖性的问题。定义如下:

其中,Zc为Z的第c个元素。

3.3 Excitation(自适应重新校正)

目的为了利用压缩操作中汇聚的信息,我们接下来通过Excitation操作来全面捕获通道依赖性。
实现方法
为了实现这个目标,这个功能必须符合两个标准
第一,它必须是灵活的 (它必须能够学习通道之间的非线性交互)
第二,它必须学习一个非互斥的关系,因为独热激活相反,这里允许强调多个通道。
为了满足这些标准,作者采用了两层全连接构成的门机制,第一个全连接层把C个通道压缩成了C/r个通道来降低计算量,再通过一个RELU非线性激活层,第二个全连接层将通道数恢复回为C个通道,再通过Sigmoid激活得到权重s,最后得到的这个s的维度是1×1×C,它是用来刻画特征图U中C个feature map的权重。r是指压缩的比例。

为什么这里要有两个FC,并且通道先缩小,再放大?

因为一个全连接层无法同时应用relu和sigmoid两个非线性函数,但是两者又缺一不可。为了减少参数,所以设置了r比率。

3.4 Scale(重新加权)

目的:最后是Scale操作,将前面得到的注意力权重加权到每个通道的特征上

实现方法:特征图U中的每个feature map乘以对应的权重,得到SE模块的最终输出\widetilde{X}

 

 

四、模型:SE-Inception和SE-ResNet

通过将一个整体的Inception模块看作SE模块中F_{tr},为Inception网络构建SE模块。

同理, 将一个整体的Residual模块看作SE模块中F_{tr},为ResNet网络构建SE模块。

五、实验

六、结论

本文提出的SE模块,这是一种新颖的架构单元,旨在通过使网络能够执行动态通道特征重新校准来提高网络的表示能力。大量实验证明了SENets的有效性,其在多个数据集上取得了最先进的性能。

七、源码分析

将SEblock嵌入ResNet的残差模块中。

7.1 SE模块

  1. '''-------------一、SE模块-----------------------------'''
  2. #全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
  3. class SE_Block(nn.Module):
  4. def __init__(self, inchannel, ratio=16):
  5. super(SE_Block, self).__init__()
  6. # 全局平均池化(Fsq操作)
  7. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  8. # 两个全连接层(Fex操作)
  9. self.fc = nn.Sequential(
  10. nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
  11. nn.ReLU(),
  12. nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
  13. nn.Sigmoid()
  14. )
  15. def forward(self, x):
  16. # 读取批数据图片数量及通道数
  17. b, c, h, w = x.size()
  18. # Fsq操作:经池化后输出b*c的矩阵
  19. y = self.gap(x).view(b, c)
  20. # Fex操作:经全连接层输出(b,c,1,1)矩阵
  21. y = self.fc(y).view(b, c, 1, 1)
  22. # Fscale操作:将得到的权重乘以原来的特征图x
  23. return x * y.expand_as(x)

7.2 SE-ResNet完整代码

不同版本的ResNet各层主要是由BasicBlock模块(18-layer、34-layer)Bottleneck模块(50-layer、101-layer、152-layer)构成的。

添加SE模块位置:在BasicBlock模块或Bottleneck模块尾部添加,但是要注意放在shortcut之前。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchsummary import summary
  5. '''-------------SE模块-----------------------------'''
  6. #全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
  7. class SE_Block(nn.Module):
  8. def __init__(self, inchannel, ratio=16):
  9. super(SE_Block, self).__init__()
  10. # 全局平均池化(Fsq操作)
  11. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  12. # 两个全连接层(Fex操作)
  13. self.fc = nn.Sequential(
  14. nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
  15. nn.ReLU(),
  16. nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
  17. nn.Sigmoid()
  18. )
  19. def forward(self, x):
  20. # 读取批数据图片数量及通道数
  21. b, c, h, w = x.size()
  22. # Fsq操作:经池化后输出b*c的矩阵
  23. y = self.gap(x).view(b, c)
  24. # Fex操作:经全连接层输出(b,c,1,1)矩阵
  25. y = self.fc(y).view(b, c, 1, 1)
  26. # Fscale操作:将得到的权重乘以原来的特征图x
  27. return x * y.expand_as(x)
  28. '''-------------(18-layer、34-layer)BasicBlock模块-----------------------------'''
  29. # residual block 结构
  30. class BasicBlock(nn.Module):
  31. expansion = 1
  32. def __init__(self, inchannel, outchannel, stride=1):
  33. super(BasicBlock, self).__init__()
  34. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3,
  35. stride=stride, padding=1, bias=False)
  36. self.bn1 = nn.BatchNorm2d(outchannel)
  37. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  38. stride=1, padding=1, bias=False)
  39. self.bn2 = nn.BatchNorm2d(outchannel)
  40. # SE_Block放在BN之后,shortcut之前
  41. self.SE = SE_Block(outchannel)
  42. self.shortcut = nn.Sequential()
  43. if stride != 1 or inchannel != self.expansion*outchannel:
  44. self.shortcut = nn.Sequential(
  45. nn.Conv2d(inchannel, self.expansion*outchannel,
  46. kernel_size=1, stride=stride, bias=False),
  47. nn.BatchNorm2d(self.expansion*outchannel)
  48. )
  49. def forward(self, x):
  50. out = F.relu(self.bn1(self.conv1(x)))
  51. out = self.bn2(self.conv2(out))
  52. SE_out = self.SE(out)
  53. out = out * SE_out
  54. out += self.shortcut(x)
  55. out = F.relu(out)
  56. return out
  57. '''-------------(50-layer、101-layer、152-layer)Bottleneck模块-----------------------------'''
  58. # residual block 结构
  59. class Bottleneck(nn.Module):
  60. expansion = 4
  61. def __init__(self, inchannel, outchannel, stride=1):
  62. super(Bottleneck, self).__init__()
  63. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=1, bias=False)
  64. self.bn1 = nn.BatchNorm2d(outchannel)
  65. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  66. stride=stride, padding=1, bias=False)
  67. self.bn2 = nn.BatchNorm2d(outchannel)
  68. self.conv3 = nn.Conv2d(outchannel, self.expansion*outchannel,
  69. kernel_size=1, bias=False)
  70. self.bn3 = nn.BatchNorm2d(self.expansion*outchannel)
  71. # SE_Block放在BN之后,shortcut之前
  72. self.SE = SE_Block(self.expansion*outchannel)
  73. self.shortcut = nn.Sequential()
  74. if stride != 1 or inchannel != self.expansion*outchannel:
  75. self.shortcut = nn.Sequential(
  76. nn.Conv2d(inchannel, self.expansion*outchannel,
  77. kernel_size=1, stride=stride, bias=False),
  78. nn.BatchNorm2d(self.expansion*outchannel)
  79. )
  80. def forward(self, x):
  81. out = F.relu(self.bn1(self.conv1(x)))
  82. out = F.relu(self.bn2(self.conv2(out)))
  83. out = self.bn3(self.conv3(out))
  84. SE_out = self.SE(out)
  85. out = out * SE_out
  86. out += self.shortcut(x)
  87. out = F.relu(out)
  88. return out
  89. '''-------------搭建SE_ResNet结构-----------------------------'''
  90. class SE_ResNet(nn.Module):
  91. def __init__(self, block, num_blocks, num_classes=10):
  92. super(SE_ResNet, self).__init__()
  93. self.in_planes = 64
  94. self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
  95. stride=1, padding=1, bias=False) # conv1
  96. self.bn1 = nn.BatchNorm2d(64)
  97. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # conv2_x
  98. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # conv3_x
  99. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # conv4_x
  100. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # conv5_x
  101. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  102. self.linear = nn.Linear(512 * block.expansion, num_classes)
  103. def _make_layer(self, block, planes, num_blocks, stride):
  104. strides = [stride] + [1]*(num_blocks-1)
  105. layers = []
  106. for stride in strides:
  107. layers.append(block(self.in_planes, planes, stride))
  108. self.in_planes = planes * block.expansion
  109. return nn.Sequential(*layers)
  110. def forward(self, x):
  111. x = F.relu(self.bn1(self.conv1(x)))
  112. x = self.layer1(x)
  113. x = self.layer2(x)
  114. x = self.layer3(x)
  115. x = self.layer4(x)
  116. x = self.avgpool(x)
  117. x = torch.flatten(x, 1)
  118. out = self.linear(x)
  119. return out
  120. def SE_ResNet18():
  121. return SE_ResNet(BasicBlock, [2, 2, 2, 2])
  122. def SE_ResNet34():
  123. return SE_ResNet(BasicBlock, [3, 4, 6, 3])
  124. def SE_ResNet50():
  125. return SE_ResNet(Bottleneck, [3, 4, 6, 3])
  126. def SE_ResNet101():
  127. return SE_ResNet(Bottleneck, [3, 4, 23, 3])
  128. def SE_ResNet152():
  129. return SE_ResNet(Bottleneck, [3, 8, 36, 3])
  130. '''
  131. if __name__ == '__main__':
  132. model = SE_ResNet50()
  133. print(model)
  134. input = torch.randn(1, 3, 224, 224)
  135. out = model(input)
  136. print(out.shape)
  137. # test()
  138. '''
  139. if __name__ == '__main__':
  140. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  141. net = SE_ResNet50().to(device)
  142. # 打印网络结构和参数
  143. summary(net, (3, 224, 224))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/290834
推荐阅读
相关标签
  

闽ICP备14008679号