pytorch实现DCGAN 生成人脸 celeba数据集

版权声明:版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DarrenXf/article/details/86684874

涉及的论文

GAN
https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
DCGAN
https://arxiv.org/pdf/1511.06434.pdf

测试用的数据集

Celeb-A Faces
数据集网站:
http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
下载链接:
百度 网盘 :https://pan.baidu.com/s/1eSNpdRG#list/path=%2F
谷歌 网盘 :https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg

数据集下载后,找到一个文件叫 img_align_celeba.zip
创建一个文件夹data,然后在data内创建一个文件夹celeba.
将img_align_celeba.zip 拷贝进celeba,然后解压

unzip img_align_celeba.zip

会生成这样的目录结构

./data/celeba/
		->img_align_celeba
			->188242.jpg
			->173822.jpg
			->284792.jpg
			...

这一步很重要,因为我们的代码中使用这样的文件结构.

实现DCGAN 包含的文件

main.py
etc.py
graph.py
model.py
show.py
record.py
DCGAN_architecture.py
celeba_dataset.py

main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-01-25 14:07
# Modified date : 2019-01-27 22:36
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import celeba_dataset
from etc import config
from graph import NNGraph

def run():
    dataloader = celeba_dataset.get_dataloader(config)
    g = NNGraph(dataloader, config)
    g.train()

run()

etc.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-24 17:02
# Modified date : 2019-01-28 23:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch

config = {}

config["dataset"] = "celeba"
config["batch_size"] = 128
config["image_size"] = 64
config["num_epochs"] = 5
config["data_path"] = "data/%s" % config["dataset"]
config["workers"] = 2
config["print_every"] = 200
config["save_every"] = 500
config["manual_seed"] = 999
config["train_load_check_point_file"] = False
#config["manual_seed"] = random.randint(1, 10000) # use if you want new results

config["number_channels"] = 3
config["size_of_z_latent"] = 100
config["number_gpus"] = 1
config["number_of_generator_feature"] = 64
config["number_of_discriminator_feature"] = 64
config["learn_rate"] = 0.0002
config["beta1"] =0.5
config["real_label"] = 1
config["fake_label"] = 0
config["device"] = torch.device("cuda:0" if (torch.cuda.is_available() and config["number_gpus"] > 0) else "cpu")

graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-24 17:17
# Modified date : 2019-01-28 17:46
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os
import time
import torch
import torchvision.utils as vutils
import model
import show
import record

