当前位置:   article > 正文

Win10+PyCharm利用MMSegmentation训练自己的数据集_from mmengine.config import config, dictaction

from mmengine.config import config, dictaction

系统版本:Windows 10 企业版

依赖环境:Anaconda3

运行软件:PyCharm

MMSegmentation版本:V1.1.1

前提:运行环境已经配置好,环境的配置可以参考:Win10系统下MMSegmentation的环境配置-CSDN博客

目录

1. 从官网下载相应的MMSegmentation

2. 定义数据集类

3. 注册数据集类

4. 配置数据处理pipeline文件

5. 配置Config文件

6. 训练模型

7. 测试模型


1. 从官网下载相应的MMSegmentation

从官网下载对应的版本:

这里可以看到不同的版本。本教程使用的是最新的V1.1.1版本。

把下载好的版本当成一个工程项目,直接在Pycharm中打开。

2. 定义数据集类

在mmseg/datasets中新建一个以自己的数据集命名的py文件,例如,我新建了一个名为cag.py的文件。

在cag.py中为自己的数据集创建一个新的类:

  1. # Copyright (c) Lisa. All rights reserved.
  2. """
  3. @Title: 创建一个自己的数据集
  4. @Author: Lisa
  5. @Date: 2023/09/13
  6. """
  7. import mmengine.fileio as fileio
  8. from mmseg.registry import DATASETS
  9. from .basesegdataset import BaseSegDataset
  10. @DATASETS.register_module()
  11. class CAGDataset(BaseSegDataset):
  12. """CAG dataset.
  13. In segmentation map annotation for CAG, 0 stands for background, which is
  14. included in 2 categories. ``reduce_zero_label`` is fixed to False. The
  15. ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
  16. '.png'.
  17. """
  18. # 类别和对应的RGB配色
  19. METAINFO = dict(
  20. classes=('background', 'vessel'), # 类别标签名称
  21. palette=[[0, 0, 0], [255, 255, 255]]) # 类别标签上色的RGB颜色
  22. # 指定图像扩展名,标注扩展名
  23. def __init__(self,
  24. img_suffix='.png', # 输入image的后缀名为’.png‘
  25. seg_map_suffix='.png', # 输入mask/label的后缀名为’.png‘
  26. reduce_zero_label=False,
  27. **kwargs) -> None:
  28. super().__init__(
  29. img_suffix=img_suffix,
  30. seg_map_suffix=seg_map_suffix,
  31. reduce_zero_label=reduce_zero_label,
  32. **kwargs)

3. 注册数据集类

打开mmseg/datasets/__init__.py文件

导入定义好的CAGDataset类,然后添加到__all__中。

4. 配置数据处理pipeline文件

在configs/_base_/datasets中创建一个数据处理pipeline的py文件。

