pytorch 多个模型 求平均

from collections import OrderedDict

import torch

from models.faceland_d import FaceLanndInference_d

if __name__ == '__main__':

    model = FaceLanndInference_d()

    model_paths = ["./weights_d/0.0680_slim128_epoch_52.pth",
    "./weights_d/0.0680_slim128_epoch_52.pth"]
    if model_paths:
        bone_dict = model.state_dict()
        new_state_dict = OrderedDict()
        data_len=len(model_paths)
        for model_path in model_paths:
            state_dict = torch.load(model_path)

            for k, v in state_dict.items():
                head = k[:7]
                if head == 'module.':
                    tmp_name = k[7:]  # remove `module.`
                else:
                    tmp_name = k  # continue
                need_v = bone_dict[tmp_name]

                if tmp_name in new_state_dict:
                    new_state_dict[tmp_name] += v/data_len
                else:
                    new_state_dict[tmp_name] = v/data_len
        model.load_state_dict(new_state_dict, strict=False)
        torch.save(model.state_dict(), "new_weight.pth")

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/131365177