Pytorch中BatchNorm2d权重的加载

有关Pytorch中BatchNorm2d权重加载的问题

>>> import torch.nn as nn
>>> bn = nn.BatchNorm2d(3)
>>> list(bn.parameters())
[Parameter containing: tensor([1., 1., 1.], requires_grad=True), 
Parameter containing: tensor([0., 0., 0.], requires_grad=True]

疑问:

bn的参数不应该是四个么(weight、bias、running_mean、running_var)?这里怎么是两个(weight、bias)呢?

原因是:在pytorch 中仅仅是可学习参数(可微分的)才能作为parameter

对于running_mean 和 running_var 来说仅仅是在forward时候,由input进行计算mean和var,再以momentum进行更新

为此这两个值为每次训练结束的最后一个forward才更新的值,所以不算是可学习的参数

如果想要获取到running_mean 和 running_var值,获取方式如下:

>>> bn.state_dict()  # get OrderedDict

Supongo que te gusta

Origin blog.csdn.net/qq_38973721/article/details/120648289
Recomendado
Clasificación