Aprendizaje Pytorch (3) Capa lineal

prefacio

Antes de hablar de la transformación lineal lineal, veamos un ejemplo de transformación matricial.

from __future__ import print_function
import torch

in_features = torch.tensor([2,2,2,2], dtype=torch.float32)

weight_matrix = torch.tensor([
    [5,5,5],
    [3,3,3],
    [4,4,4],
    [2,2,2]
], dtype=torch.float32)

out_features = in_features.matmul(weight_matrix)
print(out_features)

Impresión:
tensor ([28., 28., 28.])

El ejemplo crea un tensor unidimensional llamado in_features, un tensor bidimensional de peso_matrix matriz de peso. Luego, use la función matmul() para realizar una operación de multiplicación de matrices que produzca un tensor unidimensional. Mapea un tensor 1D con cuatro elementos a un tensor 1D con tres elementos.

Así es como funciona Lineal. Usan una matriz de ponderación para mapear un espacio in_feature a un espacio out_feature.
Principio: Aplicar una transformación lineal a los datos de entrada: y = xA^T + b.

prototipo de función

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

Descripción de parámetros

inserte la descripción de la imagen aquí
in_features: se refiere al tamaño del tensor bidimensional de entrada, es decir, el tamaño en la entrada [batch_size, size].
out_features: se refiere al tamaño del tensor bidimensional de salida, es decir, la forma del tensor bidimensional de salida es [batch_size, out_features].
Volviendo al ejemplo de multiplicación de matrices al principio del artículo, la estructura de tensor unidimensional de in_features creada es [1,4], luego 4 es el tamaño de la entrada de tensor bidimensional por Linear, y la estructura de el tensor bidimensional de la matriz de peso weight_matrix es [4 ,3], entonces 3 aquí es el tamaño del tensor bidimensional generado por Linear. De acuerdo con las reglas del álgebra lineal de la multiplicación de matrices, cuando pasamos in_features = 4 y out_features = 3 a la función Linear(), la clase PyTorch LinearLayer creará automáticamente una matriz de peso de 4 x 3.

Ejemplo

m = torch.nn.Linear(4, 3)
input = torch.tensor([2,2,2,2], dtype=torch.float32)
output = m(input)
print(output)

imprimir:
inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/weixin_44901043/article/details/123765600
Recomendado
Clasificación