GAN generates anime avatars

GAN generates anime avatars

This article is mainly based on the GAN course of Mr. Li Hongyi, combined with some codes in "Introduction and Practice of Deep Learning Framework pytorch".

Introduction to GAN principle

GAN (Generative adversarial Networks) generates a confrontation network. GAN solves a famous problem: given a batch of samples, training a system can generate similar new samples . Generative confrontation network, as the name suggests, has two parts, one is the generator (Generator), and the other is the discriminator (Discriminator).

  • Generator: Input a random noise and generate a picture
  • Discriminator (Discriminator): Discriminate whether the image is a real picture or a fake picture

Please add a picture description

The basic training process is shown in the figure. The first-generation generator was initialized randomly at the beginning, and a random noise was passed in, and the generated picture was blurred. What the first-generation discriminator does is to be able to distinguish whether the picture is the first The pictures produced by the generator of the generation, or the real pictures. Maybe the first-generation discriminator thinks that the color is a real picture, but the generator has to fool the discriminator, so it has to evolve, and it has evolved into the second generation to produce colored pictures. The discriminator has also evolved accordingly, finding out that the one with the mouth is the real picture, and distinguishing the generated picture from the real picture. Then the generator also evolved to the third generation, produced a picture of the mouth, fooled the second-generation discriminator, and the discriminator evolved into the third generation. Fight against each other step by step, evolve with each other, and finally generate a two-dimensional avatar.

Algorithm process

Train the Discriminator

First, initialize the generator and the discriminator. In each training iteration, first fix the generator, pass random noise to the generator, and generate corresponding pictures. As mentioned earlier, the task of the discriminator is to distinguish true and false pictures. The discriminator mainly scores the pictures. The higher the score, the higher the probability of a real picture, and the lower the score, the higher the probability of a fake picture. Therefore, when training the discriminator, it receives the graph from the database and the graph from the generator, and adjusts the parameters. If it is a graph from the database, it will give a high score, and if it is a graph from the generator, it will give a low score. In other words, the discriminator at our training place is that the closer the graph given by the data set is to 1, the better, and the closer the graph given by the generator is to 0, the better.

Please add a picture description

Training Generator

The discriminator has been trained before, and now the generator is trained. What the generator needs to do is to receive random noise, generate a picture, and "cheat" the discriminator. How to "cheat" the generator. The actual approach is to send the generated pictures to the discriminator for scoring. The goal is to make the score as high as possible (the closer to 1, the better). The whole part can be regarded as a whole, which is a huge hidden layer, which includes Generator and Discriminator, which inputs noise and generates scores. Adjust the parameters to keep the score close to 1, but the middle part of the Discriminator cannot be adjusted.

Please add a picture description

the whole process

Please add a picture description

  1. Select m samples {x1,x2...xm} from the data set; then create m noises, the dimension of z is determined by yourself, this is the noise input by the generator later; get the generated data, it is the G(z) generator Generated; update the discriminator parameters so that the following formula is maximized, the following formula means: D(x) the score of the discriminator to distinguish the real picture to the log average plus 1 - the log average of the value of the discriminator to distinguish the fake picture value. To put it simply, the higher the score for distinguishing real pictures, the better. When distinguishing fake pictures, the farther the score of fake pictures is from the value of 1, the better. (equivalent to training a binary classifier, using bceloss )
  2. m noise points z, update the generator, and the generator will feed the image generated by z to the discriminator. The higher the score, the better.

BCE Loss

bce loss classification, used in two divisions. Math formula like
loss ( X i , yi ) = − wi [ yilogxi + ( 1 − yi ) log ( 1 − xi ) ] loss(X_i,y_i) = -w_i[y_ilogx_i + (1 - y_i)log(1 - xi )]loss(Xi,yi)=wi[yilogxi+(1yi)log(1x i ) ]
pytorch中bceloss

