简单记录:GAN(对抗生成网络),pytorch + MNIST

目录

NET

多层感知器版:

卷积版

损失函数:

train:

总结:


NET

多层感知器版:

##GAN网络,多层感知器版
##判别网络
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(28 * 28, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            # nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            # nn.Sigmoid()        #结果在0~1之间
        )

    def forward(self, x):
        x = self.dis(x)
        return x

##生成网络
class generator(nn.Module):
    def __init__(self, in_size = 96):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_size, 1024),
            #顺别说一下如果隐藏层是256的话,效果挺差的
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 1024),
            # nn.BatchNorm1d(256),
            nn.ReLU(True),

            nn.Linear(1024, 784),
            nn.Tanh()           ##产生的结果在-1 ~ 1 之间
        )

    def forward(self, x):
        x = self.gen(x)
        return x

卷积版

###GAN 卷积版
class DC_discriminator(nn.Module):
    def __init__(self):
        super(DC_discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1),      #
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(2,2),              #

            nn.Conv2d(32, 64, 5, 1),         #
            nn.LeakyReLU(0.02, True),
            nn.MaxPool2d(2, 2)                  #
        )

        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.02, True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.shape[0],-1)
        x = self.fc(x)
        return x

##生成网络
class DC_generator(nn.Module):
    def __init__(self, in_size = 96):
        super(DC_generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_size, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),

            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128),
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

损失函数:

## 定义对抗网络的损失函数
## MSE +  KLD
bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    false_labels = torch.autograd.Variable(torch.zeros(size, 1)).float().cuda()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

##定义生成网络的损失函数
def generator_loss(logits_fake): # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.autograd.Variable(torch.ones(size, 1)).float().cuda()
    loss = bce_loss(logits_fake, true_labels)
    return loss


## 这里定义 GAN 的损失函数
## 这里用的是最小二乘
def ls_discriminator_loss(logits_real, logits_fake):
    loss = 0.5 * ((logits_real - 1) ** 2).mean() + 0.5 * (logits_fake ** 2).mean()
    return loss

def ls_generator_loss(logits_fake):
    loss = 0.5 * ((logits_fake - 1) ** 2).mean()
    return loss

train:

import torch
import torchvision.utils

import six_Net
import torch.nn as nn
import tqdm
import os

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader, sampler
from torch import optim
from torchvision.utils import save_image

##设定参数
NUM_TRAIN = 50000
NUM_VAL = 5000

NOISE_DIM = 96
batch_size = 128

def show_images(images): # 定义画图工具
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return

def preprocess_img(x):
    x = transforms.ToTensor()(x)    # x (0., 1.)
    return (x - 0.5) / 0.5          # x (-1., 1.)

def deprocess_img(x):           # x (-1., 1.)
    return (x + 1.0) / 2.0      # x (0., 1.)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_set = mnist.MNIST('./data', train=True, transform=preprocess_img, download=False)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)

def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)       #设定最小为0 ,最大为1
    x = x.view(x.shape[0], 1, 28, 28)
    return x

##定义训练
def train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
             num_epochs=10, noise_size=96, num_img=6):
    f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
    plt.ion()  # Turn the interactive mode on, continuously plot

    for epoch in range(num_epochs):
        print()
        for iteration, (ima, _) in enumerate((train_data)):
            bs = ima.shape[0]
            ##判决网络
            real_data = torch.autograd.Variable(ima).view(bs, -1).to(device)# 真实数据
            logits_real = D_net(real_data)      # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5 ) / 0.5 # -1 ~ 1 的均匀分布
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据
            logits_fake = D_net(fake_images)    # 判别网络得分

            ## 判决器的反向传播
            d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()                  # 优化判别网络

            ## 生成网络
            g_fake_seed = torch.autograd.Variable(sample_noise).to(device)# 假数据
            fake_images = G_net(g_fake_seed)    # 生成的假的数据

            gen_logits_fake = D_net(fake_images)    ## 放进判决器 看判决器是否能识别出来
            g_error = generator_loss(gen_logits_fake)# 生成网络的 loss

            ##生成网络的反向传播
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化判别网络

            ## 每隔 20 次画出生成的图片
            if iteration % 20 == 0 :
                print(f'Epoch: {epoch + 1} | Iter: {iteration} | '
                      f'D_loss: {d_total_error.cpu().data.numpy()} | '
                      f'G_loss:{g_error.cpu().data.numpy()}')
                im_gen = deprocess_img(fake_images.data.cpu().numpy())
                for i in range(num_img ** 2):
                    a[i // num_img][i % num_img].imshow(np.reshape(im_gen[i], (28, 28)), cmap='gray')
                    a[i // num_img][i % num_img].set_xticks(())
                    a[i // num_img][i % num_img].set_yticks(())
                plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration))
                plt.pause(0.01)
        pic = to_img(fake_images.cpu().data)
        torchvision.utils.save_image(pic, f'./out/ima_{epoch + 1}.png')


D_net = six_Net.discriminator().to(device)
G_net = six_Net.generator(NOISE_DIM).to(device)
D_optimizer = get_optimizer(D_net)
G_optimizer = get_optimizer(G_net)
discriminator_loss = six_Net.discriminator_loss
generator_loss = six_Net.generator_loss

train_gen(D_net, G_net,
              D_optimizer, G_optimizer,
             discriminator_loss, generator_loss,
          10, 96, 5)

总结:

网络由两个小网络组成,一个负责判别,一个负责生成

判别网络:

先将真正的图片输入 判别器 ,得到的是1位的    真实概率

随机生成一组数据,送入 生成器 ,得到  假数据

再将 假数据 送入 判别器 中,得到  虚假概率

最后将  真实概率  +  虚假概率 ,送入  损失函数计算

反向传播

生成网络:

随机生成的数据,送入 生成器  得到  虚假数据

将  虚假数据 送入 判别器 看它能不能甄别出来,得到的概率

最后送入  损失函数计算

反向传播

值得注意的是:如果生成时三通到的彩色人物图片的话,要去掉判别器里的BN层,不然生成的图片人眼根本看不出来

猜你喜欢

转载自blog.csdn.net/qq_42792802/article/details/126164883
今日推荐