pytorch权值初始化
对模型参数进行初始化
官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3
单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
然后使用net.apply()进行参数初始化
m.__class__.__name__ 获得nn.module的名字
# DCGAN中权重初始化代码
def weights_init(m):
classname = m.__class__.__name__ # 得到网络层的名字,如ConvTransposed2d
if classname.find('Conv') != -1: # 使用find函数,如果不存在返回值为 -1 ,所以让其不等于 -1
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
######################################################################
# Now, we can instantiate the generator and apply the ``weights_init``
# function. Check out the printed model to see how the generator object is
# structured.
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
torch.nn.Module.apply(fn)中apply的应用
torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) – function to be applied to each submodule
# Returns: self
# Return type: Module
import torch
import torch.nn as nn
def init_weights(m):
classname = m.__class__.__name__
print(m)
if classname.find('Line') != -1:
m.weight.data.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(3, 3))
net.apply(init_weights)
结果:
Linear(in_features=2, out_features=2, bias=True)
1 1
1 1
[torch.FloatTensor of size (2,2)]
Linear(in_features=3, out_features=3, bias=True)
1 1 1
1 1 1
1 1 1
[torch.FloatTensor of size (3,3)]
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=3, out_features=3, bias=True)
)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=3, out_features=3, bias=True)
apply(fn):将fn函数递归地应用到网络模型的每个子模型中,主要用在参数的初始化。
使用apply()时,需要先定义一个参数初始化的函数。
之后,定义自己的网络,得到网络模型,使用apply()函数,就可以分别对conv层和bn层进行参数初始化。
Reference: