当前位置:   article > 正文

可视化pytorch网络特征图_pytorch打印特征图

pytorch打印特征图

 0. 背景

在目标检测任务中,我们会使用多尺度的特征图进行预测,背后的常识是:浅层特征图包含丰富的边缘信息有利于定位小目标,高层特征图中包含大量的语义信息有利于大目标的定位和识别。为了进一步了解特征图包含的信息,可以通过可视化特征图直观的认识到神经网络学习得到的东西。此外,对于分析网络为什么有效和改进网络也有些许帮助。

1. pytorch提供的函数

1.1. register_forward_hook

 利用register_forward_hook在特定的module上添加一个hook函数,对该module的输入和输出特征图进行分析。

 1.2 save_image

 利用save_image可以将单个通道的特征图拼接,并直接保存到磁盘上。

2.  VGG可视化特征图示例

  • 输入图像example.jpg
  • 代码
  1. import torch
  2. from torch import nn
  3. from torchvision import models, transforms
  4. from PIL import Image
  5. from torchvision.utils import make_grid, save_image
  6. import os
  7. # model
  8. net = models.vgg16_bn(pretrained=True).cuda()
  9. # image pre-process
  10. transforms_input = transforms.Compose([transforms.Resize((224, 224)),
  11. transforms.ToTensor(),
  12. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  13. fImg = Image.open("input_image.jpg").convert('RGB')
  14. data = transforms_input(fImg).unsqueeze(0).cuda()
  15. # feature image save path
  16. FEATURE_FOLDER = "./outputs/features"
  17. if not os.path.exists(FEATURE_FOLDER):
  18. os.mkdir(FEATURE_FOLDER)
  19. # three global vatiable for feature image name
  20. feature_list = list()
  21. count = 0
  22. idx = 0
  23. def get_image_path_for_hook(module):
  24. global count
  25. image_name = feature_list[count] + ".png"
  26. count += 1
  27. image_path = os.path.join(FEATURE_FOLDER, image_name)
  28. return image_path
  29. def hook_func(module, input, output):
  30. image_path = get_image_path_for_hook(module)
  31. data = output.clone().detach()
  32. global idx
  33. print(idx, "->", data.shape)
  34. idx+=1
  35. data = data.data.permute(1, 0, 2, 3)
  36. save_image(data, image_path, normalize=False)
  37. for name, module in net.named_modules():
  38. if isinstance(module, torch.nn.Conv2d):
  39. print(name)
  40. feature_list.append(name)
  41. module.register_forward_hook(hook_func)
  42. out = net(data)
  •  输出log

通过输出的log可以了解到哪些符合要求的特征图以及它们的大小被打印出来了。

  1. features.0
  2. features.3
  3. features.7
  4. features.10
  5. features.14
  6. features.17
  7. features.20
  8. features.24
  9. features.27
  10. features.30
  11. features.34
  12. features.37
  13. features.40
  14. 0 -> torch.Size([1, 64, 224, 224])
  15. 1 -> torch.Size([1, 64, 224, 224])
  16. 2 -> torch.Size([1, 128, 112, 112])
  17. 3 -> torch.Size([1, 128, 112, 112])
  18. 4 -> torch.Size([1, 256, 56, 56])
  19. 5 -> torch.Size([1, 256, 56, 56])
  20. 6 -> torch.Size([1, 256, 56, 56])
  21. 7 -> torch.Size([1, 512, 28, 28])
  22. 8 -> torch.Size([1, 512, 28, 28])
  23. 9 -> torch.Size([1, 512, 28, 28])
  24. 10 -> torch.Size([1, 512, 14, 14])
  25. 11 -> torch.Size([1, 512, 14, 14])
  26. 12 -> torch.Size([1, 512, 14, 14])

通过下面两行代码可以得到VGG网络中的所有module以及它们的name,可用于后续的对比验证保存得到的特征图结果时候正确。 

  1. for name, layer in net.named_modules():
  2. print(name, '->', layer)
  1. features.0 -> Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  2. features.1 -> BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  3. features.2 -> ReLU(inplace=True)
  4. features.3 -> Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  5. features.4 -> BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  6. features.5 -> ReLU(inplace=True)
  7. features.6 -> MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  8. features.7 -> Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  9. features.8 -> BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. features.9 -> ReLU(inplace=True)
  11. features.10 -> Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  12. features.11 -> BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  13. features.12 -> ReLU(inplace=True)
  14. features.13 -> MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  15. features.14 -> Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  16. features.15 -> BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  17. features.16 -> ReLU(inplace=True)
  18. features.17 -> Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  19. features.18 -> BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  20. features.19 -> ReLU(inplace=True)
  21. features.20 -> Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  22. features.21 -> BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  23. features.22 -> ReLU(inplace=True)
  24. features.23 -> MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  25. features.24 -> Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  26. features.25 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  27. features.26 -> ReLU(inplace=True)
  28. features.27 -> Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  29. features.28 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  30. features.29 -> ReLU(inplace=True)
  31. features.30 -> Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  32. features.31 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  33. features.32 -> ReLU(inplace=True)
  34. features.33 -> MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  35. features.34 -> Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  36. features.35 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  37. features.36 -> ReLU(inplace=True)
  38. features.37 -> Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  39. features.38 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  40. features.39 -> ReLU(inplace=True)
  41. features.40 -> Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  42. features.41 -> BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  43. features.42 -> ReLU(inplace=True)
  44. features.43 -> MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

 输出特征图的结果

  •   features.3.png
  • feature.17.png
  •  

 通过上述的分析过程,可以得到VGG网络中不同卷积层输出的特征图,同时可以修改module的匹配原则得到其他类型层的输出,也即在合适的module后面添加本文的hook_func函数就可以对其特征图进行可视化。

参考链接

Pytorch可视化特征图_吹吹自然风-CSDN博客_可视化特征图

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/340879
推荐阅读
相关标签
  

闽ICP备14008679号