pytorch中训练模型的保存和提取

有两种方式

只保存模型的参数:

torch.save(the_model.state_dict(), PATH)


之后需要的时候把模型参数再次提取出来:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

保存整个模型

torch.save(the_model, PATH)

提取时:

the_model = torch.load(PATH)
发布了2 篇原创文章 · 获赞 9 · 访问量 164

猜你喜欢

转载自blog.csdn.net/qq_45171138/article/details/104545944
今日推荐