Pytorch の指数移動平均 (EMA)

EMAの紹介

EMA (指数移動平均) は、モデル パラメーターや勾配などを更新するためによく使用されます。

EMA の利点は、モデルの堅牢性を向上できることです (以前のモデルの重み情報を組み込む)。

コード例

以下の例として、yolov7/utils/torch_utils.pyコードを取り上げます。

class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

ModelEMA クラスの __init__ 関数の概要

__init__ 関数の入力パラメータの概要

  • モデル: EMA 戦略を使用してパラメーターを更新する必要があるモデル
  • Decay: 重み付けされた重み、デフォルトは 0.9999
  • 更新: モデル パラメーターの更新/反復時間

__init__ 関数の初期化の概要

最初にモデルをディープコピーします

"""
创建EMA模型

model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间

is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model

"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

次に、更新回数を初期化します。トレーニングを最初から開始する場合、このパラメータは 0 です。

self.updates = updates

最後に、重み付けされた重みの減衰 (ここでは指数関数的変化) を計算する式を定義します。

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA クラスの update() 関数の概要

この関数を呼び出すと更新と減衰が更新され、

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

現在のモデルのパラメータを取り出して、EMA モデルのパラメータを更新する準備をします。

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

EMA モデル パラメーターと現在のモデル パラメーターの加重合計は、EMA モデルの新しいパラメーターとして使用されます。

for k, v in self.ema.state_dict().items():
    if v.dtype.is_floating_point:
        v *= d
        v += (1. - d) * msd[k].detach()

参考記事

[コード解釈] pytorch で EMA を使用する - プログラマー募集

[錬金術スキル] 指数移動平均 (EMA) の原理と PyTorch の実装

歴史から学べ!機械学習における EMA の応用 - Zhihu

おすすめ

転載: blog.csdn.net/qq_38964360/article/details/131482442