Rethinking Boundary Detection in Deep Learning Models for Medical Image Segmentation, IPMI 2023
Paper: https://arxiv.org/abs/2305.00678
Code: https://github.com/xiaofang007/CTO
introduce
In this paper, we propose a novel network architecture CTO
, namely Convolution
, Transformer
and Operator
, to achieve high-precision image segmentation with an optimal balance between accuracy and efficiency by combining convolutional neural networks, visual Transformers, and explicit boundary detection operations.
CTO follows the standard encoder-decoder segmentation paradigm, where the encoder network adopts the popular CNN backbone structure to capture local semantic information, and uses a lightweight ViT auxiliary network to integrate long-range dependencies. To enhance the boundary learning ability, this paper further proposes a boundary-guided decoder network, which uses the boundary mask obtained from a dedicated boundary detection operation as explicit supervision to guide the decoding learning process.
Convolution, Transformer, and Operator (CTO)
CTO follows the encoder-decoder paradigm and employs skip connections to aggregate low-level features from the encoder into the decoder. The encoder network is composed of mainstream CNN and auxiliary ViT. The decoder network employs a boundary detection operator to guide its learning process.
- A dual-stream encoder, which combines a convolutional neural network and a lightweight visual Transformer, captures image local feature dependencies and long-range feature dependencies between image patches, respectively.
- Operator-guided decoders that use boundary detection operators (for example
Sobel
) to guide the learning process through generated boundary masks, and the entire model is trained in an end-to-end fashion.
Dual-Stream Encoder
CTO first builds a convolutional flow, chosen Res2Net
as the backbone network, to capture local feature dependencies.
CTO uses a lightweight Vision Transformer-based auxiliary flow to capture long-range dependencies between different image patches. Specifically, it consists of multiple parallel lightweight Transformer blocks that receive feature blocks of different scales as input. All Transformer blocks share a similar structure, including block embedding layers and Transformer encoding layers.
The block embedding layer of LightViT is used to convert the input feature block into an embedding vector, and convert the spatial dimension into the sequence dimension. The Transformer encoding layer is used to model the feature blocks with a self-attention mechanism to capture the long-range dependencies between different feature blocks. By introducing a self-attention mechanism in the Transformer module, LightViT can effectively model the interactions among feature blocks to extract the global context information of images.
Boundary-Guided Decoder
The boundary-guided decoder uses a gradient operator module to extract boundary information of foreground objects. Then, through the boundary optimization module, the boundary-enhanced features are integrated with the features of the multi-level encoder, aiming to simultaneously characterize the intra-class and inter-class consistency in the feature space, and enrich the representation ability of the features. This approach enables the decoder to better utilize boundary information when generating segmentation results, resulting in more accurate segmentation results.
Boundary Enhanced Module (BEM)
The boundary optimization module uses high-level features and low-level features as input, extracts boundary information and filters out information irrelevant to the boundary. Sobel
The operator is applied in the horizontal direction Gx and the vertical direction Gy to obtain the gradient map. Specifically, this paper adopts two 3*3 parameter fixed convolutions, and applies a convolution operation with a stride of 1. These two convolutions are defined as:
Then, these two convolutions are applied to the input feature map to obtain gradient maps Mx and My. Next, the gradient map sigmoid
is normalized by the function, and then fused with the input feature map to obtain the enhanced edge feature map Fe:
Among them, the circle number represents element-wise multiplication, which represents the sigmoid function, and Mxy is the splicing of Mx and My along the channel dimension. Then, we can directly fuse the edge-enhanced feature maps using simple stacked convolutional layers. Finally, the output feature map is supervised by the GT boundary map, which removes edge features inside objects and produces boundary enhanced features.
Boundary Inject Module (BIM)
The boundary-enhanced features obtained by BEM can be used as prior knowledge to improve the image representation ability of the features generated by the encoder. For BIM, a dual-path boundary fusion scheme is introduced to facilitate the representation ability of foreground and background features. Specifically, BIM receives two inputs: a channel-level connection of boundary-enhanced features with corresponding features from the encoder network, and features from the previous decoder layer. These two inputs are then fed into the BIM, which contains two separate paths to facilitate feature representation for foreground and background, respectively.
- For the foreground path, we concatenate the two inputs directly along the channel dimension, and then apply a series of Conv-BN-ReLU (convolution, batch normalization, ReLU activation) layers to obtain foreground features.
- For the background path, a background attention component is designed to selectively focus on background information.
The foreground path gets the foreground feature Ffg. The background path gets the background feature Fbg.
The foreground attention map is obtained from the feature map of the previous layer of decoder through sigmoid; the background attention map is obtained by subtracting the foreground attention map from 1. Finally, the foreground feature Ffg, background feature Fbg, and previous layer decoder features are concatenated to obtain the output of this layer.
Loss Function
CTO
is a multi-task model that includes interior and boundary segmentation, defining an overall loss function to jointly optimize these two tasks:
The overall loss consists of the main inner segmentation loss L_seg and the boundary loss L_bnd. In the boundary detection loss, only the predictions from the BEM are considered, and this module takes as input the high-level feature maps of the encoder and the low-level feature maps.
Interior Segmentation Loss
L_seg is the weighted sum of the cross-entropy loss L_CE and the average IoU loss L_mIoU:
Boundary Loss
The boundary loss L_bnd takes into account the category imbalance between foreground and background pixels in boundary detection, so Dice loss is used:
experiment
key code
CTO_net.py
# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py
class ConvBNR(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
super(ConvBNR, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class Conv1x1(nn.Module):
def __init__(self, inplanes, planes):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(inplanes, planes, 1)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(512, 256)
self.block = nn.Sequential(
ConvBNR(320 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x1, x11, p2):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x11 = self.reduce1(x11)
p2 = self.reduce4(p2)
p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
out = torch.cat((x1, x11), dim=1)
out = torch.cat((out, p2), dim=1)
out = self.block(out)
return out
class EFM(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att + c
x = self.conv2d(x)
wei = self.avg_pool(x)
wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
wei = self.sigmoid(wei)
x = x * wei
return x
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class DM(nn.Module):
def __init__(self):
super(DM, self).__init__()
self.predict3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 1, kernel_size=1)
)
self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
def forward(self, xr, dualattention):
crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
x = -1*(torch.sigmoid(crop_3)) + 1
x = x.expand(-1, 64, -1, -1).mul(xr)
x = F.relu(self.ra2_conv2(x))
x = F.relu(self.ra2_conv3(x))
ra3_feat = self.ra2_conv4(x)
x = ra3_feat + crop_3 + re3_feat
return x
class _DAHead(nn.Module):
def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DAHead, self).__init__()
self.aux = aux
inter_channels = in_channels // 4
self.conv_p1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.conv_c1 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.pam = _PositionAttentionModule(inter_channels, **kwargs)
self.cam = _ChannelAttentionModule(**kwargs)
self.conv_p2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.conv_c2 = nn.Sequential(
nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
self.out = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
if aux:
self.conv_p3 = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
self.conv_c3 = nn.Sequential(
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)
def forward(self, x):
feat_p = self.conv_p1(x)
feat_p = self.pam(feat_p)
feat_p = self.conv_p2(feat_p)
feat_c = self.conv_c1(x)
feat_c = self.cam(feat_c)
feat_c = self.conv_c2(feat_c)
feat_fusion = feat_p + feat_c
outputs = []
fusion_out = self.out(feat_fusion)
outputs.append(fusion_out)
if self.aux:
p_out = self.conv_p3(feat_p)
c_out = self.conv_c3(feat_c)
outputs.append(p_out)
outputs.append(c_out)
return tuple(outputs)
def run_sobel(conv_x, conv_y, input):
g_x = conv_x(input)
g_y = conv_y(input)
g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
return torch.sigmoid(g) * input
def get_sobel(in_chan, out_chan):
'''
filter_x = np.array([
[3, 0, -3],
[10, 0, -10],
[3, 0, -3],
]).astype(np.float32)
filter_y = np.array([
[3, 10, 3],
[0, 0, 0],
[-3, -10, -3],
]).astype(np.float32)
'''
filter_x = np.array([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]).astype(np.float32)
filter_y = np.array([
[1, 2, 1],
[0, 0, 0],
[-1, -2, -1],
]).astype(np.float32)
filter_x = filter_x.reshape((1, 1, 3, 3))
filter_x = np.repeat(filter_x, in_chan, axis=1)
filter_x = np.repeat(filter_x, out_chan, axis=0)
filter_y = filter_y.reshape((1, 1, 3, 3))
filter_y = np.repeat(filter_y, in_chan, axis=1)
filter_y = np.repeat(filter_y, out_chan, axis=0)
filter_x = torch.from_numpy(filter_x)
filter_y = torch.from_numpy(filter_y)
filter_x = nn.Parameter(filter_x, requires_grad=False)
filter_y = nn.Parameter(filter_y, requires_grad=False)
conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_x.weight = filter_x
conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
conv_y.weight = filter_y
sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
return sobel_x, sobel_y
class GlobalFilter(nn.Module):
def __init__(self, dim=32, h=64, w=33, fp32fft=True):
super().__init__()
self.complex_weight = nn.Parameter(
torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
)
self.w = w
self.h = h
self.fp32fft = fp32fft
def forward(self, x):
b, _, a, b = x.size()
x = x.permute(0, 2, 3, 1).contiguous()
if self.fp32fft:
dtype = x.dtype
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
#print(x.shape)
weight = torch.view_as_complex(self.complex_weight)
# print(x.shape)
#print(weight.shape)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")
if self.fp32fft:
x = x.to(dtype)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class ERB(nn.Module):
def __init__(self, in_channels, out_channels):
super(ERB, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x, relu=True):
x = self.conv1(x)
res = self.conv2(x)
res = self.bn(res)
res = self.relu(res)
res = self.conv3(res)
if relu:
return self.relu(x + res)
else:
return x+res
class _PositionAttentionModule(nn.Module):
""" Position attention module"""
def __init__(self, in_channels, **kwargs):
super(_PositionAttentionModule, self).__init__()
self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
self.alpha = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
feat_c = self.conv_c(x).view(batch_size, -1, height * width)
attention_s = self.softmax(torch.bmm(feat_b, feat_c))
feat_d = self.conv_d(x).view(batch_size, -1, height * width)
feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
out = self.alpha * feat_e + x
return out
class _ChannelAttentionModule(nn.Module):
"""Channel attention module"""
def __init__(self, **kwargs):
super(_ChannelAttentionModule, self).__init__()
self.beta = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, _, height, width = x.size()
feat_a = x.view(batch_size, -1, height * width)
feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
attention = torch.bmm(feat_a, feat_a_transpose)
attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
attention = self.softmax(attention_new)
feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
out = self.beta * feat_e + x
return out
class EAM(nn.Module):
def __init__(self):
super(EAM, self).__init__()
self.reduce1 = Conv1x1(256, 64)
self.reduce4 = Conv1x1(2048, 256)
self.block = nn.Sequential(
ConvBNR(256 + 64, 256, 3),
ConvBNR(256, 256, 3),
nn.Conv2d(256, 1, 1))
def forward(self, x4, x1):
size = x1.size()[2:]
x1 = self.reduce1(x1)
x4 = self.reduce4(x4)
x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
out = torch.cat((x4, x1), dim=1)
out = self.block(out)
return out
def attention(query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
query.size(-1)
)
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, patchsize, d_model):
super().__init__()
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0
)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.BatchNorm2d(d_model),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
b, c, h, w = x.size()#8,255,64,64
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)#8,32,80,80
_key = self.key_embedding(x)#8,32,80,80
_value = self.value_embedding(x)#8,32,80,80
attentions = []
for (width, height), query, key, value in zip(
self.patchsize,
torch.chunk(_query, len(self.patchsize), dim=1),
torch.chunk(_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1),
):
#print('-----------width, height):',x.size())
# print('-----------x.size()):',x.size())
#print('-----------len(self.patchsize):',len(self.patchsize)) # 4
#print('-----------_query):',_query.shape) #8,256,64,64
#print('-----------query):',query.shape) #8,64,64,64
out_w, out_h = w // width, h // height#
## 1) embedding and reshape
query = query.view(b, d_k, out_h, height, out_w, width)
# print('-----------query):',query.shape)
# print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
query = (
query.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
key = key.view(b, d_k, out_h, height, out_w, width)
key = (
key.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
value = value.view(b, d_k, out_h, height, out_w, width)
value = (
value.permute(0, 2, 4, 1, 3, 5)
.contiguous()
.view(b, out_h * out_w, d_k * height * width)
)
y, _ = attention(query, key, value)
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, out_h, out_w, d_k, height, width)
y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
attentions.append(y)
output.append(y)
output = torch.cat(output, 1)
self_attention = self.output_linear(output)
return self_attention
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, patchsize, in_channel=256):
super().__init__()
self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
self.feed_forward = FeedForward2D(
in_channel=in_channel, out_channel=in_channel
)
def forward(self, rgb):
self_attention = self.attention(rgb)
output = rgb + self_attention
output = output + self.feed_forward(output)
return output
class PatchTrans(BaseNetwork):
def __init__(self, in_channel, in_size):#32,80
super(PatchTrans, self).__init__()
self.in_size = in_size#80
patchsize = [
(32,32),#80,80
(16,16),#40,40
(8,8),#20,20
(4,4),#10,10
]
self.t = TransformerBlock(patchsize, in_channel=in_channel)
def forward(self, enc_feat):
output = self.t(enc_feat)
return output
class multi(nn.Module):
def __init__(self, channel):
super(EFM, self).__init__()
t = int(abs((log(channel, 2) + 1) / 2))
k = t if t % 2 else t + 1
self.conv2d = ConvBNR(channel, channel, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, c, att):
if c.size() != att.size():
att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
x = c * att
#x = self.conv2d(x)
#wei = self.avg_pool(x)
#wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
#wei = self.sigmoid(wei)
#x = x * wei
return x
class CTO(nn.Module):
def __init__(self,seg_classes):
super(CTO, self).__init__()
self.resnet = res2net50_v1b_26w_4s(pretrained=True)
# if self.training:
# self.initialize_weights()
self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
self.multi_trans = PatchTrans(in_channel=256,in_size=64)
self.num_class = seg_classes
self.eam = EAM()
self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
self.erb_db_1 = ERB(256, self.num_class)
self.erb_db_2 = ERB(512, self.num_class)
self.erb_db_3 = ERB(1024, self.num_class)
self.erb_db_4 = ERB(2048, self.num_class)
self.head = _DAHead(2048+256, 2048, aux=False)
self.reduce1 = Conv1x1(256, 64)
self.reduce2 = Conv1x1(512, 64)
self.reduce3 = Conv1x1(1024, 64)
self.reduce4 = Conv1x1(2048, 64)
self.reduce5 = Conv1x1(2048, 1)
self.dm1 = DM()
self.dm2 = DM()
self.dm3 = DM()
self.dm4 = DM()
self.predictor1 = nn.Conv2d(64, self.num_class, 1)
self.predictor2 = nn.Conv2d(64, self.num_class, 1)
self.predictor3 = nn.Conv2d(64, self.num_class, 1)
self.predictor4 = nn.Conv2d(64, self.num_class, 1)
# def initialize_weights(self):
# model_state = torch.load('./models/resnet50-19c8e357.pth')
# self.resnet.load_state_dict(model_state, strict=False)
def forward(self, x):
fft_fea = self.fft(x)#3,256,256
x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64] [16, 512, 32, 32] [16, 1024, 16, 16] [16, 2048, 8, 8]
trans = self.multi_trans(x1)#16,256,64,64
s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
edge = self.eam(s4, s1)
edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
dual_attention = self.head(torch.cat([trans, x4], dim=1))[0] #2048,8,8
x1a = x1*edge_att
edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
x2a = x2*edge_att2
edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
x3a = x3*edge_att3
#x1a = self.efm1(x1, edge_att)
#x2a = self.efm2(x2, edge_att)
# x3a = self.efm3(x3, edge_att)
# x4a = self.efm4(x4, edge_att)
x1r = self.reduce1(x1a)
x2r = self.reduce2(x2a)#128,32,32
x3r = self.reduce3(x3a)#256,16,16
dual_attention = self.reduce4(dual_attention)
c3 = self.dm3(x3r, dual_attention) #256 16 16
c2 = self.dm2(x2r, c3) #128 32 32
c1 = self.dm1(x1r, c2) #64 64 64
o3 = self.predictor3(c3)
o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
o2 = self.predictor2(c2)
o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False)
o1 = self.predictor1(c1)
o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)
return o3, o2, o1, oe