torch.bmm()验证

官网的说明

torch.bmm(input, mat2, out=None) → Tensor

bmm的输入必须是3维的。其他维度会出错:

import torch
a = torch.Tensor(4,2,2,3)
b = torch.Tensor(4,2,3,5)
c = torch.bmm(a,b)

Traceback (most recent call last):
  File "/Users/XXX/Desktop/MyCode/xxx.py", line 1436, in <module>
    c = torch.bmm(a,b)
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)

下面我们演示一下bmm的使用:

import torch

a = torch.stack( [torch.ones(3,4)*torch.tensor(i+1) for i in range(5)], dim=0)
b = a.transpose(1,2)
#a.shape: (5,3,4)
#b.shape: (5,4,3)

c = torch.bmm(a,b)
#c.shape: (5,3,3)
print(c)

tensor([[[  4.,   4.,   4.],
         [  4.,   4.,   4.],
         [  4.,   4.,   4.]],

        [[ 16.,  16.,  16.],
         [ 16.,  16.,  16.],
         [ 16.,  16.,  16.]],

        [[ 36.,  36.,  36.],
         [ 36.,  36.,  36.],
         [ 36.,  36.,  36.]],

        [[ 64.,  64.,  64.],
         [ 64.,  64.,  64.],
         [ 64.,  64.,  64.]],

        [[100., 100., 100.],
         [100., 100., 100.],
         [100., 100., 100.]]])

代码中我们设置了5个3*4 的tensor stack在一起,其转置相应的是 4*3。
我们的a中的每个都是一个全1到全5的矩阵。我们知道:
I ∈ R m ∗ n I \in R^{m*n} IRmn,
I ∗ I T = n ∗ I ′ , I ′ ∈ R m ∗ m I*I^{T}=n*I',I'\in R^{m*m} IIT=nI,IRmm ,
a I ∗ a I T = n ∗ a 2 ∗ I ′ aI*aI^T=n*a^2*I' aIaIT=na2I
上述的结果正好是 a = 1 , 2 , 3 , 4 , 5 a=1,2,3,4,5 a=1,2,3,4,5的情况,因此,bmm的作用是batch号相同的两个矩阵之间的矩阵乘,不同batch号之间的矩阵无关联!

猜你喜欢

转载自blog.csdn.net/weixin_45703452/article/details/119891510