Generando imágenes usando GAN en TensorFlow

1. Descripción

        Este artículo analiza en detalle cómo implementar GAN en la recopilación de datos mnist bajo tensorflow. Incluyendo: implementación de código de pasos específicos como establecimiento del marco, lectura de conjuntos de datos, generador, discriminador, función de costos, optimización, etc.

2. Introducción al marco GAN

  • Generador : Este componente se encarga de generar nuevas imágenes.
  • Discriminador : este componente evalúa la calidad de la imagen generada.

        La arquitectura general que desarrollaremos para generar imágenes usando GAN se muestra en la siguiente figura. Las siguientes secciones describen brevemente cómo leer la base de datos, crear la arquitectura requerida, calcular la función de pérdida y entrenar la red. Además, se proporciona código para inspeccionar la red y generar nuevas imágenes.

3. Leer el conjunto de datos

        El conjunto de datos MNIST ocupa una posición importante en el campo de la visión por computadora e incluye una gran cantidad de dígitos escritos a mano con dimensiones de 28 × 28 píxeles. Este conjunto de datos resultó ideal para nuestra implementación de GAN debido a su formato de imagen de un solo canal en escala de grises.

        El fragmento de código que sigue demuestra cómo cargar el conjunto de datos MNIST utilizando funciones integradas en Tensorflow. Después de una carga exitosa, procedemos a normalizar y remodelar la imagen en formato 3D. Esta transformación permite el procesamiento eficiente de datos de imágenes 2D en la arquitectura GAN. Además, se asigna memoria para datos de entrenamiento y validación.

        La forma de cada imagen se define como una matriz de 28x28x1, donde la última dimensión representa el número de canales de la imagen. Dado que el conjunto de datos MNIST contiene imágenes en escala de grises, solo tenemos un canal.

        En este caso particular, establecemos el tamaño del espacio latente (denominado "zsize") en 100. Este valor se puede ajustar en función de requisitos o preferencias específicas.

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as np

num_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100

(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

4. Definir el generador

        El generador (D) juega un papel crucial en GAN ya que es responsable de generar imágenes realistas que pueden engañar al discriminador. Es el componente principal de la formación de imágenes en GAN. En este estudio, explotamos una arquitectura específica del generador, que contiene una capa completamente conectada (FC) y adopta la activación Leaky ReLU. Sin embargo, vale la pena señalar que la última capa del generador utiliza la activación TanH en lugar de LeakyReLU. Este ajuste se realiza para garantizar que las imágenes generadas residan dentro del mismo intervalo (-1, 1) que la base de datos MNIST original.

def build_generator():
    gen_model = Sequential()
    gen_model.add(Dense(256, input_dim=z_size))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(512))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(1024))
    gen_model.add(LeakyReLU(alpha=0.2))
    gen_model.add(BatchNormalization(momentum=0.8))
    gen_model.add(Dense(np.prod(input_shape), activation='tanh'))
    gen_model.add(Reshape(input_shape))

    gen_noise = Input(shape=(z_size,))
    gen_img = gen_model(gen_noise)
    return Model(gen_noise, gen_img)

5. Definir el discriminador

        En una Red Adversaria Generativa (GAN), el discriminador (D) realiza la tarea clave de distinguir entre imágenes reales y generadas mediante la evaluación de la autenticidad y la probabilidad. Este componente puede verse como un problema de clasificación binaria. Para realizar esta tarea, podemos adoptar una arquitectura de red simplificada que incluya capas completamente conectadas (FC), activación ReLU con fugas y capas de desconexión. Cabe mencionar que la última capa del discriminador consta de una capa FC seguida de una activación sigmoidea. La función de activación sigmoidea produce las probabilidades de clasificación deseadas.

def build_discriminator():
    disc_model = Sequential()
    disc_model.add(Flatten(input_shape=input_shape))
    disc_model.add(Dense(512))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(256))
    disc_model.add(LeakyReLU(alpha=0.2))
    disc_model.add(Dense(1, activation='sigmoid'))

    disc_img = Input(shape=input_shape)
    validity = disc_model(disc_img)
    return Model(disc_img, validity)

6. Calcule la función de pérdida.

        Para garantizar un buen proceso de generación de imágenes en GAN, es importante determinar métricas apropiadas para evaluar su desempeño. Defina este parámetro mediante una función de pérdida.

        El discriminador es responsable de clasificar las imágenes generadas en reales y falsas y de dar la probabilidad de realidad. Para lograr esta diferencia, el discriminador pretende maximizar la función D(x) cuando se presentan imágenes reales y minimizar D(G(z)) cuando se presentan imágenes falsas.

        El generador, por otra parte, pretende engañar al discriminador creando imágenes realistas que pueden malinterpretarse. Matemáticamente, esto implica escalar D(G(z)). Sin embargo, confiar únicamente en este componente como función de pérdida puede hacer que la red confíe demasiado en resultados erróneos. Para resolver este problema, utilizamos el logaritmo de la función de pérdida (D(G(z))).

        La función de costo general de las imágenes generadas por GAN se puede expresar como un juego mínimo:

