Deep Learning Practice (10): Using PyTorch for 3D Medical Image Segmentation

MedicalZoo paper: Deep learning in medical image analysis: a comparative analysis of multi-modal brain-MRI segmentation with 3D deep neural networks
code has been open source:MedicalZooPytorch
More Ai information: Princess AiCharm
insert image description here

1. Project Introduction

  The rise of deep networks in computer vision has provided state-of-the-art solutions to problems where classical image processing techniques underperform. DNNs (Deep Neural Networks) can easily be claimed to have achieved superior performance in generalized image recognition tasks, including problems such as object detection, image classification and segmentation, activity recognition, optical flow, and pose estimation.

  Along with the rise of computer vision, there has been a keen interest in its application in the field of medical imaging. Although medical imaging data is not so readily available, DNNs seem to be ideal candidates for modeling such complex high-dimensional data.

  Recently, Imperial College London launched a course on COVID-19. Many studies have attempted to automatically detect COVID-19 through deep networks of 3D CT scans. Despite this, specific application data is still not available. It is clear that artificial intelligence will have a huge impact on the development of medicine through medical imaging.

  As I will see, medical images are usually three or four dimensional. Another reason this field attracts a lot of attention is its direct impact on human life. In the United States, medical malpractice is the third leading cause of death after heart disease and cancer. Therefore, it is clear that the top three causes of human death are all related to medical imaging. That’s why it’s estimated that by 2023, artificial intelligence and deep learning will create a whole new market in excess of $1 billion in medical imaging.

  This work serves as the intersection of these two worlds: deep neural networks and medical imaging. In this post, I'll tackle the problem of medical image segmentation, with a focus on magnetic resonance images, which is one of the most popular tasks because it's the one for which the most well-structured dataset anyone has access to. Since online medical data collection is not as simple as it sounds; a collection of links is provided at the end of the article to start your journey.

  This paper presents some preliminary results from an open-source library under development called MedicalZoo , which can be found here.

2. Requirements for 3D medical image segmentation

  3D volumetric image segmentation in medical images is mandatory for diagnosis, monitoring and treatment planning. I will only be using magnetic resonance images (MRI). Manual procedures require knowledge of anatomy, and they are expensive and time-consuming. Also, they may be inaccurate due to human factors. However, automatic volume segmentation can save physicians time and provide accurate and reproducible solutions for further analysis.

  I'll start by describing the fundamentals of MR imaging, since understanding your input data is critical to training a deep architecture. Then, the reader is provided with an overview of 3D-UNET that can be effectively used for this task.

3. Medical images and MRI

  Medical imaging attempts to reveal internal structures hidden by skin and bone, as well as to diagnose and treat disease. Medical magnetic resonance (MR) imaging uses signals from hydrogen nuclei to create images. In the case of the hydrogen nucleus: when it is exposed to an external magnetic field, denoted B0, the magnetic moment, or spin, aligns with the direction of the magnetic field like a compass needle.

  All constant magnetization is rotated to the other plane by an additional radio frequency pulse strong enough and applied long enough to tilt the magnetization. Immediately after excitation, the magnetization rotates in the other plane. The rotating magnetization generates the MR signal in the receiving coil. However, the MR signal fades rapidly due to two separate processes that reduce the magnetization, which leads to a return to the pre-excitation stable state, resulting in the so-called T1 image and T2 MR image. T1 relaxation is related to the excess energy of the nucleon to its surroundings, while T2 relaxation refers to the phenomenon in which individual magnetization vectors begin to cancel each other out. The above phenomena are completely independent. Therefore, different intensities represent different tissues, as shown in the figure below.
insert image description here

4. 3D Medical Image Representation

  Since medical images represent 3D structures, they can be processed by using slices of 3D volumes and performing conventional 2D sliding convolutions, as shown in the figure below. We assume that the red rectangle is a 5x5 image patch, which can be represented by a matrix containing intensity values. The voxel intensities and kernels are convolved with a 3x3 convolution kernel as shown in the figure below. In the same pattern, the kernel slides over the entire 2D grid (medical image slices), and each time we do a cross-correlation. The result of a 5x5 convolutional patch is stored in a 3x3 matrix (without padding for illustration) and propagated through the next layer of the network. insert image description here  Alternatively, you can represent them as outputs similar to intermediate layers. In deep architectures, we usually have multiple feature maps, which are actually a 3D tensor. If there is reason to believe that there is a pattern in the extra dimension, then performing a 3D sliding convolution is the best choice. This is the case in medical images. Similar to two-dimensional convolution, two-dimensional convolution encodes the spatial relationship of objects in the two-dimensional field, and three-dimensional convolution can describe the spatial relationship of objects in three-dimensional space. Since 2D representations are suboptimal for medical images, we will choose to use 3D convolutional networks in this post.
insert image description here
  Medical image slices can be viewed as multiple feature maps in an intermediate layer, with the difference that they have strong spatial relationships.

