言葉の少ない男
焦点損失
原文は
主に分類の不均衡と分類問題の分類難易度の違いの問題を解決します
損失関数の実現:
from tensorflow.keras import backend as K
def categorical_focal_loss_fixed(y_true, y_pred):
"""
:param y_true: A tensor of the same shape as `y_pred`
:param y_pred: A tensor resulting from a softmax
:return: Output tensor.
"""
# Scale predictions so that the class probas of each sample sum to 1
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
# Clip the prediction value to prevent NaN's and Inf's
epsilon = K.epsilon()
y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
# Calculate Cross Entropy
cross_entropy = -y_true * K.log(y_pred)
# Calculate Focal Loss
loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy
# Compute mean loss in mini_batch
return K.mean(loss, axis=1)
DMIの損失
成し遂げる:
import tensorflow as tf
def dmi_loss(y_true, y_pred):
"""
y_true为onehot真实标签
y_pred为softmax后分数
"""
y_true = tf.transpose(y_true, perm=[1, 0])
mat = tf.matmul(y_true, y_pred)
loss = -1.0 * tf.math.log(tf.math.abs(tf.linalg.det(mat)) + 0.001)
return loss
損失関数は負の値まで実行される可能性がありますが、機能します。
ただし、dmiは、高精度でのリコールにも悪影響を及ぼします。
クロスエントロピー損失を
使用する:dmi損失を使用する:
2つの精度に大きな違いはありませんが、クロスエントロピーのpr曲線は明らかにdmiよりも優れています。
以下は公式アカウントです。QRコードのスキャンを歓迎します。ご清聴ありがとうございました。サポートありがとうございます。
公式アカウント名:Python into the pit NLP
この公式アカウントは、主に自然言語処理、機械学習、コーディングアルゴリズム、およびPythonの知識共有に特化しています。私はただのサイドディッシュです。私の勉強と仕事のプロセスを記録しながら、みんなが一緒に進歩できることを願っています。交換と共有へようこそ。