pytorch carga parte del modelo previamente entrenado (proceso detallado de resolución del problema)

Primero, ingrese la dirección de código https://github.com/jfzhang95/pytorch-deeplab-xception y el código para cargar el modelo pre-entrenado, con algunas modificaciones

    pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)

Entre ellos, el bucle for busca la misma parte de la clave en el modelo y el modelo de preentrenamiento cargado, carga su valor en el modelo y finalmente actualiza el modelo.
Sin embargo, se produjo un error al reemplazar el modelo de
preentrenamiento : RuntimeError: Error (s) al cargar state_dict para XXX:
Falta clave (s) en state_dict:
Verifiqué algunos datos y descubrí que puede ser un problema causado por claves desalineadas, entonces, Eche un vistazo a la salida de claves de los dos modelos:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)                         #查看预训练模型的keys
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)                         #查看本地model的keys
    for k, v in pretrain_dict.items()
        if k in state_dict:
            model_dict[k] = v
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)                         #查看model更新后的keys

Compruebe los resultados de la siguiente manera:

module.backbone.conv1.weight
module.backbone.bn1.weight

...
Conv1.weight
bn1.weight
...
Efectivamente, las claves no corresponden. Después de leer muchas soluciones, descubrí que todas deben cargarse una por una. Se siente demasiado problemático, así que lo modifiqué de acuerdo con mis propias ideas y lo cargué con éxito:

    pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
    for k in pretrain_dict.keys():
        print(k)
    model_dict = {
    
    }
    state_dict = model.state_dict()
    for k in state_dict.keys():
        print(k)
    print("分界线") 
    for k, v in pretrain_dict.items():
        for i, j in state_dict.items():    #加上前缀后寻找对应的keys
            m = 'module.backbone.' + i
            if k == m :
                model_dict[i] = v
                print(i)
    state_dict.update(model_dict)
    model.load_state_dict(state_dict)
    for k in model_dict.keys():
        print(k)
    return model

De hecho, las claves del modelo tienen un prefijo.
Novato, registre el proceso de resolución del problema, hay buenos métodos, bienvenido a comunicarse.

Supongo que te gusta

Origin blog.csdn.net/qq_36804414/article/details/106718214
Recomendado
Clasificación