Pytorch模型的保存与读取方法

版权声明:本文为王小草原创文章,要转载请先联系本人哦 https://blog.csdn.net/sinat_33761963/article/details/84261043

方法一(推荐)

只保存和加载模型的参数

# 保存模型参数
def save_model(the_model, PATH):
    torch.save(the_model.state_dict(), PATH)
# 加载模型参数
def load_model(PATH):
    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))

方法二

在这种情况下,序列化的数据被绑定到特定的类和固定的目录结构,所以当在其他项目中使用时,或者在一些严重的重构器之后它可能会以各种方式break。

# 保存模型参数
def save_model(the_model, PATH):
    torch.save(the_model, PATH)
# 加载模型参数
def load_model(PATH):
    the_model = torch.load(PATH)

猜你喜欢

转载自blog.csdn.net/sinat_33761963/article/details/84261043
今日推荐