pytorch:nn.BatchNormal

这里用的是batchnormal1d

import torch
from torch import nn

x = torch.rand([2,3,16]) #2张照片,3个通道,每个通道16个pixel
batch_normal = nn.BatchNorm1d(3) #括号里面的数字要与通道数一样
normal_result=batch_normal(x)

print(normal_result.size())

print(batch_normal.running_mean) #每个batch上16*2=32个数字的平均值,一共3个平均值

print(batch_normal.running_var) #每个batch上16*2=32个数字的方差,一共3个方差

在这里插入图片描述
在这里插入图片描述
我们可以看到batch_normal的mean和var和直接算出来的mean和var不太一样,并且batch_normal的mean刚好是直接算出来的mean的1/10.
这个与batchnormal的momentum参数有关。
在这里插入图片描述
把momentum设置成1,那么batch_normal的mean和var就和直接算出来的一样了:
在这里插入图片描述

发布了43 篇原创文章 · 获赞 1 · 访问量 749

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104801799