Pytorch~MRI brain scan image segmentation

Image segmentation is one of the most important tasks in medical image analysis, and it is often the first and most critical step in many clinical applications. In brain MRI analysis, image segmentation is commonly used to measure and visualize anatomical structures, analyze brain changes, delineate pathological regions as well as surgical planning and image-guided intervention, segmentation is a prerequisite for most morphological analysis.

In this article we will introduce how to use QuickNAT to segment images of the human brain. Use MONAI, PyTorch, and common Python libraries for data visualization and computation, such as NumPy, TorchIO, and matplotlib.

This article will mainly design the following aspects:

  • Setting up datasets and exploring data

  • Process and prepare datasets for appropriate model training

  • Create a training loop

  • Evaluate the model and analyze the results

The full code is provided at the end of this article.

set data directory

The first step in using MONAI is to set the MONAI_DATA_DIRECTORY environment variable to specify a directory, if not specified a temporary directory will be used.

 directory \= os.environ.get\("MONAI\_DATA\_DIRECTORY"\)  
 root\_dir \= tempfile.mkdtemp\(\) if directory is None else directory  
 print\(root\_dir\)

set dataset

One of the main challenges in extending CNN models to brain segmentation is the limited availability of human-annotated training data. The authors introduce a new training strategy utilizing both large datasets without manual labels and small datasets with manual labels.

First, automatically generated segmentations are obtained from large unlabeled datasets using existing software tools (e.g. FreeSurfer), and then the network is pretrained using these tools. In the second step, the network is fine-tuned using smaller manually annotated data [2].

The IXI dataset consists of unlabeled MRI T1 scans of 581 healthy subjects. The data were collected from 3 different hospitals in London. The main disadvantage of using this dataset is that the labels are not publicly available, so to follow the same approach as in the research paper, this paper will use FreeSurfer to generate segmentations for these MRI T1 scans.

FreeSurfer is a software package for analyzing and visualizing structures. Download and installation instructions can be found here. All cortical reconstruction procedures can be performed directly using the "recon-all" command.

Although FreeSurfer is a very useful tool that can take advantage of a large amount of unlabeled data and train the network in a supervised manner, it takes up to 5 hours to scan and generate these labels, so we directly use the OASIS dataset to train the model here.

The OASIS dataset is a smaller dataset with publicly available manual annotations. OASIS is a project that makes brain neuroimaging datasets freely available to the scientific community. OASIS-1 is a data set consisting of cross-sections of 39 subjects, obtained as follows:

 resource \= "https://download.nrg.wustl.edu/data/oasis\_cross-sectional\_disc1.tar.gz"  
 md5 \= "c83e216ef8654a7cc9e2a30a4cdbe0cc"  
   
 compressed\_file \= os.path.join\(root\_dir, "oasis\_cross-sectional\_disc1.tar.gz"\)  
 data\_dir \= os.path.join\(root\_dir, "Oasis\_Data"\)  
 if not os.path.exists\(data\_dir\):  
  download\_and\_extract\(resource, compressed\_file, data\_dir, md5\)

data exploration

If you open 'oasis_crosssectional_disc1.tar.gz', you will find that there are different folders for each topic. For example, for topic OAS1_0001_MR1, something like this:

Image data file path: disc1\OAS1_0001_MR1\PROCESSED\MPRAGE\T88_111\ oas1_0001_mr1_mpr_n4_anon_111_t88_masked_ggc .img

Label file: disc1\OAS1_0001_MR1\FSL_SEG\OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.img

 

Data loading and preprocessing

After downloading and extracting the dataset to a temporary directory, it needs to be restructured, we want our directory to look like this:

So you need to follow the steps below to load the data:

Convert img files to nii files and save into new folders: create two new folders. Oasis_Data_Processed includes the processed MRI T1 scans for each subject and Oasis_Labels_Processed includes the corresponding labels.

 new\_path\_data\= root\_dir + '/Oasis\_Data\_Processed/'  
 if not os.path.exists\(new\_path\_data\):  
  os.makedirs\(new\_path\_data\)   
   
 new\_path\_labels\= root\_dir + '/Oasis\_Labels\_Processed/'  
 if not os.path.exists\(new\_path\_labels\):  
  os.makedirs\(new\_path\_labels\)

