标签平滑Label Smoothing


前言

对于分类问题,lable常常是one-hot编码的,即[0,0,1,0,0]形式。全概率1和0鼓励所属类别和其他类别之间的差距尽可能加大,然而,在分类问题中不同种类的类别不一定完全没有相似的特征,不能这样一杆子打死。
对于我们常用的交叉熵损失函数,我们需要用预测分布q去拟合真实分布p,现在我们来看一下拟合one-hot的分布所带来的问题:
1)例如,输出为[0.1,0.7,0.1,0.1],由于要使得Loss尽可能小,会让模型尽可能的调整为[0,1,0,0],尤其针对像交叉熵这类loss,一旦output有些偏差,loss值可能就往无穷大走了,就逼迫模型去接近真实的label。无法保证模型的泛化能力,容易造成过拟合。适当调整label,让两端的极值往中间凑凑,可以增加泛化性能。
2) 如果此时的标签还是错的话,会让Loss很大很大,模型调整按此Loss调整会使得前面正确的学习都变成功亏一篑。交叉熵损失函数如图所示:
在这里插入图片描述
在这种情况下,模型越好错误样本的影响越大。即本该用0.8去算Loss,却用了0.05去算。


Label Smoothing

在这里插入图片描述

将原本[0,1,0,0]的标签变成诸如[0.1,0.7,0.1,0.1]的标签。这样的好处有两个:
1)是更加符合实际情况的分布,即不同样本之间其实是有相似的特征的,具体每一类之间有多少相似性,需要按照实际情况来考虑分配。也可以直接用公式。也有相关的论文,做了让模型自己学习应该怎么分配标签的工作。
在这里插入图片描述

2)Label Smoothing是分类问题中解决noisy label一种方法。当样本标签错误时,由于采样这种方法,让Loss不至于像原来那般大。例如:三个类别,α=0.1,err_Loss=5x1+2x0+0.2x0=5,smooth_err_Loss=5x0.8+2x0.1+0.2x0.1=4.22。α越大平滑的效果会更明显。

代码实现

#onehot_labels 是[0,0,1,0,0]向量形式
new_onehot_labels = onehot_labels * (1 - label_smoothing)
                           + label_smoothing / num_classes

pytorch实现损失函数:

#1.Use a function to get smooth label
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist

#2.Make CrossEntropyLoss support k-hot/smoothed targets.
"""
Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...

outputs = model(data)

smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
"""

NVIDIA有个官方的实现:

# Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import torch.nn as nn


class LabelSmoothing(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, padding_idx, smoothing=0.0):
        """
        Constructor for the LabelSmoothing module.
        :param padding_idx: index of the PAD token
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothing, self).__init__()
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1,
                                                   dtype=torch.float32)

        non_pad_mask = (target != self.padding_idx)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)[non_pad_mask]
        smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask]
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.sum()


后续相关

后续如果有时间,会继续针对错误样本处理方法,讲解Bi -Tempered Logistic Loss

猜你喜欢

转载自blog.csdn.net/qq_41917697/article/details/112792943