Detailed explanation and examples of generative confrontation network (GAN)

Introduction to GANs

An intuitive way to understand GANs is to understand it from a game theory perspective. GANs consist of two players, a generator and a discriminator, each trying to beat the other. The generator takes some random noise from the partition and tries to generate some output-like distribution from it. Generators always try to create distributions that are indistinguishable from the true distribution. That is, the faked output looks like it should be the real image. However, without explicit training or labeling, the generator cannot discriminate against real images, and its only source is a tensor of random floating point numbers.

After that, GAN will introduce another player in the game, the discriminator. The discriminator is only responsible for informing the generator that the output it generates does not look like a real image, so that the generator changes how it generates the image to convince the discriminator that it is a real image. But the discriminator can always tell the generator that the image it generated is not real, because the discriminator knows that the image was generated from the generator. To solve this problem, GAN introduces real images into the game and isolates the discriminator from the generator. Now, the discriminator takes an image from a set of real images and a fake image from the generator, and it has to figure out where each image came from.

Initially, the discriminator knows nothing, but randomly predicts outcomes. However, the task of the discriminator can be modified to a classification task. A discriminator can classify an input image as an original image or a generated image, which is binary classification. Likewise, we train the discriminator network to correctly classify images, and eventually, through backpropagation, the discriminator learns to distinguish real images from generated images.

insert image description here

code example

Dataset introduction:
In this experiment, we choose the flower data set to generate images. There are six categories in this data set.
insert image description here

Model training
Training discriminator:
For real pictures, the output should be 1 as much as possible
For fake pictures, the output should be 0 as much as possible
Training generator:
For fake pictures, the output should be 1 as much as possible
1. When training the generator, there is no need to adjust the parameters of the discriminator ; When training the discriminator, there is no need to adjust the parameters of the generator.
2. When training the discriminator, it is necessary to use the detach operation to truncate the calculation graph of the image generated by the generator to avoid backpropagation from passing the gradient to the generator. Because we don't need to train the generator when training the discriminator, we don't need the gradient of the generator.
3. When training the discriminator, it is necessary to backpropagate twice, once to judge the real picture as 1, and once to judge the fake picture as 0. It is also possible to put the data of the two into a batch, and perform one forward propagation and one back propagation.
4. For fake pictures, when training the discriminator, we want it to output 0; when training the generator, we want it to output 1. So we can see a pair of seemingly contradictory code error_d_fake = criterion(output, fake_labels) and error_g = criterion(output, true_labels) . The discriminator hopes to be able to distinguish the fake picture as fake_label, while the generator hopes to distinguish it as true_label, and the discriminator and generator fight against each other to improve.

import os
import torch
from torch.utils.data import Dataset, DataLoader
from dataloader import MyDataset
from model import Generator, Discriminator
import torchvision
import numpy as np
import matplotlib.pyplot as plt
if __name__ == '__main__':
    LR = 0.0002
    EPOCH = 1000  # 50
    BATCH_SIZE = 40
    N_IDEAS = 100
    EPS = 1e-10
    TRAINED = False
    #path = r'./data/image'
    train_data = MyDataset(path=path, resize=96, Len=10000, img_type='jpg')
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

    torch.cuda.empty_cache()

    if TRAINED:
        G = torch.load('G.pkl').cuda()
        D = torch.load('D.pkl').cuda()
    else:
        G = Generator(N_IDEAS).cuda()
        D = Discriminator(3).cuda()

    optimizerG = torch.optim.Adam(G.parameters(), lr=LR)
    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)

    for epoch in range(EPOCH):
        tmpD, tmpG = 0, 0
        for step, x in enumerate(train_loader):
            x = x.cuda()
            rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()
            G_imgs = G(rand_noise)

            D_fake_probs = D(G_imgs)
            D_real_probs = D(x)

            p_d_fake = torch.squeeze(D_fake_probs)
            p_d_real = torch.squeeze(D_real_probs)

            D_loss = -torch.mean(torch.log(p_d_real + EPS) + torch.log(1. - p_d_fake + EPS))
            G_loss = -torch.mean(torch.log(p_d_fake + EPS))
            # D_loss = -torch.mean(torch.log(D_real_probs) + torch.log(1. - D_fake_probs))
            # G_loss = torch.mean(torch.log(1. - D_fake_probs))

            optimizerD.zero_grad()
            D_loss.backward(retain_graph=True)
            optimizerD.step()

            optimizerG.zero_grad()
            G_loss.backward(retain_graph=True)
            optimizerG.step()

            tmpD_ = D_loss.cpu().detach().data
            tmpG_ = G_loss.cpu().detach().data
            tmpD += tmpD_
            tmpG += tmpG_
        tmpD /= (step + 1)
        tmpG /= (step + 1)
        print(
            'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG)
        )
        # if (epoch+1) % 5 == 0:
        select_epoch = [1, 5, 10, 20, 50, 80, 100, 150, 200, 400, 500, 800, 999, 1500, 2000, 3000, 4000, 5000, 6000, 8000, 9999]
        if epoch in select_epoch:
