pytorch张量维度变换详解:view、squeeze、transpose

目录

1 view函数

1.1 指定变换后的维度

1.2 自动推理变换后的维度

1.3 将tensor展平成一维

2 reshape函数

2.1 指定变换后的维度

2.2 自动推理转换后的维度

2.3 将tensor展平成一维

2.4 使用tensor.reshape变换

3 squeeze函数

3.1 torch.squeeze去除所有为1的维度

3.2 torch.squeeze指定dim去除

3.3 tensor.squeeze去除为1的维度

4 unsqueeze函数

4.1 torch.unsqueeze指定dim插入新维度

4.2 tensor.unsqueeze指定dim插入新维度

5 transpose函数

5.1 torch.transpose转置指定维度

5.2 tensor.transpose转置指定维度

6 expand函数

7 repeat函数

8 permute函数


Pytorch张量维度变化是在构建模型过程中常用且重要的操作,本文从实际应用触发,详细介绍常用的维度变化方法,这些方法包含view、reshap、squeeze、unsqueeze、transpose等。

1 view函数

Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。

view函数的操作对象是Tensor类型,返回的对象类型也为Tensor类型

    def view(self, *size: _int) -> Tensor: ...

更便于理解的表示形式:

view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。

1.1 指定变换后的维度

通过手工指定,将一个一维tensor变换为3*8维的tensor

import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])

a2 = a1.view(3, 8)
print(a1)
print(a2)
print(a1.shape)
print(a2.shape)

运行程序显示如下:

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16],
        [17, 18, 19, 20, 21, 22, 23, 24]])
torch.Size([24])
torch.Size([3, 8])

1.2 自动推理变换后的维度

如果某个参数为-1,则表示该维度取决于其它维度,由Pytorch自己补充

import torch

a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])

a4 = a3.view(4, -1)
a5 = a3.view(2, 3, -1)
a6 = a3.view(-1, 3, 2)

print(a3)
print(a4)
print(a5)
print(a6)
print(a3.shape)
print(a4.shape)
print(a5.shape)
print(a6.shape)

 运行程序显示如下:

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24])
tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18],
        [19, 20, 21, 22, 23, 24]])
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],
        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],
        [[ 7,  8],
         [ 9, 10],
         [11, 12]],
        [[13, 14],
         [15, 16],
         [17, 18]],
        [[19, 20],
         [21, 22],
         [23, 24]]])
torch.Size([24])
torch.Size([4, 6])
torch.Size([2, 3, 4])
torch.Size([4, 3, 2])

1.3 将tensor展平成一维

import torch

a7 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                   [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
a8 = a6.view(-1)
print(a7)
print(a8)
print(a7.shape)
print(a8.shape)

 运行程序显示如下:

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24])
torch.Size([2, 12])
torch.Size([24])

2 reshape函数

返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。

torch.reshape(input, shape) → [Tensor]

也可以直接在Tensor上使用reshape,形式如下:

tensor.reshape(shape) → [Tensor]

2.1 指定变换后的维度

import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = torch.reshape(a1, (3, 4))
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)

运行程序显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

2.2 自动推理转换后的维度

import torch

a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a4 = torch.reshape(a1, (-1, 6))
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)

运行程序显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([2, 6])
tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12]])

2.3 将tensor展平成一维

import torch

a5 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
a6 = torch.reshape(a1, (-1,))
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)

运行程序显示如下:

torch.Size([2, 6])
tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12]])
torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

2.4 使用tensor.reshape变换

improt torch

a7 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a8 = a7.reshape(6, 2)
a9 = a7.reshape(-1, 3)
a10 = a9.reshape(-1)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)
print(a9.shape)
print(a9)
print(a10.shape)
print(a10)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([6, 2])
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
torch.Size([4, 3])
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

3 squeeze函数

