tflearn自定义损失函数

创建一个对象,实现__call__方法

class weighted_cross_entropy(object):
    def __call__(self, y_pred, y_true):
        """
        logits: a Tensor with shape [batch_size, image_width, image_height, channel], score from the unet conv10
        label: a Tensor with shape [batch_size, image_width, image_height], ground truth
        """
        weight = [0.21008659,  0.26289699,  0.28279202,  0.24422441]
        # label = tf.one_hot(tf.cast(y_true, dtype=tf.uint8), y_pred.get_shape()[-1])
        prob = tf.nn.softmax(y_pred, dim=-1)
        loss = -tf.reduce_mean(y_true * tf.log(prob) * weight)
        return loss

方法体内的格式可以参照tflearn -> objectives.py来写。注意传入的y_pred和y_true都是float类型的,如上中如果要使用one_hot就需要强转类型。

def categorical_crossentropy(y_pred, y_true):
    """ Categorical Crossentropy.

    Computes cross entropy between y_pred (logits) and y_true (labels).

    Measures the probability error in discrete classification tasks in which
    the classes are mutually exclusive (each entry is in exactly one class).
    For example, each CIFAR-10 image is labeled with one and only one label:
    an image can be a dog or a truck, but not both.

    `y_pred` and `y_true` must have the same shape `[batch_size, num_classes]`
    and the same dtype (either `float32` or `float64`). It is also required
    that `y_true` (labels) are binary arrays (For example, class 2 out of a
    total of 5 different classes, will be define as [0., 1., 0., 0., 0.])

    Arguments:
        y_pred: `Tensor`. Predicted values.
        y_true: `Tensor` . Targets (labels), a probability distribution.

    """
    with tf.name_scope("Crossentropy"):
        y_pred /= tf.reduce_sum(y_pred,
                                reduction_indices=len(y_pred.get_shape())-1,
                                keep_dims=True)
        # manual computation of crossentropy
        y_pred = tf.clip_by_value(y_pred, tf.cast(_EPSILON, dtype=_FLOATX),
                                  tf.cast(1.-_EPSILON, dtype=_FLOATX))
        cross_entropy = - tf.reduce_sum(y_true * tf.log(y_pred),
                               reduction_indices=len(y_pred.get_shape())-1)
        return tf.reduce_mean(cross_entropy)

在loss中新建一个对象传进去

network = regression(conv10, optimizer='adam',
                     loss=weighted_cross_entropy(),
                     learning_rate=5e-4)

源码:

tflearn -> layers -> estimator.py -> regression

# Building other ops (loss, training ops...)
if isinstance(loss, str):
    loss = objectives.get(loss)(incoming, placeholder)
# Check if function
elif hasattr(loss, '__call__'):
    try:
        loss = loss(incoming, placeholder)
    except Exception as e:
        print(str(e))
        print('Reminder: Custom loss function arguments must be defined as: '
              'custom_loss(y_pred, y_true).')
        exit()
elif not isinstance(loss, tf.Tensor):
    raise ValueError("Invalid Loss type.")

在elif中,如果loss不是字符串,同时有__call__方法,那么就通过调用该方法来实现损失的计算。该方法的作用自行百度。

猜你喜欢

转载自blog.csdn.net/Asun0204/article/details/79102228