张量的详细操作

目录

一、张量的维度索引

二、张量的维度切片

三、张量的维度变换

(1)view() 和 reshape() 变换维度

​编辑

(2)unsqueeze() 增加新的数据维度

(3)squeeze() 缩减数据维度

(4)expand()和 repeat()在某维度上扩展数据

(5)transpose()和 permute()进行张量的维度调整


神经网络的数据存储中都使用张量(Tensor),那张量又是什么呢?

        张量这一概念的核心在于,它是一个数据容器。它包含的数据几乎总是数值数据,因此它是数字的容器。张量是矩阵向任意维度的推广,注意,张量的维度(dimension)通常叫作轴(axis)。标量(0D 张量),向量(1D 张量),矩阵(2D 张量),3D 张量与更高维张量

一、张量的维度索引


张量的索引是从第零维度开始的。

例子:

创建一个四维的张量:torch.Tensor(2,3,64,64) 此时,这个张量可以表示两张边长为64的正方形彩色图像,具体来说,张量的第零维表示图像的数量;第一维表示图像的颜色通道(3即为彩色图片,代表RGB三通道);第二维和第三维代表图像的高度和宽度。此张量的索引代码如下:
 

import torch

a = torch.Tensor(2, 3, 64, 64)
# 通过.shape的方法查看当前张量的形状
print(a.shape)
print(a[0].shape)
print(a[0][0].shape)
print(a[0][0][0].shape)
print(a[0][0][0][0].shape)
# 输出
# torch.Size([2, 3, 64, 64]) # 图像的形状
# torch.Size([3, 64, 64]) # 取到第一张图像,形状为 [3, 64, 64]
# torch.Size([64, 64]) # 取到第一张图像的第一个颜色通道, 形状为[64, 64]
# torch.Size([64]) # 取到第一张图像的第一个颜色通道的第一列像素值,形状为64
# torch.Size([]) # 取到第一张图像的第一个颜色通道的第一个像素值,形状为0(因为是标量)

另外,需要注意的是pytorch也支持负索引,使用方法与python中的负索引相同。

二、张量的维度切片


维度的索引是取到某维度上的全部数据。只想要某维度上的部分数据应该怎么做?这就是切片的作用。

切片方法的格式为:tensor[ first : last : step] first与last为切片的起始和结束位置,取值方法是按照step的间隔进行左闭右开的取值;当间隔为1时,step可以默认不写;当取到该维度的所有数据时,使用冒号即可。

例子:

import torch

a = torch.Tensor(2, 3, 64, 64)
# 通过.shape的方法查看当前张量的形状
print(a.shape)
print(a[1:2, :, :, :].shape)
print(a[ : , : , 0:32, 0:32].shape)
print(a[ : , : , 0:32:2, 0:32:2].shape)
print(a[ : , : , : : 2, : : 2].shape)
# 输出
# torch.Size([2, 3, 64, 64]) # 图像的形状
# torch.Size([1, 3, 64, 64]) # 取到第二张图像
# torch.Size([2, 3, 32, 32]) # 取到两张图像1/4大小的左上角子图
# torch.Size([2, 3, 16, 16]) # 取到两张图像1/4大小的左上角子图后,在子图上隔点取样
# torch.Size([2, 3, 32, 32]) # 在原图上隔点取样

三、张量的维度变换


图片是三维的数据维度(颜色通道,高度,宽度),但是神经网络层能接受的数据维度是二维,此时维度是不匹配的,因此需要将图像的空间维度打平成向量。下面介绍pytorch中一些常见的维度变换方法。

(1)view() 和 reshape() 变换维度

import torch

a = torch.Tensor(2, 3, 32, 32)

print(a.view(2, 3, 32 * 32))
print(a.reshape(2, 3, 32 * 32))
print(a.reshape(2, 3, -1))

# 结果:torch.Size([2, 3, 1024])
# torch.Size([2, 3, 1024])
# torch.Size([2, 3, 1024])

