It is to take the average of the weights of the previous few epochs, which can alleviate catastrophic forgetting during fine-tuning (because the new data guides the model weights gradually, deviating from the data distribution learned during training, and forgetting the prior knowledge learned before)
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay # decay rate
self.shadow = {
} # old weight
self.backup = {
} # new weight
def register(self): # deep copy weight for init
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self): # ema:average weight for train
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self): # load old weight for eval begin
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self): # load new weight for eval end
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {
}
# 初始化
ema = EMA(model, 0.999)
ema.register()
# 训练过程中,更新完参数后,同步update shadow weights
def train():
optimizer.step()
ema.update()
# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
ema.apply_shadow()
# evaluate
ema.restore()