PyTorch: solución al problema de pérdida de entrenamiento del modelo (Pérdida) NaN

Bienvenido a seguir mi CSDN: https://spike.blog.csdn.net/La
dirección de este artículo: https://spike.blog.csdn.net/article/details/133378367

Durante el entrenamiento del modelo, si ocurre un problema de NaN, afectará seriamente el proceso de retropropagación de Loss, por lo que es necesario agregar algunos valores pequeños al procesamiento para evitar afectar los resultados del entrenamiento del modelo.

Por ejemplo, pérdida de entropía cruzada sigmoid_cross_entropy, incluida la función logarítmica (log), cuando la entrada es 0 al calcular el valor de registro, provocará un desbordamiento, por lo que se debe agregar una restricción de valor mínimo (como 1e-8) para evitar el desbordamiento. .

Fórmula de entropía cruzada:

L ( y , y ^ ) = − 1 N ∑ i = 1 N [ yi log ⁡ ( y ^ i ) + ( 1 − yi ) log ⁡ ( 1 − y ^ i ) ] L(y, \hat{y} ) = -\frac{1}{N} \sum_{i=1}^N [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) ]L ( y ,y^)=norte1yo = 1norte[ yyoiniciar sesión (y^yo)+( 1yyo)iniciar sesión ( 1y^yo)]

Curva logarítmica:

registro

Ahora mismo:

# 额外增加 eps,可以避免数值溢出
def sigmoid_cross_entropy(logits, labels, eps=1e-8):
    logits = logits.float()
    log_p = torch.log(torch.sigmoid(logits)+eps)
    log_not_p = torch.log(torch.sigmoid(-logits)+eps)
    loss = -labels * log_p - (1 - labels) * log_not_p
    return loss

La entropía cruzada sigmoidea es una función de pérdida de uso común que se utiliza para medir la diferencia entre los resultados de predicción del modelo y las etiquetas reales en problemas de clasificación binaria. Su función es optimizar los parámetros del modelo para que el modelo se ajuste mejor a los datos y mejorar la precisión de la clasificación.

Ejemplo: Cómo resolver la pérdida que se convierte en nan debido al uso de torch.log()

Supongo que te gusta

Origin blog.csdn.net/u012515223/article/details/133378367
Recomendado
Clasificación