GANの収束と安定性について(Xiao Bai Xue GAN:14)

元のリンク:https//arxiv.org/pdf/1705.07215v5.pdf

前書き

背景:GANをトレーニングする場合、最適化の目標として、生成された分布とトレーニング分布の間の最高度の適合(つまり、最小距離)を採用することがよくあります。ただし、これにより、トレーニングしたGANはローカルでのみ適合し、グローバルでは適合しないことがよくあります。目標分布に近くないので、これに基づいて後悔最小化を使用して最適化されたDRAGANを提案します

コアアイデア:元のGANの一貫した最小化を変更して最小化を後悔し、急激な勾配の出現を防ぎ、最適化中のGANの安定性を確保します。

基本構造

基本コンセプト

シオンの定理

場合\ phi \ subset \ mathbb {R} ^ m、\ theta \ subset \ mathbb {R} ^ n、および\ phiコンパクト凸集合である\ theta凸集合、J:\ phi \ times \ theta \ rightarrow \ mathbb {R}それは次に、二番目の引数で最初の引数と凸で凹面です。

最後の図を視覚的に指すと、上記の条件が満たされた後、これら2つの変数の最適化順序が最終値に影響を与えないことがわかります。

後悔のないアルゴリズム

一連の凸状損失関数L_1、L_2、L_3、..。が時系列L_1、L_2、...、L_t得られるとすると、の場合\ frac {R(T)} {T} = o(1)、この系列は後悔しません

                                                         R(T):= \ sum_ {t = 1} ^ {T} L_ {t}(k_t)-min_ {k \ in K} \ sum_ {t = 1} ^ {T} L_t(k)

上記の式から、後悔のないアルゴリズムは、最適化時に最適化プロセスと取得されたLOSSが同じレベルになるようにすることです。LOSSが反復プロセスと一致する場合、後悔はありません。

後悔のないアルゴリズムをGANのLOSSと組み合わせるJ(\ cdot、\ theta_t)J(\ phi_t、\ cdot)、Kラウンドのジェネレーター損失関数、Kラウンドのディスクリミネーター損失関数、およびゲームのTラウンド後と見なされます

                                       

それを仮定するとV ^ *、以下の式を得ることができる発電機と識別器はそれぞれ「後悔を有する」ジェネレータと弁別、のバランス値は次のとおりです。

これは、発電機の最善の解決策は、最良の解決策ということである、と言うことです\ bar {\ theta} _T弁別がされ\ bar {\ phi} _Tていること、私たちは「反省」を到達する前に最良のモデルを発見しました。ただし、実際には、ジェネレーターとディスクリミネーターを同時に「後悔」に到達させることは難しいため\ frac {R_1(T)+ R_2(T)} {T}、全体の「後悔」の用語としておおよその用語を使用します。

正規化

次の通常のルールを使用します。

その中に\オメガ(\ cdot)は、通常の機能と\ eta学習率があります。

元のGANの目的関数が次のように表現されている場合

次に、xとzの期待を考慮すると、次のように表すこともできます。

弁別器の最適化

ジェネレータの最適化

次に、最適な値に最適化する前に、それをエラー境界として使用できます

部分平衡

モデル崩壊の出現を防ぎます。

勾配ペナルティ

簡単に言うと、モードが崩壊すると、ディスクリミネーターの勾配がすぐに「スパイク」として表示されます。つまり、ディスクリミネーターの勾配が急速に低下し、最適化速度がジェネレーターの速度を超えて、ジェネレーターがさまざまな分布をマッピングするようになります。モードへ。次に、勾配ペナルティ項を導入して、ジェネレータを待機するようにディスクリミネータの最適化速度を制限します。

上記の方法は確かにGANのトレーニングを安定させることができますが、ノイズペナルティ項の導入はディスクリミネーターのパフォーマンスを妨げるため、最終的なジェネレーターの分布にもフィッティングノイズがかかるため、改善されます。

つまり、ノイズのサイズが制限されます。生成分布とランダム分布がすでに近い場合、つまりジェネレータがディスクリミネータに適合している場合、ノイズペナルティは発生しません。

この記事は最終的に次のペナルティメカニズムを採用しました、\ lambda \ sim 10、K = 1、C \ sim 10そして効果は実験の後でより良くなります:

カップリングペナルティ

著者は、ディスクリミネーターの勾配ペナルティが実際にはWGAN_GP結合ペナルティ項で使用されていると述べました。これにより、ジェネレーターはより良い分布を学習できますが、グローバルに実行されます。この記事では、ローカルデータ分布で実行されます。 、次に、比較中の違いは何ですか?

WGAN_GPディスクリミネーターLOSS

著者によって提案されたDRAGANの弁別器LOSS:

最大の違いは、ペナルティ項目が実データに関連しており、学習分布の影響を受けないことです。学習生成分布は学習データに関連しているため、ある意味ではグローバルサンプルに関連しています。関連性があり、変更後のペナルティ期間は部分的なサンプルにのみ関連しています。

コードの練習

参照リンク:https//github.com/jfsantos/dragan-pytorch

# coding: utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import numpy as np
import torch
from torch.autograd import Variable, grad
from torch.nn.init import xavier_normal
from torchvision import datasets, transforms
import torchvision.utils as vutils

def xavier_init(model):
    for param in model.parameters():
        if len(param.size()) == 2:
            xavier_normal(param)


if __name__ == '__main__':
    batch_size = 128
    z_dim = 100
    h_dim = 128
    y_dim = 784
    max_epochs = 1000
    lambda_ = 10

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ])),
        batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor()
                       ])),
        batch_size=batch_size, shuffle=False, drop_last=True)

    generator = torch.nn.Sequential(torch.nn.Linear(z_dim, h_dim),
            torch.nn.Sigmoid(),
            torch.nn.Linear(h_dim, y_dim),
            torch.nn.Sigmoid())

    discriminator = torch.nn.Sequential(torch.nn.Linear(y_dim, h_dim),
            torch.nn.Sigmoid(),
            torch.nn.Linear(h_dim, 1),
            torch.nn.Sigmoid())

    # Init weight matrices (xavier_normal)
    xavier_init(generator)
    xavier_init(discriminator)

    opt_g = torch.optim.Adam(generator.parameters())
    opt_d = torch.optim.Adam(discriminator.parameters())

    criterion = torch.nn.BCELoss()
    X = Variable(torch.FloatTensor(batch_size, y_dim))
    z = Variable(torch.FloatTensor(batch_size, z_dim))
    labels = Variable(torch.FloatTensor(batch_size))

    # Train
    for epoch in range(max_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            X.data.copy_(data)

            # Update discriminator
            # train with real
            discriminator.zero_grad()
            pred_real = discriminator(X)
            labels.data.fill_(1.0)
            loss_d_real = criterion(pred_real, labels)
            loss_d_real.backward()

            # train with fake
            z.data.normal_(0, 1)
            fake = generator.forward(z).detach()
            pred_fake = discriminator(fake)
            labels.data.fill_(0.0)
            loss_d_fake = criterion(pred_fake, labels)
            loss_d_fake.backward()

            # gradient penalty
            alpha = torch.rand(batch_size, 1).expand(X.size())
            x_hat = Variable(alpha * X.data + (1 - alpha) * (X.data + 0.5 * X.data.std() * torch.rand(X.size())), requires_grad=True)
            pred_hat = discriminator(x_hat)
            gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
                    create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradient_penalty = lambda_ * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
            gradient_penalty.backward()

            loss_d = loss_d_real + loss_d_fake + gradient_penalty
            opt_d.step()

            # Update generator
            generator.zero_grad()
            z.data.normal_(0, 1)
            gen = generator(z)
            pred_gen = discriminator(gen)
            labels.data.fill_(1.0)
            loss_g = criterion(pred_gen, labels)
            loss_g.backward()
            opt_g.step()

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                  % (epoch, max_epochs, batch_idx, len(train_loader),
                     loss_d.item(), loss_g.item()))

            if batch_idx % 100 == 0:
                vutils.save_image(data,
                        'samples/real_samples.png')
                fake = generator(z)
                vutils.save_image(gen.data.view(batch_size, 1, 28, 28),
                        'samples/fake_samples_epoch_%03d.png' % epoch)


ミニストテスト結果

おすすめ

転載: blog.csdn.net/fan1102958151/article/details/106965809