class NNGraph(object):
    def __init__(self, dataloader, config):
        super(NNGraph, self).__init__()
        self.config = config
        self.train_model = self._get_train_model(config)
        record.record_dict(self.config, self.train_model["config"])
        self.config = self.train_model["config"]
        self.dataloader = dataloader

    def _get_train_model(self, config):
        train_model = model.init_train_model(config)
        train_model = self._load_train_model(train_model)
        return train_model

    def _save_train_model(self):
        model_dict = model.get_model_dict(self.train_model)
        file_full_path = record.get_check_point_file_full_path(self.config)
        torch.save(model_dict, file_full_path)

    def _load_train_model(self, train_model):
        file_full_path = record.get_check_point_file_full_path(self.config)
        if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
            checkpoint = torch.load(file_full_path)
            train_model = model.load_model_dict(train_model, checkpoint)
        return train_model

    def _train_step(self, data):
        netG = self.train_model["netG"]
        optimizerG = self.train_model["optimizerG"]
        netD = self.train_model["netD"]
        optimizerD = self.train_model["optimizerD"]
        criterion = self.train_model["criterion"]
        device = self.config["device"]

        real_data = data[0].to(device)

        noise = model.get_noise(real_data, self.config)
        fake_data = netG(noise)
        label = model.get_label(real_data, self.config)

        errD, D_x, D_G_z1 = model.get_Discriminator_loss(netD, optimizerD, real_data, fake_data.detach(), label, criterion, self.config)
        errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, self.config)

        return errD, errG, D_x, D_G_z1, D_G_z2

    def _train_a_step(self, data, i, epoch):
        start = time.time()
        errD, errG, D_x, D_G_z1, D_G_z2 = self._train_step(data)
        end = time.time()
        step_time = end - start

        self.train_model["take_time"] = self.train_model["take_time"] + step_time

        print_every = self.config["print_every"]
        if i % print_every == 0:
            record.print_status(step_time*print_every,
                                self.train_model["take_time"],
                                epoch,
                                i,
                                errD,
                                errG,
                                D_x,
                                D_G_z1,
                                D_G_z2,
                                self.config,
                                self.dataloader)
        return errD, errG

    def _DCGAN_eval(self):
        fixed_noise = self.train_model["fixed_noise"]
        with torch.no_grad():
            netG = self.train_model["netG"]
            fake = netG(fixed_noise).detach().cpu()
            return fake

    def _save_generator_images(self, iters, epoch, i):
        num_epochs = self.config["num_epochs"]
        save_every = self.config["save_every"]
        img_list = self.train_model["img_list"]

        if (iters % save_every == 0) or ((epoch == num_epochs-1) and (i == len(self.dataloader)-1)):
            fake = self._DCGAN_eval()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            self._save_train_model()

    def _train_iters(self):
        num_epochs = self.config["num_epochs"]
        G_losses = self.train_model["G_losses"]
        D_losses = self.train_model["D_losses"]
        iters = self.train_model["current_iters"]
        start_epoch = self.train_model["current_epoch"]

        for epoch in range(start_epoch, num_epochs):
            self.train_model["current_epoch"] = epoch
            for i, data in enumerate(self.dataloader, 0):
                errD, errG = self._train_a_step(data, i, epoch)
                G_losses.append(errG.item())
                D_losses.append(errD.item())
                iters += 1
                self.train_model["current_iters"] = iters
                self._save_generator_images(iters, epoch, i)

    def train(self):
        self._train_iters()
        show.show_images(self.train_model, self.config, self.dataloader)

model.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : model.py
# Create date : 2019-01-24 17:00
# Modified date : 2019-01-29 00:43
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import random
import torch
import torch.nn as nn
import torch.optim as optim

from DCGAN_architecture import Generator, Discriminator

import record

def _weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def _random_init(config):
    manualSeed = config["manual_seed"]
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

def _get_a_net(Net, config):
    ngpu = config["number_gpus"]
    device = config["device"]
    net = Net(config).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        net = nn.DataParallel(net, list(range(ngpu)))
    net.apply(_weights_init)

    record.save_status(config, net)
    return net

def _get_optimizer(net, config):
    lr = config["learn_rate"]
    beta1 = config["beta1"]
    opt = optim.Adam(net.parameters(), lr=lr, betas=(beta1, 0.999))
    return opt

def _get_fixed_noise(config):
    nz = config["size_of_z_latent"]
    device = config["device"]
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    return fixed_noise

def load_model_dict(train_model, checkpoint):
    train_model["netG"].load_state_dict(checkpoint["netG"])
    train_model["netD"].load_state_dict(checkpoint["netD"])
    train_model["criterion"].load_state_dict(checkpoint["criterion"])
    train_model["optimizerD"].load_state_dict(checkpoint["optimizerD"])
    train_model["optimizerG"].load_state_dict(checkpoint["optimizerG"])

    train_model["fixed_noise"] = checkpoint["fixed_noise"]
    train_model["G_losses"] = checkpoint["G_losses"]
    train_model["D_losses"] = checkpoint["D_losses"]
    train_model["img_list"] = checkpoint["img_list"]
    train_model["current_iters"] = checkpoint["current_iters"]
    train_model["current_epoch"] = checkpoint["current_epoch"]
    train_model["config"] = checkpoint["config"]
    train_model["take_time"] = checkpoint["take_time"]
    return train_model

def get_model_dict(train_model):
    model_dict = {}
    model_dict["netG"] = train_model["netG"].state_dict()
    model_dict["netD"] = train_model["netD"].state_dict()
    model_dict["criterion"] = train_model["criterion"].state_dict()
    model_dict["optimizerD"] = train_model["optimizerD"].state_dict()
    model_dict["optimizerG"] = train_model["optimizerG"].state_dict()

    model_dict["fixed_noise"] = train_model["fixed_noise"]
    model_dict["G_losses"] = train_model["G_losses"]
    model_dict["D_losses"] = train_model["D_losses"]
    model_dict["img_list"] = train_model["img_list"]
    model_dict["current_iters"] = train_model["current_iters"]
    model_dict["current_epoch"] = train_model["current_epoch"]
    model_dict["config"] = train_model["config"]
    model_dict["take_time"] = train_model["take_time"]

    return model_dict


