pytorch模型保存(save)与读取(load)

模型参数

1.仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

2.加载model.state_dict

model = TheModelClass(*args, **kwargs) #加载模型类
model.load_state_dict(torch.load(PATH))#加载模型参数
model.eval() #推理模式

整个模型状态

1.保存整个模型的状态:

torch.save(model,PATH)

2.加载整个模型状态:

model = torch.load(PATH)
 
model.eval()

猜你喜欢

转载自blog.csdn.net/weixin_44866921/article/details/132106597
今日推荐