半监督学习之Mean teachers

半监督学习Mean teachers

在这里插入图片描述
网络整体的架构包括两个部分student model和teacher model:

  1. student model的网络参数通过学习,梯度下降获得。

  2. teacher model的网络参数通过student model的网络参数的moving average得到。

student model的网络参数更新方法:

通过损失函数的梯度下降更新参数得到。
其中损失函数包括两个部分:

第一部分是有监督损失函数,保证有标签训练数据拟合;

第二部分是无监督损失函数,主要是保证student model的预测结果和teacher model的预测结果尽量的相似。因为teacher model的参数是student model的网络参数的moving average,所以,对于任何新来的数据,预测结果都不应该有太大的抖动。
如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。

teacher model的网络参数的更新方法:

通过student model网络参数的moving average得到
θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt=αθt1+(1α)θt

基本流程

假设有一批训练样本X1,X2,其中X1使有标签数据(对应标签是z1),X2使无标签数据。具体的训练过程如下:

  1. 把这一批样本作为student网络输入,然后分别得到输出的标签:ys1,ys2;

  2. 构造对于有标签数据X1的损失函数,有标签分类损失函数L1(z1,ys1);

  3. 把这批数据作为teacher model的输入,得到输出的标签yt1,yt2;

  4. 构造无监督损失函数L2,论文中采用MSE损失函数: J ( x , θ ) = E x , η ′ , η [ ∣ ∣ f ( x , θ ′ , η ′ ) − f ( x , θ , η ) ∣ ∣ 2 ] J(x, \theta)=E_{x, \eta ^{\prime}}, \eta \left[ ||f(x, \theta ^{\prime}, \eta ^{\prime})-f(x, \theta , \eta)||^{2}\right] J(x,θ)=Ex,η,η[f(x,θ,η)f(x,θ,η)2]

  5. 总损失函数L1+L2梯度下降,更新student model的网络参数,通过moving average更新teacher model的网络参数 θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt=αθt1+(1α)θt

猜你喜欢

转载自blog.csdn.net/weixin_42764932/article/details/112979993