【Pytorch】关于torch.matmul和torch.bmm的输出tensor数值不一致问题

发现

对于torch.matmul和torch.bmm,都能实现对于batch的矩阵乘法:

a = torch.rand((2,3,10))
b = torch.rand((2,2,10))
### matmal()
res1 = torch.matmul(a,b.transpose(1,2))
print res1
"""
...
[torch.FloatTensor of size 2x3x2]
"""
### bmm()
res2 = torch.bmm(a,b.transpose(1,2))
print res2
"""
...
[torch.FloatTensor of size 2x3x2]
"""
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
从打印出来的几位数字来看,嗯,是一样的,但是在用等式(或者torch.eq())检查是却发现了问题,竟然有很多不一样的元素

print torch.eq(res1,res2)
"""
(0 ,.,.) =
1 0
1 1
1 1

(1 ,.,.) =
0 1
1 1
1 1
[torch.ByteTensor of size 2x3x2]
"""
1
2
3
4
5
6
7
8
9
10
11
12
13
将一样的数值在ipython直接输出(print会截断位数)

>>>res1[0,0,0]
2.229752540588379
>>>res2[0,0,0]
2.229752540588379
1
2
3
4
再来看看不一样的

>>>res1[0,0,1]
3.035151720046997
>>>res2[0,0,1]
3.035151481628418
1
2
3
4
可以看到从小数点后位7位开始两个输出值出现了差异!

结论

所以说在tensor的同样操作下,出现不一致结果(精度上)的可能性很大,在做相等条件判断时需要注意,即使同样的输入同样的操作可能出现不一样的结果。
之后又尝试对于a,b的位置进行交换,竟然发现即使是同一个函数操作,如matmal(),matmul(a,b.transpose(1,2))和matmul(b,a.transpose(1,2)).transpose(1,2)结果也存在不一样的元素。
---------------------
作者:Laox1ao
来源:CSDN
原文:https://blog.csdn.net/laox1ao/article/details/79159303
版权声明:本文为博主原创文章,转载请附上博文链接!

猜你喜欢

转载自www.cnblogs.com/jfdwd/p/11068704.html