torch.bmm(input, mat2, out=None) → Tensor
bmm的输入必须是3维的。其他维度会出错:
import torch
a = torch.Tensor(4,2,2,3)
b = torch.Tensor(4,2,3,5)
c = torch.bmm(a,b)
Traceback (most recent call last):
File "/Users/XXX/Desktop/MyCode/xxx.py", line 1436, in <module>
c = torch.bmm(a,b)
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
下面我们演示一下bmm的使用:
import torch
a = torch.stack( [torch.ones(3,4)*torch.tensor(i+1) for i in range(5)], dim=0)
b = a.transpose(1,2)
#a.shape: (5,3,4)
#b.shape: (5,4,3)
c = torch.bmm(a,b)
#c.shape: (5,3,3)
print(c)
tensor([[[ 4., 4., 4.],
[ 4., 4., 4.],
[ 4., 4., 4.]],
[[ 16., 16., 16.],
[ 16., 16., 16.],
[ 16., 16., 16.]],
[[ 36., 36., 36.],
[ 36., 36., 36.],
[ 36., 36., 36.]],
[[ 64., 64., 64.],
[ 64., 64., 64.],
[ 64., 64., 64.]],
[[100., 100., 100.],
[100., 100., 100.],
[100., 100., 100.]]])
代码中我们设置了5个3*4 的tensor stack在一起,其转置相应的是 4*3。
我们的a中的每个都是一个全1到全5的矩阵。我们知道:
I ∈ R m ∗ n I \in R^{m*n} I∈Rm∗n,
I ∗ I T = n ∗ I ′ , I ′ ∈ R m ∗ m I*I^{T}=n*I',I'\in R^{m*m} I∗IT=n∗I′,I′∈Rm∗m ,
a I ∗ a I T = n ∗ a 2 ∗ I ′ aI*aI^T=n*a^2*I' aI∗aIT=n∗a2∗I′
上述的结果正好是 a = 1 , 2 , 3 , 4 , 5 a=1,2,3,4,5 a=1,2,3,4,5的情况,因此,bmm的作用是batch号相同的两个矩阵之间的矩阵乘,不同batch号之间的矩阵无关联!