MS-Model【3】:Medical Transformer


foreword

This article is an article published on MICCAI in 2021, and achieved good results in many medical image segmentation task challenges that year. This article mainly introduces the structure and related content of the Medical Transformer proposed in this paper.

Original paper link: Medical Transformer: Gated Axial-Attention for Medical Image Segmentation


1. Abstract & Introduction

1.1. Abstract

Convolutional architectures have an inherent inductive bias (inductive bias refers to the fact that the neural network model will produce a biased prediction , that is to say, the inductive bias will make the learning algorithm give priority to solutions with certain properties), and they lack image Understanding of medium and long-range dependencies.

The article proposes to use transformer for medical image segmentation. The problem to be solved is that the transformer needs a larger data set for training than the convolutional neural network on image tasks, and a problem in medical image processing is insufficient data and insufficient data sets .

The main contributions of this paper are:

  • Proposes a Gated Position-Sensitive Axial Attention Mechanism for Smaller Datasets
  • An efficient local-global (LOGO) training method is introduced

1.2. Introduction

In ConvNets, each convolution kernel only focuses on a local subset of pixels in the entire image and forces the network to focus on local patterns rather than global context. Although some supplementary tricks such as image pyramid, Atrus convolution and attention mechanism were proposed later, this problem still cannot be completely solved.

Since the background of an image is scattered, learning long-range dependencies between pixels corresponding to the background can help the network prevent misclassifying a pixel as a mask, thereby reducing false positives (treating 0 as background and 1 as segmentation mask). Likewise, when the segmentation mask is large, learning the long-distance dependencies between pixels corresponding to the mask is also helpful for efficient prediction.

In digital image processing, segmentation masks are mainly used for:

  • Extract the region of interest, multiply the pre-made region of interest mask with the image to be processed, and obtain the image of the region of interest, the image value in the region of interest remains unchanged, and the value of the image outside the region is 0
  • Shielding function, use a mask to shield certain areas on the image, so that it does not participate in the processing or calculation of processing parameters, or only processes or counts the shielded area
  • Structural feature extraction, using similarity variables or image matching methods to detect and extract structural features similar to masks in images

Motivation for this article:

  • The convolutional layer of the traditional CNN lacks the ability to model long-range dependencies in the image (even if the continuous use of the pooling layer can improve the receptive field, it will cause a large amount of structural loss). Whereas Transformer has good performance in capturing long-range dependencies.
  • Since the scarcity of labeled medical data is a bottleneck problem, and the Transformer structure often requires a large amount of data to achieve better performance, this paper proposes a Gated Axial Attention structure to consider solving this problem. (mainly extending the existing architecture by introducing additional control mechanisms in the self-attention module)
  • In addition, in order to improve the performance of Transform, the article proposes a local-global training strategy (specifically, we operate on the entire image and individual patches to learn global and local features, respectively)

2. Medical Transformer (MedT)

2.1. Model structure

MedT has two branch structures, a global branch structure and a local branch structure, and the input of these two branches is the feature map extracted from the initial conv block. This block has 3 conv layers, each conv layer is followed by batch normalizationand ReLUactivation functions.

The overall structure of the network is shown in the figure below, which is a U-shape structure with two branches, and the Encoder and Decoder in the structure :

  • In the Encoder of both branches , use Transformerthe layer
    • That is, this article only uses Transformerthe mechanism in theU-Net Encoder part of the structure , and unlike other methods for cv that rely on the pre-trained weights of large data sets, this method does not require pre-trainingself-attentionTransformer
    • The Encoder part is shown in Figure (b) below, including 1 × 1 1 \times 11×1 convolutional layer (followed by onebatch normalization) and two layersmulti-head attention block, one of which operates along the height axis and the other along the width axis, each consisting multi-head attention blockof
      • Each multi-head attention blockhas 8 88 gated axesmulti-head
      • multi-head attention blockThe output of is passed through another 1 × 1 1 \times 11×1 convolutional layer is added to the residual input map to produce the output attention map
  • In the Decoder of the two branches , use convthe block
    • In each Decoder block, there is a convolutional layer, followed by an upsampling layer and ReLUactivation function
  • Between the blocks of each Encoder and Decoder in the two branches there isskip connections

insert image description here

2.2. Attention

2.2.1. Self-Attention Overview

with height HHH , weightWWInput feature mapx ∈ RC in × H × W x \in R^{C_{in} \times H \times W} of W and channel $C_{in} $xRCin× H × W With the help of projection input, the outputy of the self-attention layer is calculated using the following formula ∈ RC out × H × W y \in R^{C_{out} \times H \times W}yRCout×H×W

insert image description here

Parameter meaning:

  • enter xxx calculates the mapping to get queriesq = WQ xq = W_Q xq=WQx, keys k = W K x k = W_K x k=WKx,values v = W V x v = W_V x v=WVx
  • q i j , k i j , v i j q_{ij}, k_{ij}, v_{ij} qij,kij,vijRepresents query, key and value at any position i ∈ { 1 , … , H } i \in \{ 1, \dots, H \}i{ 1,,H} j ∈ { 1 , … , W } j \in \{ 1, \dots, W \} j{ 1,,The value of W }
  • 投影矩阵 W Q , W K , W V ∈ R C i n × C o u t W_Q, W_K, W_V \in R^{C_{in} \times C_{out}} WQ,WK,WVRCin×Coutis learnable

