Notas del estudio de código MMDetection3D: el papel de fuse-conv-bn

Notas del estudio de código MMDetection3D: el papel de fuse-conv-bn

El código de ajuste de parámetros de fuse-conv-bn en mmdetection3d

parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')

Se puede ver en la descripción en la ayuda que la función principal de fuse-conv-bn es aumentar la velocidad de inferencia del modelo (aumentar la velocidad de inferencia)

Por qué fuse-conv-bn puede acelerar la velocidad de razonamiento del modelo

Razón: Los componentes básicos de la capa convolucional actual de CNN son: Conv + BN + ReLu Three Musketeers, que casi se ha convertido en estándar. Pero, de hecho, en la etapa de razonamiento de la red, los cálculos de la capa BN se pueden fusionar en la capa Conv para reducir la cantidad de cálculos y acelerar el razonamiento. Esencialmente, los parámetros del kernel de convolución se modifican y la cantidad de cálculo de la capa BN se omite sin aumentar la cantidad de cálculo de la capa Conv. La fórmula se deriva de la siguiente manera.
Inserte la descripción de la imagen aquíInserte la descripción de la imagen aquíAdjunte una implementación de código:

def fuse_conv_and_bn(conv, bn):
    # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    with torch.no_grad():
        # init
        fusedconv = torch.nn.Conv2d(conv.in_channels,
                                    conv.out_channels,
                                    kernel_size=conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    bias=True)

        # prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
        fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

        # prepare spatial bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros(conv.weight.size(0))
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
        fusedconv.bias.copy_(b_conv + b_bn)

        return fusedconv

Aquí hay una pequeña explicación del contenido del código:

  • Para el cálculo de W, de hecho, solo es necesario multiplicar un coeficiente sobre la base del W original. Por lo tanto, en el código fuente, W se estira en una matriz de vectores de fila (cada vector de fila corresponde a un out_channel) , y se empareja con el coeficiente de posición correspondiente como elemento. Multiplique las matrices de las esquinas para obtener una matriz compuesta de nuevos vectores de fila y, finalmente, restáurela a la escala original; tenga en cuenta que es necesario estirarla a una matriz porque solo dos -matriz dimensional puede hacer multiplicaciones .mm, y una matriz de cuatro dimensiones no puede hacer multiplicaciones directamente.

  • Además, el cálculo del sesgo en el código fuente aquí en realidad no se realiza estrictamente de acuerdo con la fórmula de derivación anterior, pero el coeficiente de b en la fórmula original se cambia de 'μ / sqrt (...)' a 1

Supongo que te gusta

Origin blog.csdn.net/m0_45388819/article/details/109907805
Recomendado
Clasificación