[Medical Selection] Medical Segmentation System: Axial_attention & FCN-UNet

1. Research background and significance

Medical image segmentation is an important task in medical image processing. Its goal is to accurately segment different tissues or disease areas in medical images. Medical image segmentation has important application value in clinical diagnosis, treatment planning, and disease monitoring. However, due to the complexity and noise interference of medical images, traditional segmentation methods are often difficult to achieve high accuracy and efficiency requirements.

In recent years, deep learning technology has made significant progress in the field of medical image segmentation. Among them, methods based on Convolutional Neural Network (CNN) have been widely used. However, traditional CNN methods have some problems when processing medical images. First, medical images usually have large sizes and complex structures, causing traditional CNN methods to easily lose detailed information when extracting image features. Secondly, there are complex spatial relationships between different tissues or lesion areas in medical images, and traditional CNN methods are often difficult to capture this relationship.

In order to solve the above problems, this study proposes a medical segmentation system based on Axial attention and FCN-UNet. Axial attention is a novel attention mechanism that can effectively capture spatial relationships in medical images. Specifically, Axial attention achieves weighted fusion of features of different dimensions by introducing attention mechanisms in different dimensions of the feature map, thereby improving the expressive ability of features. In addition, this study also uses the network structure of FCN-UNet, which can effectively extract detailed information in medical images and has strong feature reconstruction capabilities.

The significance of this study is mainly reflected in the following aspects:

First, the medical segmentation system based on Axial attention and FCN-UNet can improve the accuracy and efficiency of medical image segmentation. By introducing the Axial attention mechanism, the system can better capture the spatial relationships in medical images, thereby improving the accuracy of segmentation results. At the same time, the network structure of FCN-UNet can effectively extract detailed information in the image and further improve the accuracy of segmentation. In addition, the system has high computational efficiency and can complete the segmentation task in a short time.

Secondly, the medical segmentation system based on Axial attention and FCN-UNet has strong generalization ability. Due to the diversity and complexity of medical images, traditional segmentation methods are often difficult to adapt to different types of medical images. The system proposed in this study can better adapt to different types of medical images and has strong generalization ability by introducing the Axial attention mechanism and the network structure of FCN-UNet.

Finally, the medical segmentation system based on Axial attention and FCN-UNet has important application value for clinical diagnosis and treatment. Accurate medical image segmentation results can provide doctors with more comprehensive and accurate information, helping them to make more accurate diagnoses and formulate more reasonable treatment plans. In addition, the system can also be used for disease monitoring and research, helping researchers better understand the development and change process of the disease.

In summary, the medical segmentation system based on Axial attention and FCN-UNet has important research background and significance. This system can improve the accuracy and efficiency of medical image segmentation, has strong generalization ability, and has important application value for clinical diagnosis and treatment. The results of this study will provide strong support for further research and applications in the field of medical image segmentation.

2. Picture demonstration

3.png

4.png

2.png

3. Video demonstration

Medical segmentation system based on Axial attention & FCN-UNet_bilibili_bilibili

4. Introduction to Unet

The Unet network structure is a deep learning architecture commonly used for image segmentation tasks. It was first proposed in 2015 by Olaf Ronneberger and others. The design of Unet is inspired by the needs of biomedical image segmentation, especially the task of cell nucleus segmentation. Unet's name comes from its U-shaped network structure, which is composed of symmetric encoders and decoders, making it a very effective image segmentation model.

One of Unet's main features is its use of "skip connections" or "skip connections". These connections allow information to be passed directly between the encoder and decoder, helping to preserve the spatial context of the image. Let us understand the structure and working principle of Unet in detail.

Structure of Unet

The network structure of Unet can be divided into two main parts: encoder and decoder. The encoder is responsible for gradually reducing the spatial resolution of the input image and extracting feature information, while the decoder is responsible for gradually increasing the resolution and converting the feature information into a segmentation result with the same resolution as the input image.
image.png

Encoder

