nn.Module源码介绍(二)--冻结参数篇

Module源码介绍

本篇是nn.Module源码系列文章介绍第二篇,建议读者在阅读第一篇之后,在来阅读本篇。当然,也可以直接阅读本文,因为提供了大量的实例。
第一篇地址



前置知识:Module中train/eval模块状态切换

  在上篇文章中,介绍了nn.Module是如何完成自定义网络的初始化的。比如现在我新建了一个如下的 conv+bn+conv 的简单网络。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.leconv1 = nn.Conv2d(1, 1, 1, 1, 0)
        self.lebn = nn.BatchNorm2d(1)
        self.leconv2 = nn.Conv2d(1, 1, 1, 1, 0)
    def forward(self,x):
        pass
if __name__ == '__main__':

    input = torch.ones(1,1,2,2)    # 伪造数据
    net = Net()
    for module in net.children():
        print('net包含的模块为:\n',module)
        for p in module.parameters():
            print('当前module需要学习的参数为:\n',p)

  为了方便,我把卷积核维度定义为1*1大小。通过运行上述代码可以发现:卷积核的参数仅有两个weight和bias,且其维度大小为1。BN层需要学习的参数也为两个:平移参数和形变参数。维度也为1。Okay,运行的结果图如下:
在这里插入图片描述
  从上图可以看出,总共需要学习6个参数,为啥是“要学习”?,因为每个参数后面均是 requires_grad=True。我们知道,模型有train状态和eval状态。简单来说就是训练时候让网络所有module(leconv1+lebn+leconv2)处于 train 状态,而测试时候让网络的所有module处于 eval 状态。那么nn.Module是如何区分这两种状态的呢?这里贴下nn.Module的源码:

def train(self: T, mode: bool = True) -> T:
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self
def eval(self: T) -> T:
    return self.train(False)

  函数特别简单:即若是train状态下:让net中所有module指定为True;而在eval状态下,则直接给train传入False即可。这样就修改了模型的状态。

实战:随意进行train/eval状态切换

  上述介绍仅仅是介绍了将一个网络中所有module要么全部转成train,要么全部转成eval。比较死板。那么,若仅想让leconv1处于eval状态,而让lebn和leconv2处于train状态呢(这种方式经常遇到,尤其在迁移学习过程中)?
  比较简单:就是找到leconv1然后改变leconv1状态即可。这里主要复写一下train方法即可。上代码:

import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.leconv1 = nn.Conv2d(1, 1, 1, 1, 0)
        self.lebn = nn.BatchNorm2d(1)
        self.leconv2 = nn.Conv2d(1, 1, 1, 1, 0)
    def train(self, mode = True) :
        super(Net, self).train()
        for name, module in self.named_children():   # 遍历模块
            if name == 'leconv1':                    # 若是 leconv1
                module.eval()                        # 则直接让其进入eval状态。
      def forward(self,x):
        pass
if __name__ == '__main__':
    net = Net()
    net.train()

  Okay,到目前为止,你可以随意更改一个网络中任意一层。但是若网络特别深,动辄几百层。这样一层一层找,显然不现实。而且在实际网络中,往往需要冻结所有BN层(此处不做讨论,原因可以自行百度),且看第三部分。

实战:冻结网络中所有BN层

  此处冻结就是让所有BN层处于eval状态:上代码:

import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.leconv1 = nn.Conv2d(1, 1, 1, 1, 0)
        self.lebn = nn.BatchNorm2d(1)
        self.leconv2 = nn.Conv2d(1, 1, 1, 1, 0)
    def train(self, mode = True) :
        super(Net, self).train()
        for module in self.children():
            if isinstance(module,nn.BatchNorm2d): # 若当前module为nn.BatchNorm2d
                module.eval()                       # 指定eval状态
    def forward(self,x):
        pass
if __name__ == '__main__':
    input = torch.ones(1,1,2,2)    # 伪造数据
    net = Net()
    net.train()

  通过上述就能冻结一个net中所有BN层。

nn.Module中指定梯度和梯度清0函数

  该节介绍nn.Module梯度处理函数:requires_grad和zero_grad函数:
  先来看requires_grad_函数:

def requires_grad_(self: T, requires_grad: bool = True) -> T:
     for p in self.parameters():
        p.requires_grad_(requires_grad)
    return self

  可以看出:循环网络中所有参数,然后递归调用requires_grad_函数,将所有参数的梯度设置为True。即这些参数需要更新梯度,需要进行学习。
  在来看看zero_grad_函数:

def zero_grad(self, set_to_none: bool = False) -> None:
    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

  主要借助最后一行代码,将梯度清0。

实战:冻结BN层梯度参数

  上节了解了冻结参数原理,现在假如冻结一个网络中所有BN层的梯度并将BN层内部参数均初始化为1。那么该如何写呢?

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.leconv1 = nn.Conv2d(1, 1, 1, 1, 0)
        self.lebn = nn.BatchNorm2d(1)
        self.leconv2 = nn.Conv2d(1, 1, 1, 1, 0)
    def train(self, mode = True) :
        super(Net, self).train()
        for module in self.children():
            if isinstance(module,nn.BatchNorm2d): # 若当前module为nn.BatchNorm2d
                for p in module.parameters():     # 遍历module中所有参数
                    p.data.fill_(1)               # 初始化为1
                    p.requires_grad_(False)       # 不更新梯度
    def forward(self,x):
        pass
if __name__ == '__main__':
    input = torch.ones(1,1,2,2)    # 伪造数据
    net = Net()
    net.train()
    for name, module in net.named_children():
        print('net包含的模块为:\n',module)
        for p in module.parameters():
            print('当前module需要学习的参数为:\n',p)

  现在在来看下输出结果:
在这里插入图片描述
 此时,初始化为1,且没了requires_grad这项,说明冻结参数成功。

总结

  读到这里读者可能有疑问:eval和requires_grad均能冻结参数。为啥需要两个呢?
  我感觉eval冻结的是module层面,而requires_grad可以直接冻结module里面的任意参数。一个宽泛点,一个更加精细点。在实际操作中,往往将二者混合使用(比如冻结resnet的第一阶段,同时冻结BN层)。
  下篇会介绍nn.Module中apply函数,用来初始化网络权重。后续还有hook的详解。

猜你喜欢

转载自blog.csdn.net/wulele2/article/details/112757387
今日推荐