半监督学习Mean teachers
网络整体的架构包括两个部分student model和teacher model:
-
student model的网络参数通过学习,梯度下降获得。
-
teacher model的网络参数通过student model的网络参数的moving average得到。
student model的网络参数更新方法:
通过损失函数的梯度下降更新参数得到。
其中损失函数包括两个部分:
第一部分是有监督损失函数,保证有标签训练数据拟合;
第二部分是无监督损失函数,主要是保证student model的预测结果和teacher model的预测结果尽量的相似。因为teacher model的参数是student model的网络参数的moving average,所以,对于任何新来的数据,预测结果都不应该有太大的抖动。
如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。
teacher model的网络参数的更新方法:
通过student model网络参数的moving average得到
θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt′=αθt−1′+(1−α)θt
基本流程
假设有一批训练样本X1,X2,其中X1使有标签数据(对应标签是z1),X2使无标签数据。具体的训练过程如下:
-
把这一批样本作为student网络输入,然后分别得到输出的标签:ys1,ys2;
-
构造对于有标签数据X1的损失函数,有标签分类损失函数L1(z1,ys1);
-
把这批数据作为teacher model的输入,得到输出的标签yt1,yt2;
-
构造无监督损失函数L2,论文中采用MSE损失函数: J ( x , θ ) = E x , η ′ , η [ ∣ ∣ f ( x , θ ′ , η ′ ) − f ( x , θ , η ) ∣ ∣ 2 ] J(x, \theta)=E_{x, \eta ^{\prime}}, \eta \left[ ||f(x, \theta ^{\prime}, \eta ^{\prime})-f(x, \theta , \eta)||^{2}\right] J(x,θ)=Ex,η′,η[∣∣f(x,θ′,η′)−f(x,θ,η)∣∣2]
-
总损失函数L1+L2梯度下降,更新student model的网络参数,通过moving average更新teacher model的网络参数 θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt′=αθt−1′+(1−α)θt