MMDetection3D code study notes-the role of fuse-conv-bn

MMDetection3D code study notes-the role of fuse-conv-bn

The parameter setting code of fuse-conv-bn in mmdetection3d

parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')

From the description in the help, we can see that the main function of fuse-conv-bn is to increase the inference speed of the model.

Why fuse-conv-bn can speed up the reasoning speed of the model

Reason: The basic unit of the current CNN convolutional layer is: Conv+BN+ReLu Three Musketeers, which has almost become standard. But in fact, in the reasoning stage of the network, the calculations of the BN layer can be merged into the Conv layer to reduce the amount of calculations and accelerate the reasoning. Essentially, the parameters of the convolution kernel are modified, and the calculation amount of the BN layer is omitted without increasing the calculation amount of the Conv layer. The formula is derived as follows.
Insert picture description hereInsert picture description hereAttach a code implementation:

def fuse_conv_and_bn(conv, bn):
    # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    with torch.no_grad():
        # init
        fusedconv = torch.nn.Conv2d(conv.in_channels,
                                    conv.out_channels,
                                    kernel_size=conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    bias=True)

        # prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
        fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

        # prepare spatial bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros(conv.weight.size(0))
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
        fusedconv.bias.copy_(b_conv + b_bn)

        return fusedconv

Here is a little explanation of the code content:

  • For the calculation of W, in fact, it is only necessary to multiply a coefficient on the basis of the original W, so in the source code, W is stretched into a matrix of row vectors (each row vector corresponds to an out_channel), and is paired with the corresponding position coefficient as the element Multiply the corner matrices to get a matrix composed of new row vectors, and finally restore it to the original scale; note that it needs to be stretched to a matrix because only a two-dimensional matrix can do .mm multiplication, and a four-dimensional matrix cannot directly do multiplication.

  • In addition, the calculation of bias in the source code here is actually not done strictly according to the above derivation formula, but the coefficient of b in the original formula is changed from'μ/sqrt(...)' to 1

Guess you like

Origin blog.csdn.net/m0_45388819/article/details/109907805