Redes adversarias generativas basadas en energía (Xiaobai aprende GAN 10)

Enlace original: https://arxiv.org/pdf/1609.03126.pdf

Introducción

Antecedentes: este artículo vuelve al problema que GAN no puede solucionar, es decir, cómo optimizar de manera estable. Este artículo es diferente de los anteriores "WGAN" y "GAN de la serie W". No logra este objetivo diseñando una función objetivo, sino que encuentra otra manera. Esto se logra cambiando la estructura del discriminador.

Idea central: cambie el discriminador a una función de energía , de modo que la energía de los datos sea baja cuando esté cerca de la distribución popular y la energía sea alta en otros lugares.

Como puede verse en la figura anterior, la estructura del discriminador ha cambiado. No es como la estructura de red neuronal única anterior, sino que está compuesta por un par de codificadores y decodificadores, y la salida de energía se construye en base a la estructura completa del códec. La ventaja más directa de este diseño es que no necesitamos devanar nuestros cerebros para construir nuestra propia función LOSS específica, sino que podemos usar directamente la LOSS madura, simplemente coloque este marco y deje que su energía de salida sea el objetivo de optimización.

estructura basica

Concepto basico

Equilibrio de Nash

También conocido como equilibrio de juego no cooperativo, en un proceso de juego, independientemente de la elección de estrategia del oponente, una de las partes elegirá una determinada estrategia, que se denomina estrategia dominante. Si las combinaciones de estrategias de las dos partes del juego constituyen sus respectivas estrategias dominantes, entonces esta combinación se define como un equilibrio de Nash. Una combinación de estrategias se denomina equilibrio de Nash. Cuando la estrategia de equilibrio de cada jugador es lograr el valor máximo de su rendimiento esperado, al mismo tiempo, todos los demás jugadores también siguen esta estrategia.

Luego, en GAN, cuando el proceso de juego cooperativo entre el discriminador y el generador se acerque al equilibrio de Nash, caeremos en el cuello de botella de la optimización.

Función objetiva

                                                     

Entre ellos [\ cdot] ^ + = max (0, \ cdot), m aquí está el margen para estar satisfecho m \ leq D (G (z)). Entonces, las dos funciones de PÉRDIDA anteriores añaden un elemento de restricción en el proceso de optimización del discriminador Este elemento de restricción aumenta la PÉRDIDA del discriminador cuando los datos generados están demasiado cerca de la distribución x, y acelera el proceso de optimización del discriminador; Sin embargo, cuando los datos generados están demasiado lejos de la distribución x, el elemento de restricción restringe la PÉRDIDA generada cuando los datos generados pasan a través del discriminador, es decir, el proceso de optimización del discriminador se detiene primero y la optimización del generador arrastrará la distribución de los datos generados a X.

                                                    V (G, D) = \ int_ {x, z} \ mathfrak {L} _D (x, z) p_ {datos} (x) p_z (z) dxdz

                                                    U (G, D) = \ int _z \ mathfrak {L} _G (z) p_z (z) dz

Minimice V cuando entrene el discriminador y minimice U cuando entrene el generador. G y D forman un par de equilibrio de Nash, luego satisfaga:

                                                  V (G ^ *, D ^ *) \ leq V (G ^ *, D), \ forall D

                                                  U (G ^ *, D ^ *) \ leq V (G, D ^ *), \ forall G

D representa el discriminador óptimo y G representa el generador óptimo, por lo que podemos determinar los límites superiores de las dos optimizaciones.

Debido a que la relación entre la estructura, D ^ * (x) \ leq mestá V (G ^ *, D)en el valor mínimo, hay

Dado que los dos factores del segundo término deben ser uno positivo y otro negativo, el valor integrado debe estar entre [-1,0], por lo que el valor máximo es m, es decir V (G ^ *, D ^ *) \ leq m. Así que de nuevo porque

Entonces m \ leq V (G ^ *, D ^ *), al final m \ leq V (G ^ *, D ^ *) \ leq m, es decir m = V (G ^ *, D ^ *), cuando ocurra esta situación

                                  

Este elemento es cero, es decir P_ {datos} = P_G, se alcanza nuestro objetivo de optimización.

Estructura del códec

Un problema común con el entrenamiento de codificadores automáticos es que lo que el modelo puede aprender no es una función de identidad, lo que significa que puede asignar todo el espacio a energía cero. Para evitar este problema, se debe forzar al modelo a proporcionar mayor energía a puntos fuera de la variedad de datos. Este tipo de normalizador está diseñado para limitar la capacidad de reconstrucción del codificador automático de modo que solo pueda clasificar la energía baja en una parte más pequeña del punto de entrada.

