使用permute函数,可以重新排列tensor的维度
示例
二维张量
import torch
# 生成一个shape为[2, 3]的tensor, 设shape为[dim0, dim1]
x = torch.range(1, 6).reshape([2, 3])
print(x)
'''
tensor([[1., 2., 3.],
[4., 5., 6.]])
'''
# 将shape由[dim0, dim1]重新排列为[dim1, dim0],即将shape从[2, 3]变为[3, 2]
x = x.permute(1, 0)
print(x)
'''
tensor([[1., 4.],
[2., 5.],
[3., 6.]])
'''
print(x.shape) # [3, 2]
三维张量
import torch
# 生成一个shape为[2, 3, 4]的tensor, 设shape为[dim0, dim1, dim2]
x = torch.range(1, 24).reshape([2, 3, 4])
print(x)
'''
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
'''
# 将shape由[dim0, dim1, dim2]重新排列为[dim2, dim0, dim1],即将shape从[2, 3, 4]变为[4, 2, 3]
x = x.permute(2, 0, 1)
print(x)
'''
tensor([[[ 1., 5., 9.],
[13., 17., 21.]],
[[ 2., 6., 10.],
[14., 18., 22.]],
[[ 3., 7., 11.],
[15., 19., 23.]],
[[ 4., 8., 12.],
[16., 20., 24.]]])
'''
print(x.shape) # [4, 2, 3]