tensflow自定义损失函数

三、自定义损失函数

标准的损失函数并不合适所有场景,有些实际的背景需要采用自己构造的损失函数,Tensorflow 也提供了丰富的基础函数供自行构建。
例如下面的例子:当预测值(y_pred)比真实值(y_true)大时,使用 (y_pred-y_true)*loss_more 作为 loss,反之,使用 (y_true-y_pred)*loss_less

loss = tf.reduce_sum(tf.where(tf.greater(y_pred, y_true), (y_pred-y_true)*loss_more,(y_true-y_pred)*loss_less))

tf.greater(x, y):判断 x 是否大于 y,当维度不一致时广播后比较
tf.where(condition, x, y):当 condition 为 true 时返回 x,否则返回 y 
tf.reduce_mean():沿维度求平均
tf.reduce_sum():沿维度相加
tf.reduce_prod():沿维度相乘
tf.reduce_min():沿维度找最小
tf.reduce_max():沿维度找最大
使用 Tensorflow 提供的方法可自行构造想要的损失函数。

猜你喜欢

转载自blog.csdn.net/qq_41853536/article/details/83348377