Pytorch 多GPU训练过程

    if opt.use_gpu:
        model.cuda()
        model=nn.DataParallel(model,device_ids=[0,1,2]) # multi-GPU

将模型改为多卡训练以后,直接打印模型的名字会报错:
AttributeError: ‘DataParallel’ object has no attribute ‘model_name’,
加上module可以正常打印。

print(model.module.model_name)

DataParallel也是一个Pytorch的nn.Module,只是这个类其中有一个module的变量用来保存传入的实际模型。nn.DataParallel(m)这句返回的已经不是原始的m了,而是一个DataParallel,原始的m保存在DataParallel的module变量里面。

保存的时候直接取出原始的m:

torch.save(m.module.state_dict(), path)

同样的硬件环境下加载模型时:

def load(self, path):
    '''
    load the model with specific path
    '''
    self.load_state_dict(t.load(path))

model = getattr(models, opt.model)().eval()
if opt.load_model_path:
    model.load(opt.load_model_path)
if opt.use_gpu:
    model.cuda()
    model=nn.DataParallel(model,device_ids=opt.device_id) # multi-GPU
发布了9 篇原创文章 · 获赞 0 · 访问量 162

猜你喜欢

转载自blog.csdn.net/weixin_37532614/article/details/104640399