深度学习知识六:(模型量化压缩)----pytorch自定义Module,并通过其理解DoReFaNet网络定义方法。

参考中文官方,详情参考:PyTorch 如何自定义 Module

1.自定义Module

在这里插入图片描述
Module 是 pytorch 组织神经网络的基本方式。Module 包含了模型的参数以及计算逻辑。Function 承载了实际的功能,定义了前向和后向的计算逻辑。
下面以最简单的 MLP 网络结构为例,介绍下如何实现自定义网络结构。完整代码可以参见repo

1.1 Function

Function 是 pytorch 自动求导机制的核心类。Function 是无参数或者说无状态的,它只负责接收输入,返回相应的输出;对于反向,它接收输出相应的梯度,返回输入相应的梯度。

这里我们只关注如何自定义 Function。Function 的定义见源码。下面是简化的代码段:

class Function(object):
    def forward(self, *input):
        raise NotImplementedError
 
    def backward(self, *grad_output):
        raise NotImplementedError

forward 和 backward 的输入和输出都是 Tensor 对象。

Function 对象是 callable 的,即可以通过()的方式进行调用。其中调用的输入和输出都为 Variable 对象。下面的代码示例了如何实现一个 ReLU 激活函数并进行调用:


import torch
from torch.autograd import Function
 
class ReLUF(Function):
    def forward(self, input):
        self.save_for_backward(input)
 
        output = input.clamp(min=0)
        return output
 
    def backward(self, output_grad):
        input = self.to_save[0]
 
        input_grad = output_grad.clone()
        input_grad[input < 0] = 0
        return input_grad
 
## Test
if __name__ == "__main__":
      from torch.autograd import Variable
 
      torch.manual_seed(1111)  
      a = torch.randn(2, 3)
 
      va = Variable(a, requires_grad=True)
      vb = ReLUF()(va)
      print va.data, vb.data
 
      vb.backward(torch.ones(va.size()))
      print vb.grad.data, va.grad.data

如果 backward 中需要用到 forward 的输入,需要在 forward 中显式的保存需要的输入。在上面的代码中,forward 利用self.save_for_backward函数,将输入暂时保存,并在 backward 中利用saved_tensors (python tuple 对象) 取出。

