有关Pytorch中BatchNorm2d权重加载的问题
>>> import torch.nn as nn
>>> bn = nn.BatchNorm2d(3)
>>> list(bn.parameters())
[Parameter containing: tensor([1., 1., 1.], requires_grad=True),
Parameter containing: tensor([0., 0., 0.], requires_grad=True]
疑问:
bn的参数不应该是四个么(weight、bias、running_mean、running_var)?这里怎么是两个(weight、bias)呢?
原因是:在pytorch 中仅仅是可学习参数(可微分的)才能作为parameter
对于running_mean 和 running_var 来说仅仅是在forward时候,由input进行计算mean和var,再以momentum进行更新
为此这两个值为每次训练结束的最后一个forward才更新的值,所以不算是可学习的参数
如果想要获取到running_mean 和 running_var值,获取方式如下:
>>> bn.state_dict() # get OrderedDict