min_G max_D V (D, G) = E (xp_data (x)) (log (D (x))] + E (zp (z)) (log (1 – D (G (z)))])

        Este tipo de entrenamiento GAN requiere un buen equilibrio y puede servir como competición entre dos oponentes. Cada lado intenta influir y superar al otro jugando un juego de MinMax.

        Podemos implementar el generador y el discriminador utilizando pérdida binaria de entropía cruzada.

        Para la implementación del generador y discriminador, podemos utilizar la pérdida de entropía cruzada binaria.

# discriminator
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

z = Input(shape=(z_size,))

# generator
img = generator(z)

disc.trainable = False

validity = disc(img)

# combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')

7. Optimizar las pérdidas

        Para facilitar el entrenamiento de la red, nuestro objetivo es permitir que las GAN participen en el juego MinMax. Este proceso de aprendizaje gira en torno a la optimización de los pesos de la red mediante el uso de descenso de gradiente. Para acelerar el proceso de aprendizaje y evitar la convergencia a un entorno de pérdidas subóptimo, se emplea el descenso de gradiente estocástico (SGD).

        Dado que el discriminador y el generador tienen pérdidas diferentes, una sola función de pérdida no puede optimizar ambos sistemas simultáneamente. Por lo tanto, se utiliza una función de pérdida separada para cada sistema.

def intialize_model():
    disc= build_discriminator()
    disc.compile(loss='binary_crossentropy',
        optimizer='sgd',
        metrics=['accuracy'])

    generator = build_generator()

    z = Input(shape=(z_size,))
    img = generator(z)

    disc.trainable = False

    validity = disc(img)

    combined = Model(z, validity)
    combined.compile(loss='binary_crossentropy', optimizer='sgd')
    return disc, Generator, and combined

        Después de especificar todas las características requeridas, podemos entrenar el sistema y optimizar la pérdida. Los pasos para entrenar una GAN para generar imágenes son los siguientes:

  • Carga una imagen y genera un sonido aleatorio del mismo tamaño que la imagen cargada.
  • Distinga entre imágenes cargadas y sonidos generados y considere la posibilidad de autenticidad.
  • Se genera otro ruido aleatorio de la misma amplitud y se utiliza como entrada al generador.
  • Entrene el generador durante un período de tiempo específico.
  • Repita estos pasos hasta que la imagen sea satisfactoria.
def train(epochs, batch_size=128, sample_interval=50):
    # load images
    (train_ims, _), (_, _) = mnist.load_data()
    # preprocess
    train_ims = train_ims / 127.5 - 1.
    train_ims = np.expand_dims(train_ims, axis=3)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    # training loop
    for epoch in range(epochs):

        batch_index = np.random.randint(0, train_ims.shape[0], batch_size)
        imgs = train_ims[batch_index]
    # create noise
        noise = np.random.normal(0, 1, (batch_size, z_size))
    # predict using a Generator
        gen_imgs = gen.predict(noise)
    # calculate loss functions
        real_disc_loss = disc.train_on_batch(imgs, valid)
        fake_disc_loss = disc.train_on_batch(gen_imgs, fake)
        disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)

        noise = np.random.normal(0, 1, (batch_size, z_size))

        g_loss = full_model.train_on_batch(noise, valid)
   
    # save outputs every few epochs
        if epoch % sample_interval == 0:
            one_batch(epoch)

8. Genera números escritos a mano.

        Usando el conjunto de datos MNIST, podemos crear una función de utilidad para generar predicciones para un conjunto de imágenes usando un generador. La función genera un sonido aleatorio, lo envía al generador, lo ejecuta para mostrar la imagen generada y la guarda en una carpeta especial. Se recomienda ejecutar esta función de utilidad periódicamente, por ejemplo cada 200 ciclos, para monitorear el progreso de la red. La implementación es la siguiente:

def one_batch(epoch):
    r, c = 5, 5
    noise_model = np.random.normal(0, 1, (r * c, z_size))
    gen_images = gen.predict(noise_model)

    # Rescale images 0 - 1
    gen_images = gen_images*(0.5) + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    plt.close()

        En nuestros experimentos, entrenamos aproximadamente 32 GAN utilizando un tamaño de lote de 10 000. Para realizar un seguimiento del progreso del entrenamiento, guardamos las imágenes generadas cada 200 épocas y las almacenamos en una carpeta designada llamada "imágenes".

disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)

        Ahora, verifiquemos los resultados de la simulación GAN en diferentes etapas: inicialización, 400 épocas, 5000 épocas y los resultados finales en 10000 épocas.