显然,forward 的输入应该和 backward 的输入相对应;同时,forward 的输出应该和 backward 的输入相匹配。
注意:
由于 Function 可能需要暂存 input tensor,因此,建议不复用 Function 对象,以避免遇到内存提前释放的问题。
如示例代码所示,forward的每次调用都重新生成一个 ReLUF 对象,而不能在初始化时生成在 forward 中反复调用。(意思示例代码是对的,其是每个fc都生成一个新对象。因为每个对象都自己在管理自己的的权重跟输入

2 Module

类似于 Function,Module 对象也是 callable 是,输入和输出也是 Variable。不同的是,Module 是[可以]有参数的。Module 包含两个主要部分:参数及计算逻辑(Function 调用)。由于ReLU激活函数没有参数,这里我们以最基本的全连接层为例来说明如何自定义Module。
全连接层的运算逻辑定义如下 Function:

import torch
from torch.autograd import Function
 
class LinearF(Function):
 
     def forward(self, input, weight, bias=None):
         self.save_for_backward(input, weight, bias)
 
         output = torch.mm(input, weight.t())
         if bias is not None:
             output += bias.unsqueeze(0).expand_as(output)
 
         return output
 
     def backward(self, grad_output):
         input, weight, bias = self.saved_tensors
 
         grad_input = grad_weight = grad_bias = None
         if self.needs_input_grad[0]:
             grad_input = torch.mm(grad_output, weight)
         if self.needs_input_grad[1]:
             grad_weight = torch.mm(grad_output.t(), input)
         if bias is not None and self.needs_input_grad[2]:
             grad_bias = grad_output.sum(0).squeeze(0)
 
         if bias is not None:
             return grad_input, grad_weight, grad_bias
         else:
             return grad_input, grad_weight

needs_input_grad 为一个元素为 bool 型的 tuple,长度与 forward 的参数数量相同,用来标识各个输入是否输入计算梯度;对于无需梯度的输入,可以减少不必要的计算。

Function(此处为 LinearF) 定义了基本的计算逻辑,Module 只需要在初始化时为参数分配内存空间,并在计算时,将参数传递给相应的 Function 对象。代码如下:

import torch
import torch.nn as nn
 
class Linear(nn.Module):
 
    def __init__(self, in_features, out_features, bias=True):
         super(Linear, self).__init__()
         self.in_features = in_features
         self.out_features = out_features
         self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
         if bias:
             self.bias = nn.Parameter(torch.Tensor(out_features))
         else:
            self.register_parameter('bias', None)
 
    def forward(self, input):
         return LinearF()(input, self.weight, self.bias)

需要注意的是,参数是内存空间由 tensor 对象维护,但 tensor 需要包装为一个Parameter 对象。Parameter 是 Variable 的特殊子类,仅有是不同是 Parameter 默认requires_grad为 True。Varaible 是自动求导机制的核心类,此处暂不介绍,参见教程

2.3自定义循环神经网络(RNN)

可运行代码参考:
RNN
其中Parameters是Variable的一个子类,而且其是自动求导机制,所以我们在定义代码网络的时候基本不需要从在backwards,只要定义好forward就可以了

3 定义DoReFaNet

由于其压缩的是权重,所以我们只需要定义一个卷积类,激活函数类,然后去获取他里面的权重,进行量化,定义网络如下:

class AlexNet_Q(nn.Module):
  def __init__(self, wbit, abit, num_classes=1000):
    super(AlexNet_Q, self).__init__()
    Conv2d = conv2d_Q_fn(w_bit=wbit)
    Linear = linear_Q_fn(w_bit=wbit)

    self.features = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
      nn.BatchNorm2d(96),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      Conv2d(96, 256, kernel_size=5, padding=2),
      nn.BatchNorm2d(256),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),
      nn.MaxPool2d(kernel_size=3, stride=2),

      Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),

      Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),

      Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),
      nn.MaxPool2d(kernel_size=3, stride=2),
    )
    self.classifier = nn.Sequential(
      Linear(256 * 6 * 6, 4096),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),

      Linear(4096, 4096),
      nn.ReLU(inplace=True),
      activation_quantize_fn(a_bit=abit),
      nn.Linear(4096, num_classes),
    )

    for m in self.modules():
      if isinstance(m, Conv2d) or isinstance(m, Linear):
        init.xavier_normal_(m.weight.data)

  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), 256 * 6 * 6)
    x = self.classifier(x)
    return 
  • 其中自定义权重量化如下:

class weight_quantize_fn(nn.Module):
  def __init__(self, w_bit):
    super(weight_quantize_fn, self).__init__()
    assert w_bit <= 8 or w_bit == 32
    self.w_bit = w_bit
    self.uniform_q = uniform_quantize(k=w_bit)

  def forward(self, x):
    if self.w_bit == 32:
      weight_q = x
    elif self.w_bit == 1:
      E = torch.mean(torch.abs(x)).detach()
      weight_q = self.uniform_q(x / E) * E
    else:
      weight = torch.tanh(x)
      weight = weight / 2 / torch.max(torch.abs(weight)) + 0.5
      weight_q = 2 * self.uniform_q(weight) - 1
    return weight_q