plt.imshow(np.squeeze(G_imgs[0].cpu().detach().numpy().transpose((1, 2, 0))) * 0.5 + 0.5)
            plt.savefig('./result1/_%d.png' % epoch)

    torch.save(G, 'G.pkl')
    torch.save(D, 'D.pkl')

The following is the effect of training multiple times.
insert image description here
insert image description here
insert image description here
insert image description here
insert image description here
insert image description here
The complete code is as follows:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False

# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223


noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generator


class NetGenerator(nn.Module):
    def __init__(self):
        super(NetGenerator,self).__init__()
        self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
            nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(n_generator_feature * 8),
            nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4
            nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 4),
            nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8
            nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature * 2),
            nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16
            nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_generator_feature),
            nn.ReLU(True),      # (n_generator_feature) × 32 × 32
            nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()       # 3 * 96 * 96
        )

    def forward(self, input):
        return self.main(input)


class NetDiscriminator(nn.Module):
    def __init__(self):
        super(NetDiscriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32
            nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16
            nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8
            nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_discriminator_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4
            nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()        # 输出一个概率
        )

    def forward(self, input):
        return self.main(input).view(-1)


def train():
    for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96
        real_image = Variable(image)
        real_image = real_image.cuda()

        if (i + 1) % d_every == 0:
            optimizer_d.zero_grad()
            output = Discriminator(real_image)      # 尽可能把真图片判为True
            error_d_real = criterion(output, true_labels)
            error_d_real.backward()

            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises).detach()       # 根据噪声生成假图
            fake_output = Discriminator(fake_img)       # 尽可能把假图片判为False
            error_d_fake = criterion(fake_output, fake_labels)
            error_d_fake.backward()
            optimizer_d.step()

        if (i + 1) % g_every == 0:
            optimizer_g.zero_grad()
            noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))
            fake_img = Generator(noises)        # 这里没有detach
            fake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为True
            error_g = criterion(fake_output, true_labels)
            error_g.backward()
            optimizer_g.step()


def show(num):
    fix_fake_imags = Generator(fix_noises)
    fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5

    # x = torch.rand(64, 3, 96, 96)
    fig = plt.figure(1)

    i = 1
    for image in fix_fake_imags:
        ax = fig.add_subplot(8, 8, eval('%d' % i))
        # plt.xticks([]), plt.yticks([])  # 去除坐标轴
        plt.axis('off')
        plt.imshow(image.permute(1, 2, 0))
        i += 1
    plt.subplots_adjust(left=None,  # the left side of the subplots of the figure
                        right=None,  # the right side of the subplots of the figure
                        bottom=None,  # the bottom of the subplots of the figure
                        top=None,  # the top of the subplots of the figure
                        wspace=0.05,  # the amount of width reserved for blank space between subplots
                        hspace=0.05)  # the amount of height reserved for white space between subplots)
    plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)
    plt.savefig("images/%dcgan.png" % num)


if __name__ == '__main__':
    transform = tv.transforms.Compose([
        tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecated
        tv.transforms.CenterCrop(96),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数
    ])

    dataset = tv.datasets.ImageFolder(dir, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'

    print('数据加载完毕!')
    Generator = NetGenerator()
    Discriminator = NetDiscriminator()

    optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    true_labels = Variable(torch.ones(batch_size))     # batch_size
    fake_labels = Variable(torch.zeros(batch_size))
    fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))
    noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布

    if torch.cuda.is_available() == True:
        print('Cuda is available!')
        Generator.cuda()
        Discriminator.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()


    plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]

    for i in range(3000):        # 最大迭代次数
        train()
        print('迭代次数:{}'.format(i))
        if i in plot_epoch:
            show(i)


Guess you like

Origin blog.csdn.net/weixin_45807161/article/details/123776427