当前位置:   article > 正文

Grad-CAM,即梯度加权类激活映射 (Gradient-weighted Class Activation Mapping)

grad-cam

Grad-CAM,即梯度加权类激活映射 (Gradient-weighted Class Activation Mapping),是一种用于解释卷积神经网络决策的方法。它通过可视化模型对于给定输入的关注区域来提供洞察。

原理:

Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择一个卷积层作为解释的来源。通常,我们会选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分 相对于我们选择的卷积层 输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权
  5. 结果是一个与卷积层的空间维度相同的加权热力图

优势

Grad-CAM的优点是它可以用于任何卷积神经网络,无需进行结构修改或重新训练。它为我们提供了一个简单但直观的方式来理解模型对于特定输入的决策。

Code

  1. import torch
  2. import cv2
  3. import torch.nn.functional as F
  4. import torchvision.transforms as transforms
  5. import matplotlib.pyplot as plt
  6. from PIL import Image
  7. class GradCAM:
  8. def __init__(self, model, target_layer):
  9. self.model = model
  10. self.target_layer = target_layer
  11. self.feature_maps = None
  12. self.gradients = None
  13. # Hook layers
  14. target_layer.register_forward_hook(self.save_feature_maps)
  15. target_layer.register_backward_hook(self.save_gradients)
  16. def save_feature_maps(self, module, input, output):
  17. self.feature_maps = output.detach()
  18. def save_gradients(self, module, grad_input, grad_output):
  19. self.gradients = grad_output[0].detach()
  20. def generate_cam(self, image, class_idx=None):
  21. # Set model to evaluation mode
  22. self.model.eval()
  23. # Forward pass
  24. output = self.model(image)
  25. if class_idx is None:
  26. class_idx = torch.argmax(output).item()
  27. # Zero out gradients
  28. self.model.zero_grad()
  29. # Backward pass for target class
  30. one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
  31. one_hot[0][class_idx] = 1
  32. output.backward(gradient=one_hot.cuda(), retain_graph=True)
  33. # Get pooled gradients and feature maps
  34. pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
  35. activation = self.feature_maps.squeeze(0)
  36. for i in range(activation.size(0)):
  37. activation[i, :, :] *= pooled_gradients[i]
  38. # Create heatmap
  39. heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
  40. heatmap = np.maximum(heatmap, 0)
  41. heatmap /= torch.max(heatmap)
  42. heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
  43. heatmap = np.uint8(255 * heatmap)
  44. heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  45. # Superimpose heatmap on original image
  46. original_image = self.unprocess_image(image.squeeze().cpu().numpy())
  47. superimposed_img = heatmap * 0.4 + original_image
  48. superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
  49. return heatmap, superimposed_img
  50. def unprocess_image(self, image):
  51. # Reverse the preprocessing step
  52. mean = np.array([0.485, 0.456, 0.406])
  53. std = np.array([0.229, 0.224, 0.225])
  54. image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
  55. return image
  56. def visualize_gradcam(model, input_image_path, target_layer):
  57. # Load image
  58. img = Image.open(input_image_path)
  59. preprocess = transforms.Compose([
  60. transforms.Resize((224, 224)),
  61. transforms.ToTensor(),
  62. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  63. ])
  64. input_tensor = preprocess(img).unsqueeze(0).cuda()
  65. # Create GradCAM
  66. gradcam = GradCAM(model, target_layer)
  67. heatmap, result = gradcam.generate_cam(input_tensor)
  68. plt.figure(figsize=(10,10))
  69. plt.subplot(1,2,1)
  70. plt.imshow(heatmap)
  71. plt.title('Heatmap')
  72. plt.axis('off')
  73. plt.subplot(1,2,2)
  74. plt.imshow(result)
  75. plt.title('Superimposed Image')
  76. plt.axis('off')
  77. plt.show()
  78. # Load your model (e.g., resnet20 in this case)
  79. # model = resnet20()
  80. # model.load_state_dict(torch.load("path_to_your_weights.pth"))
  81. # model.to('cuda')
  82. # Visualize GradCAM
  83. # visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