Encoders usually consist of convolutional layers and pooling layers. The convolutional layer is used to extract feature information of the image, while the pooling layer is used to reduce the resolution while retaining important information. The combination of these layers allows the encoder to gradually capture global and local features of the image, thereby generating higher-level representations.

In Unet, the encoder usually consists of four stages, each of which reduces the resolution of the image by half. For example, the first stage usually reduces the resolution of the input image from 256x256 to 128x128, the second stage reduces it to 64x64, and so on. Each stage consists of a series of convolution and pooling operations to extract features.

Decoder

The task of the decoder is to convert the feature maps generated by the encoder into segmentation results at the same resolution as the input image. Decoders usually consist of convolutional transposed layers (deconvolutional layers) and skip connections. Convolutional transpose layers are used to increase resolution, while skip connections allow information to be passed from the encoder to the decoder.

The decoder in Unet is also divided into four stages, each stage will double the resolution of the feature map. Each decoder step is combined with the feature map of the corresponding encoder stage to recover spatial context information. This combination is usually achieved through channel splicing or pixel-level merging operations.

Skip Connections

An important feature of Unet is skip connections, which connect the feature maps of the encoder and decoder to help the decoder better understand the contextual information of the image. Skip connections help solve the vanishing gradient problem in deep networks and allow the model to better capture features at different scales.

In skip connections, each decoder stage is connected to the corresponding encoder stage. This means that the first stage of the decoder will be connected to the fourth stage of the encoder, the second stage of the decoder will be connected to the third stage of the encoder, and so on. This type of connection enables the decoder to utilize more information to generate more accurate segmentation results.

Application fields of Unet

Unet has achieved excellent results in many image segmentation tasks, including medical image segmentation, road segmentation, semantic segmentation, etc. Here are some examples of Unet applications in different fields:

Medical image segmentation: Unet is widely used in tasks such as cell nucleus segmentation, lesion detection, and organ segmentation in medical images. It helps doctors analyze and diagnose medical images more easily.

Road segmentation: Unet can be used to segment roads, buildings and other features from satellite or aerial images, which is very important for urban planning and geographic information systems (GIS).

Semantic segmentation: Unet can be used to assign each pixel in an image to a different semantic category, such as segmenting roads, vehicles, pedestrians, etc., which is helpful for autonomous driving and intelligent transportation systems.

Remote sensing image segmentation: Unet can be used to segment features in remote sensing images, such as forests, lakes, farmland, etc., helping to monitor environmental changes and resource management.

Variants and improvements of Unet

Since Unet was first proposed, many researchers have proposed many variants and improved versions of Unet to adapt to different application needs and solve different problems. Some common Unet variants include:

Improved network architecture: Some variants feature a deeper and wider network architecture to further improve performance. For example, some variants use the structure of ResNet or EfficientNet.

Multi-scale feature fusion: In order to better capture multi-scale features, some variants introduce attention mechanisms or multi-scale feature fusion modules.

Data augmentation and regularization: To improve model robustness, some variants introduce more sophisticated data augmentation and regularization techniques.

Semi-supervised learning: Some variations attempt to leverage small amounts of labeled data and large amounts of unlabeled data through semi-supervised learning to improve model performance.

In summary, Unet is a very powerful and flexible image segmentation architecture that has achieved remarkable success in multiple fields. Its design philosophy and structure make it one of the preferred models for many image segmentation tasks, and also provides useful inspiration for deep learning research. As the field of deep learning continues to grow, we can expect to see more Unet-based innovations and improvements.

5. Core code explanation

5.1 model.py


