一、张量的操作
1.1张量拼接与切分
1.1.1 torch.cat()
torch.cat(tensors, dim=0, out=None)
功能: 将张量按维度dim进行拼接
- tensors: 张量序列
- dim: 要拼接的维度
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
1.1.2 torch.stack()
torch.stack(tensors, dim=0, out=None)
功能: 在新创建的维度dim上进行拼接
- tensors: 张量序列
- dim: 要拼接的维度
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
注意:
- cat()不会扩张张量的维度,而stack()则会
- 当dim指定为0,那么已有的维度会后移,如原(2,3)的t,t1=torch.stack([t, t], dim=0),t1维度为(2,2,3)
1.1.3 torch.chunk()
torch.chunk(input, chunks, dim=0)
功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除, 最后一份张量小于其他张量
- input: 要切分的张量
- chunks: 要切分的份数
- dim: 要切分的维度
a = torch.ones((2, 7)) # 7
print(a)
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx + 1, t, t.shape))
1.1.4 torch.split()
torch.split(tensor,split_size_or_sections, dim=0)
功能: 将张量按维度dim进行切分
返回值: 张量列表
- tensor: 要切分的张量
- split_size_or_sections: 为int时, 表示每一份的长度; 为list时, 按list元素切分
- dim : 要切分的维度
t = torch.ones((2, 5))
# list_of_tensors = torch.split(t, 2, dim=1)
# for idx, t in enumerate(list_of_tensors):
# print("第{}个张量:{}, shape is {}".format(idx + 1, t, t.shape))
list_of_tensors = torch.split(t, [2, 1, 2], dim=1)
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
1.2 张量索引
1.2.1 torch.index_select()
torch.index_select(input, dim, index, out=None)
功能: 在维度dim上,按index索引数据
返回值: 依index索引数据拼接的张量
- input: 要索引的张量
- dim: 要索引的维度
- index: 要索引数据的序号
1.2.2 torch.masked_select()
torch.masked_select(input, mask, out=None)
功能: 按mask中的True进行索引
返回值: 一维张量
- input: 要索引的张量
- mask: 与input同形状的布尔类型张量
1.3 张量变换
1.3.1 torch.reshape()
torch.reshape(input, shape)
功能: 变换张量形状
注意事项: 当张量在内存中是连续时,新张量与input共享数据内存
- input: 要变换的张量
- shape: 新张量的形状
1.3.2 torch.transpose()
torch.transpose(input, dim0, dim1)
功能: 交换张量的两个维度
- input: 要变换的张量
- dim0: 要交换的维度
- dim1: 要交换的维度
1.3.3 torch.t()
torch.t(input)
功能:2维张量转置,对矩阵而言,等价于 torch.transpose(input, 0, 1)
1.3.4 torch.squeeze()
torch.squeeze(input, dim=None, out=None)
功能: 压缩长度为1的维度(轴)
- dim: 若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;
1.3.5 torch.unsqueeze()
torch.unsqueeze(input, dim, out=None)
功能: 依据dim扩展维度
- dim: 扩展的维度