Enlace de referencia Explicación detallada de nn.Linear () de PyTorch - douzujun - Blog Park (cnblogs.com)
Aquí hay una demostración de un tensor bidimensional completamente conectado:
De hecho, también puedes ingresar un tensor tridimensional, como se muestra a continuación:
from torch import nn
import torch
# in_features由输入张量的形状决定,out_features则决定了输出张量的形状
linear = nn.Linear(in_features=64 * 3, out_features=5)
# 10个 大小为7*64*3, 3个channel 的张量
a = torch.rand(10, 3, 7, 64 * 3)
print(a.shape) # torch.Size([10, 3, 7, 192])
print(linear.weight.shape) # torch.Size([5, 192])
b = linear(a)
print(b.shape) # torch.Size([10, 3, 7, 5])