pytorch中判断某一个Tensor的维度数是几

  • 1. 正常情况下,判断一个Tensor的维度是通过 tensor.shape 进行判断,例如:
import torch

x = torch.randint(low=-10,high=10,size=(3,6))
print(x)
print(x.shape)

 

由此可以判断是2维(第1维有3个元素,第2维有6个元素)。

  • 2. 但有时候,需要在代码中判断某一个Tensor到底是1维还是2维还是3维等,该如何判断呢?可以通过 len(tensor.shape) 判断,例如:
import torch

x = torch.randint(low=-10,high=10,size=(3,6))
print(x)
print(x.shape)
print(len(x.shape))

 

则会输出维度是2。

猜你喜欢

转载自blog.csdn.net/m0_46483236/article/details/123723034