【IPMI 2023】Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation

Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation, IPMI 2023

Interpretation: IPMI 2023 Hong Kong University of Science and Technology Chen Hao's team's new work | CTO: Rethinking the role of boundary detection in medical image segmentation (qq.com)

Paper:  https://arxiv.org/abs/2305.00678

Code:  https://github.com/xiaofang007/CTO

introduce

In this paper, we propose a novel network architecture CTO, namely ConvolutionTransformer and  Operator, to achieve high-precision image segmentation with an optimal balance between accuracy and efficiency by combining convolutional neural networks, visual Transformers, and explicit boundary detection operations.

CTO follows the standard encoder-decoder segmentation paradigm, where the encoder network adopts the popular CNN backbone structure to capture local semantic information, and uses a lightweight ViT auxiliary network to integrate long-range dependencies. To enhance the boundary learning ability, this paper further proposes a boundary-guided decoder network, which uses the boundary mask obtained from a dedicated boundary detection operation as explicit supervision to guide the decoding learning process.

Convolution, Transformer, and Operator (CTO)

CTO follows the encoder-decoder paradigm and employs skip connections to aggregate low-level features from the encoder into the decoder. The encoder network is composed of mainstream CNN and auxiliary ViT. The decoder network employs a boundary detection operator to guide its learning process.

  • A dual-stream encoder, which combines a convolutional neural network and a lightweight visual Transformer, captures image local feature dependencies and long-range feature dependencies between image patches, respectively.
  • Operator-guided decoders that use boundary detection operators (for example Sobel) to guide the learning process through generated boundary masks, and the entire model is trained in an end-to-end fashion.

Dual-Stream Encoder

CTO first builds a convolutional flow, chosen Res2Netas the backbone network, to capture local feature dependencies.

CTO uses a lightweight Vision Transformer-based auxiliary flow to capture long-range dependencies between different image patches. Specifically, it consists of multiple parallel lightweight Transformer blocks that receive feature blocks of different scales as input. All Transformer blocks share a similar structure, including block embedding layers and Transformer encoding layers.

The block embedding layer of LightViT is used to convert the input feature block into an embedding vector, and convert the spatial dimension into the sequence dimension. The Transformer encoding layer is used to model the feature blocks with a self-attention mechanism to capture the long-range dependencies between different feature blocks. By introducing a self-attention mechanism in the Transformer module, LightViT can effectively model the interactions among feature blocks to extract the global context information of images.

Boundary-Guided Decoder

The boundary-guided decoder uses a gradient operator module to extract boundary information of foreground objects. Then, through the boundary optimization module, the boundary-enhanced features are integrated with the features of the multi-level encoder, aiming to simultaneously characterize the intra-class and inter-class consistency in the feature space, and enrich the representation ability of the features. This approach enables the decoder to better utilize boundary information when generating segmentation results, resulting in more accurate segmentation results.

Boundary Enhanced Module (BEM)

The boundary optimization module uses high-level features and low-level features as input, extracts boundary information and filters out information irrelevant to the boundary. SobelThe operator is applied in the horizontal direction Gx and the vertical direction Gy to obtain the gradient map. Specifically, this paper adopts two 3*3 parameter fixed convolutions, and applies a convolution operation with a stride of 1. These two convolutions are defined as:

Then, these two convolutions are applied to the input feature map to obtain gradient maps Mx and My. Next, the gradient map  sigmoid is normalized by the function, and then fused with the input feature map to obtain the enhanced edge feature map Fe:

Among them, the circle number represents element-wise multiplication, \sigma which represents the sigmoid function, and Mxy is the splicing of Mx and My along the channel dimension. Then, we can directly fuse the edge-enhanced feature maps using simple stacked convolutional layers. Finally, the output feature map is supervised by the GT boundary map, which removes edge features inside objects and produces boundary enhanced features.

Boundary Inject Module (BIM)

The boundary-enhanced features obtained by BEM can be used as prior knowledge to improve the image representation ability of the features generated by the encoder. For BIM, a dual-path boundary fusion scheme is introduced to facilitate the representation ability of foreground and background features. Specifically, BIM receives two inputs: a channel-level connection of boundary-enhanced features with corresponding features from the encoder network, and features from the previous decoder layer. These two inputs are then fed into the BIM, which contains two separate paths to facilitate feature representation for foreground and background, respectively.

  • For the foreground path, we concatenate the two inputs directly along the channel dimension, and then apply a series of Conv-BN-ReLU (convolution, batch normalization, ReLU activation) layers to obtain foreground features.
  • For the background path, a background attention component is designed to selectively focus on background information.

