tf.scatter_update和tf.batch_scatter_update

tf.scatter_update

函数定义:

tf.scatter_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

需要说明的是,updates.shape = [*indices.shape, *ref.shape[1:]], upadtes的shape不一定与ref的shape相等。

测试:

a = tf.Variable([[1, 2, 3, 4], [5, 6, 7, 8]])
indices = [[0, 1], [1, 0]]
updates = [[[1, 1, 1, 1], [2, 3, 4, 5]], [[2, 2, 2, 2], [3, 3, 3, 3]]]
b = tf.scatter_update(a, indices, updates)

sess = tf.InteractiveSession()
print(sess.run([b]))

结果:

[array([[3, 3, 3, 3],
        [2, 2, 2, 2]])]

结果说明:

重点是解读indices的含义

indices的值指定ref中被替换的对象,以上面测试为例,indices中的0,1分别指定a中的a[0]、a[1]将被替换。

indices的值对应的index指定ref中的相应替代值为update[index],以上为例,indices[0][0]为1,则a[1]将被替换为updates[0][0]。

此外indices中重复出现的值将被多次替换,至于结果是不确定的。


tf.batch_scatter_update

函数定义:

与tf.scatter_update相同

测试:

d = tf.Variable([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
indices = [[1, 1], [1, 0]]
updates = [[[1, 1], [2, 2]], [[3, 3], [4, 4]]]
e = tf.batch_scatter_update(d, indices, updates)
sess.run(tf.global_variables_initializer())
print(sess.run([e]))

结果:

[array([[[0, 1],
         [2, 2]],
 
        [[4, 4],
         [3, 3]]])]

结果说明:

tf.batch_scatter_update与tf.scatter_update类似,只是在进行值替换时,tf.scatter_update中ref替换对象由indices的值指定,而在tf.batch_scatter_update中由indices的值和对应的index[:-1]共同指定。

以上面为例 :indices[0][1]为1,则d[0][1]]替换为updates[0][1],其中d[0][1]中的1是indices[0][1]的值。

猜你喜欢

转载自blog.csdn.net/lssc4205/article/details/87806841
今日推荐