Matrix multiplication and broadcasting mechanism in torch

1. Broadcast mechanism broadcast

1. Two tensor " broadcastable " rules:

  • Every tensor has at least one dimension.

  • When iterating dimension sizes, starting from the last dimension, the following conditions are met: (1) the dimensions are equal in size, (2) one of the dimensions is 1, (3) or one of the dimensions does not exist.

Example:

x=torch.empty((0,))
y=torch.empty(2,2)
# x,y不可广播,因为x至少没有一个维度


x=torch.empty(5,7,3)
y=torch.empty(5,7,3)
# 相同的形状总是可广播的


x=torch.empty(5,3,4,1)
y=torch.empty(  3,1,1)
# 第一个尾维度:大小都为1
# 第二个尾维度:y的大小为1
# 第三个尾维度:x的大小== y的大小
# 第四个尾维度,y维不存在
# 满足可广播规则,因此X和y是可广播的。

x=torch.empty(5,2,4,1)
y=torch.empty(  3,1,1)
# x和y是不可广播的,因为在第三个尾维度中2 != 3

2. Broadcastable tensor calculation rules:

If two tensors can be broadcast, the resulting tensor size is calculated as follows:

  • If x and y have unequal dimensions, prepend 1 to the dimension of the tensor with less dimension, making them equal in length.

  • For each dimension size, the resulting dimension size is the maximum of the x and y sizes along that dimension.

Example:

>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])


>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

Note 1: In-place operations do not allow in-place tensors to change shape due to broadcasting

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size() # in_place
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

>>> (x+y).size()
torch.Size([3,3,7])

Note 2: In the case where two tensors do not have the same shape, but are broadcastable and have the same number of elements , the introduction of broadcasting may result in backwards incompatible changes. 

以前会产生一个大小为torch.Size([4,1])的张量,但现在产生一个大小为torch.Size([4,4])的张量。为了帮助识别代码中可能存在广播引入的向后不兼容的情况,可以设置torch.utils.backcompat_broadcast_warning。enabled为True,在这种情况下将生成python警告。

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

2. Various multiplication operations of torch 

1. torch.dot(vec1,vec2)

         It is used to calculate the dot product of two vectors, does not support the broadcast operation, and requires the same number of elements in the two one-dimensional tensors.

vec1(),vec2()\rightarrow torch.dot()\rightarrow out() 

import torch

vec1 = torch.Tensor([1,2,3,4])
vec2 = torch.Tensor([5,6,7,8])
print(torch.dot(vec1, vec2))
# tensor(70.)

2. torch.mm(mat1,mat2)

        The matrix multiplication used to calculate two two-dimensional matrices does not support the broadcast operation, and requires the dimensions of the two Tensors to meet the requirements of matrix multiplication.

mat1(n*m),mat2(m*d)\rightarrow torch.mm()\rightarrow out(n*d)

mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 5)
out = torch.mm(mat1, mat2)
print(out.shape)
# torch.Size([3, 5])

3. torch.bmm(mat1,mat2)

        It is used to calculate the multiplication of two 3D matrices with Batch. It does not support the broadcast operation. It is required that the two inputs of this function must be 3D matrices and the first dimension is the same (indicating the Batch dimension).

mat1(b*n*m),mat2(b*m*d)\rightarrow torch.bmm()\rightarrow out(b*n*d)

mat1 = torch.randn(2, 3, 4)
mat2 = torch.randn(2, 4, 5)
out = torch.bmm(mat1, mat2)
print(out.shape)
# torch.Size([2, 3, 5])

4. torch.mv(mat, object) 

        It is used to calculate the multiplication between a matrix and a vector (the matrix is ​​in front, and the vector is in the back). The broadcast operation is not supported, and the matrix and vector are required to meet the requirements of matrix multiplication.

mat(n*m),vec(m)\rightarrow torch.mv()\rightarrow out(n) 

mat = torch.randn(3, 4)
vec = torch.randn(4)
output = torch.mv(mat, vec)
print(output.shape)
# torch.Size([3])

5. torch.mul(a,b)

        troch.multiply() is equivalent to torch.mul();

        It is used to calculate element-wise multiplication (dot multiplication) of matrices, and supports broadcast operations. As long as the dimensions of a and b meet the broadcast conditions, element-wise multiplication operations can be performed.

mat1,mat2\rightarrow torch.mul()\rightarrow out

