pytorch-Tensor维度变换

▪ View/reshape
▪ Squeeze/unsqueeze
▪ Transpose/t/permute
▪ Expand/repeat

view/reshape 这两个API功能一样

import torch
a = torch.rand(4,1,28,28)
a
tensor([[[[0.6938, 0.3445, 0.6771,  ..., 0.9931, 0.9991, 0.4440],
          [0.5946, 0.8147, 0.2261,  ..., 0.7437, 0.0628, 0.8737],
          [0.6451, 0.0994, 0.0188,  ..., 0.2206, 0.5386, 0.8999],
          ...,
          [0.7386, 0.1898, 0.0403,  ..., 0.4875, 0.0487, 0.5001],
          [0.7543, 0.5100, 0.1647,  ..., 0.2497, 0.9027, 0.4331],
          [0.7490, 0.0041, 0.3491,  ..., 0.2113, 0.5284, 0.9640]]],


        [[[0.4882, 0.0085, 0.0299,  ..., 0.2456, 0.2243, 0.3239],
          [0.1939, 0.6237, 0.0720,  ..., 0.1124, 0.9007, 0.8293],
          [0.6409, 0.1809, 0.7432,  ..., 0.1125, 0.1982, 0.0389],
          ...,
          [0.2706, 0.3279, 0.1497,  ..., 0.7635, 1.0000, 0.6049],
          [0.8983, 0.2384, 0.5930,  ..., 0.6344, 0.7969, 0.9916],
          [0.5655, 0.2785, 0.7651,  ..., 0.3556, 0.3760, 0.5805]]],


        [[[0.1863, 0.4082, 0.7172,  ..., 0.4033, 0.9391, 0.0029],
          [0.2518, 0.9401, 0.6461,  ..., 0.1238, 0.0022, 0.4185],
          [0.0726, 0.6612, 0.8562,  ..., 0.4416, 0.2241, 0.8930],
          ...,
          [0.5502, 0.8952, 0.1731,  ..., 0.7085, 0.7992, 0.5871],
          [0.8164, 0.2080, 0.0844,  ..., 0.8432, 0.5732, 0.6478],
          [0.6396, 0.0602, 0.0611,  ..., 0.2161, 0.7656, 0.0967]]],


        [[[0.4451, 0.9379, 0.3999,  ..., 0.9933, 0.0786, 0.4820],
          [0.8398, 0.0777, 0.1195,  ..., 0.0305, 0.5948, 0.0538],
          [0.7693, 0.7614, 0.5999,  ..., 0.9883, 0.1184, 0.9129],
          ...,
          [0.1805, 0.6104, 0.9388,  ..., 0.3651, 0.2686, 0.9733],
          [0.5440, 0.7886, 0.5747,  ..., 0.5571, 0.9515, 0.1942],
          [0.1981, 0.7278, 0.4696,  ..., 0.5120, 0.0964, 0.0606]]]])
a.shape
torch.Size([4, 1, 28, 28])
a.view(4, 28*28) #  Lost dim information
tensor([[0.6938, 0.3445, 0.6771,  ..., 0.2113, 0.5284, 0.9640],
        [0.4882, 0.0085, 0.0299,  ..., 0.3556, 0.3760, 0.5805],
        [0.1863, 0.4082, 0.7172,  ..., 0.2161, 0.7656, 0.0967],
        [0.4451, 0.9379, 0.3999,  ..., 0.5120, 0.0964, 0.0606]])
a.reshape(4, 28*28).shape
torch.Size([4, 784])
a.view(4, 28*28).shape
torch.Size([4, 784])
a.view(4*28, 28).shape
torch.Size([112, 28])
a.view(4*1,28,28).shape
torch.Size([4, 28, 28])
b=a.view(4, 784)
b
tensor([[0.6938, 0.3445, 0.6771,  ..., 0.2113, 0.5284, 0.9640],
        [0.4882, 0.0085, 0.0299,  ..., 0.3556, 0.3760, 0.5805],
        [0.1863, 0.4082, 0.7172,  ..., 0.2161, 0.7656, 0.0967],
        [0.4451, 0.9379, 0.3999,  ..., 0.5120, 0.0964, 0.0606]])

