PyTorch中张量的操作:拼接、切分、比较、索引和变换

张量的拼接

torch.cat()

torch.cat(tensors, 
          dim=0, 
          out=None)

功能: 将张量按维度dim进行拼接

  • tensors: 张量序列
  • dim : 要拼接的维度
t = torch.ones((2, 3))
q = torch.zeros((2, 3))
t0 = torch.cat([t, q], 0)
t1 = torch.cat((t, q), dim=1)
print(t0, t0.shape)
print(t1, t1.shape)
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]]) torch.Size([4, 3])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.]]) torch.Size([2, 6])

torch.stack()

torch.stack(tensors, 
            dim=0, 
            out=None)

功能: 在新创建的维度dim上进行拼接

  • tensors:张量序列
  • dim :要拼接的维度
t = torch.ones((3, 4))
q = torch.zeros((3, 4))
t0 = torch.stack([t, q], dim=0)
t1 = torch.stack([t, q], dim=1)
t2 = torch.stack([t, q], dim=2)
print(t0, t0.shape)
print(t1, t1.shape)
print(t2, t2.shape)
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]]) torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [0., 0., 0., 0.]]]) torch.Size([3, 2, 4])
tensor([[[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]],

        [[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]],

        [[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]]) torch.Size([3, 4, 2])

张量的切分

torch.chunk()

torch.chunk(input, 
            chunks, 
            dim=0)

功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除,最后一份张量小于其他张量

  • input: 要切分的张量
  • chunks : 要切分的份数
  • dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.chunk(t, chunks=3, dim=1)
for i, ten in enumerate(list_t):
    print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
        [1.]]) torch.Size([2, 1])

torch.split()

torch.split(tensor, 
            split_size_or_sections, 
            dim=0)

功能: 将张量按维度dim进行切分
返回值: 张量列表

  • tensor: 要切分的张量
  • split_size_or_sections : 为int时,表示每一份的长度;为list时,按list元素切分
  • dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.split(t, 3, dim=1)
for i, ten in enumerate(list_t):
    print("第{}个张量:\n{}".format(i+1, ten), ten.shape)

print("\n")

list_t = torch.split(t, [3, 4], dim=1)
for i, ten in enumerate(list_t):
    print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
        [1.]]) torch.Size([2, 1])


第1个张量:
tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]]) torch.Size([2, 4])

张量的比较

torch.ge(),torch.gt(),torch.le(),torch.lt()

torch.ge(input, 
         other, 
         out=None)

功能: input中逐元素与other进行比较,满足:ge >=; gt >; le <=; lt <时,返回True
返回值: 与input同形状的布尔类型张量

  • input:被比较的张量
  • other:可以是张量,数值,布尔,input中逐元素与其进行比较
t = torch.randint(0, 10, [3, 3])
m = t.ge(5)
print(t)
print(m)
tensor([[1, 6, 5],
        [6, 5, 4],
        [0, 4, 4]])
tensor([[False,  True,  True],
        [ True,  True, False],
        [False, False, False]])

张量的索引

torch.index_select()

torch.index_select(input, 
                   dim, 
                   index, 
                   out=None)

功能: 在维度dim上,按index索引数据
返回值: 索引得到的数据拼接的张量

  • input: 要索引的张量
  • dim: 要索引的维度
  • index : 要索引数据的序号组成的张量,dtype须为torch.long
t = torch.randint(0, 10, [3, 3])
idx = torch.tensor([0, 2], dtype=torch.long)
sel = torch.index_select(t, 0, idx)
print(t)
print(sel)
tensor([[0, 7, 0],
        [8, 3, 1],
        [2, 7, 9]])
tensor([[0, 7, 0],
        [2, 7, 9]])

torch.masked_select()

torch.masked_select(input, 
                    mask, 
                    out=None)

功能: 按mask中的True进行索引
返回值: 一维张量

  • input: 要索引的张量
  • mask: 与input同形状的布尔类型张量
t = torch.randint(0, 10, [3, 3])
mask = t.ge(5) 
sel = torch.masked_select(t, mask)
print(t)
print(mask)
print(sel)
tensor([[1, 6, 5],
        [6, 5, 4],
        [0, 4, 4]])
tensor([[False,  True,  True],
        [ True,  True, False],
        [False, False, False]])
tensor([6, 5, 6, 5])

张量的变换

torch.reshape()

torch.reshape(input, 
              shape)

功能: 变换张量形状
注意事项: 当张量在内存中是连续的时,新张量与input共享数据内存。这种共享与out不同,out是整个tensor都共享内存,相当于别名;reshape是仅data共享内存。改变一个张量的数据,另一个张量会跟着改变

  • input: 要变换的张量
  • shape: 新张量的形状
t = torch.randperm(8)
re1 = torch.reshape(t, (2, 4))
re2 = torch.reshape(t, (-1, 4)) 
print(t)
print(re1)
print(re2)
t[0] = 100
re2[1, 1] = 100
print(id(t.data), id(re1.data), id(re2.data))
print(re1)
tensor([0, 7, 2, 6, 3, 5, 4, 1])
tensor([[0, 7, 2, 6],
        [3, 5, 4, 1]])
tensor([[0, 7, 2, 6],
        [3, 5, 4, 1]])
3039469824264 3039469824264 3039469824264
tensor([[100,   7,   2,   6],
        [  3, 100,   4,   1]])

torch.transpose()

torch.transpose(input, 
				dim0, 
				dim1)

功能: 交换张量的两个维度。在图像的预处理中常用,有时读取的图像数据是(c, h, w),但是我们常用的是(h, w, c),就需要用此方法把channel和width变换,再把width和height变换

  • input: 要变换的张量
  • dim0: 要交换的维度
  • dim1: 要交换的维度
t = torch.rand((2, 3, 4))
tr = torch.transpose(t, 1, 0)
print(t, t.shape)
print(tr, tr.shape)
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
         [0.4711, 0.1117, 0.1751, 0.4904],
         [0.9865, 0.7374, 0.9201, 0.5733]],

        [[0.4911, 0.4571, 0.9985, 0.7298],
         [0.5078, 0.0928, 0.1655, 0.8740],
         [0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([2, 3, 4])
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
         [0.4911, 0.4571, 0.9985, 0.7298]],

        [[0.4711, 0.1117, 0.1751, 0.4904],
         [0.5078, 0.0928, 0.1655, 0.8740]],

        [[0.9865, 0.7374, 0.9201, 0.5733],
         [0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([3, 2, 4])

torch.t()

torch.t(input)

功能: 2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)

torch.squeeze()

torch.squeeze(input, 
              dim=None, 
              out=None)

功能: 压缩长度为1的维度(轴)

  • dim: 若为None,移除所有长度为1的轴; 若指定维度,当且仅当该轴长度为1时,可以被移除
t = torch.rand((1, 2, 3, 1))
sq = torch.squeeze(t)
sq0 = torch.squeeze(t, 0)
sq1 = torch.squeeze(t, 1)
print(t.shape)
print(sq.shape)
print(sq0.shape)
print(sq1.shape)
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])

torch.unsqueeze()

torch.usqueeze(input, 
               dim, 
               out=None)

功能:依据dim扩展维度

  • dim: 扩展的维度
t = torch.rand((2, 3))
sq = torch.unsqueeze(t, 0)
print(t.shape)
print(sq.shape)
torch.Size([2, 3])
torch.Size([1, 2, 3])
发布了9 篇原创文章 · 获赞 0 · 访问量 298

猜你喜欢

转载自blog.csdn.net/SakuraHimi/article/details/104560909