class torch.nn.BCELoss(weight: Optional[torch.Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean')

weight: initialize the weight matrix

size_average: The default is True, average the loss

reduction: the default summation, the average loss for batch_size

Code

After understanding the previous theories, you can start to realize that the dataset I use here is Extra Data , or you can try it with Anime Dataset . Please find the ladder yourself.

initialization

import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.datasets as DataSet
import torchvision.transforms as transform
import torch.utils.data as Data
import numpy as np
import torch.nn as nn
import torch.optim as optim
import os


# 用于图片保存
def saveImg(inp, name):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.savefig(name)

    
# 用于图片显示,可以调试数据集是否加载成功
def imgshow(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.show()


    
batch_size = 20
# 图像处理,尺寸转为64 * 64,转tensor范围(0,1), Normalize之后转为 (-1, 1)
simple_transform = transform.Compose([
    transform.Resize((64, 64)),
    transform.ToTensor(),
    transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 使用GPU or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 噪点数量
noise_z = 100
generator_feature_map = 64
# 加载数据集
path = "AnimeDataset"
train_set = DataSet.ImageFolder(path, simple_transform)
train_loader = Data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
# 正确分数标签
true_label = torch.ones(batch_size).to(device)
true_label = true_label.view(-1, 1)
# 错误分数标签
false_label = torch.zeros(batch_size).to(device)
false_label = false_label.view(-1, 1)
# 固定的noises,这样在每个Epoch完成之后可以看到generator产生同个照片的过程
fix_noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
# 随机noises
noises = torch.randn(batch_size, noise_z, 1, 1).to(device)
g_train_cycle = 1  # 训练生成器周期
save_img_cycle = 1  # 每几次epoch输出一次结果
print_step = 200  # 打印loss 信息周期

bceloss = nn.BCELoss()

Builder

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            # 100*1*1 --> (64 * 8) * 4 *4
            nn.ConvTranspose2d(noise_z, generator_feature_map * 8, kernel_size=4, bias=False),
            nn.BatchNorm2d(generator_feature_map * 8),
            nn.ReLU(True))
        self.layer2 = nn.Sequential(
            # (64 * 8) * 4 * 4 --> (64 * 4)*8*8
            nn.ConvTranspose2d(generator_feature_map * 8, generator_feature_map * 4, kernel_size=4, stride=2,
                               padding=1),
            nn.BatchNorm2d(generator_feature_map * 4),
            nn.ReLU(True))
        self.layer3 = nn.Sequential(

            # (64*4)*8*8 --> (64*2)*16*16
            nn.ConvTranspose2d(generator_feature_map * 4, generator_feature_map * 2, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(generator_feature_map * 2),
            nn.ReLU(True))
        self.layer4 = nn.Sequential(

            # (64*2)*16*16 --> 64*32*32
            nn.ConvTranspose2d(generator_feature_map * 2, generator_feature_map, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.BatchNorm2d(generator_feature_map),
            nn.ReLU(True))
        self.layer5 = nn.Sequential(
            # 64*32*32 --> 3*64*64
            nn.ConvTranspose2d(generator_feature_map, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        return out

discriminator

class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super(Discriminator, self).__init__()
        # layer1 输入 3 x 96 x 96, 输出 (ndf) x 32 x 32
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer2 输出 (ndf*2) x 16 x 16
        self.layer2 = nn.Sequential(
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer3 输出 (ndf*4) x 8 x 8
        self.layer3 = nn.Sequential(
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer4 输出 (ndf*8) x 4 x 4
        self.layer4 = nn.Sequential(
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # layer5 输出一个数(概率)
        self.layer5 = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    # 定义NetD的前向传播
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.view(-1,1)
        return out

optimizer

generator = Generator().to(device)
discriminator = Discriminator().to(device)

learning_rate = 0.0002
beta = 0.5
# 优化器初始化
g_optim = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta, 0.999))
d_optim = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta, 0.999))

loss function

def loss_g_func(testLabel, trueLabel):
    return bceloss(testLabel, trueLabel)


def loss_d_func(real_predicts, real_labels, fake_predicts, fake_labels):
    real = bceloss(real_predicts, real_labels)	# 真图片分数,不断靠近1
    fake = bceloss(fake_predicts, fake_labels)	# 假图片分数,不断靠近0
    real.backward()
    fake.backward()
    return real + fake

start training

# 训练Discriminator
train_num = 25
for trainIdx in range(train_num):
    for step, data in enumerate(train_loader):
        image_x, _ = data
        image_x = image_x.to(device)
        # 训练判别器
        noises.data.copy_(torch.randn(batch_size, noise_z, 1, 1))
        out = discriminator(image_x)	# 原图产生的分数
        fake_pic = generator(noises)	# 生成器生成图像
        fake_predict = discriminator(fake_pic.detach())	# 使用detach()切断求导关联
        d_optim.zero_grad()
        dloss = loss_d_func(out, true_label, fake_predict, false_label)
        d_optim.step()

        if step % g_train_cycle == 0:
            # 训练生成器
            g_optim.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noise_z, 1, 1))
            fake_img = generator(noises)
            fake_out = discriminator(fake_img)
            # 尽可能让判别器把假图判别为1
            loss_fake = loss_g_func(fake_out, true_label)
            loss_fake.backward()
            g_optim.step()

        if step % print_step == print_step - 1:
            print("train: ", trainIdx, "step: ", step + 1, " d_loss: ", dloss.item(), "mean score: ",
                  torch.mean(out).item())
            print("train: ", trainIdx, "step: ", step + 1, " g_loss: ", loss_fake.item(), "mean score: ",
                  torch.mean(fake_out).item())

    if trainIdx % save_img_cycle == 0:
        fix_fake_image = generator(fix_noises)
        fix_fake_image = fix_fake_image.data.cpu()
        comb_img = torchvision.utils.make_grid(fix_fake_image, nrow=4)
        savepath = os.path.join("gan", "pics", "g_%s.jpg" % trainIdx)
        saveImg(comb_img, savepath)
        torch.save(discriminator.state_dict(), './gan/netd_%s.pth' % trainIdx)
        torch.save(generator.state_dict(), './gan/netg_%s.pth' % trainIdx)

result

This is the effect of 1 Epoch

Please add a picture description

5 Epochs

Please add a picture description

10 Epochs

Please add a picture description

25 Epochs

Please add a picture description

Guess you like

Origin blog.csdn.net/qq_36571422/article/details/123883196