Generative confrontation network cGAN (conditional GAN)

1 Introduction

论文:Conditional Generative Adversarial Nets

Paper address: https://arxiv.org/abs/1411.1784

For the shortcomings of the original GAN: the generated images are random and unpredictable, it is impossible to control the network to output specific pictures, the generation target is not clear, and the controllability is not strong.

Improvement method: The central idea of ​​cGAN is to control the pictures generated by GAN, rather than simply randomly generating pictures. Conditional GAN ​​adds additional conditional information to the input of the generator and the discriminator. The pictures generated by the generator can only pass the discriminator if they are real enough and match the conditions. Its core is to integrate attribute information into the generator G and the discriminator D. The attribute can be any label information, such as the category of the image, the facial expression of the face image, etc.

2. Model structure

Additional information y is added to both the discriminator and the generator. y can be a category label or other types of data, and y can be introduced into the discriminator and generator as an additional input layer.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
import tqdm

ROOT_TRAIN = r'D:\CNN\AlexNet\data1\train'

def one_hot(x, num_class=2): #转化为独热标签
    return torch.eye(num_class)[x, :]

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform, target_transform=one_hot)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=64,
                                           shuffle=True,
                                           num_workers=0)

# print(train_dataset[0]) #返回数据和标签, 引入one_hot编码后,标签就为长度为num_class的tensor tensor([1., 0.]

# 定义生成器,输入是长度为100的噪声(正态分布随机数),和标签独热编码(condition)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 128*56*56)
        self.bn1 = nn.BatchNorm1d(128*56*56)
        self.linear2 = nn.Linear(2, 128*56*56)
        self.bn2 = nn.BatchNorm1d(128*56*56)

        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          stride=1,
                                          padding=1)  #128*56*56
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 64*112*112
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 3,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 3*224*224

    def forward(self, x1, x2): #x1为噪声输入,x2标签独热编码输入(condition)
        x1 = F.relu(self.linear1(x1)) #100 -- 128*56*56
        x1 = self.bn1(x1)
        x2 = F.relu(self.linear2(x2)) #num_class -- 128*56*56
        x2 = self.bn2(x2)
        x1 = x1.view(-1, 128, 56, 56)
        x2 = x2.view(-1, 128, 56, 56)
        x = torch.cat([x1, x2], dim=1) #256*56*56
        x = F.relu(self.deconv1(x)) #256*56*56 -- 128*56*56
        x = self.bn3(x)
        x = F.relu(self.deconv2(x)) #128*56*56 -- 64*112*112
        x = self.bn4(x)
        x = torch.tanh(self.deconv3(x)) #64*112*112 -- 3*224*224 生成器的输出不使用bn层
        return x

# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(2, 3*224*224)
        self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*55*55, 1)

    def forward(self, x1, x2): #x1为真实图像输入,x2标签独热编码输入(condition)
        x2 = self.linear(x2)
        x2 = x2.view(-1, 3, 224, 224)
        x = torch.cat([x1, x2], dim=1) #batchsize, 6, 224, 224
        x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3)  #64*111*111 判别器的输入不使用bn层
        x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3)  #128*55*55
        x = self.bn(x)
        x = x.view(-1, 128*55*55) #展平
        x = torch.sigmoid(self.fc(x))
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, noise_input, label_input): # model为Generator,test_input代表生成器输入的随机数,label_input为标签输入
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(noise_input, label_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=noise_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./CGAN_img/image_CGAN_{}.png'.format(epoch))
    # if epoch == 99:
    #     plt.show()

# 设置生成绘图图片的随机张量,这里可视化4张图片
noise_input = torch.randn(4, 100, device=device) #测试输入:16个长度为100的随机数
# print(noise_input)
label_input0 = torch.randint(0, 1, size=(4, )) #生成4个从0到1的随机整数
# print(label_input)
label_input_onehot = one_hot(label_input0).to(device) #将tensor转化为独热编码形式
# print(label_input_onehot)


# CGAN训练
D_loss = []
G_loss = []

for epoch in range(500):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, label) in enumerate(tqdm.tqdm(dataloader)): #此时返回的label已经是独热标签
        img = img.to(device)
        label = label.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的噪声输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img, label) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #real_output应该被判定为1(真),得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise, label) #得到生成图像
        fake_output = dis(gen_img.detach(), label) #将生成图像和对应的标签同时放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #fake_output应该被判定为0(假),得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img, label) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化

        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        print('Epoch:', epoch)
        gen_img_plot(gen, epoch, noise_input, label_input_onehot) #每个epoch会生成一张图

    plt.figure(figsize=(10, 10))
    plt.plot(range(1, len(D_loss) + 1), D_loss, label='D_loss')
    plt.plot(range(1, len(G_loss) + 1), G_loss, label='G_loss')
    plt.xlabel('epoch')  # 横轴名称
    plt.legend()
    plt.savefig('loss.png')  # 保存图片

Although the image generated by cGAN has many defects, such as blurred image edges and low resolution of the generated image, it paves the way for the following pix2pixGAN and CycleGAN! ! !

Finally, put the results of my training (the amount of data is not large, there are only 400 pictures of dogs, the effect is not obvious!!!) 

 

 

Guess you like

Origin blog.csdn.net/m0_56247038/article/details/130274952