中文注释详细版

  1. import torch
  2. import cv2
  3. import torch.nn.functional as F
  4. import torchvision.transforms as transforms
  5. import matplotlib.pyplot as plt
  6. from PIL import Image
  7. class GradCAM:
  8. def __init__(self, model, target_layer):
  9. self.model = model # 要进行Grad-CAM处理的模型
  10. self.target_layer = target_layer # 要进行特征可视化的目标层
  11. self.feature_maps = None # 存储特征图
  12. self.gradients = None # 存储梯度
  13. # 为目标层添加钩子,以保存输出和梯度
  14. target_layer.register_forward_hook(self.save_feature_maps)
  15. target_layer.register_backward_hook(self.save_gradients)
  16. def save_feature_maps(self, module, input, output):
  17. """保存特征图"""
  18. self.feature_maps = output.detach()
  19. def save_gradients(self, module, grad_input, grad_output):
  20. """保存梯度"""
  21. self.gradients = grad_output[0].detach()
  22. def generate_cam(self, image, class_idx=None):
  23. """生成CAM热力图"""
  24. # 将模型设置为评估模式
  25. self.model.eval()
  26. # 正向传播
  27. output = self.model(image)
  28. if class_idx is None:
  29. class_idx = torch.argmax(output).item()
  30. # 清空所有梯度
  31. self.model.zero_grad()
  32. # 对目标类进行反向传播
  33. one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
  34. one_hot[0][class_idx] = 1
  35. output.backward(gradient=one_hot.cuda(), retain_graph=True)
  36. # 获取平均梯度和特征图
  37. pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
  38. activation = self.feature_maps.squeeze(0)
  39. for i in range(activation.size(0)):
  40. activation[i, :, :] *= pooled_gradients[i]
  41. # 创建热力图
  42. heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
  43. heatmap = np.maximum(heatmap, 0)
  44. heatmap /= torch.max(heatmap)
  45. heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
  46. heatmap = np.uint8(255 * heatmap)
  47. heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  48. # 将热力图叠加到原始图像上
  49. original_image = self.unprocess_image(image.squeeze().cpu().numpy())
  50. superimposed_img = heatmap * 0.4 + original_image
  51. superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
  52. return heatmap, superimposed_img
  53. def unprocess_image(self, image):
  54. """反预处理图像,将其转回原始图像"""
  55. mean = np.array([0.485, 0.456, 0.406])
  56. std = np.array([0.229, 0.224, 0.225])
  57. image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
  58. return image
  59. def visualize_gradcam(model, input_image_path, target_layer):
  60. """可视化Grad-CAM热力图"""
  61. # 加载图像
  62. img = Image.open(input_image_path)
  63. preprocess = transforms.Compose([
  64. transforms.Resize((224, 224)),
  65. transforms.ToTensor(),
  66. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  67. ])
  68. input_tensor = preprocess(img).unsqueeze(0).cuda()
  69. # 创建GradCAM
  70. gradcam = GradCAM(model, target_layer)
  71. heatmap, result = gradcam.generate_cam(input_tensor)
  72. # 显示图像和热力图
  73. plt.figure(figsize=(10,10))
  74. plt.subplot(1,2,1)
  75. plt.imshow(heatmap)
  76. plt.title('热力图')
  77. plt.axis('off')
  78. plt.subplot(1,2,2)
  79. plt.imshow(result)
  80. plt.title('叠加后的图像')
  81. plt.axis('off')
  82. plt.show()
  83. # 以下是示例代码,显示如何使用上述代码。
  84. # 首先,你需要加载你的模型和权重。
  85. # model = resnet20()
  86. # model.load_state_dict(torch.load("path_to_your_weights.pth"))
  87. # model.to('cuda')
  88. # 然后,调用`visualize_gradcam`函数来查看结果。
  89. # visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

论文链接:https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/340933
推荐阅读
相关标签
  

闽ICP备14008679号