Then just operate on it:

 for i in \[x for x in range\(1, 43\) if x \!= 8 and x \!= 24 and x \!= 36\]:  
  if i \< 7 or i \== 9:  
  filename \= root\_dir + '/Oasis\_Data/disc1/OAS1\_000'+ str\(i\) + '\_MR1/PROCESSED/MPRAGE/T88\_111/OAS1\_000' + str\(i\) + '\_MR1\_mpr\_n4\_anon\_111\_t88\_masked\_gfc.img'  
  elif i \== 7:   
  filename \= root\_dir + '/Oasis\_Data/disc1/OAS1\_000'+ str\(i\) + '\_MR1/PROCESSED/MPRAGE/T88\_111/OAS1\_000' + str\(i\) + '\_MR1\_mpr\_n3\_anon\_111\_t88\_masked\_gfc.img'  
  elif i\==15 or i\==16 or i\==20 or i\==24 or i\==26 or i\==34 or i\==38 or i\==39:  
  filename \= root\_dir + '/Oasis\_Data/disc1/OAS1\_00'+ str\(i\) + '\_MR1/PROCESSED/MPRAGE/T88\_111/OAS1\_00' + str\(i\) + '\_MR1\_mpr\_n3\_anon\_111\_t88\_masked\_gfc.img'  
  else:   
  filename \= root\_dir + '/Oasis\_Data/disc1/OAS1\_00'+ str\(i\) + '\_MR1/PROCESSED/MPRAGE/T88\_111/OAS1\_00' + str\(i\) + '\_MR1\_mpr\_n4\_anon\_111\_t88\_masked\_gfc.img'  
  img \= nib.load\(filename\)  
  nib.save\(img, filename.replace\('.img', '.nii'\)\)  
  i \= i+1

The specific code will not be pasted anymore, if you are interested, take a look at the final complete code. The next step is to read the image and label filenames

 image\_files \= sorted\(glob\(os.path.join\(root\_dir + '/Oasis\_Data\_Processed', '\*.nii'\)\)\)  
 label\_files \= sorted\(glob\(os.path.join\(root\_dir + '/Oasis\_Labels\_Processed', '\*.nii'\)\)\)  
 files \= \[\{'image': image\_name, 'label': label\_name\} for image\_name, label\_name in zip\(image\_files, label\_files\)\]

To visualize images with corresponding labels, one can use TorchIO, a Python library for loading, preprocessing, augmenting, and sampling of multidimensional medical images for deep learning.

 image\_filename \= root\_dir + '/Oasis\_Data\_Processed/OAS1\_0001\_MR1\_mpr\_n4\_anon\_111\_t88\_masked\_gfc.nii'  
 label\_filename \= root\_dir + '/Oasis\_Labels\_Processed/OAS1\_0001\_MR1\_mpr\_n4\_anon\_111\_t88\_masked\_gfc\_fseg.nii'  
 subject \= torchio.Subject\(image\=torchio.ScalarImage\(image\_filename\), label\=torchio.LabelMap\(label\_filename\)\)  
 subject.plot\(\)

The next step is to divide the data into 3 parts - training, validation and testing. The purpose of dividing the data into three different categories is to build a reliable machine learning model and avoid overfitting.

We divide the whole dataset into three parts:

Train: 80\%,Validation: 10\%,Test: 10\%

 train\_inds, val\_inds, test\_inds \= partition\_dataset\(data \= np.arange\(len\(files\)\), ratios \= \[8, 1, 1\], shuffle \= True\)  
   
 train \= \[files\[i\] for i in sorted\(train\_inds\)\]  
 val \= \[files\[i\] for i in sorted\(val\_inds\)\]  
 test \= \[files\[i\] for i in sorted\(test\_inds\)\]  
   
 print\(f"Training count: \{len\(train\)\}, Validation count: \{len\(val\)\}, Test count: \{len\(test\)\}"\)

Because the model requires two-dimensional slices, save each slice in a different folder, as shown in the figure below. These two code units save slices of each MRI volume of the training set in ".png" format.

  Save coronal slices for training images  
 dir \= root\_dir + '/TrainData'  
 os.makedirs\(os.path.join\(dir, "Coronal"\)\)  
 path \= root\_dir + '/TrainData/Coronal/'  
   
 for file in sorted\(glob\(os.path.join\(root\_dir + '/TrainData', '\*.nii'\)\)\):  
  image\=torchio.ScalarImage\(file\)  
  data \= image.data  
  filename \= os.path.basename\(file\)  
  filename \= os.path.splitext\(filename\)  
  for i in range\(0, 208\):  
  slice \= data\[0, :, i\]  
  array \= slice.numpy\(\)  
  data\_dir \= root\_dir + '/TrainData/Coronal/' + filename\[0\] + '\_slice' + str\(i\) + '.png'  
  plt.imsave\(fname \= data\_dir, arr \= array, format \= 'png', cmap \= plt.cm.gray\)

