Paper reading notes: RepVgg

1. RepVgg

Ding X, Zhang X, Ma N, et al. Repvgg: Making vgg-style convnets great again[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13733-13742.

The core idea of ​​this paper is to equate the complicated residual connection to the simple vgg block form. Use the normal residual structure to train during training, and use the equivalent vgg block during inference, which can greatly increase the speed of model inference without any performance loss.
insert image description here
As shown below, the left picture is a multi-branch structure. Our goal is to fuse 3x3, 1x1, original input and BN layers into a 3x3 convolution operation.
First, let's review the formula of BN:
y = x − E ( x ) D ( x ) + ξ ⋅ γ + β y = \frac{xE(x)}{\sqrt{D(x)+\xi}} \ cdot \gamma + \betay=D(x)+X xE ( x )c+βWe
make the output after convolutionconv ( x ) conv(x)co n v ( x ) , and then substitute into the above formula:
y = conv ( x ) − E ( conv ( x ) ) D ( conv ( x ) ) + ξ ⋅ γ + β y = \frac{conv(x )-E(conv(x))}{\sqrt{D(conv(x))+\xi}} \cdot \gamma + \betay=D(conv(x))+X conv(x)E(conv(x))c+β
Then a closer look shows that we can think of the above expression as a biased convolution. The weight isconv ( x ) D ( conv ( x ) ) + ξ \frac{conv(x)}{\sqrt{D(conv(x))+\xi}}D(conv(x))+ξ conv(x), the bias is β − E ( conv ( x ) ) D ( conv ( x ) ) + ξ \beta -\frac{E(conv(x))}{\sqrt{D(conv(x))+\xi }}bD(conv(x))+ξ E(conv(x)).
Through the above operations, we have fused convolution and BN together into a biased convolution. Before, our convolution and BN were written separately, and the bias of the convolution was set to False.

Next, it is necessary to replace the 1x1 convolution equivalently with a 3x3 convolution. Usually, when using a residual connection, it is necessary to ensure that the shapes of the input and output are equal before they can be added. Therefore, for the padding of the input feature map in the 3x3 convolution To be set to 1, so we only need to fill the convolution kernel of 1x1 convolution with 0 to a 3x3 convolution kernel. This operation is equivalent to the result of direct convolution with 1x1.

Finally, it is also very simple to convert the original input into a 3x3 convolution. It is enough to construct an all-zero convolution kernel with a shape of [out_channels, in_channels, 3, 3]. The center point of one of the in_channels in each out_channel is set to 1, which can actually be regarded as a special 1x1 convolution.

The multiplication of the matrix has a distribution law. Assuming that the 3x3 convolution kernel is A, the 1x1 is B, the original input is C, and the input feature map is X, then there is AX + A b + BX + B b + CX + C b = ( A + B + C ) X + A b + B b + C b AX+A_b+BX+B_b+CX+C_b = (A+B+C)X+A_b+B_b+C_bAX+Ab+BX+Bb+CX+Cb=(A+B+C)X+Ab+Bb+Cb, where the subscript is the bias. It can be seen that we have successfully merged three convolutions and BN into one.
insert image description here

2. Code implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def _conv_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1,groups=1):
    # 卷积+bn
    res = nn.Sequential()
    res.add_module("conv", nn.Conv2d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=padding,
                                     padding_mode="zeros",
                                     groups=groups,
                                     bias=True))
    res.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
    return res
 
class RepBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, deploy=False):
        super(RepBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.deploy = deploy
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.groups = groups
        self.activation = nn.ReLU()

        assert self.kernel_size == 3
        assert self.padding == 1

        if not self.deploy:
            # 训练模式,正常的带分支的结构
            self.brb_3x3 = _conv_bn(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=self.kernel_size,
                                    stride=stride,
                                    padding=self.padding,
                                    groups=groups)
            self.brb_1x1 = _conv_bn(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    groups=groups)
            self.brb_identity = nn.BatchNorm2d(self.in_channels) if self.in_channels == self.out_channels else None
        else:
            # 推理模式,需要进行重参数
            self.brb_rep = nn.Conv2d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=self.kernel_size,
                                     padding=self.padding,
                                     stride=stride,
                                     bias=True)

    def forward(self, inputs):
          if(self.deploy):
               # 推理模式
               return self.activation(self.brb_rep(inputs))

          if(self.brb_identity==None):
               identity_out=0
          else:
               identity_out=self.brb_identity(inputs)

          return self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out)

    def _switch_to_deploy(self):
        self.deploy = True
        kernel, bias = self._get_equivalent_kernel_bias()
        self.brb_rep=nn.Conv2d(in_channels=self.brb_3x3.conv.in_channels,out_channels=self.brb_3x3.conv.out_channels,
                                   kernel_size=self.brb_3x3.conv.kernel_size,padding=self.brb_3x3.conv.padding,
                                   padding_mode=self.brb_3x3.conv.padding_mode,stride=self.brb_3x3.conv.stride,
                                   groups=self.brb_3x3.conv.groups,bias=True)
        self.brb_rep.weight.data=kernel
        self.brb_rep.bias.data=bias
        for para in self.parameters():
               para.detach_()
        #删除没用的分支
        self.__delattr__('brb_3x3')
        self.__delattr__('brb_1x1')
        self.__delattr__('brb_identity')

    def _pad_1x1_kernel(self,kernel):
        # 把1x1卷积填充为3x3卷积
        if(kernel is None):
           return 0
        else:
           return F.pad(kernel,[1]*4)

         #将identity,1x1,3x3的卷积融合到一起,变成一个3x3卷积的参数
    def _get_equivalent_kernel_bias(self):
        brb_3x3_weight,brb_3x3_bias=self._fuse_conv_bn(self.brb_3x3)
        brb_1x1_weight,brb_1x1_bias=self._fuse_conv_bn(self.brb_1x1)
        brb_id_weight,brb_id_bias=self._fuse_conv_bn(self.brb_identity)
        return brb_3x3_weight+self._pad_1x1_kernel(brb_1x1_weight)+brb_id_weight,brb_3x3_bias+brb_1x1_bias+brb_id_bias

    ### 将卷积和BN的参数融合到一起
    def _fuse_conv_bn(self,branch):
        bias = torch.tensor(0, dtype=torch.float32)
        if(branch is None):
           return 0, 0
        elif(isinstance(branch,nn.Sequential)):
           # 传入的是卷积+bn块
           kernel = branch.conv.weight #[out_channels, in_channels, kernel_H, kernel_W]
           if branch.conv.bias is not None:
               bias = branch.conv.bias # [out_channels]
           running_mean = branch.bn.running_mean # [out_channels]
           running_var = branch.bn.running_var  # [out_channels]
           gamma = branch.bn.weight  # [out_channels]
           beta = branch.bn.bias  # [out_channels]
           eps = branch.bn.eps  # [out_channels]
        else:
           assert isinstance(branch, nn.BatchNorm2d)
           if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.out_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.out_channels):
                     kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
           kernel = self.id_tensor
           running_mean = branch.running_mean
           running_var = branch.running_var
           gamma = branch.weight
           beta = branch.bias
           eps = branch.eps

        std=(running_var+eps).sqrt()
        t=gamma/std
        t=t.view(-1,1,1,1)  # 扩充为四维,广播机制来加上bias
        return kernel*t,beta + (bias - running_mean*gamma)/std

input=torch.randn(3,3,49,49)
repblock=RepBlock(3,128)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())

Guess you like

Origin blog.csdn.net/loki2018/article/details/127460571