[mmsegmentation] Guide to stepping on the pit - loss_weight adjustment in config

Recently, in the parameter adjustment experiment of various models in the mmseg project, I paid attention to a class_weight parameter. According to the official website, this parameter is a means to adjust the fitting problem caused by sample imbalance and improve the accuracy of the algorithm. It is generally implemented in the decode_head in the config file, and the class_weight of all class samples is the same by default. Let's take an example to illustrate.

../configs/_base_/models/ann_r50-d8.py

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='open-mmlab://resnet50_v1c',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='ANNHead',
        in_channels=[1024, 2048],
        in_index=[2, 3],
        channels=512,
        project_channels=256,
        query_scales=(1, ),
        key_pool_scales=(1, 3, 6, 8),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=1024,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

It can be seen that the default numclass=19 means that the class_weight contribution value of each class of samples in the 19 classes is equal.

Change the weight of loss_weight according to the number of samples: If num_classes=9, and the sample distribution is unbalanced, as shown in the figure below:

The category loss weight can be reasonably adjusted according to the sample distribution. The smaller the number of samples, increase its class_weight weight. On the contrary, the larger the sample size, the lower its class_weight weight. However, it cannot be increased or decreased too much, which will lead to model training. Does not converge. Not only did it not improve the model performance, but it decreased.

Change class_weight to:

  decode_head=dict(
        type='ANNHead',
        in_channels=[1024, 2048],
        in_index=[2, 3],
        channels=512,
        project_channels=256,
        query_scales=(1, ),
        key_pool_scales=(1, 3, 6, 8),
        dropout_ratio=0.1,
        num_classes=9,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,
             class_weight=[
                0.8373, 0.918, 1.166, 0.9539, 1.1766, 0.1869, 0.9954, 1.1489,1.0152
            ])),

Comparison before and after modifying the result of class_weight (red is modified)

 

It's not easy to organize, welcome to one-click three links! ! !

Guess you like

Origin blog.csdn.net/qq_38308388/article/details/127630051#comments_27451916