論文名:グループ正規化
論文アドレス:https ://arxiv.org/abs/1803.08494
以前の記事では、BN(バッチ正規化)、リンク、およびLN(レイヤー正規化)、リンクが導入されました。今日は、GN(グループ正規化)について簡単に説明しましょう。視野では、実際にはBNが最も一般的に使用されていますが、BNにも欠点があり、通常は比較的大きなバッチサイズが必要です。下の図に示すように、青い線はBNを表します。バッチサイズが16未満の場合、エラーは大幅に増加します(ただし、バッチサイズが16より大きい場合、効果は向上します)。比較的大規模なネットワークの場合、またはGPUメモリが十分でない場合、通常、より大きなバッチサイズを設定することはできません。この場合、GNを使用できます。下図のように、バッチサイズのサイズはGNに影響を与えないため、バッチサイズを小さい値に設定するとGNを使用できます。
BN、LN、GNのいずれであっても、式は同じで、減算された平均E(x)E(x)です。E (x )、標準偏差で割った値V ar(x)+ ϵ \ sqrt {Var(x)+ \ epsilon}V a r (x )+ϵ。ここで、ϵ \ epsilonϵは非常に少量です(デフォルトは1 0 − 5 10 ^ {-5}1 0− 5)、分母がゼロになるのを防ぎます。および2つのトレーニング可能なパラメーターβ、γ\ beta、\ gammaβ 、γ。違いは、操作する次元/次元です
。y = x − E(x)V ar(x)+ ϵ ∗γ+βy= \ frac {x --E(x)} {\ sqrt {Var(x) + \ epsilon}} \ ast \ gamma + \ betay=V a r (x )+ϵ。バツ−E (x)∗c+b
GN(Group Normalization)の操作は、 num _ groups = 2 num \ _groups=2と仮定して次の図に示されています。n u m _ g r o u p s=2元の論文のデフォルト値は32です。batch_sizeとは関係がないため、1つのサンプルの状況を直接調べます。レイヤーの出力がxxx、num _ groups num\_groupsチャネルチャネルに沿ったnum _ g r o u p sc h a n nel方向はnum_groups num\_groupsに分割されn u m _ g r o u p s、次にそれぞれの平均と分散を計算し、次に上記の式に従って計算します。これは非常に簡単です。
私の理解が正しいかどうかを確認するために、Pytorchを使用して簡単な実験を行い、確率変数を作成し、公式のGNメソッドを使用して自分で実装したGNメソッドと比較して結果が同じかどうかを確認しましょう。
import torch
import torch.nn as nn
def group_norm(x: torch.Tensor,
num_groups: int,
num_channels: int,
eps: float = 1e-5,
gamma: float = 1.0,
beta: float = 0.):
assert divmod(num_channels, num_groups)[1] == 0
channels_per_group = num_channels // num_groups
new_tensor = []
for t in x.split(channels_per_group, dim=1):
var_mean = torch.var_mean(t, dim=[1, 2, 3], unbiased=False)
var = var_mean[0]
mean = var_mean[1]
t = (t - mean[:, None, None, None]) / torch.sqrt(var[:, None, None, None] + eps)
t = t * gamma + beta
new_tensor.append(t)
new_tensor = torch.cat(new_tensor, dim=1)
return new_tensor
def main():
num_groups = 2
num_channels = 4
eps = 1e-5
img = torch.rand(2, num_channels, 2, 2)
print(img)
gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps)
r1 = gn(img)
print(r1)
r2 = group_norm(img, num_groups, num_channels, eps)
print(r2)
if __name__ == '__main__':
main()
公式手法と自己実施手法を比較すると、下図の左側が公式GN手法による結果、右側が自己実施GN手法による結果である。
明らかに、結果は公式の結果とまったく同じであり、これも私の理解が正しいことを示しています。