torch.bmm()

Article Directory

Function introduction

torch.bmm(input, mat2, *, out=None) → Tensor
  • enter:

    • input (Tensor) – the first batch of matrices to be multiplied
    • mat2 (Tensor) – the second batch of matrices to be multiplied
  • Function to perform batch matrix multiplication between inputandmat2

  • inputBoth mat2must be 3-D tensors, they contain the same number of matrices

    • If is a tensor inputof shape[b, n, m]

    • mat2is [b, m, p]a tensor of shape

    • Then the result of the function is the tensor shapeof[b, n, p]

      insert image description here

example

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

Guess you like

Origin blog.csdn.net/qq_52852138/article/details/129745956