生死看淡,不服就GAN(九)----英伟达力作PGGAN实战,生成高清图片

一、前言

2017年,NVIDIA Research 网站发布了一篇颇为震撼的GAN论文:Progressive Growing of GANs for Improved Quality, Stability, and Variation(简称PGGAN),通过使用渐增型GAN网络和精心处理的CelebA-HQ数据集,实现了效果令人惊叹的生成图像,分辨率达到1024x1024。

论文地址

Progressive Growing of GANs for Improved Quality, Stability, and Variation

代码地址

我的代码:https://github.com/coolEphemeroptera/CELEBA_PGGAN
官方代码: https://github.com/tkarras/progressive_growing_of_gans
参考代码:https://github.com/zhangqianhui/progressive_growing_of_gans_tensorflow


二、生成样本展示

通俗的阅读完PGGAN论文后,大致摸清了套路后,决定说GAN就GAN。考虑到设备条件有限,所以并不选用Celeb-HQ高清数据集生产超高清样本 ,决定采用Celeb-A数据集生产128x128分辨率人脸小试牛刀。果然功夫不负有心人,经过多日试验修改,最终获取了较为可观的生成模型并附上训练过程。

128x128 生成器
128x128
64x64 训练过程
64x64
128x128 训练过程
128x128


三、关键方法解读

3.1.基于 ‘批标准差’ 增加多样性(INCREASING VARIATION USING MINIBATCH STANDARD DEVIATION)

由于GAN网络倾向于学习数据集的子分部,由此2016年Salimans提出‘minibatch discrimination’即‘批判别’作为解决方案。通过计算训练批数据的特征图的统计特性来驱动生成样本的特征图满足相似的统计特性。做法是在判别器尾端加入minibatch层,该层处理特征图的统计特性。PGGAN在此基础上做出简化操作来提升样本的多样性。

原文

实现
# 添加多样性特征
def MinibatchstateConcat(nhwf, averaging='all'):
    # input:[N,H,W,fmaps]
    s = nhwf.shape
    # 获取批大小
    group_size = get_N(nhwf)
    """
    计算方法:
            (1)先计算N个特征图的标准差得到特征图fmap1:[1,H,W,fmaps]
            (2)对fmap1求均值 得到值M1:[1,1,1,1]
            (3)复制扩张M2得到N个特征图fmap2:[N,H,W,1]
            (4)将fmap2添加至每个样本的特征图中
    """
    adjusted_std = lambda x, **kwargs: tf.sqrt(tf.reduce_mean((x - tf.reduce_mean(x, **kwargs)) **2, **kwargs) + 1e-8)
    vals = adjusted_std(nhwf, axis=0, keep_dims=True)
    # 求均值
    vals = tf.reduce_mean(vals, keep_dims=True)
    # 复制扩张
    vals = tf.tile(vals, multiples=(group_size, s[1].value, s[2].value, 1))
    # 将统计特征拼接到每个样本特征图中
    return tf.concat([nhwf, vals], axis=3)

3.2 生成器和判别器的归一化

PGGAN使用两种不同的方式来限制梯度和不健康博弈,而且方法均采用非训练的处理方式

3.2.1 平衡学习率(EQUALIZED LEARNING RATE)

原文


He的初始化方法能够确保网络初始化的时候,随机初始化的参数不会大幅度地改变输入信号的强度。然而PGGAN中不仅限初始状态scale而是实时scale,其中公式如下:

实现

# 获取归一化权值(equalized learning rate)
def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None):
    """
    HE公式:0.5*n*var(w)=1 , so:std(w)=sqrt(2)/sqrt(n)=gain/sqrt(fan_in)
    """
    # 某卷积核参数个数(h*w*fmaps1)或dense层输入节点数目fmaps1
    # conv_w:[H,W,fmaps1,fmaps2] or mlp_w:[fmaps1,fmaps2]
    if fan_in is None: fan_in = np.prod(shape[:-1])
    # He init
    std = gain / np.sqrt(fan_in)
    # 归一化
    if use_wscale:
        wscale = tf.constant(np.float32(std), name='wscale')
        return  tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal())*wscale
    else:
        return  tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0,std))

3.2.1 生成器像素归一化(pixel nromalization)

