GAN anime character avatar generation

GAN anime character avatar generation

1 Introduction

A simple DCGAN network was built to generate the avatars of anime characters. The dataset of the avatars of anime characters is taken from kaggle, and the URL is as follows
link

2. Network structure

  1. data set
  2. Builder
  3. discriminator

2.1 Dataset

The data size is 64x64x3, the sample is as follows
insert image description here
insert image description here
insert image description here

2.2 Generator

Since the original input of the generator is n-dimensional noise, if you want to generate a picture with the same size as the data set, you need to upsample. The method we use here is transposed convolution, which is realized by ConvTransposed2d in pytorch .
The generator code is as follows:

class Generator(nn.Module):

    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # out_shape = (1-1)*1-2*0+4 = 4*4
            nn.ConvTranspose2d(noise_dim, 256, kernel_size=4),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # out_shape = (4-1)*2-2*1+4 = 8*8
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # out_shape = (8-1)*2-2*1+4 = 16*16
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # out_shape = (16-1)*2-2*1+4 = 32*32
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # out_shape = (32-1)*2-2*1+4 = 64*64
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, input):

        output = self.net(input)
        return output

In the training phase, we will generate random noise of batch_sizex100x1x1 size, and then through the up-sampling of the generator to realize the fake picture with the same size as the data set picture, and then send it to the discriminator to distinguish the real picture from the fake picture.

2.3 Discriminator

The input of the discriminator is the real picture sampled from the data set and the pseudo picture generated by the generator, and the output is a value between 0-1, so the Sigmoid activation function is used at the end of the network.
The purpose of the discriminator is to judge the real picture as "1" (true) and the fake picture as "0" (false) , while the purpose of the generator is to generate a fake picture that is good enough and close enough to the distribution of the data set to This tricks the generator, so the generator hopes that the fake image it generates will score as close to "1" (true) as possible in the discriminator. In this way, the discriminator and the generator continue to "confront" and finally reach equilibrium or close to equilibrium.
The discriminator code is as follows:

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # 32*32*32
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 16*16*64
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 8*8*128
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 4*4*256
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(4*4*256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input):

        output = self.net(input)
        return output.view(-1)

The network of the discriminator is a simple feed-forward neural network. After convolution and continuous downsampling, the features of the picture are extracted, and finally the score between 0-1 is output as true or false.

3. Training phase

The general process of the training phase is almost the same as the deep learning training process. The most important part is the design and calculation of the label and loss function.
First paste the code of the training phase:

import torch
import torch.nn as nn
from torchvision import transforms
from create_dataset import My_dataset, save_img
from torch.utils.data import DataLoader
from net import Generator, Discriminator

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


dataset = My_dataset('./data', transform=transform)
batch_size, epochs = 256, 200
my_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():

    discriminator = discriminator.cuda()
    generator = generator.cuda()


d_optimizer = torch.optim.Adam(discriminator.parameters(), betas=(0.5, 0.99), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), betas=(0.5, 0.99), lr=1e-4)
criterion = nn.BCELoss()

for epoch in range(epochs):

    for i, img in enumerate(my_dataloader):

        noise = torch.randn(batch_size, 100, 1, 1).cuda()
        real_img = img.cuda()
        fake_img = generator(noise)

        real_label = torch.ones(batch_size).cuda()
        fake_label = torch.zeros(batch_size).cuda()
        real_out = discriminator(real_img)
        fake_out = discriminator(fake_img)
        real_loss = criterion(real_out, real_label)
        fake_loss = criterion(fake_out, fake_label)

        d_loss = real_loss + fake_loss
        d_optimizer.zero_grad()

        d_loss.backward()
        d_optimizer.step()

        noise = torch.randn(batch_size, 100, 1, 1).cuda()
        fake_img = generator(noise)
        output = discriminator(fake_img)

        g_loss = criterion(output, real_label)
        g_optimizer.zero_grad()

        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 5 == 0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D_real: {:.6f},D_fake: {:.6f}'.format(
                epoch, epochs, d_loss.data.item(), g_loss.data.item(),
                real_out.data.mean(), fake_out.data.mean()  # 打印的是真实图片的损失均值
            ))
        if epoch == 0 and i == len(my_dataloader) - 1:
            save_img(img[:64, :, :, :], './sample/real_images.png')
        if (epoch+1) % 10 == 0 and i == len(my_dataloader)-1:
            save_img(fake_img[:64, :, :, :], './sample/fake_images_{}.png'.format(epoch + 1))

torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

Before training, we first need to artificially set the label of the picture to be true or false. Here we set the true to 1 and use the torch.ones function to realize it, and the false setting to 0 and use the torch.zeros function to realize it.
Then it is to calculate the discriminative loss of the pictures in the data set and the pseudo pictures generated by the generator, such as d_loss in the code.
Next is the calculation of the loss of the generator, because the purpose of the generator is to generate a picture that is as real as possible, so the label of the calculation of the loss of the generator is 1. Such as g_loss in the code.

4. Denormalization and results

4.1 Denormalization

Because the data set has been normalized and standardized, it is necessary to perform denormalization when displaying the generator results. Here I first use save_image in torchvision to save the generator results, but the reverse of the official function The normalization does not match our normalization process, which causes the picture saved by this function to be a little dark, as shown below (the picture below is the real picture in the data set): So here is
insert image description here
another denormalization process of the data and use torchvision The make_grid function in is saved, and the result is as follows (the real picture in the data set):
insert image description here

4.2 Results

Train for 200 rounds, and save the results every 10 rounds. The results of 10, 50, 100, 150, and 200 rounds are shown in the figure below:
epoch10
insert image description here
insert image description here
insert image description here
insert image description here
It can be seen that the pictures generated by the generator are getting clearer and closer to the distribution of the data set

5. Summary

The final effect is still not very good, and the training process of GAN is not too stable, especially how to make the picture clearer and not blurred is still a relatively "tough" problem.
(Novice Xiaobai is writing a blog for the first time, please don’t spray it)

Finally, all the code can be seen on my github

Guess you like

Origin blog.csdn.net/weixin_43706434/article/details/123110332