def init_train_model(config):
    _random_init(config)
    netG = _get_a_net(Generator, config)
    netD = _get_a_net(Discriminator, config)
    criterion = nn.BCELoss()
    optimizerD = _get_optimizer(netD, config)
    optimizerG = _get_optimizer(netG, config)

    fixed_noise = _get_fixed_noise(config)

    train_model = {}
    train_model["netG"] = netG
    train_model["netD"] = netD
    train_model["criterion"] = criterion
    train_model["optimizerD"] = optimizerD
    train_model["optimizerG"] = optimizerG
    train_model["fixed_noise"] = fixed_noise

    train_model["G_losses"] = []
    train_model["D_losses"] = []
    train_model["img_list"] = []
    train_model["current_iters"] = 0
    train_model["current_epoch"] = 0
    train_model["config"] = config
    train_model["take_time"] = 0.0

    return train_model

def _run_Discriminator(netD, data, label, loss):
    output = netD(data).view(-1)
    err = loss(output, label)
    err.backward()
    m = output.mean().item()
    return err, m

def get_Discriminator_loss(netD, optimizerD, real_data, fake_data, label, criterion, config):
    netD.zero_grad()
    errD_real, D_x = _run_Discriminator(netD, real_data, label, criterion)
    label.fill_(config["fake_label"])
    errD_fake, D_G_z1 = _run_Discriminator(netD, fake_data, label, criterion)
    errD = errD_real + errD_fake
    optimizerD.step()
    return errD, D_x, D_G_z1

def get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, config):
    netG.zero_grad()
    label.fill_(config["real_label"])  # fake labels are real for generator cost
    errG, D_G_z2 = _run_Discriminator(netD, fake_data, label, criterion)
    optimizerG.step()
    return errG, D_G_z2

def get_label(data, config):
    b_size = data.size(0)
    real_label = config["real_label"]
    device = config["device"]
    label = torch.full((b_size, ), real_label, device=device)
    return label

def get_noise(data, config):
    b_size = data.size(0)
    device = config["device"]
    nz = config["size_of_z_latent"]
    noise = torch.randn(b_size, nz, 1, 1, device=device)
    return noise

show.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : show.py
# Create date : 2019-01-24 17:19
# Modified date : 2019-01-28 17:31
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import numpy as np
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from matplotlib import  rcParams

from matplotlib.animation import ImageMagickWriter

import record

rcParams["animation.embed_limit"] = 500

def show_some_batch(real_batch,device):
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
    plt.show()

def _plot_real_and_fake_images(real_batch, device, img_list, save_path):

    plt.figure(figsize=(30, 30))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    name = "real_and_fake.jpg"
    full_path_name = "%s/%s" % (save_path, name)
    plt.savefig(full_path_name)
    #plt.show()

def _show_generator_images(G_losses, D_losses, save_path):
    plt.figure(figsize=(40, 20))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()

    name = "G_D_losses.jpg"
    full_path_name = "%s/%s" % (save_path, name)
    plt.savefig(full_path_name)
    #plt.show()

def _show_img_list(img_list):
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

    HTML(ani.to_jshtml())
    plt.show()

def _save_img_list(img_list, save_path, config):
    #_show_img_list(img_list)
    metadata = dict(title='generator images', artist='Matplotlib', comment='Movie support!')
    writer = ImageMagickWriter(fps=1,metadata=metadata)
    ims = [np.transpose(i, (1, 2, 0)) for i in img_list]
    fig, ax = plt.subplots()
    with writer.saving(fig, "%s/img_list.gif" % save_path,500):
        for i in range(len(ims)):
            ax.imshow(ims[i])
            ax.set_title("step {}".format(i * config["save_every"]))
            writer.grab_frame()

