赞
踩
mmsegmentation
|
|- configs # 配置文件
| |- base ## 基配置文件
| | |- datasets ### 数据集相关配置文件
| | |- models ### 模型相关配置文件
| | |- schedules ### 训练日程如优化器,学习率等相关配置文件
| | |- default_runtime.py ### 运行相关的默认的设置
| |- swin ## 各个分割模型的配置文件,会引用 base 的配置并做修改
| |- …
|- data # 原始及转换后的数据集文件
|- mmseg
| |- core ## 核心组件
| | |- evaluation ### 评估模型性能代码
| |- datasets ## 数据集相关代码
| | |- pipelines ### 数据预处理
| | |- samplers ### 数据集采样代码
| | |- ade.py ### 各个数据集准备需要的代码
| | |- …
| |- models ## 分割模型具体实现代码
| | |- backbones ### 主干网络
| | |- decode_heads ### 解码头
| | |- losses ### 损失函数
| | |- necks ### 颈
| | |- segmentors ### 构建完整分割网络的代码
| | |- utils ### 构建模型时的辅助工具
| |- apis ## high level 用户接口,在这里调用 ./mmseg/ 内各个组件
| | |- train.py ### 训练接口
| | |- test.py ### 测试接口
| | |- …
| |- ops ## cuda 算子(即将迁移到 mmcv 中)
| |- utils ## 辅助工具
|- tools
| |- model_converters ## 各个主干网络预训练模型转 key 脚本
| |- convert_datasets ## 各个数据集准备转换脚本
| |- train.py ## 训练脚本
| |- test.py ## 测试脚本
| |- …
|- …
如果你想自己转换关键字使用官方存储库的预训练模型,我们还提供了一个脚本swin2mmseg.py在tools directory ,将模型的关键字从官方的repo转换为MMSegmentation风格。
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth
这个脚本从PRETRAIN_PATH转换模型,并将转换后的模型存储在STORE_PATH中。
在我们的默认设置中,预训练的模型及其对应的原始模型模型可以定义如下:
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k_20220318_015320-48d180dd.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth
#tiny https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth #small https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth #big https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pth https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth #large https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth
ADE20k拥有超过25,000张图像(20ktrain,2k val,3ktest),这些图像用开放字典标签集密集注释。对于2017 Places Challenge 2,选择了覆盖89%所有像素的100个thing和50个stuff类别。
一共150个类别。
Idx Ratio Train Val Name 1 0.1576 11664 1172 wall 2 0.1072 6046 612 building, edifice 3 0.0878 8265 796 sky 4 0.0621 9336 917 floor, flooring 5 0.0480 6678 641 tree 6 0.0450 6604 643 ceiling 7 0.0398 4023 408 road, route 8 0.0231 1906 199 bed 9 0.0198 4688 460 windowpane, window 10 0.0183 2423 225 grass 11 0.0181 2874 294 cabinet 12 0.0166 3068 310 sidewalk, pavement 13 0.0160 5075 526 person, individual, someone, somebody, mortal, soul 14 0.0151 1804 190 earth, ground 15 0.0118 6666 796 door, double door 16 0.0110 4269 411 table 17 0.0109 1691 160 mountain, mount 18 0.0104 3999 441 plant, flora, plant life 19 0.0104 2149 217 curtain, drape, drapery, mantle, pall 20 0.0103 3261 318 chair 21 0.0098 3164 306 car, auto, automobile, machine, motorcar 22 0.0074 709 75 water 23 0.0067 3296 315 painting, picture 24 0.0065 1191 106 sofa, couch, lounge 25 0.0061 1516 162 shelf 26 0.0060 667 69 house 27 0.0053 651 57 sea 28 0.0052 1847 224 mirror 29 0.0046 1158 128 rug, carpet, carpeting 30 0.0044 480 44 field 31 0.0044 1172 98 armchair 32 0.0044 1292 184 seat 33 0.0033 1386 138 fence, fencing 34 0.0031 698 61 desk 35 0.0030 781 73 rock, stone 36 0.0027 380 43 wardrobe, closet, press 37 0.0026 3089 302 lamp 38 0.0024 404 37 bathtub, bathing tub, bath, tub 39 0.0024 804 99 railing, rail 40 0.0023 1453 153 cushion 41 0.0023 411 37 base, pedestal, stand 42 0.0022 1440 162 box 43 0.0022 800 77 column, pillar 44 0.0020 2650 298 signboard, sign 45 0.0019 549 46 chest of drawers, chest, bureau, dresser 46 0.0019 367 36 counter 47 0.0018 311 30 sand 48 0.0018 1181 122 sink 49 0.0018 287 23 skyscraper 50 0.0018 468 38 fireplace, hearth, open fireplace 51 0.0018 402 43 refrigerator, icebox 52 0.0018 130 12 grandstand, covered stand 53 0.0018 561 64 path 54 0.0017 880 102 stairs, steps 55 0.0017 86 12 runway 56 0.0017 172 11 case, display case, showcase, vitrine 57 0.0017 198 18 pool table, billiard table, snooker table 58 0.0017 930 109 pillow 59 0.0015 139 18 screen door, screen 60 0.0015 564 52 stairway, staircase 61 0.0015 320 26 river 62 0.0015 261 29 bridge, span 63 0.0014 275 22 bookcase 64 0.0014 335 60 blind, screen 65 0.0014 792 75 coffee table, cocktail table 66 0.0014 395 49 toilet, can, commode, crapper, pot, potty, stool, throne 67 0.0014 1309 138 flower 68 0.0013 1112 113 book 69 0.0013 266 27 hill 70 0.0013 659 66 bench 71 0.0012 331 31 countertop 72 0.0012 531 56 stove, kitchen stove, range, kitchen range, cooking stove 73 0.0012 369 36 palm, palm tree 74 0.0012 144 9 kitchen island 75 0.0011 265 29 computer, computing machine, computing device, data processor, electronic computer, information processing system 76 0.0010 324 33 swivel chair 77 0.0009 304 27 boat 78 0.0009 170 20 bar 79 0.0009 68 6 arcade machine 80 0.0009 65 8 hovel, hut, hutch, shack, shanty 81 0.0009 248 25 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle 82 0.0008 492 49 towel 83 0.0008 2510 269 light, light source 84 0.0008 440 39 truck, motortruck 85 0.0008 147 18 tower 86 0.0008 583 56 chandelier, pendant, pendent 87 0.0007 533 61 awning, sunshade, sunblind 88 0.0007 1989 239 streetlight, street lamp 89 0.0007 71 5 booth, cubicle, stall, kiosk 90 0.0007 618 53 television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box 91 0.0007 135 12 airplane, aeroplane, plane 92 0.0007 83 5 dirt track 93 0.0007 178 17 apparel, wearing apparel, dress, clothes 94 0.0006 1003 104 pole 95 0.0006 182 12 land, ground, soil 96 0.0006 452 50 bannister, banister, balustrade, balusters, handrail 97 0.0006 42 6 escalator, moving staircase, moving stairway 98 0.0006 307 31 ottoman, pouf, pouffe, puff, hassock 99 0.0006 965 114 bottle 100 0.0006 117 13 buffet, counter, sideboard 101 0.0006 354 35 poster, posting, placard, notice, bill, card 102 0.0006 108 9 stage 103 0.0006 557 55 van 104 0.0006 52 4 ship 105 0.0005 99 5 fountain 106 0.0005 57 4 conveyer belt, conveyor belt, conveyer, conveyor, transporter 107 0.0005 292 31 canopy 108 0.0005 77 9 washer, automatic washer, washing machine 109 0.0005 340 38 plaything, toy 110 0.0005 66 3 swimming pool, swimming bath, natatorium 111 0.0005 465 49 stool 112 0.0005 50 4 barrel, cask 113 0.0005 622 75 basket, handbasket 114 0.0005 80 9 waterfall, falls 115 0.0005 59 3 tent, collapsible shelter 116 0.0005 531 72 bag 117 0.0005 282 30 minibike, motorbike 118 0.0005 73 7 cradle 119 0.0005 435 44 oven 120 0.0005 136 25 ball 121 0.0005 116 24 food, solid food 122 0.0004 266 31 step, stair 123 0.0004 58 12 tank, storage tank 124 0.0004 418 83 trade name, brand name, brand, marque 125 0.0004 319 43 microwave, microwave oven 126 0.0004 1193 139 pot, flowerpot 127 0.0004 97 23 animal, animate being, beast, brute, creature, fauna 128 0.0004 347 36 bicycle, bike, wheel, cycle 129 0.0004 52 5 lake 130 0.0004 246 22 dishwasher, dish washer, dishwashing machine 131 0.0004 108 13 screen, silver screen, projection screen 132 0.0004 201 30 blanket, cover 133 0.0004 285 21 sculpture 134 0.0004 268 27 hood, exhaust hood 135 0.0003 1020 108 sconce 136 0.0003 1282 122 vase 137 0.0003 528 65 traffic light, traffic signal, stoplight 138 0.0003 453 57 tray 139 0.0003 671 100 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 140 0.0003 397 44 fan 141 0.0003 92 8 pier, wharf, wharfage, dock 142 0.0003 228 18 crt screen 143 0.0003 570 59 plate 144 0.0003 217 22 monitor, monitoring device 145 0.0003 206 19 bulletin board, notice board 146 0.0003 130 14 shower 147 0.0003 178 28 radiator 148 0.0002 504 57 glass, drinking glass 149 0.0002 775 96 clock 150 0.0002 421 56 flag
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── cityscapes
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClass
│ │ │ ├── ImageSets
│ │ │ │ ├── Segmentation
│ │ ├── VOC2010
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClassContext
│ │ │ ├── ImageSets
│ │ │ │ ├── SegmentationContext
│ │ │ │ │ ├── train.txt
│ │ │ │ │ ├── val.txt
│ │ │ ├── trainval_merged.json
│ │ ├── VOCaug
│ │ │ ├── dataset
│ │ │ │ ├── cls
│ ├── ade
│ │ ├── ADEChallengeData2016
│ │ │ ├── annotations
│ │ │ │ ├── training
│ │ │ │ ├── validation
│ │ │ ├── images
│ │ │ │ ├── training
│ │ │ │ ├── validation
本次我们选择upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K模型进行训练,对应的配置文件如下。
具体配置信息如下
_base_ = [ '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa model = dict( backbone=dict( init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), embed_dims=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, use_abs_pos_embed=False, drop_path_rate=0.3, patch_norm=True), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), auxiliary_head=dict(in_channels=384, num_classes=150)) # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ 'absolute_pos_embed': dict(decay_mult=0.), 'relative_position_bias_table': dict(decay_mult=0.), 'norm': dict(decay_mult=0.) })) lr_config = dict( _delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2)
_base_ = [ '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa,这个可以下载后,加载下载后的路径 model = dict( backbone=dict( init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), embed_dims=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, use_abs_pos_embed=False, drop_path_rate=0.3, patch_norm=True), decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150), auxiliary_head=dict(in_channels=384, num_classes=150))#num_classes修改为自己的数据类别数,不包括背景,背景自动为0 # AdamW optimizer, no weight decay for position embedding & layer norm # in backbone optimizer = dict( _delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict( custom_keys={ 'absolute_pos_embed': dict(decay_mult=0.), 'relative_position_bias_table': dict(decay_mult=0.), 'norm': dict(decay_mult=0.) })) lr_config = dict( _delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) # By default, models are trained on 8 GPUs with 2 images per GPU data = dict(samples_per_gpu=2)
# dataset settings dataset_type = 'ADE20KDataset' data_root = 'data/ade/ADEChallengeData2016' #1、修改为自己的数据路径 img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) crop_size = (512, 512) #2、修改为自己的数据的尺寸 train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),#根据img_crop调整img_scale dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2048, 512), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict( type=dataset_type, data_root=data_root, img_dir='images/training', ann_dir='annotations/training', pipeline=train_pipeline), val=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline), test=dict( type=dataset_type, data_root=data_root, img_dir='images/validation', ann_dir='annotations/validation', pipeline=test_pipeline))
# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import mmcv import numpy as np from PIL import Image from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class ADE20KDataset(CustomDataset): """ADE20K dataset. In segmentation map annotation for ADE20K, 0 stands for background, which is not included in 150 categories. ``reduce_zero_label`` is fixed to True. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to '.png'. """ CLASSES = ( 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag')#修改为自己数据集的类别名称 PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] #同理可以修改颜色 def __init__(self, **kwargs): super(ADE20KDataset, self).__init__( img_suffix='.jpg', #可以修改数据集的后缀格式 seg_map_suffix='.png',#可以修改数据集标签的后缀格式 reduce_zero_label=True, **kwargs) def results2img(self, results, imgfile_prefix, to_label_id, indices=None): """Write the segmentation results to images. Args: results (list[ndarray]): Testing results of the dataset. imgfile_prefix (str): The filename prefix of the png files. If the prefix is "somepath/xxx", the png files will be named "somepath/xxx.png". to_label_id (bool): whether convert output to label_id for submission. indices (list[int], optional): Indices of input results, if not set, all the indices of the dataset will be used. Default: None. Returns: list[str: str]: result txt files which contains corresponding semantic segmentation images. """ if indices is None: indices = list(range(len(self))) mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] for result, idx in zip(results, indices): filename = self.img_infos[idx]['filename'] basename = osp.splitext(osp.basename(filename))[0] png_filename = osp.join(imgfile_prefix, f'{basename}.png')#这里可以修改.png # The index range of official requirement is from 0 to 150. # But the index range of output is from 0 to 149. # That is because we set reduce_zero_label=True. result = result + 1 output = Image.fromarray(result.astype(np.uint8)) output.save(png_filename) result_files.append(png_filename) return result_files def format_results(self, results, imgfile_prefix, to_label_id=True, indices=None): """Format the results into dir (standard format for ade20k evaluation). Args: results (list): Testing results of the dataset. imgfile_prefix (str | None): The prefix of images files. It includes the file path and the prefix of filename, e.g., "a/b/prefix". to_label_id (bool): whether convert output to label_id for submission. Default: False indices (list[int], optional): Indices of input results, if not set, all the indices of the dataset will be used. Default: None. Returns: tuple: (result_files, tmp_dir), result_files is a list containing the image paths, tmp_dir is the temporal directory created for saving json/png files when img_prefix is not specified. """ if indices is None: indices = list(range(len(self))) assert isinstance(results, list), 'results must be a list.' assert isinstance(indices, list), 'indices must be a list.' result_files = self.results2img(results, imgfile_prefix, to_label_id, indices) return result_files
有一点需要注意的是,如果你的图片是jpg合式,mask是png格式,应该没问题,要是不是这两种格式的话,需要在mmseg/datasets/custom.py中修改你的图片的格式。
# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import warnings from collections import OrderedDict import mmcv import numpy as np from mmcv.utils import print_log from prettytable import PrettyTable from torch.utils.data import Dataset from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose, LoadAnnotations @DATASETS.register_module() class CustomDataset(Dataset): """Custom dataset for semantic segmentation. An example of file structure is as followed. .. code-block:: none ├── data │ ├── my_dataset │ │ ├── img_dir │ │ │ ├── train │ │ │ │ ├── xxx{img_suffix} │ │ │ │ ├── yyy{img_suffix} │ │ │ │ ├── zzz{img_suffix} │ │ │ ├── val │ │ ├── ann_dir │ │ │ ├── train │ │ │ │ ├── xxx{seg_map_suffix} │ │ │ │ ├── yyy{seg_map_suffix} │ │ │ │ ├── zzz{seg_map_suffix} │ │ │ ├── val The img/gt_semantic_seg pair of CustomDataset should be of the same except suffix. A valid img/gt_semantic_seg filename pair should be like ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included in the suffix). If split is given, then ``xxx`` is specified in txt file. Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. Args: pipeline (list[dict]): Processing pipeline img_dir (str): Path to image directory img_suffix (str): Suffix of images. Default: '.jpg' ann_dir (str, optional): Path to annotation directory. Default: None seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' split (str, optional): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None data_root (str, optional): Data root for img_dir/ann_dir. Default: None. test_mode (bool): If test_mode=True, gt wouldn't be loaded. ignore_index (int): The label index to be ignored. Default: 255 reduce_zero_label (bool): Whether to mark label zero as ignored. Default: False classes (str | Sequence[str], optional): Specify classes to load. If is None, ``cls.CLASSES`` will be used. Default: None. palette (Sequence[Sequence[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, and self.PALETTE is None, random palette will be generated. Default: None gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to load gt for evaluation, load from disk by default. Default: None. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to ``dict(backend='disk')``. """ CLASSES = None PALETTE = None def __init__(self, pipeline, img_dir, img_suffix='.jpg',#修改 ann_dir=None, seg_map_suffix='.png',修改 split=None, data_root=None, test_mode=False, ignore_index=255, reduce_zero_label=False, classes=None, palette=None, gt_seg_map_loader_cfg=None, file_client_args=dict(backend='disk')): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix self.ann_dir = ann_dir self.seg_map_suffix = seg_map_suffix self.split = split self.data_root = data_root self.test_mode = test_mode self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label self.label_map = None self.CLASSES, self.PALETTE = self.get_classes_and_palette( classes, palette) self.gt_seg_map_loader = LoadAnnotations( ) if gt_seg_map_loader_cfg is None else LoadAnnotations( **gt_seg_map_loader_cfg) self.file_client_args = file_client_args self.file_client = mmcv.FileClient.infer_client(self.file_client_args) if test_mode: assert self.CLASSES is not None, \ '`cls.CLASSES` or `classes` should be specified when testing' # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.img_dir): self.img_dir = osp.join(self.data_root, self.img_dir) if not (self.ann_dir is None or osp.isabs(self.ann_dir)): self.ann_dir = osp.join(self.data_root, self.ann_dir) if not (self.split is None or osp.isabs(self.split)): self.split = osp.join(self.data_root, self.split) # load annotations self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.ann_dir, self.seg_map_suffix, self.split) def __len__(self): """Total number of samples of data.""" return len(self.img_infos) def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): """Load annotation from directory. Args: img_dir (str): Path to image directory img_suffix (str): Suffix of images. ann_dir (str|None): Path to annotation directory. seg_map_suffix (str|None): Suffix of segmentation maps. split (str|None): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None Returns: list[dict]: All image info of dataset. """ img_infos = [] if split is not None: lines = mmcv.list_from_file( split, file_client_args=self.file_client_args) for line in lines: img_name = line.strip() img_info = dict(filename=img_name + img_suffix) if ann_dir is not None: seg_map = img_name + seg_map_suffix img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in self.file_client.list_dir_or_file( dir_path=img_dir, list_dir=False, suffix=img_suffix, recursive=True): img_info = dict(filename=img) if ann_dir is not None: seg_map = img.replace(img_suffix, seg_map_suffix) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) img_infos = sorted(img_infos, key=lambda x: x['filename']) print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) return img_infos def get_ann_info(self, idx): """Get annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ return self.img_infos[idx]['ann'] def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['seg_fields'] = [] results['img_prefix'] = self.img_dir results['seg_prefix'] = self.ann_dir if self.custom_classes: results['label_map'] = self.label_map def __getitem__(self, idx): """Get training/test data after pipeline. Args: idx (int): Index of data. Returns: dict: Training/test data (with annotation if `test_mode` is set False). """ if self.test_mode: return self.prepare_test_img(idx) else: return self.prepare_train_img(idx) def prepare_train_img(self, idx): """Get training data and annotations after pipeline. Args: idx (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys introduced by pipeline. """ img_info = self.img_infos[idx] ann_info = self.get_ann_info(idx) results = dict(img_info=img_info, ann_info=ann_info) self.pre_pipeline(results) return self.pipeline(results) def prepare_test_img(self, idx): """Get testing data after pipeline. Args: idx (int): Index of data. Returns: dict: Testing data after pipeline with new keys introduced by pipeline. """ img_info = self.img_infos[idx] results = dict(img_info=img_info) self.pre_pipeline(results) return self.pipeline(results) def format_results(self, results, imgfile_prefix, indices=None, **kwargs): """Place holder to format result to dataset specific output.""" raise NotImplementedError def get_gt_seg_map_by_idx(self, index): """Get one ground truth segmentation map for evaluation.""" ann_info = self.get_ann_info(index) results = dict(ann_info=ann_info) self.pre_pipeline(results) self.gt_seg_map_loader(results) return results['gt_semantic_seg'] def get_gt_seg_maps(self, efficient_test=None): """Get ground truth segmentation maps for evaluation.""" if efficient_test is not None: warnings.warn( 'DeprecationWarning: ``efficient_test`` has been deprecated ' 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' 'friendly by default. ') for idx in range(len(self)): ann_info = self.get_ann_info(idx) results = dict(ann_info=ann_info) self.pre_pipeline(results) self.gt_seg_map_loader(results) yield results['gt_semantic_seg'] def pre_eval(self, preds, indices): """Collect eval result from each iteration. Args: preds (list[torch.Tensor] | torch.Tensor): the segmentation logit after argmax, shape (N, H, W). indices (list[int] | int): the prediction related ground truth indices. Returns: list[torch.Tensor]: (area_intersect, area_union, area_prediction, area_ground_truth). """ # In order to compat with batch inference if not isinstance(indices, list): indices = [indices] if not isinstance(preds, list): preds = [preds] pre_eval_results = [] for pred, index in zip(preds, indices): seg_map = self.get_gt_seg_map_by_idx(index) pre_eval_results.append( intersect_and_union( pred, seg_map, len(self.CLASSES), self.ignore_index, # as the labels has been converted when dataset initialized # in `get_palette_for_custom_classes ` this `label_map` # should be `dict()`, see # https://github.com/open-mmlab/mmsegmentation/issues/1415 # for more ditails label_map=dict(), reduce_zero_label=self.reduce_zero_label)) return pre_eval_results def get_classes_and_palette(self, classes=None, palette=None): """Get class names of current dataset. Args: classes (Sequence[str] | str | None): If classes is None, use default CLASSES defined by builtin dataset. If classes is a string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. palette (Sequence[Sequence[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None """ if classes is None: self.custom_classes = False return self.CLASSES, self.PALETTE self.custom_classes = True if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') if self.CLASSES: if not set(class_names).issubset(self.CLASSES): raise ValueError('classes is not a subset of CLASSES.') # dictionary, its keys are the old label ids and its values # are the new label ids. # used for changing pixel labels in load_annotations. self.label_map = {} for i, c in enumerate(self.CLASSES): if c not in class_names: self.label_map[i] = -1 else: self.label_map[i] = class_names.index(c) palette = self.get_palette_for_custom_classes(class_names, palette) return class_names, palette def get_palette_for_custom_classes(self, class_names, palette=None): if self.label_map is not None: # return subset of palette palette = [] for old_id, new_id in sorted( self.label_map.items(), key=lambda x: x[1]): if new_id != -1: palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) elif palette is None: if self.PALETTE is None: # Get random state before set seed, and restore # random state later. # It will prevent loss of randomness, as the palette # may be different in each iteration if not specified. # See: https://github.com/open-mmlab/mmdetection/issues/5844 state = np.random.get_state() np.random.seed(42) # random palette palette = np.random.randint(0, 255, size=(len(class_names), 3)) np.random.set_state(state) else: palette = self.PALETTE return palette def evaluate(self, results, metric='mIoU', logger=None, gt_seg_maps=None, **kwargs): """Evaluate the dataset. Args: results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval results or predict segmentation map for computing evaluation metric. metric (str | list[str]): Metrics to be evaluated. 'mIoU', 'mDice' and 'mFscore' are supported. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. gt_seg_maps (generator[ndarray]): Custom gt seg maps as input, used in ConcatDataset Returns: dict[str, float]: Default metrics. """ if isinstance(metric, str): metric = [metric] allowed_metrics = ['mIoU', 'mDice', 'mFscore'] if not set(metric).issubset(set(allowed_metrics)): raise KeyError('metric {} is not supported'.format(metric)) eval_results = {} # test a list of files if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( results, str): if gt_seg_maps is None: gt_seg_maps = self.get_gt_seg_maps() num_classes = len(self.CLASSES) ret_metrics = eval_metrics( results, gt_seg_maps, num_classes, self.ignore_index, metric, label_map=dict(), reduce_zero_label=self.reduce_zero_label) # test a list of pre_eval_results else: ret_metrics = pre_eval_to_metrics(results, metric) # Because dataset.CLASSES is required for per-eval. if self.CLASSES is None: class_names = tuple(range(num_classes)) else: class_names = self.CLASSES # summary table ret_metrics_summary = OrderedDict({ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) for ret_metric, ret_metric_value in ret_metrics.items() }) # each class table ret_metrics.pop('aAcc', None) ret_metrics_class = OrderedDict({ ret_metric: np.round(ret_metric_value * 100, 2) for ret_metric, ret_metric_value in ret_metrics.items() }) ret_metrics_class.update({'Class': class_names}) ret_metrics_class.move_to_end('Class', last=False) # for logger class_table_data = PrettyTable() for key, val in ret_metrics_class.items(): class_table_data.add_column(key, val) summary_table_data = PrettyTable() for key, val in ret_metrics_summary.items(): if key == 'aAcc': summary_table_data.add_column(key, [val]) else: summary_table_data.add_column('m' + key, [val]) print_log('per class results:', logger) print_log('\n' + class_table_data.get_string(), logger=logger) print_log('Summary:', logger) print_log('\n' + summary_table_data.get_string(), logger=logger) # each metric dict for key, value in ret_metrics_summary.items(): if key == 'aAcc': eval_results[key] = value / 100.0 else: eval_results['m' + key] = value / 100.0 ret_metrics_class.pop('Class', None) for key, value in ret_metrics_class.items(): eval_results.update({ key + '.' + str(name): value[idx] / 100.0 for idx, name in enumerate(class_names) }) return eval_results
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook') #开启TensorboardLoggerHook
# dict(type='PaviLoggerHook') # for internal services
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None #从给定的路径加载模型作为预先训练的模型,这不会恢复训练。
resume_from = None #从给定的路径加载模型作为训练后的断点的模型,恢复训练。
workflow = [('train', 1)]
cudnn_benchmark = True
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)#max_iters,模型训练的最大迭代次数
checkpoint_config = dict(by_epoch=False, interval=16000)##interval,模型保存的迭代次数
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)#interval=16000模型多少间隔训练一次,评估的指标,#save_best='auto'可以保留最好的模型
log_config = dict( interval=50, hooks=[ dict(type='TextLoggerHook', by_epoch=False), dict(type='TensorboardLoggerHook') ]) dist_params = dict(backend='nccl') log_level = 'INFO' load_from = '/media/lhy/Swin-Transformer-Semantic-Segmentation/checkpoints/deeplabv3plus/deeplabv3plus_r101-d8_512x512_40k_voc12aug_20200613_205333-faf03387.pth' resume_from = '/media/lhy/mmsegmentation-0.27.0/work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2/best_mIoU_iter_44000.pth' workflow = [('train', 1)] cudnn_benchmark = True optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)#我们**默认使用4 gpu的分布式训练** #调用FP16 optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic') fp16 = dict() lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False) runner = dict(type='IterBasedRunner', max_iters=160000) checkpoint_config = dict(by_epoch=False, interval=4000) evaluation = dict( interval=4000, metric=['mIoU', 'mFscore'], pre_eval=True, save_best='mIoU')#自动保存mIOU最好的模型 work_dir = 'work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2' gpu_ids = range(0, 4) auto_resume = False
# model settings norm_cfg = dict(type='SyncBN', requires_grad=True)#这里的norm_cfg中,如果是多卡训练,采用“SyncBN”; 如果是单卡训练,将type修改为'BN'即可。 backbone_norm_cfg = dict(type='LN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='SwinTransformer', pretrain_img_size=224, embed_dims=96, patch_size=4, window_size=7, mlp_ratio=4, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], strides=(4, 2, 2, 2), out_indices=(0, 1, 2, 3), qkv_bias=True, qk_scale=None, patch_norm=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, use_abs_pos_embed=False, act_cfg=dict(type='GELU'), norm_cfg=backbone_norm_cfg), decode_head=dict( type='UPerHead', in_channels=[96, 192, 384, 768], in_index=[0, 1, 2, 3], pool_scales=(1, 2, 3, 6), channels=512, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), auxiliary_head=dict( type='FCNHead', in_channels=384, in_index=2, channels=256, num_convs=1, concat_input=False, dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole')) #'whole代表全图推理模式', #滑窗重叠预测可修改为:test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341))
滑动窗口代码:mmsegmentation/mmseg/models/segmentors/encoder_decoder.py
# TODO refactor def slide_inference(self, img, img_meta, rescale): """Inference by sliding-window with overlap. If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding. """ h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() num_classes = self.num_classes h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logit = self.encode_decode(crop_img, img_meta) preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 if torch.onnx.is_in_onnx_export(): # cast count_mat to constant while exporting to ONNX count_mat = torch.from_numpy( count_mat.cpu().detach().numpy()).to(device=img.device) preds = preds / count_mat if rescale: # remove padding area resize_shape = img_meta[0]['img_shape'][:2] preds = preds[:, :, :resize_shape[0], :resize_shape[1]] preds = resize( preds, size=img_meta[0]['ori_shape'][:2], mode='bilinear', align_corners=self.align_corners, warning=False) return preds
在语义分割中,一些方法使头部的 LR 大于骨干,以实现更好的性能或更快的收敛。
在 MMSegmentation 中,您可以在配置中添加以下行,以使 head 的 LR 是主干的 10 倍。通过此修改,任何具有 LR名称的参数组的 LR’head’都将乘以 10。
Different Learning Rate(LR) for Backbone and Heads
n MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.
optimizer=dict(
paramwise_cfg = dict(
custom_keys={
'head': dict(lr_mult=10.)}))
我们在这里实现像素采样器用于训练采样。这是一个启用 OHEM 的 PSPNet 训练示例配置。
这样,只使用置信度分数低于 0.7 的像素进行训练。我们在训练期间至少保留 100000 像素。如果thresh未指定,min_kept将选择顶部丢失的像素。
Online Hard Example Mining (OHEM)
We implement pixel sampler here for training sampling. Here is an example config of training PSPNet with OHEM enabled.
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
对于类别分布不平衡的数据集,您可以更改每个类别的损失权重。这是城市景观数据集的示例。class_weight 将作为weight参数传入CrossEntropyLoss
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
# DeepLab used this class weight for cityscapes
class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))
对于损失计算,我们支持同时进行多个损失训练。unet这是一个在数据集上训练的示例配置DRIVE,其损失函数是1:3和 的加权CrossEntropyLoss和DiceLoss:
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
这样,loss_weight和loss_name将分别是对应损失的训练日志中的权重和名称。
注意:如果要将此损失项包含到后向图中,loss_必须是名称的前缀。
mmseg 中已经为各种公共分割数据集编写了描述文件和加载代码,对于有用过 PyTorch 的小伙伴而言,学习各种数据集的描述文件还是很自如的,只有 reduce_zero_label 对于 mmseg 的新手比较陌生,所以,在搭建自己的 mmseg 数据集时,新手最疑惑的大概就是 reduce_zero_label 到底应该是 True 还是 False。
它有什么用呢?从名字直译过来就是“减少 0 值标签”。在多类分割任务中,如果你的数据集中 0 值作为 label 文件中的背景类别,是建议忽略的。
打开加载数据的源码片段可以看到一段处理 reduce_zero_label 的代码,意思是:若开启了 reduce_zero_label,原本为 0 的所有标注设置为 255,也就是损失函数中 ignore_index 参数的默认值,该参数默认避免值为 255 的标注参与损失计算。前文按下不表的 150 类的 ADE 数据集,它不包含背景的原因就是开了 reduce zero label,原本为 0 值的背景设置为了 ignore_index。
# mmseg/datasets/pipelines/loading.py
...
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
...
reduce_zero_label 导致的常见问题描述
我们这里以 ADE 数据集源码为例,reduce_zero_label 默认设置为 True,然而,就算新手掌握了上一节的 reduce_zero_label,也可能对 ADE 了解比较肤浅,会怀疑配置文件中开启的 reduce_zero_label 是不是把 150 个实例类中的第一个给忽略掉了,毕竟 num_classes 不就是 150 吗,然后想当然把 reduce_zero_label 关掉。
错误原因分析
# configs/_base_/datasets/ade20k.py
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), # ADE中reduce_zero_label默认设置为True
dict(...),
...
]
label 中实际参加训练的确实只有 150 类,定义在 CLASSES 中,但 label 文件中实际包含了 151 类,而背景类(剩下仍没有标记的,或者被意外忽略的区域都归为背景,在 label 中值为 0)不包含在 150 个 CLASSES 中,需要在训练的时候设置成 ignore_index,所以我们借助上一小节的 reduce_zero_label 将背景从 151 个类中提出来单独设置为了 ignore_index,我们倘若错误地将 reduce_zero_label 关掉了,那 num_classes 就是 151 了。
在默认设置中,avg_non_ignore=False这意味着每个像素都计入损失计算,尽管其中一些属于忽略索引标签。
对于损失计算,我们支持通过avg_non_ignore和忽略某些标签的索引ignore_index。这样,平均损失只会在非忽略标签中计算,可能会获得更好的性能,这里是参考。unet这是数据集训练的示例配置Cityscapes:在损失计算中,它将忽略作为背景的标签 0,并且仅在非忽略标签上计算损失平均值:
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
只需添加ignore_index解码器头或辅助头并添加avg_non_ignore=True:
# model settings
...
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
...
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。