原文

为了避免生成器梯度爆炸,引入像素归一化,介绍如下

Pixel norm,它是local response normalization的变种。Pixel norm沿着channel维度做归一化,这样归一化的一个好处在于,feature map的每个位置都具有单位长度。这个归一化策略与作者设计的Generator输出有较大关系,注意到Generator的输出层并没有Tanh或者Sigmoid激活函数

实现

# 定义像素归一化操作(pixel normalization)
def PN(nd):
    if len(nd.shape) > 2:
        axis_ = 3
    else:
        axis_ = 1
    epsilon = 1e-8
    with tf.variable_scope('PixelNorm'):
        return nd * tf.rsqrt(tf.reduce_mean(tf.square(nd), axis=axis_, keep_dims=True) + epsilon)

四、构造渐增型网络(progressive network)

在递增的训练阶段,生成器和判别器的型号也是在逐步拓展的,比如训练128x128图像,我们从4x4开始训练,训练阶段有:
stage 1 4x4 稳定 level2-net
stage 2 8x8 过度 level3-net
stage 3 8x8 稳定 level3-net
stage 4 16x16 过渡 level4-net
stage 5 16x16 稳定 level4-net
stage 6 32x32 过渡 level5-net
stage 7 32x32 稳定 level5-net
stage 8 64x64 过渡 level6-net
stage 9 64x64 稳定 level6-net
stage 10 128x128 过渡 level7-net
stage 11 128x128 稳定 level7-net
在代码中体现为:

