Deep learning和tensorflow学习记录(三十):Broadcasting

tensorflow支持Broadcasting,在进行加,乘等运算时要确保它们的操作数匹配。但是有一种特例是其中一个tensor的一个维度为一时,tensorflow会自动推断扩展维度以使操作数匹配。

import tensorflow as tf

sess = tf.Session()
a = tf.constant([[1., 2.], [3., 4.]])
b = tf.constant([[1.], [2.]])
tile_b = tf.tile(b, [1, 2])
c1 = a + tile_b
c2 = a + b
print(sess.run(a))
print(sess.run(b))
print(sess.run(tile_b))
print(sess.run(c1))
print(sess.run(c2))

输出:

[[1. 2.]
 [3. 4.]]
[[1.]
 [2.]]
[[1. 1.]
 [2. 2.]]
[[2. 3.]
 [5. 6.]]
[[2. 3.]
 [5. 6.]]

但是像下面这种情况,我们不想Broadcasting的时候,由于tensorflow默认自动推断,所以得不到想要的结果。

sess = tf.Session()
a = tf.constant([[1.], [2.]])
b = tf.constant([1., 2.])
c = tf.reduce_sum(a + b)
print(sess.run(a))
print(sess.run(b))
print(sess.run(c))

输出:

[[1.]
 [2.]]
[1. 2.]
12.0

避免这种情况就要显式指定希望在哪个维度上进行操作。

sess = tf.Session()
a = tf.constant([[1.], [2.]])
b = tf.constant([1., 2.])
c = tf.reduce_sum(a + b, 0)
print(sess.run(a))
print(sess.run(b))
print(sess.run(c))

输出:

[[1.]
 [2.]]
[1. 2.]
[5. 7.]

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/81092602
今日推荐