Error de registro: los parámetros de peso de pytorch no coinciden

Escenas:

Después de modificar el modelo yo mismo, me encontré con el problema de los parámetros de peso no coincidentes:
终端出现问题描述如下

size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).
size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

Análisis de causa

Esto se debe a que la capa completamente conectada en el modelo preentrenado descargado tiene 1000 categorías, y la categoría de código actual solo tiene 2 categorías, por lo que se informará un error de discrepancia.

solución:

Se puede ver en el mensaje de error que los parámetros de peso de la capa fc no coinciden, por lo que solo necesitamos no cargar los parámetros de esta capa.

net = se_resnet50(num_classes=2)
pretrained_dict = torch.load("./senet/seresnet50-60a8950a85b2b.pkl")

model_dict = net.state_dict()
# 重新制作预训练的权重,主要是减去参数不匹配的层,楼主这边层名为“fc”
pretrained_dict = {
    
    k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
# 更新权重
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

Supongo que te gusta

Origin blog.csdn.net/baobao135/article/details/129208772
Recomendado
Clasificación