ACSNet segmentation model construction

Original paper : Adaptive Context Selection for Polyp Segmentation
Source code: https://github.com/ReaFly/ACSNet.

Straight to the point~~~

1. Basic modules

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, padding=1):
        super(DecoderBlock, self).__init__()

        self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsample(x)
        return x


class SideoutBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(SideoutBlock, self).__init__()

        self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.dropout = nn.Dropout2d(0.1)

        self.conv2 = nn.Conv2d(in_channels // 4, out_channels, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.conv2(x)

        return x

2. LCA module 

class LCA(nn.Module):
    def __init__(self):
        super(LCA, self).__init__()

    def forward(self, x, pred): #x:256,16,16  pre:1,16,16
        residual = x
        score = torch.sigmoid(pred)
        dist = torch.abs(score - 0.5)
        att = 1 - (dist / 0.5)
        att_x = x * att #256,16,16
        out = att_x + residual #256,16,16

        return out

3. GCM module

class GCM(nn.Module):
    def __init__(self, in_channels, out_channels): #in_channels=512, out_channels=64
        super(GCM, self).__init__()
        pool_size = [1, 3, 5]
        out_channel_list = [256, 128, 64, 64]
        upsampe_scale = [2, 4, 8, 16]
        GClist = []
        GCoutlist = []
        for ps in pool_size:
            GClist.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(ps),
                nn.Conv2d(in_channels, out_channels, 1, 1),
                nn.ReLU(inplace=True)))
        GClist.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, 1),
            nn.ReLU(inplace=True),
            NonLocalBlock(out_channels)))
        self.GCmodule = nn.ModuleList(GClist)
        for i in range(4):
            GCoutlist.append(nn.Sequential(nn.Conv2d(out_channels * 4, out_channel_list[i], 3, 1, 1),
                                           nn.ReLU(inplace=True),
                                           nn.Upsample(scale_factor=upsampe_scale[i], mode='bilinear')))
        self.GCoutmodel = nn.ModuleList(GCoutlist)

    def forward(self, x): # 输入x: 512,8,8
        xsize = x.size()[2:]
        global_context = []
        for i in range(len(self.GCmodule) - 1): #range(3)
            global_context.append(F.interpolate(self.GCmodule[i](x), xsize, mode='bilinear', align_corners=True))
        global_context.append(self.GCmodule[-1](x))
        global_context = torch.cat(global_context, dim=1)

        output = []
        for i in range(len(self.GCoutmodel)): #range(4)
            output.append(self.GCoutmodel[i](global_context))

        return output

4. NonLocalBlock module

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): #in_channels=64
        super(NonLocalBlock, self).__init__()

        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                          kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(self.in_channels)
            )
            # nn.init.constant_(tensor, val):基于输入参数(val)初始化输入张量tensor,即tensor的值均初始化为val。
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                               kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
            self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))

    def forward(self, x): #bs,64,8,8

        batch_size = x.size(0)

        # bs,64,8,8->bs,32,4,4->bs,32,16
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1) #bs,16,32

        # bs,64,8,8->bs,32,8,8->bs,32,64
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1) #bs,64,32

        # bs,64,8,8->bs,32,4,4->bs,32,16
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x) #bs,64,16
        f_div_C = F.softmax(f, dim=-1) #bs,64,16

        y = torch.matmul(f_div_C, g_x) #bs,64,32
        y = y.permute(0, 2, 1).contiguous() #bs,32,64
        y = y.view(batch_size, self.inter_channels, *x.size()[2:]) #bs,32,8,8
        W_y = self.W(y) #bs,64,8,8
        z = W_y + x #bs,64,8,8

        return z

5. SE module

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

6. ASM module

class ASM(nn.Module):
    def __init__(self, in_channels, all_channels):
        super(ASM, self).__init__()
        self.non_local = NonLocalBlock(in_channels)
        self.selayer = SELayer(all_channels)

    def forward(self, lc, fuse, gc):
        fuse = self.non_local(fuse)
        fuse = torch.cat([lc, fuse, gc], dim=1)
        fuse = self.selayer(fuse)

        return fuse

Seven, ACSNet network structure

