深度学习——09模型的保存:torch.save()、加载:torch.load()

两种方式

保存模型主要分为两类:
1、保存整个模型
2、保存模型参数

1、第一种

结构模型+模型参数

保存整个网络模型,加载整个网络模型(可能比较耗时)

# 保存方式1
torch.save(vgg16, "vgg16_model1.pth")
# 对应保存的方式1
model = torch.load("vgg16_model1.pth")
print(model)

2、第二种

只保存加载模型参数(推荐)

保存模型的权重参数(速度快,占内存少)

# 保存方式2
torch.save(vgg16.state_dict(),"vgg16_model2.pth")
# 对应保存的方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
model2 = torch.load("vgg16_model2.pth")
vgg16.load_state_dict(model2)

假设网络为:
model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr),
假设在某个epoch,要保存模型参数,优化器参数以及epoch
先建立一个字典,保存三个参数:

state = { ‘net’: model.state_dict(), ‘optimizer’: optimizer.state_dict(), ‘epoch’: epoch}

torch.save(state, "./project/mymodel.pth")

当想恢复某一阶段的训练时,那就可以读取之前保存的网络模型参数等。

mymodel= torch.load("./project/mymodel.pth")
model.load_state_dict(mymodel['net'])  #  加载之前的网络模型参数
optimizer.load_state_dict(mymodel['optimizer'])  # 加载之前的优化器的参数
start_epoch = mymodel['epoch'] + 1  #  加载新的训练回合数

猜你喜欢

转载自blog.csdn.net/weixin_48501651/article/details/124794506