AI插画师:生成对抗网络

AI插画师:生成对抗网络

本文基于《深度学习框架PyTorch入门与实践》,将书中的代码进行了修改并补充了大量的注释。

本文用GAN实现一个生成动漫人物头像的例子

  1. 首先看本实验的代码结构:
    model.py文件实现模型定义,config.py文件对模型的参数进行配置,main.py文件实现训练和生成
    model.py文件实现模型定义,config.py文件对模型的参数进行配置,main.py文件实现训练和生成
  2. 接着看model.py中定义的生成器及判别器
    (这里对生成器的数据进行归一化处理时采用的是LayerNorm(),书中采用的的是BatchNorm(),二者的具体区别可参考官方文档)
import torch
import torch.nn as nn 
from config import opt

#定义生成器
class NetG(nn.Module):

    def __init__(self, opt):
        super(NetG, self).__init__()

        self.main = nn.Sequential(
            #输入是nz维度即100的噪声,feature map是100*1*1
            nn.ConvTranspose2d(opt.nz, opt.ngf * 8, kernel_size = 4, stride = 1,padding = 0, bias = False),
            nn.LayerNorm((opt.ngf * 8, 4, 4)),
            nn.ReLU(True),
            #512*4*4

            nn.ConvTranspose2d(opt.ngf * 8, opt.ngf * 4, kernel_size = 4, stride = 2,padding = 1, bias = False),
            nn.LayerNorm((opt.ngf * 4,8,8)),
            nn.ReLU(True),
            #256*8*8   输出图像尺寸 = (输入图像尺寸 - 1) * stride - 2 * padding + kernel_size

            nn.ConvTranspose2d(opt.ngf * 4, opt.ngf * 2, kernel_size = 4, stride = 2,padding = 1,bias=False),
            nn.LayerNorm((opt.ngf*2,16,16)),
            nn.ReLU(True),
            #128*16*16

            nn.ConvTranspose2d(opt.ngf * 2, opt.ngf, kernel_size = 4,stride = 2,padding = 1,bias=False),
            nn.LayerNorm((opt.ngf,32,32)),
            nn.ReLU(True),
            #64*32*32

            nn.ConvTranspose2d(opt.ngf, 3, kernel_size = 4, stride = 2, padding = 1, bias = False),
            #3*64*64
            
            nn.Tanh()#将输出图片的像素归一化到-1~1
        )
    def forward(self, input):
        return self.main(input)

