[PyTorch] Conversion of pre-trained weights

        As we all know, using a backbone network pre-trained with a large amount of data can improve the generalization ability of the entire model, and if we replace the backbone network of the network, we cannot directly use the original weights. The purpose of this project is to "steal" the network pre-training weights after you replace the backbone network.

        The conclusion is given : replace the backbone network of DeeplabV3+ by Xception to mobilenetv3, use pre-training and non-applicable pre-training, the results of the first epoch and the first 171 epochs, and the loss function of the first 200 rounds are as follows:

不适用预训练:
--->1轮miou如下:
   mIoU: 16.11; mPA: 16.67; Accuracy: 96.68
使用预训练:
--->1轮miou如下:
   mIoU: 40.53; mPA: 56.8; Accuracy: 97.09
不适用预训练:
--->171轮miou如下:
   mIoU:64.68 ; mPA: 78.0; Accuracy: 98.82
使用预训练:
--->171轮miou如下:
   mIoU:86.08 ; mPA: 92.54; Accuracy: 99.56

        It can be seen that after converting and loading the pre-trained weights, both the early gradient descent and the final model performance have been significantly improved .

1. The structure of the weight file

        The weight file for pyotrch is .pth. It consists of a collections.OrderedDict , which is a serialized collection.

         It consists of a unit consisting of two elements: the name of the layer + the Tensor weight of the layer. By traversing it, the two can be taken out separately:

for k, v in pretrained_dict.items():

         After verification, the data type of k is str type. In the original adaptive loading code, the reason why the pre-trained weights of other networks cannot be used is: the weight name does not match. Since the weight name is only str type, using replace will If the inconsistent part is modified, the weight file can be added to the network.

2. Modify the weight name

        First of all, this method is only for the case where the weight name has a certain relationship, so when you transplant the backbone network, try to keep the function name and structure name consistent with the original network . Since the backbone network is located at the beginning of the model, we can know the difference between the modified network and the original network by outputting the first layer. In this example, mobilenetv3 and deeplabV3 + in bubbliiiiing 's YoloV4 are used as materials to move the mobilenentv3 network into deeplabv3+.

        If you directly output a layer of OrderedDict in python3, an error will be reported because it is a serialized data structure. We need to use list(.) to convert it into an unsequential data structure and then output it. Use the following code to complete the output:

print('源模型文件格式为:{}'.format(list(pretrained_dict.keys())[0]))
print('目标模型文件格式为:{}'.format(list(model_dict.keys())[0]))

        After the output, it can be seen that the gap between the two models is:

         Use two variables to temporarily store the difference between the two, and then send these two variables to the replace() function to complete the modification of the layer name. As shown in the figure above, the gap between the source model and the target model is that the source model has one more model than the target model. Just delete it.

3. Convert and save

        The functions used in this example have been packaged separately:

def model_converter(custom_model ,model_path):
    model_dict      = custom_model.state_dict()
    pretrained_dict = torch.load(model_path, map_location = torch.device('cpu'))
    load_key, no_load_key, temp_dict = [], [], {}
    #  展示骨干网络的第一层
    print('源模型文件格式为:{}'.format(list(pretrained_dict.keys())[0]))
    print('目标模型文件格式为:{}'.format(list(model_dict.keys())[0]))
    print('请将两模型之间不同的部分输入:')
    orgStr = input('源模型:')
    targetStr = input('目标模型:')

    print('--->开始模型转换')
    for k, v in pretrained_dict.items():
        k = k.replace(orgStr,targetStr)
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            load_key.append(k)
        else:
            no_load_key.append(k)
    #  将权重更新到模型中
    model_dict.update(temp_dict)
    custom_model.load_state_dict(model_dict)
    #  保存模型
    torch.save(custom_model.state_dict(), 'converted_weights.pth')

        Just input the instantiated model and source model file into this function. After modification, the matched part will be automatically converted, and the unmatched part will be temporarily stored in no_load_key , which can be analyzed by itself. After testing, this method can basically complete the transformation of the backbone network.

Guess you like

Origin blog.csdn.net/weixin_37878740/article/details/130259766