pytorch 交换tensor的维度

使用permute函数,可以重新排列tensor的维度


示例

二维张量

import torch

# 生成一个shape为[2, 3]的tensor, 设shape为[dim0, dim1]
x = torch.range(1, 6).reshape([2, 3])

print(x)
'''
tensor([[1., 2., 3.],
        [4., 5., 6.]])
'''

# 将shape由[dim0, dim1]重新排列为[dim1, dim0],即将shape从[2, 3]变为[3, 2]
x = x.permute(1, 0)

print(x)
'''
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
'''

print(x.shape) # [3, 2]

三维张量

import torch

# 生成一个shape为[2, 3, 4]的tensor, 设shape为[dim0, dim1, dim2]
x = torch.range(1, 24).reshape([2, 3, 4])

print(x)
'''
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])
'''

# 将shape由[dim0, dim1, dim2]重新排列为[dim2, dim0, dim1],即将shape从[2, 3, 4]变为[4, 2, 3]
x = x.permute(2, 0, 1)

print(x)
'''
tensor([[[ 1.,  5.,  9.],
         [13., 17., 21.]],

        [[ 2.,  6., 10.],
         [14., 18., 22.]],

        [[ 3.,  7., 11.],
         [15., 19., 23.]],

        [[ 4.,  8., 12.],
         [16., 20., 24.]]])
'''

print(x.shape) # [4, 2, 3]

猜你喜欢

转载自blog.csdn.net/weixin_46566663/article/details/127642683