5. 3D-Unet model

  In our example we will use the accepted 3D U-shaped network. The latter (code) extends the continuous idea of ​​symmetric U-shaped 2D Unet networks, yielding impressive results in RGB-related tasks such as semantic segmentation. The model has an encoder (shrinkage path) and a decoder (synthesis path) paths, each with four parsing steps. In the encoder path, each layer consists of two 3×3×3 convolutions, each followed by a rectified linear unit (ReLu), followed by a 2×2×2 maximal set, each dimension The step size is 2. In the decoding path, each layer consists of a 2×2×2 transposed convolution with a stride of 2 in each dimension, followed by two 3×3×3 convolutions, each followed by a ReLu . Shortcut jump connections of equal-resolution layers in the analysis path provide essential high-resolution features for the synthesis path. In the last layer, a 1×1×1 convolution reduces the number of output channels to the number of labels. By doubling the number of channels before max pooling, the bottleneck problem is avoided. Before each ReLU, 3D batch normalization is introduced. During training, the mean and standard deviation of each batch are normalized and global statistics are updated with these values. Next is a layer that explicitly learns the scale and bias. The figure below illustrates the network structure.

insert image description here

5.1 Loss function: Dice Loss

  Due to the inherent task imbalance, cross-entropy cannot always provide a good solution for this task. Specifically, the cross-entropy loss examines each pixel individually, comparing the class prediction (pixel vector in the depth direction) to our one-hot encoded target vector. Because the cross-entropy loss evaluates the class prediction for each pixel vector individually and then averages over all pixels, we are basically asserting that each pixel in the image is learned equally. This can be a problem if you have an imbalanced representation of your various classes in your images, as the most prevalent classes may dominate training.

  The 4 classes we will try to distinguish in the brain MRI have different frequencies in the images (i.e. air has much more instances than other tissue). This is why the Dice Loss indicator is employed. It is based on the dice coefficient, which is essentially a measure of the overlap between two samples. This measure ranges from 0 to 1, where a Dice coefficient of 1 indicates perfect and complete overlap. Dice loss was originally developed for binary classification, but it can be generalized to multiclass work. Please feel free to use our multi-class implementation of Dice loss.

5.2 Medical Imaging Data

  Deep architectures require a large number of training samples before they can produce any useful generalization representations, and labeled training data is often expensive and difficult to produce. This is why every day we see more and more medical imaging data being produced using new techniques of generative learning. Additionally, the training data must be representative of what the network will encounter in the future. If the training samples come from a data distribution different from that encountered in the real world, the generalization performance of the network will be lower than expected.

  Since we focus on automatic segmentation of brain MRI, it is necessary to briefly introduce the basic structures of the brain that DNNs try to distinguish: a) white matter (WM), b) gray matter (GM), c) cerebrospinal fluid (CSF). The figure below illustrates segmented tissue in a brain MRI slice.
insert image description here

5.2.1 2017 I-Seg Medical Image Data Challenge

  Accurate segmentation of infant brain MRI images into white matter (WM), gray matter (GM), and cerebrospinal fluid (CSF) during this critical period is fundamental for studying normal and abnormal early brain development. The first year of life is the most active period of postnatal human brain development, with rapid tissue growth and development of various cognitive and motor functions. This early stage is critical in many neurodevelopmental and neuropsychiatric disorders, such as schizophrenia and autism. More and more attention is being paid to this critical period.

  The purpose of this dataset is to facilitate automatic segmentation algorithms for MRI brain MRIs of 6-month-old infants. This challenge was held concurrently with MICCAI 2017, and a total of 21 international teams participated. The dataset contains 10 densely annotated images from experts and 13 imaging images for testing. No test tags are provided, and you can only see your score after uploading the results on the official website. Each subject has one T1-weighted and T2-weighted image.

  The first subject will be used for testing. The size of the original MR volume is 256x192x144. In 3D-Unet, the size of the sample subvolume used is 128x128x64. The resulting training dataset consists of 500 subvolumes. For the validation set, 10 random samples from one subject were used.

  Dataset download

6. Medical Zoo

  Our goal is to implement in PyTorch an open-source medical image segmentation library consisting of state-of-the-art 3D deep neural networks, and data loaders for the most common medical datasets. The first stable version of our repository is expected to be released soon.

  We strongly believe in open and reproducible deep learning research. To reproduce our results, both the code and materials for this work can be found in this repository. This project started as a master's thesis and is currently being further developed.

6.1 Implementation Details

  We use the PyTorch framework, which is considered the most widely accepted deep learning research tool. All experiments use stochastic gradient descent with a single batch size, a learning rate of 1e-3, and a weight decay of 1e-8. We provide tests in the repository where you can easily reproduce our results so you can use the code, models and data loaders.

  Recently, we augmented the visualization capabilities of Tensorboard with Pytorch. This amazing feature keeps your sanity in place and lets you keep track of your model's training process. Below you can see an example that keeps training statistics, dice coeff. and loss, and per-class scores to get an idea of ​​how the model behaves.