def show_images(train_model, config, dataloader):
    G_losses = train_model["G_losses"]
    D_losses = train_model["D_losses"]
    img_list = train_model["img_list"]
    save_path = record.get_check_point_path(config)

    _show_generator_images(G_losses, D_losses, save_path)
    _save_img_list(img_list,save_path,config)
    real_batch = next(iter(dataloader))
    _plot_real_and_fake_images(real_batch, config["device"], img_list, save_path)

record.py


#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : record.py
# Create date : 2019-01-28 15:51
# Modified date : 2019-01-28 18:07
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os

def _get_param_str(config):
    # pylint: disable=bad-continuation
    param_str = "%s_%s_%s_%s_%s_%s_%s" % (
                                config["dataset"],
                                config["image_size"],
                                config["batch_size"],
                                config["number_of_generator_feature"],
                                config["number_of_discriminator_feature"],
                                config["size_of_z_latent"],
                                config["learn_rate"],
                                )
    # pylint: enable=bad-continuation
    return param_str

def get_check_point_path(config):
    param_str = _get_param_str(config)
    directory = "%s/save/%s/" % (config["data_path"], param_str)
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

def get_check_point_file_full_path(config):
    path = get_check_point_path(config)
    param_str = _get_param_str(config)
    file_full_path = "%s%scheckpoint.tar" % (path, param_str)
    return file_full_path

def _write_output(config, con):
    save_path = get_check_point_path(config)
    file_full_path = "%s/output" % save_path
    f = open(file_full_path, "a")
    f.write("%s\n" %  con)
    f.close()

def record_dict(config, dic):
    save_status(config, "config:")
    for key in dic:
        dic_str = "%s : %s" % (key, dic[key])
        save_status(config, dic_str)

def save_status(config, con):
    print(con)
    _write_output(config, con)

def print_status(step_time, take_time, epoch, i, errD, errG, D_x, D_G_z1, D_G_z2, config, dataloader):
    num_epochs = config["num_epochs"]
    # pylint: disable=bad-continuation
    print_str = '[%d/%d]\t[%d/%d]\t Loss_D: %.4f\t Loss_G: %.4f\t D(x): %.4f\t D(G(z)): %.4f / %.4f take_time: %.fs' % (
                                                        epoch,
                                                        num_epochs,
                                                        i,
                                                        len(dataloader),
                                                        errD.item(),
                                                        errG.item(),
                                                        D_x,
                                                        D_G_z1,
                                                        D_G_z2,
                                                        take_time,
                                                        )
    # pylint: enable=bad-continuation
    save_status(config, print_str)

DCGAN_architecture.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : DCGAN_architecture.py
# Create date : 2019-01-26 23:16
# Modified date : 2019-01-27 22:47
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.ngpu = config["number_gpus"]
        nz = config["size_of_z_latent"]
        ngf = config["number_of_generator_feature"]
        nc = config["number_channels"]
        # pylint: disable=bad-continuation
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        # pylint: enable=bad-continuation

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.ngpu = config["number_gpus"]
        ndf = config["number_of_discriminator_feature"]
        nc = config["number_channels"]
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

celeba_dataset.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : celeba_dataset.py
# Create date : 2019-01-24 18:02
# Modified date : 2019-01-26 22:57
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms

def get_dataloader(config):
    image_size = config["image_size"]
    batch_size = config["batch_size"]
    dataroot = config["data_path"]
    workers = config["workers"]

    tf = transform=transforms.Compose([
           transforms.Resize(image_size),
           transforms.CenterCrop(image_size),
           transforms.ToTensor(),
           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
       ])

    dataset = dset.ImageFolder(root=dataroot, transform=tf)
    dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=workers)
    return dataloader

运行2个epoch后生成的结果
Loss_G 和 Loss_D 对比图
在这里插入图片描述
真实图片和假图片对比
在这里插入图片描述

运行5个 epoch后的结果]
Loss_G 和 Loss_D 的对比图
在这里插入图片描述
真实图片和假图片对比

在这里插入图片描述

运行200个epoch 后结果

Loss_G 和 Loss_D 的对比图
在这里插入图片描述

真实图片和假图片对比

在这里插入图片描述

github :https://github.com/darr/DCGAN

猜你喜欢

转载自blog.csdn.net/DarrenXf/article/details/86684874
今日推荐