这里用的是batchnormal1d
import torch
from torch import nn
x = torch.rand([2,3,16]) #2张照片,3个通道,每个通道16个pixel
batch_normal = nn.BatchNorm1d(3) #括号里面的数字要与通道数一样
normal_result=batch_normal(x)
print(normal_result.size())
print(batch_normal.running_mean) #每个batch上16*2=32个数字的平均值,一共3个平均值
print(batch_normal.running_var) #每个batch上16*2=32个数字的方差,一共3个方差
我们可以看到batch_normal的mean和var和直接算出来的mean和var不太一样,并且batch_normal的mean刚好是直接算出来的mean的1/10.
这个与batchnormal的momentum参数有关。
把momentum设置成1,那么batch_normal的mean和var就和直接算出来的一样了: