torch之网络模型的保存与读取

torch之网络模型的保存与读取

  我们今天分别用两种方式来说明模型的保存与读取,申明一点,模型保存方式一定要与读取方式一致。例如:我们用方式一保存的模型,我们读取的时候也需要使用方式一来读取。

1.模型的保存

vgg16 = torchvision.models.vgg16(pretrained=False)
#方式一:
torch.save(vgg16, "vgg16_method1.pth")
#方式二:
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

2.模型的读取

#方式一:
model = torch.load("vgg16_method1.pth")
#方式二:
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))

猜你喜欢

转载自blog.csdn.net/weixin_49005845/article/details/125653784
今日推荐