在模型推理时合并BN和Conv层

我们在这里简单讲解一下,在模型推理时合并BN和Conv层,能够简化网络架构,起到加速模型推理的作用。在模型中,BN层一般置于Conv层之后。

Conv:

卷积层的计算简单,公式为:

y = w *x+b

BN:

再来回忆一下BN操作的公式

合并conv和bn:

合并的过程可以用以下式子来表示:

\hat{f}_{i,j} = W_{BN}.(W_{conv}.f_{i,j}+b_{conv}) + b_{BN}

合并的结果我们可以用一个卷积操作来表示。

权重:W=W_{BN} .W_{conv}

偏置:b = W_{BN}.b_{conv} + b_{BN}

由BN层的最后两个公式可得:

y_i = \gamma * \frac{x_i- \mu}{\sqrt{\sigma^2+\epsilon}} + \beta

由于Conv层的输出y,就是BN层的输入​ ,然后代入得:

y_i = \gamma* \frac{w*x+b-u}{\sqrt{\sigma^2+\epsilon}} + \beta =(\frac{w}{\sqrt{\sigma^2+\epsilon}}*\gamma)*x + (\frac{b-u}{\sqrt{\sigma^2+\epsilon}}*\gamma + \beta)

式子中:均值 \mu​ ; 方差 ​\sigma^2 ; 较小的数 \epsilon​ (防止分母为0); 缩放因子 \gamma​ ; 偏置 \beta​ ;

因此,Conv和BN层合并后仅用一个卷积操作表示即可

其权值为:

\frac{w}{\sqrt{\sigma^2+\epsilon}}*\gamma

偏置为:

\frac{b-u}{\sqrt{\sigma^2+\epsilon}}*\gamma + \beta

代码:

def fuse_conv_and_bn(conv, bn):
	#
	# init
	fusedconv = torch.nn.Conv2d(
		conv.in_channels,
		conv.out_channels,
		kernel_size=conv.kernel_size,
		stride=conv.stride,
		padding=conv.padding,
		bias=True
	)
	#
	# prepare filters
	w_conv = conv.weight.clone().view(conv.out_channels, -1)
	w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
	fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) )
	#
	# prepare spatial bias
	if conv.bias is not None:
		b_conv = conv.bias
	else:
		b_conv = torch.zeros( conv.weight.size(0) )
	b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
	fusedconv.bias.copy_( torch.matmul(w_bn, b_conv) + b_bn )
	#
	# we're done
	return fusedconv

import torch
import torchvision
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
	rn18.conv1,
	rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_and_bn(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)

参考:

https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

模型部署之Convolution与BatchNorm合并_jefferyqian的博客-CSDN博客

猜你喜欢

转载自blog.csdn.net/u012505617/article/details/125635288