RuntimeError: Error(s) in loading state_dict for SENET

错误提示:

RuntimeError: Error(s) in loading state_dict for SENET: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var"
 

错误原因

load_state_dict 中 strict 参数默认是 True :

表示预训练模型的层和自己定义的网络结构层严格对应相等(如层名和维度)

所以当我们修改了网络结构后,如果strict为True的时候就会报错

model = MODEL( num_classes= 500 , senet154_weight = WEIGHT_PATH, multi_scale = True, learn_region=True)
model = torch.nn.DataParallel(model)
vgg16 = model
vgg16.load_state_dict(torch.load('model/senet.pth'),False)

解决方案

将 load_state_dict 中 strict 参数设置为 False

修改前:

vgg16.load_state_dict(torch.load('model/food500_model.pth'))

修改后:

vgg16.load_state_dict(torch.load('model/food500_model.pth'),False)

猜你喜欢

转载自blog.csdn.net/qq_40905284/article/details/130718638