将input张量中所有维度数据为1的维度给移除掉。指定了dim,如果dim对应维度的值不为1 ,则保持不变,为1则移除该维度。

 torch.squeeze(input, dim=None) → [Tensor]

 也可以在tensor上直接使用squeeze,形式如下:

 tensor.squeeze(dim=None) → [Tensor]

3.1 torch.squeeze去除所有为1的维度

import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4)
a3 = torch.squeeze(a2)

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)

 运行结果显示如下:(a2的第二个维度被移除)

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

3.2 torch.squeeze指定dim去除为1的维度

import torch

a4 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a5 = a1.reshape(3, 1, 4)
a6 = torch.squeeze(a5, 0)
a7 = torch.squeeze(a5, 1)

print(a4.shape)
print(a4)
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)

运行结果显示如下:(a5的第一个维度不为1,所以保持不变;a5的第二个维度为1,所以被移除)

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

3.3 tensor.squeeze指定dim去除为1的维度

import torch

a8 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a9 = a8.reshape(3, 1, 4)
a10 = a9.squeeze()
a11 = a9.squeeze(0)
a12 = a9.squeeze(1)

print(a8.shape)
print(a8)
print(a9.shape)
print(a9)
print(a10.shape)
print(a10)
print(a11.shape)
print(a11)
print(a12.shape)
print(a12)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

4 unsqueeze函数

在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾

torch.unsqueeze(input, dim) → [Tensor]

 也可以在tensor上直接使用unsqueeze,形式如下:

torch.unsqueeze(dim) → [Tensor]

4.1 torch.unsqueeze指定dim插入新维度

import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 4)
a3 = torch.unsqueeze(a2, 0)
a4 = torch.unsqueeze(a2, 2)

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
torch.Size([1, 3, 4])
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
         [ 2],
         [ 3],
         [ 4]],

        [[ 5],
         [ 6],
         [ 7],
         [ 8]],

        [[ 9],
         [10],
         [11],
         [12]]])

4.2 tensor.unsqueeze指定dim插入新维度

import torch

a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a6 = a5.reshape(3, 4)
a7 = a6.unsqueeze(0)
a8 = a6.unsqueeze(1)

print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)

运行结果显示如下:

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
torch.Size([1, 3, 4])
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]]])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],

        [[ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12]]])

5 transpose函数

返回 input 张量的转置,dim0与dim1交换位置

torch.transpose(input, dim0, dim1) → [Tensor]

  也可以在tensor上直接使用unsqueeze,形式如下:

tensor.transpose(dim0, dim1) → [Tensor]

