nn.Module源码介绍(三)---模型加载/保存乱炖

前言

断断续续写了好久,本篇是最后一篇。太肝了,写这东西,写这东西收视率不高,实在不行就换实战把。



一、创建一个简单网络并训练保存

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# 产生数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

x , y =(Variable(x),Variable(y))


class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden1 = nn.Linear(n_input,n_hidden)
        self.hidden2 = nn.Linear(n_hidden,n_hidden)
        self.predict = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.hidden1(input)
        out = F.relu(out)
        out = self.hidden2(out)
        out = F.relu(out)
        out =self.predict(out)
        return out

net = Net(1,20,1)
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()

for t in range(200):
    prediction = net(x)
    loss = loss_func(prediction,y)
    if t % 10 == 0:
        print('LOSS:',loss.data)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

torch.save(net,'net.pkl')
torch.save(net.state_dict(),'net_parameter.pkl')

 上述我们创建了一个三层全连接网络。并通过torch.save()进行保存。
 其中:torch.save(net,path) —> 保存整个模型;
 torch.save(net.state_dict(),path) --> 保存模型的参数。

二、State_dict()源码解读

 这个系列文章是用来介绍nn.Module源码的,因此,我将介绍下Module是如何保存模型的。
 这里torch.save(net,path)就是将整个net进行序列化了,没有介绍的东西。
 这里介绍net.state_dict():

def _save_to_state_dict(self, destination, prefix, keep_vars):
    for name, param in self._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.detach()
    for name, buf in self._buffers.items():
        if buf is not None and name not in self._non_persistent_buffers_set:
            destination[prefix + name] = buf if keep_vars else buf.detach()

def state_dict(self, destination=None, prefix='', keep_vars=False):
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()
    destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
    self._save_to_state_dict(destination, prefix, keep_vars)
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) # 递归调用
    for hook in self._state_dict_hooks.values():
        hook_result = hook(self, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination

 简单介绍下逻辑:首先创建一个有序字典,之后通过函数_save_to_state_dict()函数添加参数及buffer。然后递归每一个module。逐个将模型参数进行保存。Hook这一块后期介绍。
 此处可以打印下print(net.state_dict())。会出现:

OrderedDict([('hidden1.weight', tensor([[-0.8167], [ 0.7373]])),
             ('hidden1.bias', tensor([0.2706, 0.4230])), 
             ('hidden2.weight', tensor([[0.4133, 0.0363],[0.2908, 0.4775]])), ('hidden2.bias', tensor([-0.5193,  0.1890])), 
             ('predict.weight', tensor([[-0.4416,  0.0453]])), ('predict.bias', tensor([-0.0529]))])

 实际上就是一个有序字典。

三、加载模型

 加载模型就是使用:
 torch.load(path) 加载 整个模型
  net.load_state_dict(torch.load(path)) 加载模型参数
 同上,我们看下net.load_state_dict源代码。

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    persistent_buffers = {
    
    k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
    local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
    local_state = {
    
    k: v for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            input_param = state_dict[key]

            # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
            if len(param.shape) == 0 and len(input_param.shape) == 1:
                input_param = input_param[0]

            if input_param.shape != param.shape:
                # local shape should match the one in checkpoint
                error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                  'the shape in current model is {}.'
                                  .format(key, input_param.shape, param.shape))
                continue

            try:
                with torch.no_grad():
                    param.copy_(input_param)
            except Exception as ex:
                error_msgs.append('While copying the parameter named "{}", '
                                  'whose dimensions in the model are {} and '
                                  'whose dimensions in the checkpoint are {}, '
                                  'an exception occurred : {}.'
                                  .format(key, param.size(), input_param.size(), ex.args))
        elif strict:
            missing_keys.append(key)

    if strict:
        for key in state_dict.keys():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules and input_name not in local_state:
                    unexpected_keys.append(key)

def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
                    strict: bool = True):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {
    
    } if metadata is None else metadata.get(prefix[:-1], {
    
    })
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(self)
    load = None  # break load->load reference cycle

    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in missing_keys)))

    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                           self.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

 简单介绍下思路:创建了三个list:missing_keys,unexpected_keys和error_msgs。这三个list最终若不为None,则说明加载过程中报错了。而加载过程也是递归module,之后遍历有序字典,将key和value一一对应。

四、实战:保存/加载(模型+优化器+epoch)

 其实前面三部分只介绍了两件事:
(1)保存模型
 torch.save(net,path)
 torch.save(net.state_dict(),path)
