オーバーホール蒸留 (ICCV 2019) の原理とコード分析

論文:特徴抽出の包括的な見直し

公式実装:GitHub - clovaai/overhour-distillation: 「機能蒸留の包括的なオーバーホール」 (ICCV 2019) の公式 PyTorch 実装

この記事の革新性

この論文では、知識蒸留のさまざまな側面を研究し、教師特徴変換、生徒特徴変換、特徴蒸留位置、および距離関数の間で蒸留損失を相乗させる新しい特徴蒸留方法を提案します。具体的には、本論文で提案する蒸留損失には、新しく設計されたマージンレル特徴変換方法、新しい蒸留位置、および部分L2距離関数が含まれています。ImageNet では、この論文で提案された方法により、ResNet-50 は 21.65% のトップ 1 誤差を達成でき、これは教師ネットワーク ResNet-152 の精度よりも優れています。

メソッドの紹介 

蒸留場所

活性化関数はニューラル ネットワークの重要な部分であり、ネットワークを非線形にします。しかし、これまでの蒸留手法の多くは活性化関数を考慮しておらず、蒸留位置も特定の層やブロックの端にある場合が多く、ReLUなどの活性化関数との関係は考慮されていませんでした。この論文で提案する方法では、下図に示すように、蒸留位置は層の端と最初の ReLU の間に位置します。

 

pre-ReLU の位置付けにより、生徒は ReLU を通じて教師モデル間の情報にさらされるようになり、情報の分解や損失が回避されます。

損失関数

蒸留位置が ReLU の前であるため、特徴の正の値には教師が使用する情報が含まれ、負の値には含まれていません。教師ネットワークの値が正の場合、生徒ネットワークは同じ情報を生成する必要があります教師ネットワークが負の値の場合、生徒も負の値を生成して、活性化状態を教師と一致させる必要があります。したがって、著者が提案した教師の変換関数は正の値を保存し、負の値のマージンを持ちます 

ここで、 \(m\) は 0 未満のマージン値であり、著者はこれを margin ReLu と名付けました。\(m\) の特定の値は、各チャネルの負の応答値の期待値として次のように定義されます。

 

\(m\) 一方では、https://github.com/clovaai/overhour-distillation/issues/7に示すように、トレーニング プロセス中に直接計算できます。前の BN 層のパラメータによって計算することもできますが、具体的な計算方法は付録で著者が示しています。

チャネル\(\mathcal{C}\) と教師特徴量の \(i\) 番目の要素\(F^{i}_{t}\) の場合、チャネルのマ​​ージン値\(m_{c }\ ) はトレーニング画像の期待値であり、式 (3) です。通常、\(F^{i}_{t}\) の分布はわからないため、学習プロセス中の平均値によってのみ期待値を得ることができます。しかし、ReLU より前の BN 層はバッチ内の特徴 \(F^{i}_{t}\) の分布を決定し、BN 層は各チャネルの特徴を平均 \(\mu\) 分散\ に正規化します。 (\sigma\) ガウス分布、つまり

各チャネルの平均分散\((\mu,\sigma)\) は BN 層のパラメータ\((\beta,\gamma)\) に対応するため、\(F^{i}_{t} \) 分布の限界値を直接計算できます

 

期待値は、ガウス分布の確率密度関数 pdf を積分することによって取得されます。範囲はゼロ未満です。積分の結果は、正規分布の cdf 累積分布関数 \(\Phi(\cdot)\) で簡単に表すことができます。

 

公式実装ではマージン値もこの方法、つまり式(10)で計算されます。

蒸留の位置が ReLU 関数の前にあるため、負の応答は ReLU によってフィルター処理されないため、蒸留損失関数は ReLU を考慮する必要があります。教師機能では、肯定的な応答はネットワークによって実際に使用されます。つまり、教師の肯定的な応答には特定の値が渡される必要がありますが、否定的な応答には渡されません。教師の否定的な反応については、生徒の反応値が目標値より高ければ下げる必要があり、目標値より低ければ上げる必要はありません。具体的な値が何であっても、 ReLU によって除外されます。したがって、本論文では、次のような部分 L2 距離関数を提案します。

完全な蒸留損失関数は次のとおりです。

 

ここで、 \(\sigma_{m_{c}}\) は教師の変換関数 margin ReLU、 \(r\) は生徒の変換関数 1x1 conv + BN、 \(d_{p}\) は部分距離関数ですL2距離。

実験結果

CIFAR-100 データセットにおける、さまざまな教師ネットワークと生徒ネットワークの結果を表 2 に示します。

さまざまな教師と生徒のネットワークの組み合わせ、この方法と他の蒸留方法の結果を以下のように比較すると、すべての組み合わせにおいて、この論文で提案した方法の誤差が最も低いことがわかります。

ImageNet データセットの他の手法との比較を表 4 に示します。この論文の手法の誤差も最も低いことがわかります。 

コード分​​析

実装コードは主に distiller.py にあります. この記事の最初の革新点は蒸留の位置、つまり ReLU の前にあります。 

t_feats, t_out = self.t_net.extract_feature(x, preReLU=True)
s_feats, s_out = self.s_net.extract_feature(x, preReLU=True)

Student featureの変換は1x1 convolution + BN、つまりself.Connectorsの実装となっており、具体的な実装は以下の通りです

def build_feature_connector(t_channel, s_channel):
    C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
         nn.BatchNorm2d(t_channel)]

    for m in C:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

    return nn.Sequential(*C)

教師特徴量の変換が本論文で提案するマージンReLUであり、マージンマージン値の計算は以下のとおり、すなわち上記式(10)

def get_margin_from_BN(bn):
    margin = []
    std = bn.weight.data
    mean = bn.bias.data
    for (s, m) in zip(std, mean):
        s = abs(s.item())
        m = m.item()
        if norm.cdf(-m / s) > 0.001:
            margin.append(- s * math.exp(- (m / s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m)
        else:
            margin.append(-3 * s)

    return torch.FloatTensor(margin).to(std.device)

蒸留損失関数は次のように実装されます。最初の行は教師の特徴量の変換関数、つまり式 (2) です。

def distillation_loss(source, target, margin):
    target = torch.max(target, margin)
    loss = torch.nn.functional.mse_loss(source, target, reduction="none")
    loss = loss * ((source > target) | (target > 0)).float()
    return loss.sum()

おすすめ

転載: blog.csdn.net/ooooocj/article/details/131195287