【pytorch】RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of si

every blog every motto: Just live your life cause we don’t live twice.

0. Preface

I encountered an error when using pytorch to train the network and calculate the loss

1. Text

1.1 Error review

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [2, 10, 256, 256]

As shown above, I hope to pass in a 3-dimensional Tensor, and the result is a 4-dimensional.
Code:

# 损失函数
c_loss = nn.CrossEntropyLoss()
print('预测值:', d0.shape)  
print('标签值: ', labels_v.shape)  

loss0 = c_loss(d0, labels_v.long())

shape:
Insert picture description here
Description:

  1. Tensor in pytorch is stored in this way (batch, channel, height, width) .
  2. In Tensorflow, it is stored in (batch, height, width, channel) .

According to the thinking of calculating loss in Tensorflow/keras, the shape of the predicted value and the label value should be the same. It should be possible to calculate the loss, why is it reported an error? ? ?
Does pytorch can only calculate 3D Tensor???
Actually not...

1.2 Solution

Pytorch and Tensorflow are slightly different in terms of loss function calculation.
In our (very kind) task, something like this:

1.2.1 Tensorflow

Model training:
Insert picture description here
label one-hot encoding:
Insert picture description here
through these two steps, we can calculate the loss between the label and the prediction result generated by the model .

1.2.2 In pytorch

In Pytorch, we "do not need" to perform one-hot encoding on the label, and we need to compress the channel dimension.
Namely:
Predicted value: (256,256,10)
Label value: (256,256) The
loss can be calculated directly.
Among them, the value in the label corresponds to the number of categories
. As shown in the figure below, we can calculate the loss
Insert picture description here

references

[1] https://discuss.pytorch.org/t/runtimeerror-only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-dimension-4/82098
[2] https://discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030/4
[3] https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html
[4] https://www.cnblogs.com/gshang/p/13854889.html
[5] https://jianzhuwang.blog.csdn.net/article/details/110955851

Guess you like

Origin blog.csdn.net/weixin_39190382/article/details/114433884