(2)加载模型
 torch.load(net)
  net.state_dict(torch.load(net))
 但以上并不是在实际中最通用的。在实际训练网络过程中,在训练一定epoch后,会保存参数。但是存在一个问题,倘若由于特殊原因导致训练中断了,当时中断处的学习率以及epoch是多少的信息会丢失。因此,为了程序更加鲁棒,会同时保存epoch,参数及优化器的状态。Okay,以第一部分的代码为例,修改过代码为:

net = Net(1,2,1)
start_epoch = -1
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()
for epoch in range(start_epoch + 1,20):
    prediction = net(x)
    loss = loss_func(prediction,y)
    if epoch % 5 == 0:                  # 假如每训练5轮打印此loss,并保存模型
        print('LOSS:',loss.data)
        checkpoint = {
    
    
            'net': net.state_dict(),     # 保存模型
            'optimizer': optimizer.state_dict(), # 保存优化器
            'epoch':epoch                # 保存训练轮数
        }
        torch.save(checkpoint,'net%s.pkl'%(str(epoch)))

 Okay,当然还差断点 加载 部分。我们在修改下,

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# 产生数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

x , y =(Variable(x),Variable(y))


class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden1 = nn.Linear(n_input,n_hidden)
        self.hidden2 = nn.Linear(n_hidden,n_hidden)
        self.predict = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.hidden1(input)
        out = F.relu(out)
        out = self.hidden2(out)
        out = F.relu(out)
        out =self.predict(out)
        return out

net = Net(1,2,1)

start_epoch = -1
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()

Resume = True
if Resume:  # 若加载
    checkpoint = torch.load("/home/wujian/leleDetections/Save_and_Load/net10.pkl")
    net.load_state_dict(checkpoint['net'])              # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器
    start_epoch = checkpoint['epoch']                   # 加载训练轮数

for epoch in range(start_epoch + 1,20):
    prediction = net(x)
    loss = loss_func(prediction,y)

    if epoch % 5 == 0:                   # 假如每训练5轮打印此loss,并保存模型
        print('epoch:%s,LOSS:%d'%(str(epoch),loss.data))
        checkpoint = {
    
    
            'net': net.state_dict(),     # 保存模型
            'optimizer': optimizer.state_dict(), # 保存优化器
            'epoch':epoch                # 保存训练轮数
        }
        torch.save(checkpoint,'net%s.pkl'%(str(epoch)))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

 其中Resume参数控制是否载入模型。

五、网络相同但命名方式不同的加载解析

 在实际问题中,我们往往需要使用别人预训练模型,比如会使用官方提供的vgg.pth。但是假如自己实现相同vgg,加载vgg.pth肯定会出现keys不匹配错误。以简单例子为例:

checkpoint = torch.load('net_parameter.pkl')
print(checkpoint.keys())

 输出结果为:

(['hidden1.weight', 'hidden1.bias', 'hidden2.weight', 'hidden2.bias', 'predict.weight', 'predict.bias'])

 从这可以看出,我们网络命名方式为hidden1…
 现在假如新建一个相同的net2,但各个模块命名方式不同。如何加载呢 ?

class Net1(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net1,self).__init__()
        self.le1 = nn.Linear(n_input,n_hidden)
        self.le2 = nn.Linear(n_hidden,n_hidden)
        self.prele = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.le1(input)
        out = F.relu(out)
        out = self.le2(out)
        out = F.relu(out)
        out =self.prele(out)
        return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')    # 加载模型
net1.load_state_dict(checkpoint)        

 会出现如下错误(出现错误原因请看源码解析第二部分):
在这里插入图片描述
 这里解决办法有两种:
 (1)通过改写字典中keys名称,使其一一对应肯定就okay。
上代码:

class Net1(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net1,self).__init__()
        self.le1 = nn.Linear(n_input,n_hidden)
        self.le2 = nn.Linear(n_hidden,n_hidden)
        self.prele = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.le1(input)
        out = F.relu(out)
        out = self.le2(out)
        out = F.relu(out)
        out =self.prele(out)
        return out

def load_from_state_dict(checkpoint,net):
    ori_keys = checkpoint.keys()         # 取出权重中key
    now_keys = net.state_dict().keys()   # 取出现在网络key
    # 将权重的key重命名为现在网络中key
    for ori_key,now_key in zip(list(ori_keys),list(now_keys)):
        checkpoint[now_key] = checkpoint.pop(ori_key)          # 更新字典中的键
    return checkpoint

net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')   
checkpoint = load_from_state_dict(checkpoint,net1)
print('更改后键的名城:\n',checkpoint.keys())
net1.load_state_dict(checkpoint)                  # 加载模型

 (2)控制load_state_dict(checkpoint,strict)中strict参数。
 第二种方法相较于第一种方法更加简单,直接修改一个strict参数,另其为False即可:

