- 调用
torch.einsum('i, ijk, j->jk', c_ws, u_minus_c, trainCellVol)
时, 提示了如下错误
File "/Users/yczhang/opt/anaconda3/envs/fealpy/lib/python3.8/site-packages/torch/functional.py", line 241, in einsum
return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm_out
根据错误的提示, 可能是 c_ws
, u_minus_c
, trainCellVol
的数据类型不一致造成的.
并注意, pytorch 中的 Tensor
数据 float64
是 double
,而 float32
是标准的 float
.
然后查看下各数据的类型:
c_ws.type()
'torch.DoubleTensor'
u_minus_c.type()
'torch.FloatTensor'
trainCellVol.type()
'torch.DoubleTensor'
那么, 将 c_ws.float()
, 或将 u_minus_c.double()
是都可以的, 即下面的两种方式:
a0 = torch.einsum('i, ijk, j->jk', c_ws, u_minus_c.double(), trainCellVol)
a0 = torch.einsum('i, ijk, j->jk', c_ws.float(), u_minus_c, trainCellVol)