【动手学深度学习Pycharm实现8】Pytorch神经网络参数的保存与读取

前言

很久没更新了,第一个原因是学校的课程任务,第二个原因是在kaggle实战去了,我参加的是泰坦尼克那个比赛,调了快一周的代码,收获也是不小,感受最大的就是:在机器学习的任务中,非常非常重要的就是特征工程,同样的模型,一个好的特征处理工程能让你的准确率提升百分之几,在kaggle上这能让你的排名上升非常多,这是一篇kaggle经验文章,也可以看csdn上的中译版:中译版点这里,我看完之后感觉受益良多。然后,步入正题:今天笔记的主题就是Pytorch神经网络参数的保存与读取了。


环境配置:

python版本:3.8.6
torch版本:1.11.0

导入的库:

import torch
from torch import nn

一、创建一个简单神经网络

多层感知机:

net = nn.Sequential(
    nn.Linear(5, 2),
    nn.ReLU(),
    nn.Linear(2, 1)
)

这里的net其实就是nn.Sequential这个类的实例,

二、保存神经网络参数

torch.save(net.state_dict(), 'net.params')

!!! 注意:这里只是保存了神经网络的参数,而没有保存神经网络的结构,也就是说,如果我们后面要读取参数,必须要创建一个与之前相同结构的神经网络。

三、克隆之前的网络

net_clone = nn.Sequential(
    nn.Linear(5, 2),
    nn.ReLU(),
    nn.Linear(2, 1)
)

四、克隆的网络加载本体网络保存的参数

net_clone.load_state_dict(torch.load('net.params'))

五、验证两个网络参数是否一致

print('net:', net.state_dict())
print('net_clone:', net_clone.state_dict())

输出结果如下:
在这里插入图片描述
参数一致,保存成功!
或许会有细心的小伙伴发现,为什么只有"0"层和 "2"层的参数,"1"层的呢?其实这里"1"层是ReLU激活层,没有参数,所以没有显示出来了。

附:总代码

我真的没有在凑字数

import torch
from torch import nn

net = nn.Sequential(
    nn.Linear(5, 2),
    nn.ReLU(),
    nn.Linear(2, 1)
)

torch.save(net.state_dict(), 'net.params')

net_clone = nn.Sequential(
    nn.Linear(5, 2),
    nn.ReLU(),
    nn.Linear(2, 1)
)

net_clone.load_state_dict(torch.load('net.params'))

print('net:', net.state_dict())
print('net_clone:', net_clone.state_dict())


猜你喜欢

转载自blog.csdn.net/weixin_45887062/article/details/124855586