前言:DeLiGAN是计算机视觉顶会CVPR2017发表的一篇论文,本文将结合Python源代码学习DeLiGAN中的核心内容。DeLiGAN最大的贡献就是将生成对抗网络(GANs)的输入潜空间编码为混合模型(高斯混合模型),从而使得生成对抗网络(GANs)在数量有限但具有多样性的训练数据上表现出较好的性能;同时,在初始得分(Inception Score)的基础上提出了改进版初始得分(modified Inception Score)用于评测生成样本的类内多样性。
论文地址:https://arxiv.org/abs/1706.02071
源码地址:https://github.com/val-iisc/deligan
一、概述
(一)关键科学问题:DeLiGAN要解决的关键科学问题是传统的生成对抗网络(GANs)需要大量的训练数据才能学习到跨图片模态的多样性(diversity across the image modality,根据全文内容,感觉这里理解为数据集中包含很多不同类别图像,生成对抗网络要学习到每一个类别的类内多样性比较困难),在训练数据有限且类内、类间具有多样性的情况下,传统的生成对抗网络(GANs)的效果并不好。
(二)解决方法:将生成对抗网络中的输入潜空间参数化为混合模型,通过对抗训练的方式学习生成对抗网络(GANs)的参数以及该潜空间混合模型的参数,从而提高生成对抗网络(GANs)对有限数量的、且具有类内多样性的训练数据的学习能力。
(三)几个概念的说明
(1)intra-class diversity:类内多样性(inter-class:类间多样性)
(2)modified version of inception score:改进版初始得分(m-IS)
二、核心方法解读(结合Python源码)
(一)输入噪声采样的潜空间
将GANs的输入噪声采样潜空间重构为高斯混合模型
其中,表示高斯分布下噪声样本z的概率
因为无法从训练中得到每一个高斯分布的权重(也就是上面公式中的),所以将所有的高斯分布的权重置为,则上面公式变为:
为了采样噪声样本,从个高斯分布中随机选取一个,利用“重参数化技巧(reparameterization trick)”,将噪声样本表示为被选中的第个高斯分布的参数和以及一个辅助噪声变量的确定函数,即:
Toy源代码中的zin表示的是,zsig表示,从正太分布或均匀分布中采样
这样,噪声样本采样就转换为从高斯分布采样,
因为训练生成器的原来目标是
训练生成器的新目标就变为:
最终,通过对抗训练学习生成对抗网络(GANs)的参数的同时,根据生成器的损失反馈到混合模型的梯度来训练高斯模型的参数和,
参数初始化设置:从均匀分布随机采样,为非零初始值0.2。
同时,因为上面公式中具有局部最优,为了防止生成器在训练过程中为了生成更多的高概率区域的样本而不断减小趋于0,在生成器的损失函数中引入正则化项,则生成器的新的损失函数的公式表示为:
(二)改进版初始得分(Inception Score)
初始得分(Inception Score):将生成图像输入到一个训练好的具有inception结构的分类器中就会得到一个条件标签分布,当足够真实时,就会得到一个多峰(peaky)的标签分布,即应该具有较低的熵值。同时,我们也希望生成的图片能涵盖所有的类别,即应该具有较高的熵值。这两个要求整合为初始得分(Inception Score)的测量标准:
该公式表示形式使模型具有较高初始分值但却导致低熵条件类别分布,但是我们需要每个类别的图像都具有多样性,因此使用交叉熵作为训练好的Inception模型的输出结果,因此改进版初始得分为:
三、Python源代码
dg_toy.py
import tensorflow as tf
import numpy as np
import os
import time
from random import randint
import cv2
import matplotlib.pylab as Plot
batchsize=50
results_dir='../results/toy'
def linear(x,output_dim, name="linear"):
""" Linear Layer for 2d input x """
w=tf.get_variable(name+"/w", [x.get_shape()[1], output_dim])
b=tf.get_variable(name+"/b", [output_dim], initializer=tf.constant_initializer(0.0))
return tf.matmul(x,w)+b
# 判别器设置1个隐层(含32个神经元),即df_dim=32
def discriminator(image, reuse=False):
""" Discriminator function description """
with tf.variable_scope('disc', reuse=reuse):
h0 = tf.tanh(linear(image,df_dim,'d_l1'))
h1 = linear(h0, 1, 'd_l2')
return tf.nn.sigmoid(h1), h1
# 生成器设置1个隐层(含32个神经元),即gf_dim=32
def generator(z, n):
""" Generator function description """
with tf.variable_scope('gen'+str(n)):
#z = tf.tanh(linear(z, batchsize,'g_l0')) # Uncomment for testing GAN++ model
h1 = tf.tanh(linear(z, gf_dim,'g_l1'))
h2 = linear(h1, 2, 'g_l2')
return tf.nn.tanh(h2)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
imageshape = [2]
z_dim = 2
gf_dim = 32 # 生成器全连接隐层的神经元数量
#gf_dim = 32*50 # Uncomment this line for testing Nx-GAN
df_dim = 32 # 判别器全连接隐层的神经元数量
learningrate = 0.0001
beta1 = 0.5
# Taking Inputs for the graph as placeholders
images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images")
z = tf.placeholder(tf.float32, [None, z_dim], name="z")
lr1 = tf.placeholder(tf.float32, name="lr")
zin = tf.get_variable("g_z", [batchsize, z_dim],initializer=tf.random_uniform_initializer(-1,1))
# zin对应论文公式(8)中的\mu_{i}
zsig = tf.get_variable("g_sig", [batchsize, z_dim],initializer=tf.constant_initializer(0.02))
# zsig对应论文公式(8)中的\sigma_{2},论文中说是0.2
inp = tf.add(zin,tf.mul(z,zsig)) #Uncomment this line for testing the DeliGAN
# inp也就是转化为混合模型后,采样得到的最终要输入到生成器中的z
#moe = tf.eye(batchsize) #Uncomment this line for testing the MoE-GAN
#inp = tf.concat_v2([moe, z],1) #Uncomment this line for testing the MoE-GAN
# Calling the generator Function for different Models
#G = generator(z[:1],0) #Uncomment this line when testing Ensemble-GAN
G = generator(inp,0) #Uncomment this line for testing DeliGAN, MoE-GAN
#G = generator(z,0) #Uncomment this line for testing GAN and Nx-GAN
#for n in range(batchsize-1): #Uncomment this line when testing Ensemble-GAN
# g = generator(z[n+1:n+2],n+1) #Uncomment this line when testing Ensemble-GAN
#G = tf.concat_v2([g,G],0) #Uncomment this line when testing Ensemble-GAN
lab = tf.where(G[:,0]<0)
D_prob, D_logit = discriminator(images)
D_fake_prob, D_fake_logit = discriminator(G, reuse=True)
# Defining Losses
sig_loss = 0.1*tf.reduce_mean(tf.square(zsig-1))
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit, tf.ones_like(D_logit)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.zeros_like(D_fake_logit)))
gloss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.ones_like(D_fake_logit)))
gloss1 = gloss+sig_loss
dloss = d_loss_real + d_loss_fake
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# 随机生成两类符合正太分布的训练数据
data = np.random.normal(0,0.3,(200,2)) # Comment this line when using multimodal (i.e. Uncomment for unimodal data)
data1 = np.random.normal(0,0.3,(200,2)) # Comment this line when using multimodal (i.e. Uncomment for unimodal data)
#data = np.random.normal(0.6,0.15,(200,2)) # Uncomment this line for multimodal data
#data1 = np.random.normal(-0.6,0.15,(200,2)) # Uncomment this line for multimodal data
data = np.vstack((data,data1))
data = data.reshape([-1,2])
# Optimization
d_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(dloss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(gloss1, var_list=g_vars)
tf.initialize_all_variables().run()
saver = tf.train.Saver(max_to_keep=10)
counter = 1
start_time = time.time()
data_size = data.shape[0]
display_z = np.random.normal(0, 1.0, [batchsize, z_dim]).astype(np.float32) #Uncomment this line for using a mixture of normal prior
#display_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32) #Uncomment this line for using a mixture of uniform distributions prior
seed = 1
rng = np.random.RandomState(seed)
train = True
thres=1.0
count=0
t1=0.73
for epoch in xrange(8000):
batch_idx = data_size/batchsize
batch = data[rng.permutation(data_size)]
if count<-1000:
t1=max(t1-0.005, 0.70)
lr = learningrate
for idx in xrange(batch_idx):
batch_images = batch[idx*batchsize:(idx+1)*batchsize]
batch_z = np.random.normal(0, 1.0, [batchsize, z_dim]).astype(np.float32)
batch_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32)
# 这里的batch_z相当于论文中公式(8)的\epsilon,论文是同从正太分布N(0,1)中采样,这里会被后面的batch_z覆盖掉
# Threshold to decide the which phase to run (generator or discrminator phase)
if count>10:
thres=min(thres+0.01, 1.0)
count=0
if count<-150 and thres>t1:
thres=max(thres-0.001, t1)
count=0
# Training each phase based on the value of thres and gloss
for k in xrange(5):
if gloss.eval({z: batch_z})>thres:
sess.run([g_optim],feed_dict={z: batch_z, lr1:lr})
count+=1
else:
sess.run([d_optim],feed_dict={ images: batch_images, z: batch_z, lr1:lr })
count-=1
counter += 1
# Printing training status periodically
if counter % 300 == 0:
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idx, time.time() - start_time,))
sdata = sess.run(G,feed_dict={ z: display_z })
errD_fake = d_loss_fake.eval({z: display_z})
errD_real = d_loss_real.eval({images: batch_images})
errG = gloss.eval({z: display_z})
sl = sig_loss.eval({z: display_z})
print('D_real: ', errD_real)
print('D_fake: ', errD_fake)
print('G_err: ', errG)
print('zloss: ', sl)
# Plotting the generated samples and the training data
if counter % 1000 == 0:
f, (ax1,ax2, ax3) = Plot.subplots(1, 3)
ax1.set_autoscale_on(False)
ax2.set_autoscale_on(False)
lab1 = lab.eval({z:display_z})
gen = G.eval({z:display_z})
ax1.scatter(gen[:,0], gen[:,1]);
#ax1.scatter(gen[lab1,0], gen[lab1,1], color='r'); # Uncomment this line when testing with multimodal data
ax1.set_title('Generated samples')
ax1.set_aspect('equal')
ax1.axis([-1,1,-1,1])
ax2.scatter(batch[:,0], batch[:,1])
lab_ = batch[batch[:,0]<-0.1]
#ax2.scatter(lab_[:,0], lab_[:,1], color='r'); # Uncomment this line when testing with multimodal data
ax2.set_title('Training samples')
ax2.set_aspect('equal')
ax2.axis([-1,1,-1,1])
f.savefig(results_dir + '/plot' + str(counter) + ".png")
saver.save(sess, os.getcwd()+results_dir+'/train/',global_step=counter)