图解PyTorch中的torch.gather函数

1 背景

去年我理解了torch.gather()用法,今年看到又给忘了,索性把自己的理解梳理出来,方便今后遗忘后快速上手。

官方文档:

官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据

看到这个核心定义,我们很容易想到gather()基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是 高度自定义化的,可乱序的

2 实战

我们找个3x3的二维矩阵做个实验

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],
		[ 6,  7,  8],
		[ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示
在这里插入图片描述

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示
在这里插入图片描述

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],
        [7],
        [9]])

过程如图所示
在这里插入图片描述

2.4 输入二维矩阵index,并替换列索引(dim=1)

index = torch.tensor([[0, 2], 
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[3, 5],
        [7, 8]])

过程同上
注意计算过程中行号的变化

3 在强化学习DQN中的使用

在PyTorch官网DQN页面的代码中,是这样获取 Q ( S t , a ) Q(S_t,a) Q(St,a)

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

其中 Q ( S t ) Q(S_t) Q(St),即policy_net(state_batch)为shape=(128, 2)的二维表,动作数为2
在这里插入图片描述

而我们通过神经网络输出的对应批量动作 Q ( S t , a ) Q(S_t,a) Q(St,a)
在这里插入图片描述
此时,使用gather()函数即可轻松获取批量状态对应批量动作 Q ( S t , a ) Q(S_t,a) Q(St,a)

3 总结

从以上典型案例,我们可以归纳出torch.gather()的使用要点

  • 输入index的shape等于输出value的shape
  • 输入index的索引值仅替换该index中对应dim的index值
  • 最终输出为替换index后在原tensor中的值

本文转载自知乎,原文网址: 图解PyTorch中的torch.gather函数

猜你喜欢

转载自blog.csdn.net/weixin_46707326/article/details/120424556