Similarly, the following is the save label:

 dir \= root\_dir + '/TrainLabels'  
 os.makedirs\(os.path.join\(dir, "Coronal"\)\)  
 path \= root\_dir + '/TrainLabels/Coronal/'  
   
 for file in sorted\(glob\(os.path.join\(root\_dir + '/TrainLabels', '\*.nii'\)\)\):  
  label \= torchio.LabelMap\(file\)  
  data \= label.data  
  filename \= os.path.basename\(file\)  
  filename \= os.path.splitext\(filename\)  
  for i in range\(0, 208\):  
  slice \= data\[0, :, i\]  
  array \= slice.numpy\(\)  
  data\_dir \= root\_dir + '/TrainLabels/Coronal/' + filename\[0\] + '\_slice' + str\(i\) + '.png'  
  plt.imsave\(fname \= data\_dir, arr \= array, format \= 'png'\)

Define image transformation processing for training and validation

In this example, we will use Dictionary Transforms, where the data is a Python dictionary.

 train\_images\_coronal \= \[\]  
 for file in sorted\(glob\(os.path.join\(root\_dir + '/TrainData/Coronal', '\*.png'\)\)\):  
  train\_images\_coronal.append\(file\)  
 train\_images\_coronal \= natsort.natsorted\(train\_images\_coronal\)  
   
 train\_labels\_coronal \= \[\]  
 for file in sorted\(glob\(os.path.join\(root\_dir + '/TrainLabels/Coronal', '\*.png'\)\)\):  
  train\_labels\_coronal.append\(file\)  
 train\_labels\_coronal\= natsort.natsorted\(train\_labels\_coronal\)  
   
 val\_images\_coronal \= \[\]  
 for file in sorted\(glob\(os.path.join\(root\_dir + '/ValData/Coronal', '\*.png'\)\)\):  
  val\_images\_coronal.append\(file\)  
 val\_images\_coronal \= natsort.natsorted\(val\_images\_coronal\)  
   
 val\_labels\_coronal \= \[\]  
 for file in sorted\(glob\(os.path.join\(root\_dir + '/ValLabels/Coronal', '\*.png'\)\)\):  
  val\_labels\_coronal.append\(file\)  
 val\_labels\_coronal \= natsort.natsorted\(val\_labels\_coronal\)  
   
 train\_files\_coronal \= \[\{'image': image\_name, 'label': label\_name\} for image\_name, label\_name in zip\(train\_images\_coronal, train\_labels\_coronal\)\]  
 val\_files\_coronal \= \[\{'image': image\_name, 'label': label\_name\} for image\_name, label\_name in zip\(val\_images\_coronal, val\_labels\_coronal\)\]

Now we will apply the following transformation:

LoadImaged: Load image data and metadata. We use 'PILReader' to load image and label files. ensure_channel_first is set to True to convert the image array shape to channel-first.

Rotate90d: We rotate images and labels by 90 degrees because they were not oriented correctly when we downloaded them.

ToTensord: Converts input images and labels to tensors.

NormalizeIntensityd: Normalizes the input.

 train\_transforms \= Compose\(  
  \[  
  LoadImaged\(keys \= \['image', 'label'\], reader\=PILReader\(converter\=lambda image: image.convert\("L"\)\), ensure\_channel\_first \= True\),  
  Rotate90d\(keys \= \['image', 'label'\], k \= 2\),  
  ToTensord\(keys \= \['image', 'label'\]\),  
  NormalizeIntensityd\(keys \= \['image'\]\)  
  \]  
  \)  
   
 val\_transforms \= Compose\(  
  \[  
  LoadImaged\(keys \= \['image', 'label'\], reader\=PILReader\(converter\=lambda image: image.convert\("L"\)\), ensure\_channel\_first \= True\),  
  Rotate90d\(keys \= \['image', 'label'\], k \= 2\),  
  ToTensord\(keys \= \['image', 'label'\]\),  
  NormalizeIntensityd\(keys \= \['image'\]\)  
  \]  
  \)

MaskColorMap lets us define a new transformation that maps corresponding pixel values ​​to multiple labels in one format. This transformation is essential in semantic segmentation since we have to provide binary features for each possible class. One-Hot Encoding assigns a value of 1 to the feature of each sample corresponding to the original category.

Because the OASIS-1 dataset has only 3 labels for brain structures, for a more detailed segmentation it would ideally be annotated for 28 cortical structures as they do in their research paper. Labels for more brain structures obtained using FreeSurfer can be found in the OASIS-1 download instructions.

So this article will segment more neuroanatomical structures. We want to modify the parameter num_classes of the model to the corresponding number of labels, so that the output of the model is a feature map with N channels, equal to num_classes.

