Generate confrontation network DCGAN

1 Introduction

论文:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

Paper address: https://arxiv.org/abs/1511.06434

DCGAN combines CNN and the original GAN. Both the generative model and the discriminative model use the deep convolutional neural network to generate the confrontation network. After laying the basic network architecture of almost all GANs, it greatly improves the stability of the original GAN ​​training and Generate quality of results. 

2. Improvements

  • DCGAN's generator and discriminator both discard the pooling layer of CNN, the discriminator retains the overall architecture of CNN, and the generator replaces the convolutional layer with a deconvolutional layer (ConvTranspose2d) 
  • The Batch Normalization layer is used in the discriminator and generator, which helps to deal with training problems caused by poor initialization, accelerates model training, and improves training stability. Note that BN layers are not used in the output layer of the generator and the input layer of the discriminator. 
  • In the generator, except the output layer uses the Tanh() activation function, all other layers use the ReLU activation function. In the discriminator, all layers except the output layer use the LeakyReLU activation function to prevent gradient sparsity.
  • In the generator, except the output layer uses the Tanh() activation function, all other layers use the ReLU activation function. 

3. Structural diagram

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
import tqdm

ROOT_TRAIN = r'D:\CNN\anime-faces'

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=0)


# 定义生成器,输入是长度为100的噪声(正态分布随机数)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 256*16*16)
        self.bn1 = nn.BatchNorm1d(256*16*16)
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          stride=1,
                                          padding=1)  #128*56*56
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 64*112*112
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 3,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 3*224*224

    def forward(self, x): #x为噪声输入
        x = F.relu(self.linear1(x)) #100 -- 256*56*56
        x = self.bn1(x)
        x = x.view(-1, 256, 16, 16)
        x = F.relu(self.deconv1(x)) #256*56*56 -- 128*56*56
        x = self.bn2(x)
        x = F.relu(self.deconv2(x)) #128*56*56 -- 64*112*112
        x = self.bn3(x)
        x = torch.tanh(self.deconv3(x)) #64*112*112 -- 3*224*224 生成器的输出不使用bn层
        return x


# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*15*15, 1)

    def forward(self, x):
        x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3)  #64*111*111 判别器的输入不使用bn层
        x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3)  #128*55*55
        x = self.bn(x)
        x = x.view(-1, 128*15*15) #展平
        x = torch.sigmoid(self.fc(x))
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为Generator/Discriminator,test_input代表生成器输入的随机数
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=test_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./face_DCGAN/image_GAN_{}.png'.format(epoch))
    # if epoch == 99:
    #     plt.show()

test_input = torch.randn(4, 100, device=device) #测试输入:16个长度为100的随机数


# DCGAN训练
D_loss = []
G_loss = []

for epoch in range(100):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, _) in enumerate(tqdm.tqdm(dataloader)):
        img = img.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的噪声输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise) #得到生成图像
        fake_output = dis(gen_img.detach()) #将生成图像放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化


        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        gen_img_plot(gen, epoch, test_input)  # 每个epoch会生成一张图
        print('Epoch:', epoch)

    plt.figure(figsize=(10, 10))
    plt.plot(range(1, len(D_loss)+1), D_loss, label='D_loss')
    plt.plot(range(1, len(G_loss)+1), G_loss, label='G_loss')
    plt.xlabel('epoch')  # 横轴名称
    plt.legend()
    plt.savefig('loss.png')  # 保存图片


# if __name__ == '__main__':
#     x = torch.rand((4, 3, 224, 224))
#     model = Discriminator()
#     out = model(x)
#     print(out.shape)

Visualization of Training Effects on Cartoon Character Dataset

Guess you like

Origin blog.csdn.net/m0_56247038/article/details/130270514