グループ正規化の詳細な説明

論文名:グループ正規化
論文アドレス:https ://arxiv.org/abs/1803.08494

以前の記事では、BN(バッチ正規化)、リンク、およびLN(レイヤー正規化)、リンクが導入されました。今日は、GN(グループ正規化)について簡単に説明しましょう。視野では、実際にはBNが最も一般的に使用されていますが、BNにも欠点があり、通常は比較的大きなバッチサイズが必要です。下の図に示すように、青い線はBNを表します。バッチサイズが16未満の場合、エラーは大幅に増加します(ただし、バッチサイズが16より大きい場合、効果は向上します)。比較的大規模なネットワークの場合、またはGPUメモリが十分でない場合、通常、より大きなバッチサイズを設定することはできません。この場合、GNを使用できます。下図のように、バッチサイズのサイズはGNに影響を与えないため、バッチサイズを小さい値に設定するとGNを使用できます。

グループ規範
BN、LN、GNのいずれであっても、式は同じで、減算された平均E(x)E(x)です。E x 、標準偏差で割った値V ar(x)+ ϵ \ sqrt {Var(x)+ \ epsilon}V a r x +ϵ ここで、ϵ \ epsilonϵは非常に少量です(デフォルトは1 0 − 5 10 ^ {-5}1 05)、分母がゼロになるのを防ぎます。および2つのトレーニング可能なパラメーターβ、γ\ beta、\ gammaβ γ違いは、操作する次元/次元です
。y = x − E(x)V ar(x)+ ϵ ∗γ+βy= \ frac {x --E(x)} {\ sqrt {Var(x) + \ epsilon}} \ ast \ gamma + \ betay=V a r x +ϵ バツE xc+b

GN(Group Normalization)の操作は、 num _ groups = 2 num \ _groups=2と仮定して次の図に示されています。n u m _ g r o u p s=2元の論文のデフォルト値は32です。batch_sizeとは関係がないため、1つのサンプルの状況を直接調べます。レイヤーの出力がxxxnum _ groups num\_groupsチャネルチャネル沿ったnum _ g r o u p sc h a n nel方向はnum_groups num\_groups分割されn u m _ g r o u p s、次にそれぞれの平均と分散を計算し、次に上記の式に従って計算します。これは非常に簡単です。

おやすみなさい

私の理解が正しいかどうかを確認するために、Pytorchを使用して簡単な実験を行い、確率変数を作成し、公式のGNメソッドを使用して自分で実装したGNメソッドと比較して結果が同じかどうかを確認しましょう。

import torch
import torch.nn as nn


def group_norm(x: torch.Tensor,
               num_groups: int,
               num_channels: int,
               eps: float = 1e-5,
               gamma: float = 1.0,
               beta: float = 0.):
    assert divmod(num_channels, num_groups)[1] == 0
    channels_per_group = num_channels // num_groups

    new_tensor = []
    for t in x.split(channels_per_group, dim=1):
        var_mean = torch.var_mean(t, dim=[1, 2, 3], unbiased=False)
        var = var_mean[0]
        mean = var_mean[1]
        t = (t - mean[:, None, None, None]) / torch.sqrt(var[:, None, None, None] + eps)
        t = t * gamma + beta
        new_tensor.append(t)

    new_tensor = torch.cat(new_tensor, dim=1)
    return new_tensor


def main():
    num_groups = 2
    num_channels = 4
    eps = 1e-5

    img = torch.rand(2, num_channels, 2, 2)
    print(img)

    gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps)
    r1 = gn(img)
    print(r1)

    r2 = group_norm(img, num_groups, num_channels, eps)
    print(r2)


if __name__ == '__main__':
    main()

公式手法と自己実施手法を比較すると、下図の左側が公式GN手法による結果、右側が自己実施GN手法による結果である。

vs
明らかに、結果は公式の結果とまったく同じであり、これも私の理解が正しいことを示しています。

おすすめ

転載: blog.csdn.net/qq_37541097/article/details/118016048