通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()

目录

0. 前言

1. Pytorch框架加载与保存权重的方法

2. 实例问题说明

3. 加载权重数据

4. 保存权重数据


0. 前言

在深度学习实际应用中,往往涉及到的神经元网络模型都很大,权重参数众多,因此会导致训练epoch次数很多,训练时间长。

如果每次调整非模型相关的参数(训练数据集、优化函数类型、学习率、迭代次数)都要重新训练一次模型,这显然会浪费大量的训练时间。

而且,对于一些成熟的网络模型,已经有前人做过大量的“预训练”,这时如果能基于前人预训练的结果,训练自己的数据集,明显会事半功倍。

因此,加载与保存权重在深度学习实际使用中有很大的必要。

1. Pytorch框架加载与保存权重的方法

①加载权重的方法: .load_state_dict()方法说明:

.load_state_dict()定义:

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):

- state_dict :即要加载的权重,通常是一个文件地址;

- strick: 可以理解为等于"True"时是“精确匹配”,要求要加载的权重与要被加载权重的模型完全匹配。

Pytorch源文件注释:

Args:
    state_dict (dict): a dict containing parameters and
        persistent buffers.
    strict (bool, optional): whether to strictly enforce that the keys
        in :attr:`state_dict` match the keys returned by this module's
        :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

*小注释:meth笔误了,应该是mesh,网格

②保存权重的方法:.save()方法说明:

.save()定义:

def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:

- obj:要保存的权重参数;

- f:保存的文件路径

这里仅说明.save()在保存网络模型权重数据上的作用。实际上.save()还有很多应用,例如:保存整个网络,这里不再赘述。

Pytorch源文件注释:

"""Saves an object to a disk file.

See also: `saving-loading-tensors`

Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string or
       os.PathLike object containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

2. 实例问题说明

首先说明本次的实例问题:本次要构建的神经元网络为一个“平方网络”,即网络输出数据为输入数据的平方。

网络模型结构:

输入(1)→全连接层(1×5)→Sigmoid激活函数(5)→全连接层(5×5)→Sigmoid激活函数(5)→全连接层(5×1)→输出(1)

训练数据:

输入数据[1, 2, 3, 4, 5]; 

输出数据[1, 4, 9, 16, 25]

3. 加载权重数据

直接上代码

import torch

class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )

    def forward(self,x):
        return self.net(x)

square_net = LinearNet(1,1)

square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重

if __name__ == '__main__':
    print(square_net(torch.tensor([3.16],dtype=torch.float32)))

其中weight.pth是我已经训练好的权重数据路径,这里定义好网络模型后,直接加载权重数据,不必关心这个权重是如何训练来的,更不必关系具体权重的值是多少。测试输入为3.16输出为:

tensor([9.9180], grad_fn=<AddBackward0>)

这里要注意的是:因为上面strict默认为True,即为“精确匹配”,这里新构建的网络模型结构必须和权重来源的网络模型结构相同

4. 保存权重数据

import torch

input = torch.tensor([[1],[2],[3],[4],[5]], dtype=torch.float32)
output = torch.tensor([[1],[4],[9],[16],[25]], dtype=torch.float32)

class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )

    def forward(self,x):
        return self.net(x)

Loss = torch.nn.MSELoss()
linear_net = LinearNet(1,1)
opt = torch.optim.SGD(linear_net.parameters(), lr= 0.003)

for k in range(1000):
    opt.zero_grad()
    for i in range(len(input)):
        train_out = linear_net(input[i])
        loss = Loss(train_out, output[i])

        loss.backward()
        opt.step()

torch.save(linear_net.state_dict(),'weight.pth')   #保存.pth权重文件

for keys,values in linear_net.state_dict().items():   #查看权重名称及值
    print(keys)
    print(values)
    print('************************************************************************')

if __name__ == '__main__':
    print(linear_net(torch.tensor([3.16],dtype=torch.float32)))

这里可以看到具体的训练过程及相关的训练参数,权重保存在'weight.pth'文件中。

可以通过print查看具体的权重数值:

net.0.weight
tensor([[-0.8204],
        [-1.7341],
        [-0.6987],
        [ 0.9370],
        [-1.5558]])
************************************************************************
net.0.bias
tensor([ 0.9285,  2.1061,  1.0247, -2.9221,  7.1159])
************************************************************************
net.2.weight
tensor([[-1.6075, -1.3072, -1.5342,  2.4527, -3.9922],
        [-0.7101, -1.5125, -0.6791,  2.0325, -2.3406],
        [-1.1707, -1.6899, -0.9883,  2.9682, -1.5409],
        [-1.1992, -2.0559, -0.7610,  2.3890, -1.3782],
        [-1.1274, -1.7907, -1.0860,  2.3549, -3.6847]])
************************************************************************
net.2.bias
tensor([0.4826, 0.7057, 0.9702, 1.0532, 0.4214])
************************************************************************
net.4.weight
tensor([[7.3601, 4.7667, 6.2473, 5.0187, 7.2028]])
************************************************************************
net.4.bias
tensor([-0.2476])
************************************************************************

猜你喜欢

转载自blog.csdn.net/m0_49963403/article/details/129912258