1. 公式モデル変換 MMSegmentation スタイル
公式リポジトリの事前トレーニング済みモデルを使用するためにキーワードを自分で変換したい場合は、公式リポジトリのモデルのキーワードを MMSegmentation スタイルに変換するためのスクリプト swin2mmseg.py も tools ディレクトリに提供されています。
python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
python tools/model_converters/swin2mmseg.py https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth pretrain/swin_base_patch4_window7_224.pth
このスクリプトは、PRETRAIN_PATH からモデルを変換し、変換されたモデルを STORE_PATH に保存します。
デフォルトの設定では、事前トレーニング済みモデルとそれに対応する元のモデルを次のように定義できます。
2. ADK20のモデルをダウンロードする
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_small_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192015-ee2fff1c.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K_20210526_211650-762e2178.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k/upernet_swin_large_patch4_window7_512x512_pretrain_224x224_22K_160k_ade20k_20220318_015320-48d180dd.pth
https://download.openmmlab.com/mmsegmentation/v0.5/swin/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k/upernet_swin_large_patch4_window12_512x512_pretrain_384x384_22K_160k_ade20k_20220318_091743-9ba68901.pth
3. Swin Transform 事前トレーニング モデルをダウンロードする
#tiny
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth
#small
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth
#big
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_20220317-e9b98025.pth
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_22k_20220317-e5c09f74.pth
#large
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth
https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth
4. ADK20構造のデータディレクトリを構築します
ADE20k には、オープン辞書ラベル セットで密に注釈が付けられた 25,000 枚を超える画像 (20ktrain、2k val、3ktest) が含まれています。2017 Places Challenge 2 では、全ピクセルの 89% をカバーする 100 個のモノと 50 個のモノのカテゴリが選択されました。
カテゴリは全部で150あります。
Idx Ratio Train Val Name
1 0.1576 11664 1172 wall
2 0.1072 6046 612 building, edifice
3 0.0878 8265 796 sky
4 0.0621 9336 917 floor, flooring
5 0.0480 6678 641 tree
6 0.0450 6604 643 ceiling
7 0.0398 4023 408 road, route
8 0.0231 1906 199 bed
9 0.0198 4688 460 windowpane, window
10 0.0183 2423 225 grass
11 0.0181 2874 294 cabinet
12 0.0166 3068 310 sidewalk, pavement
13 0.0160 5075 526 person, individual, someone, somebody, mortal, soul
14 0.0151 1804 190 earth, ground
15 0.0118 6666 796 door, double door
16 0.0110 4269 411 table
17 0.0109 1691 160 mountain, mount
18 0.0104 3999 441 plant, flora, plant life
19 0.0104 2149 217 curtain, drape, drapery, mantle, pall
20 0.0103 3261 318 chair
21 0.0098 3164 306 car, auto, automobile, machine, motorcar
22 0.0074 709 75 water
23 0.0067 3296 315 painting, picture
24 0.0065 1191 106 sofa, couch, lounge
25 0.0061 1516 162 shelf
26 0.0060 667 69 house
27 0.0053 651 57 sea
28 0.0052 1847 224 mirror
29 0.0046 1158 128 rug, carpet, carpeting
30 0.0044 480 44 field
31 0.0044 1172 98 armchair
32 0.0044 1292 184 seat
33 0.0033 1386 138 fence, fencing
34 0.0031 698 61 desk
35 0.0030 781 73 rock, stone
36 0.0027 380 43 wardrobe, closet, press
37 0.0026 3089 302 lamp
38 0.0024 404 37 bathtub, bathing tub, bath, tub
39 0.0024 804 99 railing, rail
40 0.0023 1453 153 cushion
41 0.0023 411 37 base, pedestal, stand
42 0.0022 1440 162 box
43 0.0022 800 77 column, pillar
44 0.0020 2650 298 signboard, sign
45 0.0019 549 46 chest of drawers, chest, bureau, dresser
46 0.0019 367 36 counter
47 0.0018 311 30 sand
48 0.0018 1181 122 sink
49 0.0018 287 23 skyscraper
50 0.0018 468 38 fireplace, hearth, open fireplace
51 0.0018 402 43 refrigerator, icebox
52 0.0018 130 12 grandstand, covered stand
53 0.0018 561 64 path
54 0.0017 880 102 stairs, steps
55 0.0017 86 12 runway
56 0.0017 172 11 case, display case, showcase, vitrine
57 0.0017 198 18 pool table, billiard table, snooker table
58 0.0017 930 109 pillow
59 0.0015 139 18 screen door, screen
60 0.0015 564 52 stairway, staircase
61 0.0015 320 26 river
62 0.0015 261 29 bridge, span
63 0.0014 275 22 bookcase
64 0.0014 335 60 blind, screen
65 0.0014 792 75 coffee table, cocktail table
66 0.0014 395 49 toilet, can, commode, crapper, pot, potty, stool, throne
67 0.0014 1309 138 flower
68 0.0013 1112 113 book
69 0.0013 266 27 hill
70 0.0013 659 66 bench
71 0.0012 331 31 countertop
72 0.0012 531 56 stove, kitchen stove, range, kitchen range, cooking stove
73 0.0012 369 36 palm, palm tree
74 0.0012 144 9 kitchen island
75 0.0011 265 29 computer, computing machine, computing device, data processor, electronic computer, information processing system
76 0.0010 324 33 swivel chair
77 0.0009 304 27 boat
78 0.0009 170 20 bar
79 0.0009 68 6 arcade machine
80 0.0009 65 8 hovel, hut, hutch, shack, shanty
81 0.0009 248 25 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle
82 0.0008 492 49 towel
83 0.0008 2510 269 light, light source
84 0.0008 440 39 truck, motortruck
85 0.0008 147 18 tower
86 0.0008 583 56 chandelier, pendant, pendent
87 0.0007 533 61 awning, sunshade, sunblind
88 0.0007 1989 239 streetlight, street lamp
89 0.0007 71 5 booth, cubicle, stall, kiosk
90 0.0007 618 53 television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box
91 0.0007 135 12 airplane, aeroplane, plane
92 0.0007 83 5 dirt track
93 0.0007 178 17 apparel, wearing apparel, dress, clothes
94 0.0006 1003 104 pole
95 0.0006 182 12 land, ground, soil
96 0.0006 452 50 bannister, banister, balustrade, balusters, handrail
97 0.0006 42 6 escalator, moving staircase, moving stairway
98 0.0006 307 31 ottoman, pouf, pouffe, puff, hassock
99 0.0006 965 114 bottle
100 0.0006 117 13 buffet, counter, sideboard
101 0.0006 354 35 poster, posting, placard, notice, bill, card
102 0.0006 108 9 stage
103 0.0006 557 55 van
104 0.0006 52 4 ship
105 0.0005 99 5 fountain
106 0.0005 57 4 conveyer belt, conveyor belt, conveyer, conveyor, transporter
107 0.0005 292 31 canopy
108 0.0005 77 9 washer, automatic washer, washing machine
109 0.0005 340 38 plaything, toy
110 0.0005 66 3 swimming pool, swimming bath, natatorium
111 0.0005 465 49 stool
112 0.0005 50 4 barrel, cask
113 0.0005 622 75 basket, handbasket
114 0.0005 80 9 waterfall, falls
115 0.0005 59 3 tent, collapsible shelter
116 0.0005 531 72 bag
117 0.0005 282 30 minibike, motorbike
118 0.0005 73 7 cradle
119 0.0005 435 44 oven
120 0.0005 136 25 ball
121 0.0005 116 24 food, solid food
122 0.0004 266 31 step, stair
123 0.0004 58 12 tank, storage tank
124 0.0004 418 83 trade name, brand name, brand, marque
125 0.0004 319 43 microwave, microwave oven
126 0.0004 1193 139 pot, flowerpot
127 0.0004 97 23 animal, animate being, beast, brute, creature, fauna
128 0.0004 347 36 bicycle, bike, wheel, cycle
129 0.0004 52 5 lake
130 0.0004 246 22 dishwasher, dish washer, dishwashing machine
131 0.0004 108 13 screen, silver screen, projection screen
132 0.0004 201 30 blanket, cover
133 0.0004 285 21 sculpture
134 0.0004 268 27 hood, exhaust hood
135 0.0003 1020 108 sconce
136 0.0003 1282 122 vase
137 0.0003 528 65 traffic light, traffic signal, stoplight
138 0.0003 453 57 tray
139 0.0003 671 100 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
140 0.0003 397 44 fan
141 0.0003 92 8 pier, wharf, wharfage, dock
142 0.0003 228 18 crt screen
143 0.0003 570 59 plate
144 0.0003 217 22 monitor, monitoring device
145 0.0003 206 19 bulletin board, notice board
146 0.0003 130 14 shower
147 0.0003 178 28 radiator
148 0.0002 504 57 glass, drinking glass
149 0.0002 775 96 clock
150 0.0002 421 56 flag
mmsegmentation
§── mmseg
§── ツール
§── configs
├── データ
│ §── 街並み
│ │ §── leftImg8bit
│ │ │ ├── train
│ │ │ │ §── val
│ │ §── gtFine
│ │ │ §── train
│ │ │ §── val
│ ├── VOCdevkit
│ │ §── VOC2012
│ │ │ §── JPEGImages
│ │ │ っています │ │ │ §── SegmentationClass
│
│ │ っています。 ─ セグメンテーション
│ │ §─ VOC2010
│ │ │ §─ JPEGImages
│ │ │ §─ SegmentationClassContext
│ │ │ §─ ImageSets
│ │ │ │ §─ SegmentationContext
│ │ │ │ │ §─ train.txt
│ │ │ │ │ ∴── val.txt
│ │ │ §── trainval_merged.json
│ │ │ §── VOCaug
│ │ │ ├─ データセット
│ │ │ │ ├─ cls
│ §─ ade
│ │ ├─ ADEChallengeData2016
│ │ │ ├─ アノテーション
│ │ │ │ §─ トレーニング
│ │ │ │ §─ 検証
│ │ │ §── 画像
│ │ │ │ §── トレーニング
│ │ │ │ §── 検証
5. 基本設定ファイルを変更する
今回はトレーニング用に upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K モデルを選択し、対応する設定ファイルは次のとおりです。
具体的な構成情報は次のとおりです
_base_ = [
'../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
1. 変更されたカテゴリの数を設定し、事前トレーニング モデルをロードします (モデル アーキテクチャ構成ファイル uppernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py)
_base_ = [
'../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa,这个可以下载后,加载下载后的路径
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))#num_classes修改为自己的数据类别数,不包括背景,背景自动为0
# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
2. データ情報 (データタイプ、データメインパスなど、およびバッチサイズ) を変更します ('.../ base /datasets/ade20k.py')
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016' #1、修改为自己的数据路径
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512) #2、修改为自己的数据的尺寸
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),#根据img_crop调整img_scale
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
3 カテゴリ名 CLASSES とサフィックス名を変更します\損失計算で指定されたラベル インデックスを無視します (mmseg/datasets/ade.py、mmseg/datasets/custom.py)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from PIL import Image
from .builder import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class ADE20KDataset(CustomDataset):
"""ADE20K dataset.
In segmentation map annotation for ADE20K, 0 stands for background, which
is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
'.png'.
"""
CLASSES = (
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag')#修改为自己数据集的类别名称
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]] #同理可以修改颜色
def __init__(self, **kwargs):
super(ADE20KDataset, self).__init__(
img_suffix='.jpg', #可以修改数据集的后缀格式
seg_map_suffix='.png',#可以修改数据集标签的后缀格式
reduce_zero_label=True,
**kwargs)
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
"""Write the segmentation results to images.
Args:
results (list[ndarray]): Testing results of the
dataset.
imgfile_prefix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission.
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
if indices is None:
indices = list(range(len(self)))
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
for result, idx in zip(results, indices):
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
png_filename = osp.join(imgfile_prefix, f'{basename}.png')#这里可以修改.png
# The index range of official requirement is from 0 to 150.
# But the index range of output is from 0 to 149.
# That is because we set reduce_zero_label=True.
result = result + 1
output = Image.fromarray(result.astype(np.uint8))
output.save(png_filename)
result_files.append(png_filename)
return result_files
def format_results(self,
results,
imgfile_prefix,
to_label_id=True,
indices=None):
"""Format the results into dir (standard format for ade20k evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str | None): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix".
to_label_id (bool): whether convert output to label_id for
submission. Default: False
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
result_files = self.results2img(results, imgfile_prefix, to_label_id,
indices)
return result_files
注意すべき点は、画像が jpg 形式でマスクが png 形式であれば問題ないことですが、これら 2 つの形式ではない場合は、mmseg/datasets/custom で画像の形式を変更する必要があります。 .py。
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from collections import OrderedDict
import mmcv
import numpy as np
from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations
@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.
.. code-block:: none
├── data
│ ├── my_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{
img_suffix}
│ │ │ │ ├── yyy{
img_suffix}
│ │ │ │ ├── zzz{
img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{
seg_map_suffix}
│ │ │ │ ├── yyy{
seg_map_suffix}
│ │ │ │ ├── zzz{
seg_map_suffix}
│ │ │ ├── val
The img/gt_semantic_seg pair of CustomDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{
img_suffix}`` and ``xxx{
seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
Args:
pipeline (list[dict]): Processing pipeline
img_dir (str): Path to image directory
img_suffix (str): Suffix of images. Default: '.jpg'
ann_dir (str, optional): Path to annotation directory. Default: None
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
split (str, optional): Split txt file. If split is specified, only
file with suffix in the splits will be loaded. Otherwise, all
images in img_dir/ann_dir will be loaded. Default: None
data_root (str, optional): Data root for img_dir/ann_dir. Default:
None.
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated.
Default: None
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
load gt for evaluation, load from disk by default. Default: None.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
CLASSES = None
PALETTE = None
def __init__(self,
pipeline,
img_dir,
img_suffix='.jpg',#修改
ann_dir=None,
seg_map_suffix='.png',修改
split=None,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False,
classes=None,
palette=None,
gt_seg_map_loader_cfg=None,
file_client_args=dict(backend='disk')):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
self.ann_dir = ann_dir
self.seg_map_suffix = seg_map_suffix
self.split = split
self.data_root = data_root
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
self.gt_seg_map_loader = LoadAnnotations(
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)
self.file_client_args = file_client_args
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.img_dir):
self.img_dir = osp.join(self.data_root, self.img_dir)
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
self.ann_dir = osp.join(self.data_root, self.ann_dir)
if not (self.split is None or osp.isabs(self.split)):
self.split = osp.join(self.data_root, self.split)
# load annotations
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
self.ann_dir,
self.seg_map_suffix, self.split)
def __len__(self):
"""Total number of samples of data."""
return len(self.img_infos)
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
split):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
lines = mmcv.list_from_file(
split, file_client_args=self.file_client_args)
for line in lines:
img_name = line.strip()
img_info = dict(filename=img_name + img_suffix)
if ann_dir is not None:
seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in self.file_client.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
suffix=img_suffix,
recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_map = img.replace(img_suffix, seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
img_infos = sorted(img_infos, key=lambda x: x['filename'])
print_log(f'Loaded {
len(img_infos)} images', logger=get_root_logger())
return img_infos
def get_ann_info(self, idx):
"""Get annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.img_infos[idx]['ann']
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
results['img_prefix'] = self.img_dir
results['seg_prefix'] = self.ann_dir
if self.custom_classes:
results['label_map'] = self.label_map
def __getitem__(self, idx):
"""Get training/test data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set
False).
"""
if self.test_mode:
return self.prepare_test_img(idx)
else:
return self.prepare_train_img(idx)
def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.img_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, idx):
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by
pipeline.
"""
img_info = self.img_infos[idx]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""Place holder to format result to dataset specific output."""
raise NotImplementedError
def get_gt_seg_map_by_idx(self, index):
"""Get one ground truth segmentation map for evaluation."""
ann_info = self.get_ann_info(index)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
return results['gt_semantic_seg']
def get_gt_seg_maps(self, efficient_test=None):
"""Get ground truth segmentation maps for evaluation."""
if efficient_test is not None:
warnings.warn(
'DeprecationWarning: ``efficient_test`` has been deprecated '
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
'friendly by default. ')
for idx in range(len(self)):
ann_info = self.get_ann_info(idx)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
yield results['gt_semantic_seg']
def pre_eval(self, preds, indices):
"""Collect eval result from each iteration.
Args:
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
after argmax, shape (N, H, W).
indices (list[int] | int): the prediction related ground truth
indices.
Returns:
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
area_ground_truth).
"""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
pre_eval_results = []
for pred, index in zip(preds, indices):
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
intersect_and_union(
pred,
seg_map,
len(self.CLASSES),
self.ignore_index,
# as the labels has been converted when dataset initialized
# in `get_palette_for_custom_classes ` this `label_map`
# should be `dict()`, see
# https://github.com/open-mmlab/mmsegmentation/issues/1415
# for more ditails
label_map=dict(),
reduce_zero_label=self.reduce_zero_label))
return pre_eval_results
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Default: None
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {
}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = class_names.index(c)
palette = self.get_palette_for_custom_classes(class_names, palette)
return class_names, palette
def get_palette_for_custom_classes(self, class_names, palette=None):
if self.label_map is not None:
# return subset of palette
palette = []
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
elif palette is None:
if self.PALETTE is None:
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(0, 255, size=(len(class_names), 3))
np.random.set_state(state)
else:
palette = self.PALETTE
return palette
def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.
Args:
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
results or predict segmentation map for computing evaluation
metric.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset
Returns:
dict[str, float]: Default metrics.
"""
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {
}
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
self.ignore_index,
metric,
label_map=dict(),
reduce_zero_label=self.reduce_zero_label)
# test a list of pre_eval_results
else:
ret_metrics = pre_eval_to_metrics(results, metric)
# Because dataset.CLASSES is required for per-eval.
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({
'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
print_log('\n' + summary_table_data.get_string(), logger=logger)
# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0
ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})
return eval_results
1. カスタムの変更された設定ファイル
2. voc データ型は、指定されたラベル インデックスが計算で無視された後のモデル評価の結果を変更します。
4. 実行情報構成を変更します (事前トレーニング済みモデルとブレークポイント トレーニングをロードします) (configs/-base-/default_runtime.py)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook') #开启TensorboardLoggerHook
# dict(type='PaviLoggerHook') # for internal services
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None #从给定的路径加载模型作为预先训练的模型,这不会恢复训练。
resume_from = None #从给定的路径加载模型作为训练后的断点的模型,恢复训练。
workflow = [('train', 1)]
cudnn_benchmark = True
5. 実行情報の設定を変更します (モデルトレーニングの最大数、トレーニングごとにチェックポイントを保持、モデルトレーニングを実行する回数、モデルトレーニングの評価の指標、最適なモデルを自動的に保持) (configs/-base) - /schedule_40k.py, ... /ベース/schedules/schedule_160k.py)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)#max_iters,模型训练的最大迭代次数
checkpoint_config = dict(by_epoch=False, interval=16000)##interval,模型保存的迭代次数
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)#interval=16000模型多少间隔训练一次,评估的指标,#save_best='auto'可以保留最好的模型
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook')
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = '/media/lhy/Swin-Transformer-Semantic-Segmentation/checkpoints/deeplabv3plus/deeplabv3plus_r101-d8_512x512_40k_voc12aug_20200613_205333-faf03387.pth'
resume_from = '/media/lhy/mmsegmentation-0.27.0/work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2/best_mIoU_iter_44000.pth'
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
#调用FP16
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
fp16 = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(
interval=4000, metric=['mIoU', 'mFscore'], pre_eval=True, save_best='mIoU')#自动保存mIOU最好的模型
work_dir = 'work_dirs/runs/train/road0.5m_1_deeplabv3plus_r101_exp2'
gpu_ids = range(0, 4)
auto_resume = False
単一 GPU の学習率 lr= LR*(batch_size/16)、LR は 4GPU の学習率を表します
6. モデルの推論モードと Norm_cfg を変更します (…/ base /models/upernet_swin.py)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)#这里的norm_cfg中,如果是多卡训练,采用“SyncBN”; 如果是单卡训练,将type修改为'BN'即可。
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='SwinTransformer',
pretrain_img_size=224,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg),
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
#'whole代表全图推理模式',
#滑窗重叠预测可修改为:test_cfg=dict(mode='slide', crop_size=crop_size, stride=(341, 341))
スライディング ウィンドウ コード: mmsegmentation/mmseg/models/segmentors/encoder_decoder.py
# TODO refactor
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.encode_decode(crop_img, img_meta)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
# remove padding area
resize_shape = img_meta[0]['img_shape'][:2]
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
preds = resize(
preds,
size=img_meta[0]['ori_shape'][:2],
mode='bilinear',
align_corners=self.align_corners,
warning=False)
return preds
6. モデル最適化スキル
1. 学習率最適化手法
セマンティック セグメンテーションでは、一部の方法では、ヘッドの LR をバックボーンの LR よりも大きくして、パフォーマンスの向上やより高速な収束を実現します。
MMSegmentation では、次の行を設定に追加して、ヘッドの LR をトランクの LR の 10 倍にすることができます。この変更により、LR 名を持つパラメータ グループの LR 'head' が 10 倍になります。
Different Learning Rate(LR) for Backbone and Heads
n MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.
optimizer=dict(
paramwise_cfg = dict(
custom_keys={
'head': dict(lr_mult=10.)}))
2、オンラインハードサンプルマイニング(OHEM)
ここではトレーニング サンプリングのためにピクセル サンプラーを実装します。これは、OHEM を有効にした PSPNet トレーニングの設定例です。
このようにして、信頼スコアが 0.7 より低いピクセルのみがトレーニングに使用されます。トレーニング中は少なくとも 100,000 ピクセルを維持します。thresh が指定されていない場合、min_kept は上部の欠落ピクセルを選択します。
Online Hard Example Mining (OHEM)
We implement pixel sampler here for training sampling. Here is an example config of training PSPNet with OHEM enabled.
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
3. クラスバランスの喪失
クラス分布が不均衡なデータセットの場合は、クラスごとに損失の重みを変更できます。これは都市景観データセットの例です。class_weight は重みパラメータとして CrossEntropyLoss に渡されます。
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py'
model=dict(
decode_head=dict(
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
# DeepLab used this class weight for cityscapes
class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507])))
4. 複数の損失
損失の計算では、複数の損失トレーニングを同時にサポートします。以下は、重み付けされた CrossEntropyLoss と DiceLoss の 1:3 の合計である損失関数を使用して、データセット DRIVE でトレーニングされた構成の例です。
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
このようにして、 loss_weight と loss_name は、それぞれ対応する損失のトレーニング ログ内の重みと名前になります。
注: この損失項を後方グラフに含める場合は、名前の前に loss_ を付ける必要があります。
5. 損失計算では、指定されたラベル インデックスを無視します。
mmseg には、さまざまなパブリック セグメンテーション データ セットの記述ファイルとロード コードが記述されています。PyTorch を使用したことがある人にとっては、さまざまなデータ セットの記述ファイルを学ぶのは依然として非常に簡単です。mmseg の初心者にとっては、reduce_zero_label だけが比較的馴染みがありません。独自の mmseg データ セットの場合、初心者にとって最も混乱するのは、おそらく、reduce_zero_label を True にするか False にするかということです。
それはなんのためですか?名前を直訳すると「0 値ラベルを減らす」です。マルチクラス セグメンテーション タスクでは、データ セットの 0 値がラベル ファイルの背景カテゴリとして使用されている場合、それを無視することをお勧めします。
ロードされたデータのソース コード フラグメントを開くと、reduce_zero_label を処理するコード部分が表示されます。これは、reduce_zero_label が有効な場合、元々 0 だったすべてのラベルが 255 に設定されます。これは、損失関数。このパラメータのデフォルトの回避値は 255 です。損失計算にはアノテーションが含まれます。150 クラスの ADE データ セットが上に示されていない理由は、reduce zero ラベルが有効になっており、値 0 の背景がignore_index に設定されているためです。
# mmseg/datasets/pipelines/loading.py
...
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
...
reduce_zero_label によって引き起こされる一般的な問題の説明
ここでは、ADE データセットのソース コードを例として取り上げます。reduce_zero_label のデフォルト設定は True です。ただし、初心者が前のセクションで Reduce_zero_label をマスターしたとしても、ADE の表面的な理解しかなく、reduce_zero_label が適切かどうか疑問に思うかもしれません。構成ファイルで有効にされているのは、150 個のインスタンスを使用することです。結局のところ、クラスの最初のインスタンスは num_classes 150 ではなく無視され、reduce_zero_label は当然のことと見なされます。
エラー原因の分析
# configs/_base_/datasets/ade20k.py
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), # ADE中reduce_zero_label默认设置为True
dict(...),
...
]
実際にトレーニングに参加したラベルには、CLASSES で定義されているクラスが 150 個しかありませんが、ラベル ファイルには実際には 151 個のクラスと、背景クラス (残りのマークされていない領域、または誤って無視された領域は、背景として分類されます) が含まれています。ラベル ファイル (中央値は 0) は 150 クラスには含まれておらず、トレーニング中にignore_indexに設定する必要があるため、前のセクションのreduce_zero_labelを使用して151クラスから背景を抽出し、それをignore_indexのみに設定します。誤ってreduce_zero_labelをオフに設定した場合、num_classesは151になります。
デフォルト設定では、avg_non_ignore=False です。これは、一部のピクセルはインデックス ラベルを無視するものですが、すべてのピクセルが損失計算の対象となることを意味します。
損失計算については、特定のラベルを無視する avg_non_ignore およびignore_index の受け渡しをサポートしています。このようにして、平均損失は無視されないラベルでのみ計算され、より良いパフォーマンスを達成できる可能性があります。参考情報は次のとおりです。unet これは、都市景観をトレーニングするデータセットの設定例です。損失計算では、背景としてラベル 0 を無視し、無視されていないラベルでのみ損失平均を計算します。
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
ignore_index デコーダ ヘッダーまたは補助ヘッダーを追加し、avg_non_ignore=True を追加するだけです。
# model settings
...
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
...