PGGAN(0,latents_size,batch_size,  lowest, highest, level=2, isTransit=False,epochs=epochs,data_size=data_size)
PGGAN(1,latents_size, batch_size, lowest, highest, level=3, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(2,latents_size, batch_size, lowest, highest, level=3, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(3,latents_size, batch_size, lowest, highest, level=4, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(4,latents_size, batch_size, lowest, highest, level=4, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(5,latents_size, batch_size, lowest, highest, level=5, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(6,latents_size, batch_size, lowest, highest, level=5, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(7,latents_size, batch_size, lowest, highest, level=6, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(8,latents_size, batch_size, lowest, highest, level=6, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(9,latents_size, batch_size, lowest, highest, level=7, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(10,latents_size, batch_size, lowest, highest, level=7, isTransit=False, epochs=epochs, data_size=data_size)

4.1 上采样和下采样

论文中上采样由近邻插值方法,下采样由平均池化方法实现。
同时在卷积过程中,PG-GAN移除了deconv 方式,改用了conv + upsample。checkboard提到了deconv会让生成模型遭受checkerboard效应,关于什么时候是checkerboard,可以参考链接的介绍。
以下论文给出的生成器和判别器中的卷积块:
生成器卷积块:

判别器卷积块:

有点类似于高斯金字塔的上下采样过程(高斯金字塔和拉普拉斯金字塔 https://blog.csdn.net/poem_qianmo/article/details/26157633))

实现
# 上采样
def upsampling2d(nhwf):
    _, h, w, _ = int_shape(nhwf)
    return tf.image.resize_nearest_neighbor(nhwf, (2 * h, 2 * w))

# 下采样
def downsampling2d(nhwf):
    return tf.nn.avg_pool(nhwf, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')

4.2 设计不同level的生成器和判别器 ( level = log2(res), res:当前分辨率)

以生成一张5级(32x32)图片为例,GAN网络从最低分辨率4x4慢慢向最高分辨率32x32学习,其中G/D网络也是逐阶段递增的。 接下来将以生成器为例,解释生成器的不同阶段的搭建方式

(1)建立初级(level=2)卷积层

原文

实现
with tf.variable_scope('generator',reuse=reuse):
    # ******** 构造二级初始架构 ******************
    with tf.variable_scope('scale_%d'%(2)):
        nf = PN(latents)
        # 论文:CONV4x4+CONV3x3,这里CONV4x4采用FC替代(参考论文源码)
        with tf.variable_scope('Dense0' ):
            nf = dense(nf,fmaps=fn(2)*4*4,gain=np.sqrt(2)/4,use_wscale=True)# Dense0:[N,512] to [N,4*4*512}
            nhwf = tf.reshape(nf,[-1, 4, 4,fn(2)])# reshape:[N,4*4*512} to [N,4,4,512]
            nhwf = PN(lrelu(add_bias(nhwf)))
        with tf.variable_scope('CONV1'):
            nhwf = PN(lrelu(add_bias(conv2d(nhwf,fmaps=fn(2), kernel=3, use_wscale=True))))

(2)建立拓扑卷积层

在4.1介绍了卷积块,这里我们通过这些卷积块来拼接成更高级网络,当然每个卷积块的特征图数量是指定的,PGGAN里指定为:
feats_map_num = [512,512,512,512,256,128,64,32,16]
拓扑结构如下:

实现:这里我们只要生成128x128即可(硬件资源有限)

首先定义生成器卷积块

def G_CONV_BLOCK(nhwf, level, use_wscale=False):
    """
    上采样+CONV0 = pyrUp
    """
    # 上采样
    with tf.variable_scope('upscale2d'):
        nhwf = upsampling2d(nhwf)
    # CONV0
    with tf.variable_scope('CONV0'):
        nhwf = PN(lrelu(add_bias(conv2d(nhwf, fmaps=fn(level), kernel=3, use_wscale=use_wscale))))
    # CONV1
    with tf.variable_scope('CONV1'):
        nhwf = PN(lrelu(add_bias(conv2d(nhwf, fmaps=fn(level), kernel=3, use_wscale=use_wscale))))
    return nhwf

再拓展生成器就可以了,注意到如果是训练阶段,还需保存上一阶段输出(toRGB)

for scale in range(3,level+1):
    if scale == level and isTransit: # 在最后卷积层新建之前,获取当前输出图片并上采样
        RGB0 = upsampling2d(nhwf)  # 上采样
        RGB0 = toRGB(RGB0,scale-1,use_wscale=True)# toRGB
    with tf.variable_scope('scale_%d'%scale):
        nhwf = G_CONV_BLOCK(nhwf,scale,use_wscale=True)# 卷积层拓展

(3)生成器输出(整合特征图:toRGB)

经过多层卷积之后,我们获得了特征图,输出端我们需要将这些特征图整合为3通道的RGB图像
首先定义toRGB函数

实现
# 定义toRGB
def toRGB(nhwf, level, use_wscale=False):  
    with tf.variable_scope('level_%d_toRGB' % level):
        return add_bias(conv2d(nhwf, fmaps=3, kernel=1, gain=1, use_wscale=use_wscale))

然后需要考虑该阶段是否为过渡阶段,如果是过渡阶段还需将上一阶段输出过渡至本阶段

实现
RGB1 = toRGB(nhwf, level,use_wscale=True)  # 获取最后卷积层输出图像
# 判断是否为过渡阶段
if isTransit:
    nhw3 = trans_alpha * RGB1 + (1 - trans_alpha) * RGB0  # 由RGB0平滑过渡到RGB1
else:
    nhw3 = RGB1

return nhw3

其中过渡系数 0<= trans_alpha <=1,且随着训练进度线性递增

(4) 完整生成器定义

下面给出完整的生成器定义,判别器原理和生成器类似,相关代码请参考我的github

def Generator_PG(latents,level,reuse = False,isTransit = False,trans_alpha = 0.0):
    """
    :param latents: 输入分布
    :param level: 网络等级(阶段)
    :param reuse: 变量复用
    :param isTransit: 是否fade_in
    :param trans_alpha: 过度系数
    :return: 生成图片
    """
    """
        说明:(1)Generator构成:scale_2 + scale_3~level + toRGB , 其中toRGB层将全部特征图合成RGB
              (2) 过渡阶段: ① 本阶段RGB将融合上一阶段RGB输出。对于上一阶段RGB处理层而言,通过特征图上采样匹配大小,再toRGB再融合。
                           ② 上一阶段toRGB的卷积核参数对于上采样后的特征图依然有效
    """
    # ******************************* 构造PG生成器 ************************************
    with tf.variable_scope('generator',reuse=reuse):
        # ******** 构造二级初始架构 ******************
        with tf.variable_scope('scale_%d'%(2)):
            nf = PN(latents)
            # 论文:CONV4x4+CONV3x3,这里CONV4x4采用FC替代(参考论文源码)
            with tf.variable_scope('Dense0' ):
                nf = dense(nf,fmaps=fn(2)*4*4,gain=np.sqrt(2)/4,use_wscale=True)# Dense0:[N,512] to [N,4*4*512}
                nhwf = tf.reshape(nf,[-1, 4, 4,fn(2)])# reshape:[N,4*4*512} to [N,4,4,512]
                nhwf = PN(lrelu(add_bias(nhwf)))
            with tf.variable_scope('CONV1'):
                nhwf = PN(lrelu(add_bias(conv2d(nhwf,fmaps=fn(2), kernel=3, use_wscale=True))))

        # ********* 构造拓扑架构(3~level) *********************
        for scale in range(3,level+1):
            if scale == level and isTransit: # 在最后卷积层新建之前,获取当前输出图片并上采样
                RGB0 = upsampling2d(nhwf)  # 上采样
                RGB0 = toRGB(RGB0,scale-1,use_wscale=True)# toRGB
            with tf.variable_scope('scale_%d'%scale):
                nhwf = G_CONV_BLOCK(nhwf,scale,use_wscale=True)# 卷积层拓展

        # ******************* toRGB *****************************
        RGB1 = toRGB(nhwf, level,use_wscale=True)  # 获取最后卷积层输出图像
        # 判断是否为过渡阶段
        if isTransit:
            nhw3 = trans_alpha * RGB1 + (1 - trans_alpha) * RGB0  # 由RGB0平滑过渡到RGB1
        else:
            nhw3 = RGB1

        return nhw3

五、生成图片质量评价—— sliced wasserstein distance

原文

六、tensorflow上实现多阶段训练

PGGAN论文同时也给出了训练Celeb-HQ的一些trick(在论文的A.1节),这里我们参考其trick在tensorflow上实现
由于tf的计算图为静态图,因此需要训练完一个阶段,再保存其参数,再重新编写计算图再读取上一阶段参数。这里注意的是模型读取参数需要匹配正确,下面给出训练过程代码,其中结束每一阶段注意清除图(tf.reset_default_graph())

import time
import os
from ops import *
import utils as us
import tfr_tools as tfr
import sliced_wasserstein_distance as swd
os.environ['CUDA_VISIBLE_DEVICES']='0'


def PGGAN(  id ,     # PG模型序号
            latents_size, # 噪声型号
            batch_size, # 批型号
            lowest,# 最低网络级数
            highest,#最高网络级数
            level,# 目标网络等级
            isTransit, # 是否过渡
            epochs, # 训练循环次数
            data_size, # 数据集大小
                ):
    #-------------------- 超参 --------------------------#
    learning_rate = 0.001
    lam_gp = 10
    lam_eps = 0.001
    beta1 = 0.0
    beta2 = 0.99
    max_iters = int(epochs * data_size / batch_size)
    n_critic = 1  # 判别器训练次数

    #---------- (1)创建目录和指定模型路径 -------------#

    # 当前模型路径
    model_path = './ckpt/PG%d_level%d_%s' % (id,level, isTransit)
    us.MKDIR(model_path)
    # 上一级网络模型路径
    if isTransit:
        old_model_path = r'./ckpt/PG%d_level%d_%s/' % (id-1,level - 1, not isTransit)  # 上一阶段稳定模型
    else:
        old_model_path = r'./ckpt/PG%d_level%d_%s/' % (id-1,level, not isTransit)  # 该阶段过度模型

    #--------------------- (2)定义输入输出 --------------#

    # 图像分辨率
    res = int(2 ** level)
    # 定义噪声输入
    latents = tf.placeholder(name='latents', shape=[None, latents_size], dtype=tf.float32)
    # 定义数据输入
    real_images = tf.placeholder(name='real_images', shape=[None, res, res, 3], dtype=tf.float32)
    # 训练步数
    train_steps = tf.Variable(0, trainable=False, name='train_steps', dtype=tf.float32) # 等于生成器训练次数

    # 生成器和判别器输出
    fake_images = Generator_PG(latents=latents, level=level, reuse=False, isTransit=isTransit,
                               trans_alpha=train_steps / max_iters)
    d_real_logits = Discriminator_PG(RGB=real_images, level=level, reuse=False, isTransit=isTransit,
                                     trans_alpha=train_steps / max_iters)
    d_fake_logits = Discriminator_PG(RGB=fake_images, level=level, reuse=True, isTransit=isTransit,
                                     trans_alpha=train_steps / max_iters)


    #------------ (3)Wasserstein距离和损失函数 --------------#
    # 定义wasserstein距离
    wass_dist = tf.reduce_mean(d_real_logits-d_fake_logits)

    # 定义G,D损失函数
    d_loss = -wass_dist  # 判别器损失函数
    g_loss = -tf.reduce_mean(d_fake_logits)  # 生成器损失函数

    # 基于‘WGAN-GP’的梯度惩罚
    alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.)  # 获取[0,1]之间正态分布
    alpha = alpha_dist.sample((batch_size, 1, 1, 1))
    interpolated = real_images + alpha * (fake_images - real_images)  # 对真实样本和生成样本之间插值
    inte_logit = Discriminator_PG(RGB=interpolated, level=level, reuse=True, isTransit=isTransit,
                                  trans_alpha=train_steps / max_iters)  # 求得对应判别器输出

    # 求得判别器梯度
    gradients = tf.gradients(inte_logit, [interpolated, ])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    slopes_m = tf.reduce_mean(slopes)
    # 定义惩罚项
    gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)
    # d_loss加入惩罚项
    d_loss += gradient_penalty * lam_gp

    # 零点偏移修正
    d_loss += tf.reduce_mean(tf.square(d_real_logits)) * lam_eps

    # ------------ (4)模型可训练参数提取 --------------#
    # 获取G,D 所有可训练参数
    train_vars = tf.trainable_variables()
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
    ShowParasList(d_vars, g_vars, level, isTransit)# 记录参数

    # 提取本阶段各级网络层参数(不含RGB处理层)
    d_vars_c = [var for var in d_vars if 'fromRGB' not in var.name]  # discriminator/scale_(0~level)/
    g_vars_c = [var for var in g_vars if 'toRGB' not in var.name]  # generator/scale_(0~level)/

    # 提取上一阶段各级网络层参数(不含RGB处理层)
    d_vars_old = [var for var in d_vars_c if 'scale_%d' % level not in var.name]  # discriminator/scale_(0~level-1)/
    g_vars_old = [var for var in g_vars_c if 'scale_%d' % level not in var.name]  # generator/scale_(0~level-1)/

    # 提取本次和上次阶段RGB处理层参数
    d_vars_rgb = [var for var in d_vars if 'fromRGB' in var.name]  # discriminator/level_*_fromRGB/
    g_vars_rgb = [var for var in g_vars if 'toRGB' in var.name]  # generator/level_*_toRGB/

    # 提取上一阶段RGB处理层参数
    d_vars_rgb_old = [var for var in d_vars_rgb if
                      'level_%d_fromRGB' % level not in var.name]  # discriminator/level_level-1_fromRGB/
    g_vars_rgb_old = [var for var in g_vars_rgb if
                      'level_%d_toRGB' % level not in var.name]  # generator/level_level-1_fromRGB/

    # 提取上一阶段全部变量
    old_vars = d_vars_old + g_vars_old + d_vars_rgb_old + g_vars_rgb_old

    # ------------ (5)梯度下降 --------------#
    # G,D梯度下降方式
    d_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                         beta1=beta1,
                                         beta2=beta2).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                         beta1=beta1,
                                         beta2=beta2).minimize(g_loss, var_list=g_vars, global_step=train_steps)
    # 为保持全局平稳学习,我们将保存adam参数的更新状态
    all_vars = tf.all_variables()
    adam_vars = [var for var in all_vars if 'Adam' in var.name]
    adam_vars_old = [var for var in adam_vars if 'level_%d' % level not in var.name and 'scale_%d' % level not in var.name]

    # ------------ (6)模型保存与恢复 ------------------#
    # 保存本阶段所有变量
    saver = tf.train.Saver(d_vars + g_vars + adam_vars,max_to_keep=3)
    # 提取上一阶段所有变量
    if level > lowest:
        VARS_MATCH(old_model_path, old_vars)  # 核对
        old_saver = tf.train.Saver(old_vars + adam_vars_old)

    # ------------ (7)数据集读取(TFR) --------------#
    # read TFR
    [num, data, label] = tfr.Reading_TFR(sameName=r'./TFR/celeba_%dx%d-*'%(res,res) ,
                                         isShuffle=False, datatype=tf.float32, labeltype=tf.int8)
    # # get batch
    [num_batch, data_batch, label_batch] = tfr.Reading_Batch_TFR(num, data, label, data_size=res*res*3,
                                                                 label_size=1, isShuffle=False, batchSize=batch_size)

    # ------------------ (8)迭代 ---------------------#
    # GPU配置
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # 保存记录
    losses = []
    Genlog = []
    Wass = []
    SWD = []

    # 加载数据集的descriptors集合
    if res>=16:
        # 加载训练数据的特征集
        DESC = us.PICKLE_LOADING(r'./DESC.desc')

    # 开启会话
    with tf.Session(config=config) as sess:

        # 全局和局部变量初始化
        init = (tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init)

        # 开启协调器
        coord = tf.train.Coordinator()
        # 启动线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # 加载上一阶段参数
        if level>lowest:
            if isTransit:  # 如果处于过渡阶段
                old_saver.restore(sess, tf.train.latest_checkpoint(old_model_path))  # 恢复历史模型
                print('成功读取上一阶段参数...')
            else:  # 如果处于稳定阶段
                saver.restore(sess, tf.train.latest_checkpoint(old_model_path))  # 继续训练该架构

        # 迭代
        time_start = time.time()  # 开始计时
        for steps in range(1,max_iters+1):
            # 获取trans_alpha
            trans_alpha = steps / max_iters

            # 输入标准正态分布
            z = np.random.normal(size=(batch_size, latents_size))
            # 获取数据集
            minibatch = sess.run(data_batch)
            # 格式修正
            minibatch = np.reshape(minibatch,[-1,res,res,3]).astype(np.float32)
            # 数据集显示
            # us.CV2_IMSHOW_NHWC_RAMDOM(minibatch, 1, 9, 3, 3, 'minibatch', 0)

            # 数据集过度处理
            if isTransit:
                # minibatch_low = us.lpf_nhwc(minibatch)
                # minibatch_input = trans_alpha * minibatch + (1 - trans_alpha) * minibatch_low  # 数据集过渡处理
                trans_res = int(0.5*res+0.5*trans_alpha*res)
                minibatch_input = us.upsize_nhwc(us.downsize_nhwc(minibatch,(trans_res,trans_res)),(res,res))
            else:
                minibatch_input = minibatch
            # 规格化【-1,1】
            minibatch_input = minibatch_input*2-1

            # 训练判别器
            for i in range(n_critic):
                sess.run(d_train_opt, feed_dict={
    
    real_images: minibatch_input, latents: z})

            # 训练生成器
            sess.run(g_train_opt, feed_dict={
    
    latents: z})

            # recording training info
            [d_loss2,g_loss2,wass_dist2,slopes2] = sess.run([d_loss,g_loss,wass_dist,slopes_m], feed_dict={
    
    real_images: minibatch_input, latents: z})

            # recording training_products
            z = np.random.normal(size=[9, latents_size])
            gen_samples = sess.run(fake_images, feed_dict={
    
    latents: z})
            us.CV2_IMSHOW_NHWC_RAMDOM((gen_samples+1)/2, 1, 9, 3, 3, 'GEN', 10)

            # 打印
            print('level:%d(%dx%d)..' % (level, res, res),
                  'isTrans:%s..' % isTransit,
                  'step:%d/%d..' % (sess.run(train_steps), max_iters),
                  'Discriminator Loss: %.4f..' % (d_loss2),
                  'Generator Loss: %.4f..' % (g_loss2),
                  'Wasserstein:%.3f..'% wass_dist2,
                  'Slopes:%.3f..'%slopes2)


            #  记录训练信息
            if steps % 10 == 0:
                # (1)记录损失函数
                losses.append([steps, d_loss2, g_loss2])
                Wass.append([steps,wass_dist2])

            # if steps % 50 == 0:
                # (2)记录生成样本
                # GenLog.append(gen_samples[0:9])

            # 计算swd模块
            if steps % 1000 == 0 and res>=16:
                # 获取2^13个fake 样本
                FAKES = []
                for i in range(64):
                    z = np.random.normal(size=[128, latents_size])
                    fakes = sess.run(fake_images, feed_dict={
    
    latents: z})
                    FAKES.append(fakes)
                FAKES = np.concatenate(FAKES, axis=0)
                FAKES = (FAKES + 1) / 2
                # 计算与数据集拉式金字塔指定层的swd
                if res >16:
                    FAKES = us.hpf_nhwc(FAKES) # 获取高频信号
                d_desc = swd.get_descriptors_for_minibatch(FAKES, 7, 64)# 提取特征
                del FAKES
                d_desc = swd.finalize_descriptors(d_desc)
                swd2 = swd.sliced_wasserstein_distance(d_desc, DESC[str(res)], 4, 128) * 1e3 # 计算swd*1e3
                SWD.append([steps,swd2])
                print('当前生成样本swd(x1e3):', swd2, '...')
                del d_desc

            # 保存生成模型
            if steps % 1000 == 0:
                saver.save(sess, model_path + '/network.ckpt', global_step=steps)  # 保存模型

        # 关闭线程
        coord.request_stop()
        coord.join(threads)

        # 计时结束:
        us.CV2_ALL_CLOSE()
        time_end = time.time()
        print('迭代结束,耗时:%.2f秒' % (time_end - time_start))

    # 保存信息
    us.PICKLE_SAVING(np.array(losses),'./trainlog/losses_%dx%d_trans_%s'%(res,res,isTransit))
    us.PICKLE_SAVING(np.array(Wass), './trainlog/Wass_%dx%d_trans_%s' % (res, res, isTransit))
    # us.PICKLE_SAVING(Genlog, './trainlog/Genlog_%dx%d_trans_%s' % (res, res, isTransit))
    if res>=16:
        us.PICKLE_SAVING(np.array(SWD),'./trainlog/SWD_%dx%d_trans_%s'%(res,res,isTransit))

    # 清理图
    tf.reset_default_graph()

#********************************************************* main *******************************************************#
if __name__ == '__main__':
    # 超参
    latents_size = 512
    batch_size = 16
    lowest = 2
    highest = 7
    epochs = 10
    data_size = 30000

    us.MKDIR('ckpt')
    us.MKDIR('structure')
    us.MKDIR('trainlog')

    # progressive growing
    time0 = time.time()  # 开始计时
    PGGAN(0,latents_size,batch_size,  lowest, highest, level=2, isTransit=False,epochs=epochs,data_size=data_size)
    PGGAN(1,latents_size, batch_size, lowest, highest, level=3, isTransit=True, epochs=epochs, data_size=data_size)
    PGGAN(2,latents_size, batch_size, lowest, highest, level=3, isTransit=False, epochs=epochs, data_size=data_size)
    PGGAN(3,latents_size, batch_size, lowest, highest, level=4, isTransit=True, epochs=epochs, data_size=data_size)
    PGGAN(4,latents_size, batch_size, lowest, highest, level=4, isTransit=False, epochs=epochs, data_size=data_size)
    PGGAN(5,latents_size, batch_size, lowest, highest, level=5, isTransit=True, epochs=epochs, data_size=data_size)
    PGGAN(6,latents_size, batch_size, lowest, highest, level=5, isTransit=False, epochs=epochs, data_size=data_size)
    PGGAN(7,latents_size, batch_size, lowest, highest, level=6, isTransit=True, epochs=epochs, data_size=data_size)
    PGGAN(8,latents_size, batch_size, lowest, highest, level=6, isTransit=False, epochs=epochs, data_size=data_size)
    PGGAN(9,latents_size, batch_size, lowest, highest, level=7, isTransit=True, epochs=epochs, data_size=data_size)
    PGGAN(10,latents_size, batch_size, lowest, highest, level=7, isTransit=False, epochs=epochs, data_size=data_size)
    time1 = time.time()  # 开始计时
    print('全部训练耗费时间:%.2f..'%(time1-time0))


猜你喜欢

转载自blog.csdn.net/Ephemeroptera/article/details/89193727