view() 和reshape()都可以对某张量进行维度的变化,reshape()方法的鲁棒性更强,更推荐使用。此外,view() 和reshape()接受的参数都是变换后的维度大小,在设置变换后维度的参数时,如果只剩一个维度没有给予,可直接使用-1来代替,pytorch会根据之前已设置的维度自动推导出最后未给予的维度。最后,这里需要注意的是变换后的总维度数量必须与变换前相等,否则报错。变换后新形状的总元素数量与原始张量的总元素数量也需要相同

#     if total_elements == torch.tensor(new_shape).prod():
#         rgb_reshaped = torch.reshape(rgb, new_shape)
#         print(rgb_reshaped.shape)  # 输出 torch.Size([1024, 3])
#     else:
#         print("无法调整形状,总元素数量不匹配。")

(2)unsqueeze() 增加新的数据维度


数据的增加需要在原始张量表示的基础上扩张维度来存储新增加的数据。 unsqueeze() 方法用来增加数据维度的,接受的参数含义是在哪个维度之前增加新维度,这个参数也支持负索引。

例子:

import torch

a = torch.Tensor(2, 3, 64, 64)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(-1).shape)

# output
# torch.Size([1, 2, 3, 64, 64])
# torch.Size([2, 1, 3, 64, 64])
# torch.Size([2, 3, 1, 64, 64])
# torch.Size([2, 3, 64, 64, 1])

(3)squeeze() 缩减数据维度


减少维度的方法是squeeze(),接受的参数是要进行维度缩减的维度索引,注意,缩减的维度值必须等于1,否则不能进行缩减,而且程序不报错。

例子:

import torch

a = torch.Tensor(2, 1, 64, 64)
print(a.squeeze(1).shape)
print(a.squeeze(2).shape) 

# output
# torch.Size([2, 64, 64])
# torch.Size([2, 1, 64, 64]) 

(4)expand()和 repeat()在某维度上扩展数据


expand()可以在某维度上进行数据扩展,扩展的方法是复制原始数据。注意,expand()方法不能扩展维度大于1的维度,否则报错。因为其扩展方式是复制,当维度大于1时,expand()方法不清楚应该复制哪个数据。

例子:

import torch

a = torch.Tensor(2, 1, 64, 64) 
print(a.shape)
print(a.expand(2,3,64,64).shape)
print(a.expand(2,3,65,65).shape)

# output
# torch.Size([2, 1, 64, 64])
# torch.Size([2, 3, 64, 64])
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# Input In [22], in <cell line: 6>()
#       4 print(a.shape)
#       5 print(a.expand(2,3,64,64).shape)
# ----> 6 print(a.expand(2,3,65,65).shape)

# RuntimeError: The expanded size of the tensor (65) must match the existing size (64) 
# at non-singleton dimension 3 Target sizes: [2, 3, 65, 65]. Tensor sizes: [2, 1, 64, 64]

repeat()也可以在某维度上进行数据扩展,但是其接受的参数含义与expand()函数不同。repeat()函数接受的是在该维度上复制全部数据的次数.

例子:

import torch

a = torch.Tensor(2, 1, 64, 64) 
print(a.shape)
print(a.repeat(1,3,1,1).shape)
print(a.repeat(3,3,3,3).shape)

# output
# torch.Size([2, 1, 64, 64])
# torch.Size([2, 3, 64, 64])
# torch.Size([6, 3, 192, 192])

(5)transpose()和 permute()进行张量的维度调整


transpose()可以通过指定张量中某两个维度的索引,来对这两个维度的数据进行交换维度操作,示例如下:

import torch

a = torch.Tensor(2, 3, 64, 64) 
print(a.shape)
print(a.transpose(0, 1).shape) 

# output
# torch.Size([2, 3, 64, 64])
# torch.Size([3, 2, 64, 64])

猜你喜欢

转载自blog.csdn.net/qq_46684028/article/details/133133226
今日推荐