MMSegmentation V0.27.0 training and reasoning your own data set (2)

1. Official model conversion MMSegmentation style

If you want to convert the keywords yourself to use the pretrained model from the official repository, we also provide a script swin2mmseg.py in the tools directory to convert the keywords of the model from the official repo to the MMSegmentation style.

python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth

This script converts the model from PRETRAIN_PATH and stores the converted model in STORE_PATH.
In our default setup, the pretrained model and its corresponding original model model can be defined as follows:
insert image description here
insert image description here

2. Download the model of ADK20

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k_20220318_015320-48d180dd.pth

https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth

3. Download the Swin Transform pre-training model

#tiny

https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth

#small
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth

#big
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pth

https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth

https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth

https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth

#large
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth

https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth

4. Construct the data directory of ADK20 structure

ADE20k has more than 25,000 images (20ktrain, 2k val, 3ktest) densely annotated with an open dictionary label set. For the 2017 Places Challenge 2, 100 thing and 50 stuff categories covering 89% of all pixels were selected.
There are 150 categories in total.
insert image description here

Idx	Ratio	Train	Val	Name
1	0.1576	11664	1172	wall
2	0.1072	6046	612	building, edifice
3	0.0878	8265	796	sky
4	0.0621	9336	917	floor, flooring
5	0.0480	6678	641	tree
6	0.0450	6604	643	ceiling
7	0.0398	4023	408	road, route
8	0.0231	1906	199	bed 
9	0.0198	4688	460	windowpane, window 
10	0.0183	2423	225	grass
11	0.0181	2874	294	cabinet
12	0.0166	3068	310	sidewalk, pavement
13	0.0160	5075	526	person, individual, someone, somebody, mortal, soul
14	0.0151	1804	190	earth, ground
15	0.0118	6666	796	door, double door
16	0.0110	4269	411	table
17	0.0109	1691	160	mountain, mount
18	0.0104	3999	441	plant, flora, plant life
19	0.0104	2149	217	curtain, drape, drapery, mantle, pall
20	0.0103	3261	318	chair
21	0.0098	3164	306	car, auto, automobile, machine, motorcar
22 	0.0074	709	75	water
23	0.0067	3296	315	painting, picture
24 	0.0065	1191	106	sofa, couch, lounge
25 	0.0061	1516	162	shelf
26 	0.0060	667	69	house
27 	0.0053	651	57	sea
28	0.0052	1847	224	mirror
29	0.0046	1158	128	rug, carpet, carpeting
30	0.0044	480	44	field
31	0.0044	1172	98	armchair
32	0.0044	1292	184	seat
33	0.0033	1386	138	fence, fencing
34	0.0031	698	61	desk
35	0.0030	781	73	rock, stone
36	0.0027	380	43	wardrobe, closet, press
37	0.0026	3089	302	lamp
38	0.0024	404	37	bathtub, bathing tub, bath, tub
39	0.0024	804	99	railing, rail
40	0.0023	1453	153	cushion
41	0.0023	411	37	base, pedestal, stand
42	0.0022	1440	162	box
43	0.0022	800	77	column, pillar
44	0.0020	2650	298	signboard, sign
45	0.0019	549	46	chest of drawers, chest, bureau, dresser
46	0.0019	367	36	counter
47	0.0018	311	30	sand
48	0.0018	1181	122	sink
49	0.0018	287	23	skyscraper
50	0.0018	468	38	fireplace, hearth, open fireplace
51	0.0018	402	43	refrigerator, icebox
52	0.0018	130	12	grandstand, covered stand
53	0.0018	561	64	path
54	0.0017	880	102	stairs, steps
55	0.0017	86	12	runway
56	0.0017	172	11	case, display case, showcase, vitrine
57	0.0017	198	18	pool table, billiard table, snooker table
58	0.0017	930	109	pillow
59	0.0015	139	18	screen door, screen
60	0.0015	564	52	stairway, staircase
61	0.0015	320	26	river
62	0.0015	261	29	bridge, span
63	0.0014	275	22	bookcase
64	0.0014	335	60	blind, screen
65	0.0014	792	75	coffee table, cocktail table
66	0.0014	395	49	toilet, can, commode, crapper, pot, potty, stool, throne
67	0.0014	1309	138	flower
68	0.0013	1112	113	book
69	0.0013	266	27	hill
70	0.0013	659	66	bench
71	0.0012	331	31	countertop
72	0.0012	531	56	stove, kitchen stove, range, kitchen range, cooking stove
73	0.0012	369	36	palm, palm tree
74	0.0012	144	9	kitchen island
75	0.0011	265	29	computer, computing machine, computing device, data processor, electronic computer, information processing system
76	0.0010	324	33	swivel chair
77	0.0009	304	27	boat
78	0.0009	170	20	bar
79	0.0009	68	6	arcade machine
80	0.0009	65	8	hovel, hut, hutch, shack, shanty
81	0.0009	248	25	bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
82	0.0008	492	49	towel
83	0.0008	2510	269	light, light source
84	0.0008	440	39	truck, motortruck
85	0.0008	147	18	tower
86	0.0008	583	56	chandelier, pendant, pendent
87	0.0007	533	61	awning, sunshade, sunblind
88	0.0007	1989	239	streetlight, street lamp
89	0.0007	71	5	booth, cubicle, stall, kiosk
90	0.0007	618	53	television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box
91	0.0007	135	12	airplane, aeroplane, plane
92	0.0007	83	5	dirt track
93	0.0007	178	17	apparel, wearing apparel, dress, clothes
94	0.0006	1003	104	pole
95	0.0006	182	12	land, ground, soil
96	0.0006	452	50	bannister, banister, balustrade, balusters, handrail
97	0.0006	42	6	escalator, moving staircase, moving stairway
98	0.0006	307	31	ottoman, pouf, pouffe, puff, hassock
99	0.0006	965	114	bottle
100	0.0006	117	13	buffet, counter, sideboard
101	0.0006	354	35	poster, posting, placard, notice, bill, card
102	0.0006	108	9	stage
103	0.0006	557	55	van
104	0.0006	52	4	ship
105	0.0005	99	5	fountain
106	0.0005	57	4	conveyer belt, conveyor belt, conveyer, conveyor, transporter
107	0.0005	292	31	canopy
108	0.0005	77	9	washer, automatic washer, washing machine
109	0.0005	340	38	plaything, toy
110	0.0005	66	3	swimming pool, swimming bath, natatorium
111	0.0005	465	49	stool
112	0.0005	50	4	barrel, cask
113	0.0005	622	75	basket, handbasket
114	0.0005	80	9	waterfall, falls
115	0.0005	59	3	tent, collapsible shelter
116	0.0005	531	72	bag
117	0.0005	282	30	minibike, motorbike
118	0.0005	73	7	cradle
119	0.0005	435	44	oven
120	0.0005	136	25	ball
121	0.0005	116	24	food, solid food
122	0.0004	266	31	step, stair
123	0.0004	58	12	tank, storage tank
124	0.0004	418	83	trade name, brand name, brand, marque
125	0.0004	319	43	microwave, microwave oven
126	0.0004	1193	139	pot, flowerpot
127	0.0004	97	23	animal, animate being, beast, brute, creature, fauna
128	0.0004	347	36	bicycle, bike, wheel, cycle 
129	0.0004	52	5	lake
130	0.0004	246	22	dishwasher, dish washer, dishwashing machine
131	0.0004	108	13	screen, silver screen, projection screen
132	0.0004	201	30	blanket, cover
133	0.0004	285	21	sculpture
134	0.0004	268	27	hood, exhaust hood
135	0.0003	1020	108	sconce
136	0.0003	1282	122	vase
137	0.0003	528	65	traffic light, traffic signal, stoplight
138	0.0003	453	57	tray
139	0.0003	671	100	ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
140	0.0003	397	44	fan
141	0.0003	92	8	pier, wharf, wharfage, dock
142	0.0003	228	18	crt screen
143	0.0003	570	59	plate
144	0.0003	217	22	monitor, monitoring device
145	0.0003	206	19	bulletin board, notice board
146	0.0003	130	14	shower
147	0.0003	178	28	radiator
148	0.0002	504	57	glass, drinking glass
149	0.0002	775	96	clock
150	0.0002	421	56	flag

mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── cityscapes
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2012
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClass
│ │ │ ├── ImageSets
│ │ │ │ ├── Segmentation
│ │ ├── VOC2010
│ │ │ ├── JPEGImages
│ │ │ ├── SegmentationClassContext
│ │ │ ├── ImageSets
│ │ │ │ ├── SegmentationContext
│ │ │ │ │ ├── train.txt
│ │ │ │ │ ├── val.txt
│ │ │ ├── trainval_merged.json
│ │ ├── VOCaug
│ │ │ ├── dataset
│ │ │ │ ├── cls
│ ├── ade
│ │ ├── ADEChallengeData2016
│ │ │ ├── annotations
│ │ │ │ ├── training
│ │ │ │ ├── validation
│ │ │ ├── images
│ │ │ │ ├── training
│ │ │ │ ├── validation
insert image description here

5. Modify the basic configuration file

This time we choose the upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K model for training, and the corresponding configuration files are as follows.
insert image description here
The specific configuration information is as follows

_base_ = [
    '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa
model = dict(
    backbone=dict(
        init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
        embed_dims=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        use_abs_pos_embed=False,
        drop_path_rate=0.3,
        patch_norm=True),
    decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
    auxiliary_head=dict(in_channels=384, num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
    
    
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))

lr_config = dict(
    _delete_=True,
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

1. Set the number of modified categories and load the pre-training model (model architecture configuration file uppernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py)

_base_ = [
    '../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth'  # noqa,这个可以下载后,加载下载后的路径
model = dict(
    backbone=dict(
        init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
        embed_dims=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        use_abs_pos_embed=False,
        drop_path_rate=0.3,
        patch_norm=True),
    decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
    auxiliary_head=dict(in_channels=384, num_classes=150))#num_classes修改为自己的数据类别数,不包括背景,背景自动为0

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys={
    
    
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))

lr_config = dict(
    _delete_=True,
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-6,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

2. Modify data information (data type, data main path, etc. and batch-size) ('.../ base /datasets/ade20k.py')

# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016' #1、修改为自己的数据路径
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512) #2、修改为自己的数据的尺寸
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),#根据img_crop调整img_scale
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 512),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/training',
        ann_dir='annotations/training',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline))