The foreground path gets the foreground feature Ffg. The background path gets the background feature Fbg. 

The foreground attention map is obtained from the feature map of the previous layer of decoder through sigmoid; the background attention map is obtained by subtracting the foreground attention map from 1. Finally, the foreground feature Ffg, background feature Fbg, and previous layer decoder features are concatenated to obtain the output of this layer.

Loss Function

CTOis a multi-task model that includes interior and boundary segmentation, defining an overall loss function to jointly optimize these two tasks:

The overall loss consists of the main inner segmentation loss L_seg and the boundary loss L_bnd. In the boundary detection loss, only the predictions from the BEM are considered, and this module takes as input the high-level feature maps of the encoder and the low-level feature maps.

Interior Segmentation Loss

L_seg is the weighted sum of the cross-entropy loss L_CE and the average IoU loss L_mIoU:

Boundary Loss

The boundary loss L_bnd takes into account the category imbalance between foreground and background pixels in boundary detection, so Dice loss is used:

experiment

 

key code

CTO_net.py

# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py

class ConvBNR(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
        super(ConvBNR, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class Conv1x1(nn.Module):
    def __init__(self, inplanes, planes):
        super(Conv1x1, self).__init__()
        self.conv = nn.Conv2d(inplanes, planes, 1)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x


class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(512, 256)
        self.block = nn.Sequential(
            ConvBNR(320 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x1, x11, p2):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x11 = self.reduce1(x11)
        p2 = self.reduce4(p2)
        p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
        out = torch.cat((x1, x11), dim=1)
        out = torch.cat((out, p2), dim=1)
        out = self.block(out)

        return out



class EFM(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att + c
        x = self.conv2d(x)
        wei = self.avg_pool(x)
        wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        wei = self.sigmoid(wei)
        x = x * wei

        return x

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class DM(nn.Module):
    def __init__(self):
        super(DM, self).__init__()
        self.predict3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 1, kernel_size=1)
        )
        self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, xr, dualattention):
        crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
        re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
        x = -1*(torch.sigmoid(crop_3)) + 1
        x = x.expand(-1, 64, -1, -1).mul(xr)
        x = F.relu(self.ra2_conv2(x))
        x = F.relu(self.ra2_conv3(x))
        ra3_feat = self.ra2_conv4(x)
        x = ra3_feat + crop_3 + re3_feat


        return x


class _DAHead(nn.Module):
    def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
        super(_DAHead, self).__init__()
        self.aux = aux
        inter_channels = in_channels // 4
        self.conv_p1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.pam = _PositionAttentionModule(inter_channels, **kwargs)
        self.cam = _ChannelAttentionModule(**kwargs)
        self.conv_p2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.out = nn.Sequential(
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, nclass, 1)
        )
        if aux:
            self.conv_p3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )
            self.conv_c3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )

    def forward(self, x):
        feat_p = self.conv_p1(x)
        feat_p = self.pam(feat_p)
        feat_p = self.conv_p2(feat_p)

        feat_c = self.conv_c1(x)
        feat_c = self.cam(feat_c)
        feat_c = self.conv_c2(feat_c)

        feat_fusion = feat_p + feat_c

        outputs = []
        fusion_out = self.out(feat_fusion)
        outputs.append(fusion_out)
        if self.aux:
            p_out = self.conv_p3(feat_p)
            c_out = self.conv_c3(feat_c)
            outputs.append(p_out)
            outputs.append(c_out)

        return tuple(outputs)

def run_sobel(conv_x, conv_y, input):
    g_x = conv_x(input)
    g_y = conv_y(input)
    g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
    return torch.sigmoid(g) * input

