repvgg预训练模型迁移到yolov5-repvgg的pt的backbone中/模型的保存与读取

前情

我在训练改进的yolov5-repvgg时,发现yolov5s训练50epoch,和我改进的yolov5-repvgg模型,以yolov5s预训练模型为预训练,俩对比都训50,发现没有yolov5s训50epoch,也可能我的待测物体比较特殊。
但是就是认为是yolov5-repvgg不应该以yolov5s.pt为预训练模型,确实俩都不用yolov5s.pt为预训练模型后,yolov5-repvgg的效果更好。
在这里插入图片描述
我的原模型是yolov5s-repvgg模型,不带预训练模型,训练了50epoch,(我的yolov5-repvgg模型是这个帖子的改进方法yolov5 引入RepVGG模型结构
然后把官方的repvgg预训练模型的参数(代码:DingXiaoH/RepVGG
,权重:官方权重链接下的repvgg-A1-train.pth),移植到上面所述模型的backbone中。

我的权重移植代码:

#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@File  : trained_model
@Author: Wendy
@Time  : 2023/1/9 14:47
@Desc  :
yolov5-repvgg移植repvgg的backbone
"""

import torch
from models.yolo import Model
import sys

sys.path.append("../RepVGG-main")
from repvgg import repvgg_model_convert, create_RepVGG_A1

path = r'E:\competition\Wadhwani_AI_Bollworm_Counting_Challenge\pretrained_model_vgg/'

# yolov5-repvgg
model = Model(r"D:\98project\yolov5-7.0\models\yolov5s-haihang-repvgg.yaml", 3, 12, None).to("cpu")
model.load_state_dict(torch.load(path + 'last.pt'), strict=False)
# for k, v in model.state_dict().items():
#     print(f"{k}: {v.shape}")
print("yolov5s-repvgg原权重:")
log1 = open(path + "log_yolov5s_repvgg_original.txt", mode="a+", encoding="utf-8")
for k, v in model.state_dict().items():
    print(k, v, file=log1)

# repvgg
train_model = create_RepVGG_A1(deploy=False)
train_model.load_state_dict(torch.load(path + 'RepVGG-A1-train.pth'))  # or train from scratch
# do whatever you want with train_model
# deploy_model = repvgg_model_convert(train_model, save_path='RepVGG-A0-deploy.pth')
print("repvgg原权重:")
log2 = open(path + "log_repvgg_original.txt", mode="a+", encoding="utf-8")
for k, v in train_model.state_dict().items():
    print(k, v, file=log2)

yolov5s_parms = model.state_dict()
repvgg_parms = train_model.state_dict()
for key_yolov5s, v in yolov5s_parms.items():
    if key_yolov5s.startswith("model.0."):
        key_repvgg = 'stage0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.1."):
        key_repvgg = 'stage1.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.2."):
        key_repvgg = 'stage1.1.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.3."):
        key_repvgg = 'stage2.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.0."):
        key_repvgg = 'stage2.1.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.1."):
        key_repvgg = 'stage2.2.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.4.2."):
        key_repvgg = 'stage2.3.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.5."):
        key_repvgg = 'stage3.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.0."):
        key_repvgg = 'stage3.1.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.1."):
        key_repvgg = 'stage3.2.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.3."):
        key_repvgg = 'stage3.4.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.4."):
        key_repvgg = 'stage3.5.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.5."):
        key_repvgg = 'stage3.6.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.6."):
        key_repvgg = 'stage3.7.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.7."):
        key_repvgg = 'stage3.8.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.8."):
        key_repvgg = 'stage3.9.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.9."):
        key_repvgg = 'stage3.10.' + key_yolov5s[10:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.10."):
        key_repvgg = 'stage3.11.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.11."):
        key_repvgg = 'stage3.12.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.6.12."):
        key_repvgg = 'stage3.13.' + key_yolov5s[11:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    elif key_yolov5s.startswith("model.7."):
        key_repvgg = 'stage4.0.' + key_yolov5s[8:]
        yolov5s_parms[key_yolov5s] = repvgg_parms[key_repvgg]
    else:
        pass

print("yolov5s-repvgg新权重:")
log3 = open(path + "log_yolov5s_repvgg_update.txt", mode="a+", encoding="utf-8")
for k, v in yolov5s_parms.items():
    print(k, v, file=log3)

print("repvgg新权重:")
log4 = open(path + "log_repvgg_update.txt", mode="a+", encoding="utf-8")
for k, v in repvgg_parms.items():
    print(k, v, file=log4)

# 保存移植了repvgg的backbone的权重文件

torch.save(model, path+'yolov5s_repvgg_backbone.pt')

我的方式比较暴力,因为我发现他俩的key按照名字可以对上。
在这里插入图片描述
移植结束后,会生成一个yolov5s_repvgg_backbone.pt
我在他的基础上训了50epoch,效果不好,变差了。

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/Qingyou__/article/details/128629736