cag_pipeline.py中数据处理pipeline如下:

  1. # 数据处理pipeline
  2. # 参照同济张子豪
  3. # dataset settings 设置数据集路径
  4. dataset_type = 'CAGDataset' # must be the same name of custom dataset. 必须和自定义数据集名称完全一致。
  5. data_root = '../data/CAG' # 数据集根目录, 后续所有的pipeline使用的目录都会在此目录下的子目录读取
  6. # img_scale = (2336, 3504)
  7. # img_scale = (512, 512)
  8. # crop_size = (256, 256)
  9. # 输入模型的图像裁剪尺寸,一般是128的倍数。
  10. crop_size = (512, 512)
  11. # 训练预处理
  12. train_pipeline = [
  13. dict(type='LoadImageFromFile'),
  14. dict(type='LoadAnnotations'),
  15. dict(
  16. type='RandomResize',
  17. scale=(2048, 1024),
  18. ratio_range=(0.5, 2.0),
  19. keep_ratio=True),
  20. # dict(type='Resize', img_scale=img_scale),
  21. dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
  22. dict(type='RandomFlip', prob=0.5),
  23. dict(type='PhotoMetricDistortion'),
  24. dict(type='PackSegInputs')
  25. ]
  26. # 测试 预处理
  27. test_pipeline = [
  28. dict(type='LoadImageFromFile'),
  29. dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
  30. # add loading annotation after ``Resize`` because ground truth
  31. # does not need to do resize data transform
  32. dict(type='LoadAnnotations'),
  33. dict(type='PackSegInputs')
  34. ]
  35. # TTA 后处理
  36. img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
  37. tta_pipeline = [
  38. # dict(type='LoadImageFromFile', backend_args=None),
  39. dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
  40. dict(
  41. type='TestTimeAug',
  42. transforms=[
  43. [
  44. dict(type='Resize', scale_factor=r, keep_ratio=True)
  45. for r in img_ratios
  46. ],
  47. [
  48. dict(type='RandomFlip', prob=0., direction='horizontal'),
  49. dict(type='RandomFlip', prob=1., direction='horizontal')
  50. ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
  51. ])
  52. ]
  53. # 训练 Dataloader
  54. train_dataloader = dict(
  55. batch_size=4, # 4
  56. num_workers=2, # 4 dataloader的线程数目,一般设为2, 4, 8,根据CPU核数确定,或使用os.cpu_count()函数代替,一般num_workers>=4速度提升就不再明显。
  57. persistent_workers=True, # 一种加速图片加载的操作
  58. sampler=dict(type='InfiniteSampler', shuffle=True), # shuffle=True是打乱图片
  59. dataset=dict(
  60. type=dataset_type,
  61. data_root=data_root,
  62. data_prefix=dict(
  63. img_path='images/training',
  64. seg_map_path='annotations/training'),
  65. # ann_file='splits/train.txt',
  66. pipeline=train_pipeline
  67. )
  68. )
  69. # 测试DataLoader
  70. val_dataloader = dict(
  71. batch_size=1,
  72. num_workers=4, # 4
  73. persistent_workers=True,
  74. sampler=dict(type='DefaultSampler', shuffle=False),
  75. dataset=dict(
  76. type=dataset_type,
  77. data_root=data_root,
  78. data_prefix=dict(
  79. img_path='images/validation',
  80. seg_map_path='annotations/validation'),
  81. # ann_file='splits/val.txt',
  82. pipeline=test_pipeline
  83. ))
  84. # 验证DataLoader
  85. test_dataloader = val_dataloader
  86. # 验证 Evaluator
  87. val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore']) # 分割指标评估器
  88. test_evaluator = val_evaluator

5. 配置Config文件

以UNet为例,配置Config文件

  1. """
  2. create_Unet_config.py
  3. 配置Unet的Config文件
  4. """
  5. from mmengine import Config
  6. cfg = Config.fromfile('../configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py ')
  7. dataset_cfg = Config.fromfile('../configs/_base_/datasets/cag_pipeline.py')
  8. cfg.merge_from_dict(dataset_cfg)
  9. # 修改Config配置文件
  10. # 类别个数
  11. NUM_CLASS = 2
  12. cfg.crop_size = (512, 512)
  13. cfg.model.data_preprocessor.size = cfg.crop_size
  14. cfg.model.data_preprocessor.test_cfg = dict(size_divisor=128)
  15. # 单卡训练时,需要把SyncBN改为BN
  16. cfg.norm_cfg = dict(type='BN', requires_grad=True)
  17. cfg.model.backbone.norm_cfg = cfg.norm_cfg
  18. cfg.model.decode_head.norm_cfg = cfg.norm_cfg
  19. cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
  20. # 模型decode/auxiliary输出头,指定为类别个数
  21. cfg.model.decode_head.num_classes = NUM_CLASS
  22. cfg.model.auxiliary_head.num_classes = NUM_CLASS
  23. # 训练Batch Size
  24. cfg.train_dataloader.batch_size = 4
  25. # 结果保存目录
  26. cfg.work_dir = '../work_dirs/my_Unet'
  27. # 模型保存与日志记录
  28. cfg.train_cfg.max_iters = 60000 # 训练迭代次数
  29. cfg.train_cfg.val_interval = 500 # 评估模型间隔 500
  30. cfg.default_hooks.logger.interval = 100 # 日志记录间隔
  31. cfg.default_hooks.checkpoint.interval = 500 # 模型权重保存间隔 2500
  32. cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重 1, 2
  33. cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重
  34. # 随机数种子
  35. cfg['randomness'] = dict(seed=0)
  36. # 查看完整的Config配置文件
  37. print(cfg.pretty_text)
  38. # 保存最终的config配置文件
  39. cfg.dump('../my_Configs/my_Unet_20230913.py')

6. 训练模型

在PyCharm中运行my_train.py主要有两种方式,

方式一:在右上方的my_train处单击“Edit Configurations...”

在Parameters:一栏中输入命令“ ../my_Configs/my_PSPNet_20230906.py  ../work_dirs/my_PSPNet”,然后点击“运行”。

方式二: 改写超参数

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """
  3. 参考train.py改写
  4. """
  5. import argparse
  6. import logging
  7. import os
  8. import os.path as osp
  9. import torch.backends.cudnn
  10. from mmengine.config import Config, DictAction
  11. from mmengine.logging import print_log
  12. from mmengine.runner import Runner
  13. from mmseg.registry import RUNNERS
  14. os.environ['CUDA_LAUNCH_BLOCKING'] = '1' #(上面报错的最后一行的提示信息)
  15. # torch.backends.cudnn.benchmark = True
  16. # 用于解决报错:RuntimeError: CUDA error: an illegal memory access was encountered
  17. # CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
  18. # For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
  19. def parse_args():
  20. parser = argparse.ArgumentParser(description='Train a segmentor')
  21. parser.add_argument('--config', default='../my_Configs/my_PSPNet_20230906.py', help='train config file path') # 不加--号是表示该参数是必须的
  22. # parser.add_argument('--config', default='../configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-80k_cag-512x512_my.py',
  23. # help='train config file path') # 不加--号是表示该参数是必须的
  24. # parser.add_argument('--config', default='../my_Configs/my_KNet_20230830.py',
  25. # help='train config file path') # 不加--号是表示该参数是必须的
  26. # parser.add_argument('--work-dir', default='work_dir', help='the dir to save logs and models')
  27. parser.add_argument('--work-dir', default='../work_dirs/my_PSPNet', help='the dir to save logs and models')
  28. parser.add_argument(
  29. '--resume',
  30. action='store_true',
  31. default=True,
  32. help='resume from the latest checkpoint in the work_dir automatically')
  33. parser.add_argument(
  34. '--amp',
  35. action='store_true',
  36. default=False,
  37. help='enable automatic-mixed-precision training')
  38. parser.add_argument(
  39. '--cfg-options',
  40. nargs='+',
  41. action=DictAction,
  42. help='override some settings in the used config, the key-value pair '
  43. 'in xxx=yyy format will be merged into config file. If the value to '
  44. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  45. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  46. 'Note that the quotation marks are necessary and that no white space '
  47. 'is allowed.')
  48. parser.add_argument(
  49. '--launcher',
  50. choices=['none', 'pytorch', 'slurm', 'mpi'],
  51. default='none',
  52. help='job launcher')
  53. # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
  54. # will pass the `--local-rank` parameter to `tools/train.py` instead
  55. # of `--local_rank`.
  56. parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
  57. args = parser.parse_args()
  58. if 'LOCAL_RANK' not in os.environ:
  59. os.environ['LOCAL_RANK'] = str(args.local_rank)
  60. return args
  61. def main():
  62. args = parse_args()
  63. # load config
  64. cfg = Config.fromfile(args.config)
  65. cfg.launcher = args.launcher
  66. if args.cfg_options is not None:
  67. cfg.merge_from_dict(args.cfg_options)
  68. # work_dir is determined in this priority: CLI > segment in file > filename
  69. if args.work_dir is not None:
  70. # update configs according to CLI args if args.work_dir is not None
  71. cfg.work_dir = args.work_dir
  72. elif cfg.get('work_dir', None) is None:
  73. # use config filename as default work_dir if cfg.work_dir is None
  74. cfg.work_dir = osp.join('./work_dirs',
  75. osp.splitext(osp.basename(args.config))[0])
  76. # enable automatic-mixed-precision training
  77. if args.amp is True:
  78. optim_wrapper = cfg.optim_wrapper.type
  79. if optim_wrapper == 'AmpOptimWrapper':
  80. print_log(
  81. 'AMP training is already enabled in your config.',
  82. logger='current',
  83. level=logging.WARNING)
  84. else:
  85. assert optim_wrapper == 'OptimWrapper', (
  86. '`--amp` is only supported when the optimizer wrapper type is '
  87. f'`OptimWrapper` but got {optim_wrapper}.')
  88. cfg.optim_wrapper.type = 'AmpOptimWrapper'
  89. cfg.optim_wrapper.loss_scale = 'dynamic'
  90. # resume training
  91. cfg.resume = args.resume
  92. # build the runner from config
  93. if 'runner_type' not in cfg:
  94. # build the default runner
  95. runner = Runner.from_cfg(cfg)
  96. else:
  97. # build customized runner from the registry
  98. # if 'runner_type' is set in the cfg
  99. runner = RUNNERS.build(cfg)
  100. # start training
  101. runner.train()
  102. if __name__ == '__main__':
  103. main()

7. 测试模型

 单张测试的话,可以参考以下代码:

  1. """
  2. single_image_predict.py
  3. 用训练得到的模型进行预测-单张图像
  4. """
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from mmseg.apis import init_model, inference_model, show_result_pyplot
  8. import cv2
  9. # 载入模型 KNet
  10. # 模型 config 配置文件
  11. config_file = '../work_dirs/my_FastSCNN/my_FastSCNN_20230911.py'
  12. # 模型checkpoint权重文件
  13. checkpoint_file = '../work_dirs/my_FastSCNN/best_mIoU_iter_50500.pth'
  14. device = 'cuda:0'
  15. model = init_model(config_file, checkpoint_file, device=device)
  16. # 载入测试集图像或新图像
  17. img_path = '../data/.../xxx.png'
  18. img_bgr = cv2.imread(img_path)
  19. # 显示原图
  20. plt.figure(figsize=(8, 8))
  21. plt.imshow(img_bgr[:, :, ::-1])
  22. plt.show()
  23. # 语义分割预测
  24. result = inference_model(model, img_bgr)
  25. print(result.keys())
  26. pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
  27. print(pred_mask.shape)
  28. print(np.unique(pred_mask))
  29. # #****** 可视化语义分割预测结果——方法一(直接显示分割结果)******#
  30. # 定性
  31. plt.figure(figsize=(8, 8))
  32. plt.imshow(pred_mask)
  33. plt.savefig('test_result/k1-0.jpg')
  34. plt.show()
  35. # 定量
  36. print(result.seg_logits.data.shape)
  37. # #****** 可视化语义分割预测结果--方法二(叠加在原因上进行显示)******#
  38. # 显示语义分割结果
  39. plt.figure(figsize=(10, 8))
  40. plt.imshow(img_bgr[:, :, ::-1])
  41. plt.imshow(pred_mask, alpha=0.55)
  42. plt.axis('off')
  43. plt.savefig('test_result/k1-1.jpg')
  44. plt.show()
  45. # #****** 可视化语义分割预测结果--方法三(和原图并排显示) ******#
  46. plt.figure(figsize=(14, 8))
  47. plt.subplot(1, 2, 1)
  48. plt.imshow(img_bgr[:, :, ::-1])
  49. plt.axis('off')
  50. plt.subplot(1, 2, 2)
  51. plt.imshow(img_bgr[:, :, ::-1])
  52. plt.imshow(pred_mask, alpha=0.6)
  53. plt.axis('off')
  54. plt.savefig('test_result/k1-2.jpg')
  55. plt.show()
  56. # #****** 可视化语义分割预测结果-方法四(按配色方案叠加在原图上显示) ******#
  57. # 各类别的配色方案(BGR)
  58. palette = [
  59. ['background', [127, 127, 127]],
  60. ['vessel', [0, 0, 200]]
  61. ]
  62. palette_dict = {}
  63. for idx, each in enumerate(palette):
  64. palette_dict[idx] = each[1]
  65. print('palette_dict:', palette_dict)
  66. opacity = 0.3 # 透明度,越大越接近原图
  67. # 将预测的整数ID,映射为对应类别的颜色
  68. pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
  69. for idx in palette_dict.keys():
  70. pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
  71. pred_mask_bgr = pred_mask_bgr.astype('uint8')
  72. # 将语义分割预测图和原图叠加显示
  73. pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)
  74. cv2.imwrite('test_result/k1-3.jpg', pred_viz)
  75. plt.figure(figsize=(8, 8))
  76. plt.imshow(pred_viz[:, :, ::-1])
  77. plt.show()
  78. # #***** 可视化语义分割预测结果-方法五(按mmseg/datasets/cag.py里定义的类别颜色可视化) ***** #
  79. img_viz = show_result_pyplot(model, img_path, result, opacity=0.8, title='MMSeg', out_file='test_result/k1-4.jpg')
  80. print('the shape of img_viz:', img_viz.shape)
  81. plt.figure(figsize=(14, 8))
  82. plt.imshow(img_viz)
  83. plt.show()
  84. # #***** 可视化语义分割预测结果--方法六(加图例) ***** #
  85. from mmseg.datasets import CAGDataset
  86. import numpy as np
  87. import mmcv
  88. from PIL import Image
  89. # 获取类别名和调色板
  90. classes = CAGDataset.METAINFO['classes']
  91. palette = CAGDataset.METAINFO['palette']
  92. opacity = 0.15
  93. # 将分割图按调色板染色
  94. seg_map = pred_mask.astype('uint8')
  95. seg_img = Image.fromarray(seg_map).convert('P')
  96. seg_img.putpalette(np.array(palette, dtype=np.uint8))
  97. from matplotlib import pyplot as plt
  98. import matplotlib.patches as mpatches
  99. plt.figure(figsize=(14, 8))
  100. img_plot = ((np.array(seg_img.convert('RGB')))*(1-opacity) + mmcv.imread(img_path)*opacity) / 255
  101. im = plt.imshow(img_plot)
  102. # 为每一种颜色创建一个图例
  103. patches = [mpatches.Patch(color=np.array(palette[i])/255, label=classes[i]) for i in range(len(classes))]
  104. plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large')
  105. plt.savefig('test_result/k1-6.jpg')
  106. plt.show()

