ハンドティアリング/手書き/BNレイヤーの実装/バッチノルム/バッチ正規化 python torch pytorch

計算プロセス

畳み込みニューラル ネットワークでは、BN 層によって入力される特徴マップの次元は (N、C、H、W) であり、出力特徴マップの次元も (N、C、H、W) です。N は、バッチ サイズ C は
チャネル数を表します
H
は特徴マップの高さ W を表します
特徴マップの幅を表します

チャネル次元でバッチ正規化を行う必要があります。
バッチでは、
すべての特徴マップの同じ位置にあるチャネルのすべての要素を使用して平均と分散を計算し、
計算された平均と分散を使用してチャネルを更新します。対応する特徴マップ。新しい特徴マップを生成します。

以下の図に示すように、
4 つのオレンジ色の特徴マップについて、すべての要素の平均と分散を計算し、それらを使用して 4 つの特徴マップ内の要素を更新します (元の要素から平均を引いたものを分散で割った値)。
![[添付ファイル/BN 図.png]]

コード

def my_batch_norm_2d_detail(features, eps=1e-5):
    '''
        这个函数的写法是为了帮助理解 BatchNormalization 具体运算过程
        实际使用时这样写会比较慢
    '''
    
    n,c,h,w = features.shape
    features_copy = features.clone()
    running_var = torch.randn(c)
    running_mean = torch.randn(c)
    for ci in range(c):# 分别 处理每一个通道
        mean = 0 # 均值
        var = 0 # 方差
        
        _sum = 0 
        # 对一个 batch 中,特征图相同位置 channel 的每一个元素求和
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    _sum += features[ni,ci, hi, wi]
        mean = _sum / (n * h * w) 
        running_mean[ci] = mean
        

        _sum = 0
        # 对一个 batch 中,特征图相同位置 channel 的每一个元素求平方和,用于计算方差 
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    _sum += (features[ni,ci, hi, wi] - mean) ** 2
        var = _sum / (n * h * w )
        running_var[ci] = _sum / (n * h * w - 1)

        # 更新元素
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    features_copy[ni,ci, hi, wi] = (features_copy[ni,ci, hi, wi] - mean) / torch.sqrt(var + eps) 
        
    return features_copy, running_mean, running_var

if __name__ == "__main__":


    torch.set_printoptions(precision=7)

    torch_bn = nn.BatchNorm2d(4)  # 设置 channel 数
    torch_bn.momentum = None
    features = torch.randn(4, 4, 2, 2) # (N,C,H,W)
        
    torch_bn_output = torch_bn(features)    
    my_bn_output, running_mean, running_var = my_batch_norm_2d_detail(features)        
            
    print(torch.allclose(torch_bn_output, my_bn_output))
    print(torch.allclose(torch_bn.running_mean, running_mean))
    print(torch.allclose(torch_bn.running_var, running_var))

予防

分散計算

トレーニング プロセス中に、分散には 2 つの異なる計算方法があることに注意してください。

トレーニング中に、バイアス付き分散を使用して特徴マップを更新し
、running_var の計算では不偏分散を使用します。
ここに画像の説明を挿入

関連リンク

関係者による手書きのBN

"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d

@author: ptrblck
"""

import torch
import torch.nn as nn


def compare_bn(bn1, bn2):
    err = False
    if not torch.allclose(bn1.running_mean, bn2.running_mean):
        print('Diff in running_mean: {} vs {}'.format(
            bn1.running_mean, bn2.running_mean))
        err = True

    if not torch.allclose(bn1.running_var, bn2.running_var):
        print('Diff in running_var: {} vs {}'.format(
            bn1.running_var, bn2.running_var))
        err = True

    if bn1.affine and bn2.affine:
        if not torch.allclose(bn1.weight, bn2.weight):
            print('Diff in weight: {} vs {}'.format(
                bn1.weight, bn2.weight))
            err = True

        if not torch.allclose(bn1.bias, bn2.bias):
            print('Diff in bias: {} vs {}'.format(
                bn1.bias, bn2.bias))
            err = True

    if not err:
        print('All parameters are equal!')


class MyBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input


# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)

compare_bn(my_bn, bn)  # weight and bias should be different
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
compare_bn(my_bn, bn)

# Run train
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())

# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())

おすすめ

転載: blog.csdn.net/SugerOO/article/details/130029642
おすすめ