6.2 Code

  Let's put all the described modules together and set up an experiment with a short script (for illustration) from MedicalZoo.

# Python libraries
import argparse
import os

# Lib files
import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
import lib.train as train
import lib.utils as utils
from lib.losses3D import DiceLoss

def main():
    args = get_arguments()
    utils.make_dirs(args.save)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
                                                                                               path='.././datasets')
    model, optimizer = medzoo.create_model(args)
    criterion = DiceLoss(classes=args.classes)

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    trainer = train.Trainer(args, model, criterion, optimizer, train_data_loader=training_generator,
                            valid_data_loader=val_generator, lr_scheduler=None)
    print("START TRAINING...")
    trainer.training()


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=4)
    parser.add_argument('--dataset_name', type=str, default="iseg2017")
    parser.add_argument('--dim', nargs="+", type=int, default=(64, 64, 64))
    parser.add_argument('--nEpochs', type=int, default=200)
    parser.add_argument('--classes', type=int, default=4)
    parser.add_argument('--samples_train', type=int, default=1024)
    parser.add_argument('--samples_val', type=int, default=128)
    parser.add_argument('--inChannels', type=int, default=2)
    parser.add_argument('--inModalities', type=int, default=2)
    parser.add_argument('--threshold', default=0.1, type=float)
    parser.add_argument('--terminal_show_freq', default=50)
    parser.add_argument('--augmentation', action='store_true', default=False)
    parser.add_argument('--normalization', default='full_volume_mean', type=str,
                        help='Tensor normalization: options ,max_min,',
                        choices=('max_min', 'full_volume_mean', 'brats', 'max', 'mean'))
    parser.add_argument('--split', default=0.8, type=float, help='Select percentage of training data(default: 0.8)')
    parser.add_argument('--lr', default=1e-2, type=float,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--loadData', default=True)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='VNET',
                        choices=('VNET', 'VNET2', 'UNET3D', 'DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
    parser.add_argument('--opt', type=str, default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--log_dir', type=str,
                        default='../runs/')

    args = parser.parse_args()

    args.save = '../saved_models/' + args.model + '_checkpoints/' + args.model + '_{}_{}_'.format(
        utils.datestr(), args.dataset_name)
    return args


if __name__ == '__main__':
    main()

6.3 Experimental results

  Below you can see the training and validation Dice LOSS curves for the model. It is important to monitor your model performance and tune parameters to get such a smooth training curve. It's easy to understand the efficiency of this model.
insert image description here
insert image description here
  Surprisingly, the model achieves about 93% of the dice coefficient scores on the validation set of the sub-volume. Last but not least, let's see some visual predictions of 3D-Unet on the validation set. We only show a representative slice here, despite predicting a 3D volume. By extracting multiple sub-volumes of the MRI, we can combine them to form a complete 3D MRI segmentation. Note that the fact that we use subvolume sampling is as an augmentation of the data.

insert image description here
Unnormalized last layer preactivation from trained 3D-Unet. The network learns highly semantic task-related content, corresponding to brain structures similar to the input.
insert image description here
  Our predictions vs ground truth. Which prediction do you think is the ground truth? Before you decide, take a closer look. It's worth noting that we've only shown mid-axis slices here, but the projection is a 3D volume. One can observe that the network perfectly predicts air voxels, whereas it has trouble distinguishing tissue boundaries. But let's double check and find out what's real!
insert image description here
  Now, I'm sure you can discern the truth from the ground. If you are not sure, please see the end of the article

  Recently we also added Pytorch's Tensorboard visualization capabilities. This amazing feature keeps your sanity in place and lets you keep track of your model's training process. Below you can see an example of keeping training statistics, dice coefficients and loss, and per-class scores to get an idea of ​​how the model behaves.

insert image description here
  It is clear that different tissues have different accuracies, even from the beginning of training. For example, looking at the air voxels in the validation set, they start with high values ​​because it is the most dominant class in the imbalanced dataset. Gray matter, on the other hand, starts from the lowest value because it is the most difficult to distinguish and has fewer training examples.

7. Summary

  This post partially illustrates some features of the MedicalZoo Pytorch library. Deep learning models will provide society with immersive medical image solutions.

  In this post, the basic concepts of medical imaging and MRI are reviewed, and how they are represented and used in deep learning architectures. Then, an efficient and widely accepted 3D architecture (Unet) and dice loss function for handling class imbalance are described. Finally, the preliminary results of our experimental analysis in brain MRI are presented combining all the features described above and using library scripts. These results demonstrate the efficiency of 3D architectures and the potential of deep learning in medical image analysis.

More Ai information: Princess AiCharm
insert image description here

Guess you like

Origin blog.csdn.net/muye_IT/article/details/125210968