torch学习 (三十七):DCGAN详解

引入

  论文详解Unsupervised representation learning with deep convolutional generative adversarial networks
  对抗生成网络的核心在于生成器鉴别器以及两者之间的交互,本文将详细对这几个部分进行介绍。

1 生成器

  DCGAN生成器的本质是多个卷积层、批量归一化、激活函数的堆叠,具体结构如下表:

结构 输入通道 输出通道 卷积核大小 步幅 填充 后续
ConvTranspose2d nz ngf × \times × 8 4 1 0 BatchNorm2d+ReLU
ConvTranspose2d ngf × \times × 8 ngf × \times × 4 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf × \times × 4 ngf × \times × 2 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf × \times × 2 ngf 4 2 1 BatchNorm2d+ReLU
ConvTranspose2d ngf nc 4 2 1 Tanh

其中nz为输入通道数,ngf为给定结点数,nc是输出类别数。例如对于MNIST数据集,可设置nz=100,ngf=64,nc=1。
  对应代码如下:

class Generator(nn.Module):
    """生成器"""

    def __init__(self):
        super(Generator, self).__init__()
        # 使用的GPU数量
        self.ngpu = ngpu
        # 生成器结构,与表格中一致
        self.main = nn.Sequential(
            # 输入大小:nz
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 大小:(ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 大小:(ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 大小:(ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 大小:(ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 大小:(nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output

  生成器的结构输出如下 (接下来都以mnist为例):

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

2 鉴别器

  鉴别器与生成器的不同之处在于,其卷积层、激活函数的设置不同,输入通道数也是逐渐增加的:

结构 输入通道 输出通道 卷积核大小 步幅 填充 后续
Conv2d nc ndf 4 2 1 LeakyReLU(0.2)
Conv2d ndf ndf × \times × 2 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf × \times × 2 ndf × \times × 4 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf × \times × 4 ndf × \times × 8 4 2 1 BatchNorm2d+LeakyReLU(0.2)
Conv2d ndf × \times × 8 1 4 1 0 Sigmid()

其中nc为输入通道数,ndf为给定结点数。注意输出通道变为1了哟。
  对应代码如下:

class Discriminator(nn.Module):
    """鉴别器"""

    def __init__(self):
        super(Discriminator, self).__init__()
        # 使用的GPU数量
        self.ngpu = ngpu
        # 鉴别器的结构
        self.main = nn.Sequential(
            # 输入大小: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        # 与生成器类似哟
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        # 注意输出已经延展成一列的张量了
        return output.view(-1, 1).squeeze(1)

  鉴别器的结构输出如下:

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

3 模型训练:生成器与鉴别器的交互

  下图绘制的是DCGAN生成器与鉴别器的交互过程,数字代表该步骤在程序中的运行过程:

  训练过程的代码如下:

def DCGAN():
    """DCGAN主函数"""
    # 每一轮训练
    for epoch in range(opt.nepoch):
        # 每一个批次的训练
        for i, data in enumerate(dataset, 0):
            """步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""
            # 首先基于真实图像进行训练
            # 鉴别器的梯度清零
            netD.zero_grad()
            # 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的
            data = data[0].to(device)
            # 获取当前批次的图像的数量
            batch_size = data.size(0)
            # 将当前批次所有图像的标签设置为指定的真实标签,如1
            label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)
            # 先鉴别器输出一下
            output = netD(data)
            # 计算鉴别器上基于真实图像计算的损失
            errorD_real = loss(output, label)
            errorD_real.backward()
            D_x = output.mean().item()

            # 训练虚假图像
            # 随机生成一个虚假图像
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            # 生成器开始造假
            fake = netG(noise)
            # 标签设置为假的标签
            label.fill_(fake_label)
            # 鉴别器来判断
            output = netD(fake.detach())
            # 假图片的损失
            errorD_fake = loss(output, label)
            errorD_fake.backward()
            # 假图片的梯度
            D_G_z1 = output.mean().item()
            # 鉴别器的总损失
            errorD = errorD_real + errorD_fake
            # 鉴别器优化一下
            optimD.step()

            """训练生成器"""
            # 生成器梯度清零
            netG.zero_grad()
            # 生成器填真实标签,毕竟想造假
            label.fill_(real_label)
            # 得到假图片的输出
            output = netD(fake)
            # 计算生成器的损失
            errorG = loss(output, label)
            # 生成器梯度清零
            errorG.backward()
            # 再一次假图片的梯度
            D_G_z2 = output.mean().item()
            optimG.step()

            # 输出一些关键信息
            print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f "
                  "D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),
                                                     errorD.item(), errorG.item(),
                                                     D_x, D_G_z1, D_G_z2))

            # 存储图像,可以设置想要的存储时间结点哈
            if i % 100 == 0:
                vutils.save_image(data, "real_image.png", normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                                  'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)

4 参数设置

  用的如下:

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

  相关参数设置,变量的信息于help中给出:

def get_parser():
    """获取参数设置器"""
    parser = argparse.ArgumentParser()
    # 设置实验用数据集的类型,help中为所支持的数据集类型
    parser.add_argument("--dataset", required=False, default="mnist",
                        help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")
    parser.add_argument("--data_root", required=False, help="数据集的存储路径",
                        default=r"D:\Data\OneDrive\Code\MIL1\Data")
    parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")
    parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")
    parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")
    parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")
    parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")
    parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")
    parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")
    parser.add_argument("--lr", type=float, default=0.0002, help="学习率")
    parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")
    parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")
    parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")
    parser.add_argument("--manual_seed", type=int, help="随机种子")
    parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")

    return parser.parse_args()

  管理随机种子

def get_seed():
    """管理随机种子"""
    if opt.manual_seed is None:
        opt.manual_seed = random.randint(1, 10000)
    random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)

  网络的权重等设置:

def init_weight(m):
    """初始化权重"""
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

  主函数

if __name__ == '__main__':
    # 参数管理器
    opt = get_parser()
    # 设备
    device = torch.device("cuda:0" if opt.cuda else "cpu")
    # GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置
    ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)
    # 数据集、输出通道数
    dataset, nc = get_data()
    # 启动生成器
    netG = Generator().to(device)
    netG.apply(init_weight)
    # 启动鉴别器
    netD = Discriminator().to(device)
    netD.apply(init_weight)
    # 损失函数
    loss = nn.BCELoss()
    # 设置噪声及标签
    fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0
    # 启动优化器
    optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    DCGAN()

5 数据载入

  所集成的数据集如下:

def get_data():
    """获取数据集"""
    # 输出通道数
    nc = 3
    if opt.dataset in ["imagenet", "folder", "lfw"]:
        dataset = dset.ImageFolder(root=opt.data_root,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.image_size),
                                       transforms.CenterCrop(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == "lsun":
        classes = [c + "_train" for c in opt.classes.split(',')]
        dataset = dset.LSUN(root=opt.data_root, classes=classes,
                            transform=transforms.Compose([
                                transforms.Resize(opt.image_size),
                                transforms.CenterCrop(opt.image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == "cifar10":
        dataset = dset.CIFAR10(root=opt.data_root, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(opt.image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    elif opt.dataset == "mnist":
        dataset = dset.MNIST(root=opt.data_root, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(opt.image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,)),
                             ]))
        nc = 1
    else:
        dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),
                                transform=transforms.ToTensor())

    return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),
            nc)

6 完整代码

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


def get_parser():
    """获取参数设置器"""
    parser = argparse.ArgumentParser()
    # 设置实验用数据集的类型,help中为所支持的数据集类型
    parser.add_argument("--dataset", required=False, default="mnist",
                        help="数据集类型:cifar10 | lsun | mnist |imagenet | folder | lfw | fake")
    parser.add_argument("--data_root", required=False, help="数据集的存储路径",
                        default=r"D:\Data\OneDrive\Code\MIL1\Data")
    parser.add_argument("--workers", type=int, default=2, help="数据集下载并行数")
    parser.add_argument("--batch_size", type=int, default=64, help="数据集的输入批次大小")
    parser.add_argument("--image_size", type=int, default=64, help="输入图像的高/宽")
    parser.add_argument("--nz", type=int, default=100, help="隐含向量z的大小")
    parser.add_argument("--ngf", type=int, default=64, help="生成器隐藏层结点数")
    parser.add_argument("--ndf", type=int, default=64, help="鉴别器隐藏层结点数")
    parser.add_argument("--nepoch", type=int, default=5, help="训练轮次数")
    parser.add_argument("--lr", type=float, default=0.0002, help="学习率")
    parser.add_argument("--beta1", type=float, default=0.5, help="Adam的beta1")
    parser.add_argument("--cuda", action="store_true", help="CUDA是否可用")
    parser.add_argument("--ngpu", type=int, default=1, help="GPU使用数量")
    parser.add_argument("--manual_seed", type=int, help="随机种子")
    parser.add_argument("--classes", default="bedroom", help="LSUN卧室数据集的列表划分分隔符")

    return parser.parse_args()


def get_seed():
    """管理随机种子"""
    if opt.manual_seed is None:
        opt.manual_seed = random.randint(1, 10000)
    random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)


def get_data():
    """获取数据集"""
    # 输出通道数
    nc = 3
    if opt.dataset in ["imagenet", "folder", "lfw"]:
        dataset = dset.ImageFolder(root=opt.data_root,
                                   transform=transforms.Compose([
                                       transforms.Resize(opt.image_size),
                                       transforms.CenterCrop(opt.image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == "lsun":
        classes = [c + "_train" for c in opt.classes.split(',')]
        dataset = dset.LSUN(root=opt.data_root, classes=classes,
                            transform=transforms.Compose([
                                transforms.Resize(opt.image_size),
                                transforms.CenterCrop(opt.image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == "cifar10":
        dataset = dset.CIFAR10(root=opt.data_root, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(opt.image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    elif opt.dataset == "mnist":
        dataset = dset.MNIST(root=opt.data_root, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(opt.image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,)),
                             ]))
        nc = 1
    else:
        dataset = dset.FakeData(image_size=(3, opt.image_size, opt.image_size),
                                transform=transforms.ToTensor())

    return (torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)),
            nc)


def init_weight(m):
    """初始化权重"""
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
    """生成器"""

    def __init__(self):
        super(Generator, self).__init__()
        # 使用的GPU数量
        self.ngpu = ngpu
        # 生成器结构,与表格中一致
        self.main = nn.Sequential(
            # 输入大小:nz
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 大小:(ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 大小:(ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 大小:(ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 大小:(ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 大小:(nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output


class Discriminator(nn.Module):
    """鉴别器"""

    def __init__(self):
        super(Discriminator, self).__init__()
        # 使用的GPU数量
        self.ngpu = ngpu
        # 鉴别器的结构
        self.main = nn.Sequential(
            # 输入大小: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        # 与生成器类似哟
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        # 注意输出已经延展成一列的张量了
        return output.view(-1, 1).squeeze(1)


def DCGAN():
    """DCGAN主函数"""
    # 每一轮训练
    for epoch in range(opt.nepoch):
        # 每一个批次的训练
        for i, data in enumerate(dataset, 0):
            """步骤1:训练鉴别器,即最大化log(D(x)) + log(1 - D(G(z)))"""
            # 首先基于真实图像进行训练
            # 鉴别器的梯度清零
            netD.zero_grad()
            # 格式化当前批次,这里data = data[0]的原因是因为所有批次的图像是放在一个列表里面的
            data = data[0].to(device)
            # 获取当前批次的图像的数量
            batch_size = data.size(0)
            # 将当前批次所有图像的标签设置为指定的真实标签,如1
            label = torch.full((batch_size, ), real_label, dtype=data.dtype, device=device)
            # 先鉴别器输出一下
            output = netD(data)
            # 计算鉴别器上基于真实图像计算的损失
            errorD_real = loss(output, label)
            errorD_real.backward()
            D_x = output.mean().item()

            # 训练虚假图像
            # 随机生成一个虚假图像
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            # 生成器开始造假
            fake = netG(noise)
            # 标签设置为假的标签
            label.fill_(fake_label)
            # 鉴别器来判断
            output = netD(fake.detach())
            # 假图片的损失
            errorD_fake = loss(output, label)
            errorD_fake.backward()
            # 假图片的梯度
            D_G_z1 = output.mean().item()
            # 鉴别器的总损失
            errorD = errorD_real + errorD_fake
            # 鉴别器优化一下
            optimD.step()

            """训练生成器"""
            # 生成器梯度清零
            netG.zero_grad()
            # 生成器填真实标签,毕竟想造假
            label.fill_(real_label)
            # 得到假图片的输出
            output = netD(fake)
            # 计算生成器的损失
            errorG = loss(output, label)
            # 生成器梯度清零
            errorG.backward()
            # 再一次假图片的梯度
            D_G_z2 = output.mean().item()
            optimG.step()

            # 输出一些关键信息
            print("[%d/%d][%d/%d] lossD: %.4f lossG: %.4f "
                  "D(x): %.4f D(G(z)): %.4f/%.4f" % (epoch, opt.nepoch, i, len(dataset),
                                                     errorD.item(), errorG.item(),
                                                     D_x, D_G_z1, D_G_z2))

            # 存储图像,可以设置想要的存储时间结点哈
            if i % 100 == 0:
                vutils.save_image(data, "real_image.png", normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                                  'fake_image_%03d_%03d.png' % (epoch, i), normalize=True)


if __name__ == '__main__':
    # 参数管理器
    opt = get_parser()
    # 设备
    device = torch.device("cuda:0" if opt.cuda else "cpu")
    # GPU数量、生成器输入通道数、生成器结点数设置、鉴别器结点数设置
    ngpu, nz, ngf, ndf = int(opt.ngpu), int(opt.nz), int(opt.ngf), int(opt.ndf)
    # 数据集、输出通道数
    dataset, nc = get_data()
    # 启动生成器
    netG = Generator().to(device)
    netG.apply(init_weight)
    # 启动鉴别器
    netD = Discriminator().to(device)
    netD.apply(init_weight)
    # 损失函数
    loss = nn.BCELoss()
    # 设置噪声及标签
    fixed_noise, real_label, fake_label = torch.randn(opt.batch_size, nz, 1, 1, device=device), 1, 0
    # 启动优化器
    optimG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    DCGAN()

7 部分输出图像示意

7.1 真实图像

7.2 训练200个批次

7.2 训练400个批次

7.2 训练600个批次

  设备有限,就这么多了:

猜你喜欢

转载自blog.csdn.net/weixin_44575152/article/details/121426785