Inicialmente, comenzamos con ruido aleatorio como entrada al generador.

        Después de 400 épocas de entrenamiento, podemos observar algunos avances, aunque las imágenes generadas siguen siendo muy diferentes de las cifras reales.

        Después de entrenar durante 5000 épocas, podemos observar que los números generados comienzan a parecerse al conjunto de datos MNIST.

        Después de completar el entrenamiento completo de 10,000 épocas, obtenemos el siguiente resultado.

        Estas imágenes generadas son muy similares a los datos de dígitos escritos a mano que se utilizan para entrenar la red. Es importante tener en cuenta que estas imágenes no forman parte del conjunto de entrenamiento y son generadas en su totalidad por la red.

9. Próximos pasos

        Ahora que hemos logrado buenos resultados con GAN para la generación de imágenes, hay muchas formas de mejorarlo aún más. Dentro del alcance de esta discusión, podemos considerar probar diferentes parámetros. Aquí hay algunas sugerencias:

  • Explore diferentes valores de la variable de espacio latente z_size para ver si mejora la eficiencia.
  • Aumentar el número de épocas de entrenamiento más allá de 10.000. Duplicar o triplicar la duración del entrenamiento puede mostrar resultados mejorados o degradados.
  • Intente utilizar un conjunto de datos diferente, como Fashion MNIST o Mobile MNIST. Dado que estos conjuntos de datos tienen la misma estructura que MNIST, adapte nuestro código existente.
  • Considere probar arquitecturas alternativas como CycleGun, DCGAN, etc. Modificar las funciones generadora y discriminadora puede ser suficiente para explorar estos modelos.

        Al implementar estos cambios, podemos mejorar aún más las capacidades de las GAN y explorar nuevas posibilidades para la generación de imágenes.

        Estas imágenes generadas son muy similares a los datos de dígitos escritos a mano que se utilizan para entrenar la red. Estas imágenes no forman parte del conjunto de entrenamiento y son generadas en su totalidad por la red.

10. Conclusión

        En resumen, GAN es un poderoso modelo de aprendizaje automático capaz de generar nuevas imágenes basadas en bases de datos existentes. En este tutorial, mostramos cómo diseñar y entrenar una GAN simple usando la biblioteca Tensorflow como ejemplo y la base de datos MNIST.

        Conclusiones clave

  • GAN consta de dos componentes importantes: un generador, responsable de generar nuevas imágenes a partir de entradas aleatorias, y un discriminador, diseñado para distinguir imágenes reales y falsas.
  • A través del proceso de aprendizaje, logramos crear un conjunto de imágenes que se parecen mucho a dígitos escritos a mano, como se muestra en la imagen de ejemplo.
  • Para optimizar el rendimiento de GAN, proporcionamos indicadores coincidentes y funciones de pérdida para ayudar a distinguir imágenes reales y falsas. Al evaluar las GAN con datos invisibles y utilizar el generador, podemos generar imágenes nuevas nunca antes vistas.
  • En general, las GAN ofrecen posibilidades interesantes en la generación de imágenes y tienen un gran potencial en una variedad de aplicaciones como el aprendizaje automático y la visión por computadora.

11. Preguntas frecuentes

        Pregunta 1. ¿Qué es una Red Generativa Adversaria (GAN)?

        Respuesta: Una Red Generativa Adversaria (GAN) es un marco de aprendizaje automático que genera nuevos datos con estadísticas similares a un conjunto de entrenamiento determinado. Utilice GAN con muchos tipos de datos, incluidas imágenes, vídeos o texto.

        Pregunta 2. ¿Qué es un modelo creativo?

        uno. Un modelo generativo es un algoritmo de aprendizaje automático que genera nuevos datos basándose en un conjunto de datos de entrada. Utilice estos modelos para tareas como generación de imágenes, generación de texto y otras formas de síntesis de datos.

        Pregunta 3. ¿Qué es la función de pérdida?

        uno. Una función de pérdida es una función matemática que mide la diferencia entre dos conjuntos de datos. En el contexto de las GAN, un generador de modelos se entrena optimizando una función de pérdida que define la diferencia entre los datos generados y los de entrenamiento, generalmente utilizando registros de clase e imágenes anotadas.

        Pregunta 4. ¿Cuál es la diferencia entre CNN y Gan?

        Respuesta: CNN (red neuronal convolucional) y GAN (red generativa adversaria) son arquitecturas de aprendizaje profundo, pero tienen objetivos diferentes. Las GAN son modelos generativos diseñados para generar nuevos datos similares a un conjunto de entrenamiento determinado, mientras que las CNN se utilizan para tareas de clasificación y reconocimiento. Si bien las CNN se pueden utilizar como modelos generativos configurándolas como codificadores automáticos variables (VAE), las CNN funcionan bien en el entrenamiento discriminativo y son más efectivas en tareas de clasificación de imágenes en visión por computadora.

Supongo que te gusta

Origin blog.csdn.net/gongdiwudu/article/details/131563850
Recomendado
Clasificación