【debug】pytorch CTC_Loss为nan

1. feature中有nan值

有次max_pool2d参数设计错误出现了这种情况
可以通过 print(feature.max()) 看feature的最大值

2. target length有0值

现在pytorch中有自带的ctcloss其用法


>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()

其中:

  • target.shape = target_lengths.sum()
    注意: 这里的target_lengths中如果有 ‘0’ 则loss为nan。表示一张图片中没有字符,一个字符都没有
    解决方案: 在dataset中过滤掉len(label)==0的图片

猜你喜欢

转载自blog.csdn.net/u011622208/article/details/105647379
今日推荐