To simplify this tutorial, we will use the following tags, less than OASIS-1 but less than FreeSurfer:

  • Label 0: Background

  • Label 1: LeftCerebralExterior

  • Label 2: LeftWhiteMatter

  • Label 3: LeftCerebralCortex

So the code for MaskColorMap is as follows:

 class MaskColorMap\(Enum\):  
  Background = \(30\)  
  LeftCerebralExterior = \(91\)  
  LeftWhiteMatter = \(137\)  
  LeftCerebralCortex = \(215\)

Datasets and Data Loading

Datasets and dataloaders fetch data from storage and send it in batches to the training loop. Here we use monai.data.Dataset to load the previously defined training and validation dictionaries and apply the corresponding transformations to the input data. dataloader is used to load datasets into memory. We will define a dataset and data loader for training and validation, and for each view.

In order to facilitate the demonstration, we use torch.utils.data.Subset to create a subset at the specified index, and only use part of the data training to speed up the demonstration.

 train\_dataset\_coronal \= Dataset\(data\=train\_files\_coronal, transform \= train\_transforms\)  
 train\_loader\_coronal \= DataLoader\(train\_dataset\_coronal, batch\_size \= 1, shuffle \= True\)  
   
 val\_dataset\_coronal \= Dataset\(data \= val\_files\_coronal, transform \= val\_transforms\)  
 val\_loader\_coronal \= DataLoader\(val\_dataset\_coronal, batch\_size \= 1, shuffle \= False\)  
   
 \# We will use a subset of the dataset  
 subset\_train \= list\(range\(90, len\(train\_dataset\_coronal\), 120\)\)  
 train\_dataset\_coronal\_subset \= torch.utils.data.Subset\(train\_dataset\_coronal, subset\_train\)  
 train\_loader\_coronal\_subset \= DataLoader\(train\_dataset\_coronal\_subset, batch\_size \= 1, shuffle \= True\)  
   
 subset\_val \= list\(range\(90, len\(val\_dataset\_coronal\), 50\)\)  
 val\_dataset\_coronal\_subset \= torch.utils.data.Subset\(val\_dataset\_coronal, subset\_val\)  
 val\_loader\_coronal\_subset \= DataLoader\(val\_dataset\_coronal\_subset, batch\_size \= 1, shuffle \= False\)

define model

Given a set of MRI brain scans I = {I1,...In} and their corresponding segmentations S = {S1,...Sn}, we want to learn a function fseg: I -> S. We denote this function as an F-CNN model, called QuickNAT:

QuickNAT consists of three 2D f-cnns operating on coronal, axial, and sagittal views respectively, and then an aggregation step is used to infer the final segmentation result, which is composed of the probability maps of the three networks. Each F-CNN has an encoder/decoder architecture with 4 encoders and 4 decoders separated by bottleneck layers. The last layer is a classifier block with softmax. The architecture also includes residual linkage within each encoder/decoder block.

