Mean teachers of semi-supervised learning

Semi-supervised learning Mean teachers

Insert picture description here
The overall network architecture includes two parts: student model and teacher model:

  1. The network parameters of the student model are obtained through learning and gradient descent.

  2. The network parameters of the teacher model are obtained by the moving average of the network parameters of the student model.

Update method of student model network parameters:

Obtained by updating the parameters of the gradient descent of the loss function.
The loss function includes two parts:

The first part is a supervised loss function to ensure the fitting of labeled training data;

The second part is the unsupervised loss function , mainly to ensure that the prediction results of the student model and the prediction results of the teacher model are as similar as possible. Because the parameters of the teacher model are the moving average of the network parameters of the student model, for any new data, the prediction results should not have too much jitter.
If the model is correct, then the prediction labels of the two models should be close, and the change is small, then moving the model in the direction that makes the prediction results of the two models close is to move to the groundtruth model.

How to update the network parameters of teacher model:

Through the moving average of the student model network parameters,
θ t ′ = α θ t − 1 ′ + (1 − α) θ t \theta_{t}^{\prime} = \alpha \theta _{t-1}^{ \prime}+(1- \alpha)\theta _{t}θt=α θt1+(1a ) it

Basic process

Suppose there is a batch of training samples X1, X2, where X1 is labeled data (corresponding to the label is z1), and X2 is unlabeled data. The specific training process is as follows:

  1. Take this batch of samples as input to the student network, and then get the output labels: ys1, ys2;

  2. Construct a loss function for labeled data X1, a labeled classification loss function L1(z1, ys1);

  3. Take this batch of data as the input of the teacher model and get the output labels yt1, yt2;

  4. Construct an unsupervised loss function L2. The MSE loss function is used in the paper: J (x, θ) = E x, η ′, η [∣ ∣ f (x, θ ′, η ′) − f (x, θ, η) ∣ ∣ 2] J(x, \theta)=E_{x, \eta ^{\prime}}, \eta \left[ ||f(x, \theta ^{\prime}, \eta ^{\prime })-f(x, \theta, \eta)||^{2}\right]J(x,i )=Ex , η,the[f(x,θ,the)f(x,θ ,h ) 2]

  5. The total loss function L1+L2 gradient descent, update the network parameters of the student model, update the network parameters of the teacher model through the moving average θ t ′ = α θ t − 1 ′ + (1 − α) θ t \theta_{t}^{ \prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t}θt=α θt1+(1a ) it

Guess you like

Origin blog.csdn.net/weixin_42764932/article/details/112979993