初始化你的pytorch模型

版权声明:本文为博主原创文章,转载请联系作者 https://blog.csdn.net/u013832707/article/details/80584768

相关资料

在设计好卷积神经网络模型后,面临的第一个问题就是如何进行初始化。如此博主查阅了一些资料,如下:
关于weight initialization的讨论
以及在该讨论下一些答主给出的例子:
https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
https://github.com/pytorch/examples/blob/master/dcgan/main.py#L131
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
博主采用的方式是pytorch文档中给出的VGG源码中使用的初始化方式。可以说只要是CNN,都可以采用该方式进行初始化。

初始化方法

在定义模型的类中添加子函数,直接调用即可

def _initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()

猜你喜欢

转载自blog.csdn.net/u013832707/article/details/80584768