pytorch实现标签平滑

class LabelSmoothingLoss(nn.Module):
    '''LabelSmoothingLoss
    '''
    def __init__(self, smoothing=0.05, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        num_class = pred.size()[-1]
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (num_class - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

猜你喜欢

转载自blog.csdn.net/qq_55542491/article/details/130882950
今日推荐