Ten methods of PyTorch weight initialization

pytorch provides commonly used initialization method functions in torch.nn.init. Here is a brief introduction to facilitate query and use.

The introduction is divided into two parts:

1. Xavier, kaiming series;

2. Distribution by other methods

Xavier initialization method, the paper is in "Understanding the difficulty of training deep feedforward neural networks"

The formula is derived from the "consistency of variance", the initial distribution has two kinds of uniform distribution and normal distribution.

1. Xavier evenly distributed

torch.nn.init.xavier_uniform_(tensor, gain=1)

The xavier initialization method follows a uniform distribution U (−a, a), and the distribution parameter a = gain * sqrt (6 / fan_in + fan_out),

There is a gain, the size of the gain is set according to the type of activation function

eg:nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

PS: The above initialization method, also known as Glorot initialization

2. Xavier normal distribution

torch.nn.init.xavier_normal_(tensorgain=1)

The xavier initialization method follows a normal distribution,

mean=0,std = gain * sqrt(2/fan_in + fan_out)

 

Kaiming initialization method, the paper is in "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification", the formula derivation is also derived from the "variance consistency" method, kaiming is for the xavier initialization method in the relu this type of activation function does not perform well For the improvements proposed, please refer to the paper for details.

3. Kaiming is evenly distributed

torch.nn.init.kaiming_uniform_(tensora=0mode='fan_in'nonlinearity='leaky_relu')

This is a uniform distribution, U ~ (-bound, bound), bound = sqrt (6 / (1 + a ^ 2) * fan_in)

Where a is the slope of the negative half axis of the activation function and relu is 0

mode- optional fan_in or fan_out, fan_in makes the variance consistent when forward propagation; fan_out makes the variance consistent when backward propagation

nonlinearity- optional relu and leaky_relu, the default value is. leaky_relu

nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

4. kaiming normal distribution

torch.nn.init.kaiming_normal_(tensora=0mode='fan_in'nonlinearity='leaky_relu')

This is a normal distribution with a mean of 0, N ~ (0, std), where std = sqrt (2 / (1 + a ^ 2) * fan_in)

Where a is the slope of the negative half axis of the activation function and relu is 0

mode- optional fan_in or fan_out, fan_in makes the variance consistent when forward propagation; fan_out makes the variance consistent when back propagation

nonlinearity- optional relu and leaky_relu, the default value is. leaky_relu

nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

5. Initialization of uniform distribution

torch.nn.init.uniform_(tensora=0b=1)

Make the values ​​follow a uniform distribution U (a, b)

 

6. Initialization of normal distribution

torch.nn.init.normal_(tensormean=0std=1)

Make the value follow the normal distribution N (mean, std), the default value is 0, 1

 

7. Constant initialization

torch.nn.init.constant_(tensorval)

Make the value constant val nn.init.constant_ (w, 0.3)

 

8. Identity matrix initialization

torch.nn.init.eye_(tensor)

Initialize the two-dimensional tensor to the identity matrix (the identity matrix)

 

9. Orthogonal initialization

torch.nn.init.orthogonal_(tensorgain=1)

Makes tensor orthogonal, the paper: Exact solutions to the nonlinear dynamics of learning in deep linear neural networks "-Saxe, A. et al. (2013)

 

10. Sparse initialization

torch.nn.init.sparse_(tensorsparsitystd=0.01)

Perform sparseness from the normal distribution N ~ (0. Std), so that each column has a part of 0

sparsity- each column's sparse ratio, which is 0 ratio

nn.init.sparse_(w, sparsity=0.1)

11. Calculate the gain

torch.nn.init.calculate_gain(nonlinearityparam=None)

Published 943 original articles · Like 136 · Visit 330,000+

Guess you like

Origin blog.csdn.net/weixin_36670529/article/details/105257050