Pytorch implements autoencoder variants

Table of contents

1.Denoising Auto-Encoder

2.Dropout Auto-Encoder

3.Adversarial Auto-Encoder

(1) Schematic diagram

(2) Explanation 


Autoencoder principle and implementation using Pytorch framework (AutoEncoder)

1.Denoising Auto-Encoder

        Generally speaking, the training of the autoencoder network is relatively stable, which is indeed the case from the previous training effect of the autoencoder. However, since the loss function directly measures the distance between the reconstruction sample and the underlying features of the real sample, the above The metric calculation of the given autoencoder is measured by Euclidean distance. However, it is not an abstract index to evaluate the fidelity and diversity of the reconstructed samples, so the effect on some tasks is mediocre, such as image reconstruction, which is prone to blurred edges of the reconstructed image, and the fidelity is still far behind the real picture .

        In order for the encoder to try to learn the true distribution of the data, a series of autoencoder variant networks are generated.

Github code implementation:  https://github.com/KeepTryingTo/Pytorch-GAN

2.Dropout Auto-Encoder

        The autoencoder network also faces the risk of over-fitting, so it also needs to be regularized in the network. Dropout Auto-Encoder reduces the expressiveness of the network by randomly disconnecting the network to prevent over-fitting .

        Dropout Auto-Encoder can directly add the Dropout layer to the network.

Github code implementation:   https://github.com/KeepTryingTo/Pytorch-GAN

3.Adversarial Auto-Encoder

Tip: It is recommended that readers read this part of the theory before reading the basic part of the Generative Adversarial Network. 

GAN principle and Pytorch framework to realize GAN (relatively easy to understand) 

Pytorch framework implements DCGAN (relatively easy to understand)

The basic principle of CycleGAN and the implementation of Pytorch framework

The basic principle of WGAN and the implementation of WGAN by Pytorch

Pytorch framework implements WGAN-GP

(1) Schematic diagram

(2) Explanation 

        Sampling the hidden vector z from a known prior distribution p(z), it is convenient to use p(z) to reconstruct the input, and the anti-autoencoder uses an additional discriminator network (Discriminator) to determine the dimensionality-reduced hidden vector z Whether to sample from the prior distribution p(z). As shown in the figure above, the input of the discriminator network is a variable belonging to the [0,1] interval, indicating whether the hidden vector is sampled from the prior distribution p(z); all z samples from the prior distribution p(z) are marked as True, the conditional probability q(z|x) of the sampled self-encoder is labeled as false.

        Through the above training method, in addition to reconstructing samples, it is also possible to constrain the conditional probability distribution q(z|x) to approximate the prior distribution p(z). (Adversarial autoencoders are derived from algorithms that generate adversarial networks).

 Github code implementation:   https://github.com/KeepTryingTo/Pytorch-GAN

core part

     #train encoder and decoder
        z_en = encoder(imgs)
        z_fake = decoder(z_en)
        loss_ae = loss_AE(z_fake,imgs)
        opt_AE.zero_grad()
        loss_ae.backward()
        opt_AE.step()
        step_loss_AE += loss_ae.item()
        # ----------------------------------------------------------------

        # ----------------------------------------------------------------
        #train discriminator
        z_size = np.shape(imgs)[0]
        z_real = torch.randn(size=(z_size,128)).to(config.DEVICE)
        z_en_fake = encoder(imgs).detach()
        discInput = torch.cat((z_real,z_en_fake),dim = 0)
        discLabel = torch.cat((torch.ones(z_size,1),torch.zeros(z_size,1)),dim = 0).to(config.DEVICE)

        discOutput = disc(discInput)
        loss_disc_out = loss_disc(discOutput,discLabel)
        opt_disc.zero_grad()
        loss_disc_out.backward()
        opt_disc.step()
        step_loss_disc += loss_disc_out.item()
        # ----------------------------------------------------------------

        # ----------------------------------------------------------------
        #train encoder
        z_en = encoder(imgs).detach()
        enOutput = disc(z_en)
        loss_en_out = loss_en(enOutput,torch.ones(z_size,1).to(config.DEVICE))
        opt_en.zero_grad()
        loss_en_out.backward()
        opt_en.step()
        step_loss_en += loss_en_out.item()
        # ----------------------------------------------------------------

Guess you like

Origin blog.csdn.net/Keep_Trying_Go/article/details/130623570