torch.nn.Linear () function appreciated

import torch

x = torch.randn(128, 20) # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)

# Ans = torch.mm (input, torch.t (m.weight)) + m.bias equivalent to the following
ANS = torch.mm (X, m.weight.t ()) + m.bias
Print ( ' ans.shape: \ n ', ans.shape)

print(torch.equal(ans, output))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
True
1
2
3
4
5
6
7
8
9
为什么 m.weight.shape = (30,20)?

A: Because the linear transformation formula is:

y=xAT+b y=xA^T+b
y=xA
T
+b

Mr. into a (30, 20) of the weight, the actual operation of the sub-set, so you can do x and the matrix multiplication
---------------------
Author: m0_37586991
source: CSDN
original: https: //blog.csdn.net/m0_37586991/article/details/87861418
copyright: This article is a blogger original article, reproduced, please attach Bowen link!

Guess you like

Origin www.cnblogs.com/jfdwd/p/11068544.html