ネットワークスリムプルーニング用のカスタムコード

ネットワークスリムプルーニング用のカスタムコード

最近のプロジェクトでは、検出アルゴリズムの最終展開が必要ですが、展開後のアルゴリズムの推論速度がそれほど速くないため、モデルを枝刈りする必要があります。

私が参照している枝刈りアルゴリズムは、2017 ICCV の古典的な枝刈りアルゴリズム「Learning Efficient Convolutional Networks through Network Slimming」です。アルゴリズムの原理も非常にシンプルです。使用する必要がある事前知識は、まずバッチ正規化の原理を理解することです。BN の原理については、バッチ正規化の原理と勾配の消失と勾配の爆発を簡単に分析しまし。この記事で必要なものをご覧いただけます。

私の検出アルゴリズムは mmdet のフレームワーク上に構築されているため、Microsoft のオープン ソースのプルーニング ツール nni を試しましたが、失敗しました。最終的には、プルーニング用のコードを自分で作成する予定でした。以下は、私のアイデアと具体的なコードの実装です。剪定。

1.1 剪定のアイデア

上記の参考文献で述べられている原則に従って、BN 層をプルーニングする必要があります。resnet50 を例に挙げると、通常、ネットワークを構築する場合、Conv 層の後に BN 層が接続され、次の Conv 層は活性化関数層に加えて BN 層に接続されます。便宜上、BN 層を置きます畳み込み層は Conv1 として示され、後者は Conv2 として示されます。それらの間の関係は次のとおりです: Conv1 の出力チャネル数は BN 層のパラメータ次元と一致しており、Conv2 の入力チャネル数とも一致している必要があります。以下の図に示すように、ここに例を示します。 resnet50 の 1 層の Conv1 の入力チャネル数
ここに画像の説明を挿入
は 256、コンボリューション カーネル サイズは [1,1]、出力チャネル数は 64、Conv1 の出力は BN 層の入力として使用され、 BN 層処理後の出力は入力チャネルの数と一致しており、両方とも 64 です。その場合、BN 層の出力は Conv2 の入力として使用され、Conv2 は 64 チャネルの入力を受け入れます。出力も64チャンネルです。

次に、これもとても簡単なBNの剪定方法を紹介します。BN 層は、入力の各チャンネルの 2 つのパラメーターをそれぞれ学習しますβγ \ガンマγ、次の式に対応します:
y 1 ← γ 1 x 1 ^ + β 1 ≡ BN γ 1 , β 1 ( x 1 ) y_{1} \leftarrow \gamma_1 \hat{x_{1}}+\beta_1 \相当 B N_{\gamma_1, \beta_1}\left(x_{1}\right)y1c1バツ1^+b1BN _c1b1( ×1)
したがって、パラメータγ \gammaγ は、各チャネルの重要性を測定するための重みとして見ることができます。しきい値を設定できます。このしきい値より低い場合はチャネルが削除され、高い場合は保持されます。以下の図に示すように。
ここに画像の説明を挿入

1.2 コード実装のアイデア

原理を明確にしたら、次のステップはこのアイデアを実現するためのコードを書くことです。BN 層をプルーニングしたい場合は、BN 層に直接接続されている 2 つの Conv 層を同時にプルーニングする必要があります。

まず、モデル全体をロードします。ここで使用する方法は、torch.load() 関数を直接使用することです。

# 保存整个网络
torch.save(model, PATH)
# 加载整个模型
torch.load(PATH)

この方法でロードされた印刷モデルは次のようになります。
ここに画像の説明を挿入
これはモデルの構造であることがわかります (resnet50 の一部のみを取り出しました)。次に、そのパラメーターを印刷して確認します。上記はパラメーターの寸法の一部です
ここに画像の説明を挿入
。モデルの。BN 層のパラメータは 64 次元であり、前の Conv 層の出力チャネルの数に対応し、BN 層には重み、バイアス、平均、および変数があり、これらは次元を持つ 4 つのパラメータであることがわかります。同時に剪定する必要があります。

