Want to learn from SegViT official model source code and deploy it to your own local code file
1. Environment configuration
The official website requires the installation of mmcv-full=1.4.4 and mmsegmentation=0.24.0.
Remember to uninstall the original versions of mmcv and mmsegmentation before doing so.
pip uninstall mmcv
pip uninstall mmcv-full
pip uninstall mmsegmentation
Install mmcv
Among them, mmcv includes two versions: one is the full version of mmcv (originally called mmcv-full), and the other is the streamlined version of mmcv-lite (originally called mmcv). The name was changed after version 2.0.0. For specific differences, please see the mmcv official website manual . And blog
to install mmcv-full (that is, the full version of mmcv) mainly refer to the mmcv official website manual.
If you want to install mmcv>=2.0.0, you can install it directly according to the official website manual without going into details.
If you want to install a historical version, for example, I installed mmcv-full==1.4.4, you can refer to my records.
Before installing mmcv, you must first know your corresponding versions of pytorch and cuda.
View pytorch version:
python -c 'import torch;print(torch.__version__)'
If the version information is output, pytorch has been installed.
Check the cuda version:
Be careful to check the cuda version corresponding to pytorch in your environment.
For example,
this is the cuda version I used to check using the nvidia-smi command:
This is the command I used to check the cuda version corresponding to pytorch:
python -c 'import torch;print(torch.version.cuda)'
It can also be written as:
Reference blog: https://blog.csdn.net/qq_49821869/article/details/127700187
python
>>>import torch
>>>torch.version.cuda
Here my pytorch version should be 1.11.0, and the corresponding cuda version is 11.3
Reference blog: https://blog.csdn.net/qq_41661809/article/details/125345690
So, I entered the command:
pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
It was unsuccessful, so I visited this website to check and found that the lowest version I could use was 1.4.7,
so I changed the command to:
pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
mmcv-full installation completed
Install mmsegmentation
I originally installed mmsegmentation according to the instructions on the official website
, but it required mmcv>=2.0.0, and the installed version was mmsegmentation==1.0.0, which conflicted with my requirements.
Note that mmsegmentation must match the mmcv version:
Reference blog: https://blog.csdn.net/CharilePuth/article/details/122909620
So I directly:
pip install mmsegmentation==0.24.0
Successful installation.
"Pip installation package is as easy as drinking water" - a big boss once said.
2. Code!
Find model configuration file
Go to the official website and find the config file corresponding to the model in Training: In Highlights, I
learned that one of the highlights of this article is the shrinkage structure, which can reduce calculation costs, so next I will choose the shrinkage structure:
Since the size of the image I want to run is 512, I found the corresponding model of the COCO data set with the same 512*512 in the Results of this code:
Return to the configs folder to find the network model corresponding to this data set:
observe the code and learn that the backbone used is vit_shrink, and the decoding head is TPNATMHead:
pay attention to the parameter settings, and also pay attention to the configuration file of __base__, the parameters of which are declared in the model When you enter it.
Find model code
Enter the backbone folder and find the vit_shrink network:
copy and paste it into your own py file.
Find the decoding header code under the decode_heads folder:
copy and paste it into your own py file.
patching up code
- Supplement the library
file with whatever is lacking in the library file. For example, if you need to reference the contents of the other two decoder codes in the tpn_atm_head decoder code, directly enter the codes of the other two decoders with ctrl C+V and leave the modules you need to use. Just come down:
- Check the input and output of the input
and output backbone:
The input and output of the decoder part are as shown in the figure:
Write a SegViT to test the input and output. Please refer to the configuration file to declare the corresponding configuration in advance:
class SegViT(nn.Module):
def __init__(self, num_class):
super(SegViT, self).__init__()
out_indices = [7,23]
in_channels = 1024
img_size = 512
# checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'
# self.backbone = get_vit_shrink()
self.backbone = vit_shrink(
img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
self.decoder = TPNATMHead(
img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))
def forward(self, _x):
x = self.backbone(_x)
out = self.decoder(x)
# if self.training:
# return out['pred'], out['ce_aux']
# else:
# return out
return out
Run check out type
if __name__ == "__main__":
x = torch.randn(4, 3, 512, 512)
net = SegViT(6)
# flops, params = profile(net, (x,))
# print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0))
# res, aux = net(x)
res = net(x)
print(res)
Then it is found that the output is a dictionary type, prediction is the value corresponding to the key name pred, the value is tensor type, the shape size is (4,6,512,512), and the output is correct.
Next we need to find the output of the auxiliary branch.
Found in the forward of the decoder header:
remove the comments to get the output of the auxiliary branch (the output of the auxiliary branch will be added to atm_out in the form of dictionary elements, you can debug it), remember to add the comments of the corresponding initialization function Also remove:
Among them, since I am running with a single card, I changed SyncBN to BN, otherwise an error will be reported.
In addition, the output of the training phase and the testing phase is different, which can be debugged and checked:
def forward(self, _x):
x = self.backbone(_x)
out = self.decoder(x)
if self.training:
return out['pred'], out['ce_aux']
else:
return out
- Load the weight file.
Note that the weight file can be downloaded in advance.
def get_vit_shrink(pretrained=True, img_size=512, in_channels=1024, out_indices=[7,23]):
model = vit_shrink(
img_size=(img_size,img_size),embed_dims=in_channels,num_layers=24,drop_path_rate=0.3,num_heads=16,out_indices=out_indices)
if pretrained:
checkpoint = '权重文件所在路径'
# if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
# else: state_dict = checkpoint
model.load_state_dict(checkpoint, strict=False)
return model
The final model is:
class SegViT(nn.Module):
def __init__(self, num_class):
super(SegViT, self).__init__()
out_indices = [7,23]
in_channels = 1024
img_size = 512
# checkpoint = './pretrained/vit_large_p16_384_20220308-d4efb41d.pth'
# checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_large_p16_384_20220308-d4efb41d.pth'
self.backbone = get_vit_shrink()
self.decoder = TPNATMHead(
img_size=img_size,in_channels=in_channels,channels=in_channels,embed_dims=in_channels//2,num_heads=16,num_classes=num_class,num_layers=3, use_stages=len(out_indices))
def forward(self, _x):
x = self.backbone(_x)
out = self.decoder(x)
if self.training:
return out['pred'], out['ce_aux']
else:
return out
- Check the final input and output
. End.
3. Run the model
In your own framework, configure parameters and then run.
Finish.