pytorch中weight-initilzation

pytorch权值初始化

官方论坛关于weight-initilzation的讨论

对模型参数进行初始化

官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)

 然后使用net.apply()进行参数初始化

m.__class__.__name__ 获得nn.module的名字

 DCGAN的Github链接

# DCGAN中权重初始化代码
def weights_init(m):
    classname = m.__class__.__name__      # 得到网络层的名字,如ConvTransposed2d
    if classname.find('Conv') != -1:      # 使用find函数,如果不存在返回值为 -1 ,所以让其不等于 -1
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


######################################################################
# Now, we can instantiate the generator and apply the ``weights_init``
# function. Check out the printed model to see how the generator object is
# structured.
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

torch.nn.Module.apply(fn)中apply的应用

参考Pytorch官方手册troch.nn中apply

torch.nn.Module.apply(fn) 
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数 
# 常用来对模型的参数进行初始化 
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数 
# fn (Module -> None) – function to be applied to each submodule 
# Returns:  self 
# Return type:  Module
import torch
import torch.nn as nn
def init_weights(m):
    classname = m.__class__.__name__
    print(m)
    if classname.find('Line') != -1:
        m.weight.data.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(3, 3))
net.apply(init_weights)

结果:

Linear(in_features=2, out_features=2, bias=True)

 1  1
 1  1
[torch.FloatTensor of size (2,2)]

Linear(in_features=3, out_features=3, bias=True)

 1  1  1
 1  1  1
 1  1  1
[torch.FloatTensor of size (3,3)]

Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=3, out_features=3, bias=True)
)

Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=3, out_features=3, bias=True)

apply(fn):将fn函数递归地应用到网络模型的每个子模型中,主要用在参数的初始化。

使用apply()时,需要先定义一个参数初始化的函数。

之后,定义自己的网络,得到网络模型,使用apply()函数,就可以分别对conv层和bn层进行参数初始化。

Reference:

pytorch的weight-initilzation

pytorch使用记录(二) 参数初始化

DCGAN TUTORIAL----做项目学pytorch

猜你喜欢

转载自blog.csdn.net/xrinosvip/article/details/86503130