【仿真基本功】【PyTorch】张量变形相关函数的使用方法及说明

PyTorch中经常用到Tensor的变形。近期写GNN相关代码时,大量出现相关操作。为了避免以后仍需重复测试,特将相关内容、测试结果及个人理解整理在下面。自用的同时服务大家。同时也欢迎大家收藏和帮助我一起完善,谢谢各位~

0. 本文考虑的基本张量(Tensor)

import torch
a = torch.randn([3, 2, 4, 5])
b = torch.randn([3, 2, 5])

1. Tensor的形状

想知道Tensor的形状,可以使用shape函数

print(a.shape)  # 输出结果:torch.Size([3, 2, 4, 5])

2. Tensor的维度置换

Tensor的维度置换可以使用permute()。

(1) 当Tensor仅有其中两个维度需要置换时,可以使用transpose()。

虽然此时transpose()的速度与permute()相同(经测试验证),但是permute()中需要把Tensor的所有维度写全,transpose()中只需要写需要置换的两个维度即可。

例如: a.transpose(1, 3) 与 a.permute(0, 3, 2, 1) 效果相同,但后者必须把a的四个维度写全

print(a.transpose(1, 3).shape)  # 输出结果:torch.Size([3, 5, 4, 2])
print(a.transpose(3, 1).shape)  # transpose()的两个参数的顺序不影响结果,这里输出结果与上面相同(个人强迫症,喜欢把小的放在前面)

(2) 当Tensor有超过两个维度需要置换时,可以使用permute()或者多个transpose()。

例如: a.transpose(1, 3).transpose(2, 3) 与 a.permute(0, 3, 1, 2) 效果相同但建议使用permute(),因为对较大规模的四维张量测试(使用小规模张量时考虑运行速度的意义不大)后发现,permute()比transpose()快50%左右。

3. Tensor的变形

(1) Tensor的一般变形: reshape()
(2) Tensor的临时变形: view()

注:如果只需要使用一次Tensor变形后的结果,可以使用view()而非reshape(),这样可以节省内存。因为reshape()需要开辟新的内存,而view()不需要。
具体可参考: PyTorch:view() 与 reshape() 区别详解

print(a.reshape(3*2, 4*5).shape)  # 输出结果:torch.Size([6, 20])
print(a.view(3*2, 4*5).shape)  # 此时的形状为[6, 20], 数据与Tensor a共享
c = a.view(3*2, 4*5)  # 如果将view()处理后的结果赋值给其他变量,那么仍然需要开辟新的内存存储变形后的结果,此时view()的优势将不复存在

4. Tensor的拼接

在实际应用中,有时需要将张量的某一维度进行拓展。

(1) 一般的拼接:将待拼接的ensor用[]放在一起(可以超过两个)。要求:各待拼接Tensor除了待拼接的维度,其余维度的大小均相同。

注:Tensor的拼接不可以仅使用[],因为拼接后每个元素的类型为Tensor,拼接结果的类型却是List,不符合需求。

d = b.reshape(3, 2, 1, 5)  # d -> shape: [3, 2, 1, 5]
e1 = torch.cat([a, d], dim=2)  # e -> shape: [3, 2, 4+1, 5]

# 更简洁高效的写法
e2 = torch.cat([a, b.view([3, 2, 1, 5)], dim=2)  # 结果与e1相同,但是更节省内存

(2) 通过拼接实现某一维度的扩展

f = torch.cat([b.view(3, 2, 1, 5)]*4, dim=2)  # f -> shape: [3, 2, 4, 5]

5. (本文核心内容)Tensor的拼接与变形

这里介绍“如何保证Tensor在各种拼接和变形后,各维度数据不发生错乱”,这是本文要介绍的核心内容,也是个人经过测试后给出的万无一失的办法

任务:将shape为 [ 3 , 2 , 5 ] [3, 2, 5] [3,2,5]的Tensor b复制扩展为shape为 [ 3 , 2 , 4 , 5 ] [3, 2, 4, 5] [3,2,4,5]的Tensor,然后reshape为 [ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5] [243,5]的Tensor

方法1

c = torch.cat([b.view(3, 2, 1, 5)]*4, dim=2)  # c -> shape: [3, 2, 4, 5]
d = c.permute([1, 2, 0, 3]).reshape(-1, 5)  # d -> shape: [2*4*3, 5]

注:第二行代码不能直接reshape,否则得到的Tensor的shape实际上是 [ 3 ∗ 2 ∗ 4 , 5 ] [3*2*4, 5] [324,5], 此时在与其他同(或类)shape的Tensor拼接或运算时,维度与数据会出现错乱。
例如,还有一个shape是 [ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5] [243,5]的Tensor与d相加。如果维度错乱,那么就不是对应元素相加了;而转换为 [ 2 , 4 , 3 , 5 ] [2,4,3,5] [2,4,3,5]的shape又额外增加代码。

方法2

c = torch.cat([b]*4, dim=0)  # c-> shape:[4*3, 2, 5]
d = c.transpose(0, 1).reshape(-1, 5)  # d -> shape: [2*4*3, 5]

注:这里的第二行代码同样不能直接reshape,否则得到的Tensor的shape实际上是 [ 4 ∗ 3 ∗ 2 , 5 ] [4*3*2, 5] [432,5]

讲解重点之前,我们先以b为例说明。b的shape是 [ 3 , 2 , 5 ] [3, 2, 5] [3,2,5], 将b具体表示出来(randn随机生成的一个例子)就是

  tensor([[[ 0.3543, -0.9587, -0.6313,  1.5067,  1.4628],
           [-0.0671,  1.1080,  0.5200, -0.2528,  0.2759]],
          [[ 0.1023, -1.7001,  0.0717,  0.2326,  0.1111],
           [-0.8022,  0.6989, -0.6247, -1.1926, -0.3376]],
          [[ 0.4788,  0.3146,  0.4460, -0.0280, -1.0335],
           [-2.4860,  0.7232,  0.5325,  0.4981, -0.0081]]])

由上面的例子可以看出,shape中的“5”是指 [ 5 , ] [5, ] [5,]的向量,shape中的“2”是指有2个 [ 5 , ] [5, ] [5,]的向量,shape中的“3”是指有3个(2个 [ 5 , ] [5, ] [5,]的向量)。

下面我们再回头关注方法2里的第一行代码。使用

torch.cat([b]*n, dim=i)

形式的代码,就是将b的第i个维度的大小(d_i)乘上了n,表示第i个维度有n个 [d_i, ] 的量,因此乘的形式是 nd_i 而非 d_in。此时由reshape()进行变形,才能得到正确的 [ . . . , n , d i , . . . ] [..., n, d_i, ...] [...,n,di,...] 而非 [ . . . , d i , n , . . . ] [..., d_i, n, ...] [...,di,n,...]。( [ . . . , n , d i , . . . ] [..., n, d_i, ...] [...,n,di,...]表示有 …个(n个(d_i个(…的向量))))

如果对你有用,请帮忙点个赞鼓励我创作哦~

猜你喜欢

转载自blog.csdn.net/AbaloneVH/article/details/128942839
今日推荐