利用unsqueeze做矩阵乘法

import torch
if __name__ == '__main__':
    a = torch.FloatTensor([1,2,3])
    b = torch.FloatTensor([4,5,6])
    x = torch.unsqueeze(a,dim=1)
    y = torch.unsqueeze(b,dim=0)
    print(x)
    print(y)
    print(x*y)

output:

tensor([[1.],
        [2.],
        [3.]])
tensor([[4., 5., 6.]])
tensor([[ 4.,  5.,  6.],
        [ 8., 10., 12.],
        [12., 15., 18.]])

发布了41 篇原创文章 · 获赞 44 · 访问量 7648

猜你喜欢

转载自blog.csdn.net/tailonh/article/details/105319407