PyTorch张量操作(拼接、切分、索引、维度变换)

张量拼接

使用torch.cat()

功能:将张量按维度dim进行拼接
示例:

t = torch.ones((2, 3))

t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)

print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

结果:

t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])

使用torch.stack()

功能:新创建一个维度dim,然后在新维度dim上进行拼接
示例:

t = torch.ones((2, 3))

t_stack = torch.stack([t, t, t], dim=0)

print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

结果:

t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([3, 2, 3])

对比

cat不会拓展维度,而stack一定会创建一个新维度,会拓展维度

张量切分

使用torch.chunk()

功能:按照指定维度dim进行平均切分,如果不能整除,那最后一份张量的结果小于平均张量,最后一份张量在dim维度上的大小为余数
示例:

a = torch.ones((2, 7))
list_of_tensors = torch.chunk(a, dim=1, chunks=3)

for idx, t in enumerate(list_of_tensors):
	print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

结果:

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

使用torch.split()

功能:可设置切分每一份的长度,当长度为整型int时,表示每一份的长度,当长度为列表list时,表示按list元素切分。如果list中元素和大小不等于原维度,就会报错。
示例:

t = torch.ones((2, 5))

list_of_tensors = torch.split(t, [2, 1, 2], dim=1)

for idx, t in enumerate(list_of_tensors):
	print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

结果:

第1个张量:tensor([[1., 1.],
        [1., 1.]]), shape is torch.Size([2, 2])
第2个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])
第3个张量:tensor([[1., 1.],
        [1., 1.]]), shape is torch.Size([2, 2])

张量索引

使用torch.index_select()

功能:在维度dim上,按照index索引数据,返回由index索引到的数据拼接的张量(index数据类型是指定的,必须是torch.long)
示例:

t = torch.randint(0, 9, size=(3, 3))

idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)

print("t:\n{}\nt_select:\n{}".format(t, t_select))

结果:

t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
t_select:
tensor([[4, 5, 0],
        [2, 5, 8]])

使用torch.masked_select()

功能:按照mask中的true进行索引,mask是与待索引张量形状相同的布尔型张量。
示例:

t = torch.randint(0, 9, size=(3, 3))

mask = t.le(5)  # 张量小于等于5的元素(ge,le:greater than or equal, less than or equal)
t_select = torch.masked_select(t, mask)

print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

结果:

t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
mask:
tensor([[1, 1, 1],
        [1, 0, 1],
        [1, 1, 0]], dtype=torch.uint8)
t_select:
tensor([4, 5, 0, 5, 1, 2, 5])

张量维度变换

使用torch.reshape()

功能:变换张量形状,如果张量在内存中是连续的,那么新张量与原张量共享内存。shape以元组形式写出。
示例:

t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2))  # -1表示自动计算剩下的维度,这里就是2
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

# 印证共享内存
t[0] = 1024
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))

结果:

t:tensor([5, 4, 2, 6, 7, 3, 1, 0])
t_reshape:
tensor([[[5, 4],
         [2, 6]],

        [[7, 3],
         [1, 0]]])
t:tensor([1024,    4,    2,    6,    7,    3,    1,    0])
t_reshape:
tensor([[[1024,    4],
         [   2,    6]],

        [[   7,    3],
         [   1,    0]]])
t.data 内存地址:2924900028352
t_reshape.data 内存地址:2924900028352

使用torch.transpose()

功能:交换张量的两个维度。而torch.t()仅限二维张量,多用于矩阵转置。如果想多个维度重组就用torch.permute()。
示例:

t = torch.rand((2, 3, 4))

t_transpose = torch.transpose(t, dim0=1, dim1=2)

print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))

结果:

t shape:torch.Size([2, 3, 4])
t_transpose shape: torch.Size([2, 4, 3])

使用torch.squeeze()

功能:压缩张量中长度为1的维度。如果不指定dim,默认移除所有长度为1的维度,如果指定dim,当且仅当该维度长度为1时可被移除。torch.unsqueeze()则是根据dim扩展维度。
示例:

t = torch.rand((1, 2, 3, 1))

t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)

print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)

结果:

torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])
发布了33 篇原创文章 · 获赞 45 · 访问量 2509

猜你喜欢

转载自blog.csdn.net/nstarLDS/article/details/104610593