PyTorch——解决报错“RuntimeError: running_mean should contain *** elements not ***”

1 问题描述

在使用PyTorch编程的时候,经常遇到一种报错就是:“RuntimeError: running_mean should contain *** elements not ***”;

这次我具体的报错信息是:

  File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
    exponential_average_factor, self.eps)
  File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 192 elements not 768

从最后一行的报错信息,可以看到:进行求均值元素的总数应该是192而不是768;

2 解决方案

我们可以继续看看上一条提示信息:“File "/home/songyuc/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm

有一个值得注意的信息是batch_norm,而我们的模型中也刚好使用了BN的操作,所以应该是BN的设置出现了问题,

我们回到代码定位的部分进行查看,需要查看的是BN初始化设置的代码,然后看到了下面的代码:

modules = [nn.Sequential(
    nn.Conv2d(in_channels, OUT_CHANNELS, 1, groups=1, bias=False),
    nn.BatchNorm2d(in_channels),
    nn.ReLU()),

我们可以看到,果然,BatchNorm2d的输入通道数与前一层Conv2d的输出通道数不一致,而这里的OUT_CHANNELS=192,in_channels=192,所以造成了这种维度的不一致,所以才会报错;

所以,我们需要根据自己模型的设计,将BN层与Conv层的输出维度保持一致。

发布了277 篇原创文章 · 获赞 76 · 访问量 30万+

猜你喜欢

转载自blog.csdn.net/songyuc/article/details/104431857