Deep learning 9: Simple understanding of the principle of generative confrontation network

Table of contents

generating algorithm

Generative Adversarial Networks (GANs)

"Generation" section

"Adversarial" part

How do GANs work?

Tips for training GANs?

GAN code example

How can GANs be improved?

in conclusion


generating algorithm

You can group generation algorithms into one of three buckets:

  1. Given the labels, they predict the associated features (Naive Bayesian)
  2. Given hidden representations, they predict associated features (variational autoencoders, generative adversarial networks)
  3. Given some features, they predict the rest (repair, imputation)

We'll explore some basics of Generative Adversarial Networks ! GANs have incredible potential because they can learn to mimic any data distribution. That is, GANs can learn to create worlds in any domain : images, music, speech.

Example GAN architecture

Generative Adversarial Networks (GANs)

"Generation" section

  • called a generator .
  • Given a certain label , try to predict a function .
  • EX: Predict (generate) the text of an email given that the email is marked as spam.
  • Generative models learn distributions over classes.

"Adversarial" part

  • called the discriminator .
  • Given these features , try to predict labels .
  • EX: Based on the text of the email, predict (distinguish) spam or non-spam.
  • Discriminative models learn the boundaries between classes.

How do GANs work?

A neural network called the Generator generates new data instances, while another neural network, the Discriminator, evaluates their authenticity.

You can think of GAN as a cat and mouse game between the counterfeiter (Generator) and the police (Discriminator). Counterfeiters are learning to make fake money, and police are learning how to detect it. They are all learning and improving. Counterfeiters keep learning to create better fakes, and police keep getting better at detecting them. The end result is that counterfeiters (generators) are now trained to create surreal amounts of money!

Let's explore a concrete example with the MNIST dataset of handwritten digits:

MNIST Handwritten Digits Dataset

We will have the Generator create new images, like those in the MNIST dataset, taken from the real world. When shown instances from the real MNIST dataset, the goal of the Discriminator is to identify them as real.

Meanwhile, the Generator is creating new images that are passed to the Discriminator. It does so in the hope that they will also be considered real, even if they are fake. The goal of the Generator is to generate passable handwritten digits in order to lie without being caught. The goal of the Discriminator is to classify images from the Generator as fake.

MNIST handwritten digits + GAN architecture

GAN steps:

  1. The generator takes in a random number and returns an image.
  2. The generated images are fed into the discriminator along with a stream of images taken from the real dataset .
  3. The discriminator takes in real and fake images and returns a probability, a number between 0 and 1, with 1 being a prediction of authenticity and 0 being fake .

Two feedback loops:

  1. The discriminator is in a feedback loop and has the ground truth of the images (whether they are real or fake), which we know.
  2. The Generator is in a feedback loop with the Discriminator (the Discriminator marks it as real or fake, regardless of the truth).

Tips for training GANs?

Pre-identifying the discriminator before starting to train the generator will build sharper gradients.

When training the Discriminator, keep the Generator value unchanged. When training the generator, keep the Discriminator value constant. This allows the network to better understand the gradients it has to learn.

GANs are formulated as a game between two networks, important: keeping them balanced. GANs can be difficult to learn if the generator or discriminator is too good.

GANs take a long time to train. GANs can take hours on a single GPU and days on a single CPU.

GAN code example

class GAN():
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity 
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)
        
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("gan/images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)

How can GANs be improved?

GANs were just invented in 2014 - they are very new! GANs are a promising family of generative models because, unlike other methods, they generate very clean and sharp images and learn weights that contain valuable information about the underlying data. However, as mentioned above, it may be difficult to keep the Discriminator and Generator networks in balance. There is a lot of work in progress to make GAN training more stable.

In addition to producing beautiful images, a method for semi-supervised learning with GANs has been developed that involves the discriminator producing additional outputs indicative of the input labels. This approach can achieve state-of-the-art results on datasets using very few labeled examples. For example, on MNIST, a fully connected neural network with only 10 labeled examples per class achieves 99.1% accuracy – a result very close to the best known for a fully supervised method using all 60,000 labeled examples result. This is very promising since obtaining labeled examples can be very expensive in practice.

in conclusion

Guess you like

Origin blog.csdn.net/qq_38998213/article/details/132516465