1. RepVgg
Ding X, Zhang X, Ma N, et al. Repvgg: Making vgg-style convnets great again[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13733-13742.
The core idea of this paper is to equate the complicated residual connection to the simple vgg block form. Use the normal residual structure to train during training, and use the equivalent vgg block during inference, which can greatly increase the speed of model inference without any performance loss.
As shown below, the left picture is a multi-branch structure. Our goal is to fuse 3x3, 1x1, original input and BN layers into a 3x3 convolution operation.
First, let's review the formula of BN:
y = x − E ( x ) D ( x ) + ξ ⋅ γ + β y = \frac{xE(x)}{\sqrt{D(x)+\xi}} \ cdot \gamma + \betay=D(x)+Xx−E ( x )⋅c+βWe
make the output after convolutionconv ( x ) conv(x)co n v ( x ) , and then substitute into the above formula:
y = conv ( x ) − E ( conv ( x ) ) D ( conv ( x ) ) + ξ ⋅ γ + β y = \frac{conv(x )-E(conv(x))}{\sqrt{D(conv(x))+\xi}} \cdot \gamma + \betay=D(conv(x))+Xconv(x)−E(conv(x))⋅c+β
Then a closer look shows that we can think of the above expression as a biased convolution. The weight isconv ( x ) D ( conv ( x ) ) + ξ \frac{conv(x)}{\sqrt{D(conv(x))+\xi}}D(conv(x))+ξconv(x), the bias is β − E ( conv ( x ) ) D ( conv ( x ) ) + ξ \beta -\frac{E(conv(x))}{\sqrt{D(conv(x))+\xi }}b−D(conv(x))+ξE(conv(x)).
Through the above operations, we have fused convolution and BN together into a biased convolution. Before, our convolution and BN were written separately, and the bias of the convolution was set to False.
Next, it is necessary to replace the 1x1 convolution equivalently with a 3x3 convolution. Usually, when using a residual connection, it is necessary to ensure that the shapes of the input and output are equal before they can be added. Therefore, for the padding of the input feature map in the 3x3 convolution To be set to 1, so we only need to fill the convolution kernel of 1x1 convolution with 0 to a 3x3 convolution kernel. This operation is equivalent to the result of direct convolution with 1x1.
Finally, it is also very simple to convert the original input into a 3x3 convolution. It is enough to construct an all-zero convolution kernel with a shape of [out_channels, in_channels, 3, 3]. The center point of one of the in_channels in each out_channel is set to 1, which can actually be regarded as a special 1x1 convolution.
The multiplication of the matrix has a distribution law. Assuming that the 3x3 convolution kernel is A, the 1x1 is B, the original input is C, and the input feature map is X, then there is AX + A b + BX + B b + CX + C b = ( A + B + C ) X + A b + B b + C b AX+A_b+BX+B_b+CX+C_b = (A+B+C)X+A_b+B_b+C_bAX+Ab+BX+Bb+CX+Cb=(A+B+C)X+Ab+Bb+Cb, where the subscript is the bias. It can be seen that we have successfully merged three convolutions and BN into one.
2. Code implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def _conv_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1,groups=1):
# 卷积+bn
res = nn.Sequential()
res.add_module("conv", nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode="zeros",
groups=groups,
bias=True))
res.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
return res
class RepBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, deploy=False):
super(RepBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.deploy = deploy
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.groups = groups
self.activation = nn.ReLU()
assert self.kernel_size == 3
assert self.padding == 1
if not self.deploy:
# 训练模式,正常的带分支的结构
self.brb_3x3 = _conv_bn(in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
stride=stride,
padding=self.padding,
groups=groups)
self.brb_1x1 = _conv_bn(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups)
self.brb_identity = nn.BatchNorm2d(self.in_channels) if self.in_channels == self.out_channels else None
else:
# 推理模式,需要进行重参数
self.brb_rep = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
padding=self.padding,
stride=stride,
bias=True)
def forward(self, inputs):
if(self.deploy):
# 推理模式
return self.activation(self.brb_rep(inputs))
if(self.brb_identity==None):
identity_out=0
else:
identity_out=self.brb_identity(inputs)
return self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out)
def _switch_to_deploy(self):
self.deploy = True
kernel, bias = self._get_equivalent_kernel_bias()
self.brb_rep=nn.Conv2d(in_channels=self.brb_3x3.conv.in_channels,out_channels=self.brb_3x3.conv.out_channels,
kernel_size=self.brb_3x3.conv.kernel_size,padding=self.brb_3x3.conv.padding,
padding_mode=self.brb_3x3.conv.padding_mode,stride=self.brb_3x3.conv.stride,
groups=self.brb_3x3.conv.groups,bias=True)
self.brb_rep.weight.data=kernel
self.brb_rep.bias.data=bias
for para in self.parameters():
para.detach_()
#删除没用的分支
self.__delattr__('brb_3x3')
self.__delattr__('brb_1x1')
self.__delattr__('brb_identity')
def _pad_1x1_kernel(self,kernel):
# 把1x1卷积填充为3x3卷积
if(kernel is None):
return 0
else:
return F.pad(kernel,[1]*4)
#将identity,1x1,3x3的卷积融合到一起,变成一个3x3卷积的参数
def _get_equivalent_kernel_bias(self):
brb_3x3_weight,brb_3x3_bias=self._fuse_conv_bn(self.brb_3x3)
brb_1x1_weight,brb_1x1_bias=self._fuse_conv_bn(self.brb_1x1)
brb_id_weight,brb_id_bias=self._fuse_conv_bn(self.brb_identity)
return brb_3x3_weight+self._pad_1x1_kernel(brb_1x1_weight)+brb_id_weight,brb_3x3_bias+brb_1x1_bias+brb_id_bias
### 将卷积和BN的参数融合到一起
def _fuse_conv_bn(self,branch):
bias = torch.tensor(0, dtype=torch.float32)
if(branch is None):
return 0, 0
elif(isinstance(branch,nn.Sequential)):
# 传入的是卷积+bn块
kernel = branch.conv.weight #[out_channels, in_channels, kernel_H, kernel_W]
if branch.conv.bias is not None:
bias = branch.conv.bias # [out_channels]
running_mean = branch.bn.running_mean # [out_channels]
running_var = branch.bn.running_var # [out_channels]
gamma = branch.bn.weight # [out_channels]
beta = branch.bn.bias # [out_channels]
eps = branch.bn.eps # [out_channels]
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.out_channels, input_dim, 3, 3), dtype=np.float32)
for i in range(self.out_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std=(running_var+eps).sqrt()
t=gamma/std
t=t.view(-1,1,1,1) # 扩充为四维,广播机制来加上bias
return kernel*t,beta + (bias - running_mean*gamma)/std
input=torch.randn(3,3,49,49)
repblock=RepBlock(3,128)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())