原理の前の部分で述べたように、γ \gammaをとります。γパラメータはチャネルの重要度を測る重みとして使用され、パラメータ部分の重みに相当する。したがって、BN 層をプルーニングするときは、最初にしきい値を設定し、次に重みの各値をしきい値と比較して、しきい値より大きいインデックスを取得する必要があります。

def find_indice(module, thresh):  #module就是一个BN层
    gamma = module.weight.data
    mask = gamma > thresh
    indices = torch.nonzero(mask).view(-1)
    return indices

次に、このインデックスを使用して、BN レイヤーの 4 つのパラメーターをプルーニングします。パラメーターの枝刈りに加えて、構造の枝刈りにも特別な注意を払う必要があります。つまり、モデル上の BN 層に対応するチャネルの数を枝刈り後の次元に変更することです。

#对参数进行修剪
m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"]]  #gamma
m.bias.data = m.bias.data[bn_dict["backbone.layer1.0.bn1"]]      #beta
m.running_mean.data = m.running_mean.data[bn_dict["backbone.layer1.0.bn1"]]
m.running_var.data = m.running_var.data[bn_dict["backbone.layer1.0.bn1"]]
#对结构进行修剪
m.num_features = bn_dict["backbone.layer1.0.bn1"].size()[0]

さらに、BN レイヤーの前の Conv の出力レイヤーをトリミングし、BN レイヤーの後ろの Conv の入力レイヤーをトリミングする必要があります。コード全体は次のとおりです。

#先得到所有BN层需要保留的权重索引
bn_dict = dict()
for name, m in model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        indice = find_indice(m, thresh=0.17)
        bn_dict[name] = indice
        
#进行剪枝
if name == "backbone.layer1.0.conv1":
    m.weight.data = m.weight.data[:, bn_dict["backbone.bn1"], :, :]
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"], :, :, :]
    m.in_channels = bn_dict["backbone.bn1"].size()[0]
    m.out_channels = bn_dict["backbone.layer1.0.bn1"].size()[0]
if name == "backbone.layer1.0.bn1":
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn1"]]  #gamma
    m.bias.data = m.bias.data[bn_dict["backbone.layer1.0.bn1"]]      #beta
    m.running_mean.data = m.running_mean.data[bn_dict["backbone.layer1.0.bn1"]]
    m.running_var.data = m.running_var.data[bn_dict["backbone.layer1.0.bn1"]]
    m.num_features = bn_dict["backbone.layer1.0.bn1"].size()[0]
if name == "backbone.layer1.0.conv2":
    m.weight.data = m.weight.data[:, bn_dict["backbone.layer1.0.bn1"], :, :]
    m.weight.data = m.weight.data[bn_dict["backbone.layer1.0.bn2"], :, :, :]
    m.in_channels = bn_dict["backbone.layer1.0.bn1"].size()[0]
    m.out_channels = bn_dict["backbone.layer1.0.bn2"].size()[0]

切断後、モデルを保存すると、重量と構造がトリミングされたモデルが印刷されます。

torch.save(model, "修剪好的模型的保存路径")

印刷:
ここに画像の説明を挿入
枝刈りされたモデルは直接実行することもできます。

私自身の検出モデルのバックボーンは resnet50 であり、バックボーン部分のみをトリミングしました。トリミング後、mAP は少し低下しました (設定されたしきい値に関連し、トリミングははるかに低くなります) が、微調整後の効果は良くなりました。トリミング前よりも良くなりました。この現象を仮にBN層のγ \gammaと解釈してみます。γパラメータは、アテンション メカニズムとして使用できます。誰かがより良い説明をできる場合は、私を修正してください。

1.3 注意事項

剪定の途中にも穴がいくつかありますので、ここに記録してください。

Resnet50 にはいくつかの残りの接続があることに注意してください。各リレイヤー層の先頭には、Conv と BN で構成される別のダウンサンプル レイヤーがあります。リレイヤーの入力次元が部分的にカットされている場合は、ダウンサンプル レイヤーも調整することを忘れないでください。

resnet50 の構造図を添付します。
ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/weixin_45453121/article/details/130891939