Plug and play series! | MedAugment: Automatic data augmentation plugin for image classification and segmentation (with Pytorch source code)

guide

Today I mainly introduce an MedAugmentautomatic data enhancement method called , which aims to introduce automatic data enhancement technology into the field of medical image analysis. This method divides the enhancement space into two types:

  • Pixel Enhancement Space
  • Space Enhancement Space

Thereby resolving the difference between natural images and medical images. The paper proposes a novel operation sampling strategy for sampling data augmentation operations from the augmentation space. To demonstrate the performance and generalization performance of MedAugment, the authors conduct extensive experiments on four classification datasets and three segmentation datasets, and show that MedAugment outperforms most state-of-the-art data augmentation methods, such as comparing the mainstream and et AutoAugmental RandAugment. Automatic data augmentation methods.

Although it is used in the field of medical images, it can be modified and copied to natural images.

motivation

Why not just use existing automatic data augmentation methods? According to the authors, these methods were originally designed for natural images and cannot be directly applied in the field of medical image analysis. In addition, most automatic data augmentation methods are also originally designed for image classification tasks, while in the field of medical image analysis, image segmentation is a core task. Therefore, there is currently a lack of a general and powerful automatic data augmentation method suitable for medical image analysis, and this paper proposes MedAugmentto fill this gap.

method

Introduction to the overall framework

The principle of MedAugment is actually very simple, and its framework implementation is shown in Figure 1 above. In this method, the authors design two augmentation spaces, Ap and As, containing six and eight data augmentation operations, respectively. In this way, we have 14 kinds of data enhancement operations (which can be freely played according to the characteristics of our own data sets).

To better adapt to the field of medical image analysis, a novel operation sampling strategy is also developed here. MedAugment consists of N augmentation branches and a separate branch (for preserving original image information). Each augmentation branch consists of M = {2, 3} data augmentation operations performed sequentially. By tuning a hyperparameter, the enhancement level l = 5 l = 5l=5 , we can precisely control the degree of enhancement of MedAugment.

In addition, the authors design a new mapping method such that lll not only controls the maximum magnitude of each data augmentation operation, but also controls the corresponding trigger probability. Overall, MedAugment introduces randomness in three aspects, namely, sampling strategy, operation type and sequence, and trigger probability. For each branch, MedAugment first samples data augmentation operations from the augmentation space using a sampling strategy, then randomly sorts the sampled operations, and executes them sequentially. As for the specific data enhancement operation range and probability, it is determined by the hyperparameterlll control, for each image, the amplitude control of the operation samples uniformly over the maximum amplitude range. The specific method implementation can refer to Algorithm 1:

Enhancement space

In order to adapt to the characteristics of the field of medical image analysis, the author carefully set up the specific data enhancement methods. First, we start with common data augmentation operations, and then filter according to the characteristics of the MIA domain, excluding operations that are not suitable for medical images, such as inversion, equalization, and inversion, which may destroy the details and feature. Next, we divide the data augmentation operation into pixel-level and spatial-level operations, and construct two augmentation spaces, the pixel augmentation space Ap and the spatial augmentation space As. Ap and As include pixel- and space-dependent data augmentation operations, respectively. It should be noted that the data augmentation operation in Ap does not apply to masks. [The source code is based on the Albumentations software package to implement these data enhancement operations, and you can also design according to your own task characteristics]

sampling strategy

Since medical images are very sensitive to properties such as brightness, and we observe that consecutive operations in Ap may lead to unrealistic output medical images, this paper designs a novel operation sampling strategy to sample operations from Ap and As .

Specifically, we randomly sample M data augmentation operations for each branch, where the number of operations sampled from Ap is no more than one. We determined the value range of M after weighing. For continuous data augmentation operations, we need to carefully consider the number of continuous operations. Using more continuous operations may further improve the generalization ability of the model, but too many continuous operations may generate images that are far from the original image.

Therefore, we determine that the upper limit of M is 3. Since our method uses data augmentation operations in parallel, it does not make sense to set M to 1, as this degenerates to a single operation with no combined effect. Based on these considerations, the author designed M = {2, 3}. Given M = {2, 3}, we generate four combinations sampled from Ap and As, namely 1+2, 0+3, 1+1 and 0+2. This number of sampling combinations explains why the number of branches N = 4.

