pytorch 的 cat函数的详解

cat是concatnate的意思,也就是说进行 张量 的拼接

实例:

>>> import torch
>>> A=torch.ones(2,3)    #2x3的张量(矩阵)
>>> A
tensor([[1., 1., 1.],
        [1., 1., 1.]])
>>> A1=2*torch.ones(4,3)  #4x3的张量(矩阵)
>>> A1
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

>>> A2=torch.cat((A,A1),0)  #按维数0(行)拼接
>>> A2
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
>>>

>>> A.shape
torch.Size([2, 3])
>>> A1.shape
torch.Size([4, 3])
>>> A2.shape
torch.Size([6, 3])
>>>

猜你喜欢

转载自blog.csdn.net/Vertira/article/details/131484743