torch之网络模型的保存与读取
我们今天分别用两种方式来说明模型的保存与读取,申明一点,模型保存方式一定要与读取方式一致。例如:我们用方式一保存的模型,我们读取的时候也需要使用方式一来读取。
1.模型的保存
vgg16 = torchvision.models.vgg16(pretrained=False)
#方式一:
torch.save(vgg16, "vgg16_method1.pth")
#方式二:
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
2.模型的读取
#方式一:
model = torch.load("vgg16_method1.pth")
#方式二:
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))