It means that there are two pictures in a batch, and there are two channels in each picture. The numbers correspond to each channel.
According to the code, it can be observed that the mean and variance of channel 1 of all pictures in the batch are calculated first, and then the mean and variance of channel 2 are calculated in turn. and the sample standard deviation.
Next, use the update publicity to update the standardized data. and make modifications to the new mean and variance.
The following is the reference formula and code
import numpy as np
import torch.nn as nn
import torch
def bn_process(feature, mean, var):
feature_shape = feature.shape
for i in range(feature_shape[1]):#遍历的是channel
# [batch, channel, height, width]
feature_t = feature[:, i, :, :]
mean_t = feature_t.mean()#依次进行计算,首先计算第一个通道的均值和方差。
# 总体标准差
std_t1 = feature_t.std()
# 样本标准差
std_t2 = feature_t.std(ddof=1)
# bn process
# 这里记得加上eps和pytorch保持一致
feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
# update calculating mean and var
mean[i] = mean[i] * 0.9 + mean_t * 0.1
var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
print(feature)
# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
torch.manual_seed(1)
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
print(feature1.numpy())
# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)
Learning reference link:
https://blog.csdn.net/wzk4869/article/details/127261308
https://blog.csdn.net/qq_37541097/article/details/104434557