Pytorch增加矩阵维度

原文地址

分类目录——Pytorch

Pytorch通过.unsqueeze(int)方法来增加1个维度,传的int值为增加的维度的索引。下面通过程序来说明其用法。

  • 生成测试数据

    import torch
    t1 = torch.tensor([1,2,3])
    
  • 进行维度增加

    print(t1.unsqueeze(0))
    # tensor([[1, 2, 3]])
    print(t1.unsqueeze(1))
    # tensor([[1],
    #         [2],
    #         [3]])
    print(t1.unsqueeze(2))	# 报错
    # IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
    
  • 从上面可能不容易观察传入索引的作用,观察一下增加后的size()

    print(t1.size())
    print(t1.unsqueeze(0).size())
    print(t1.unsqueeze(1).size())
    # torch.Size([3])
    # torch.Size([1, 3])
    # torch.Size([3, 1])
    

    当前维度是1,增加1维后维度变成2,索引值也就只能由两个(0和1),所以上面以2为索引时报错了

  • 用二维数据进行测试

    t2 = torch.tensor([[1,2,3],[4,5,6]])
    print(t2.size())
    # torch.Size([2, 3])
    print(t2.unsqueeze(0).size())
    # torch.Size([1, 2, 3])
    print(t2.unsqueeze(1).size())
    # torch.Size([2, 1, 3])
    print(t2.unsqueeze(2).size())
    # torch.Size([2, 3, 1])
    print(t2.unsqueeze(3).size())
    # IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
    

    2+1=3,索引可以是0,1,2,当索引时3的时候报错

  • 我想看看增加到三维是个什么样的

    print(t2.unsqueeze(0))
    # tensor([[[1, 2, 3],
    #          [4, 5, 6]]])
    print(t2.unsqueeze(1))
    # tensor([[[1, 2, 3]],
    #
    #         [[4, 5, 6]]])
    
发布了102 篇原创文章 · 获赞 68 · 访问量 5120

猜你喜欢

转载自blog.csdn.net/BBJG_001/article/details/104231386