class AxialAttention(nn.Module):
    def __init__(self, in_channels, heads=4):
        super(AxialAttention, self).__init__()
        self.heads = heads
        self.scale = heads ** -0.5
        self.query = nn.Conv1d(in_channels, in_channels, 1, groups=heads)
        self.key = nn.Conv1d(in_channels, in_channels, 1, groups=heads)
        self.value = nn.Conv1d(in_channels, in_channels, 1, groups=heads)

    def forward(self, x):
        B, C, H, W, D = x.size()
        queries = self.query(x).view(B, self.heads, C // self.heads, H, W, D)
        keys = self.key(x).view(B, self.heads, C // self.heads, H, W, D)
        values = self.value(x).view(B, self.heads, C // self.heads, H, W, D)

        attn_scores = (queries @ keys.transpose(-2, -1)) * self.scale
        attn_probs = F.softmax(attn_scores, dim=-1)
        out = attn_probs @ values
        out = out.contiguous().view(B, C, H, W, D)
        return out

class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_gn=False, axial_attention=False):
        super(UNetBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(out_channels // 2, out_channels) if use_gn else nn.BatchNorm3d(out_channels)
        self.norm2 = nn.GroupNorm(out_channels // 2, out_channels) if use_gn else nn.BatchNorm3d(out_channels)
        self.relu = nn.LeakyReLU(0.01)
        self.axial_attention = AxialAttention(out_channels) if axial_attention else None

    def forward(self, x):
        x = self.relu(self.norm1(self.conv1(x)))
        x = self.relu(self.norm2(self.conv2(x)))
        if self.axial_attention:
            x = x + self.axial_attention(x)
        return x

......

The program file is named model.py and mainly contains the following parts:

  1. AxialAttention module: This module defines a 1D self-attention module for calculating attention scores and attention probabilities.

  2. UNetBlock module: This module defines a single-layer Block of U-Net, including two convolutional layers, a normalization layer and an activation function. You can choose whether to use the AxialAttention module.

  3. UNet module: This module defines the complete U-Net model, including the Encoder and Decoder parts. The Encoder part includes 5 UNetBlocks, and the Decoder part includes 4 UNetBlocks. You can choose whether to use the AxialAttention module.

  4. Create model: At the end of the file, a UNet model object is created and the model structure is printed.

This program file mainly implements the definition and creation of the U-Net model, in which the AxialAttention module is used to enhance the attention mechanism of the model.

5.2 predict.py


# 定义Axial Attention模块,采用1D self-attention
class AxialAttention(nn.Module):
    def __init__(self, in_channels, heads=4):
        super(AxialAttention, self).__init__()
        self.heads = heads
        self.scale = heads ** -0.5
        self.query = nn.Conv1d(in_channels, in_channels, 1, groups=heads)
        self.key = nn.Conv1d(in_channels, in_channels, 1, groups=heads)
        self.value = nn.Conv1d(in_channels, in_channels, 1, groups=heads)

    def forward(self, x):
        B, C, H, W, D = x.size()
        queries = self.query(x).view(B, self.heads, C // self.heads, H, W, D)
        keys = self.key(x).view(B, self.heads, C // self.heads, H, W, D)
        values = self.value(x).view(B, self.heads, C // self.heads, H, W, D)

        # 注意力得分
        attn_scores = (queries @ keys.transpose(-2, -1)) * self.scale
        attn_probs = F.softmax(attn_scores, dim=-1)
        out = attn_probs @ values
        out = out.contiguous().view(B, C, H, W, D)
        return out


# 定义单层的U-Net Block
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_gn=False, axial_attention=False):
        super(UNetBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(out_channels // 2, out_channels) if use_gn else nn.BatchNorm3d(out_channels)
        self.norm2 = nn.GroupNorm(out_channels // 2, out_channels) if use_gn else nn.BatchNorm3d(out_channels)
        self.relu = nn.LeakyReLU(0.01)
        self.axial_attention = AxialAttention(out_channels) if axial_attention else None

    def forward(self, x):
        x = self.relu(self.norm1(self.conv1(x)))
        x = self.relu(self.norm2(self.conv2(x)))
        if self.axial_attention:
            x = x + self.axial_attention(x)
        return x


# 定义完整的U-Net模型
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes, use_gn=False, axial_attention_levels=[]):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = UNetBlock(in_channels, 32, use_gn)
        self.enc2 = UNetBlock(32, 64, use_gn)
        self.enc3 = UNetBlock(64, 128, use_gn)
        self.enc4 = UNetBlock(128, 256, use_gn, axial_attention=(1 in axial_attention_levels))
        self.enc5 = UNetBlock(256, 512, use_gn, axial_attention=(2 in axial_attention_levels))

        # Decoder
        self.dec1 = UNetBlock(512 + 256, 256, use_gn, axial_attention=(3 in axial_attention_levels))
        self.dec2 = UNetBlock(256 + 128, 128, use_gn, axial_attention=(4 in axial_attention_levels))
        self.dec3 = UNetBlock(128 + 64, 64, use_gn)
        self.dec4 = UNetBlock(64 + 32, 32, use_gn)

        # Final Convolution
        self.final_conv = nn.Conv3d(32, num_classes, kernel_size=1)

......

The program file name is predict.py, which is a model prediction program for image segmentation. The program uses the TensorFlow and PyTorch libraries to build and load the model, and uses the matplotlib library for image display.

The Axial Attention module and UNet model are defined in the program. The Axial Attention module uses 1D self-attention to extract image features. The UNet model is a complete U-Net model for image segmentation.

The program also defines some auxiliary functions, such as loading data sets, checking environment versions, image preprocessing and normalization processing, etc.

In the main function, first check whether the main environment version meets the requirements, then load the image data set to be predicted, and set the configuration of the data set. Then load the trained model and use the model to predict the image. Finally, the original image, real labels and prediction results are displayed.

The function of the entire program is to load the trained image segmentation model, use the model to predict the segmentation of the input image, and display the prediction results.

5.3 train.py

# 定义U-Net模型
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes, use_gn=False, axial_attention_levels=[]):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = UNetBlock(in_channels, 32, use_gn)
        self.enc2 = UNetBlock(32, 64, use_gn)
        self.enc3 = UNetBlock(64, 128, use_gn)
        self.enc4 = UNetBlock(128, 256, use_gn, axial_attention=(1 in axial_attention_levels))
        self.enc5 = UNetBlock(256, 512, use_gn, axial_attention=(2 in axial_attention_levels))

        # Decoder
        self.dec1 = UNetBlock(512 + 256, 256, use_gn, axial_attention=(3 in axial_attention_levels))
        self.dec2 = UNetBlock(256 + 128, 128, use_gn, axial_attention=(4 in axial_attention_levels))
        self.dec3 = UNetBlock(128 + 64, 64, use_gn)
        self.dec4 = UNetBlock(64 + 32, 32, use_gn)

        # Final Convolution
        self.final_conv = nn.Conv3d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.enc1(x)
        x2 = self.enc2(F.max_pool3d(x1, 2))
        x3 = self.enc3(F.max_pool3d(x2, 2))
        x4 = self.enc4(F.max_pool3d(x3, 2))
        x5 = self.enc5(F.max_pool3d(x4, 2))

        # Decoder
        x = F.interpolate(x5, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.dec1(torch.cat([x, x4], dim=1))
        x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.dec2(torch.cat([x, x3], dim=1))
        x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.dec3(torch.cat([x, x2], dim=1))
        x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.dec4(torch.cat([x,

The program file is named train.py and mainly contains the following parts:

  1. The AxialAttention module and UNetBlock module are defined to build the encoder and decoder parts of the U-Net model.
  2. The complete U-Net model is defined, including the encoder and decoder parts.
  3. Functions for reading jpg and png images are defined and normalized.
  4. Defines functions for loading image datasets.
  5. In the main function, the image and annotation data sets are first loaded, and then shuffled and divided into training sets and test sets.
  6. Build the decoder part of the U-Net model and perform jump connections with the output of the pre-trained model.
  7. Compile the model and set training parameters.
  8. Train the model and save the optimal weight file.
  9. Save the complete model locally.
  10. Plot accuracy and loss curves during training.
5.4 ui.py

This is a program file for a medical image segmentation system implemented using PyQt5 and TensorFlow. The Axial Attention module and U-Net model are defined in the program, and PyQt5 is used to build the graphical user interface. The program mainly implements the following functions:

  1. Load the dataset: Load the image dataset through the load_dataset function.
  2. Image preprocessing: Read the image file through the read_jpg and read_png functions, and use the normal_img function to normalize the image.
  3. Set configuration: Set the configuration of the data set through the set_config function, including the number of multi-threads and batch_size.
  4. Model loading and prediction: Load the trained U-Net model through the load_model function, and use the predict function to predict image segmentation.
  5. Image processing and display: Perform image processing on the prediction results, including grayscale, binarization and contour drawing, and use the showimg function to display the processed image in the graphical interface.
  6. Graphical interface design: Use QtDesigner to design a graphical user interface, including controls such as labels, buttons, and text boxes, and use the setupUi function to initialize and layout the interface.
  7. Event binding and thread starting: Bind the button click event to the corresponding slot function through the connect function, create a thread object through the Thread_1 class, and start the thread through the start function.

This program implements a simple medical image segmentation system. Users can select image files for segmentation prediction and display the prediction results in the interface.

6. Overall structure of the system

Overall function and architecture overview:

This project is a medical image segmentation system based on Axial attention & FCN-UNet. It contains multiple program files, each responsible for different functions. The main program files include model.py, predict.py, train.py, ui.py, data\check.py, tools\check_img.py, tools\check_seg.py and tools\url_get.py.

The model.py file defines the AxialAttention module and UNet module, which are used to build the encoder and decoder parts of the U-Net model. It implements the definition and creation of the U-Net model, in which the AxialAttention module is used to enhance the attention mechanism of the model.

The predict.py file is a model prediction program for image segmentation. It loads a trained image segmentation model, uses the model to predict segmentation of the input image, and displays the prediction results.

The train.py file is used to train the image segmentation model. It loads image and annotation datasets, builds a U-Net model, and sets training parameters for model training.

The ui.py file uses PyQt5 to build a graphical user interface. Users can select image files for segmentation prediction and display the prediction results in the interface.

The data\check.py file is used to process image files, process the image files in the specified folder, and save the processed results to several other folders.

The tools\check_img.py file is used to reduce noise and compress images in the specified folder.

The tools\check_seg.py file is used to process the segmentation results in the specified folder.

The tools\url_get.py file is used to download image files from the specified URL address.

The following table organizes the functions of each file:

file name Function
model.py Define the AxialAttention module and UNet module to build the U-Net model
predict.py Image segmentation model prediction program, loads the model and performs segmentation prediction on the image
train.py Train the image segmentation model, load the data set, build the model, and set parameters for training
ui.py Use PyQt5 to build a graphical user interface for image segmentation prediction and result display
data\check.py Process image files, process and save image files to other folders
tools\check_img.py Perform noise reduction and compression on pictures in the specified folder
tools\check_seg.py Process the segmentation results in the specified folder
tools\url_get.py Download image file from specified URL address

7. Improve nnU-Net

The nnU-Net core is a 3D U-Net running on a patch size of 128×128×128. The network has an encoder-decoder structure with Skip Connection connecting the two paths.

The encoder consists of 5 convolutional layers of the same resolution with convolutional downsampling. The decoder follows the same structure, using transposed convolution upsampling and concatenated Skip features operating on the same level as the encoder branch. After each convolution operation, Leaky ReLU (lReLU) with a slope of 0.01 and batch normalization are used. mpMRI volumes are connected and used as 4-channel input.

nnU-Net applies Region-Based training, and instead of predicting 3 mutually exclusive tumor sub-regions, nnU-Net predicts 3 mutually exclusive tumor partitions. As with the provided segmentation labels, the network predicts 3 overlapping regions of enhanced tumors. Such as enhanced tumor (ET, original region), tumor core or TC (ET+necrotic tumor), and whole tumor or WT (ET+NT+ED).

The softmax of the last layer of the network is replaced by sigmoid, treating each voxels as a multi-class classification problem.

Since the calculated metrics for public and private leaderboards are based on these regions, this region-based training can improve performance. Additional sigmoid outputs are added to each resolution except the 2 lowest levels, applying deep supervision and improving gradient propagation to early layers. The number of convolutional filters is initialized to 32 and can reach a maximum of 320 each time the resolution is doubled.

Bigger Network and GN

In the process of continuous development and improvement of Unet architecture,"Expanding the U-Net for Image Segmentation" ways to improve its performance and applicability. An important modification is to cope with more complex tasks and larger data sets by increasing the size of the network. This modification involves doubling the number of convolution kernels in the encoder while keeping the number of convolution kernels in the decoder unchanged. This asymmetric network expansion helps improve Unet’s representation capabilities.

The initial version of Unet was designed to solve tasks such as medical image segmentation and therefore had a relatively small network size. However, with the rapid development of the field of deep learning and the increase of computing resources, researchers have begun to apply Unet to a wider range of application fields, such as natural image segmentation, satellite image analysis, and intelligent transportation systems. In order to accommodate these increasing application demands, expanding the scale of Unet has become critical.

A common modification is to increase the number of convolution kernels in the encoder. The task of the encoder is to gradually reduce the resolution of the input image and extract features at different scales. By increasing the number of convolution kernels, the encoder can capture the image information in more detail, thus improving the representation ability of the model. This is important for processing images with more detail and complex structures. At the same time, in order to maintain the symmetry of the decoder, the number of convolution kernels in the decoder remains unchanged to ensure that the decoder can effectively restore the resolution of the image.

Furthermore, combined with larger-scale data sets, this modification to increase network capacity allows for better modeling of various data types. The availability of large-scale data sets enables deep learning models to better generalize to different scenarios and data distributions. For example, the size of medical image datasets has increased significantly over the past few years, making modifications to expand Unet network capacity more reasonable. Therefore, modern Unet models often have more parameters and higher complexity than when Unet was first proposed to cope with the growing data and task complexity.
image.png

Axial attention decoder

The continuous development and improvement of Unet network architecture has always been one of the research focuses in the field of deep learning. One of the latest improvements in Unet is the introduction of the Axial attention decoder. The Axial attention decoder proposed in "Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation" and other documents represents an innovative application of self-attention mechanism, especially when processing multi-dimensional data, such as 3D image segmentation tasks. This section will introduce the background of Axial attention and its application in Unet.

Self-Attention, or Transformer, is a groundbreaking idea that allows models to automatically learn adaptive attention distributions when processing input sequences. Initially, the self-attention mechanism achieved great success in natural language processing tasks, but has since been widely used in the field of computer vision. However, a major challenge of the self-attention mechanism is that its computational complexity scales quadratically with the size of the input sequence, which results in an unfeasible computational burden for large-scale image data. This problem is especially significant when dealing with 3D data with extra dimensions.

Axial attention is proposed as an effective method to solve the computational complexity problem of attention in multi-dimensional data. The core idea of ​​Axial attention is to apply Self-Attention to each axis of the input data independently. For example, for 3D image data, the self-attention mechanism can be applied on the x, y and z axes respectively. This separated application method makes the computational complexity of the attention mechanism linearly related to the size of the image instead of quadratic, thus greatly reducing the computational burden and making self-attention feasible in a wider range of applications. .

In Unet, Axial attention is introduced into the decoder part to process the output of transposed convolution upsampling. This means that in the process of restoring image resolution, the model can effectively capture the correlation between different dimensions through the Axial attention mechanism, thereby better understanding the structure and semantic information of the image. Axial attention decoders are better able to handle multi-dimensional data than traditional decoders, especially when processing 3D data containing extra dimensions, and their performance is significantly improved.

In summary, the Axial attention decoder represents an important improvement in the Unet architecture, which overcomes the computational complexity problem of self-attention mechanisms when processing multi-dimensional data. By applying self-attention to axes of different dimensions respectively, Axial attention enables Unet to better handle complex image segmentation tasks, including medical images, natural images and 3D data. This innovative application method provides interesting ideas for deep learning research and opens up more possibilities for future image segmentation tasks.
image.png

8. System integration

Below is the complete source code & environment deployment video tutorial & data set & custom UI interface

1.png

Reference Blog "Medical Segmentation System Based on Axial Attention & FCN-UNet"

Guess you like

Origin blog.csdn.net/cheng2333333/article/details/134999065