【Pytorch】 RuntimeError: 1 solo se admiten lotes de objetivos espaciales (tensores 3D) pero obtuvieron objetivos de si

cada blog, cada lema: Simplemente vive tu vida porque no vivimos dos veces.

0. Prefacio

Encontré un error al usar Pytorch para entrenar la red y calcular la pérdida

1. Texto

1.1 Revisión de errores

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [2, 10, 256, 256]

Como se muestra arriba, espero pasar un Tensor de 3 dimensiones y, como resultado, obtengo un Tensor de 4 dimensiones.
Código:

# 损失函数
c_loss = nn.CrossEntropyLoss()
print('预测值:', d0.shape)  
print('标签值: ', labels_v.shape)  

loss0 = c_loss(d0, labels_v.long())

forma:
Inserte la descripción de la imagen aquí
Descripción:

  1. El tensor en pytorch se almacena de esta manera (lote, canal, altura, ancho) .
  2. En Tensorflow, se almacena en (lote, alto, ancho, canal) .

De acuerdo con la idea de calcular la pérdida en Tensorflow / keras, la forma del valor predicho y el valor de la etiqueta deberían ser iguales. Debería ser posible calcular la pérdida, ¿por qué se informa un error? ? ?
¿Pytorch solo puede calcular el tensor 3D?
En realidad no ...

1.2 Solución

Pytorch y Tensorflow son ligeramente diferentes en términos de cálculo de la función de pérdida.
En nuestra (muy amable) tarea, algo como esto:

1.2.1 Tensorflow

Entrenamiento del modelo:
Inserte la descripción de la imagen aquí
codificación de etiqueta one-hot: a
Inserte la descripción de la imagen aquí
través de estos dos pasos, podemos calcular la pérdida entre la etiqueta y el resultado de predicción generado por el modelo .

1.2.2 En pytorch

En Pytorch, "no necesitamos" realizar una codificación one-hot en la etiqueta, y necesitamos comprimir la dimensión del canal.
A saber:
Valor previsto: (256,256,10)
Valor de etiqueta: (256,256) La
pérdida se puede calcular directamente.
Entre ellos, el valor de la etiqueta corresponde al número de categorías
. Como se muestra en la figura siguiente, podemos calcular la pérdida
Inserte la descripción de la imagen aquí

referencias

[1] https://discuss.pytorch.org/t/runtimeerror-only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-dimension-4/82098
[2] https : //discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030/4
[3] https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html
[4] https://www.cnblogs.com/gshang/p/13854889.html
[5] https: / /jianzhuwang.blog.csdn.net/article/details/110955851

Supongo que te gusta

Origin blog.csdn.net/weixin_39190382/article/details/114433884
Recomendado
Clasificación