pytorch parameter initialization
-
- 1. About common initialization methods
-
- 1) Uniform distribution initialization torch.nn.init.uniform_()
- 2) Normal distribution initialization torch.nn.init.normal_()
- 3) Constant initialization torch.nn.init.constant_()
- 4) Xavier uniform distribution
- 5) Xavier normal distribution initialization
- 6) Kaiming uniform distribution initialization
- 7) Kaiming normal distribution initialization
- 8) Unit initialization (may be used when optimizing some transformation matrices)
- 9) Orthogonal initialization
- 10) Custom initialization
1. About common initialization methods
1) Uniform distribution initialization torch.nn.init.uniform_()
torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
Make the input tensor obey the uniform distribution of (a,b) and return it.
2) Normal distribution initialization torch.nn.init.normal_()
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
Generates values from a normal distribution N(mean,std) with given mean and standard deviation, initializing tensors.
3) Constant initialization torch.nn.init.constant_()
torch.nn.init.constant_(tensor, val)
Initializes a tensor with a certain value.
4) Xavier uniform distribution
torch.nn.init.xavier_uniform_(tensor, gain=1.0)
Sampling from the uniform distribution U(−a, a), initialize the input tensor, where the value of a is determined by the following formula,
The gain value in the formula is determined according to different activation functions
5) Xavier normal distribution initialization
torch.nn.init.xavier_normal_(tensor, gain=1.0)
6) Kaiming uniform distribution initialization
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
7) Kaiming normal distribution initialization
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
[1]https://www.cxyzjd.com/article/CQUSongYuxin/110928126
8) Unit initialization (may be used when optimizing some transformation matrices)
torch.nn.init.eye_(tensor)
9) Orthogonal initialization
torch.nn.init.orthogonal_(tensor, gain=1)
10) Custom initialization
When building a model, you will encounter a time when you need to customize your own initialization data
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
::
---Conv-ReLU-Conv-+-
|________________|
Args:
mid_channels (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Used to scale the residual before addition.
Default: 1.0.
"""
def __init__(self, mid_channels=64, res_scale=1.0):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
# if res_scale < 1.0, use the default initialization, as in EDSR.
# if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet.
if res_scale == 1.0:
self.init_weights()
# 直接在__init__函数中使用
def init_weights(self):
"""Initialize weights for ResidualBlockNoBN.
Initialization methods like `kaiming_init` are for VGG-style
modules. For modules with residual paths, using smaller std is
better for stability and performance. We empirically use 0.1.
See more details in "ESRGAN: Enhanced Super-Resolution Generative
Adversarial Networks"
"""
# 初始化需要初始化的layer
for m in [self.conv1, self.conv2]:
nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
m.weight.data *= 0.1
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
identity = x
x=self.conv1(x)
x=self.relu(x)
out = self.conv2(x)
return identity + out * self.res_scale
In short, parameter initialization is to attach the required value to the parameter,
for example:
# 创建一个卷积层,它的权值是默认kaiming初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
# 先创建一个自定义权值的Tensor,这里为了方便只创建一个简单的tensor, 将所有权值设为1
ones=torch.Tensor(np.ones([2,2,3,3]))
# 当然也可以不使用numpy,直接torch.ones
ones=torch.ones((2,2,3,3))
# 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
w.weight=torch.nn.Parameter(ones)
[2] https://ptorch.com/docs/1/nn-init
[3] https://blog.csdn.net/goodxin_ie/article/details/84555805
For the theoretical introduction of initialization parameters, please refer to
[4] https ://zhuanlan.zhihu.com/p/25110150