class QuickNat\(nn.Module\):  
  """  
  A PyTorch implementation of QuickNAT  
   
  """  
   
  def \_\_init\_\_\(self, params\):  
  """  
  :param params: \{'num\_channels':1,  
  'num\_filters':64,  
  'kernel\_h':5,  
  'kernel\_w':5,  
  'stride\_conv':1,  
  'pool':2,  
  'stride\_pool':2,  
  'num\_classes':28  
  'se\_block': False,  
  'drop\_out':0.2\}  
  """  
  super\(QuickNat, self\).\_\_init\_\_\(\)  
   
  \# from monai.networks.blocks import squeeze\_and\_excitation as se  
  \# self.cSE = ChannelSELayer\(num\_channels, reduction\_ratio\)  
   
  \# self.encode1 = sm.EncoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# params\["num\_channels"\] = params\["num\_filters"\]  
  \# self.encode2 = sm.EncoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.encode3 = sm.EncoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.encode4 = sm.EncoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.bottleneck = sm.DenseBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# params\["num\_channels"\] = params\["num\_filters"\] \* 2  
  \# self.decode1 = sm.DecoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.decode2 = sm.DecoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.decode3 = sm.DecoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
  \# self.decode4 = sm.DecoderBlock\(params, se\_block\_type=se.SELayer.CSSE\)  
   
  \# self.encode1 = EncoderBlock\(params, se\_block\_type=se.ChannelSELayer\)  
  self.encode1 \= EncoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  params\["num\_channels"\] \= params\["num\_filters"\]  
  self.encode2 \= EncoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.encode3 \= EncoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.encode4 \= EncoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.bottleneck \= DenseBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  params\["num\_channels"\] \= params\["num\_filters"\] \* 2  
  self.decode1 \= DecoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.decode2 \= DecoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.decode3 \= DecoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  self.decode4 \= DecoderBlock\(params, se\_block\_type\=se.SELayer.CSSE\)  
  params\["num\_channels"\] \= params\["num\_filters"\]  
  self.classifier \= ClassifierBlock\(params\)  
   
  def forward\(self, input\):  
  """  
  :param input: X  
  :return: probabiliy map  
   
  """  
   
  e1, out1, ind1 \= self.encode1.forward\(input\)  
  e2, out2, ind2 \= self.encode2.forward\(e1\)  
  e3, out3, ind3 \= self.encode3.forward\(e2\)  
  e4, out4, ind4 \= self.encode4.forward\(e3\)  
   
  bn \= self.bottleneck.forward\(e4\)  
   
  d4 \= self.decode4.forward\(bn, out4, ind4\)  
  d3 \= self.decode1.forward\(d4, out3, ind3\)  
  d2 \= self.decode2.forward\(d3, out2, ind2\)  
  d1 \= self.decode3.forward\(d2, out1, ind1\)  
  prob \= self.classifier.forward\(d1\)  
   
  return prob  
   
  def enable\_test\_dropout\(self\):  
  """  
  Enables test time drop out for uncertainity  
  :return:  
  """  
  attr\_dict \= self.\_\_dict\_\_\["\_modules"\]  
  for i in range\(1, 5\):  
  encode\_block, decode\_block \= \(  
  attr\_dict\["encode" + str\(i\)\],  
  attr\_dict\["decode" + str\(i\)\],  
  \)  
  encode\_block.drop\_out \= encode\_block.drop\_out.apply\(nn.Module.train\)  
  decode\_block.drop\_out \= decode\_block.drop\_out.apply\(nn.Module.train\)  
   
  \@property  
  def is\_cuda\(self\):  
  """  
  Check if model parameters are allocated on the GPU.  
  """  
  return next\(self.parameters\(\)\).is\_cuda  
   
  def save\(self, path\):  
  """  
  Save model with its parameters to the given path. Conventionally the  
  path should end with '\*.model'.  
   
  Inputs:  
  - path: path string  
  """  
  print\("Saving model... \%s" \% path\)  
  torch.save\(self.state\_dict\(\), path\)  
   
  def predict\(self, X, device\=0, enable\_dropout\=False\):  
  """  
  Predicts the output after the model is trained.  
  Inputs:  
  - X: Volume to be predicted  
  """  
  self.eval\(\)  
  print\("tensor size before transformation", X.shape\)  
   
  if type\(X\) is np.ndarray:  
  \# X = torch.tensor\(X, requires\_grad=False\).type\(torch.FloatTensor\)  
  X \= \(  
  torch.tensor\(X, requires\_grad\=False\)  
  .type\(torch.FloatTensor\)  
  .cuda\(device, non\_blocking\=True\)  
  \)  
  elif type\(X\) is torch.Tensor and not X.is\_cuda:  
  X \= X.type\(torch.FloatTensor\).cuda\(device, non\_blocking\=True\)  
   
  print\("tensor size ", X.shape\)  
   
  if enable\_dropout:  
  self.enable\_test\_dropout\(\)  
   
  with torch.no\_grad\(\):  
  out \= self.forward\(X\)  
   
  max\_val, idx \= torch.max\(out, 1\)  
  idx \= idx.data.cpu\(\).numpy\(\)  
  prediction \= np.squeeze\(idx\)  
  print\("prediction shape", prediction.shape\)  
  del X, out, idx, max\_val  
  return prediction

loss function

The training of a neural network requires a loss function to calculate the model error. The goal of training is to minimize the loss between the predicted output and the target output. Our model is optimized using a joint loss function of Dice Loss and Weighted Logistic Loss, where the weights compensate for high class imbalances in the data and encourage correct segmentation of anatomical boundaries.

optimizer

The optimization algorithm allows us to continue to update the parameters of the model and minimize the value of the loss function, we set the following hyperparameters:

Learning rate: The initial setting is 0.1, and it will be reduced by 1 order after 10 times. This can be achieved with a learning rate scheduler.

Weight Decay: 0.0001.

Batch size: 1.

Momentum: Set to a high value of 0.95 to compensate for noisy gradients due to small batch sizes.

training network

