Exponential Moving Average (EMA) in Pytorch

EMA Introduction

EMA, exponential moving average, is often used to update model parameters, gradients, etc.

The advantage of EMA is that it can improve the robustness of the model (incorporating the previous model weight information)

code example

Take the yolov7/utils/torch_utils.py code as an example below:

class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

Introduction to the __init__ function of the ModelEMA class

Introduction to the input parameters of the __init__ function

  • model: The model that needs to update parameters using the EMA strategy
  • decay: weighted weight, the default is 0.9999
  • updates: Model parameter update/iteration times

Introduction to the initialization of the __init__ function

First deep copy a model

"""
创建EMA模型

model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间

is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model

"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

Next, initialize the number of updates. If you start training from scratch, this parameter is 0

self.updates = updates

Finally, define the formula for calculating the weighted weight decay (exponential change here),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

Introduction to the update() function of the ModelEMA class

If this function is called, updates and decay are updated,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

Take out the parameters of the current model to prepare for updating the parameters of the EMA model,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

The weighted sum of the EMA model parameters and the current model parameters is used as a new parameter of the EMA model,

for k, v in self.ema.state_dict().items():
    if v.dtype.is_floating_point:
        v *= d
        v += (1. - d) * msd[k].detach()

Reference article

[Code Interpretation] Using EMA in pytorch - Programmer Sought

[Alchemy Skills] Principle of Exponential Moving Average (EMA) and PyTorch Implementation

Learn from history! The application of EMA in machine learning - Zhihu

Guess you like

Origin blog.csdn.net/qq_38964360/article/details/131482442