def conv2d_Q_fn(w_bit):
  class Conv2d_Q(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
      super(Conv2d_Q, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias)
      self.w_bit = w_bit
      self.quantize_fn = weight_quantize_fn(w_bit=w_bit)

    def forward(self, input, order=None):
      weight_q = self.quantize_fn(self.weight)
      # print(np.unique(weight_q.detach().numpy()))
      return F.conv2d(input, weight_q, self.bias, self.stride,
                      self.padding, self.dilation, self.groups)

  return Conv2d_Q

其中uniform_quantize进行比特量化函数为:

def uniform_quantize(k):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
      if k == 32:
        out = input
      elif k == 1:
        out = torch.sign(input)
      else:
        n = float(2 ** k - 1)
        out = torch.round(input * n) / n
      return out

    @staticmethod
    def backward(ctx, grad_output):
      grad_input = grad_output.clone()
      return grad_input

  return qfn().apply

  • 激活层激活后进行输出的比特量化为:
class activation_quantize_fn(nn.Module):
  def __init__(self, a_bit):
    super(activation_quantize_fn, self).__init__()
    assert a_bit <= 8 or a_bit == 32
    self.a_bit = a_bit
    self.uniform_q = uniform_quantize(k=a_bit)

  def forward(self, x):
    if self.a_bit == 32:
      activation_q = x
    else:
      activation_q = self.uniform_q(torch.clamp(x, 0, 1))
      # print(np.unique(activation_q.detach().numpy()))
    return activation_q

4、 进行模型通道压缩,以yolov3代码为例子:

通道压缩主要通过编排BN的aX+b=ouput的权重a的分布。训练完后排序这个权重a,根据裁剪比例进行获取阈值,然后把小于阈值的通道进行排除、大于阈值的保留。

  • 获取包含有BN层的网络层数索引:
def parse_module_defs(module_defs):
    CBL_idx = []
    Conv_idx = []
    for i, module_def in enumerate(module_defs):
        # 具有bn层的,才进行保存索引,然后用来压缩训练。
        if module_def['type'] == 'convolutional':
            if module_def['batch_normalize'] == '1':
                CBL_idx.append(i)
            else:
                Conv_idx.append(i)
    #排除一些不要进行压缩的层,为了保证准确度。
    ignore_idx = set()
    for i, module_def in enumerate(module_defs):
        if module_def['type'] == 'shortcut':
            ignore_idx.add(i-1)
            identity_idx = (i + int(module_def['from']))
            if module_defs[identity_idx]['type'] == 'convolutional':
                ignore_idx.add(identity_idx)
            elif module_defs[identity_idx]['type'] == 'shortcut':
                ignore_idx.add(identity_idx - 1)

    ignore_idx.add(84)
    ignore_idx.add(96)
    #获取符合要求跟需要压缩的网络层索引。
    prune_idx = [idx for idx in CBL_idx if idx not in ignore_idx]

    return CBL_idx, Conv_idx, prune_idx

  • 进行训练的代码:
        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            batches_done = len(dataloader) * epoch + batch_i

            imgs = imgs.to(device)
            targets = targets.to(device)

            loss, outputs = model(imgs, targets)

            optimizer.zero_grad()
            loss.backward()
            #进行模型权重的稀疏化训练,opt.s可以控制稀疏值,prune_idx是包含需要进行稀疏训练的
            #卷积层索引。
            BNOptimizer.updateBN(sr_flag, model.module_list, opt.s, prune_idx)

            optimizer.step()
  • 其中进行稀疏训练的主要函数,原理:通过外加opt.s值来使得权重越重要的可以获取到一个正梯度外加值。代码函数如下:
class BNOptimizer():

    @staticmethod
    def updateBN(sr_flag, module_list, s, prune_idx):
        if sr_flag:
            for idx in prune_idx:
                # Squential(Conv, BN, Lrelu)
                bn_module = module_list[idx][1]
                #torch.sign()是一个跳变函数,其值只有-1,0,-1.
                #这里的s指的是附加到权重梯度grad上的值,其可以用来加大或减少梯度值,
                #从而在更新权重的时候,权值大的可以变的越来越大。
                bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data))  # L1

  • 在整个模型训练完后,就可以通过自己设置的prune比例进行删减不必要的网络通道,其函数如下:
percent = 0.85
def prune_and_eval(model, sorted_bn, percent=.0):
    model_copy = deepcopy(model)
    #根据裁剪比例进行获取停止的通道索引
    thre_index = int(len(sorted_bn) * percent)
    #任何根据这个停止的通道索引获取其对应的通道权重值,
    #然后作为刷选阈值。
    thre = sorted_bn[thre_index]
    print(f'Channels with Gamma value less than {thre:.4f} are pruned!')
    remain_num = 0
    for idx in prune_idx:
        bn_module = model_copy.module_list[idx][1]
        #通过阈值获取刷选通道
        mask = obtain_bn_mask(bn_module, thre)
        remain_num += int(mask.sum())
        #根据mask的0\1进行获取保留了通道的
        bn_module.weight.data.mul_(mask)
    mAP = eval_model(model_copy)[2].mean()

    print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}')
    print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}')
    print(f'mAP of the pruned model is {mAP:.4f}')

    return thre

猜你喜欢

转载自blog.csdn.net/yangdashi888/article/details/104986121