def get_sobel(in_chan, out_chan):
    '''
    filter_x = np.array([
        [3, 0, -3],
        [10, 0, -10],
        [3, 0, -3],
    ]).astype(np.float32)
    filter_y = np.array([
        [3, 10, 3],
        [0, 0, 0],
        [-3, -10, -3],
    ]).astype(np.float32)
    '''
    filter_x = np.array([
        [1, 0, -1],
        [2, 0, -2],
        [1, 0, -1],
    ]).astype(np.float32)
    filter_y = np.array([
        [1, 2, 1],
        [0, 0, 0],
        [-1, -2, -1],
    ]).astype(np.float32)
    filter_x = filter_x.reshape((1, 1, 3, 3))
    filter_x = np.repeat(filter_x, in_chan, axis=1)
    filter_x = np.repeat(filter_x, out_chan, axis=0)

    filter_y = filter_y.reshape((1, 1, 3, 3))
    filter_y = np.repeat(filter_y, in_chan, axis=1)
    filter_y = np.repeat(filter_y, out_chan, axis=0)

    filter_x = torch.from_numpy(filter_x)
    filter_y = torch.from_numpy(filter_y)
    filter_x = nn.Parameter(filter_x, requires_grad=False)
    filter_y = nn.Parameter(filter_y, requires_grad=False)
    conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_x.weight = filter_x
    conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_y.weight = filter_y
    sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
    sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
    return sobel_x, sobel_y

class GlobalFilter(nn.Module):
    def __init__(self, dim=32, h=64, w=33, fp32fft=True):
        super().__init__()
        self.complex_weight = nn.Parameter(
            torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
        )
        self.w = w
        self.h = h
        self.fp32fft = fp32fft

    def forward(self, x):
        b, _, a, b = x.size()
        x = x.permute(0, 2, 3, 1).contiguous()

        if self.fp32fft:
            dtype = x.dtype
            x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        #print(x.shape)
        weight = torch.view_as_complex(self.complex_weight)
       # print(x.shape)
        #print(weight.shape)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")

        if self.fp32fft:
            x = x.to(dtype)

        x = x.permute(0, 3, 1, 2).contiguous()

        return x

class ERB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ERB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, relu=True):
        x = self.conv1(x)
        res = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        if relu:
            return self.relu(x + res)
        else:
            return x+res

class _PositionAttentionModule(nn.Module):
    """ Position attention module"""

    def __init__(self, in_channels, **kwargs):
        super(_PositionAttentionModule, self).__init__()
        self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
        self.alpha = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        feat_c = self.conv_c(x).view(batch_size, -1, height * width)
        attention_s = self.softmax(torch.bmm(feat_b, feat_c))
        feat_d = self.conv_d(x).view(batch_size, -1, height * width)
        feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
        out = self.alpha * feat_e + x

        return out


class _ChannelAttentionModule(nn.Module):
    """Channel attention module"""

    def __init__(self, **kwargs):
        super(_ChannelAttentionModule, self).__init__()
        self.beta = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_a = x.view(batch_size, -1, height * width)
        feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
        attention = torch.bmm(feat_a, feat_a_transpose)
        attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
        attention = self.softmax(attention_new)

        feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
        out = self.beta * feat_e + x

        return out
        
class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(2048, 256)
        self.block = nn.Sequential(
            ConvBNR(256 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x4, x1):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x4 = self.reduce4(x4)
        x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
        out = torch.cat((x4, x1), dim=1)
        out = self.block(out)

        return out

def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
        query.size(-1)
    )
    p_attn = F.softmax(scores, dim=-1)
    p_val = torch.matmul(p_attn, value)
    return p_val, p_attn

