【实验】SegViT: Semantic Segmentation with Plain Vision Transformers

Insert image description here
Want to learn from SegViT official model source code and deploy it to your own local code file

1. Environment configuration

The official website requires the installation of mmcv-full=1.4.4 and mmsegmentation=0.24.0.
Remember to uninstall the original versions of mmcv and mmsegmentation before doing so.

pip uninstall mmcv
pip uninstall mmcv-full
pip uninstall mmsegmentation

Install mmcv

Among them, mmcv includes two versions: one is the full version of mmcv (originally called mmcv-full), and the other is the streamlined version of mmcv-lite (originally called mmcv). The name was changed after version 2.0.0. For specific differences, please see the mmcv official website manual . And blog
to install mmcv-full (that is, the full version of mmcv) mainly refer to the mmcv official website manual.
If you want to install mmcv>=2.0.0, you can install it directly according to the official website manual without going into details.
If you want to install a historical version, for example, I installed mmcv-full==1.4.4, you can refer to my records.
Before installing mmcv, you must first know your corresponding versions of pytorch and cuda.
View pytorch version:

python -c 'import torch;print(torch.__version__)'

If the version information is output, pytorch has been installed.
Check the cuda version:
Be careful to check the cuda version corresponding to pytorch in your environment.
For example,
this is the cuda version I used to check using the nvidia-smi command:
Insert image description here
This is the command I used to check the cuda version corresponding to pytorch:

python -c 'import torch;print(torch.version.cuda)'

It can also be written as:

Reference blog: https://blog.csdn.net/qq_49821869/article/details/127700187

python

>>>import torch
>>>torch.version.cuda

Insert image description here
Here my pytorch version should be 1.11.0, and the corresponding cuda version is 11.3

Reference blog: https://blog.csdn.net/qq_41661809/article/details/125345690

So, I entered the command:

pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

It was unsuccessful, so I visited this website to check and found that the lowest version I could use was 1.4.7,
Insert image description here
so I changed the command to:

pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html

mmcv-full installation completed

Install mmsegmentation

I originally installed mmsegmentation according to the instructions on the official website
, but it required mmcv>=2.0.0, and the installed version was mmsegmentation==1.0.0, which conflicted with my requirements.
Note that mmsegmentation must match the mmcv version:

Reference blog: https://blog.csdn.net/CharilePuth/article/details/122909620

Insert image description here

So I directly:

pip install mmsegmentation==0.24.0

Successful installation.
"Pip installation package is as easy as drinking water" - a big boss once said.

2. Code!

Find model configuration file

Go to the official website and find the config file corresponding to the model in Training: In Highlights, I
Insert image description here
learned that one of the highlights of this article is the shrinkage structure, which can reduce calculation costs, so next I will choose the shrinkage structure:
Insert image description here

Since the size of the image I want to run is 512, I found the corresponding model of the COCO data set with the same 512*512 in the Results of this code:
Insert image description here

Return to the configs folder to find the network model corresponding to this data set:
Insert image description here
Insert image description here
observe the code and learn that the backbone used is vit_shrink, and the decoding head is TPNATMHead:
Insert image description here
pay attention to the parameter settings, and also pay attention to the configuration file of __base__, the parameters of which are declared in the model When you enter it.

Find model code

Enter the backbone folder and find the vit_shrink network:
Insert image description here
copy and paste it into your own py file.
Find the decoding header code under the decode_heads folder:
Insert image description here
copy and paste it into your own py file.

patching up code

  1. Supplement the library
    file with whatever is lacking in the library file. For example, if you need to reference the contents of the other two decoder codes in the tpn_atm_head decoder code, directly enter the codes of the other two decoders with ctrl C+V and leave the modules you need to use. Just come down:
    Insert image description here
    Insert image description here
  2. Check the input and output of the input
    and output backbone:
    Insert image description here
    The input and output of the decoder part are as shown in the figure:
    Insert image description here
    Write a SegViT to test the input and output. Please refer to the configuration file to declare the corresponding configuration in advance:
class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        # self.backbone = get_vit_shrink()
        self.backbone = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        # if self.training:
            # return out['pred'], out['ce_aux']
        # else:
            # return out
        return out
 

Run check out type

if __name__ == "__main__":
    x = torch.randn(4, 3, 512, 512)
    net = SegViT(6)
    # flops, params = profile(net, (x,))
    # print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0))
    # res, aux = net(x)
    res = net(x)
    print(res)

Then it is found that the output is a dictionary type, prediction is the value corresponding to the key name pred, the value is tensor type, the shape size is (4,6,512,512), and the output is correct.
Next we need to find the output of the auxiliary branch.
Found in the forward of the decoder header:
Insert image description here
remove the comments to get the output of the auxiliary branch (the output of the auxiliary branch will be added to atm_out in the form of dictionary elements, you can debug it), remember to add the comments of the corresponding initialization function Also remove:
Insert image description here
Among them, since I am running with a single card, I changed SyncBN to BN, otherwise an error will be reported.
In addition, the output of the training phase and the testing phase is different, which can be debugged and checked:

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
  1. Load the weight file.
    Note that the weight file can be downloaded in advance.
def get_vit_shrink(pretrained=True, img_size=512, in_channels=1024, out_indices=[7,23]):
    model = vit_shrink(
            img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
    if pretrained:
        checkpoint = '权重文件所在路径'
        # if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
        # else: state_dict = checkpoint
        model.load_state_dict(checkpoint, strict=False)
    return model

The final model is:

class SegViT(nn.Module):
    def __init__(self, num_class):
        super(SegViT, self).__init__()
        out_indices = [7,23]
        in_channels = 1024
        img_size = 512
        # checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
        # checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'

        self.backbone = get_vit_shrink()
        self.decoder = TPNATMHead(
            img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))

    def forward(self, _x):
        x = self.backbone(_x)
        out = self.decoder(x)
        if self.training:
            return out['pred'], out['ce_aux']
        else:
            return out
 
  1. Check the final input and output
    . End.

3. Run the model

In your own framework, configure parameters and then run.

Finish.

Guess you like

Origin blog.csdn.net/qq_43606119/article/details/130764322