CenterLoss

CenterLoss

CenterLoss是干嘛的?

我们知道:对于图像的分类问题,我们通常使用softmax来计算损失在这里插入图片描述

上图就是我们使用softmax对手写数字图像进行不同数字进行颜色的分类

如果我们在softmax损失的基础上添加上一个centerloss损失,那么对于手写数字图像的分类就如下图所示了

在这里插入图片描述

那么centerloss是怎么把图像分类成上面的这个样子的呢?

在这里插入图片描述

在训练过程中,我们同时更新中心并最小化深度特征与其对应的类中心之间的距离

在这里插入图片描述

通过联合监督,不仅扩大了类间特征差异,而且减少了类内特征差异

至此,我们大体知道了centerloss的作用了,那么centerLoss具体是如何实现的呢?在这里插入图片描述

对于公式的解读:

  • c y i c_{y_i} cyi表示第 y i {y_i} yi个类别的特征中心
  • x i x_i xi表示特征值

公式的大体意思是:训练中,每一个batch的样本的特征与当前类别中心的距离的平方和越小越好(也就是类内距越小越好)

在这里插入图片描述

比如上图:红蓝两个样本,其中X所代表的就是类别的中心点 c y i c_{y_i} cyi

centerloss就是将离散的点向中心点聚拢从而增大了类内的间距

centerLoss和Softmax的结合

在这里插入图片描述

代码实现

class CenterLoss(nn.Module):
    def __init__(self, num_class, num_feature):
        super().__init__()
        self.center = nn.Parameter(torch.rand(num_class, num_feature))
    def forward(self, features, targets):
        batch_size = features.size(0)
        loss = torch.sum((features - torch.index_select(self.center,0,targets))**2)/batch_size
        return loss

在这里插入图片描述

Guess you like

Origin blog.csdn.net/qq_38973721/article/details/113179096