tf.gather_nd 用法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/G66565906/article/details/84949512

tf.gather_nd的函数原型是:

def gather_nd(params, indices, name=None)

根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor

在谈论该函数之前,我们先来看一下 索引的概念,

在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记

如数组 ary=[1,2,3,4];

元素2的索引 为 1, 元素的引用可表示为 [1]

元素3的索引为 2,  元素的引用可表示为 [2]

那么二维数组呢? 类似地

对于二维 ary=[ [1,2], [3,4] ]

元素 [1,2]  在一维中的索引为 [0],   元素 1 的索引 则表示为 [0,0], 元素 2 的索引 则表示为 [0,1], 

因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,

还是以上面的二维数组为例

[0,0] 表示 的是 1,

[0,1] 表示的是 2

当indices 为  [[0,0],[0,1]] 时, 该函数的输出则为  [1,2]

即 indices 中 表示索引的 部分 被提取到的值替换

那么当indices 为[ [ [ [  [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4  ] ] ] ]

猜你喜欢

转载自blog.csdn.net/G66565906/article/details/84949512