应用pytorch训练一个神经网络后,如何保存神经网络呢?
pytorch中有两种方法,第一种是将网络整体保存,第二种是保存神经网络的参数(推荐第二种!),这里就简单讲讲如何保存参数。
- 保存
训练完后加上如下代码
torch.save(model.state_dict(),“model_params.pkl”)
这样在运行结束后你会发现你此项目的文件夹里会多出来一个model_params.pkl文件,
代码中的model是我代码里定义的神经网络的名字,如果你的网络名字叫net,那就写torch.save(net.state_dict(),“model_params.pkl”)。 - 加载
加载时运用下面一行代码
model.load_state_dict(torch.load(‘model_params.pkl’))
我是在另一个python file中加载的,所以前面还要加上你的神经网络是怎么定义的,所以我的加载代码是这样的:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 10, 5, 1, 2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(10, 20, 5, 1, 2)
self.fc1 = nn.Linear(20*56*56, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 4)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 20*56*56)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
model.load_state_dict(torch.load('model_params.pkl'))
补充一下保存整体网络的方法
保存:
torch.save(model, ‘model_params.pkl’)
加载:
由于之前保存的是网络的整体结构,所以在加载的程序中不需要class Net(nn.Module):{……}这一项,这里与保存参书的方法不同。
只需要一行代码
model = torch.load(‘model_params.pkl’)
以上是全部内容,希望对你有帮助哦!