[tensorflow应用之路]如何用少量标注训练样本?将GAN用于半监督(上)

我们用tensorflow应用于实际项目中时,常常会遇到一种情况:我们有很多的数据,但是只有很少的标注。因为标注需要很多时间。这时我们可能会想到用半监督(semi-supervise)的方法训练数据。但是半监督需要将无标签(unlabeled)的数据用于训练中,这是一个很困难的事情。恰好,最近有一种很火的方法——生成对抗网络(Generative Adversarial Nets,GAN)——中有关于半监督的方法应用,并且在论文中得到了很好的效果,我们参考它们尝试设计自己的半监督网络。
首先聊一聊GAN的发展历史。2016年,ImprovedTechniquesforTrainingGANs将GAN用于生成样本和半监督中(它并不是首例GAN,但它的代码引用是最多的),设计了两类损失包括监督损失和无监督损失,达到了比较好的训练精度。后来,随着时间的推移,人们发现KL散度用于度量GAN这种低维映射到高维的网络损失时,有个理论上的巨大漏洞,容易导致梯度消失,所以出现了WGAN;随后,又有人发现WGAN的权值剪切法会导致梯度极端分布,所以出现了WGAN-GPWGAN-CT等方法,都是使用梯度惩罚项来实现Lipschitz连续。
下面我们看看它们的具体实现。代码参考

SSGAN

SSGAN是所有介绍的方法中,最“老”的一种。但是,后续的半监督方法无一不参考了它的思想——将损失分为监督损失和无监督损失。

L=Lsupervised+LunsupervisedLsupervised=Ex,y pdata(x,y)logpmodel(y|x,y<K+1)Lunsupervised=Ex pdata(x)log[1pmodel(y=K+1|x)]Ex Glog[pmodel(y=K+1|x)]

论文源码请点这里
结合上式与代码,我们看到SGAN使用无标签样本的方法就是将其用于K+1类。假设我们的判别网络为discriminator,生成网络有,此时有:

G_img = generator('gen', z, reuse=False)
d_logits_r, layer_out_r = discriminator('dis', x, reuse=False)
d_logits_f, layer_out_f = discriminator('dis', G_img, reuse=True)

# caculate the unsupervised loss
d_loss_r=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_r[:, -1]),logits=d_logits_r[:, -1]))
d_loss_f=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_f[:, -1])*0.9, logits=d_logits_f[:, -1]))
d_loss_f1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_f[:, -1]),logits=d_logits_f[:, -1]))

# feature match
f_match = []
for i in range(4):
    f_match += [tf.reduce_mean(tf.multiply(layer_out_f[i]-layer_out_r[i], layer_out_f[i]-layer_out_r[i]))]

# caculate the supervised loss
s_label = tf.concat([label, tf.zeros(shape=(batch_size,1))], axis=1)

s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label, logits=d_logits_r))

d_loss = d_loss_r + d_loss_f + s_l_r*flag+d_regular
g_loss = d_loss_f1+0.1*tf.reduce_mean(f_match,0)

其中,s_l_r是有标签的损失,d_loss是判别器损失,g_loss是生成器损失。当有标签时,我们将d_loss中的flag设为1,没有时,设为0。
为什么虚拟标签需要乘以0.9,请参照论文内的One-sided label smoothing方法。feature matching 也一样。

WGAN-CT

wgan-ct相对于WGAN-GP(WGAN的改进型,略去不讲,请参考论文),使用最后两层网络的梯度惩罚。论文中说这样可以减少真实数据的不连续性(未完全理解)。之所以选择这个网络,是因为这个方法也做了半监督实验。代码如下(论文源码):

G_img = generator('gen', z, reuse=False)
d_logits_r1, d_logits_r11,d_logits_r12 = discriminator_with_dropout('dis', x, reuse=False)
d_logits_r2, d_logits_r21,_ = discriminator_with_dropout('dis', x, reuse=True)
d_logits_f, _ ,d_logits_f2 = discriminator_with_dropout('dis', G_img, reuse=True)

# caculate the unsupervised loss
logits_r, logits_f = tf.nn.softmax(d_logits_r1), tf.nn.softmax(d_logits_f)
d_loss_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_r[:, -1]), logits=d_logits_r1[:, -1]))
d_loss_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_f[:, -1]), logits=d_logits_f[:, -1]))
                                                     logits=tf.reduce_max(d_logits_r1, -1))
loss_ct=tf.square(d_logits_r1-d_logits_r2)
loss_ct_=0.1*tf.reduce_mean(tf.square(d_logits_r11-d_logits_r21))
CT=loss_ct+loss_ct_

# caculate the supervised loss
s_label = tf.concat([label, tf.zeros(shape=(batch_size,1))], axis=1)
s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label, logits=d_logits_r1))
d_l_1, d_l_2 = d_loss_r + d_loss_f, s_l_r
d_loss = d_loss_r + d_loss_f + s_l_r*flag +0.1*tf.reduce_mean(CT)
g_loss = tf.square(tf.reduce_mean(d_logits_f2,0)-tf.reduce_mean(d_logits_r12,0))

all_vars = tf.global_variables()
for v in all_vars:
    print(v)
all_vars = tf.global_variables()
g_vars = [v for v in all_vars if 'gen' in v.name]
d_vars = [v for v in all_vars if 'dis' in v.name]
opt_d = tf.train.AdamOptimizer(lr).minimize(d_loss, var_list=d_vars)
opt_g = tf.train.AdamOptimizer(lr).minimize(g_loss, var_list=g_vars)

和SSGAN一样,使用flag作为控制带标签和不带标签的开关。其中CT项是该方法的创新,具体方法为:通过加入dropout,由同一输入得到不同的输出,然后将网络最后两层的输出做差分,作为梯度惩罚项。

MINIST

我们拿minist(深度学习界的果蝇)做一下实验,使用同样的判别器和生成器,对比SSGAN和WGAN-CT的结果。


首先,作为比较结果,我们得到仅使用带标签数据的训练成果。
这里写图片描述
其中,每1000次迭代(iteration)更新一次带标签数据,带标签数据一共100个。一次迭代的batch size为50,即一次训练的样本数为50000。


然后,使用同样的判别网络,与数据输入方式,SSGAN的训练精度为
这里写图片描述
可以看到,最高精度反而降低了。
以下是SSGAN的生成样本。
这里写图片描述
和WGAN中讨论的一样,生成的样本面临多样性不足的问题(训练多次后有所缓解)。


最后,是WGAN-CT的半监督测试精度结果:
这里写图片描述
相对于带监督结果,精度只提高了一点。
以下是生成的样本
这里写图片描述
多样性不足的问题有所缓解。这个实验我没有做满50000次,因为CT-GAN的由于需要计算梯度,所以反向训练比较耗时,感兴趣的读者可以参考我提供的代码地址,自己使用minist复现结果。

总结

两种半监督方法的效果不是很好,比不上无监督方法。其实这也很好理解,因为它们只是沿用了传统GAN的思路定义了损失(实际上只是多出来一个虚拟类),但是并没有思考生成出来的图片如何进一步为判别器所用,提升判别器的精度。
下一步需要做的:
1.Cifar数据集测试.
2.思考两种方法的损失改进方法,使生成的图片能够用于判别器的分类中(不仅仅是真实和虚假两类)。
3.增加VAT方法.

最后,祝您身体健康,再见!

猜你喜欢

转载自blog.csdn.net/h8832077/article/details/79485378
今日推荐