3 Modify the category name CLASSES and the suffix name\Ignore the specified label index in the loss calculation (mmseg/datasets/ade.py, mmseg/datasets/custom.py)

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import mmcv
import numpy as np
from PIL import Image

from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class ADE20KDataset(CustomDataset):
    """ADE20K dataset.

    In segmentation map annotation for ADE20K, 0 stands for background, which
    is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
    The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
    '.png'.
    """
    CLASSES = (
        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
        'clock', 'flag')#修改为自己数据集的类别名称

    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
               [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
               [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
               [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
               [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
               [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
               [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
               [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
               [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
               [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
               [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
               [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
               [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
               [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
               [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
               [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
               [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
               [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
               [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
               [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
               [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
               [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
               [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
               [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
               [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
               [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
               [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
               [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
               [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
               [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
               [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
               [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
               [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
               [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
               [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
               [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
               [102, 255, 0], [92, 0, 255]] #同理可以修改颜色

    def __init__(self, **kwargs):
        super(ADE20KDataset, self).__init__(
            img_suffix='.jpg', #可以修改数据集的后缀格式
            seg_map_suffix='.png',#可以修改数据集标签的后缀格式
            reduce_zero_label=True,
            **kwargs)

    def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
        """Write the segmentation results to images.

        Args:
            results (list[ndarray]): Testing results of the
                dataset.
            imgfile_prefix (str): The filename prefix of the png files.
                If the prefix is "somepath/xxx",
                the png files will be named "somepath/xxx.png".
            to_label_id (bool): whether convert output to label_id for
                submission.
            indices (list[int], optional): Indices of input results, if not
                set, all the indices of the dataset will be used.
                Default: None.

        Returns:
            list[str: str]: result txt files which contains corresponding
            semantic segmentation images.
        """
        if indices is None:
            indices = list(range(len(self)))

        mmcv.mkdir_or_exist(imgfile_prefix)
        result_files = []
        for result, idx in zip(results, indices):

            filename = self.img_infos[idx]['filename']
            basename = osp.splitext(osp.basename(filename))[0]

            png_filename = osp.join(imgfile_prefix, f'{basename}.png')#这里可以修改.png

            # The  index range of official requirement is from 0 to 150.
            # But the index range of output is from 0 to 149.
            # That is because we set reduce_zero_label=True.
            result = result + 1

            output = Image.fromarray(result.astype(np.uint8))
            output.save(png_filename)
            result_files.append(png_filename)

        return result_files

    def format_results(self,
                       results,
                       imgfile_prefix,
                       to_label_id=True,
                       indices=None):
        """Format the results into dir (standard format for ade20k evaluation).

        Args:
            results (list): Testing results of the dataset.
            imgfile_prefix (str | None): The prefix of images files. It
                includes the file path and the prefix of filename, e.g.,
                "a/b/prefix".
            to_label_id (bool): whether convert output to label_id for
                submission. Default: False
            indices (list[int], optional): Indices of input results, if not
                set, all the indices of the dataset will be used.
                Default: None.

        Returns:
            tuple: (result_files, tmp_dir), result_files is a list containing
               the image paths, tmp_dir is the temporal directory created
                for saving json/png files when img_prefix is not specified.
        """

        if indices is None:
            indices = list(range(len(self)))

        assert isinstance(results, list), 'results must be a list.'
        assert isinstance(indices, list), 'indices must be a list.'

        result_files = self.results2img(results, imgfile_prefix, to_label_id,
                                        indices)
        return result_files

One thing to note is that if your picture is in jpg format and the mask is in png format, it should be fine. If it is not in these two formats, you need to modify the format of your picture in mmseg/datasets/custom.py.

insert image description here

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDict

import mmcv
import numpy as np
from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Dataset

from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations


@DATASETS.register_module()
class CustomDataset(Dataset):
    """Custom dataset for semantic segmentation. An example of file structure
    is as followed.

    .. code-block:: none

        ├── data
        │   ├── my_dataset
        │   │   ├── img_dir
        │   │   │   ├── train
        │   │   │   │   ├── xxx{
    
    img_suffix}
        │   │   │   │   ├── yyy{
    
    img_suffix}
        │   │   │   │   ├── zzz{
    
    img_suffix}
        │   │   │   ├── val
        │   │   ├── ann_dir
        │   │   │   ├── train
        │   │   │   │   ├── xxx{
    
    seg_map_suffix}
        │   │   │   │   ├── yyy{
    
    seg_map_suffix}
        │   │   │   │   ├── zzz{
    
    seg_map_suffix}
        │   │   │   ├── val

    The img/gt_semantic_seg pair of CustomDataset should be of the same
    except suffix. A valid img/gt_semantic_seg filename pair should be like
    ``xxx{
     
     img_suffix}`` and ``xxx{
     
     seg_map_suffix}`` (extension is also included
    in the suffix). If split is given, then ``xxx`` is specified in txt file.
    Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
    Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.


    Args:
        pipeline (list[dict]): Processing pipeline
        img_dir (str): Path to image directory
        img_suffix (str): Suffix of images. Default: '.jpg'
        ann_dir (str, optional): Path to annotation directory. Default: None
        seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
        split (str, optional): Split txt file. If split is specified, only
            file with suffix in the splits will be loaded. Otherwise, all
            images in img_dir/ann_dir will be loaded. Default: None
        data_root (str, optional): Data root for img_dir/ann_dir. Default:
            None.
        test_mode (bool): If test_mode=True, gt wouldn't be loaded.
        ignore_index (int): The label index to be ignored. Default: 255
        reduce_zero_label (bool): Whether to mark label zero as ignored.
            Default: False
        classes (str | Sequence[str], optional): Specify classes to load.
            If is None, ``cls.CLASSES`` will be used. Default: None.
        palette (Sequence[Sequence[int]]] | np.ndarray | None):
            The palette of segmentation map. If None is given, and
            self.PALETTE is None, random palette will be generated.
            Default: None
        gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
            load gt for evaluation, load from disk by default. Default: None.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmcv.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
    """

    CLASSES = None

    PALETTE = None

    def __init__(self,
                 pipeline,
                 img_dir,
                 img_suffix='.jpg',#修改
                 ann_dir=None,
                 seg_map_suffix='.png',修改
                 split=None,
                 data_root=None,
                 test_mode=False,
                 ignore_index=255,
                 reduce_zero_label=False,
                 classes=None,
                 palette=None,
                 gt_seg_map_loader_cfg=None,
                 file_client_args=dict(backend='disk')):
        self.pipeline = Compose(pipeline)
        self.img_dir = img_dir
        self.img_suffix = img_suffix
        self.ann_dir = ann_dir
        self.seg_map_suffix = seg_map_suffix
        self.split = split
        self.data_root = data_root
        self.test_mode = test_mode
        self.ignore_index = ignore_index
        self.reduce_zero_label = reduce_zero_label
        self.label_map = None
        self.CLASSES, self.PALETTE = self.get_classes_and_palette(
            classes, palette)
        self.gt_seg_map_loader = LoadAnnotations(
        ) if gt_seg_map_loader_cfg is None else LoadAnnotations(
            **gt_seg_map_loader_cfg)

        self.file_client_args = file_client_args
        self.file_client = mmcv.FileClient.infer_client(self.file_client_args)

        if test_mode:
            assert self.CLASSES is not None, \
                '`cls.CLASSES` or `classes` should be specified when testing'

        # join paths if data_root is specified
        if self.data_root is not None:
            if not osp.isabs(self.img_dir):
                self.img_dir = osp.join(self.data_root, self.img_dir)
            if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
                self.ann_dir = osp.join(self.data_root, self.ann_dir)
            if not (self.split is None or osp.isabs(self.split)):
                self.split = osp.join(self.data_root, self.split)

        # load annotations
        self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
                                               self.ann_dir,
                                               self.seg_map_suffix, self.split)

    def __len__(self):
        """Total number of samples of data."""
        return len(self.img_infos)

    def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
                         split):
        """Load annotation from directory.

        Args:
            img_dir (str): Path to image directory
            img_suffix (str): Suffix of images.
            ann_dir (str|None): Path to annotation directory.
            seg_map_suffix (str|None): Suffix of segmentation maps.
            split (str|None): Split txt file. If split is specified, only file
                with suffix in the splits will be loaded. Otherwise, all images
                in img_dir/ann_dir will be loaded. Default: None

        Returns:
            list[dict]: All image info of dataset.
        """

        img_infos = []
        if split is not None:
            lines = mmcv.list_from_file(
                split, file_client_args=self.file_client_args)
            for line in lines:
                img_name = line.strip()
                img_info = dict(filename=img_name + img_suffix)
                if ann_dir is not None:
                    seg_map = img_name + seg_map_suffix
                    img_info['ann'] = dict(seg_map=seg_map)
                img_infos.append(img_info)
        else:
            for img in self.file_client.list_dir_or_file(
                    dir_path=img_dir,
                    list_dir=False,
                    suffix=img_suffix,
                    recursive=True):
                img_info = dict(filename=img)
                if ann_dir is not None:
                    seg_map = img.replace(img_suffix, seg_map_suffix)
                    img_info['ann'] = dict(seg_map=seg_map)
                img_infos.append(img_info)
            img_infos = sorted(img_infos, key=lambda x: x['filename'])

        print_log(f'Loaded {
     
     len(img_infos)} images', logger=get_root_logger())
        return img_infos

    def get_ann_info(self, idx):
        """Get annotation by index.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Annotation info of specified index.
        """

        return self.img_infos[idx]['ann']

    def pre_pipeline(self, results):
        """Prepare results dict for pipeline."""
        results['seg_fields'] = []
        results['img_prefix'] = self.img_dir
        results['seg_prefix'] = self.ann_dir
        if self.custom_classes:
            results['label_map'] = self.label_map

    def __getitem__(self, idx):
        """Get training/test data after pipeline.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Training/test data (with annotation if `test_mode` is set
                False).
        """

        if self.test_mode:
            return self.prepare_test_img(idx)
        else:
            return self.prepare_train_img(idx)

    def prepare_train_img(self, idx):
        """Get training data and annotations after pipeline.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Training data and annotation after pipeline with new keys
                introduced by pipeline.
        """

        img_info = self.img_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        self.pre_pipeline(results)
        return self.pipeline(results)

    def prepare_test_img(self, idx):
        """Get testing data after pipeline.

        Args:
            idx (int): Index of data.

        Returns:
            dict: Testing data after pipeline with new keys introduced by
                pipeline.
        """

        img_info = self.img_infos[idx]
        results = dict(img_info=img_info)
        self.pre_pipeline(results)
        return self.pipeline(results)

    def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
        """Place holder to format result to dataset specific output."""
        raise NotImplementedError

    def get_gt_seg_map_by_idx(self, index):
        """Get one ground truth segmentation map for evaluation."""
        ann_info = self.get_ann_info(index)
        results = dict(ann_info=ann_info)
        self.pre_pipeline(results)
        self.gt_seg_map_loader(results)
        return results['gt_semantic_seg']

    def get_gt_seg_maps(self, efficient_test=None):
        """Get ground truth segmentation maps for evaluation."""
        if efficient_test is not None:
            warnings.warn(
                'DeprecationWarning: ``efficient_test`` has been deprecated '
                'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
                'friendly by default. ')

        for idx in range(len(self)):
            ann_info = self.get_ann_info(idx)
            results = dict(ann_info=ann_info)
            self.pre_pipeline(results)
            self.gt_seg_map_loader(results)
            yield results['gt_semantic_seg']

    def pre_eval(self, preds, indices):
        """Collect eval result from each iteration.

        Args:
            preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
                after argmax, shape (N, H, W).
            indices (list[int] | int): the prediction related ground truth
                indices.

        Returns:
            list[torch.Tensor]: (area_intersect, area_union, area_prediction,
                area_ground_truth).
        """
        # In order to compat with batch inference
        if not isinstance(indices, list):
            indices = [indices]
        if not isinstance(preds, list):
            preds = [preds]

        pre_eval_results = []

        for pred, index in zip(preds, indices):
            seg_map = self.get_gt_seg_map_by_idx(index)
            pre_eval_results.append(
                intersect_and_union(
                    pred,
                    seg_map,
                    len(self.CLASSES),
                    self.ignore_index,
                    # as the labels has been converted when dataset initialized
                    # in `get_palette_for_custom_classes ` this `label_map`
                    # should be `dict()`, see
                    # https://github.com/open-mmlab/mmsegmentation/issues/1415
                    # for more ditails
                    label_map=dict(),
                    reduce_zero_label=self.reduce_zero_label))

        return pre_eval_results

    def get_classes_and_palette(self, classes=None, palette=None):
        """Get class names of current dataset.

        Args:
            classes (Sequence[str] | str | None): If classes is None, use
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.
            palette (Sequence[Sequence[int]]] | np.ndarray | None):
                The palette of segmentation map. If None is given, random
                palette will be generated. Default: None
        """
        if classes is None:
            self.custom_classes = False
            return self.CLASSES, self.PALETTE

        self.custom_classes = True
        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        if self.CLASSES:
            if not set(class_names).issubset(self.CLASSES):
                raise ValueError('classes is not a subset of CLASSES.')

            # dictionary, its keys are the old label ids and its values
            # are the new label ids.
            # used for changing pixel labels in load_annotations.
            self.label_map = {
    
    }
            for i, c in enumerate(self.CLASSES):
                if c not in class_names:
                    self.label_map[i] = -1
                else:
                    self.label_map[i] = class_names.index(c)

        palette = self.get_palette_for_custom_classes(class_names, palette)

        return class_names, palette

    def get_palette_for_custom_classes(self, class_names, palette=None):

        if self.label_map is not None:
            # return subset of palette
            palette = []
            for old_id, new_id in sorted(
                    self.label_map.items(), key=lambda x: x[1]):
                if new_id != -1:
                    palette.append(self.PALETTE[old_id])
            palette = type(self.PALETTE)(palette)

        elif palette is None:
            if self.PALETTE is None:
                # Get random state before set seed, and restore
                # random state later.
                # It will prevent loss of randomness, as the palette
                # may be different in each iteration if not specified.
                # See: https://github.com/open-mmlab/mmdetection/issues/5844
                state = np.random.get_state()
                np.random.seed(42)
                # random palette
                palette = np.random.randint(0, 255, size=(len(class_names), 3))
                np.random.set_state(state)
            else:
                palette = self.PALETTE

        return palette

    def evaluate(self,
                 results,
                 metric='mIoU',
                 logger=None,
                 gt_seg_maps=None,
                 **kwargs):
        """Evaluate the dataset.

        Args:
            results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
                 results or predict segmentation map for computing evaluation
                 metric.
            metric (str | list[str]): Metrics to be evaluated. 'mIoU',
                'mDice' and 'mFscore' are supported.
            logger (logging.Logger | None | str): Logger used for printing
                related information during evaluation. Default: None.
            gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
                used in ConcatDataset

        Returns:
            dict[str, float]: Default metrics.
        """
        if isinstance(metric, str):
            metric = [metric]
        allowed_metrics = ['mIoU', 'mDice', 'mFscore']
        if not set(metric).issubset(set(allowed_metrics)):
            raise KeyError('metric {} is not supported'.format(metric))

        eval_results = {
    
    }
        # test a list of files
        if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
                results, str):
            if gt_seg_maps is None:
                gt_seg_maps = self.get_gt_seg_maps()
            num_classes = len(self.CLASSES)
            ret_metrics = eval_metrics(
                results,
                gt_seg_maps,
                num_classes,
                self.ignore_index,
                metric,
                label_map=dict(),
                reduce_zero_label=self.reduce_zero_label)
        # test a list of pre_eval_results
        else:
            ret_metrics = pre_eval_to_metrics(results, metric)

        # Because dataset.CLASSES is required for per-eval.
        if self.CLASSES is None:
            class_names = tuple(range(num_classes))
        else:
            class_names = self.CLASSES

        # summary table
        ret_metrics_summary = OrderedDict({
    
    
            ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
            for ret_metric, ret_metric_value in ret_metrics.items()
        })

        # each class table
        ret_metrics.pop('aAcc', None)
        ret_metrics_class = OrderedDict({
    
    
            ret_metric: np.round(ret_metric_value * 100, 2)
            for ret_metric, ret_metric_value in ret_metrics.items()
        })
        ret_metrics_class.update({
    
    'Class': class_names})
        ret_metrics_class.move_to_end('Class', last=False)

        # for logger
        class_table_data = PrettyTable()
        for key, val in ret_metrics_class.items():
            class_table_data.add_column(key, val)

        summary_table_data = PrettyTable()
        for key, val in ret_metrics_summary.items():
            if key == 'aAcc':
                summary_table_data.add_column(key, [val])
            else:
                summary_table_data.add_column('m' + key, [val])

        print_log('per class results:', logger)
        print_log('\n' + class_table_data.get_string(), logger=logger)
        print_log('Summary:', logger)
        print_log('\n' + summary_table_data.get_string(), logger=logger)

        # each metric dict
        for key, value in ret_metrics_summary.items():
            if key == 'aAcc':
                eval_results[key] = value / 100.0
            else:
                eval_results['m' + key] = value / 100.0

        ret_metrics_class.pop('Class', None)
        for key, value in ret_metrics_class.items():
            eval_results.update({
    
    
                key + '.' + str(name): value[idx] / 100.0
                for idx, name in enumerate(class_names)
            })

        return eval_results

1. The modified configuration file of custom

insert image description here

2. The voc data type modifies the result of the model evaluation after the specified label index ignored in the calculation

insert image description here

4. Modify the running information configuration (load the pre-trained model and breakpoint training) (configs/-base-/default_runtime.py)

# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        # dict(type='TensorboardLoggerHook') #开启TensorboardLoggerHook
        # dict(type='PaviLoggerHook') # for internal services
    ])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None  #从给定的路径加载模型作为预先训练的模型,这不会恢复训练。
resume_from = None  #从给定的路径加载模型作为训练后的断点的模型,恢复训练。
workflow = [('train', 1)]
cudnn_benchmark = True

5. Modify the running information configuration (the maximum number of model training, keep a checkpoints for each training, how many times to perform model training, the index of model training evaluation is, automatically retain the best model,) (configs/-base- /schedule_40k.py, ... / base /schedules/schedule_160k.py)

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)#max_iters,模型训练的最大迭代次数
checkpoint_config = dict(by_epoch=False, interval=16000)##interval,模型保存的迭代次数
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)#interval=16000模型多少间隔训练一次,评估的指标,#save_best='auto'可以保留最好的模型
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        dict(type='TensorboardLoggerHook')
    ])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = '/media/lhy/Swin-Transformer-Semantic-Segmentation/checkpoints/deeplabv3plus/deeplabv3plus_r101-d8_512x512_40k_voc12aug_20200613_205333-faf03387.pth'
resume_from = '/media/lhy/mmsegmentation-0.27.0/work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2/best_mIoU_iter_44000.pth'
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
#调用FP16
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
fp16 = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(
    interval=4000, metric=['mIoU', 'mFscore'], pre_eval=True, save_best='mIoU')#自动保存mIOU最好的模型
work_dir = 'work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2'
gpu_ids = range(0, 4)
auto_resume = False

Single GPU learning rate lr= LR*(batch_size/16), LR represents the learning rate of 4GPU

6. Modify the inference mode of the model and norm_cfg (…/ base /models/upernet_swin.py)

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)#这里的norm_cfg中,如果是多卡训练,采用“SyncBN”; 如果是单卡训练,将type修改为'BN'即可。
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained=None,
    backbone=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=96,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,
        qk_scale=None,
        patch_norm=True,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.3,
        use_abs_pos_embed=False,
        act_cfg=dict(type='GELU'),
        norm_cfg=backbone_norm_cfg),
    decode_head=dict(
        type='UPerHead',
        in_channels=[96, 192, 384, 768],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=512,
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=384,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))
    #'whole代表全图推理模式',
#滑窗重叠预测可修改为:test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341))

insert image description here
Sliding window code: mmsegmentation/mmseg/models/segmentors/encoder_decoder.py

    # TODO refactor
    def slide_inference(self, img, img_meta, rescale):
        """Inference by sliding-window with overlap.
        If h_crop > h_img or w_crop > w_img, the small patch will be used to
        decode without padding.
        """

        h_stride, w_stride = self.test_cfg.stride
        h_crop, w_crop = self.test_cfg.crop_size
        batch_size, _, h_img, w_img = img.size()
        num_classes = self.num_classes
        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
        preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
        for h_idx in range(h_grids):
            for w_idx in range(w_grids):
                y1 = h_idx * h_stride
                x1 = w_idx * w_stride
                y2 = min(y1 + h_crop, h_img)
                x2 = min(x1 + w_crop, w_img)
                y1 = max(y2 - h_crop, 0)
                x1 = max(x2 - w_crop, 0)
                crop_img = img[:, :, y1:y2, x1:x2]
                crop_seg_logit = self.encode_decode(crop_img, img_meta)
                preds += F.pad(crop_seg_logit,
                               (int(x1), int(preds.shape[3] - x2), int(y1),
                                int(preds.shape[2] - y2)))

                count_mat[:, :, y1:y2, x1:x2] += 1
        assert (count_mat == 0).sum() == 0
        if torch.onnx.is_in_onnx_export():
            # cast count_mat to constant while exporting to ONNX
            count_mat = torch.from_numpy(
                count_mat.cpu().detach().numpy()).to(device=img.device)
        preds = preds / count_mat
        if rescale:
            # remove padding area
            resize_shape = img_meta[0]['img_shape'][:2]
            preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
            preds = resize(
                preds,
                size=img_meta[0]['ori_shape'][:2],
                mode='bilinear',
                align_corners=self.align_corners,
                warning=False)
        return preds

6. Model optimization skills

1. Learning rate optimization techniques

In semantic segmentation, some methods make the LR of the head larger than that of the backbone to achieve better performance or faster convergence.
In MMSegmentation, you can add the following line to the configuration to make the LR of the head 10 times that of the trunk. With this modification, the LR 'head' of any parameter group with an LR name will be multiplied by 10.

Different Learning Rate(LR) for Backbone and Heads
n MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.

optimizer=dict(
    paramwise_cfg = dict(
        custom_keys={
    
    
            'head': dict(lr_mult=10.)}))

2、Online Hard Example Mining (OHEM)

We implement pixel sampler here for training sampling. This is an example configuration for PSPNet training with OHEM enabled.
This way, only pixels with confidence scores lower than 0.7 are used for training. We keep at least 100000 pixels during training. If thresh is not specified, min_kept will select the top missing pixels.

Online Hard Example Mining (OHEM)
We implement pixel sampler here for training sampling. Here is an example config of training PSPNet with OHEM enabled.

_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
    decode_head=dict(
        sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )

3. Class balance loss

For datasets with an imbalanced class distribution, you can change the loss weights for each class. Here is an example of a cityscape dataset. class_weight will be passed to CrossEntropyLoss as the weight parameter

_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
    decode_head=dict(
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
            # DeepLab used this class weight for cityscapes
            class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
                        1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
                        1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))

4. Multiple losses

For loss computation, we support multiple loss trainings at the same time. unet Here is an example configuration trained on the dataset DRIVE with a loss function that is a 1:3 sum of weighted CrossEntropyLoss and DiceLoss:

_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
    decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
            dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
    auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
            dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
    )

In this way, loss_weight and loss_name will be the weight and name in the training log of the corresponding loss respectively.
Note: if you want to include this loss term into the backward graph, loss_ must be prefixed to the name.

5. Ignore the specified label index in the loss calculation

mmseg has written description files and loading codes for various public segmentation data sets. For those who have used PyTorch, it is still very comfortable to learn the description files of various data sets. Only reduce_zero_label is relatively unfamiliar to mmseg novices. Therefore, when building your own mmseg data set, the most confusing thing for novices is probably whether reduce_zero_label should be True or False.

What is it for? Literally translated from the name is "reduce 0 value label". In the multi-class segmentation task, if the 0 value in your data set is used as the background category in the label file, it is recommended to ignore it.

Open the source code fragment of the loaded data and you can see a piece of code that processes reduce_zero_label, which means: if reduce_zero_label is enabled, all labels that were originally 0 are set to 255, which is the default value of the ignore_index parameter in the loss function. The default avoidance value of this parameter is 255 annotations are involved in the loss calculation. The reason for the 150-class ADE data set not shown above is that the reduce zero label is enabled, and the background with a value of 0 is set to ignore_index.

# mmseg/datasets/pipelines/loading.py

...
# reduce zero_label
if self.reduce_zero_label:
    # avoid using underflow conversion
    gt_semantic_seg[gt_semantic_seg == 0] = 255
    gt_semantic_seg = gt_semantic_seg - 1
    gt_semantic_seg[gt_semantic_seg == 254] = 255
...

Description of common problems caused by reduce_zero_label

Here we take the source code of the ADE dataset as an example. The default setting of reduce_zero_label is True. However, even if a novice has mastered the reduce_zero_label in the previous section, he may have a superficial understanding of ADE and wonder whether the reduce_zero_label enabled in the configuration file is to use 150 instances The first one in the class is ignored, after all, isn't num_classes 150, and then reduce_zero_label is taken for granted.

Error Cause Analysis

# configs/_base_/datasets/ade20k.py

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True), # ADE中reduce_zero_label默认设置为True
    dict(...),
    ...
]

There are indeed only 150 classes in the label that actually participated in the training, which are defined in CLASSES, but the label file actually contains 151 classes, and the background class (the remaining unmarked, or accidentally ignored areas are classified as background, in the label file The median value is 0) is not included in the 150 CLASSES, and needs to be set to ignore_index during training, so we use the reduce_zero_label in the previous section to extract the background from the 151 classes and set it to ignore_index alone. If we mistakenly set reduce_zero_label Turn it off, then num_classes is 151.

In the default setting, avg_non_ignore=False which means that every pixel counts towards the loss calculation, although some of them belong to ignore index labels.
For loss calculation, we support passing avg_non_ignore and ignore_index which ignore certain labels. In this way, the average loss will only be calculated in the non-ignored labels, which may achieve better performance, here is the reference. unet This is an example configuration for dataset training Cityscapes: in loss calculation it ignores label 0 as background and calculates loss average only on non-ignored labels:

_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
    decode_head=dict(
        ignore_index=0,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
    auxiliary_head=dict(
        ignore_index=0,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
    ))

Just add ignore_index decoder header or auxiliary header and add avg_non_ignore=True:

# model settings
...
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
...

Guess you like

Origin blog.csdn.net/qq_41627642/article/details/126479513