First look at the function parameters:
torch.flatten(input, start_dim=0, end_dim=-1)
input: A tensor, that is, the tensor to be "flattened".
start_dim: The starting dimension of "smoothing".
end_dim: The end dimension of "push flat".
First, if you follow the default values of start_dim and end_dim, then this function will flatten the input into a tensor with shape [n] [n], where nn is the number of elements in the input.
What if we want to set the start and end dimensions ourselves?
Let's first look at the shape of tensor:
t = torch.tensor([[[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4]],
[[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]]])
print(t, t.shape)
运行结果:
tensor([[[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4]],
[[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]]])
torch.Size([2, 3, 4])
We can see that the outermost square bracket contains two elements, so the first value of shape is 2; similarly, the second square bracket contains three elements, and the second value of shape is 3; The innermost square bracket contains four elements, and the second value of shape is 4.
Sample code:
x = torch.flatten(t, start_dim=1)
print(x, x.shape)
y = torch.flatten(t, start_dim=0, end_dim=1)
print(y, y.shape)
运行结果:
tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4],
[5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]])
torch.Size([2, 12])
tensor([[1, 2, 2, 1],
[3, 4, 4, 3],
[1, 2, 3, 4],
[5, 6, 6, 5],
[7, 8, 8, 7],
[5, 6, 7, 8]])
torch.Size([6, 4])
It can be seen that when start_dim = 11 and end_dim = −1−1, it flattens and merges the eleventh dimension to the last dimension. And when start_dim = 00 and end_dim = 11, it flattened and merged the 00th dimension to the 11th dimension. The torch.nn.Flatten and torch.Tensor.flatten methods in pytorch are actually based on the torch.flatten function above.