Use GAN network on pytorch to generate 0-9 numbers

Use GAN to generate adversarial networks on pytorch to generate 0-9 handwritten digit images

1. Background
GAN generative adversarial network can generate images, audio, and videos, and is mainly divided into generator network and discriminator network.

①General process
Input a random signal, let the generator randomly generate a picture, and then let the discriminator identify real pictures and fake pictures.

②What is confrontation?
In short, the pictures generated by the generator make it more difficult for the discriminator to identify fake pictures. On the other hand, the discriminator continuously improves its ability to identify real and fake pictures, thus forming a confrontation. network.

2. Dataset
The data set comes from the MNIST handwritten 0-9 data set (28x28) of torchvision's dataset
Please learn more by yourself a>

3. Model
Generator and Discriminator
model.py file

from torch import nn

# 图像生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),  
            nn.ReLU(),  
            nn.Linear(256, 512),  
            nn.ReLU(),  
            nn.Linear(512, 784),  #图片大小为28*28=784
            nn.Tanh()  # Tanh激活使得生成数据分布在[-1,1]之间,因为输入的真实数据的经过transforms之后也是这个分布
        )

    def forward(self, x):
        x = self.gen(x)
        return x


# 图像判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.f1 = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2)
        )
        self.f2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2)
        )
        self.out = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.f1(x)
        x = self.f2(x)
        x = self.out(x)
        return x

4.training model
train.py

import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms, datasets
from torchvision.utils import save_image

from model import Discriminator, Generator


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
    out = out.view(-1, 1, 28, 28)  # view()函数作用是将一个多行的Tensor,拼接成一行
    return out


def GAN_train_model(dataset, generator, discriminator, batch_size, epoch, lr, z_dim, device):
    device = device
    batch_size = batch_size
    epoch = epoch
    lr = lr
    z_dim = z_dim

    # 返回一个数据迭代器
    # shuffle:是否打乱顺序
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True)

    if device == "CUDA":
        D = discriminator.cuda()
        G = generator.cuda()
    else:
        D = discriminator.cpu()
        G = generator.cpu()

    criterion = nn.BCELoss()  # 定义损失函数

    d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)



    steps_per_epoch = len(data_loader)
    # 开始训练
    for cur_epoch in range(epoch):  # 进行多个epoch的训练
        total_d_loss = 0
        total_g_loss = 0

        for i, (img, _) in enumerate(data_loader):

            num_img = img.size(0)
            # 将图像变为1维数据
            img = img.view(num_img, -1)
            real_img = img

            # 定义真实的图片label为1
            real_label = torch.ones(num_img, 1)
            # 定义假的图片的label为0
            fake_label = torch.zeros(num_img, 1)

            if device == "CUDA":
                real_img = real_img.cuda()
                real_label = real_label.cuda()
                fake_label = fake_label.cuda()
            else:
                real_img = real_img.cpu()
                real_label = real_label.cpu()
                fake_label = fake_label.cpu()

            # 判别器训练

            # 将真实图片放入判别器中
            real_out = D(real_img)

            # 得到真实图片的loss
            d_loss_real = criterion(real_out, real_label)
            # 得到真实图片的判别值,real_out输出的值越接近1越好
            real_scores = real_out

            # 计算假的图片的损失
            z = torch.randn(num_img, z_dim)  # 随机生成一些噪声

            if device == "CUDA":
                z = z.cuda()
            else:
                z = z.cpu()

            # 随机噪声放入生成网络中,生成一张假的图片。
            # 避免梯度传到G,因为G不用更新, detach分离
            fake_img = G(z).detach()
            # 判别器判断假的图片
            fake_out = D(fake_img)
            # 得到假的图片的loss
            d_loss_fake = criterion(fake_out, fake_label)
            # 得到假图片的判别值,对于判别器来说,假图片的d_loss_fake损失越接近0越好
            fake_scores = fake_out
            # 损失函数和优化,总的来讲就是训练判别器能判断图片是真图还是假图(生成图)
            d_loss = d_loss_real + d_loss_fake  # 损失包括判真损失和判假损失

            total_d_loss += d_loss.data.item()

            d_optimizer.zero_grad()  # 在反向传播之前,先将梯度归0
            d_loss.backward()  # 将误差反向传播
            d_optimizer.step()  # 更新参数

            # 训练生成器
            # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
            # 在此过程中,将判别器固定,将假的图片传入判别器的结果与real_label的对应,
            # 使得生成的图片让判别器以为是真的
            # 这样就达到了对抗的目的
            # 计算假的图片的损失
            z = torch.randn(num_img, z_dim)  # 得到随机噪声

            if device == "CUDA":
                z = z.cuda()
            else:
                z = z.cpu()

            fake_img = G(z)  # 随机噪声输入到生成器中,得到一副假的图片
            output = D(fake_img)  # 经过判别器得到的结果
            g_loss = criterion(output, real_label)  # 得到的假的图片与真实的图片的label的loss

            total_g_loss += g_loss.data.item()

            # bp and optimize
            g_optimizer.zero_grad()  # 梯度归0
            g_loss.backward()  # 进行反向传播
            g_optimizer.step()  # .step()一般用在反向传播后面,用于更新生成网络的参数


        # 打印每个epoch的损失
        print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '.format(
            cur_epoch, epoch, total_d_loss / steps_per_epoch, total_g_loss / steps_per_epoch)  # 打印的是真实图片的损失均值
        )

        if cur_epoch == 0:
            real_images = to_img(real_img.cpu().data)
            save_image(real_images, './img/real_images.png')

        fake_images = to_img(fake_img.data)
        save_image(fake_images, './img/fake_images-{}.png'.format(cur_epoch + 1))

    # 保存生成器和判别器模型
    torch.save(generator, "model/generator.pkl")
    torch.save(discriminator, "model/discriminator.pkl")


if __name__ == "__main__":
    # 图像变化器,转为tensor并标准化数据
    transform = transforms.Compose([
        transforms.ToTensor(),  # 数据范围[0,1],归一化
        transforms.Normalize((0.5,), (0.5,))  # (x-mean) / std,数据范围[-1,1],经过Normalize后,可以加快模型的收敛速度(不确定)
    ])

    # 加载数据集
    dataset = datasets.MNIST(root='./data/',
                             train=True,
                             transform=transform,
                             download=True)

    # 初始生成器generator与判别器discriminator
    discriminator = Discriminator()
    generator = Generator()

    # batch_size
    batch_size = 128

    # epoch次数
    epoch = 100

    # lr学习率
    lr = 3e-4

    # 噪声维度
    z_dim = 100

    GAN_train_model(
        dataset=dataset,
        discriminator=discriminator,
        generator=generator,
        batch_size=batch_size,
        epoch=epoch,
        lr=lr,
        z_dim=z_dim,
        device="CUDA"
    )

5. Training effect (generator)
The 100th epoch training effect
Insert image description here
Real data
Insert image description here
Available here See that the pictures generated by the generator are the same as the real ones.

Guess you like

Origin blog.csdn.net/pk296256948/article/details/127959611