The repvgg pre-training model is migrated to the pt backbone of yolov5-repvgg/model saving and reading

antecedent

When I was training the improved yolov5-repvgg, I found that yolov5s trains 50epoch, and my improved yolov5-repvgg model uses the yolov5s pre-training model as pre-training, and both train 50 for comparison. The measured object is rather special.
But I just think that yolov5-repvgg should not use yolov5s.pt as the pre-training model. After both of them do not use yolov5s.pt as the pre-training model, the effect of yolov5-repvgg is better.
insert image description here
My original model is yolov5s-repvgg model, without pre-training model, trained for 50epoch, (my yolov5-repvgg model is an improved method of this post, yolov5 is introduced into the RepVGG model structure )
and then the parameters of the official repvgg pre-training model ( Code: DingXiaoH/RepVGG
, weight: repvgg-A1-train.pth under the official weight link ), transplanted to the backbone of the model mentioned above.

My weight porting code:

#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@File  : trained_model
@Author: Wendy
@Time  : 2023/1/9 14:47
@Desc  :
yolov5-repvgg移植repvgg的backbone
"""

import torch
from models.yolo import Model
import sys

sys.path.append("../RepVGG-main")
from repvgg import repvgg_model_convert, create_RepVGG_A1

path = r'E:\competition\Wadhwani_AI_Bollworm_Counting_Challenge\pretrained_model_vgg/'

# yolov5-repvgg
model = Model(r"D:\98project\yolov5-7.0\models\yolov5s-haihang-repvgg.yaml", 3, 12, None).to("cpu")
model.load_state_dict(torch.load(path + 'last.pt'), strict=False)
# for k, v in model.state_dict().items():
#     print(f"{k}: {v.shape}")
print("yolov5s-repvgg原权重:")
log1 = open(path + "log_yolov5s_repvgg_original.txt", mode="a+", encoding="utf-8")
for k, v in model.state_dict().items():
    print(k, v, file=log1)

# repvgg
train_model = create_RepVGG_A1(deploy=False)
train_model.load_state_dict(torch.load(path + 'RepVGG-A1-train.pth'))  # or train from scratch
# do whatever you want with train_model
# deploy_model = repvgg_model_convert(train_model, save_path='RepVGG-A0-deploy.pth')
print("repvgg原权重:")
log2 = open(path + "log_repvgg_original.txt", mode="a+", encoding="utf-8")
for k, v in train_model.state_dict().items():
    print(k, v, file=log2)

yolov5s_parms = model.state_dict()
repvgg_parms = train_model.state_dict()
for key_yolov5s, v in yolov5s_parms.items():
    if key_yolov5s.startswith("model.0."):
        key_repvgg = 'stage0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.1."):
        key_repvgg = 'stage1.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.2."):
        key_repvgg = 'stage1.1.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.3."):
        key_repvgg = 'stage2.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.0."):
        key_repvgg = 'stage2.1.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.1."):
        key_repvgg = 'stage2.2.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.2."):
        key_repvgg = 'stage2.3.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.5."):
        key_repvgg = 'stage3.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.0."):
        key_repvgg = 'stage3.1.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.1."):
        key_repvgg = 'stage3.2.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.3."):
        key_repvgg = 'stage3.4.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.4."):
        key_repvgg = 'stage3.5.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.5."):
        key_repvgg = 'stage3.6.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.6."):
        key_repvgg = 'stage3.7.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.7."):
        key_repvgg = 'stage3.8.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.8."):
        key_repvgg = 'stage3.9.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.9."):
        key_repvgg = 'stage3.10.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.10."):
        key_repvgg = 'stage3.11.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.11."):
        key_repvgg = 'stage3.12.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.12."):
        key_repvgg = 'stage3.13.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.7."):
        key_repvgg = 'stage4.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    else:
        pass

print("yolov5s-repvgg新权重:")
log3 = open(path + "log_yolov5s_repvgg_update.txt", mode="a+", encoding="utf-8")
for k, v in yolov5s_parms.items():
    print(k, v, file=log3)

print("repvgg新权重:")
log4 = open(path + "log_repvgg_update.txt", mode="a+", encoding="utf-8")
for k, v in repvgg_parms.items():
    print(k, v, file=log4)

# 保存移植了repvgg的backbone的权重文件

torch.save(model, path+'yolov5s_repvgg_backbone.pt')

My method is more violent, because I found that the keys of the two can be matched according to their names.
insert image description here
After the transplantation, a yolov5s_repvgg_backbone.pt will be generated.
I trained 50epoch based on it, but the effect is not good, and it has become worse.

insert image description here

Guess you like

Origin blog.csdn.net/Qingyou__/article/details/128629736