class Net1(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net1,self).__init__()
        self.le1 = nn.Linear(n_input,n_hidden)
        self.le2 = nn.Linear(n_hidden,n_hidden)
        self.prele = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.le1(input)
        out = F.relu(out)
        out = self.le2(out)
        out = F.relu(out)
        out =self.prele(out)
        return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')
net1.load_state_dict(checkpoint,strict=False)

 原因:在源码中:stirct参数控制权重是否严格匹配新网络中键。默认为True,即严格匹配。若不匹配,则往missing_keys列表中添加错误key,然后导致RuntimeError。若另其为False,则不执行这部分,而仅仅考虑是否 权重 shape 匹配的问题。

    elif strict:
        missing_keys.append(key)
if strict:
    for key in state_dict.keys():
        if key.startswith(prefix):
            input_name = key[len(prefix):]
            input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
            if input_name not in self._modules and input_name not in local_state:
                unexpected_keys.append(key)

六、网络不同的权重加载解析

 上部分考虑的是网络相同,但还是不够鲁棒。因为在实际任务中,往往预训练加载模型和实际模型是不同的。有可能只加载一部分权重,或者在新网络中添加新的结构。这个时候如何加载呢?
 (1)比如下面代码,我在原有网络基础上又添加了一层,加载可直接通过strict = False。

class Net1(nn.Module):


    def __init__(self,n_input,n_hidden,n_output):
        super(Net1,self).__init__()
        self.le1 = nn.Linear(n_input,n_hidden)
        self.le2 = nn.Linear(n_hidden,n_hidden)
        self.prele = nn.Linear(n_hidden,n_output)
        self.add_le = nn.Linear(n_output,n_output)      # 添加了一层
    def forward(self,input):
        out = self.le1(input)
        out = F.relu(out)
        out = self.le2(out)
        out = F.relu(out)
        out =self.prele(out)
        out = self.add_le(out)
        return out
net1 = Net1(1,2,1)
checkpoint = torch.load('net_parameter.pkl')
print('原始权重:',checkpoint)
net1.load_state_dict(checkpoint,strict=False)
print('加载后网络权重:',list(net1.named_parameters()))

总结!!!:在加载模型过程中,实际都是直接令strict=False.而load_state_dict函数会自动比较权重中key和自定义网络中key。若相等,则就加载key对应的权重值。若两个key不等,则不加载。
 其实深入考虑下:比如在alexnet解决10分类手写数字问题,若我现在想做个6分类任务,但我的net前半部分权重需要alexnet权重。因此,可以直接stirct=False,同时只要修改alexnet的分类层的key换个名字使其不要和alxenet中权重名字相同即可,核心代码如下:

pretrained_dict = torch.load('models/cifar10_statedict.pkl')
model_dict = model.state_dict()
print('随机初始化权重第一层:',model_dict['conv1.0.weight'])

# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict}
print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])
# 更新现有的model_dict
model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型
model.load_state_dict(model_dict)

 总之一句话:上述代码只是一种思路,懂得原理实际上自己可以任意写,愿意加载哪一层就加载哪一层,不要拘泥。

附录、其他常用的代码片段

(1)加载任意权重片段

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    # override the _load_from_state_dict function
    # convert the backbone weights pre-trained in Mask R-CNN
    # use list(state_dict.keys()) to avoid
    # RuntimeError: OrderedDict mutated during iteration
    for key_name in list(state_dict.keys()):
        key_changed = True
        if key_name.startswith('backbone.'):
            new_key_name = f'img_backbone{key_name[8:]}'
        elif key_name.startswith('neck.'):
            new_key_name = f'img_neck{key_name[4:]}'
        elif key_name.startswith('rpn_head.'):
            new_key_name = f'img_rpn_head{key_name[8:]}'
        elif key_name.startswith('roi_head.'):
            new_key_name = f'img_roi_head{key_name[8:]}'
        else:
            key_changed = False
        if key_changed:
            logger = get_root_logger()
            print_log(
                f'{key_name} renamed to be {new_key_name}', logger=logger)
            state_dict[new_key_name] = state_dict.pop(key_name)
    super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                  strict, missing_keys, unexpected_keys,
                                  error_msgs)

(2)多步长SGD训练:

#这里我设置了不同的epoch对应不同的学习率衰减,在10->20->30,学习率依次衰减为原来的0.1,即一个数量级
lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)


#加载恢复
if RESUME:
    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
#保存
for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()
    lr_schedule.step()
    if epoch %10 ==0:
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
        checkpoint = {
    
    
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("./model_parameter/test"):
            os.mkdir("./model_parameter/test")
        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

参考文献

https://zhuanlan.zhihu.com/p/133250753

猜你喜欢

转载自blog.csdn.net/wulele2/article/details/113393472