if opt.use_gpu:
model.cuda()
model=nn.DataParallel(model,device_ids=[0,1,2]) # multi-GPU
将模型改为多卡训练以后,直接打印模型的名字会报错:
AttributeError: ‘DataParallel’ object has no attribute ‘model_name’,
加上module可以正常打印。
print(model.module.model_name)
DataParallel也是一个Pytorch的nn.Module,只是这个类其中有一个module的变量用来保存传入的实际模型。nn.DataParallel(m)这句返回的已经不是原始的m了,而是一个DataParallel,原始的m保存在DataParallel的module变量里面。
保存的时候直接取出原始的m:
torch.save(m.module.state_dict(), path)
同样的硬件环境下加载模型时:
def load(self, path):
'''
load the model with specific path
'''
self.load_state_dict(t.load(path))
model = getattr(models, opt.model)().eval()
if opt.load_model_path:
model.load(opt.load_model_path)
if opt.use_gpu:
model.cuda()
model=nn.DataParallel(model,device_ids=opt.device_id) # multi-GPU