For better expansion, we can also design N in MedAugment to be scalable to other values, and use replacement sampling for sampling. At the same time, individual branches can also be masked. MedAugment can perform one-to-one data augmentation when N = 1 and mask individual branches. It can be seen from the comparison in Figure 2 that our MedAugment is more suitable for medical images, while existing methods may generate unrealistic augmented images. In the worst case, some augmented images are considered nonsensical, because there is too much "noise", or there is little useful information. These augmented images may be correctly identified on deep learning models, but are meaningless from a medical point of view.

hyperparameter mapping

In order to use only one hyperparameter lll to control MedAugment, and finally a novel mapping method is designed to determine the maximum magnitude MA and probability PA for each operation. We consider the mapping relationship of each operation one by one to ensure that MA is suitable for medical images. We observe that forPosterizeoperations like this, medical images are particularly sensitive to magnitude. When the number of remaining bits is reduced, the quality of the enhanced image degrades rapidly. Therefore, in this paper, we carefully design the magnitude of such data augmentation operations through extensive experiments to ensure that the augmented images are still relevant.

Furthermore, given lll , the mapping relationship between l and MA for each data augmentation operation in Ap and As can be found in Table 1. Note that operations without magnitude are denoted as "-" in the table. we willlll is set to 5, but can be extended to l = {1, 2, 3, 4, 5} as needed to take scalability into account. The larger l is, the stronger the degree of enhancement will be. For probability PA, all data augmentation operations in Ap and As follow the same rule, namely PA = 0.2l. This is designed so that the probability increases with the augmentation level, thus increasing the likelihood of applying the augmentation operation.

data augmentation

For more principles and implementations of data augmentation, welcome to add the editor WeChat: cv_huber, remark "data augmentation", and receive the corresponding PDF document:

Source code implementation

import albumentations as A
import torch
import math
import random
import os
import cv2
import shutil
import numpy as np
import argparse
from torchvision import transforms
from PIL import Image


def make_odd(num):
    num = math.ceil(num)
    if num % 2 == 0:
        num += 1
    return num


def med_augment(data_path, name, level, number_branch, mask_i=False, shield=False):
    if mask_i:
        image_path = f"{
      
      data_path}{
      
      name}"
        mask_path = f"{
      
      image_path}_mask"
        output_path = f"{
      
      os.path.dirname(os.path.dirname(data_path))}/medaugment/{
      
      name}/"
        out_mask = f"{
      
      os.path.dirname(os.path.dirname(data_path))}/medaugment/{
      
      name}_mask/"
    else:
        image_path = data_path + name
        output_path = f"{
      
      os.path.dirname(os.path.dirname(os.path.dirname(data_path)))}/medaugment/training/{
      
      name}/"

    transform = A.Compose([
        A.ColorJitter(brightness=0.04 * level, contrast=0, saturation=0, hue=0, p=0.2 * level),
        A.ColorJitter(brightness=0, contrast=0.04 * level, saturation=0, hue=0, p=0.2 * level),
        A.Posterize(num_bits=math.floor(8 - 0.8 * level), p=0.2 * level),
        A.Sharpen(alpha=(0.04 * level, 0.1 * level), lightness=(1, 1), p=0.2 * level),
        A.GaussianBlur(blur_limit=(3, make_odd(3 + 0.8 * level)), p=0.2 * level),
        A.GaussNoise(var_limit=(2 * level, 10 * level), mean=0, per_channel=True, p=0.2 * level),
        A.Rotate(limit=4 * level, interpolation=1, border_mode=0, value=0, mask_value=None, rotate_method='largest_box',
                 crop_border=False, p=0.2 * level),
        A.HorizontalFlip(p=0.2 * level),
        A.VerticalFlip(p=0.2 * level),
        A.Affine(scale=(1 - 0.04 * level, 1 + 0.04 * level), translate_percent=None, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
                 shear={
    
    'x': (0, 2 * level), 'y': (0, 0)}
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),  # x
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
                 shear={
    
    'x': (0, 0), 'y': (0, 2 * level)}
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent={
    
    'x': (0, 0.02 * level), 'y': (0, 0)}, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent={
    
    'x': (0, 0), 'y': (0, 0.02 * level)}, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level)
    ])

    for j, file_name in enumerate(os.listdir(image_path)):
        if file_name.endswith(".png") or file_name.endswith(".jpg"):
            file_path = os.path.join(image_path, file_name)
            file_n, file_s = file_name.split(".")[0], file_name.split(".")[1]
            image = cv2.imread(file_path)
            if mask_i: mask = cv2.imread(f"{
      
      mask_path}/{
      
      file_n}_mask.{
      
      file_s}")
            strategy = [(1, 2), (0, 3), (0, 2), (1, 1)]
            for i in range(number_branch):
                if number_branch != 4:
                    employ = random.choice(strategy)
                else:
                    index = random.randrange(len(strategy))
                    employ = strategy.pop(index)
                level, shape = random.sample(transform[:6], employ[0]), random.sample(transform[6:], employ[1])
                img_transform = A.Compose([*level, *shape])
                random.shuffle(img_transform.transforms)
                if not os.path.exists(output_path): os.makedirs(output_path)
                if mask_i:
                    transformed = img_transform(image=image, mask=mask)
                    transformed_image, transformed_mask = transformed['image'], transformed['mask']
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      i+1}.{
      
      file_s}", transformed_image)
                    cv2.imwrite(f"{
      
      out_mask}/{
      
      file_n}_{
      
      i+1}_mask.{
      
      file_s}", transformed_mask)
                else:
                    transformed = img_transform(image=image)
                    transformed_image = transformed['image']
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      i+1}.{
      
      file_s}", transformed_image)
                if not shield:
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      number_branch+1}.{
      
      file_s}", image)
                    if mask_i: cv2.imwrite(f"{
      
      out_mask}/{
      
      file_n}_{
      
      number_branch+1}_mask.{
      
      file_s}", mask)


