Pytorch张量的拆分与拼接

Pytorch张量的拆分与拼接

预览

在 PyTorch 中,对张量 (Tensor) 进行拆分通常会用到两个函数:

  • torch.split [按块大小拆分张量]
  • torch.chunk [按块数拆分张量]

而对张量 (Tensor) 进行拼接通常会用到另外两个函数:

  • torch.cat [按已有维度拼接张量]
  • torch.stack [按新维度拼接张量]

1.张量的拆分

  • torch.split函数
 torch.split(tensor, split_size_or_sections, dim = 0)

块大小拆分张量
tensor 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分。dim=0是按照行拆分,dim=1是按照列拆分。如果是三维向量的话,可以按照dim=2在矩阵的方向上划分。
split_size_or_sections 表示在 dim 维度拆分张量时每一块在该维度的尺寸大小 (int),或各块尺寸大小的列表 (list)
指定每一块的尺寸大小后,如果在该维度无法整除,则最后一块会取余数,尺寸较小一些
如:长度为 10 的张量,按单位长度 3 拆分,则前三块长度为 3,最后一块长度为 1
函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 tensor

import torch
X = torch.randn(6, 2)
Y=torch.split(X, 2, dim = 0)
#返回一个元组tutle
(tensor([[-0.0039, -0.1259],
        [-0.7630,  1.3833]]), tensor([[-0.7960,  0.2523],
        [-0.5351, -0.5850]]), tensor([[ 0.3403, -0.2898],
        [-0.3122, -0.7490]]))
Y=torch.split(X, 4, dim = 0)
#除不尽的取余数
(tensor([[ 1.4674,  0.7185],
        [ 0.4943,  1.4040],
        [-1.5243,  0.0566],
        [-1.2039, -0.3079]]), tensor([[-2.9470, -1.6064],
        [-0.8393, -0.5528]]))
  • torch.chunk 函数
torch.chunk(input, chunks, dim = 0)

块数拆分张量
input 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分
chunks 表示在 dim 维度拆分张量时最后所分出的总块数 (int),根据该块数进行平均拆分
指定总块数后,如果在该维度无法整除,则每块长度向上取整,最后一块会取余数,尺寸较小一些,若余数恰好为 0,则会只分出 chunks - 1 块
如:
长度为 6 的张量,按 4 块拆分,则只分出三块,长度为 2 (6 / 4 = 1.5 → 2)
长度为 10 的张量,按 4 块拆分,则前三块长度为 3 (10 / 4 = 2.5 → 3),最后一块长度为 1
函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 input

In [1]: X = torch.randn(6, 2)

In [2]: X
Out[2]:
tensor([[-0.3711,  0.7372],
        [ 0.2608, -0.1129],
        [-0.2785,  0.1560],
        [-0.7589, -0.8927],
        [ 0.1480, -0.0371],
        [-0.8387,  0.6233]])

In [3]: torch.chunk(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560]]),
 tensor([[-0.7589, -0.8927],
         [ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [4]: torch.chunk(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [5]: torch.chunk(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [6]: Y = torch.randn(10, 2)

In [6]: Y
Out[6]:
tensor([[-0.9749,  1.3103],
        [-0.4138, -0.8369],
        [-0.1138, -1.6984],
        [ 0.7512, -0.3417],
        [-1.4575, -0.4392],
        [-0.2035, -0.2962],
        [-0.7533, -0.8294],
        [ 0.0104, -1.3582],
        [-1.5781,  0.8594],
        [ 0.0286,  0.7611]])

In [7]: torch.chunk(Y, 4, dim = 0)
Out[7]:
(tensor([[-0.9749,  1.3103],
         [-0.4138, -0.8369],
         [-0.1138, -1.6984]]),
 tensor([[ 0.7512, -0.3417],
         [-1.4575, -0.4392],
         [-0.2035, -0.2962]]),
 tensor([[-0.7533, -0.8294],
         [ 0.0104, -1.3582],
         [-1.5781,  0.8594]]),
 tensor([[0.0286, 0.7611]]))

这个函数还是很好理解的

2.张量的合并

可以用torch.cat和torch.stack方法将多个张量合并,但是torch.cat仅仅是张量的连接,不会增加维度,而torch.stack是堆叠,会增加维度。

  • cat方法
torch.cat(tensors, dim = 0, out = None)

已有维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的所在维度,即在第几维对张量进行拼接,除该拼接维度外,其余维度上待拼接张量的尺寸必须相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors
在这里插入图片描述

- stack方法

torch.stack(tensors, dim = 0, out = None)

新维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的新维度对应已有维度的插入索引,即在原来第几维的位置上插入新维度对张量进行拼接,待拼接张量在所有已有维度上的尺寸必须完全相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors

In [1]: x = torch.randn(2, 3)
In [2]: x
Out[2]:
tensor([[-0.0288,  0.6936, -0.6222],
        [ 0.8786, -1.1464, -0.6486]])

In [3]: torch.stack((x, x, x), dim = 0)
Out[3]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]]])

In [4]: torch.stack((x, x, x), dim = 0).shape
Out[4]: torch.Size([3, 2, 3])

In [5]: torch.stack((x, x, x), dim = 1)
Out[5]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222]],

        [[ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486]]])

In [6]: torch.stack((x, x, x), dim = 1).shape
Out[6]: torch.Size([2, 3, 3])

猜你喜欢

转载自blog.csdn.net/weixin_43427721/article/details/107208470