pytorch中的模型保存

官方宣称,保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path)

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path))

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path)

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path)

功能:加载保存到path中的整个神经网络

经过自己的尝试之后,发现这种方式只能保存nn.Module模块中的参数,如果想要保存global_step之类的信息,需要一些小技巧:

state = {'net':model.state_dict(),
	'optimizer':optim.state_dict(),
	'global_step':global_step,
	'best_acc':best_acc,
	'best_step':best_step}
torch.save(state, args.saved_model_path)

然后加载的时候,用如下的方式加载:

checkpoint = torch.load(args.saved_model_path)
model.load_state_dict(checkpoint['net'])
optim.load_state_dict(checkpoint['optimizer'])
best_acc = checkpoint['best_acc']
best_step = checkpoint['best_step']
global_step = checkpoint['global_step']

python的变量也可以用类似这样的方式保存和加载,不得不说pytorch真的是很方便啊。
 

猜你喜欢

转载自blog.csdn.net/bonjourdeutsch/article/details/103085397