Squeeze v.s. unsqueeze

# -4,-3,-2,-1
#  0, 1, 2, 3
#  4, 1,28,28
a.shape
torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape # 右插
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape # 左插
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(4).shape # 没有直接插
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-4).shape # 左插
torch.Size([4, 1, 1, 28, 28])
a.unsqueeze(-5).shape # 没有直接插
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(5).shape # 越界
---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-23-b54eab361a50> in <module>
----> 1 a.unsqueeze(5).shape


IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
import torch
a = torch.tensor([1.2, 2.3])
print(a)
print(a.shape)
print(a.dim())
tensor([1.2000, 2.3000])
torch.Size([2])
1
b = a.unsqueeze(-1)
b
tensor([[1.2000],
        [2.3000]])
b.dim()
2
c = a.unsqueeze(0)
c
tensor([[1.2000, 2.3000]])
c.dim()
2

unsqueeze的例子

b=torch.rand(32)
print(b)
print("-"*70)
print(b.shape)
tensor([0.0243, 0.3869, 0.2683, 0.5434, 0.2612, 0.6457, 0.7067, 0.8813, 0.0767,
        0.2134, 0.8361, 0.9873, 0.0582, 0.5941, 0.3209, 0.3008, 0.4303, 0.8667,
        0.6951, 0.8412, 0.4237, 0.0509, 0.4851, 0.0748, 0.3602, 0.2068, 0.7432,
        0.8633, 0.3234, 0.5924, 0.6898, 0.5217])
----------------------------------------------------------------------
torch.Size([32])
f = torch.rand(4, 32, 14, 14)
f.shape
torch.Size([4, 32, 14, 14])
c=b.unsqueeze(1)
c.shape
torch.Size([32, 1])
c=c.unsqueeze(2)
c.shape
torch.Size([32, 1, 1])
c=c.unsqueeze(0)
c.shape
torch.Size([1, 32, 1, 1])
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
b.shape
torch.Size([1, 32, 1, 1])

squeeze的例子

b.shape
torch.Size([1, 32, 1, 1])
b.squeeze().shape # 将1全部压缩
torch.Size([32])
b.squeeze(0).shape
torch.Size([32, 1, 1])
b.squeeze(-1).shape
torch.Size([1, 32, 1])
b.squeeze(1).shape
torch.Size([1, 32, 1, 1])
b.squeeze(-4).shape
torch.Size([32, 1, 1])

Expand / repeat

▪ Expand: broadcasting
▪ Repeat: memory copied

a = torch.rand(4,32,14,14)
print("a.shape:",a.shape)
print("b.shape:",b.shape)
a.shape: torch.Size([4, 32, 14, 14])
b.shape: torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape
torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])
b.shape
torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape # 1x4 | 32*32 | ...
torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape
torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape
torch.Size([4, 32, 32, 32])

矩阵转置

b.shape
torch.Size([1, 32, 1, 1])
b.t() # .t()只适合二维矩阵
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-90-a21ff70d06f2> in <module>
----> 1 b.t() # .t()只适合二维矩阵


RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
a.shape
# 0 1  2  3
# 4 32 14 14
torch.Size([4, 32, 14, 14])
a1 = a.transpose(1,3) # 1,3交换
a1.shape
torch.Size([4, 14, 14, 32])
a=torch.rand(4,3,28,28)
a.transpose(1,3).shape
torch.Size([4, 28, 28, 3])
b=torch.rand(4,3,28,32)
b.transpose(1,3).shape
torch.Size([4, 32, 28, 3])
b.transpose(1,3).transpose(1,2).shape
torch.Size([4, 28, 32, 3])

互换一步到位permute

b.permute(0,2,3,1).shape
torch.Size([4, 28, 32, 3])

猜你喜欢

转载自blog.csdn.net/MasterCayman/article/details/109395065