La función de energía (discriminador) en el marco de EBGAN también se considera normalizada por un generador que genera muestras comparativas El discriminador debe dar a las muestras de contraste una alta energía de reconstrucción. Desde este punto de vista, el marco EBGAN permite más flexibilidad, porque: (i) el normalizador (generador) se puede entrenar en lugar de especificar manualmente; (2) el modo de entrenamiento adversario produce muestras contrastantes y energía de aprendizaje Los dos objetivos de la función pueden interactuar directamente.

La elección del codificador automático de D parece arbitraria a primera vista, pero la configuración del autor la hace más atractiva que las redes de clasificación binaria:
(1) La salida basada en la reconstrucción no es utilizar una sola información de destino para entrenar el modelo, sino Proporcionar objetivos diversificados para el discriminador. Debido a la red de clasificación binaria, solo son posibles dos objetivos, por lo que en un lote pequeño, es más probable que los gradientes correspondientes a diferentes muestras estén lejos de ser ortogonales, lo que conduce a un entrenamiento ineficiente, y el hardware actual generalmente no proporciona Reducir el tamaño de lotes pequeños. Por otro lado, la pérdida de reconstrucción puede producir direcciones de gradiente muy diferentes dentro de un lote, lo que permite tamaños de lote más grandes sin pérdida de eficiencia.
(2) Tradicionalmente, los codificadores automáticos se utilizan para representar modelos basados ​​en energía. Cuando se entrena con regularización, el autocodificador puede aprender múltiples energías sin supervisión ni contraejemplos. Esto significa que cuando se entrena el modelo de codificación automática EBGAN para reconstruir muestras reales, el discriminador también ayuda a encontrar la variedad de datos. Por el contrario, si no hay ejemplos negativos del generador, el discriminador entrenado con la pérdida de clasificación binaria pierde sentido.

El párrafo anterior está tomado de la traducción del artículo. Se puede resumir como la introducción de la estructura del códec, para que los datos generados puedan tener más diversidad. Al mismo tiempo, debido a que se usa el códec, el juicio simple original de verdadero y falso debe adaptarse a la pérdida antes y después de la codificación. MSE utilizado en el texto.

Resultados de código y práctica

Enlace de referencia: https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/ebgan/ebgan.py

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=62, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # Upsampling
        self.down = nn.Sequential(nn.Conv2d(opt.channels, 64, 3, 2, 1), nn.ReLU())
        # Fully-connected layers
        self.down_size = opt.img_size // 2
        down_dim = 64 * (opt.img_size // 2) ** 2

        self.embedding = nn.Linear(down_dim, 32)

        self.fc = nn.Sequential(
            nn.BatchNorm1d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(32, down_dim),
            nn.BatchNorm1d(down_dim),
            nn.ReLU(inplace=True),
        )
        # Upsampling
        self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, opt.channels, 3, 1, 1))

    def forward(self, img):
        out = self.down(img)
        embedding = self.embedding(out.view(out.size(0), -1))
        out = self.fc(embedding)
        out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
        return out, embedding


# Reconstruction loss of AE
pixelwise_loss = nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def pullaway_loss(embeddings):
    norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
    normalized_emb = embeddings / norm
    similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0))
    batch_size = embeddings.size(0)
    loss_pt = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
    return loss_pt


# ----------
#  Training
# ----------

# BEGAN hyper parameters
lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)
        recon_imgs, img_embeddings = discriminator(gen_imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = pixelwise_loss(recon_imgs, gen_imgs.detach()) + lambda_pt * pullaway_loss(img_embeddings)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_recon, _ = discriminator(real_imgs)
        fake_recon, _ = discriminator(gen_imgs.detach())

        d_loss_real = pixelwise_loss(real_recon, real_imgs)
        d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())

        d_loss = d_loss_real
        if (margin - d_loss_fake.data).item() > 0:
            d_loss += margin - d_loss_fake

        d_loss.backward()
        optimizer_D.step()

        # --------------
        # Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

resultados de la prueba mnist

Se puede ver que el resultado del entrenamiento no es el ideal, porque solo demuestra que este método es optimizado y estable, pero quizás la velocidad de convergencia no sea rápida.

 

 

 

Supongo que te gusta

Origin blog.csdn.net/fan1102958151/article/details/106562544
Recomendado
Clasificación