MXNet Symbol Batch Normalization fix_gamma=True转ONNX方法

1. 问题

在这里插入图片描述
当MXNet模型的Batch Normalization的fix_gamma参数为True时,会导致转ONNX模型失败,此时输出的ONNX参数如下图所示,导致ONNX的推理结果和MXNet不一致。
在这里插入图片描述

2. 解决方法

出现MXNet Batch Normalization的fix_gamma参数等于True时,可以手动修改batchnorm_gamma参数值,使ONNX模型输出正常,相关代码如下:

sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
reshape_params = {}
for k, v in arg_params.items():
    if 'batchnorm_gamma' in k:
        v = 1 - v
        reshape_params[k] = v
mx.model.save_checkpoint(prefix, epoch, sym, reshape_params, aux_params)

经过修改后ONNX模型结构如下图所示,结果输出正常。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/linghu8812/article/details/109464911