单张测试或批量测试,也可以参考以下代码:

  1. """
  2. @title: 根据配置(Config)文件和已训练好的参数(pth)进行推理
  3. @Date: 2023/09/14
  4. """
  5. # 导入必要的库
  6. import os
  7. import numpy as np
  8. import cv2
  9. from tqdm import tqdm
  10. from mmseg.apis import init_model, inference_model, show_result_pyplot
  11. import matplotlib.pyplot as plt
  12. import matplotlib.patches as mpatches
  13. from mmseg.datasets import CAGDataset
  14. import numpy as np
  15. import mmcv
  16. from PIL import Image
  17. def predict_single_img(img_path, model, save=False, save_path=None, save_model=1, show=False, show_model=1):
  18. # 读取图像
  19. img_bgr = cv2.imread(img_path)
  20. # 语义分割预测
  21. result = inference_model(model, img_bgr)
  22. pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
  23. # print(pred_mask.shape)
  24. # print(np.unique(pred_mask))
  25. # 将预测的整数ID, 映射为对应类别的颜色
  26. pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
  27. for idx in palette_dict.keys():
  28. pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
  29. pred_mask_bgr = pred_mask_bgr.astype('uint8')
  30. # 保存分割结果
  31. if save:
  32. # 保存方式一: 直接保存\显示分割(预测)结果
  33. if save_model == 1:
  34. save_image_path = os.path.join(save_path, 'pred-' + img_path.split('\\')[-1])
  35. cv2.imwrite(save_image_path, pred_mask_bgr)
  36. # 保存方式二: 将分割(预测)结果和原图叠加保存
  37. elif save_model == 2:
  38. opacity = 0.7 # 透明度,取值范围(0,1),越大越接近原图
  39. pred_add = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)
  40. save_image_path = os.path.join(save_path, 'pred-' + img_path.split('/')[-1])
  41. cv2.imwrite(save_image_path, pred_add)
  42. # 保存方式三:检测分割结果的边界轮廓,然后叠加到原图上
  43. elif save_model == 3:
  44. # 预测图转为灰度图
  45. binary = np.where(0.5 < pred_mask, 1, 0).astype(dtype=np.uint8)
  46. binary = binary * 255
  47. # 检测mask的边界
  48. contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  49. # 绘制边界轮廓
  50. # cv2.drawContours(image, contours, -1, (255, 0, 0), 1) # 蓝色 # (255, 0, 0) 蓝色; (0, 255, 0)绿色; (0, 0, 255)蓝色
  51. cv2.drawContours(img_bgr, contours, -1, (0, 255, 0), 1) # 绿色
  52. # 保存
  53. # save_image_path = os.path.join(save_path, 'pred-' + img_path.split('\\')[-1])
  54. save_image_path = os.path.join(save_path, img_path.split('\\')[-1])
  55. cv2.imwrite(save_image_path, img_bgr)
  56. # cv2.imencode(".png", img_bgr)[1].tofile(save_image_path)
  57. # 显示分割结果
  58. if show:
  59. # #****** 可视化语义分割预测结果——方法一(直接显示分割结果)******#
  60. if show_model == 1:
  61. plt.figure(figsize=(8, 8))
  62. plt.imshow(pred_mask)
  63. plt.savefig('test_result/k1-0.jpg')
  64. plt.show()
  65. print(result.seg_logits.data.shape) # 定量
  66. # #****** 可视化语义分割预测结果--方法二(叠加在原因上进行显示)******#
  67. elif show_model == 2:
  68. # 显示语义分割结果
  69. plt.figure(figsize=(10, 8))
  70. plt.imshow(img_bgr[:, :, ::-1])
  71. plt.imshow(pred_mask, alpha=0.55)
  72. plt.axis('off')
  73. plt.savefig('test_result/k1-1.jpg')
  74. plt.show()
  75. # #****** 可视化语义分割预测结果--方法三(和原图并排显示) ******#
  76. elif show_model == 3:
  77. plt.figure(figsize=(14, 8))
  78. plt.subplot(1, 2, 1)
  79. plt.imshow(img_bgr[:, :, ::-1])
  80. plt.axis('off')
  81. plt.subplot(1, 2, 2)
  82. plt.imshow(img_bgr[:, :, ::-1])
  83. plt.imshow(pred_mask, alpha=0.6)
  84. plt.axis('off')
  85. plt.savefig('test_result/k1-2.jpg')
  86. plt.show()
  87. # #****** 可视化语义分割预测结果-方法四(按配色方案叠加在原图上显示) ******#
  88. elif show_model == 4:
  89. opacity = 0.3 # 透明度,越大越接近原图
  90. # 将预测的整数ID,映射为对应类别的颜色
  91. pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
  92. for idx in palette_dict.keys():
  93. pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
  94. pred_mask_bgr = pred_mask_bgr.astype('uint8')
  95. # 将语义分割预测图和原图叠加显示
  96. pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
  97. cv2.imwrite('test_result/k1-3.jpg', pred_viz)
  98. plt.figure(figsize=(8, 8))
  99. plt.imshow(pred_viz[:, :, ::-1])
  100. plt.show()
  101. # #***** 可视化语义分割预测结果-方法五(按mmseg/datasets/cag.py里定义的类别颜色可视化) ***** #
  102. elif show_model == 5:
  103. img_viz = show_result_pyplot(model, img_path, result, opacity=0.8, title='MMSeg',
  104. out_file='test_result/k1-4.jpg')
  105. print('the shape of img_viz:', img_viz.shape)
  106. plt.figure(figsize=(14, 8))
  107. plt.imshow(img_viz)
  108. plt.show()
  109. # #***** 可视化语义分割预测结果--方法六(加图例) ***** #
  110. elif show_model == 6:
  111. # 获取类别名和调色板
  112. classes = CAGDataset.METAINFO['classes']
  113. palette = CAGDataset.METAINFO['palette']
  114. opacity = 0.15
  115. # 将分割图按调色板染色
  116. seg_map = pred_mask.astype('uint8')
  117. seg_img = Image.fromarray(seg_map).convert('P')
  118. seg_img.putpalette(np.array(palette, dtype=np.uint8))
  119. # from matplotlib import pyplot as plt
  120. # import matplotlib.patches as mpatches
  121. plt.figure(figsize=(14, 8))
  122. img_plot = ((np.array(seg_img.convert('RGB'))) * (1 - opacity) + mmcv.imread(img_path) * opacity) / 255
  123. im = plt.imshow(img_plot)
  124. # 为每一种颜色创建一个图例
  125. patches = [mpatches.Patch(color=np.array(palette[i]) / 255, label=classes[i]) for i in range(len(classes))]
  126. plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize='large')
  127. plt.savefig('test_result/k1-6.jpg')
  128. plt.show()
  129. if __name__ == '__main__':
  130. # 指定测试集路径
  131. test_images_path = '../data/XXX/images/test'
  132. # 指定测试集结果存放路径
  133. result_save_path = '../data/XXX/predict_result/Segformer_model3'
  134. # 检测结果存放的目录是否存在.不存在的话,创建空文件夹
  135. if not os.path.exists(result_save_path):
  136. os.mkdir(result_save_path)
  137. # 载入模型
  138. # 模型config配置文件
  139. config_file = '../my_Configs/my_Segformer_20230907.py'
  140. # 模型checkpoint权重文件
  141. checkpoint_file = '../work_dirs/my_Segformer/best_mIoU_iter_50000.pth'
  142. # 计算硬件
  143. device = 'cuda:0'
  144. # 指定各个类别的BGR配色
  145. palette = [
  146. ['background', [0, 0, 0]],
  147. ['vessel', [255, 255, 255]]
  148. ]
  149. palette_dict = {}
  150. for idx, each in enumerate(palette):
  151. palette_dict[idx] = each[1]
  152. print(palette_dict)
  153. # 加载模型
  154. model = init_model(config_file, checkpoint_file, device=device)
  155. # 单张图像预测函数
  156. opacity = 0.7
  157. # 测试集批量预测
  158. for each in tqdm(os.listdir(test_images_path)):
  159. print(each)
  160. image_path = os.path.join(test_images_path, each)
  161. # predict_single_img(image_path, model, save=True, save_path=result_save_path, save_model=3, show=False)
  162. predict_single_img(image_path, model, save=False, save_path=result_save_path, save_model=3, show=True, show_model=2)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/347432?site=
推荐阅读
相关标签
  

闽ICP备14008679号