The model can now be trained. For QuickNAT, 3 models need to be trained on 3 (coronal, axial, sagittal) 2d slices. The probabilities of the three models are then combined in an aggregation step to generate the final result, but in this paper we only demonstrate training one F-CNN model on 2D slices of the coronal view since the other two are similar.

 num\_epochs \= 20  
 start\_epoch \= 1  
   
 val\_interval \= 1  
   
 train\_loss\_epoch\_values \= \[\]  
 val\_loss\_epoch\_values \= \[\]  
   
 best\_ds\_mean \= \-1  
 best\_ds\_mean\_epoch \= \-1  
   
 ds\_mean\_train\_values \= \[\]  
 ds\_mean\_val\_values \= \[\]  
 \# ds\_LCE\_values = \[\]  
 \# ds\_LWM\_values = \[\]  
 \# ds\_LCC\_values = \[\]  
   
 print\("START TRAINING. : model name = ", "quicknat"\)  
   
 for epoch in range\(start\_epoch, num\_epochs\):  
  print\("==== Epoch \["+ str\(epoch\) + " / "+ str\(num\_epochs\)+ "\] DONE ===="\)   
   
  checkpoint\_name \= CHECKPOINT\_DIR + "/checkpoint\_epoch\_" + str\(epoch\) + "." + CHECKPOINT\_EXTENSION  
  print\(checkpoint\_name\)  
  state \= \{  
  "epoch": epoch,  
  "arch": "quicknat",  
  "state\_dict": model\_coronal.state\_dict\(\),  
  "optimizer": optimizer.state\_dict\(\),  
  "scheduler": scheduler.state\_dict\(\),  
  \}  
  save\_checkpoint\(state \= state, filename \= checkpoint\_name\)  
   
  print\("\\n==== Epoch \[ \%d  /  \%d \] START ====" \% \(epoch, num\_epochs\)\)  
   
  steps\_per\_epoch \= len\(train\_dataset\_coronal\_subset\) / train\_loader\_coronal\_subset.batch\_size  
   
  model\_coronal.train\(\)  
  train\_loss\_epoch \= 0  
  val\_loss\_epoch \= 0  
  step \= 0  
   
  predictions\_train \= \[\]  
  labels\_train \= \[\]  
   
  predictions\_val \= \[\]  
  labels\_val \= \[\]  
   
  for i\_batch, sample\_batched in enumerate\(train\_loader\_coronal\_subset\):  
  inputs \= sample\_batched\['image'\].type\(torch.FloatTensor\)  
  labels \= sample\_batched\['label'\].type\(torch.LongTensor\)  
   
  \# print\(f"Train Input Shape: \{inputs.shape\}"\)  
   
  labels \= labels.squeeze\(1\)  
  \_img\_channels, \_img\_height, \_img\_width \= labels.shape  
  encoded\_label\= np.zeros\(\(\_img\_height, \_img\_width, 1\)\).astype\(int\)  
   
  for j, cls in enumerate\(MaskColorMap\):  
  encoded\_label\[np.all\(labels \== cls.value, axis \= 0\)\] \= j  
   
  labels \= encoded\_label  
  labels \= torch.from\_numpy\(labels\)  
  labels \= torch.permute\(labels, \(2, 1, 0\)\)  
   
  \# print\(f"Train Label Shape: \{labels.shape\}"\)  
  \# plt.title\("Train Label"\)  
  \# plt.imshow\(labels\[0, :, :\]\)  
  \# plt.show\(\)  
   
  optimizer.zero\_grad\(\)  
  outputs \= model\_coronal\(inputs\)  
  loss \= loss\_function\(outputs, labels\)  
    
  loss.backward\(\)  
  optimizer.step\(\)  
  scheduler.step\(\)  
   
  with torch.no\_grad\(\):  
  \_, batch\_output \= torch.max\(outputs, dim \= 1\)  
  \# print\(f"Train Prediction Shape: \{batch\_output.shape\}"\)  
  \# plt.title\("Train Prediction"\)  
  \# plt.imshow\(batch\_output\[0, :, :\]\)  
  \# plt.show\(\)  
   
  predictions\_train.append\(batch\_output.cpu\(\)\)  
  labels\_train.append\(labels.cpu\(\)\)  
  train\_loss\_epoch += loss.item\(\)  
  print\(f"\{step\}/\{len\(train\_dataset\_coronal\_subset\) // train\_loader\_coronal\_subset.batch\_size\}, Training\_loss: \{loss.item\(\):.4f\}"\)  
  step += 1  
   
  predictions\_train\_arr, labels\_train\_arr \= torch.cat\(predictions\_train\), torch.cat\(labels\_train\)  
   
  \#  print\(predictions\_train\_arr.shape\)  
   
  dice\_metric\(predictions\_train\_arr, labels\_train\_arr\)  
   
  ds\_mean\_train \= dice\_metric.aggregate\(\).item\(\)  
  ds\_mean\_train\_values.append\(ds\_mean\_train\)  
  dice\_metric.reset\(\)  
   
  train\_loss\_epoch /= step  
  train\_loss\_epoch\_values.append\(train\_loss\_epoch\)  
  print\(f"Epoch \{epoch + 1\} Train Average Loss: \{train\_loss\_epoch:.4f\}"\)  
    
  if \(epoch + 1\) \% val\_interval \== 0:  
   
  model\_coronal.eval\(\)  
  step \= 0  
   
  with torch.no\_grad\(\):  
   
  for i\_batch, sample\_batched in enumerate\(val\_loader\_coronal\_subset\):  
  inputs \= sample\_batched\['image'\].type\(torch.FloatTensor\)  
  labels \= sample\_batched\['label'\].type\(torch.LongTensor\)  
   
  \# print\(f"Val Input Shape: \{inputs.shape\}"\)  
   
  labels \= labels.squeeze\(1\)  
  integer\_encoded\_labels \= \[\]  
  \_img\_channels, \_img\_height, \_img\_width \= labels.shape  
  encoded\_label\= np.zeros\(\(\_img\_height, \_img\_width, 1\)\).astype\(int\)  
   
  for j, cls in enumerate\(MaskColorMap\):  
  encoded\_label\[np.all\(labels \== cls.value, axis \= 0\)\] \= j  
   
  labels \= encoded\_label  
  labels \= torch.from\_numpy\(labels\)  
  labels \= torch.permute\(labels, \(2, 1, 0\)\)  
   
  \# print\(f"Val Label Shape: \{labels.shape\}"\)  
  \# plt.title\("Val Label"\)  
  \# plt.imshow\(labels\[0, :, :\]\)  
  \# plt.show\(\)  
   
  val\_outputs \= model\_coronal\(inputs\)  
   
  val\_loss \= loss\_function\(val\_outputs, labels\)  
   
  predicted \= torch.argmax\(val\_outputs, dim \= 1\)  
   
  \# print\(f"Val Prediction Shape: \{predicted.shape\}"\)  
  \# plt.title\("Val Prediction"\)  
  \# plt.imshow\(predicted\[0, :, :\]\)  
  \# plt.show\(\)  
    
  predictions\_val.append\(predicted\)  
  labels\_val.append\(labels\)  
   
  val\_loss\_epoch += val\_loss.item\(\)  
  print\(f"\{step\}/\{len\(val\_dataset\_coronal\_subset\) // val\_loader\_coronal\_subset.batch\_size\}, Validation\_loss: \{val\_loss.item\(\):.4f\}"\)  
  step += 1  
   
  predictions\_val\_arr, labels\_val\_arr \= torch.cat\(predictions\_val\), torch.cat\(labels\_val\)  
   
  dice\_metric\(predictions\_val\_arr, labels\_val\_arr\)  
  \# dice\_metric\_batch\(predictions\_val\_arr, labels\_val\_arr\)  
    
  ds\_mean\_val \= dice\_metric.aggregate\(\).item\(\)  
  ds\_mean\_val\_values.append\(ds\_mean\_val\)   
  \# ds\_mean\_val\_batch = dice\_metric\_batch.aggregate\(\)  
  \# ds\_LCE = ds\_mean\_val\_batch\[0\].item\(\)  
  \# ds\_LCE\_values.append\(ds\_LCE\)  
  \# ds\_LWM = ds\_mean\_val\_batch\[1\].item\(\)  
  \# ds\_LWM\_values.append\(ds\_LWM\)  
  \# ds\_LCC = ds\_mean\_val\_batch\[2\].item\(\)  
  \# ds\_LCC\_values.append\(ds\_LCC\)  
   
  dice\_metric.reset\(\)  
  \# dice\_metric\_batch.reset\(\)  
   
  if ds\_mean\_val \> best\_ds\_mean:  
  best\_ds\_mean \= ds\_mean\_val  
  best\_ds\_mean\_epoch \= epoch + 1  
  torch.save\(model\_coronal.state\_dict\(\), os.path.join\(BESTMODEL\_DIR, "best\_metric\_model\_coronal.pth"\)\)  
  print\("Saved new best metric model coronal"\)  
   
  print\(  
  f"Current Epoch: \{epoch + 1\} Current Mean Dice score is: \{ds\_mean\_val:.4f\}"  
  f"\\nBest Mean Dice score: \{best\_ds\_mean:.4f\} "  
  \# f"\\nMean Dice score Left Cerebral Exterior: \{ds\_LCE:.4f\} Mean Dice score Left White Matter: \{ds\_LWM:.4f\} Mean Dice score Left Cerebral Cortex: \{ds\_LCC:.4f\} "  
  f"at Epoch: \{best\_ds\_mean\_epoch\}"  
  \)  
   
  val\_loss\_epoch /= step  
  val\_loss\_epoch\_values.append\(val\_loss\_epoch\)  
  print\(f"Epoch \{epoch + 1\} Average Validation Loss: \{val\_loss\_epoch:.4f\}"\)  
   
 print\("FINISH."\) 

