[GAN] スペクトル正規化 スペクトル正規化 - 原理と実装

1. スペクトル正規化の背景

スペクトル正規化は、「Spectral Normalization For Generative Adversarial Networks」論文リンクで提案されています。

ネイティブ GAN の目的関数は、生成されたデータの分布と実際のデータの分布の間の JS の発散 (Jensen-Shannon Divergence) を最適化することに相当します。
両者の間には無視できない重なりがほとんどないため、どれだけ離れていても、JS の発散は一定の log2 となり、最終的にジェネレーターの勾配が (近似的に) 0 になり、勾配が消滅します。
つまり、弁別器がより適切に訓練されるほど、生成器の勾配はより深刻に消失します。

WGAN は、優れた Wasserstein 距離を使用して、ネイティブ GAN の JS 発散を置き換えます。次に、KR 双対原理を使用して、Wasserstein 距離を解く問題を、最適な Lipschitz 連続関数を解く問題に変換します。識別子 D がリプシッツ連続性を満たすようにするために、著者は「勾配クリッピング」を使用して、しきい値の直下で過度に大きなパラメータをクリップします。

「勾配クリッピング」技術は、ニューラル ネットワークの各層のパラメーター マトリックスのスペクトル ノルムの観点からリプシッツ連続性制約を導入するため、ニューラル ネットワークは入力外乱に対する感度が向上し、トレーニング プロセスがより安定して効率的になります。収束しやすい。(深層学習モデルには「敵対的攻撃サンプル」があります。たとえば、画像が 1 ピクセルだけ変更された場合、まったく異なる分類結果が得られます。これは、モデルが入力に対して敏感すぎるケースです。)

極小点付近が平坦(傾きが抑えられている)であれば汎化性能が良く、逆に平坦ではない(鋭さがある)とわずかな変化で変化が生じると理解できます。 Changeが大きいほど汎化性能が悪く、不安定です。

Spectral Norm は、より洗練された方法を使用して、識別子 D がリプシッツ連続性を満たすようにします。これにより、関数の急激な変化が制限され、モデルがより安定します。

2. リプシッツ連続性

リプシッツ条件は、関数の変化の重大度、つまり関数の最大勾配を制限します。

K-リプシッツとは、関数の最大の傾きがKであることを意味し、Kをリプシッツ定数(リプシッツ定数)といいます。たとえば、y = sinx の最大傾きは 1 であるため、1-リプシッツの傾きになります。

ここに画像の説明を挿入
赤い線は sin(x)、最大傾きは 1、黄色は制限領域です
次に、
ここに画像の説明を挿入
重要な理論: 行列 A をスペクトル ノルムで除算すると、1-リプシッツ連続性が得られます。

3. GANのスペクトル正規化の原理

GAN でスペクトル ノルムを実行することは、実際には、識別子 D が 1-リプシッツ条件を満たすようにすることです。

また、弁別器 D は各層によって追加されるバイアスを省略するため、この多層ニューラル ネットワークは実際には複数の複合関数の入れ子操作になります。
最も一般的な入れ子は、畳み込みの 1 つの層、活性化関数の 1 つの層、畳み込みの別の層、活性化関数の別の層であり、層がラップされます。活性化関数は通常 ReLU を選択し、Leaky ReLU は 1-リプシッツです。畳み込み部分が 1-リプシッツ連続であることを確認するだけで済みます (線形層がある場合は、線形層が 1-リプシッツ連続であることも確認します)。ニューラル ネットワーク全体が 1-リプシッツ連続であることを保証できます。

では、畳み込み部分が 1-リプシッツ連続であることを確認するにはどうすればよいでしょうか?
画像上の各位置の畳み込み演算は行列の乗算とみなすことができます。したがって、各層の畳み込みカーネルのパラメータ W を 1-リプシッツ連続になるように制約するだけで、畳み込み部分が 1-リプシッツ連続になるようにすることができ、全体の 1-リプシッツ連続を満たすことができます。ニューラルネットワーク。

したがって、具体的な方法は次のとおりです。
弁別器 D の畳み込みの各層のパラメータ行列 W をスペクトル正規化します。つまり、
ステップ 1:
W に対して SVD を実行します (実装では、べき乗反復法を使用して SVD を近似し、計算コスト )、W の最大特異値を取得します。
ステップ 2:
1-リプシッツ連続性を満たすように、W の各更新後に W の最大の特異値で除算します。

知らせ?:
ディスクリミネーター D がスペクトル ノルムを使用した後は、BatchNorm (または他のノルム) を使用できません。理由も非常に単純で、バッチ ノルムの「分散の除算」と「スケーリング係数の乗算」という 2 つの操作により、明らかに識別器のリプシッツ連続性が破壊されてしまうからです

4. GANのスペクトル正規化の実装

Google は tensorflow を使用してスペクトル正規化関数を実装しています。リンク
pytorch には適切に実装されたスペクトル正規化関数があります torch.nn.utils.spectral_norm (公式ドキュメント) (github)
ここに画像の説明を挿入

import torch.nn as nn
import torch

# 对线性层做谱归一化
sn_module = nn.utils.spectral_norm(nn.Linear(20,40))
# 验证谱归一化后的线性层是否满足1-Lipschitz continuity
print(torch.linalg.norm(sn_module.weight,2))# tensor(1.3898)  为啥不是1.000呢?

ただし、公式ドキュメントでは新しいバージョンのスペクトル正規化関数 torch.nn.utils.parametrizations.spectral_norm を使用することが推奨されています
が、私がこれまで見た限りでは、まだ torch.nn.utils.spectral_norm が使用されています。

使用法:

# #############################################################################################
# 用法1 ref: https://blog.csdn.net/qq_37950002/article/details/115592633
# #############################################################################################
import torch
import torch.nn as nn

class TestModule(nn.Module):
    def __init__(self):
        super(TestModule,self).__init__()
        self.layer1 = nn.Conv2d(16,32,3,1)
        self.layer2 = nn.Linear(32,10)
        self.layer3 = nn.Linear(32,10)
 
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)

model = TestModule()

def add_sn(m):
    for name, layer in m.named_children():
        m.add_module(name, add_sn(layer))
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            return nn.utils.spectral_norm(m)
        else:
            return m

my_model = add_sn(model)

# #############################################################################################
# 用法2 ref: https://github.com/Vbansal21/Custom_Architecture/EATS/models/v2_discriminator.py
# #############################################################################################
...
# 直接在卷积/线性层外面套一层nn.utils.spectral_norm()
nn.Sequential(
                nn.ReflectionPad1d(7),
                nn.utils.spectral_norm(nn.Conv1d(1, 16, kernel_size=15)),
                nn.LeakyReLU(0.2, True),
            ),
...

おすすめ

転載: blog.csdn.net/lingchen1906/article/details/129827098