Tensorflow一个矩阵与多个矩阵相乘(不同维度的Tensor相乘)

TensorFlow的 matmul 已经支持了batch,对于高维向量的相乘(比如两个三维矩阵),tensorflow把前面的维度当成是batch,对最后两维进行普通的矩阵乘法。也就是说,最后两维之前的维度,都需要相同。比如 A.shape=(a, b, n, m)B.shape=(a, b, m, k),tf.matmul(A,B) 的结果 shape=(a,b,n,k)

有时候需要一个矩阵与多个矩阵相乘,也就是一个 2D Tensor 与一个 3D Tensor 相乘,比如 A.shape=(m, n)B.shape=(k, n, p),希望计算 A*B 得到一个 C.shape=(k, m, p) 的 Tensor,可以采取的思路为:

  1. B transpose(B,[1,0,2]),(n,k,p),维度交换
  2. B reshape(B,[n,k*p])
  3. C=AB (m,kp)
  4. C reshape(c,[m,k,p])
  5. C transpose(C,[1,0,2]),(k,m,p),维度交换

可以看下面一个例子(c为标准答案,g为最后的正确结果,e是错误的):

import tensorflow as tf
a = tf.reshape(tf.linspace(1.,6.,6),[2,3])
b = tf.reshape(tf.linspace(1.,24.,24),[2,3,4])
c = tf.matmul(tf.tile(tf.expand_dims(a,0),multiples=[2,1,1]),b)
d = tf.matmul(a,tf.reshape(b,[3,2*4]))
e = tf.reshape(d,[2,2,4])

f = tf.transpose(b,[1,0,2])
g = tf.matmul(a,tf.reshape(f,[3,-1]))
g = tf.reshape(g,[2,2,4])
g = tf.transpose(g,[1,0,2])


with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))
    print('-------------')
    print(sess.run(c))
    print('-------------')
    print(sess.run(d))
    print(sess.run(e))
    print('-------------')
    print(sess.run(f))
    print(sess.run(g))

结果:

[[1. 2. 3.]
 [4. 5. 6.]]
[[[ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]]

 [[13. 14. 15. 16.]
  [17. 18. 19. 20.]
  [21. 22. 23. 24.]]]
-------------
[[[ 38.  44.  50.  56.]
  [ 83.  98. 113. 128.]]

 [[110. 116. 122. 128.]
  [263. 278. 293. 308.]]]
-------------
[[ 70.  76.  82.  88.  94. 100. 106. 112.]
 [151. 166. 181. 196. 211. 226. 241. 256.]]
[[[ 70.  76.  82.  88.]
  [ 94. 100. 106. 112.]]

 [[151. 166. 181. 196.]
  [211. 226. 241. 256.]]]
-------------
[[[ 1.  2.  3.  4.]
  [13. 14. 15. 16.]]

 [[ 5.  6.  7.  8.]
  [17. 18. 19. 20.]]

 [[ 9. 10. 11. 12.]
  [21. 22. 23. 24.]]]
[[[ 38.  44.  50.  56.]
  [ 83.  98. 113. 128.]]

 [[110. 116. 122. 128.]
  [263. 278. 293. 308.]]]
发布了42 篇原创文章 · 获赞 34 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/weixin_41024483/article/details/88536662
今日推荐