class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, patchsize, d_model):
        super().__init__()
        self.patchsize = patchsize
        self.query_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.value_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.key_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.output_linear = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        b, c, h, w = x.size()#8,255,64,64
        d_k = c // len(self.patchsize)
        output = []
        _query = self.query_embedding(x)#8,32,80,80
        _key = self.key_embedding(x)#8,32,80,80
        _value = self.value_embedding(x)#8,32,80,80
        attentions = []
        for (width, height), query, key, value in zip(
            self.patchsize,
            torch.chunk(_query, len(self.patchsize), dim=1),
            torch.chunk(_key, len(self.patchsize), dim=1),
            torch.chunk(_value, len(self.patchsize), dim=1),
        ):
            #print('-----------width, height):',x.size())
           # print('-----------x.size()):',x.size())
            
            #print('-----------len(self.patchsize):',len(self.patchsize))  # 4
            
            #print('-----------_query):',_query.shape)   #8,256,64,64
            
            #print('-----------query):',query.shape)  #8,64,64,64
            
            out_w, out_h = w // width, h // height#
            ## 1) embedding and reshape
            query = query.view(b, d_k, out_h, height, out_w, width)
           # print('-----------query):',query.shape)
            
           # print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
            query = (
                query.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            key = key.view(b, d_k, out_h, height, out_w, width)
            key = (
                key.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            value = value.view(b, d_k, out_h, height, out_w, width)
            value = (
                value.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            y, _ = attention(query, key, value)

            # 3) "Concat" using a view and apply a final linear.
            y = y.view(b, out_h, out_w, d_k, height, width)
            y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
            attentions.append(y)
            output.append(y)

        output = torch.cat(output, 1)
        self_attention = self.output_linear(output)

        return self_attention



class TransformerBlock(nn.Module):
    """
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, patchsize, in_channel=256):
        super().__init__()
        self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
        self.feed_forward = FeedForward2D(
            in_channel=in_channel, out_channel=in_channel
        )

    def forward(self, rgb):
        self_attention = self.attention(rgb)
        output = rgb + self_attention
        output = output + self.feed_forward(output)
        return output

class PatchTrans(BaseNetwork):
    def __init__(self, in_channel, in_size):#32,80
        super(PatchTrans, self).__init__()
        self.in_size = in_size#80

        patchsize = [
              (32,32),#80,80
              (16,16),#40,40
              (8,8),#20,20
              (4,4),#10,10
        ]

        self.t = TransformerBlock(patchsize, in_channel=in_channel)

    def forward(self, enc_feat):
        output = self.t(enc_feat)
        return output

class multi(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att 
        #x = self.conv2d(x)
        #wei = self.avg_pool(x)
        #wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        #wei = self.sigmoid(wei)
        #x = x * wei

        return x

class CTO(nn.Module):
    def __init__(self,seg_classes):
        super(CTO, self).__init__()
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
        # if self.training:
        # self.initialize_weights()
        self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
        
        self.multi_trans = PatchTrans(in_channel=256,in_size=64)
        
        
        
        self.num_class = seg_classes
        self.eam = EAM()
        self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
        self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
        self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
        self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
        
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
        
        self.erb_db_1 = ERB(256, self.num_class)
        self.erb_db_2 = ERB(512, self.num_class)
        self.erb_db_3 = ERB(1024, self.num_class)
        self.erb_db_4 = ERB(2048, self.num_class)
        
        self.head = _DAHead(2048+256, 2048, aux=False)

        

        self.reduce1 = Conv1x1(256, 64)
        self.reduce2 = Conv1x1(512, 64)
        self.reduce3 = Conv1x1(1024, 64)
        self.reduce4 = Conv1x1(2048, 64)
        self.reduce5 = Conv1x1(2048, 1)

        self.dm1 = DM()
        self.dm2 = DM()
        self.dm3 = DM()
        self.dm4 = DM()

        self.predictor1 = nn.Conv2d(64, self.num_class, 1)
        self.predictor2 = nn.Conv2d(64, self.num_class, 1)
        self.predictor3 = nn.Conv2d(64, self.num_class, 1)
        self.predictor4 = nn.Conv2d(64, self.num_class, 1)

    # def initialize_weights(self):
    # model_state = torch.load('./models/resnet50-19c8e357.pth')
    # self.resnet.load_state_dict(model_state, strict=False)

    def forward(self, x):
        fft_fea = self.fft(x)#3,256,256
        x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64]  [16, 512, 32, 32]   [16, 1024, 16, 16]   [16, 2048, 8, 8]
        
        trans = self.multi_trans(x1)#16,256,64,64
        
        s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
        s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
       
        edge = self.eam(s4, s1)
        edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
        
        trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
        dual_attention = self.head(torch.cat([trans, x4], dim=1))[0]  #2048,8,8
        
        x1a = x1*edge_att
        edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
        x2a = x2*edge_att2
        edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
        x3a = x3*edge_att3
        
        #x1a = self.efm1(x1, edge_att)
        #x2a = self.efm2(x2, edge_att)
       # x3a = self.efm3(x3, edge_att)
       # x4a = self.efm4(x4, edge_att)
        
        x1r = self.reduce1(x1a)  
        x2r = self.reduce2(x2a)#128,32,32
        x3r = self.reduce3(x3a)#256,16,16
        
        dual_attention = self.reduce4(dual_attention)
       
        c3 = self.dm3(x3r, dual_attention) #256 16 16
        c2 = self.dm2(x2r, c3)  #128 32 32
        c1 = self.dm1(x1r, c2) #64 64 64
        

        o3 = self.predictor3(c3)
        o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
        o2 = self.predictor2(c2)
        o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False) 
        o1 = self.predictor1(c1)
        o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
        oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)

        return  o3, o2, o1, oe

Guess you like

Origin blog.csdn.net/m0_61899108/article/details/131155202