pytorch function test

Mainly test for operations that do not know the result of the function

1. Function test

1.1. Test.sum(dim=(m,n))

f = torch.arange(4 * 5 * 6).view(1,1, 4, 5, 6)
f.shape
Out[19]: torch.Size([1, 1, 4, 5, 6])
f
Out[20]: 
tensor([[[[[  0,   1,   2,   3,   4,   5],
           [  6,   7,   8,   9,  10,  11],
           [ 12,  13,  14,  15,  16,  17],
           [ 18,  19,  20,  21,  22,  23],
           [ 24,  25,  26,  27,  28,  29]],
          [[ 30,  31,  32,  33,  34,  35],
           [ 36,  37,  38,  39,  40,  41],
           [ 42,  43,  44,  45,  46,  47],
           [ 48,  49,  50,  51,  52,  53],
           [ 54,  55,  56,  57,  58,  59]],
          [[ 60,  61,  62,  63,  64,  65],
           [ 66,  67,  68,  69,  70,  71],
           [ 72,  73,  74,  75,  76,  77],
           [ 78,  79,  80,  81,  82,  83],
           [ 84,  85,  86,  87,  88,  89]],
          [[ 90,  91,  92,  93,  94,  95],
           [ 96,  97,  98,  99, 100, 101],
           [102, 103, 104, 105, 106, 107],
           [108, 109, 110, 111, 112, 113],
           [114, 115, 116, 117, 118, 119]]]]])
g = torch.arange(6).view(1,1, 1, 1, 6)
g.shape
Out[22]: torch.Size([1, 1, 1, 1, 6])
g
Out[23]: tensor([[[[[0, 1, 2, 3, 4, 5]]]]])
k = f*g
k.shape
Out[25]: torch.Size([1, 1, 4, 5, 6])
k
Out[26]: 
tensor([[[[[  0,   1,   4,   9,  16,  25],
           [  0,   7,  16,  27,  40,  55],
           [  0,  13,  28,  45,  64,  85],
           [  0,  19,  40,  63,  88, 115],
           [  0,  25,  52,  81, 112, 145]],
          [[  0,  31,  64,  99, 136, 175],
           [  0,  37,  76, 117, 160, 205],
           [  0,  43,  88, 135, 184, 235],
           [  0,  49, 100, 153, 208, 265],
           [  0,  55, 112, 171, 232, 295]],
          [[  0,  61, 124, 189, 256, 325],
           [  0,  67, 136, 207, 280, 355],
           [  0,  73, 148, 225, 304, 385],
           [  0,  79, 160, 243, 328, 415],
           [  0,  85, 172, 261, 352, 445]],
          [[  0,  91, 184, 279, 376, 475],
           [  0,  97, 196, 297, 400, 505],
           [  0, 103, 208, 315, 424, 535],
           [  0, 109, 220, 333, 448, 565],
           [  0, 115, 232, 351, 472, 595]]]]])
           
n = k.sum(dim=(2,3))
n.shape
Out[31]: torch.Size([1, 1, 6])
n
Out[32]: tensor([[[   0, 1160, 2360, 3600, 4880, 6200]]])

1.2, test.sum(dim=-1)

The program is executed after the above program,
so the value of dim is the dimension to be eliminated, and the value of dim is -1 to eliminate the last dimension.
For the above matrix with 5 dimensions, dimension indices 0, 1, 2, 3, 4.
First perform the elimination of dimensions 2, 3, the reserved dimensions 0, 1, 4, and then perform the elimination of dimension -1, that is, eliminate dimension 4, then the remaining dimensions are 0, 1.

o = n.sum(dim=-1)
o.shape
Out[34]: torch.Size([1, 1])
o
Out[35]: tensor([[18200]])

Guess you like

Origin blog.csdn.net/juluwangriyue/article/details/123487892