numpy和pytorch中的元素拼接操作:stack,concatenat,cat

numpy中对ndarray的拼接

numpy.concatenate

将多个矩阵沿着一个已经存在的维度进行拼接

需求:两个张量的维度的数量需要相同,同时除了拼接的维度,其它维度的形状需要相同。

np.concatenate((,),axis=)
'''
第一个参数为要拼接的矩阵组成的元组
第二个参数为要拼接的维度
'''

a = np.arange(3*3).reshape((3,3))
b = np.arange(3*4).reshape((3,4))

a,b
(array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]),
 array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]))

np.concatenate([a,b],axis=1)
array([[ 0,  1,  2,  0,  1,  2,  3],
       [ 3,  4,  5,  4,  5,  6,  7],
       [ 6,  7,  8,  8,  9, 10, 11]])

numpy.stack

将具有相同维度和形状的矩阵,在一个新的维度上进行堆叠

需求:矩阵具有完全相同的尺寸

np.stack((,), axis=)
'''
参数和上面的函数相同
'''


x1 = np.arange(9).reshape((3,3))
x2 = np.arange(10,19,1).reshape((3,3))

y2 = np.stack((x1,x2),axis=0)

输出:
    [[[ 0  1  2]
      [ 3  4  5]
      [ 6  7  8]]

     [[10 11 12]
      [13 14 15]
      [16 17 18]]]

    'y2.shape': (2,3,3)

PS:

np.hstack(tup) = np.concatenate(tup, axis=1)

np.vstack(tup) = np.concatenate(tup, axis=0)

pytorch中对ndarray的拼接

pytorch中同样存在两个函数和numpy中的两种操作相同,分别为:

torch.cat((,), dim=) -> numpy.concatenate((,), axis=)

torch.stack((,), dim=) - > numpy.stack((,), axis=)

区别:

torch对tensor张量进行拼接,函数的第二个参数为dim=

numpy对ndarray进行拼接,函数的第二个参数为axis=

猜你喜欢

转载自blog.csdn.net/c_procomer/article/details/126030724