错误提示:
RuntimeError: Error(s) in loading state_dict for SENET: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var"
错误原因
load_state_dict
中 strict 参数默认是 True :
表示预训练模型的层和自己定义的网络结构层严格对应相等(如层名和维度)
所以当我们修改了网络结构后,如果strict为True的时候就会报错
model = MODEL( num_classes= 500 , senet154_weight = WEIGHT_PATH, multi_scale = True, learn_region=True)
model = torch.nn.DataParallel(model)
vgg16 = model
vgg16.load_state_dict(torch.load('model/senet.pth'),False)
解决方案
将 load_state_dict
中 strict 参数设置为 False
修改前:
vgg16.load_state_dict(torch.load('model/food500_model.pth'))
修改后:
vgg16.load_state_dict(torch.load('model/food500_model.pth'),False)