Convolution和Batch normalization的融合

理论推算

当前CNN卷积层的基本组成单元标配:Conv + BN +ReLU 三个子模块。但其实在网络的推理阶段,可以将BN层的运算融合到Conv层中,减少运算量,加速推理。本质上是修改了卷积核的参数,在不增加Conv层计算量的同时,略去了BN层的计算量。公式推导如下。

conv层的参数
在这里插入图片描述
BN层的参数
在这里插入图片描述
假设输入为x,则x->Conv->BN的输出便是:
在这里插入图片描述
做个简单的公式变形:
在这里插入图片描述
在这里插入图片描述

代码实现

在实际使用时,首先要定位conv和bn的位置,根据实际情况进行替换或者删除BN层。在本次实施例中,以开源分割模型库https://github.com/qubvel/segmentation_models.pytorch为案例进行融合实验,对BN层进行了替换。

class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):

        if use_batchnorm == "inplace" and InPlaceABN is None:
            raise RuntimeError(
                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
                + "To install see: https://github.com/mapillary/inplace_abn"
            )

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm == "inplace":
            bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
            relu = nn.Identity()

        elif use_batchnorm and use_batchnorm != "inplace":
            bn = nn.BatchNorm2d(out_channels)

        else:
            bn = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, bn, relu)
from turtle import forward
from torch.fx.experimental.optimization import fuse
import torch
import torch.nn as nn
import time

import segmentation_models_pytorch.base.modules as md

from utils.torchUtils import fuse_conv_and_bn


def fuseModel(model):  # fuse model Conv2d() + BatchNorm2d() layers
    for m in model.modules():
        if isinstance(m, (md.Conv2dReLU)) and isinstance(m[1], (nn.BatchNorm2d)):
            m[0] = fuse_conv_and_bn(m[0], m[1])  # update conv
            m[1] = nn.Identity()
            count += 1
    return model

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/125551559
今日推荐