[Pytorch] Save the neural network model

1 Only save model parameters

# 保存
torch.save(model.state_dict(), './parameter.pkl')
# 加载
model = TheModelClass(...)
model.load_state_dict(torch.load('./parameter.pkl'))

2 Save the complete model

# 保存
torch.save(model, './model.pkl')
# 加载
model = torch.load('./model.pkl')

Guess you like

Origin blog.csdn.net/ao1886/article/details/109137719