A = torch.randn(2,1,4)
B = torch.randn(3, 1) # 矩阵
print(torch.mul(A,B).shape)
# torch.Size([2, 3, 4])
b0 = 2 # 标量
print(torch.mul(A,b0).shape)
# torch.Size([2, 1, 4])
b1 = torch.tensor([1,2,3,4]) # 行向量
print(torch.mul(A,b1).shape)
# torch.Size([2, 1, 4])
b2 = torch.Tensor([1,2,3]).reshape(-1,1) # 列向量
print(torch.mul(A,b2).shape)
# torch.Size([2, 3, 4])

6. torch.matmul(mat1,mat2)

        It can be used to calculate almost all matrix/vector multiplication situations, and supports broadcast operations. It can be understood as the broadcast version of torch.mm. The multiplication rules depend on the dimensions of the two tensors participating in the multiplication.

mat1(j*1*n*m),mat2(k*m*p)\rightarrow torch.matmul()\rightarrow out(j*k*n*p)

        In particular, for multidimensional data matmul() multiplication, it can be considered that the matmul() multiplication uses the last two dimensions of the two parameters to calculate, and other dimensions can be considered as batch dimensions.

mat1 = torch.randn(2,1,4,5)
mat2 = torch.randn(2,1,5,2)
out = torch.matmul(mat1, mat2)
print(out.shape)
# torch.Size([2, 1, 4, 2])

If the two matrices are one-dimensional, the function of this function is the same as torch.dot(), returning the dot product result of two one-dimensional tensors;

vec1 = torch.Tensor([1,2,3,4])
vec2 = torch.Tensor([5,6,7,8])
print(torch.matmul(vec1, vec2))
# tensor(70.)
print(torch.dot(vec1, vec2))
# tensor(70.)

If the two matrices are two-dimensional, the function of this function is the same as torch.mm(), returning the matrix multiplication of two two-dimensional matrices;

mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 5)
out = torch.mm(mat1, mat2)
print(out.shape)
# torch.Size([3, 5])

out1 = torch.matmul(mat1, mat2)
print(out1.shape)
# torch.Size([3, 5])

If the first argument is a 2D tensor (matrix) and the second argument is a 1D tensor (vector), then the product of matrix×vector will be returned. Then the function of this function is the torch.mv()same as that, requiring the matrix and vector to meet the requirements of matrix multiplication;

mat = torch.randn(3, 4)
vec = torch.randn(4)
output = torch.mv(mat, vec)
print(output.shape)
# torch.Size([3])

output1 = torch.matmul(mat, vec)
print(output1.shape)
# torch.Size([3])

If the first parameter is a one-dimensional tensor and the second parameter is a two-dimensional tensor, then add a dimension (broadcast) in front of the one-dimensional tensor, and then perform matrix multiplication;

vec = torch.randn(4)
mat = torch.randn(4,2)
print(torch.matmul(vec, mat).shape)
# torch.Size([2])

7. Operators

@operator   : It works like torch.matmul.

mat1 = torch.randn(2,1,4,5)
mat2 = torch.randn(2,1,5,2)
out = torch.matmul(mat1, mat2)
print(out.shape)
# torch.Size([2, 1, 4, 2])
out1 = mat1 @ mat2
print(out1.shape)
# torch.Size([2, 1, 4, 2])

 *   operator: It works like torch.mul.

A = torch.randn(2,1,4)
B = torch.randn(3, 1) # 矩阵
print(torch.mul(A,B).shape)
# torch.Size([2, 3, 4])

print((A * B).shape)
# torch.Size([2, 3, 4])

8. Extension: torch.einsum(): Einstein summation convention

Put a link:  einsum is all you need!

        If, like me, you find it difficult to remember the names and signatures of functions that compute dot products, outer products, transposes, matrix-vector multiplications, and matrix-matrix multiplications in PyTorch/TensorFlow, then einsum notation is our lifesaver. The einsum notation is an elegant way to express the above operations, including complex tensor operations. Basically, einsum can be regarded as a domain-specific language. Once you understand and take advantage of einsum, you can write more compact and efficient code more quickly, in addition to the benefits of not having to memorize and frequently look up specific library functions. When einsum is not used, it is easy to introduce unnecessary tensor transformation or transposition operations, as well as intermediate tensors that can be omitted.

Just for learning record!

Guess you like

Origin blog.csdn.net/panghuzhenbang/article/details/129732720