MXNetシンボルバッチ正規化fix_gamma = ONNXメソッドにTrue

1.問題

ここに画像の説明を挿入します
MXNetモデルのバッチ正規化のfix_gammaパラメーターがTrueの場合、ONNXモデルが失敗します。このとき、出力ONNXパラメーターは次の図のようになり、ONNXとMXNetの推論結果は次のようになります。一貫性がありません。
ここに画像の説明を挿入します

2.解決策

MXNetバッチ正規化のfix_gammaパラメーターがTrueに等しい場合、batchnorm_gammaパラメーター値を手動で変更して、ONNXモデルの出力を正常にすることができます。関連するコードは次のとおりです。

sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
reshape_params = {}
for k, v in arg_params.items():
    if 'batchnorm_gamma' in k:
        v = 1 - v
        reshape_params[k] = v
mx.model.save_checkpoint(prefix, epoch, sym, reshape_params, aux_params)

変更後のONNXモデル構造を下図に示しますが、結果は正常です。
ここに画像の説明を挿入します

おすすめ

転載: blog.csdn.net/linghu8812/article/details/109464911