pytorch索引与切片

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/3/13 14:27
# @Author  : zhoujianwen
# @Email   : [email protected]
# @File    : IndexSlice.py
# @Describe: 索引与切片
import torch
import numpy as np

# cnn:[b,c,h,w]
a = torch.rand(4,3,28,28)
print("a[1].shape",a[1].shape)
a[1].shape torch.Size([3, 28, 28])
# 第0张图片第0个通道的size
print("a[0,0].shape",a[0,0].shape)
a[0,0].shape torch.Size([28, 28])
# 第0张图片第0个通道的第二行第四列的像素点,dim=0的标量数据
print("a[0,0,2,4]:",a[0,0,2,4])
print("a[0,0,2,4].size:",a[0,0,2,4].shape)
print("a[0,0,2,4].dim:",a[0,0,2,4].dim())
a[0,0,2,4]: tensor(0.3461)
a[0,0,2,4].size: torch.Size([])
a[0,0,2,4].dim: 0
# select first/last N
print("a.shape:",a.shape)
a.shape: torch.Size([4, 3, 28, 28])
# [0,2)只包含第0、1张图片,不包含第2张图片,左闭右开
print("a[:2].shape:",a[:2].shape)
a[:2].shape: torch.Size([2, 3, 28, 28])
# 读取第[0,2)张图片的第[0,1)个通道的所有像素点的数据
print("a[:2,:1,:,:].shape:",a[:2,:1,:,:].shape)
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
# 读取第[0,2)张图片的第[1,3)个通道开始,也就是取G,B通道
print("a[:2,1:,:,:].shape:",a[:2,1:,:,:].shape)
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
# 读取第[0,2)张图片的第[-1 -> 3)个通道开始,也就是取B通道
print("a[:2,-1:,:,:].shape:",a[:2,-1:,:,:].shape)
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])
# [n:],n<-x,[:n],x->n
b = [1,2,3]
print(b[-1:])  # [3]
print(b[:2])  # [1,2]
print(b[2:])  # [3]
[3]
[1, 2]
[3]
# select by steps,其实只有一种通用形式:start:end:step
# 隔行采样,0:28:等同于0:28:1
print("a[:,:,0:28:2,0:28:2].shape:",a[:,:,0:28:2,0:28:2].shape)
print("a[:,:,::2,::2].shape",a[:,:,::2,::2].shape)
a[:,:,0:28:2,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape torch.Size([4, 3, 14, 14])
# select by specific index
print("a.shape:",a.shape)
a.shape: torch.Size([4, 3, 28, 28])
# 选择第0个维度的第[0,2]张图片
print("a.index_select:",a.index_select(0,torch.tensor([0,2])))
a.index_select: tensor([[[[0.8546, 0.2013, 0.1062,  ..., 0.9824, 0.8863, 0.6017],
          [0.1386, 0.2514, 0.8634,  ..., 0.7499, 0.7272, 0.8322],
          [0.3031, 0.3741, 0.5038,  ..., 0.7924, 0.9043, 0.2345],
          ...,
          [0.8459, 0.2774, 0.1076,  ..., 0.0398, 0.6477, 0.9998],
          [0.3281, 0.1745, 0.5973,  ..., 0.8342, 0.8901, 0.3011],
          [0.1842, 0.2550, 0.6711,  ..., 0.1121, 0.9892, 0.3081]],

         [[0.8326, 0.5537, 0.0351,  ..., 0.9903, 0.9098, 0.4415],
          [0.7563, 0.1376, 0.1964,  ..., 0.0684, 0.6294, 0.8427],
          [0.8452, 0.0330, 0.6098,  ..., 0.7506, 0.7094, 0.8238],
          ...,
          [0.7471, 0.3504, 0.6913,  ..., 0.8323, 0.3782, 0.5122],
          [0.1563, 0.4442, 0.3153,  ..., 0.9618, 0.5851, 0.0769],
          [0.5881, 0.7677, 0.9798,  ..., 0.8064, 0.4111, 0.7886]],

         [[0.2338, 0.6615, 0.1037,  ..., 0.3131, 0.0601, 0.8218],
          [0.4265, 0.0386, 0.4150,  ..., 0.5856, 0.6884, 0.5182],
          [0.2303, 0.2400, 0.7040,  ..., 0.0359, 0.8518, 0.6753],
          ...,
          [0.6395, 0.4236, 0.0155,  ..., 0.2507, 0.5653, 0.7536],
          [0.6709, 0.6526, 0.7172,  ..., 0.7923, 0.9099, 0.2377],
          [0.9862, 0.2967, 0.9797,  ..., 0.7989, 0.8265, 0.5555]]],


        [[[0.7573, 0.8412, 0.4920,  ..., 0.1342, 0.5042, 0.6848],
          [0.9763, 0.7163, 0.7308,  ..., 0.1032, 0.8580, 0.9857],
          [0.5480, 0.2225, 0.9607,  ..., 0.7664, 0.9186, 0.6998],
          ...,
          [0.8516, 0.2715, 0.9604,  ..., 0.0824, 0.6400, 0.0049],
          [0.2225, 0.0882, 0.7353,  ..., 0.5928, 0.5812, 0.5691],
          [0.6565, 0.1949, 0.6019,  ..., 0.5297, 0.2580, 0.4920]],

         [[0.9223, 0.5609, 0.7168,  ..., 0.5464, 0.0719, 0.1355],
          [0.8929, 0.3358, 0.0200,  ..., 0.3294, 0.2123, 0.1627],
          [0.7752, 0.0918, 0.6133,  ..., 0.9695, 0.3872, 0.4596],
          ...,
          [0.3844, 0.5066, 0.9669,  ..., 0.3451, 0.5119, 0.0080],
          [0.3395, 0.8516, 0.4613,  ..., 0.2481, 0.1935, 0.3335],
          [0.7209, 0.6026, 0.2995,  ..., 0.1239, 0.8646, 0.7569]],

         [[0.7907, 0.0318, 0.4738,  ..., 0.4866, 0.4233, 0.8879],
          [0.8737, 0.4157, 0.8674,  ..., 0.7644, 0.6824, 0.2374],
          [0.3484, 0.4345, 0.0616,  ..., 0.3273, 0.6805, 0.2046],
          ...,
          [0.1819, 0.9985, 0.7422,  ..., 0.5651, 0.6958, 0.7335],
          [0.0575, 0.5193, 0.7036,  ..., 0.4359, 0.0778, 0.2884],
          [0.9213, 0.3569, 0.7103,  ..., 0.0056, 0.5919, 0.4703]]]])
# 选择第1个维度的第[1,2]个通道的图片
print("a.index_select(1,torch.tensor([1,2])).shape:",a.index_select(1,torch.tensor([1,2])).shape)
a.index_select(1,torch.tensor([1,2])).shape: torch.Size([4, 2, 28, 28])
# 选择第2个维度的第[0,28)行
print("a.index_select(2,torch.arange((28)).shape:",a.index_select(2,torch.arange(28)).shape)
a.index_select(2,torch.arange((28)).shape: torch.Size([4, 3, 28, 28])
# 选择第2个维度的第[0,8)行
print("a.index_select(2,torch.arange((8)).shape:",a.index_select(2,torch.arange(8)).shape)
a.index_select(2,torch.arange((8)).shape: torch.Size([4, 3, 8, 28])
# 选择所有维度信息, ...代表任意多的意思,等价于:,:,:,:  ,两个::代表隔行
print("a[...].shape:",a[...].shape)
print("a[:,:,:,:].shape:",a[:,:,:,:].shape)
a[...].shape: torch.Size([4, 3, 28, 28])
a[:,:,:,:].shape: torch.Size([4, 3, 28, 28])
# 当有...出现时,右边的索引需要理解为最右边,
# ...会根据shape的实际情况推测a的维度,比如a[0,...,::2],中间的...就代表C和H,::2 隔列采样
print("a[0,...,::2].shape:",a[0,...,::2].shape)
print("a[0,:,:,::2].shape:",a[0,:,:,::2].shape)
print("a[0].shape:",a[0].shape)
a[0,...,::2].shape: torch.Size([3, 28, 14])
a[0,:,:,::2].shape: torch.Size([3, 28, 14])
a[0].shape: torch.Size([3, 28, 28])
# 选择第0张图片的所有维度信息
print("a[0,...].shape:",a[0,...].shape)
a[0,...].shape: torch.Size([3, 28, 28])
# 选择所有图片的第一个通道的维度信息
print("a[:,1,...]:",a[:,1,...].shape)
print("a[:,1]:",a[:,1].shape)
a[:,1,...]: torch.Size([4, 28, 28])
a[:,1]: torch.Size([4, 28, 28])
# 选择所有图片的第2列像素的维度信息
print("a[...,:2].shape:",a[...,:2].shape)
a[...,:2].shape: torch.Size([4, 3, 28, 2])
# select by mask , .masked_select()
x = torch.randn(3,4)
print("x:",x)
mask = x.ge(0.5)
print("mask:",mask)
x: tensor([[ 0.6774,  1.4763, -0.5385, -0.8852],
        [ 0.5032, -0.2443,  2.2794, -0.4565],
        [-0.2374, -0.6779, -1.1302, -0.3287]])
mask: tensor([[ True,  True, False, False],
        [ True, False,  True, False],
        [False, False, False, False]])
# 大于0.5的元素个数是根据内容才能确定的,比如取出所有概率大于0.5的物体
print("torch.masked_select(x,mask):",torch.masked_select(x,mask))
torch.masked_select(x,mask): tensor([0.6774, 1.4763, 0.5032, 2.2794])
# dim=1的不定长度,长度取决于有多少个元素大于0.5
print("torch.masked_select(x,mask).shape:",torch.masked_select(x,mask).shape)
torch.masked_select(x,mask).shape: torch.Size([4])
# select by flatten index
# 先把[2,3],dim=2 -> [6],dim=1
src = torch.tensor([[4,3,5],[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5])))
tensor([4, 5, 8])
发布了40 篇原创文章 · 获赞 0 · 访问量 1531

猜你喜欢

转载自blog.csdn.net/zhoumoon/article/details/104894736