Keras (tf.keras) implementation of focal loss and dmi loss

Man of few words

focal loss

The original text
mainly solves the problems of imbalance in classification and differences in classification difficulty.

Loss function realization:

from tensorflow.keras import backend as K
def categorical_focal_loss_fixed(y_true, y_pred):
            """
            :param y_true: A tensor of the same shape as `y_pred`
            :param y_pred: A tensor resulting from a softmax
            :return: Output tensor.
            """

            # Scale predictions so that the class probas of each sample sum to 1
            y_pred /= K.sum(y_pred, axis=-1, keepdims=True)

            # Clip the prediction value to prevent NaN's and Inf's
            epsilon = K.epsilon()
            y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

            # Calculate Cross Entropy
            cross_entropy = -y_true * K.log(y_pred)

            # Calculate Focal Loss
            loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

            # Compute mean loss in mini_batch
            return K.mean(loss, axis=1)

DMI loss

Introduction article

achieve:

import tensorflow as tf
def dmi_loss(y_true, y_pred):
            """
            y_true为onehot真实标签
            y_pred为softmax后分数
            """
            y_true = tf.transpose(y_true, perm=[1, 0])
            mat = tf.matmul(y_true, y_pred)
            loss = -1.0 * tf.math.log(tf.math.abs(tf.linalg.det(mat)) + 0.001)
            return loss

The loss function may run to a negative value, but it works.

But dmi will also cause its negative impact on recalls under high precision.

Use cross entropy loss:
Use dmi loss
use dmi loss:
Insert picture description here
Although the accuracy of the two is not much different, the pr curve of cross entropy is obviously better than dmi.

The following is the official account, welcome to scan the QR code, thank you for your attention, thank you for your support!

Official account name: Python into the pit NLP
No public
This official account is mainly dedicated to natural language processing, machine learning, coding algorithms and some knowledge sharing of Python. I am just a side dish. I hope everyone can make progress together while recording the process of my study and work. Welcome to exchange and share.

Guess you like

Origin blog.csdn.net/lovoslbdy/article/details/107169797