The code is also a traditional Pytorch training step, so I won’t explain it in detail

Plot loss and accuracy curves

The training curve represents how well the model learns, and the validation curve represents how well the model generalizes to unseen instances. We use matplotlib to draw the graphs. You can also use TensorBoard, which makes it easier to understand and debug deep learning programs, and in real time.

 epoch \= range\(1, num\_epochs + 1\)  
   
 \# Plot Loss Curves  
 plt.figure\(figsize\=\(18, 6\)\)  
 plt.subplot\(1, 3, 1\)  
 plt.plot\(epoch, train\_loss\_epoch\_values, label\='Training Loss'\)  
 plt.plot\(epoch, val\_loss\_epoch\_values, label\='Validation Loss'\)  
 plt.title\('Training and Validation Loss'\)  
 plt.xlabel\('Epoch'\)  
 plt.legend\(\)  
 plt.figure\(\)  
 plt.show\(\)  
   
 \# Plot Train Dice Coefficient Curve  
 plt.figure\(figsize\=\(18, 6\)\)  
 plt.subplot\(1, 3, 2\)  
 x \= \[\(i + 1\) for i in range\(len\(ds\_mean\_train\_values\)\)\]  
 plt.plot\(x, ds\_mean\_train\_values, 'blue', label \= 'Train Mean Dice Score'\)  
 plt.title\("Training Mean Dice Coefficient"\)  
 plt.xlabel\('Epoch'\)  
 plt.ylabel\('Mean Dice Score'\)  
 plt.show\(\)  
   
 \# Plot Validation Dice Coefficient Curve  
 plt.figure\(figsize\=\(18, 6\)\)  
 plt.subplot\(1, 3, 3\)  
 x \= \[\(i + 1\) for i in range\(len\(ds\_mean\_val\_values\)\)\]  
 plt.plot\(x, ds\_mean\_val\_values, 'orange', label \= 'Validation Mean Dice Score'\)  
 plt.title\("Validation Mean Dice Coefficient"\)  
 plt.xlabel\('Epoch'\)  
 plt.ylabel\('Mean Dice Score'\)  
 plt.show\(\)

