PyTorch-模型保存与加载

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiang425776024/article/details/88240531

保存: 

model = LinearRegression()
# ......各种操作
model.eval()
#训练完成,保存状态字典到linear.pkl
torch.save(model.state_dict(), './linear.pkl')

加载:

model = LinearRegression()
model.load_state_dict(torch.load('linear.pth'))
#...各种使用,比如预测...
x_test=np.arrar([..............])
x_test = torch.from_numpy(x_test)
predict_y = model(Variable(x_test))

猜你喜欢

转载自blog.csdn.net/jiang425776024/article/details/88240531
今日推荐