一、前言
前文简单的介绍了pytorch的数据类型,本文将主要介绍tensor的索引、切片等操作。
二、数据操作
1、索引
a = torch.randn(64,3,28,28) #以MNIST的数据维度为例
[In] a[0].shape #首张输入图片的维度信息
[Out] torch.Size([3,28,28])
[In] a[0,0].shape #首张图片的第一个通道的维度信息
[Out] torch.Size([28,28])
[In] a[0,0,27,27].shape #像素值
[Out] tensor(-1.5442)
[In] a.index_select(2,torch.arange(8)).shape #在维度2的位置,数值为0~7
[Out] torch.Size([64, 3, 8, 28])
b = torch.randn(2,3)
[Out]tensor([[-0.9997, -0.1196, -1.5024],
[-0.1027, -0.9639, 0.9655]])
[In] mask = b.ge(0.5) #在数组中比0.5大的数位True,反之为False
[Out] tensor([[False, False, False],
[False, False, True]])
[In] torch.masked_select(b,mask) #获取b中比0.5大的元素
[Out] tensor([0.9655])
2、切片
a = torch.randn(64,3,28,28)
[In] a[0:2].shape #在维度0的位置取0~1,手写体数字而言就是选取前两张图片的信息
[Out] torch.Size([2, 3, 28, 28])
[In] a[0:2,:1,:,:].shape #维度0选取0~1,维度1选取0,维度2选取0~27,维度3选取0~27
[Out] torch.Size([2, 1, 28, 28])
[In] a[:2,1:,:,:].shape #维度0选取0~1,维度1选取1,维度2选取0~27,维度3选取0~27
[Out] torch.Size([2, 2, 28 ,28])
[In] a[:2,-1:,:,:].shape #维度0选取0~1,维度1选取2,维度2选取0~27,维度3选取0~27
[Out] torch.Size([2, 1, 28, 28])
[In] a[:,:,:28:2,::2].shape #维度0全选,维度1全选,维度2在0~27步长位2选取,维度3相同
[Out] torch.Size([64, 3, 14, 14])
[In] a[...].shape #全选 a = a[...]
[Out] torch.Size([64, 3, 28, 28])
[In] a[0,...].shape #维度0取0,其他维度全选
[Out] torch.Size([3, 28, 28])
[In] a[:,1,...].shape #维度1取1,其他维度全选
[Out] torch.Size([64, 28, 28])
三、总结
数据的索引与切片较为简单,与数组相同。但是当这两个操作出现在项目中时,我会突然忘记或者搞混,所以在此记录一下,还是需要大量锻炼才行。翻过一座山又是一座山,下座山峰见。