一. 用法
Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。
二. 参数
1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim
被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。
2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim
被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。
三. 实例
(1). 首先随机定义一个满足正态分布的(2,3,4)的数据x
import torch
x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)
------------------------------------
tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721],
[ 1.2374, -0.6929, 1.1186, 0.4372],
[ 0.5122, 1.4653, -0.1673, 0.7258]],
[[ 0.2772, -1.9994, -1.2284, 0.2764],
[-0.0451, -0.9195, 0.5749, 0.1942],
[ 0.8539, -0.0434, -0.7313, 0.0234]]])
tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372,
0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764,
-0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234])
此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。
(2).
import torch
x = torch.randn(2,3,4)
print(x)
x = x.flatten(1)
print(x)
===========================================
tensor([[[-0.7137, -0.0859, -1.5284, 0.7284],
[ 0.8425, 0.3606, 1.7639, 0.1848],
[ 0.4040, -1.6575, 1.9134, -1.0787]],
[[ 0.6981, 1.3494, -0.5817, -1.1824],
[-0.4972, 0.4179, 2.1742, -0.2462],
[ 0.2429, -1.9315, -0.3497, 0.7190]]])
tensor([[-0.7137, -0.0859, -1.5284, 0.7284, 0.8425, 0.3606, 1.7639, 0.1848,
0.4040, -1.6575, 1.9134, -1.0787],
[ 0.6981, 1.3494, -0.5817, -1.1824, -0.4972, 0.4179, 2.1742, -0.2462,
0.2429, -1.9315, -0.3497, 0.7190]])
此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)
注意:start_dim
和end_dim
参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim()
之间。