- input: Input, the type is Tensor.
- start_dim: The starting dimension of the flattening.
- end_dim: The end dimension of the flattening.
import torch
a = torch.ones(2,3,4,5)
b = torch.flatten(a,start_dim=0,end_dim=2)
# 从0维开始往后推,推到第2维。所以最后应该是:(2*3*4,5)
print(b.shape)
b = torch.flatten(a,end_dim=2)
# 默认为0
print(b.shape)
b = torch.flatten(a,start_dim=-1)
# 从最后一维往后退,不变
print(b.shape)
b = torch.flatten(a,end_dim=-1)
# 推到最后一维,展平
print(b.shape)
Result:
torch.Size([24, 5])
torch.Size([24, 5])
torch.Size([2, 3, 4, 5])
torch.Size([120])