pytorch --- cat、stack、split的用法

可以用torch.cat方法和torch.stack方法将多个张量合并,可以用torch.split方法把一个张量分割成多个张量。

torch.cat和torch.stack有略微的区别,torch.cat是连接,不会增加维度,而torch.stack是堆叠,会增加维度。

举例如下:

cat

torch.cat(tensors,dim=0,out=None)→ Tensor

A = torch.tensor([[1, 2, 3],
        		  [4, 5, 6],
        		  [7, 8, 9]])
print("A shape: {}" .format(A.shape))

B = torch.tensor([[12, 22, 33],
        		  [44, 55, 66],
        		  [77, 88,99]])
print("B shape: {}" .format(B.shape))
A shape: torch.Size([3, 3])
B shape: torch.Size([3, 3])

dim = 0

按照维度0进行拼接,不新增维度,(3, 3)和(3, 3) cat后维度为:(6, 3)

result1 = torch.cat((A, B), 0)
print("result1 shape: {}".format(result1.shape))
print(result1)
result1 shape: torch.Size([6, 3])
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [12, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])

dim = 1

按照维度1进行拼接,不新增维度,(3, 3)和(3, 3) cat后维度为:(3, 6)

result2 = torch.cat((A, B), 1)
print("result2 shape: {}".format(result2.shape))
print(result2)
result2 shape: torch.Size([3, 6])
tensor([[ 1,  2,  3, 12, 22, 33],
        [ 4,  5,  6, 44, 55, 66],
        [ 7,  8,  9, 77, 88, 99]])

stack

dim=0

按照维度0进行拼接,会新增一个维度,(3, 3)和(3, 3) stack后维度为:(2, 3, 3)

result3 = torch.stack((A, B), dim=0)
print("result3 shape: {}".format(result3.shape))
print(result3)
result3 shape: torch.Size([2, 3, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[12, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

dim=1

按照维度1进行拼接,会新增一个维度,(3, 3)和(3, 3) stack后维度为:(3, 2, 3)

result4 = torch.stack((A, B), dim=1)
print("result4 shape: {}".format(result4.shape))
print(result4)
result4 shape: torch.Size([3, 2, 3])
tensor([[[ 1,  2,  3],
         [12, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

dim=2

按照维度2进行拼接,因为会新增一个维度,所以2不会索引越界,(3, 3)和(3, 3) stack后维度为:(3, 3, 2)

result5 = torch.stack((A, B), dim=2)
print("result5 shape: {}".format(result5.shape))
print(result5)
result5 shape: torch.Size([3, 3, 2])
tensor([[[ 1, 12],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

split

q, k, v = torch.split(result3, split_size_or_sections=1, dim=1)
print(q, q.shape)
print(k, k.shape)
print(v, v.shape)

print()
print(torch.stack((q, k, v), 1), torch.stack((q, k, v), 1).shape)
print(torch.cat((q, k, v), 1), torch.cat((q, k, v), 1).shape)
tensor([[[ 1,  2,  3]],

        [[12, 22, 33]]]) torch.Size([2, 1, 3])
tensor([[[ 4,  5,  6]],

        [[44, 55, 66]]]) torch.Size([2, 1, 3])
tensor([[[ 7,  8,  9]],

        [[77, 88, 99]]]) torch.Size([2, 1, 3])

tensor([[[[ 1,  2,  3]],

         [[ 4,  5,  6]],

         [[ 7,  8,  9]]],


        [[[12, 22, 33]],

         [[44, 55, 66]],

         [[77, 88, 99]]]]) torch.Size([2, 3, 1, 3])
         
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[12, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]]) torch.Size([2, 3, 3])
q1, q2 = torch.split(result3, split_size_or_sections=[2, 1], dim=1)
print(q1, q1.shape)
print(q2, q2.shape)
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[12, 22, 33],
         [44, 55, 66]]]) torch.Size([2, 2, 3])
         
tensor([[[ 7,  8,  9]],

        [[77, 88, 99]]]) torch.Size([2, 1, 3])

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/126599975