エネルギーベースの生成的敵対ネットワーク(XiaobaiはGAN 10を学習します)

元のリンク:https//arxiv.org/pdf/1609.03126.pdf

前書き

背景:この記事は、GANが回避できない問題、つまり安定して最適化する方法に戻ります。この記事は、以前の「WGAN」および「WシリーズGAN」とは異なり、目的関数を設計することによってこの目標を達成するのではなく、別の方法を見つけます。これは、ディスクリミネーターの構造を変更することによって実現されます。

コアアイデア:ディスクリミネーターをエネルギー関数変更して、人気のある分布に近いときデータのエネルギーを低くし、他の場所ではエネルギーを高くするようにします。

上図からわかるように、ディスクリミネーターの構造が変化し、従来のシングルニューラルネットワーク構造とは異なり、エンコーダーとデコーダーのペアで構成され、コーデック構造全体に基づいてエネルギー出力が構築されています。この設計の最も直接的な利点は、独自の特定のLOSS関数を構築するために頭を悩ませる必要がないことですが、成熟したLOSSを直接​​使用して、このフレームワークを配置し、その出力エネルギーを最適化の目標として使用できます。

基本構造

基本コンセプト

ナッシュ平衡

非協力的なゲーム均衡とも呼ばれ、ゲームプロセスでは、対戦相手の戦略の選択に関係なく、一方の当事者が支配的な戦略と呼ばれる特定の戦略を選択します。ゲーム内の2つのパーティの戦略の組み合わせがそれぞれの支配的な戦略を構成する場合、この組み合わせはナッシュ均衡として定義されます。戦略の組み合わせはナッシュバランスと呼ばれます。各プレーヤーのバランス戦略が期待されるリターンの最大値を達成することである場合、同時に、他のすべてのプレーヤーもこの戦略に従います。

次にGANでは、ディスクリミネーターとジェネレーターの間の協調ゲームプロセスがナッシュ平衡に近づくと、最適化のボトルネックに陥ります。

目的関数

                                                     

その中で[\ cdot] ^ + = max(0、\ cdot)、ここでのmは満たすべきマージンm \ leq D(G(z))です。次に、上記の2つのLOSS関数は、実際にディスクリミネーターの最適化プロセスに制限項目を追加します。この制限項目は、生成されたデータがx分布に近すぎる場合にディスクリミネーターの損失を拡大し、ディスクリミネーターの最適化プロセスを高速化します。ただし、生成されたデータがx分布から離れすぎている場合、制限項目は、生成されたデータがディスクリミネーターを通過するときに生成されるLOSSを制限します。つまり、ディスクリミネーターの最適化プロセスが最初に停止され、ジェネレーターの最適化は、生成されたデータの分布をバツ。

                                                    V(G、D)= \ int_ {x、z} \ mathfrak {L} _D(x、z)p_ {data}(x)p_z(z)dxdz

                                                    U(G、D)= \ int _z \ mathfrak {L} _G(z)p_z(z)dz

弁別器をトレーニングするときはVを最小化し、ジェネレーターをトレーニングするときはUを最小化します。GとDはナッシュ平衡のペアを形成し、次の条件を満たす必要があります。

                                                  V(G ^ *、D ^ *)\ leq V(G ^ *、D)、\ forall D

                                                  U(G ^ *、D ^ *)\ leq V(G、D ^ *)、\ forall G

Dは最適なディスクリミネーターを表し、Gは最適なジェネレーターを表すため、2つの最適化の上限を決定できます。

構造の関係があるためD ^ *(x)\ leq mであるV(G ^ *、D)最小値であります

第2項の2つの要素は、正と負の2つでなければならないため、積分値は[-1,0]の間にある必要があります。したがって、最大値はm、つまりですV(G ^ *、D ^ *)\ leq mだから再び

そしてm \ leq V(G ^ *、D ^ *)、最後にm \ leq V(G ^ *、D ^ *)\ leq m、つまりm = V(G ^ *、D ^ *)、この状況が発生したとき

                                  

この項目はゼロです。つまりP_ {データ} = P_G、最適化の目標達成されています。

コーデック構造

自動エンコーダーのトレーニングに関する一般的な問題は、モデルが学習する可能性があるのはID関数ではないことです。つまり、空間全体をゼロエネルギーに割り当てる可能性があります。この問題を回避するには、モデルがデータマニホールドの外側のポイントにより高いエネルギーを提供するように強制する必要があります。この種のノーマライザーは、自動エンコーダーの再構築能力を制限するように設計されているため、低エネルギーを入力ポイントのより小さな部分にしか分類できません。

