torch.argmax() function

argmax function: torch.argmax(input, dim=None, keepdim=False)Returns the serial number of the maximum value of the specified dimension. The definition given by dim is: the demention to reduce, which is to change the dimension of dim to the index of the maximum value of this dimension.

1) dim means different dimensions. In particular, dim=0 represents the column in the two-dimensional matrix, and dim=1 represents the row in the two-dimensional matrix. Broadly speaking, we don't care how many dimensions a matrix is. For example, a matrix has the following dimensions: (d0,d1,...,dn−1), then dim=0 means that it corresponds to d0, which is the first dimension, and dim= 1 means corresponding to the second dimension, and so on.

2) It is not enough to know what the value of dim means, but also to know what will happen when this dim is given in the function.

Example 1: Two-dimensional array

import torch

x = torch.randn(2, 4)
print(x)
'''
tensor([[ 1.2864, -0.5955,  1.5042,  0.5398],
        [-1.2048,  0.5106, -2.0288,  1.4782]])
'''

# y0表示矩阵dim=0维度上(每一列)张量最大值的索引
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([0, 1, 0, 1])
'''

# y1表示矩阵dim=1维度上(每一行)张量最大值的索引
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([2, 3])
'''

Example 2: Three-dimensional array

x = torch.randn(2, 4, 5)
print(x)
'''
tensor([[[-1.2204, -0.6428, -0.2278,  0.5589,  1.1589],
         [ 0.4235,  1.9663,  0.5055, -1.3472,  1.3523],
         [ 1.4220,  0.7886, -1.0821,  0.6268, -0.9465],
         [-0.3950,  1.3275,  0.3369,  1.0224, -0.9944]],

        [[ 0.6024, -0.2604, -0.8631,  0.8113, -0.3140],
         [ 0.3487, -0.1941, -0.3955, -0.1719, -1.3734],
         [ 0.2467, -0.4268, -1.3428,  0.7346,  1.0932],
         [-0.5799,  0.0976, -1.9403, -0.2643,  0.7657]]])
'''

# dim=0,将第一个维度消除,也就是将两个[4*5]矩阵只保留一个,因此要在上下两个[3*4]的矩阵分别在对应位置上比较
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([[1, 1, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 1],
        [0, 0, 0, 0, 1]])
'''

# dim=1,将第二个维度消除,也就是将四个[2*5]矩阵只保留一个
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([[2, 1, 1, 3, 1],
        [0, 3, 1, 0, 2]])
'''

y2 = torch.argmax(x, dim=2)
print(y2)
'''
tensor([[4, 1, 0, 1],
        [3, 0, 4, 4]])
'''

おすすめ

転載: blog.csdn.net/qq_40507857/article/details/123357918