Pytorch learning (3) Linear layer

foreword

Before talking about Linear linear transformation, let's look at an example of matrix transformation

from __future__ import print_function
import torch

in_features = torch.tensor([2,2,2,2], dtype=torch.float32)

weight_matrix = torch.tensor([
    [5,5,5],
    [3,3,3],
    [4,4,4],
    [2,2,2]
], dtype=torch.float32)

out_features = in_features.matmul(weight_matrix)
print(out_features)

Printout:
tensor([28., 28., 28.])

The example creates a one-dimensional tensor called in_features, a two-dimensional tensor of weight_matrix weight matrix. Then, use the matmul() function to perform a matrix multiplication operation that produces a one-dimensional tensor. It maps a 1D tensor with four elements to a 1D tensor with three elements.

This is how Linear works. They use a weight matrix to map an in_feature space to an out_feature space.
Principle: Apply a linear transformation to the input data: y = xA^T + b.

function prototype

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

Parameter Description

insert image description here
in_features: refers to the size of the input two-dimensional tensor, that is, the size in the input [batch_size, size].
out_features: refers to the size of the output two-dimensional tensor, that is, the shape of the output two-dimensional tensor is [batch_size, out_features].
Going back to the example of matrix multiplication at the beginning of the article, the created in_features one-dimensional tensor structure is [1,4], then 4 is the size of the two-dimensional tensor input by Linear, and the structure of the weight_matrix weight matrix two-dimensional tensor is [4 ,3], then 3 here is the size of the two-dimensional tensor output by Linear. According to the linear algebra rules of matrix multiplication, when we pass in_features = 4 and out_features = 3 to the Linear() function, the PyTorch LinearLayer class will automatically create a 4 x 3 weight matrix.

Example

m = torch.nn.Linear(4, 3)
input = torch.tensor([2,2,2,2], dtype=torch.float32)
output = m(input)
print(output)

printout:
insert image description here

Guess you like

Origin blog.csdn.net/weixin_44901043/article/details/123765600