EBGANフレームワークのエネルギー関数(ディスクリミネーター)も、比較サンプルを生成するジェネレーターによって正規化されていると見なされます。ディスクリミネーターは、コントラストサンプルに高い再構成エネルギーを与える必要があります。この観点から、EBGANフレームワークでは、次の理由により、柔軟性が向上します。(i)ノーマライザー(ジェネレーター)を手動で指定するのではなくトレーニングできる。(2)敵対的なトレーニングモードでは、対照的なサンプルと学習エネルギーが生成される。関数の2つの目標は、直接相互作用できます。

Dの自動エンコーダーの選択は一見任意に見えますが、作成者の設定により、バイナリ分類ネットワークよりも魅力的です。
(1)再構築に基づく出力は、モデルのトレーニングに単一のターゲット情報を使用することではなく、弁別器に多様なターゲットを提供します。バイナリ分類ネットワークのため、2つの目標しか不可能であるため、小さなバッチでは、異なるサンプルに対応する勾配が直交から遠くなる可能性が高く、非効率的なトレーニングにつながり、現在のハードウェアは通常提供しません小さなバッチのサイズを減らします。一方、再構築の損失は、バッチ内で非常に異なる勾配方向を生成する可能性があり、効率を損なうことなく、より大きなバッチサイズを可能にします。
(2)従来、自動エンコーダーはエネルギーベースのモデルを表すために使用されていました。正規化を使用してトレーニングする場合、自動エンコーダーは監視や反例なしでエネルギー多様体を学習できます。これは、EBGAN自動コーディングモデルが実際のサンプルを再構築するようにトレーニングされている場合、ディスクリミネーターはデータマニホールドの検索にも役立つことを意味します。逆に、ジェネレーターからの否定的な例がない場合、バイナリ分類損失でトレーニングされた弁別器は無意味になります。

上記の段落は記事の翻訳から抜粋したもので、コーデック構造の導入として要約できるため、生成されたデータの多様性を高めることができます。同時に、コーデックが使用されるため、エンコードの前後の損失に、元の単純な真と偽の判断を適合させる必要があります。本文中で使用されているMSE。

コードと練習結果

参照リンク:https//github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/ebgan/ebgan.py

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=62, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # Upsampling
        self.down = nn.Sequential(nn.Conv2d(opt.channels, 64, 3, 2, 1), nn.ReLU())
        # Fully-connected layers
        self.down_size = opt.img_size // 2
        down_dim = 64 * (opt.img_size // 2) ** 2

        self.embedding = nn.Linear(down_dim, 32)

        self.fc = nn.Sequential(
            nn.BatchNorm1d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(32, down_dim),
            nn.BatchNorm1d(down_dim),
            nn.ReLU(inplace=True),
        )
        # Upsampling
        self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, opt.channels, 3, 1, 1))

    def forward(self, img):
        out = self.down(img)
        embedding = self.embedding(out.view(out.size(0), -1))
        out = self.fc(embedding)
        out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
        return out, embedding


# Reconstruction loss of AE
pixelwise_loss = nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def pullaway_loss(embeddings):
    norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
    normalized_emb = embeddings / norm
    similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0))
    batch_size = embeddings.size(0)
    loss_pt = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
    return loss_pt


# ----------
#  Training
# ----------

# BEGAN hyper parameters
lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)
        recon_imgs, img_embeddings = discriminator(gen_imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = pixelwise_loss(recon_imgs, gen_imgs.detach()) + lambda_pt * pullaway_loss(img_embeddings)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_recon, _ = discriminator(real_imgs)
        fake_recon, _ = discriminator(gen_imgs.detach())

        d_loss_real = pixelwise_loss(real_recon, real_imgs)
        d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())

        d_loss = d_loss_real
        if (margin - d_loss_fake.data).item() > 0:
            d_loss += margin - d_loss_fake

        d_loss.backward()
        optimizer_D.step()

        # --------------
        # Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

mnistテスト結果

この方法が最適化されて安定していることを証明するだけなので、トレーニングの結果は理想的ではないことがわかりますが、おそらく収束速度は速くありません。

 

 

 

おすすめ

転載: blog.csdn.net/fan1102958151/article/details/106562544