参数:

  • input ([Tensor] 输入的张量
  • dim0 ([int] 第一个要转置的维度
  • dim1 ([int] 第二个要转置的维度

5.1 torch.transpose转置指定维度

import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(4, 3, 1)
a3 = torch.transpose(a2, 0, 1)
a4 = torch.transpose(a2, 1, 2)

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([4, 3, 1])
tensor([[[ 1],
         [ 2],
         [ 3]],
        [[ 4],
         [ 5],
         [ 6]],
        [[ 7],
         [ 8],
         [ 9]],
        [[10],
         [11],
         [12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
         [ 4],
         [ 7],
         [10]],
        [[ 2],
         [ 5],
         [ 8],
         [11]],
        [[ 3],
         [ 6],
         [ 9],
         [12]]])
torch.Size([4, 1, 3])
tensor([[[ 1,  2,  3]],
        [[ 4,  5,  6]],
        [[ 7,  8,  9]],
        [[10, 11, 12]]])

5.2 tensor.transpose转置指定维度

import torch

a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a6 = a1.reshape(4, 3, 1)
a7 = a6.transpose(0, 1)
a8 = a6.transpose(1, 2)

print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([4, 3, 1])
tensor([[[ 1],
         [ 2],
         [ 3]],
        [[ 4],
         [ 5],
         [ 6]],
        [[ 7],
         [ 8],
         [ 9]],
        [[10],
         [11],
         [12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
         [ 4],
         [ 7],
         [10]],
        [[ 2],
         [ 5],
         [ 8],
         [11]],
        [[ 3],
         [ 6],
         [ 9],
         [12]]])
torch.Size([4, 1, 3])
tensor([[[ 1,  2,  3]],
        [[ 4,  5,  6]],
        [[ 7,  8,  9]],
        [[10, 11, 12]]])

6 expand函数

返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。

Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;

扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。

Tensor.expand( *sizes) → [Tensor]

参数:

  • sizes (torch.Size or [int] – 指定维度复制的次数
import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4, 1)

# 维度为 1 的 size 可以扩展成什么任意的 size
a3 = a2.expand(3, 5, 4, 2)

# -1 表示对应的维度size不变,但如果第一个维度3扩展成6则会报错,维度不为1不能扩展
a4 = a2.expand(-1, 5, -1, -1)

# 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为-1
a5 = a2.expand(2, -1, 5, -1, -1)

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
print(a5.shape)
print(a5)

运行结果显示如下 :(维度不为1则不能扩展)

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4, 1])
tensor([[[[ 1],
          [ 2],
          [ 3],
          [ 4]]],
        [[[ 5],
          [ 6],
          [ 7],
          [ 8]]],
        [[[ 9],
          [10],
          [11],
          [12]]]])
torch.Size([3, 5, 4, 2])
tensor([[[[ 1,  1],
          [ 2,  2],
          [ 3,  3],
          [ 4,  4]],
         [[ 1,  1],
          [ 2,  2],
          [ 3,  3],
          [ 4,  4]],
         [[ 1,  1],
          [ 2,  2],
          [ 3,  3],
          [ 4,  4]],
         [[ 1,  1],
          [ 2,  2],
          [ 3,  3],
          [ 4,  4]],
         [[ 1,  1],
          [ 2,  2],
          [ 3,  3],
          [ 4,  4]]],
        [[[ 5,  5],
          [ 6,  6],
          [ 7,  7],
          [ 8,  8]],
         [[ 5,  5],
          [ 6,  6],
          [ 7,  7],
          [ 8,  8]],
         [[ 5,  5],
          [ 6,  6],
          [ 7,  7],
          [ 8,  8]],
         [[ 5,  5],
          [ 6,  6],
          [ 7,  7],
          [ 8,  8]],
         [[ 5,  5],
          [ 6,  6],
          [ 7,  7],
          [ 8,  8]]],
        [[[ 9,  9],
          [10, 10],
          [11, 11],
          [12, 12]],
         [[ 9,  9],
          [10, 10],
          [11, 11],
          [12, 12]],
         [[ 9,  9],
          [10, 10],
          [11, 11],
          [12, 12]],
         [[ 9,  9],
          [10, 10],
          [11, 11],
          [12, 12]],
         [[ 9,  9],
          [10, 10],
          [11, 11],
          [12, 12]]]])
torch.Size([3, 5, 4, 1])
tensor([[[[ 1],
          [ 2],
          [ 3],
          [ 4]],
         [[ 1],
          [ 2],
          [ 3],
          [ 4]],
         [[ 1],
          [ 2],
          [ 3],
          [ 4]],
         [[ 1],
          [ 2],
          [ 3],
          [ 4]],
         [[ 1],
          [ 2],
          [ 3],
          [ 4]]],
        [[[ 5],
          [ 6],
          [ 7],
          [ 8]],
         [[ 5],
          [ 6],
          [ 7],
          [ 8]],
         [[ 5],
          [ 6],
          [ 7],
          [ 8]],
         [[ 5],
          [ 6],
          [ 7],
          [ 8]],
         [[ 5],
          [ 6],
          [ 7],
          [ 8]]],
        [[[ 9],
          [10],
          [11],
          [12]],
         [[ 9],
          [10],
          [11],
          [12]],
         [[ 9],
          [10],
          [11],
          [12]],
         [[ 9],
          [10],
          [11],
          [12]],
         [[ 9],
          [10],
          [11],
          [12]]]])
torch.Size([2, 3, 5, 4, 1])
tensor([[[[[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]]],
         [[[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]]],
         [[[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]]]],
        [[[[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]],
          [[ 1],
           [ 2],
           [ 3],
           [ 4]]],
         [[[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]],
          [[ 5],
           [ 6],
           [ 7],
           [ 8]]],
         [[[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]],
          [[ 9],
           [10],
           [11],
           [12]]]]])

7 repeat函数

根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据

Tensor.repeat( *sizes) → [Tensor]

参数:

  • sizes (torch.Size or [int] – 指定维度复制的次数 
import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print(a1.storage().data_ptr())
a2 = a1.reshape(3, 1, 4)
print(a2.storage().data_ptr())
a3 = a2.expand(3, 3, -1)

# expand 操作后,张量的内存地址没变
print(a3.storage().data_ptr())

a4 = a2.repeat(2, 4, 1)

# repeat 操作后,张量的内存地址会改变
print(a4.storage().data_ptr())

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)

 运行结果显示如下:

1974461518528
1974461518528
1974461518528
1974462302208
torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],
        [[ 5,  6,  7,  8]],
        [[ 9, 10, 11, 12]]])