In the curve, we can see that the model is overfitting as the validation loss goes up and the training loss goes down. This is a common pitfall in deep learning algorithms, where the model ends up memorizing the training data and is unable to generalize to unseen data.

Tips to avoid overfitting:

  • Train with more data: Larger datasets reduce overfitting.

  • Data augmentation: If we cannot collect more data, we can apply data augmentation to artificially increase the size of the dataset.

  • Adding Regularization: Regularization is a technique that limits our network from learning models that are too complex, and thus may overfit.

evaluation network

How do we measure model performance? A successful prediction is one that maximizes the overlap between the prediction and the true.

Two related but distinct metrics for this objective are Dice and the Intersection/Union (IoU) coefficient, which is also known as the Jaccard coefficient. Both metrics are between 0 (no overlap) and 1 (full overlap).

Both metrics can be used in similar situations, but the difference is that Dice Score tends towards average performance, while IoU helps you understand worst-case performance. 

We can examine the metric class by class, or take the average across all classes. Here we will use monai.metrics.DiceMetric to calculate the score. A more general approach is to use torchmetrics, but since the monai framework is used here, its built-in functions are used directly.

We can see that the behavior of the Dice score curve is rather unusual. Mainly because the validation average Dice score is higher than 1, which is impossible since this metric is between 0 and 1. We were unable to identify the main reason for this behavior, but we recommend providing separate metric computations for each class in multiclass problems and always providing visualization examples for visual evaluation.

Result analysis

Finally we're going to see how the model generalizes to unseen data. Almost everything the model predicts is left brain white matter, and some pixels are left brain cortex. Although its predictions seem to be correct, there is still a lot of room for improvement since our model is too small and a deeper model can be chosen to achieve better results.

 

Summarize

In this paper, we describe how to train QuickNAT for the challenging task of brain segmentation. We followed as much as possible the learning strategy that the authors explained in their research paper, this tutorial is only demonstrated on the simplest steps for the convenience of demonstration, the full code of the text: https://github.com/inesdv26/Brain -Segmentation

  whaosoft aiot http://143ai.com

Guess you like

Origin blog.csdn.net/qq_29788741/article/details/130515054