keras--实现加权交叉熵(多分类)

在做图像分割任务时由于背景类别占比很大,导致网络倾向于预测背景,虽然准确率很高,但是目标区域完全没有被预测,因此考虑修改loss函数交叉熵,将背景类别的权重降低。

实现交叉熵计算

交叉熵的计算原理直接看一个例子:
在这里插入图片描述
下面基于keras和tensorflow实现交叉熵,假设有3个样本,共4类(0123),y_pred为网络的输出(logits,即未softmax)

import numpy as np
import keras.backend as K
import tensorflow as tf
from keras.utils import to_categorical

label = np.array([1, 0, 3])
y_true = to_categorical(label)
# [[0. 1. 0. 0.]
#  [1. 0. 0. 0.]
#  [0. 0. 0. 1.]]
y_pred = np.array(
    [[2.8883,  0.1760,  1.0774, 1.05],
     [1.1216, -0.0562,  0.0660, 0.001],
     [-1.3939, -0.0967,  3.5853, 1.58]]
)
y_pred = K.constant(y_pred)

对其softmax,每行之和为1,在对其取log

soft_pred = K.softmax(y_pred) # softmax
# [[0.71995354 0.04779337 0.11771739 0.11453571]
#  [0.50453496 0.1553743  0.17556988 0.16452082]
#  [0.03178228 0.11629253 0.23000678 0.6219184 ]]
log_pred = K.log(soft_pred)
# [[-0.32856858 -3.0408685  -2.1394684  -2.1668687 ]
#  [-0.68411815 -1.8619182  -1.7397182  -1.8047183 ]
#  [-3.4488466  -2.1516464  -1.4696465  -0.4749464 ]]

自己的交叉熵和tensorflow的交叉熵比较(注意这里是3个样本的交叉熵,要计算整个batch的交叉熵还要求和)。

my_loss = -K.sum(y_true*log_pred, axis=1)
# [3.0408685  0.68411815 0.4749464 ]
tf_loss = tf.nn.softmax_cross_entropy_with_logits(y_true, y_pred)
# [3.0408685  0.68411815 0.4749464 ]

加权交叉熵

接着以上的交叉熵进一步实现加权交叉熵,加权就是为了给不同类别加上权重,使得网络重视样本量较少的类别。现在我们有四个类别,为每个类别赋一个权重,例如第三个类别样本最少,我们赋予权重4,类别一样本多,赋予1。然后将类别权重和one-hot做一个点乘来确定3个样本的分别属于哪个类别并赋予权重,这里用到了python的广播机制,它会将class_weights扩展到和y_true一个维度,然后再点乘,此时得到weights为[1.5 1. 2. ],即为三个样本的权重

class_weights = tf.constant([[1.0, 1.5, 4.0, 2.0]])  # 4个类别的权重
weights = tf.reduce_sum(class_weights * y_true, axis=1) 

将计算得到的样本权重和上面的交叉熵点乘就可以得到带权交叉熵(unweighted_losses即你自己实现的my_loss或者用tf求的tf_loss),对比不带权的结果可以看到加权后第三个样本的loss(属于第四类)变大了,即网络更重视第四个类别

weighted_losses = unweighted_losses * weights
# weighted_losses [4.561302   0.68411815 4.3189106 ]
# unweighted_losses [3.0408685  0.68411815 0.4749464 ]

还可以使用skleaern来计算class_weights然后再计算加权交叉熵,详见

参考:https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy

发布了83 篇原创文章 · 获赞 4 · 访问量 5343

猜你喜欢

转载自blog.csdn.net/weixin_43486780/article/details/105629077