Fly paddle learning - paddle.nn.Flatten

paddle.nn.Flatten

paddle.nn.Flatten(start_axis=1, stop_axis=- 1)

Function:
It realizes the flattening of a continuous dimension Tensor into a one-dimensional Tensor.
Parameters:
start_axis (int, optional) - start dimension of expansion, default value is 1.
stop_axis (int, optional) – The end dimension of the expansion, default is -1.

For example, a tensor: [5, 2, 3, 4] is a four-dimensional tensor. If the default value of start_axis is 1, it corresponds to the dimension of '2', and the default value of stop_axis is -1 (the complex number is from Back and forward index) corresponds to the dimension where '4' is located. paddle.nn.Flatten(start_axis=1, stop_axis=- 1) will flatten the two dimensions into one dimension. That is the example of [5,2 3 4]
:

import paddle
import numpy as np

inp_np = np.ones([5, 2, 3, 4]).astype('float32')
inp_np = paddle.to_tensor(inp_np)
flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
flatten_res = flatten(inp_np)
print(flatten_res )

insert image description here
source code:

def flatten_(x, start_axis=0, stop_axis=-1, name=None):
    """
    Inplace version of ``flatten`` API, the output Tensor will be inplaced with input ``x``.
    Please refer to :ref:`api_tensor_flatten`.
    """
    if not (isinstance(x, Variable)):
        raise ValueError("The input x should be a Tensor")

    x_dim = len(x.shape)
    if not (isinstance(start_axis, int)) or (
            start_axis > x_dim - 1) or start_axis < -x_dim:
        raise ValueError(
            "The start_axis should be a int, and in range [-rank(x), rank(x))")
    if not (isinstance(stop_axis, int)) or (
            stop_axis > x_dim - 1) or stop_axis < -x_dim:
        raise ValueError(
            "The stop_axis should be a int, and in range [-rank(x), rank(x))")
    if start_axis < 0:
        start_axis = start_axis + x_dim
    if stop_axis < 0:
        stop_axis = stop_axis + x_dim
    if start_axis > stop_axis:
        raise ValueError("The stop_axis should be larger than stat_axis")

    dy_out, _ = _C_ops.flatten_contiguous_range_(x, 'start_axis', start_axis,
                                                 'stop_axis', stop_axis)
    return dy_out

Guess you like

Origin blog.csdn.net/m0_66478571/article/details/122611283