pytorch切片与索引

import torch
import numpy as np

切片与索引

a = torch.randint(1,10,[4,3,5,5]) #4张图片,3个通道,长和宽都是28个pixel
a[0].shape #第一张图片的尺寸 0表示第一张
torch.Size([3, 5, 5])
a[1,2].shape #第二张图片,第三个通道的尺寸
torch.Size([5, 5])
a[:,1] #生成一个4*28*28的tensor 取出每张照片第一个通道的所有内容
tensor([[[9, 8, 7, 7, 8],
         [1, 5, 6, 6, 3],
         [5, 4, 5, 9, 5],
         [5, 5, 9, 6, 8],
         [4, 3, 5, 9, 9]],

        [[7, 8, 7, 9, 6],
         [9, 7, 5, 7, 3],
         [7, 1, 9, 7, 1],
         [5, 7, 9, 6, 8],
         [4, 4, 2, 4, 5]],

        [[8, 6, 7, 4, 8],
         [8, 2, 8, 1, 1],
         [3, 5, 4, 5, 1],
         [7, 5, 5, 3, 3],
         [3, 2, 5, 9, 1]],

        [[4, 4, 4, 3, 7],
         [9, 9, 9, 4, 9],
         [3, 8, 1, 9, 4],
         [1, 4, 8, 7, 3],
         [1, 8, 1, 4, 6]]])
a[0,0,0,0].dim()
0
a[:2].shape  #前两张图片的全部内容
torch.Size([2, 3, 5, 5])
a[:2,:,1,:2]#前两张图片,所有通道的第2行中的前两列
tensor([[[2, 1],
         [4, 7],
         [8, 6]],

        [[7, 5],
         [9, 4],
         [1, 4]]])
a.index_select(0,torch.tensor([1])) #0表示第0维(也就是4张图片) torch.tensor([0])表示4张图片中的哪张照片,(1表示是第二张照片)
tensor([[[[2, 8, 8, 7, 8],
          [2, 7, 4, 2, 5],
          [9, 6, 4, 8, 2],
          [2, 3, 6, 3, 8],
          [5, 2, 1, 3, 7]],

         [[7, 8, 7, 9, 6],
          [9, 7, 5, 7, 3],
          [7, 1, 9, 7, 1],
          [5, 7, 9, 6, 8],
          [4, 4, 2, 4, 5]],

         [[9, 8, 3, 9, 7],
          [8, 9, 4, 2, 5],
          [3, 5, 1, 1, 5],
          [4, 1, 9, 8, 1],
          [1, 2, 3, 5, 9]]]])
a.index_select(1,torch.tensor([0,2])).shape #1表示第二个维度,即通道,[0,2]表示第一个通道和第三个通道 那么整个就表示取出所有图片的第一个和第三个通道的所有信息
torch.Size([4, 2, 5, 5])
a[0,...,-1] #第一张图片的所有通道的所有行的最后一列
tensor([[9, 2, 3, 8, 9],
        [8, 3, 5, 8, 9],
        [2, 4, 6, 5, 1]])
a[...,2,2]#所有图片的所有通道的第三行第三列 即1表示第一张图片的第一个通道的第三行第三列,5表示第一张图片的第二个通道的第三行第三列
tensor([[1, 5, 6],
        [4, 9, 1],
        [6, 4, 4],
        [6, 1, 1]])
  • 掩码
mask = a.ge(5) #a里面的元素是不是大于等于5 是的话返回true 不是的话返回false
torch.masked_select(a,mask)
tensor([5, 6, 9, 6, 6, 8, 5, 8, 9, 5, 9, 9, 9, 8, 7, 7, 8, 5, 6, 6, 5, 5, 9, 5,
        5, 5, 9, 6, 8, 5, 9, 9, 8, 7, 8, 8, 8, 5, 6, 6, 6, 7, 5, 5, 7, 8, 8, 7,
        8, 7, 5, 9, 6, 8, 6, 8, 5, 7, 7, 8, 7, 9, 6, 9, 7, 5, 7, 7, 9, 7, 5, 7,
        9, 6, 8, 5, 9, 8, 9, 7, 8, 9, 5, 5, 5, 9, 8, 5, 9, 9, 8, 9, 5, 5, 5, 6,
        5, 6, 6, 9, 9, 5, 7, 5, 8, 6, 7, 8, 8, 8, 5, 5, 7, 5, 5, 5, 9, 6, 5, 5,
        5, 8, 6, 6, 9, 5, 9, 6, 5, 5, 5, 9, 7, 7, 6, 6, 7, 9, 9, 6, 7, 7, 9, 9,
        9, 9, 8, 9, 8, 7, 8, 6, 6, 8, 8, 6, 7, 6, 7, 6, 5, 5, 7, 5])
  • flatten index
src = torch.tensor([[1,2,3],[4,5,6]])
src
tensor([[1, 2, 3],
        [4, 5, 6]])
src.take(torch.tensor([0,3,5])) #take的话,就是先把src拉平,然后按照torch.tensor的索引取出对应的元素
tensor([1, 4, 6])
发布了43 篇原创文章 · 获赞 1 · 访问量 762

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104679520