焦点損失とdmi損失のKeras(tf.keras)実装

言葉の少ない男

焦点損失

原文は
主に分類の不均衡と分類問題の分類難易度の違いの問題を解決します

損失関数の実現:

from tensorflow.keras import backend as K
def categorical_focal_loss_fixed(y_true, y_pred):
            """
            :param y_true: A tensor of the same shape as `y_pred`
            :param y_pred: A tensor resulting from a softmax
            :return: Output tensor.
            """

            # Scale predictions so that the class probas of each sample sum to 1
            y_pred /= K.sum(y_pred, axis=-1, keepdims=True)

            # Clip the prediction value to prevent NaN's and Inf's
            epsilon = K.epsilon()
            y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

            # Calculate Cross Entropy
            cross_entropy = -y_true * K.log(y_pred)

            # Calculate Focal Loss
            loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

            # Compute mean loss in mini_batch
            return K.mean(loss, axis=1)

DMIの損失

紹介記事

成し遂げる:

import tensorflow as tf
def dmi_loss(y_true, y_pred):
            """
            y_true为onehot真实标签
            y_pred为softmax后分数
            """
            y_true = tf.transpose(y_true, perm=[1, 0])
            mat = tf.matmul(y_true, y_pred)
            loss = -1.0 * tf.math.log(tf.math.abs(tf.linalg.det(mat)) + 0.001)
            return loss

損失関数は負の値まで実行される可能性がありますが、機能します。

ただし、dmiは、高精度でのリコールにも悪影響を及ぼします。

クロスエントロピー損失を
dmi損失を使用する
使用するdmi損失を使用する:
ここに写真の説明を挿入
2つの精度に大きな違いはありませんが、クロスエントロピーのpr曲線は明らかにdmiよりも優れています。

以下は公式アカウントです。QRコードのスキャンを歓迎します。ご清聴ありがとうございました。サポートありがとうございます。

公式アカウント名:Python into the pit NLP
公開なし
この公式アカウントは、主に自然言語処理、機械学習、コーディングアルゴリズム、およびPythonの知識共有に特化しています。私はただのサイドディッシュです。私の勉強と仕事のプロセスを記録しながら、みんなが一緒に進歩できることを願っています。交換と共有へようこそ。

おすすめ

転載: blog.csdn.net/lovoslbdy/article/details/107169797