pytorchのtorch.mul()とtorch.mm()の違い
torch.mul(a、b)は、行列aとbの対応するビットの乗算です。aとbの次元は等しくなければなりません。たとえば、aの次元は(1、2)で、bの次元は(1、2)は(1、2)の行列です
torch.mm(a、b)は、行列aとbの行列乗算です。たとえば、aの次元は(1、2)、bの次元は(2、3)、返される行列はです。 (1、3)
import torch
a = torch.rand(1, 2)
b = torch.rand(1, 2)
c = torch.rand(2, 3)
print(torch.mul(a, b)) # 返回 1*2 的tensor
print(torch.mm(a, c)) # 返回 1*3 的tensor
print(torch.mul(a, c)) # 由于a、b维度不同,报错