基于生成式对抗网络的半监督深度学习增量自标记方法

1. 摘要

    深度神经网络的成功部分归功于大规模标记良好的训练数据。然而,随着现代数据集规模的不断增长,标签信息的获取变得极为困难。该论文提出了一种基于生成式对抗网络的半监督深度学习增量自标记方法,该方法通过不断地将未标记的数据分配虚拟标签来促进训练过程。具体来说,在分配虚拟标签的过程中,该论文引入了一种基于时间的自标记方法。然后,为了在训练过程中动态地向数据分配更多的虚拟标签,该论文采用了分阶段的增量标签筛选和更新方法。最后,该论文进一步引入了平衡因子项(balance factor Term, BT),平衡训练过程中样本的信息损失。

2. 简介

    对于使用生成式对抗网络(GANs)的半监督分类,大多数网络都是通过修改常规的GAN鉴别器来产生k个输出对应于k个类。为了进一步利用未标记的数据进行训练,通常由生成器生成一个额外的第(k + 1)个类,以增强鉴别器的鉴别能力。后者能提取更多的信息特征,用于区分真实数据和虚假数据。

    该论文致力于探索一种增量式自标记方法(ISL -GAN),并将其嵌入到稳健的SSL(SSL)框架中,以提高GAN在分类领域的性能。

3. 方法

    首先,在标签预测的正确性上面,大部分的训练数据,包括标记数据和未标记数据,在训练过程中都得到了正确的预测。为了进一步检验模型对噪声标记的鲁棒性,该论文在模型训练中加入了一些错误标记的样本,发现一定比例的标签错误确实会影响最终测试的准确性。

    下面介绍该论文提出的模型。如图1所示,所提出的模型由两部分组成:第一部分是基于一致性的半监督GAN模型。第二部分负责给未标记数据分配虚拟标签,每隔一定的epoch间隔,对可信度高的数据分配一个虚拟标签来更新已标记的训练数据集。

在这里插入图片描述

图1. 增量式自标记GAN (isli -GAN)的研究进展。灰色和橙色表示模型的两个部分。不同的形状表示不同标签的输入数据,蓝色表示已标记的数据,灰色表示未标记的数据

    已知对于网络不同训练阶段输出稳定的样本,误分类的概率较低,容易误分类的样本往往出现在分类边缘附近,这必然会导致样本输出的不稳定性。考虑到这一点,为了保持每个训练样本相对稳定和安全的虚拟标签,该论文选择计算多个历史输出的平均来保证稳定性。

    考虑到通过半监督方式学习样本最终的数据标签预测正确率很高,该论文使用该方法并且通过设置未标记数据的虚拟标签来更新训练数据集。如果一个未标记的样本多次被分到了同一个类别标签,那么就把这个类别标签作为该样本的虚拟标签。在训练过程中,对一个未标记样本分配一个虚拟标签,可以有效地增加SSL中标记样本的数量,从而增加分类准确率。

    在这里,我们需要为分配了虚拟标签的样本设置一个可信度的阈值。这个阈值不可以过低,也不能过高。如果设置的过高,比如说是100%的可信度,将会导致:当我们用这一部分数据更新模型的时候算出来的损失值为0,无法进一步更新模型。过低则导致这一部分数据不可信,这和我们拿它当作训练数据的思想是违背的。因此,为了增加已拟合低损耗模型的虚拟标记样本的贡献,该论文将平衡因子项(BT) exp(pin)引入到原交叉熵损耗CE中,最终的监督损耗可由下面公式表示(详细的模型损失函数请见原论文,这里只给出了引入了平衡因子项的公式以供参考):
在这里插入图片描述
pi表示了样本属于第i类的概率,yi代表第i类的one-hot编码值。超参数n是控制损失权重的平衡因子,默认为2.0。

4. 结论

    实验结果表明,该方法可以得到MNIST、CIFAR-10和SVHN数据集的最新SSL结果。特别地,该论文的模型在样本标记较少条件下表现良好。对于只有1000个标记为CIFAR-10的图像的数据集,可以实现11.2%的测试误差,对于500和1000个标记为SVHN的数据集,可以实现几乎相同的性能,测试误差为3.5%。
在这里插入图片描述

更多有趣资讯扫码关注 BBIT
发布了6 篇原创文章 · 获赞 0 · 访问量 39

猜你喜欢

转载自blog.csdn.net/ShenggengLin/article/details/105301847
今日推荐