1.問題
MXNetモデルのバッチ正規化のfix_gammaパラメーターがTrueの場合、ONNXモデルが失敗します。このとき、出力ONNXパラメーターは次の図のようになり、ONNXとMXNetの推論結果は次のようになります。一貫性がありません。
2.解決策
MXNetバッチ正規化のfix_gammaパラメーターがTrueに等しい場合、batchnorm_gammaパラメーター値を手動で変更して、ONNXモデルの出力を正常にすることができます。関連するコードは次のとおりです。
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
reshape_params = {}
for k, v in arg_params.items():
if 'batchnorm_gamma' in k:
v = 1 - v
reshape_params[k] = v
mx.model.save_checkpoint(prefix, epoch, sym, reshape_params, aux_params)
変更後のONNXモデル構造を下図に示しますが、結果は正常です。