torch.gather()用法详解

官方示例

链接:https://pytorch.org/docs/stable/generated/torch.gather.html

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
# 沿由 dim 指定的轴收集input的值,其输出形状与index相同。
  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

对于一个3维的tensor,其输出如下:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

图解

参考:https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms

torch.gather 通过沿输入维度 dim 从每一行获取值,从输入张量创建一个新张量。 torch.LongTensor 中的值作为索引传递,指定从每个“行”中获取的值。 输出张量的维度与索引张量的维度相同。 以下图片能更清楚地解释它:
在这里插入图片描述

图1:torch.gather二维用法

label smoothing 中的用法

参考:https://www.pythonfixing.com/2021/11/fixed-label-smoothing-in-pytorch.html

假设GT是 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0),设label smoothing的系数 α = 0.2 \alpha=0.2 α=0.2,我们想象中的结果应该是 ( 0.1 , 0.8 , 0.1 ) (0.1,0.8,0.1) (0.1,0.8,0.1),以上标签是上述参考的结果。然而实际上在pytorchtensorflow中都不是这么实现的,它们俩得到的标签应该是 ( 0.2 / 3 , 0.8 + 0.2 / 3 , 0.2 / 3 ) (0.2/3,0.8+0.2/3,0.2/3) (0.2/3,0.8+0.2/3,0.2/3),再进行交叉熵计算得到结果。勘误: 原始论文中的label smoothing就是按照pytorch中那样做的,想象中的是误导!!!

import torch
import torch.nn as nn

# label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self):
        super(LabelSmoothingCrossEntropy, self).__init__()
    def forward(self, x, target, smoothing=0.1):
        confidence = 1. - smoothing
        logprobs = x.log_softmax(dim=-1)
        print(logprobs.shape)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1) # cross_entropy loss without mean
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + smoothing * smooth_loss
        return loss.mean()

ce0 = LabelSmoothingCrossEntropy()
ce1 = nn.CrossEntropyLoss(label_smoothing=0.1)

predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                                 [0, 0.9, 0.2, 0.2, 1], 
                                 [1, 0.2, 0.7, 0.9, 1]])
label = Variable(torch.LongTensor([2, 1, 0]))
out0 = ce0(predict,target)  # tensor(1.3096)
out1 = ce1(predict,target)  # tensor(1.3096)

猜你喜欢

转载自blog.csdn.net/Huang_Fj/article/details/124234947