pytorch自定义网络

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            # 输入3,输出 16,核大小为5,步长1,填充边缘1,填充方式 0填充
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5,
                      stride=1, padding=2, bias=False),
            nn.ReLU(),
            # num_features取 通道数
            nn.BatchNorm2d(num_features=16, affine=False)
        )
        self.conv2 = nn.Sequential(
            #添加自己的操作
            。。。
            
        )
        self.conv3 = nn.Sequential(
            。。。
        )
        self.fc = nn.Sequential(
            。。。

        )
    
    #前向传播
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    #初始化权重
    def initialize_weights(self):
        print(self.modules())
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 0.1)
                m.bias.data.zero_()

猜你喜欢

转载自blog.csdn.net/qq_55542491/article/details/130871651