Pytorch distributed 多卡并行载入模型

Pytorch distributed 多卡并行载入模型

前面的博客介绍了pytorch多卡distribute的方法,这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。
大部分情况下,我们在测试时不需要多卡并行计算。所以,我在测试时只使用单卡。

from collections import OrderedDict


device = torch.device("cuda")

model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)
发布了131 篇原创文章 · 获赞 6 · 访问量 6919

猜你喜欢

转载自blog.csdn.net/Orientliu96/article/details/104702520