tf.kerasカスタム損失関数

統計では、フーバー損失はロバスト回帰で使用される損失関数であり、二乗誤差損失よりもデータの外れ値の影響を受けません。分類のバリアントも時々使用されます。

def huber_fn(y_true, y_pred):
    error = y_true - y_pred
    is_small_error = tf.abs(error) < 1
    squared_loss = tf.square(error) / 2
    linear_loss  = tf.abs(error) - 0.5
    return tf.where(is_small_error, squared_loss, linear_loss)

カスタム損失関数の戻り値は平均損失はなくベクトルであり、各要素はインスタンスに対応することに注意してくださいこれの利点は、Kerasがウェイトをパスclass_weightまたはsample_weight調整できることです。

huber_fn(y_valid, y_pred)
<tf.Tensor: id=4894, shape=(3870, 1), dtype=float64, numpy=
array([[0.10571115],
       [0.03953311],
       [0.02417886],
       ...,
       [0.00039475],
       [0.00245003],
       [0.12238744]])>

おすすめ

転載: www.cnblogs.com/yaos/p/12746391.html