Directorio de artículos
prefacio
Antes de hablar de la transformación lineal lineal, veamos un ejemplo de transformación matricial.
from __future__ import print_function
import torch
in_features = torch.tensor([2,2,2,2], dtype=torch.float32)
weight_matrix = torch.tensor([
[5,5,5],
[3,3,3],
[4,4,4],
[2,2,2]
], dtype=torch.float32)
out_features = in_features.matmul(weight_matrix)
print(out_features)
Impresión:
tensor ([28., 28., 28.])
El ejemplo crea un tensor unidimensional llamado in_features, un tensor bidimensional de peso_matrix matriz de peso. Luego, use la función matmul() para realizar una operación de multiplicación de matrices que produzca un tensor unidimensional. Mapea un tensor 1D con cuatro elementos a un tensor 1D con tres elementos.
Así es como funciona Lineal. Usan una matriz de ponderación para mapear un espacio in_feature a un espacio out_feature.
Principio: Aplicar una transformación lineal a los datos de entrada: y = xA^T + b.
prototipo de función
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
Descripción de parámetros
in_features: se refiere al tamaño del tensor bidimensional de entrada, es decir, el tamaño en la entrada [batch_size, size].
out_features: se refiere al tamaño del tensor bidimensional de salida, es decir, la forma del tensor bidimensional de salida es [batch_size, out_features].
Volviendo al ejemplo de multiplicación de matrices al principio del artículo, la estructura de tensor unidimensional de in_features creada es [1,4], luego 4 es el tamaño de la entrada de tensor bidimensional por Linear, y la estructura de el tensor bidimensional de la matriz de peso weight_matrix es [4 ,3], entonces 3 aquí es el tamaño del tensor bidimensional generado por Linear. De acuerdo con las reglas del álgebra lineal de la multiplicación de matrices, cuando pasamos in_features = 4 y out_features = 3 a la función Linear(), la clase PyTorch LinearLayer creará automáticamente una matriz de peso de 4 x 3.
Ejemplo
m = torch.nn.Linear(4, 3)
input = torch.tensor([2,2,2,2], dtype=torch.float32)
output = m(input)
print(output)
imprimir: