DARTS代码分析(Pytorch)

最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。

OPS = {
    'none': lambda C, stride, affine: Zero(stride),
    'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
    'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
    'skip_connect': lambda C, stride, affine: \
        Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
    'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
    'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
    'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
    'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
    'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
    'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}

首先定义10个操作,依次解释:

  • class PoolBN(nn.Module):
        """
        AvgPool or MaxPool - BN
        """
        def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
            """
            Args:
                pool_type: 'max' or 'avg'
            """
            super().__init__()
            if pool_type.lower() == 'max':
                self.pool = nn.MaxPool2d(kernel_size, stride, padding)
            elif pool_type.lower() == 'avg':
                self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
            else:
                raise ValueError()
    
            self.bn = nn.BatchNorm2d(C, affine=affine)
    
        def forward(self, x):
            out = self.pool(x)
            out = self.bn(out)
            return out

    这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去

  • class Identity(nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return x

    这个表示skip conncet

  • class FactorizedReduce(nn.Module):
        """
        Reduce feature map size by factorized pointwise(stride=2).
        """
        def __init__(self, C_in, C_out, affine=True):
            super().__init__()
            self.relu = nn.ReLU()
            self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
            self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
            self.bn = nn.BatchNorm2d(C_out, affine=affine)
    
        def forward(self, x):
            x = self.relu(x)
            out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
            out = self.bn(out)
            return out

    这个表示逐步减少特征图大小,通道数保持不变

猜你喜欢

转载自www.cnblogs.com/yqpy/p/11453074.html