pytorch保存与加载神经网络方法

应用pytorch训练一个神经网络后,如何保存神经网络呢?
pytorch中有两种方法,第一种是将网络整体保存,第二种是保存神经网络的参数(推荐第二种!),这里就简单讲讲如何保存参数。

  • 保存
    训练完后加上如下代码
    torch.save(model.state_dict(),“model_params.pkl”)
    这样在运行结束后你会发现你此项目的文件夹里会多出来一个model_params.pkl文件,
    代码中的model是我代码里定义的神经网络的名字,如果你的网络名字叫net,那就写torch.save(net.state_dict(),“model_params.pkl”)。
  • 加载
    加载时运用下面一行代码
    model.load_state_dict(torch.load(‘model_params.pkl’))
    我是在另一个python file中加载的,所以前面还要加上你的神经网络是怎么定义的,所以我的加载代码是这样的:
class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(3, 10, 5, 1, 2)
       self.pool = nn.MaxPool2d(2, 2)
       self.conv2 = nn.Conv2d(10, 20, 5, 1, 2)
       self.fc1 = nn.Linear(20*56*56, 120)
       self.fc2 = nn.Linear(120, 84)
       self.fc3 = nn.Linear(84, 4)

   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = x.view(-1, 20*56*56)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x
       
model = Net()

model.load_state_dict(torch.load('model_params.pkl'))

补充一下保存整体网络的方法
保存:
torch.save(model, ‘model_params.pkl’)

加载:
由于之前保存的是网络的整体结构,所以在加载的程序中不需要class Net(nn.Module):{……}这一项,这里与保存参书的方法不同。
只需要一行代码
model = torch.load(‘model_params.pkl’)

以上是全部内容,希望对你有帮助哦!

猜你喜欢

转载自blog.csdn.net/zzy_pphz/article/details/104728662