【Pytorch】保存神经网络模型

1 只保存模型参数

# 保存
torch.save(model.state_dict(), './parameter.pkl')
# 加载
model = TheModelClass(...)
model.load_state_dict(torch.load('./parameter.pkl'))

2 保存完整模型

# 保存
torch.save(model, './model.pkl')
# 加载
model = torch.load('./model.pkl')

猜你喜欢

转载自blog.csdn.net/ao1886/article/details/109137719