def generate_datasets(train_type, dataset, seed, level, number_branch):

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)

    if train_type == "classification":
        print('Executing data augmentation for image classification...')
        data_path = f"./datasets/classification/{
      
      dataset}/baseline/training/"
        folder_path = f"./datasets/classification/{
      
      dataset}/"
        n = len([name for name in os.listdir(f"{
      
      folder_path}/baseline/training") if
                 os.path.isdir(os.path.join(f"{
      
      folder_path}/baseline/training", name))])

        for folder in ["medaugment"]:
            shutil.copytree(f"{
      
      folder_path}baseline", f"{
      
      folder_path}{
      
      folder}",
                            ignore=shutil.ignore_patterns("training"))
            training_folder_path = f"{
      
      folder_path}{
      
      folder}/training"
            os.makedirs(training_folder_path)
            for i in range(n):
                os.makedirs(f"{
      
      training_folder_path}/n{
      
      i}")

        for i in range(n):
            name = f"n{
      
      i}"
            med_augment(data_path, name, level, number_branch)
    else:
        print('Executing data augmentation for image segmentation...')
        data_path = f"./datasets/segmentation/{
      
      dataset}/baseline/"
        folder_path = f"./datasets/segmentation/{
      
      dataset}/"

        for folder in ["medaugment"]:
            shutil.copytree(f"{
      
      folder_path}baseline", f"{
      
      folder_path}{
      
      folder}",
                            ignore=shutil.ignore_patterns("training", "training_mask"))
            os.makedirs(f"{
      
      folder_path}{
      
      folder}/training")
            os.makedirs(f"{
      
      folder_path}{
      
      folder}/training_mask")

        folder_list = ["training"]
        for i in range(len(folder_list)):
            name = folder_list[i]
            med_augment(data_path, name, level, number_branch, mask_i=True)


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
    group = parser.add_argument_group()
    group.add_argument('--dataset', required=True)
    group.add_argument('--train_type', choices=['classification', 'segmentation'], default='classification')
    group.add_argument('--level', help='Augmentation level', default=5, type=int, metavar='INT')
    group.add_argument('--number_branch', help='Number of branch', default=4, type=int, metavar='INT')
    group.add_argument('--seed', help='Seed', default=8, type=int, metavar='INT')
    args = parser.parse_args()
    generate_datasets(**vars(args))


if __name__ == '__main__':
    main()

Summarize

Today I introduce a MedAugmentplug-and-play automatic data enhancement method called , which can introduce automatic data enhancement into the field of medical image analysis, significantly reducing the dependence of data enhancement on experience. By designing an augmentation space and a sampling strategy, MedAugment addresses the differences between natural and medical images.

However, MedAugment and other state-of-the-art methods still have some problems in balancing different evaluation metrics, such as sensitivity and other metrics may be low. Therefore, how to balance different indicators can be further studied. For example, multiple hyperparameters can be introduced and evaluated, and different metrics can be balanced through hyperparameter updates during training. In addition, the problem of small object size needs further study. This leads to a problem of how to emphasize the target to be segmented. Further research can apply different types and levels of data augmentation depending on the size of the target. For example, images with smaller objects are more likely to be magnified during enhancement than larger objects, and the magnification factor can also be larger.

Guess you like

Origin blog.csdn.net/CVHub/article/details/131625515