Limitations of the self-attention mechanism:

  • Unlike convolution, the self-attention mechanism can capture non-local information from the entire feature map, but this calculation of similarity is very computationally intensive
    • When ViT was proposed, each of the Transformer tokenwill calculate attention for tokeneach , so it is ( hw ) 2 (hw)^2(hw)2 calculations, which is a very large amount of calculation. For other content about ViT, please refer to my other blog:CV-Model [6]: Vision Transformer
  • Because no location information is introduced, Transformer actually does not have the ability to express locations.

2.2.2. Axial-Attention

insert image description here

  • In order to overcome the high computational complexity, the traditional self-attention module is divided into two attention modules on the width and height, called axial attention, which greatly reduces the computational complexity
    • The receptive field of Axial attention is the WW of the same row (or the same column) of the target pixelW (orHHH ) pixels
    • The axial attention applied on the height and width axes effectively mimics the original self-attention mechanism with better computational efficiency
    • An axial attention layer propagates information along a specific axis. To capture global information, we use two axial attention layers consecutively for the height axis and the width axis respectively, and both axial attention layers adopt multi-head attention mechanism
  • In order to increase the ability of position expression, one needs to be added position embedding, that is, to use onehota position vector, through a full connection embedding, to generate a position code, this full connection is trainable
    • This position code was originally only added to Q , K , VQ, K, VQ,K,V 'sQQon Q , now add it toQ , K , VQ, K, VQ,K,V three above

After adding axial attention and multiple position-encoded tricks, the attention mechanism is as follows (the article is given in the width direction wwIn the attention on w , the height direction hhsimilar attention on h ):

insert image description here

Parameter meaning:

  • w w w indicates the corresponding line (width)
  • y i j y_{ij} yijIndicates the output at a specific location
  • r q , r k , r v ∈ R W × W r^q, r^k, r^v \in R^{W \times W} rq,rk,rvRW × W represents the position matrix in the width-wise axial attention model

The height direction is the same

2.2.3. Gated Axial-Attention

However, the above trick requires a large number of data sets for training, a small amount of data is not enough to train the three QKV position embedding, and medical data sets are mostly a small amount

In this case, the inaccurate position embeddingwill have a negative impact on the network accuracy, so this article proposes a method to control the degree of this impact, and modify the above formula as follows:

insert image description here

Here are three GQ , GK , GV G_Q, G_K, G_VGQ,GK,GVBoth are learnable parameters. position embeddingWhen , the network Gwill be smaller, otherwise it will be larger, so it plays a so-called Gated role. position embeddingAs long as the relative position is the same, it should be the same for different samples, because position embeddingit is only position information and does not contain semantic information

In general, if a relative position code is learned accurately, the gating mechanism will give it a higher weight relative to those codes that were not learned accurately.

2.3. Local-Global Training

Transformer can do image segmentation in patch-wisethe way of , that is to say, a complete picture is cut into multiple patch, patchand patchcorresponding to this maskis used as a sample to train the transformer, which is very fast

The problem, however, is that a lesion in an image may be patchlarger , so the patchwill look weird because it's filled with lesions. This limits the network from learning any information or dependencies patchbetween pixels

The idea of ​​the Local-Global part is a bit like a multi-scale thinking. He divides the network into two branches:

  • The first branch is the Global branch, which works on the original resolution image without special processing. The strategy adopted is to go through fewer blocks (transformer block twice and then send it to the Decoder) to obtain a greater distance dependence
    • The number of gated axial transformer layers is reduced because it is found that the first few blocks of the proposed transformer model are sufficient to model long-distance dependencies
  • The first branch is Local branch, which is divided into 4 × 4 4 \times 44×4 ,patcheach is sent to the transformer block for forward propagation patchalonepatchand there is no connection patchwith4 × 4 4 \times 44×4patch of thefeature mapby concatoperation
    • Each feeds forward patchthrough the network and resamples the output feature map according to its position to obtain the output feature map

Add the output feature maps of the two branches and pass 1 × 1 1 \times 11×1 convolutional layer to produce the output segmentation mask. This strategy of operating the shallower model on the global context of the image and the deeper model patchon the

2.4. Loss function

MedT uses a binary cross-entropy (CE) loss between predictions and ground truth to train the network

insert image description here

Parameter meaning:

  • w , h w, h w,h is the size of the image
  • p ( x , y ) p(x, y) p(x,y ) corresponds to a pixel in the image
  • p ^ ( x , y ) \hat{p}(x, y) p^(x,y ) means at a specific position( x , y ) (x, y)(x,y ) output prediction

Summarize

The authors explore the use of transformer-based encoder architectures to segment medical images without any pre-training:

  1. A gated position-sensitive axial attention mechanism suitable for smaller datasets is proposed, and a gated axial attention layer is proposed as a building block of a multi-head attention model for network encoders.
  2. An efficient local-global (LOGO) training method is introduced, in which the authors use the same network architecture to train whole images and image patches. Global branches help learn high-level features in the network by modeling long-term dependencies, while local branches focus on finer-grained features by operating on patches.
  3. A MedT (Medical Transformer) with axial attention as the main building block is proposed as the main building block of the encoder, and a LOGO strategy is used to train images.
  4. Extensive experiments on three different medical image segmentation datasets improve the performance of ConvNets and other Transformer-based architectures.

Guess you like

Origin blog.csdn.net/HoraceYan/article/details/128745781