class ACSNet(nn.Module):
    def __init__(self, num_classes):
        super(ACSNet, self).__init__()

        self.resnet = resnet34(pretrained=False)
        
        # Encoder
        self.encoder1_conv = self.resnet.conv1
        self.encoder1_bn = self.resnet.bn1
        self.encoder1_relu = self.resnet.relu
        self.maxpool = self.resnet.maxpool
        self.encoder2 = self.resnet.layer1
        self.encoder3 = self.resnet.layer2
        self.encoder4 = self.resnet.layer3
        self.encoder5 = self.resnet.layer4

        # Decoder
        self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=1024, out_channels=256)
        self.decoder3 = DecoderBlock(in_channels=512, out_channels=128)
        self.decoder2 = DecoderBlock(in_channels=256, out_channels=64)
        self.decoder1 = DecoderBlock(in_channels=192, out_channels=64)

        self.outconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1),
                                      nn.Dropout2d(0.1),
                                      nn.Conv2d(32, num_classes, 1))

        # Sideout
        self.sideout2 = SideoutBlock(64, 1)
        self.sideout3 = SideoutBlock(128, 1)
        self.sideout4 = SideoutBlock(256, 1)
        self.sideout5 = SideoutBlock(512, 1)

        # local context attention module
        self.lca1 = LCA()
        self.lca2 = LCA()
        self.lca3 = LCA()
        self.lca4 = LCA()

        # global context module
        self.gcm = GCM(512, 64)

        # adaptive selection module
        self.asm4 = ASM(512, 1024)
        self.asm3 = ASM(256, 512)
        self.asm2 = ASM(128, 256)
        self.asm1 = ASM(64, 192)

    def forward(self, x):
        # x: 3,256,256
        e1 = self.encoder1_conv(x)  # 64,128,128
        e1 = self.encoder1_bn(e1)
        e1 = self.encoder1_relu(e1)
        e1_pool = self.maxpool(e1) # 64,64,64
        e2 = self.encoder2(e1_pool) # 64,64,64
        e3 = self.encoder3(e2)  # 128,32,32
        e4 = self.encoder4(e3)  # 256,16,16
        e5 = self.encoder5(e4)  # 512,8,8

        global_contexts = self.gcm(e5)
        # print(global_contexts[0].shape) [1, 256, 16, 16]
        # print(global_contexts[1].shape) [1, 128, 32, 32]
        # print(global_contexts[2].shape) [1, 64, 64, 64]
        # print(global_contexts[3].shape) [1, 64, 128, 128]
        
        d5 = self.decoder5(e5) # 512,8,8->512,16,16
        out5 = self.sideout5(d5) # 1,16,16
        lc4  = self.lca4(e4, out5) # 256,16,16
        gc4 = global_contexts[0]
        comb4 = self.asm4(lc4, d5, gc4) # 1024, 16, 16

        d4 = self.decoder4(comb4) # 256, 32, 32
        out4 = self.sideout4(d4) # 1, 32, 32
        lc3 = self.lca3(e3, out4) # 128, 32, 32
        gc3 = global_contexts[1]
        comb3 = self.asm3(lc3, d4, gc3) # 512,32,32


        d3 = self.decoder3(comb3)  # 128,64,64
        out3 = self.sideout3(d3) # 1,64,64
        lc2 = self.lca2(e2, out3) # 64,64,64
        gc2 = global_contexts[2]
        comb2 = self.asm2(lc2, d3, gc2)  # 256, 64, 64

        d2 = self.decoder2(comb2)  # 64,128,128
        out2 = self.sideout2(d2) # 1,128,128
        lc1 = self.lca1(e1, out2) # 64,128,128
        gc1 = global_contexts[3]
        comb1 = self.asm1(lc1, d2, gc1) # 192,128,128


        d1 = self.decoder1(comb1)  # 64,256,256
        out1 = self.outconv(d1)  # num_classes,256,256

        # return out1
        return torch.sigmoid(out1), torch.sigmoid(out2), torch.sigmoid(out3), \
            torch.sigmoid(out4), torch.sigmoid(out5)


if __name__ == '__main__':
    input_tensor = torch.randn((1, 3, 256, 256))
    model = ACSNet(num_classes=4)
    # out1 = model(input_tensor)
    # print(out1.shape)
    o1,o2,o3,o4,o5 = model(input_tensor)
    print(o1.shape,o2.shape,o3.shape,o4.shape,o5.shape)

 8. Loss function (Deep Supervision Loss)

def DeepSupervisionLoss(pred, gt):
    d0, d1, d2, d3, d4 = pred[0:]

    criterion = BceDiceLoss()

    loss0 = criterion(d0, gt) #256,256
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss1 = criterion(d1, gt) #128,128
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss2 = criterion(d2, gt) #64,64
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss3 = criterion(d3, gt) #32,32
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss4 = criterion(d4, gt) #16,16

    return loss0 + loss1 + loss2 + loss3 + loss4

Guess you like

Origin blog.csdn.net/m0_56247038/article/details/131423708