Verwendung der Funktion flatten() in PyTorch

1. Nutzung

Flatten Layer wird hauptsächlich zum „Flatten“ der Eingabe verwendet, das heißt, um mehrdimensionale Eingaben in eine Dimension umzuwandeln, und wird in Faltungsschichten zu verwendet Übergang der vollständig verbundenen Schicht. Dies hat keinen Einfluss auf die Stapelgröße. Es kann so verstanden werden, dass das Array mit hohem Breitengrad entsprechend der x-Achse oder der y-Achse in ein eindimensionales Array gestreckt wird.

2. Parameter

      1.start_dim (optionaler Parameter): Geben Sie die Dimension an, ab der mit der Abflachung des Tensors begonnen werden soll. Standardmäßig ist start_dim auf 0 gesetzt, was bedeutet, dass die Reduzierung ab der ersten Dimension (normalerweise der Stapelgröße) beginnt. Bei Festlegung auf einen anderen ganzzahligen Wert beginnt die Reduzierung ab der angegebenen Dimension.

       2.end_dim (optionaler Parameter): Geben Sie die Dimension an, in der der abgeflachte Tensor enden soll. Standardmäßig ist end_dim auf -1 eingestellt, was eine Reduzierung auf die letzte Dimension bedeutet. Bei Festlegung auf einen anderen ganzzahligen Wert endet die Reduzierung bei der angegebenen Dimension.

3. Beispiele

 (1). Definieren Sie zunächst zufällig Daten x, die die Normalverteilung (2, 3, 4) erfüllen.

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(0)
print(x)

------------------------------------
tensor([[[ 0.1281,  1.6878,  0.2301, -0.0721],
         [ 1.2374, -0.6929,  1.1186,  0.4372],
         [ 0.5122,  1.4653, -0.1673,  0.7258]],

        [[ 0.2772, -1.9994, -1.2284,  0.2764],
         [-0.0451, -0.9195,  0.5749,  0.1942],
         [ 0.8539, -0.0434, -0.7313,  0.0234]]])
tensor([ 0.1281,  1.6878,  0.2301, -0.0721,  1.2374, -0.6929,  1.1186,  0.4372,
         0.5122,  1.4653, -0.1673,  0.7258,  0.2772, -1.9994, -1.2284,  0.2764,
        -0.0451, -0.9195,  0.5749,  0.1942,  0.8539, -0.0434, -0.7313,  0.0234])

Zu diesem Zeitpunkt beträgt die Dimension von x 2 × 3 × 4 = 24, und die Ergebnisse von x = flatten (0) und x = flatten () sind dieselben.

 (2).

import torch 

x = torch.randn(2,3,4)
print(x)
x = x.flatten(1)
print(x)

===========================================
tensor([[[-0.7137, -0.0859, -1.5284,  0.7284],
         [ 0.8425,  0.3606,  1.7639,  0.1848],
         [ 0.4040, -1.6575,  1.9134, -1.0787]],

        [[ 0.6981,  1.3494, -0.5817, -1.1824],
         [-0.4972,  0.4179,  2.1742, -0.2462],
         [ 0.2429, -1.9315, -0.3497,  0.7190]]])
tensor([[-0.7137, -0.0859, -1.5284,  0.7284,  0.8425,  0.3606,  1.7639,  0.1848,
          0.4040, -1.6575,  1.9134, -1.0787],
        [ 0.6981,  1.3494, -0.5817, -1.1824, -0.4972,  0.4179,  2.1742, -0.2462,
          0.2429, -1.9315, -0.3497,  0.7190]])

Zu diesem Zeitpunkt wird x von Dimension 1 erweitert und die endgültige x-Dimension ist (2, 3 × 4), was (2, 12) ist.

Hinweis: Der Wertebereich der Parameter start_dim und end_dim sollte zwischen -x.dim() <= start_dim <= end_dim < x.dim() liegen.

Guess you like

Origin blog.csdn.net/m0_62278731/article/details/134263429