#定义判别器      
class NetD(nn.Module):

    def __init__(self, opt):
        super(NetD, self).__init__()
        self.main = nn.Sequential(
            #输入:3*64*64
            nn.Conv2d(3, opt.ndf, kernel_size = 4, stride = 2, padding = 1,bias=False),
            nn.BatchNorm2d(opt.ndf),
            nn.LeakyReLU(0.2,inplace=True),
            #输出:64*32*32

            nn.Conv2d(opt.ndf, opt.ndf * 2, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 2),
            nn.LeakyReLU(0.2,inplace=True),
            #输出:128*16*16

            
            nn.Conv2d(opt.ndf * 2, opt.ndf * 4, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 4),
            nn.LeakyReLU(0.2,inplace=True),
            #输出:256*8*8
            
            nn.Conv2d(opt.ndf * 4, opt.ndf * 8, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            #输出:512*4*4

            nn.Conv2d(opt.ndf * 8, 1, kernel_size = 4, stride = 1, padding = 0, bias=False),
            nn.Sigmoid() #将输出图片的像素归一化到0~1
        )


    def forward(self,x):
        return self.main(x)

#看一下定义好的网络结构
if __name__ == '__main__':
    model = NetG(opt)
    print(model)
    model2 = NetD(opt)
    print(model2)

        
  1. 在开始训练之前,先看模型的配置参数
class Config():
    dataroot = '自定义的数据地址'
    nz = 100 #噪声的维度
    ngf = 64 #生成器feature map数
    ndf = 64 #判别器feature map数
    lr = 0.0002
    num_workers = 50
    batch_size = 64
    image_size = 64
    max_epoch = 120
    beta1 = 0.5 #Adam优化器的beta1参数
    use_gpu = True

opt = Config()
  1. main.py
#训练和生成
import torch
import torch.nn as nn
import model
from torch.optim import lr_scheduler
from torch import optim
import random
import torchvision
from config import opt
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from torchvision import transforms
from torchvision.datasets import ImageFolder


def train(opt):
    #定义模型
    modelG = model.NetG(opt)
    modelD = model.NetD(opt)

    #定义优化器
    optimizer_G = optim.Adam(modelG.parameters(), opt.lr, betas = (opt.beta1, 0.999))
    optimizer_D = optim.Adam(modelD.parameters(), opt.lr, betas = (opt.beta1, 0.999))

    #定义损失函数 单目标二分类交叉熵函数
    criterion = nn.BCELoss()

    #定义dataloader 数据载入
    trfs = torchvision.transforms.Compose([
        transforms.Resize(64, 64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
    ])
    dataset = ImageFolder(opt.dataroot, transform = trfs)
    dataloader = DataLoader(dataset,
                            batch_size = opt.batch_size, 
                            shuffle = True,
                            num_workers = opt.num_workers, 
                            drop_last = True)

    #定义label,真图片label为1,假图片label为0
    true_labels = torch.ones(opt.batch_size)
    fake_labels = torch.zeros(opt.batch_size)

    #生成网络的输入噪声
    noises = torch.randn(opt.batch_size, opt.nz, 1, 1)

    #使用GPU进行训练
    if opt.use_gpu:
        modelD.cuda()
        modelG.cuda()
        criterion.cuda()
        true_labels = true_labels.cuda()
        fake_labels = fake_labels.cuda()
        noises = noises.cuda()

    '''
        下面开始训练网络,训练步骤:
        1.训练判别器,固定生成器
        对于真图片,判别器的输出概率值尽可能接近于1
        对于生成器生成的假图片,判别器尽可能输出0
        2.训练生成器,固定判别器
        生成器生成图片,尽可能让判别器输出1
        3.返回第一步,循环交替训练
    '''
    writer = SummaryWriter()
    cnt = 0
    for epoch in range(1, opt.num_workers): #进行多个epoch的训练
        for ii, (img, _) in enumerate(dataloader): #第一个参数是要拼接的tensor,第二个参数是-1
            if opt.use_gpu:
                img = img.cuda()
            
            #训练判别器
            optimizer_D.zero_grad(set_to_none = True)
            noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
            
            fake_img = modelG(noises).detach() # 将随机噪声放入生成网络中生成一张假图片
            #detach()避免梯度传到G,因为G不用更新
            #print(fake_img.size()) #[64, 3, 64, 64] batch_size * C * H * W 

            fake_out = modelD(fake_img) #判别器判断假的图片
            loss_fake = criterion(fake_out.squeeze(), fake_labels) #得到假图片的loss

            img_out = modelD(img) #将真实图片放入判别器中
            loss_true = criterion(img_out.squeeze(), true_labels)
            loss_D = loss_fake + loss_true #损失包括判真损失和判假损失
            loss_D.backward() #反向传播
            writer.add_scalar('loss_D', loss_D, cnt)
            optimizer_D.step() #更新参数

            #训练生成器
            '''
            目的是希望生成的假图片被判别器判断为真图片
            '''
            optimizer_G.zero_grad(set_to_none = True)
            noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
            fake_img = modelG(noises)
            fake_out = modelD(fake_img).squeeze()
            loss_G = criterion(fake_out, true_labels) #假图片和真实图片label的loss
            loss_G.backward() 
            optimizer_G.step() #反向传播更新的参数是生成网络的参数
            writer.add_scalar('loss_G', loss_G, cnt)
            cnt += 1
            print('epoch: [%d|%d]  batch: [%d|%d]  lossD: %.3f  lossG: %.3f' % (
                epoch, opt.num_workers, ii, len(dataloader), loss_D, loss_G
            ))

        torchvision.utils.save_image(fake_img.data, '/tmp/pycharm_project_238/DCCGAN/imgs/epoch_%03d.png' % (epoch), 
                                     normalize = True)
    
    writer.flush()
    writer.close()
    torch.save(modelG.state_dict(), 'modelG.pth')
    torch.save(modelD.state_dict(), 'modelD.pth')


if __name__ == '__main__':
    train(opt)



  1. 实验结果分析
    如下图所示,分别是训练1个、10个、20个、30个、40个、50个epoch之后神经网络生成的动漫头像。
    在这里插入图片描述在这里插入图片描述在这里插入图片描述
    在这里插入图片描述在这里插入图片描述在这里插入图片描述
    有条件的话可以训练更多的epoch,会得到不少以假乱真的动漫图片。这个程序还可以应用到不同的生成图片场景中,只要将训练图片改成其他类型的图片即可。事实上,上述模型还有很大的改进空间。

猜你喜欢

转载自blog.csdn.net/weixin_44022810/article/details/114440304
今日推荐