None作为ndarray或tensor的索引

None作为ndarray或tensor的索引作用是增加维度,与 pytorch中的 torch.unsqueeze() 或 tensorflow 中的tf.expand_dims() 作用相同

例子:

In [5]: t=torch.from_numpy(np.arange(12).reshape(3,4))

In [6]: t
Out[6]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [7]: t.dim()
Out[7]: 2

In [8]: t[:,None,:]
Out[8]:
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

In [9]: arr=np.arange(12).reshape(3,4)

In [10]: arr
Out[10]:
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

In [11]: arr[:,None,:]
Out[11]:
array([[[ 0,  1,  2,  3]],

       [[ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11]]])

In [12]: arr[:,None,None,:]
Out[12]:
array([[[[ 0,  1,  2,  3]]],


       [[[ 4,  5,  6,  7]]],


       [[[ 8,  9, 10, 11]]]])

In [13]: arr[:,None,None,:].shape
Out[13]: (3, 1, 1, 4)

In [14]: t
Out[14]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [15]: t.unsqueeze(1)
Out[15]:
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

In [16]: t
Out[16]:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [17]: t.unsqueeze(1)==t[:,None,:]
Out[17]:
tensor([[[True, True, True, True]],

        [[True, True, True, True]],

        [[True, True, True, True]]])


发布了67 篇原创文章 · 获赞 27 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/xpy870663266/article/details/105361313
今日推荐