采用python实现torch.matmul


a=torch.ones(2,5,3)
b= torch.ones(1,3,4)

for i in range(2):
    for j in range(5):
        a[i,j,:]=i+j

c=torch.matmul(a,b)
print(a)
print(b)
print(c)
print(c.shape)
# pdb.set_trace()

c1=torch.zeros(2,5,4)
for i in range(2):
    for j in range(5):
        for k in range(4):
            tmp=0
            for k1 in range(3):
                tmp=tmp+a[i,j,k1]*b[0,k1,k]
            c1[i,j,k]=tmp

print(c1.shape)
print(c1)

结果:

tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.]],

        [[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.]]])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[ 0.,  0.,  0.,  0.],
         [ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.]],

        [[ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.],
         [15., 15., 15., 15.]]])
torch.Size([2, 5, 4])
torch.Size([2, 5, 4])
tensor([[[ 0.,  0.,  0.,  0.],
         [ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.]],

        [[ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.],

可见,可以采用类似的方法来实现。


        batch_T = torch.matmul(self.inv_delta_C, batch_C_prime_with_zeros)  # batch_size x F+3 x 2
        t3=time.time()
        # pdb.set_trace()
        batch_P_prime = torch.matmul(self.P_hat, batch_T)  # batch_size x n x 2
        t4=time.time()
        b2=torch.zeros(1,102400, 2)
        for i  in range(102400):
            if i%10==0:
                print(i)
            for j in range(2):
                tmp=0
                for k in range(964):
                    tmp=tmp+self.P_hat[i,k]*batch_T[0,k,j]
                b2[0,i,j]=tmp
        
        for i  in range(102400):
            for j in range(2):
                if b2[0,i,j]!=batch_P_prime[0,i,j]:
                    print(i,j,b2[0,i,j]-batch_P_prime[0,i,j])
                    pdb.set_trace()

不用库,确实特别慢啊,特别慢。

猜你喜欢

转载自blog.csdn.net/anlongstar/article/details/130743888
今日推荐