【Pytorch】 Dice系数与Dice Loss损失函数实现

由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面通过自己的想法并结合网上的一些参考进行详细实现。

先来看一个我在网上看到的一个版本。

def diceCoeff(pred, gt, smooth=1, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = 2 * (intersection + smooth) / (unionset + smooth)

    return loss.sum() / N

整体思路就是运用dice的计算公式 (2 * A∩B) / (A∪B)。下面来分析一下可能存在的问题:

smooth参数是用来防止分母除0的,但是如果smooth=1的话,会使得dice的计算结果略微偏高,看下面的测试代码。

第一种情况:预测和标签完全一样

# shape = torch.Size([1, 3, 4, 4])
'''
1 0 0= bladder
0 1 0 = tumor
0 0 1= background 
'''
pred = torch.Tensor([[
        [[0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 0],
         [1, 0, 0, 1]]]])
    
gt = torch.Tensor([[
        [[0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 0],
         [1, 0, 0, 1]]]])


dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果
smooth=1 : dice=1.050
smooth=1e-5 : dice=1.0

我们最后预测的是一个3分类的分割图,第一类是baldder, 第二类是tumor, 第三类是背景。我们先假设bladder的预测pred和gt一样,计算bladder的dice值,发现当smooth=1的时候,dice偏高, 而smooth=1e-5时dice比较合理。

解决办法:我想这里应该更改代码的实现方式,用下面的计算公式替换之前的,因为之前加smooth的位置有问题。

# loss = 2 * (intersection + smooth) / (unionset + smooth)  # 之前的

loss = (2 * intersection + smooth) / (unionset + smooth)

替换后的dice如下:

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2 * intersection + smooth) / (unionset + smooth)

    return loss.sum() / N

上面用到的测试数据进行验证结果如下:dice计算正确 

# smooth=1 : dice=1.0
# smooth=1e-5 : dice=1.0

第二种情况:预测的结果不在标签中 

如下面的代码,我们假设预测的pred中有一部分bladder,但gt中没有bladder,看计算出的dice值如何。

'''
    1 0 0= bladder
    0 1 0 = tumor
    0 0 1= background 
    '''
    pred = torch.Tensor([[
        [[0, 1, 1, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]]])

    gt = torch.Tensor([[
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]]])

    dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
    dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
    print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
    print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果
smooth=1 : dice=0.3333
smooth=1e-5 : dice=5e-06

从结果可以看到,smooth=1时的dice值为0.3333;而 smooth=1e-5时的dice值接近于0,较为合理。

dice的另一种计算方式:这里参考肾脏肿瘤挑战赛提供的dice计算方法。

def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * tp) / (2 * tp + fp + fn)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    tp = torch.sum(gt_flat * pred_flat, dim=1)
    fp = torch.sum(pred_flat, dim=1) - tp
    fn = torch.sum(gt_flat, dim=1) - tp
    loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    return loss.sum() / N

整理代码

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2 * intersection + smooth) / (unionset + smooth)

    return loss.sum() / N



def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * tp) / (2 * tp + fp + fn)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    tp = torch.sum(gt_flat * pred_flat, dim=1)
    fp = torch.sum(pred_flat, dim=1) - tp
    fn = torch.sum(gt_flat, dim=1) - tp
    loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    return loss.sum() / N
class DiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, activation='sigmoid'):
        super(DiceLoss, self).__init__()
        self.activation = activation

    def forward(self, y_pr, y_gt):
        return 1 - diceCoeffv2(y_pr, y_gt, activation=self.activation)

总结:上面是这几天对dice以及dice loss的一些思考和实现。

原创文章 7 获赞 3 访问量 2296

猜你喜欢

转载自blog.csdn.net/baidu_36511315/article/details/105217674