torch之torch.flatten()

  • input: Input, the type is Tensor.
  • start_dim: The starting dimension of the flattening.
  • end_dim: The end dimension of the flattening.
import torch

a = torch.ones(2,3,4,5)

b = torch.flatten(a,start_dim=0,end_dim=2)
# 从0维开始往后推,推到第2维。所以最后应该是:(2*3*4,5)
print(b.shape)

b = torch.flatten(a,end_dim=2)
# 默认为0
print(b.shape)

b = torch.flatten(a,start_dim=-1)
# 从最后一维往后退,不变
print(b.shape)

b = torch.flatten(a,end_dim=-1)
# 推到最后一维,展平
print(b.shape)

Result:

torch.Size([24, 5])
torch.Size([24, 5])
torch.Size([2, 3, 4, 5])
torch.Size([120])

Guess you like

Origin blog.csdn.net/jokerxsy/article/details/105968782