dot或者matmul表示矩阵相乘;
multiply或者*表示矩阵中的对应元素相乘;
示例:
import numpy as np # a = torch.randn(2, 3, 4) # b = torch.randn(4, 4) # c = torch.matmul(a, b) # print(c) # print(c.size()) a = np.arange(0, 9).reshape(3, 3) b = np.arange(0, 9).reshape(3, 3) np.ma c = np.multiply(a, b) d = a * b e = np.dot(a, b) print(a) print(b) print(c) print(d) print(e)
输出:
[[0 1 2]
[3 4 5]
[6 7 8]]
[[0 1 2]
[3 4 5]
[6 7 8]]
[[ 0 1 4]
[ 9 16 25]
[36 49 64]]
[[ 0 1 4]
[ 9 16 25]
[36 49 64]]
[[ 15 18 21]
[ 42 54 66]
[ 69 90 111]]
Process finished with exit code 0