foreword
This article introduces python
the library dedicated to the semantic separation model mmsegmentation
, github
the project address , the operating environment is Kaggle notebook
, GPU
is P100
sota
For environment configuration, pre-training model reasoning, fine-tuning the new model model on the watermelon dataset mask2former
, data description
mask2former
Due to the small watermelon dataset, we finally fine-tuned the model on the glomerulus dataset of histopathological sections , the data illustrate
This tutorial has some reference github
projects MMSegmentation_Tutorials
, project address
Environment configuration
Running through the code requires openmim
, mmsegmentation
, mmengine
, mmdetection
and mmcv
the environment. The configuration mmcv
of the environment kaggle
is troublesome and requires a pre-configured package. Here I have packaged all the pre-configured packages and put them in the data frozen-packages-mmdetection
set details page
import IPython. display as display
!pip install - U openmim
!rm - rf mmsegmentation
!git clone https: // github. com/ open - mmlab/ mmsegmentation. git
% cd mmsegmentation
!pip install - v - e .
!pip install "mmdet>=3.0.0rc4"
!pip install - q / kaggle/ input / frozen- packages- mmdetection/ mmcv- 2.0 .1 - cp310- cp310- linux_x86_64. whl
!pip install wandb
display. clear_output( )
After running the above code in the actual test, kaggle
the requirements of running the project can be met in , and no error is reported (July 13, 2023).
Import common base packages
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Config
import matplotlib. pyplot as plt
% matplotlib inline
from mmseg. datasets import cityscapes
from mmseg. utils import register_all_modules
register_all_modules( )
from mmseg. datasets import CityscapesDataset
from mmengine. model. utils import revert_sync_batchnorm
from mmseg. apis import init_model, inference_model, show_result_pyplot
import warnings
warnings. filterwarnings( 'ignore' )
display. clear_output( )
Create folders for placing datasets, model pre-trained weights, and model inference output
os. mkdir( 'checkpoint' )
os. mkdir( 'outputs' )
os. mkdir( 'data' )
Download the pre-trained weights of pspnet, segformer, and mask2former on cityscapes respectively, and save them in the checkpoint
folder
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / pspnet/ pspnet_r50- d8_512x1024_40k_cityscapes/ pspnet_r50- d8_512x1024_40k_cityscapes_20200605_003338- 2966598c. pth - P checkpoint
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / segformer/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth - P checkpoint
!wget https: // download. openmmlab. com/ mmsegmentation/ v0. 5 / mask2former/ mask2former_swin- l- in22k- 384x384- pre_8xb2- 90k_cityscapes- 512x1024/ mask2former_swin- l- in22k- 384x384- pre_8xb2- 90k_cityscapes- 512x1024_20221202_141901- 28ad20f1. pth - P checkpoint
display. clear_output( )
Download some pictures and videos for testing the model and store them data
in a folder.
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220713 - mmdetection/ images/ street_uk. jpeg - P data
!wget https: // zihao- download. obs. cn- east- 3 . myhuaweicloud. com/ detectron2/ traffic. mp4 - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220713 - mmdetection/ images/ street_20220330_174028. mp4 - P data
display. clear_output( )
image reasoning
command-line reasoning
Use the command line to reason about images and PIL
visualize the results using
pspnet
Models and segformer
models were used for inference
!python demo/ image_demo. py \
data/ street_uk. jpeg \
configs/ pspnet/ pspnet_r50- d8_4xb2- 40k_cityscapes- 512x1024. py \
checkpoint/ pspnet_r50- d8_512x1024_40k_cityscapes_20200605_003338- 2966598c. pth \
- - out- file outputs/ B1_uk_pspnet. jpg \
- - device cuda: 0 \
- - opacity 0.5
display. clear_output( )
Image. open ( 'outputs/B1_uk_pspnet.jpg' )
!python demo/ image_demo. py \
data/ street_uk. jpeg \
configs/ segformer/ segformer_mit- b5_8xb1- 160k_cityscapes- 1024x1024. py \
checkpoint/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth \
- - out- file outputs/ B1_uk_segformer. jpg \
- - device cuda: 0 \
- - opacity 0.5
display. clear_output( )
Image. open ( 'outputs/B1_uk_segformer.jpg' )
It can be seen that the actual segformer
effect pspnet
is better than the model effect, and it can basically separate different objects.
API reasoning
Image inference using mmsegmentation's Python API
Use mask2former model inference and use matplotlib to visualize the results
img_path = 'data/street_uk.jpeg'
img_pil = Image. open ( img_path)
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
model = init_model( config_file, checkpoint_file, device= 'cuda:0' )
if not torch. cuda. is_available( ) :
model = revert_sync_batchnorm( model)
result = inference_model( model, img_path)
pred_mask = result. pred_sem_seg. data[ 0 ] . detach( ) . cpu( ) . numpy( )
display. clear_output( )
img_bgr = cv2. imread( img_path)
plt. figure( figsize= ( 14 , 8 ) )
plt. imshow( img_bgr[ : , : , : : - 1 ] )
plt. imshow( pred_mask, alpha= 0.55 )
plt. axis( 'off' )
plt. savefig( 'outputs/B2-1.jpg' )
plt. show( )
mask2former
As sota
a model, it works really well!
video reasoning
command-line reasoning
not recommended, very slow
!python demo/ video_demo. py \
data/ street_20220330_174028. mp4 \
configs/ segformer/ segformer_mit- b5_8xb1- 160k_cityscapes- 1024x1024. py \
checkpoint/ segformer_mit- b5_8x1_1024x1024_160k_cityscapes_20211206_072934- 87a052ec. pth \
- - device cuda: 0 \
- - output- file outputs/ B3_video. mp4 \
- - opacity 0.5
API reasoning
mask2former
The model performs inference on video using the API
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
model = init_model( config_file, checkpoint_file, device= 'cuda:0' )
if not torch. cuda. is_available( ) :
model = revert_sync_batchnorm( model)
display. clear_output( )
input_video = 'data/street_20220330_174028.mp4'
temp_out_dir = time. strftime( '%Y%m%d%H%M%S' )
os. mkdir( temp_out_dir)
print ( '创建临时文件夹 {} 用于存放每帧预测结果' . format ( temp_out_dir) )
classes = cityscapes. CityscapesDataset. METAINFO[ 'classes' ]
palette = cityscapes. CityscapesDataset. METAINFO[ 'palette' ]
def pridict_single_frame ( img, opacity= 0.2 ) :
result = inference_model( model, img)
seg_map = np. array( result. pred_sem_seg. data[ 0 ] . detach( ) . cpu( ) . numpy( ) ) . astype( 'uint8' )
seg_img = Image. fromarray( seg_map) . convert( 'P' )
seg_img. putpalette( np. array( palette, dtype= np. uint8) )
show_img = ( np. array( seg_img. convert( 'RGB' ) ) ) * ( 1 - opacity) + img* opacity
return show_img
imgs = mmcv. VideoReader( input_video)
prog_bar = mmengine. ProgressBar( len ( imgs) )
for frame_id, img in enumerate ( imgs) :
show_img = pridict_single_frame( img, opacity= 0.15 )
temp_path = f' {
temp_out_dir} / {
frame_id: 06d } .jpg'
cv2. imwrite( temp_path, show_img)
prog_bar. update( )
mmcv. frames2video( temp_out_dir, 'outputs/B3_video.mp4' , fps= imgs. fps, fourcc= 'mp4v' )
shutil. rmtree( temp_out_dir)
print ( '删除临时文件夹' , temp_out_dir)
Small sample data set fine-tuning mask2former
Fine-tuning the model on the watermelon semantically separated dataset
download dataset
!rm - rf Watermelon87_Semantic_Seg_Mask. zip Watermelon87_Semantic_Seg_Mask
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ dataset/ watermelon/ Watermelon87_Semantic_Seg_Mask. zip
!unzip Watermelon87_Semantic_Seg_Mask. zip >> / dev/ null
!rm - rf Watermelon87_Semantic_Seg_Mask. zip
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ watermelon_test1. jpg - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ video_watermelon_2. mp4 - P data
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20230130 - mmseg/ watermelon/ data/ video_watermelon_3. mov - P data
!find . - iname '__MACOSX'
!find . - iname '.DS_Store'
!find . - iname '.ipynb_checkpoints'
!for i in `find . - iname '__MACOSX' `; do rm - rf $i; done
!for i in `find . - iname '.DS_Store' `; do rm - rf $i; done
!for i in `find . - iname '.ipynb_checkpoints' `; do rm - rf $i; done
!find . - iname '__MACOSX'
!find . - iname '.DS_Store'
!find . - iname '.ipynb_checkpoints'
display. clear_output( )
Visual Exploration of Semantic Segmentation Datasets
Visualize Semantic Information
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'
img = cv2. imread( img_path)
mask = cv2. imread( mask_path)
plt. figure( figsize= ( 8 , 8 ) )
plt. imshow( img[ : , : , : : - 1 ] )
plt. imshow( mask[ : , : , 0 ] , alpha= 0.6 )
plt. axis( 'off' )
plt. show( )
Define Dataset and Pipeline
In Dataset
the section, you can set the specific category corresponding to the value, as well as the label color of different categories. Image format, whether to ignore category 0
In Pipeline
the section, you can set the data processing steps for training and verification. and the specified image crop size
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
# 类别和对应的 RGB配色
METAINFO = {
'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],
'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]
}
# 指定图像扩展名、标注扩展名
def __init__(self,
seg_map_suffix='.png', # 标注mask图像的格式
reduce_zero_label=False, # 类别ID为0的类别是否需要除去
**kwargs) -> None:
super().__init__(
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
"""
with io. open ( 'mmseg/datasets/MyCustomDataset.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_dataset)
will custom_dataset
join __init__.py
the file
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate,
RandomRotFlip, Rerange, ResizeShortestEdge,
ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'MyCustomDataset'
]
"""
with io. open ( 'mmseg/datasets/__init__.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_init)
Define Dataset Preprocessing Passes
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)
# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)
# 训练预处理
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
# 测试预处理
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
# 训练 Dataloader
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/train', seg_map_path='ann_dir/train'),
pipeline=train_pipeline))
# 验证 Dataloader
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='img_dir/val', seg_map_path='ann_dir/val'),
pipeline=test_pipeline))
# 测试 Dataloader
test_dataloader = val_dataloader
# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
# 测试 Evaluator
test_evaluator = val_evaluator
"""
with io. open ( 'configs/_base_/datasets/custom_pipeline.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_pipeline)
Modify the configuration file
Mainly modify the number of categories, pre-training weight path, initialize the image size (generally an integer multiple of 128), batch_size
scale the learning rate (the modified ratio is base_lr_default * (your_bs / default_bs)
), change the learning rate decay strategy
About the learning rate: optimizer
in the main revision lr
, no modification optim_wrapper
Freezing the backbone network of the model mask2former
can speed up training for
cfg = Config. fromfile( 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py' )
dataset_cfg = Config. fromfile( 'configs/_base_/datasets/custom_pipeline.py' )
cfg. merge_from_dict( dataset_cfg)
NUM_CLASS = 6
cfg. norm_cfg = dict ( type = 'BN' , requires_grad= True )
cfg. crop_size = ( 640 , 640 )
cfg. model. data_preprocessor. size = cfg. crop_size
cfg. load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
cfg. model. decode_head. num_classes = NUM_CLASS
cfg. model. decode_head. loss_cls. class_weight = [ 1.0 ] * NUM_CLASS + [ 0.1 ]
cfg. model. backbone. frozen_stages = 4
cfg. train_dataloader. batch_size = 2
cfg. test_dataloader = cfg. val_dataloader
cfg. optimizer. lr = cfg. optimizer. lr / 8
cfg. work_dir = './work_dirs'
cfg. train_cfg. max_iters = 4000
cfg. train_cfg. val_interval = 50
cfg. default_hooks. logger. interval = 50
cfg. default_hooks. checkpoint. interval = 50
cfg. default_hooks. checkpoint. max_keep_ckpts = 2
cfg. default_hooks. checkpoint. save_best = 'mIoU'
cfg. param_scheduler[ 0 ] . end = cfg. train_cfg. max_iters
cfg[ 'randomness' ] = dict ( seed= 0 )
cfg. visualizer. vis_backends = [ dict ( type = 'LocalVisBackend' ) , dict ( type = 'WandbVisBackend' ) ]
cfg. dump( 'custom_mask2former.py' )
!python tools/ train. py custom_mask2former. py
Select the optimal model and test the accuracy of the model
best_pth = glob. glob( 'work_dirs/best_mIoU*.pth' ) [ 0 ]
!python tools/ test. py custom_mask2former. py '{best_pth}'
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| Class | IoU | Acc | Dice | Fscore | Precision | Recall |
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| background | 98.55 | 99.12 | 99.27 | 99.27 | 99.42 | 99.12 |
| red | 96.54 | 98.83 | 98.24 | 98.24 | 97.65 | 98.83 |
| green | 94.37 | 96.08 | 97.1 | 97.1 | 98.14 | 96.08 |
| white | 85.96 | 92.67 | 92.45 | 92.45 | 92.24 | 92.67 |
| seed- black | 81.98 | 90.87 | 90.1 | 90.1 | 89.34 | 90.87 |
| seed- white | 65.57 | 69.98 | 79.21 | 79.21 | 91.24 | 69.98 |
+ - - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
Visualize training metrics
Fine-tuning the model on the glomerulus dataset
mask2former
Fine-tuning the model on a single-class dataset (histopathologically sliced glomeruli)
First clear the working directory, data folder and outputs file
!rm - r work_dirs/ *
!rm - r data/ *
!rm - r outputs/ *
Visual Exploration of Semantic Segmentation Datasets
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'
mask = cv2. imread( '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png' )
np. unique( mask)
array( [ 0 , 1 ] , dtype= uint8)
Visualize Semantic Segmentation Information
n = 5
opacity = 0.65
fig, axes = plt. subplots( nrows= n, ncols= n, sharex= True , figsize= ( 12 , 12 ) )
for i, file_name in enumerate ( os. listdir( PATH_IMAGE) [ : n** 2 ] ) :
img_path = os. path. join( PATH_IMAGE, file_name)
mask_path = os. path. join( PATH_MASKS, file_name. split( '.' ) [ 0 ] + '.png' )
img = cv2. imread( img_path)
mask = cv2. imread( mask_path)
axes[ i// n, i% n] . imshow( img[ : , : , : : - 1 ] )
axes[ i// n, i% n] . imshow( mask[ : , : , 0 ] , alpha= opacity)
axes[ i// n, i% n] . axis( 'off' )
fig. suptitle( 'Image and Semantic Label' , fontsize= 20 )
plt. tight_layout( )
plt. savefig( 'outputs/C2-1.jpg' )
plt. show( )
Split training set and test set
Create new training and verification folders
!mkdir - p data/ images/ train
!mkdir - p data/ images/ val
!mkdir - p data/ masks/ train
!mkdir - p data/ masks/ val
Randomly scramble the data and split it according to 90% training set and 10% test set
def copy_file ( og_images, og_masks, tr_images, tr_masks, thor) :
file_names = os. listdir( og_images)
random. shuffle( file_names)
split_index = int ( thor * len ( file_names) )
for file_name in file_names[ : split_index] :
og_image = os. path. join( og_images, file_name)
og_mask = os. path. join( og_masks, file_name)
tr_image = os. path. join( tr_images, 'train' , file_name)
tr_mask = os. path. join( tr_masks, 'train' , file_name)
shutil. copyfile( og_image, tr_image)
shutil. copyfile( og_mask, tr_mask)
for file_name in file_names[ split_index: ] :
og_image = os. path. join( og_images, file_name)
og_mask = os. path. join( og_masks, file_name)
tr_image = os. path. join( tr_images, 'val' , file_name)
tr_mask = os. path. join( tr_masks, 'val' , file_name)
shutil. copyfile( og_image, tr_image)
shutil. copyfile( og_mask, tr_mask)
og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'
tr_images = 'data/images'
tr_masks = 'data/masks'
copy_file( og_images, og_masks, tr_images, tr_masks, 0.9 )
Redefine Dataset and Pipeline
Mainly to modify the category and the corresponding RGB color matching
And the path information of dataload
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):
# 类别和对应的RGB配色
METAINFO = {
'classes':['normal','sclerotic'],
'palette':[[127,127,127],[251,189,8]]
}
# 指定图像扩展名、标注扩展名
def __init__(self,img_suffix='.png',
seg_map_suffix='.png', # 标注mask图像的格式
reduce_zero_label=False, # 类别ID为0的类别是否需要除去
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
"""
with io. open ( 'mmseg/datasets/MyCustomDataset.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate,
RandomRotFlip, Rerange, ResizeShortestEdge,
ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'MyCustomDataset'
]
"""
with io. open ( 'mmseg/datasets/__init__.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_init)
Define data preprocessing pipeline
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)
# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)
# 训练预处理
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
# 测试预处理
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
# 训练 Dataloader
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/train', seg_map_path='masks/train'),
pipeline=train_pipeline))
# 验证 Dataloader
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/val', seg_map_path='masks/val'),
pipeline=test_pipeline))
# 测试 Dataloader
test_dataloader = val_dataloader
# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
# 测试 Evaluator
test_evaluator = val_evaluator
"""
with io. open ( 'configs/_base_/datasets/custom_pipeline.py' , 'w' , encoding= 'utf-8' ) as f:
f. write( custom_pipeline)
Modify the configuration file
cfg = Config. fromfile( 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py' )
dataset_cfg = Config. fromfile( 'configs/_base_/datasets/custom_pipeline.py' )
cfg. merge_from_dict( dataset_cfg)
change configuration file
NUM_CLASS = 2
cfg. norm_cfg = dict ( type = 'BN' , requires_grad= True )
cfg. crop_size = ( 640 , 640 )
cfg. model. data_preprocessor. size = cfg. crop_size
cfg. load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
cfg. model. decode_head. num_classes = NUM_CLASS
cfg. model. decode_head. loss_cls. class_weight = [ 1.0 ] * NUM_CLASS + [ 0.1 ]
cfg. model. backbone. frozen_stages = 4
cfg. train_dataloader. batch_size = 2
cfg. test_dataloader = cfg. val_dataloader
cfg. optimizer. lr = cfg. optimizer. lr / 8
cfg. work_dir = './work_dirs'
cfg. train_cfg. max_iters = 40000
cfg. train_cfg. val_interval = 500
cfg. default_hooks. logger. interval = 50
cfg. default_hooks. checkpoint. interval = 2500
cfg. default_hooks. checkpoint. max_keep_ckpts = 2
cfg. default_hooks. checkpoint. save_best = 'mIoU'
cfg[ 'randomness' ] = dict ( seed= 0 )
cfg. visualizer. vis_backends = [ dict ( type = 'LocalVisBackend' ) , dict ( type = 'WandbVisBackend' ) ]
Save the configuration file and start training
cfg. dump( 'custom_mask2former.py' )
!python tools/ train. py custom_mask2former. py
Visualize training metrics
Evaluate models and test inference speed
best_pth = glob. glob( 'work_dirs/best_mIoU*.pth' ) [ 0 ]
!python tools/ test. py custom_mask2former. py '{best_pth}'
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| Class | IoU | Acc | Dice | Fscore | Precision | Recall |
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
| normal | 99.74 | 99.89 | 99.87 | 99.87 | 99.86 | 99.89 |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71 | 93.57 | 91.87 |
+ - - - - - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - + - - - - - - - - + - - - - - - - - - - - + - - - - - - - - +
Test model inference speed
!python tools/ analysis_tools/ benchmark. py custom_mask2former. py '{best_pth}'
Done image [ 50 / 200 ] , fps: 2.24 img / s
Done image [ 100 / 200 ] , fps: 2.24 img / s
Done image [ 150 / 200 ] , fps: 2.24 img / s
Done image [ 200 / 200 ] , fps: 2.24 img / s
Overall fps: 2.24 img / s
Average fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0