The torch.max() function == "returns the maximum value of the dimension and the index corresponding to the maximum value of the dimension

Today, in the process of learning TTSR, I always encountered a line of code. I found that the max() function can return two values, so I decided to relearn this function.

R_lv3_star, R_lv3_star_arg = torch.max(R_lv3, dim=1) #[N, H*W]  hi


 1. Basic usage:

The first is the basic usage of torch.max(), input a tensor and return a certain maximum value

torch.max(input) → Tensor

Example:

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)

 2. In-depth usage:

torch.max(inputdimkeepdim=False*out=None)

Returns the maximum value by dimension dim, and returns the index.

Parameters

  • input (Tensor) – the input tensor.

  • dim (int) – the dimension to reduce.

  • keepdim (bool) – whether the output tensor has dim retained or not. Default: False.

Keyword Arguments

out  ( tuple optional ) – the result tuple of two output tensors (max, max_indices), the returned maximum value and index are each a tensor, which respectively represent the maximum value of the dimension and the index of the maximum value of the dimension, which together form the element group(Tensor, LongTensor)

Example:

torch.max(a,0) returns the element with the largest value in each column, and returns the index (returns the row index of the largest element in this column). The returned maximum value and index are each a tensor, which together form a tuple (Tensor, LongTensor)

a = torch.randn(4, 4)
print(a)
print(torch.max(a,0))


tensor([[ 0.7439,  2.2739, -2.7576, -0.0676],
        [-0.7755, -0.6696,  0.3009, -1.4939],
        [-0.9244,  2.7325,  1.7982,  1.2904],
        [-0.9091, -0.1857, -1.3392, -1.2928]])
torch.return_types.max(
values=tensor([0.7439, 2.7325, 1.7982, 1.2904]),
indices=tensor([0, 2, 2, 2]))

torch.max(a,1) returns the element with the largest value in each row, and returns its index (returns the column index of the largest element in this row)

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

Pytorch notes torch.max() Develop Paper

torch.max — PyTorch 1.10 documentation

Guess you like

Origin blog.csdn.net/weixin_43135178/article/details/123257024
Recommended