Determine the number of dimensions of a Tensor in pytorch

  • 1. Under normal circumstances, the dimension of a Tensor is  judged by tensor.shape  , for example:
import torch

x = torch.randint(low=-10,high=10,size=(3,6))
print(x)
print(x.shape)

 

From this, it can be judged that it is 2-dimensional (3 elements in the first dimension and 6 elements in the second dimension).

  • 2. But sometimes, it is necessary to judge whether a Tensor is 1-dimensional, 2-dimensional or 3-dimensional, etc. in the code. How to judge? It can be judged by len(tensor.shape) , for example:
import torch

x = torch.randint(low=-10,high=10,size=(3,6))
print(x)
print(x.shape)
print(len(x.shape))

 

The output dimension will be 2.

 

Guess you like

Origin blog.csdn.net/m0_46483236/article/details/123723034