torch.Size([3, 3, 4])
tensor([[[ 1,  2,  3,  4],
         [ 1,  2,  3,  4],
         [ 1,  2,  3,  4]],
        [[ 5,  6,  7,  8],
         [ 5,  6,  7,  8],
         [ 5,  6,  7,  8]],
        [[ 9, 10, 11, 12],
         [ 9, 10, 11, 12],
         [ 9, 10, 11, 12]]])
torch.Size([6, 4, 4])
tensor([[[ 1,  2,  3,  4],
         [ 1,  2,  3,  4],
         [ 1,  2,  3,  4],
         [ 1,  2,  3,  4]],
        [[ 5,  6,  7,  8],
         [ 5,  6,  7,  8],
         [ 5,  6,  7,  8],
         [ 5,  6,  7,  8]],
        [[ 9, 10, 11, 12],
         [ 9, 10, 11, 12],
         [ 9, 10, 11, 12],
         [ 9, 10, 11, 12]],
        [[ 1,  2,  3,  4],
         [ 1,  2,  3,  4],
         [ 1,  2,  3,  4],
         [ 1,  2,  3,  4]],
        [[ 5,  6,  7,  8],
         [ 5,  6,  7,  8],
         [ 5,  6,  7,  8],
         [ 5,  6,  7,  8]],
        [[ 9, 10, 11, 12],
         [ 9, 10, 11, 12],
         [ 9, 10, 11, 12],
         [ 9, 10, 11, 12]]])

8 permute函数

返回重新排列的张量

torch.permute(input, dims) → [Tensor]

 也可以在tensor上直接使用permute,形式如下: 

tensor.permute(dims) → [Tensor]

参数:

  • input ([Tensor] 要重新排列的张量
  • dims (tuple of python:int) 需要重排的维度索引数组
import torch

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4)
a3 = torch.permute(a2, (2, 0, 1))
a4 = torch.permute(a2, (1, 0, 2))
a5 = a2.permute(1, 2, 0)

print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
print(a5.shape)
print(a5)

运行结果显示如下:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1,  2,  3,  4]],
        [[ 5,  6,  7,  8]],
        [[ 9, 10, 11, 12]]])
torch.Size([4, 3, 1])
tensor([[[ 1],
         [ 5],
         [ 9]],
        [[ 2],
         [ 6],
         [10]],
        [[ 3],
         [ 7],
         [11]],
        [[ 4],
         [ 8],
         [12]]])
torch.Size([1, 3, 4])
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]]])
torch.Size([1, 4, 3])
tensor([[[ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11],
         [ 4,  8, 12]]])

猜你喜欢

转载自blog.csdn.net/lsb2002/article/details/132905346