即插即用系列!| MedAugment: 用于图像分类和分割的自动数据增强插件(附Pytorch源码)

导读

今天主要向大家介绍一种名为 MedAugment 的自动数据增强方法,旨在将自动数据增强技术引入医学图像分析领域。该方法通过将增强空间划分为两种:

  • 像素增强空间
  • 空间增强空间

从而解决了自然图像和医学图像之间的差异。论文提出了一种新颖的操作采样策略,用于从增强空间中采样数据增强操作。为了证明 MedAugment 的性能和泛化性能,作者在四个分类数据集和三个分割数据集上进行了广泛的实验,并表明 MedAugment 优于大多数最先进的数据增强方法,例如比较主流的 AutoAugmentRandAugment 等自动数据增强方法。

虽然说是用于医学图像领域,但完全可以稍作修改照搬到自然图像即可。

动机

为什么不直接使用现在的自动数据增强方法?作者认为,这些方法最初设计用于自然图像,并不能直接应用于医学图像分析领域。此外,大部分自动数据增强方法最初也是针对图像分类任务设计的,而在医学图像分析领域,图像分割是一个核心任务。因此,目前缺乏一种适用于医学图像分析的通用且强大的自动数据增强方法,本文提出了 MedAugment 来填补这一空白。

方法

整体框架介绍

MedAugment 的原理其实很简单,其框架实现如上图1所示。在该方法中,作者设计了两个增强空间 Ap 和 As,分别包含六个和八个数据增强操作。这样我们便有14种数据增强操作(可根据自己的数据集特点自由发挥)。

为了更好地适应医学图像分析领域,这里还开发了一种新颖的操作采样策略。MedAugment 由 N 个增强分支和一个单独的分支组成(用于保留原始图像信息)。每个增强分支由顺序执行的 M = {2, 3}个数据增强操作组成。通过调整一个超参数,即增强级别 l = 5 l = 5 l=5,我们可以精确控制 MedAugment 的增强程度。

此外,作者还设计了一种新的映射方法,使得 l l l 不仅控制每个数据增强操作的最大幅度,还控制了相应的触发概率。总的来说,MedAugment 在三个方面引入了随机性,即采样策略、操作类型和顺序、触发概率。对于每个分支,MedAugment 首先使用采样策略从增强空间中采样数据增强操作,然后对采样的操作进行随机排序,并按顺序执行这些操作。至于具体的数据增强操作幅度和概率由超参数 l l l 控制,对于每个图像,操作的幅度控制在最大幅度范围内均匀采样。具体的方法实现可以参考算法1:

增强空间

为了适应医学图像分析领域的特点,作者对具体的数据增强方法进行了细心的设置。首先,从常见的数据增强操作开始,然后根据 MIA 领域的特点进行筛选,排除了不适合医学图像的操作,如反转、均衡化和反转等操作,这些操作可能破坏医学图像中的细节和特征。接下来,我们将数据增强操作分为像素级和空间级操作,并构建了两个增强空间,即像素增强空间Ap和空间增强空间As。Ap和As分别包括与像素和空间相关的数据增强操作。需要注意的是,Ap中的数据增强操作不适用于掩膜。【源码是基于 Albumentations 软件包实现了这些数据增强操作,大家也可以针对自身任务特点去设计】

采样策略

由于医学图像对亮度等属性非常敏感,并且我们观察到Ap中连续的操作可能会导致不真实的输出医学图像,因此本文设计了一种新颖的操作采样策略,用于从Ap和As中采样操作。

具体而言,我们随机采样每个分支的M个数据增强操作,其中从Ap中采样的操作数量不超过一个。我们经过权衡决定了M的取值范围。对于连续的数据增强操作,我们需要谨慎考虑连续操作的数量。使用更多的连续操作可能进一步提高模型的泛化能力,但过多的连续操作可能生成与原始图像差距较大的图像。

因此,我们确定M的上限为3。由于本文方法并行使用数据增强操作,将M设置为1没有意义,因为这会退化为单个操作而没有组合效果。基于这些考虑,作者设计了M = {2, 3}。给定M = {2, 3},我们生成了四种从Ap和As中采样的组合方式,分别是1 + 2、0 + 3、1 + 1和0 + 2。这个采样组合的数量解释了为什么分支数N = 4。

为了更好的扩展,我们也可以将 MedAugment 中的 N 设计为可扩展到其他值,并使用替换采样的方式进行采样。同时,单独的分支也可以被屏蔽。当N = 1并屏蔽单独的分支时,MedAugment 可以执行一对一的数据增强。通过图2的比较可以看出,我们的MedAugment更适合医学图像,而现有的方法可能会生成不真实的增强图像。在最差的情况下,一些增强图像被认为是无意义的,因为存在过多的“噪音”,或者几乎没有有用的信息。这些增强图像可能在深度学习模型上能够被正确识别,但从医学的角度来看是没有意义的。

超参数映射

为了仅使用一个超参数 l l l 来控制 MedAugment,最后设计了一种新颖的映射方法,用于确定每个操作的最大幅度MA和概率PA。我们逐个考虑每个操作的映射关系,以确保MA适用于医学图像。我们观察到,对于像Posterize这样的操作,医学图像对幅度特别敏感。当剩余位数减少时,增强图像的质量会迅速下降。因此,本文通过广泛的实验精心设计了这类数据增强操作的幅度,以确保增强图像仍然具有重要性。

此外,给定 l l l,可以在表1中找到Ap和As中每个数据增强操作的l和MA之间的映射关系。需要注意的是,没有幅度的操作在表中表示为“-”。我们将 l l l 设置为5,但可以根据需要扩展到l = {1, 2, 3, 4, 5},以考虑可扩展性。l越大,增强的程度越强。对于概率PA,Ap和As中的所有数据增强操作都遵循相同的规则,即PA = 0.2l。这样设计的目的是使得概率随着增强级别的增加而增加,从而增加了应用增强操作的可能性。

数据增强

更多数据增强的原理及实现,欢迎添加小编微信:cv_huber,备注“数据增强”,领取对应的 PDF 文档:

源码实现

import albumentations as A
import torch
import math
import random
import os
import cv2
import shutil
import numpy as np
import argparse
from torchvision import transforms
from PIL import Image


def make_odd(num):
    num = math.ceil(num)
    if num % 2 == 0:
        num += 1
    return num


def med_augment(data_path, name, level, number_branch, mask_i=False, shield=False):
    if mask_i:
        image_path = f"{
      
      data_path}{
      
      name}"
        mask_path = f"{
      
      image_path}_mask"
        output_path = f"{
      
      os.path.dirname(os.path.dirname(data_path))}/medaugment/{
      
      name}/"
        out_mask = f"{
      
      os.path.dirname(os.path.dirname(data_path))}/medaugment/{
      
      name}_mask/"
    else:
        image_path = data_path + name
        output_path = f"{
      
      os.path.dirname(os.path.dirname(os.path.dirname(data_path)))}/medaugment/training/{
      
      name}/"

    transform = A.Compose([
        A.ColorJitter(brightness=0.04 * level, contrast=0, saturation=0, hue=0, p=0.2 * level),
        A.ColorJitter(brightness=0, contrast=0.04 * level, saturation=0, hue=0, p=0.2 * level),
        A.Posterize(num_bits=math.floor(8 - 0.8 * level), p=0.2 * level),
        A.Sharpen(alpha=(0.04 * level, 0.1 * level), lightness=(1, 1), p=0.2 * level),
        A.GaussianBlur(blur_limit=(3, make_odd(3 + 0.8 * level)), p=0.2 * level),
        A.GaussNoise(var_limit=(2 * level, 10 * level), mean=0, per_channel=True, p=0.2 * level),
        A.Rotate(limit=4 * level, interpolation=1, border_mode=0, value=0, mask_value=None, rotate_method='largest_box',
                 crop_border=False, p=0.2 * level),
        A.HorizontalFlip(p=0.2 * level),
        A.VerticalFlip(p=0.2 * level),
        A.Affine(scale=(1 - 0.04 * level, 1 + 0.04 * level), translate_percent=None, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
                 shear={
    
    'x': (0, 2 * level), 'y': (0, 0)}
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),  # x
        A.Affine(scale=None, translate_percent=None, translate_px=None, rotate=None,
                 shear={
    
    'x': (0, 0), 'y': (0, 2 * level)}
                 , interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent={
    
    'x': (0, 0.02 * level), 'y': (0, 0)}, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level),
        A.Affine(scale=None, translate_percent={
    
    'x': (0, 0), 'y': (0, 0.02 * level)}, translate_px=None, rotate=None,
                 shear=None, interpolation=1, mask_interpolation=0, cval=0, cval_mask=0, mode=0, fit_output=False,
                 keep_ratio=True, p=0.2 * level)
    ])

    for j, file_name in enumerate(os.listdir(image_path)):
        if file_name.endswith(".png") or file_name.endswith(".jpg"):
            file_path = os.path.join(image_path, file_name)
            file_n, file_s = file_name.split(".")[0], file_name.split(".")[1]
            image = cv2.imread(file_path)
            if mask_i: mask = cv2.imread(f"{
      
      mask_path}/{
      
      file_n}_mask.{
      
      file_s}")
            strategy = [(1, 2), (0, 3), (0, 2), (1, 1)]
            for i in range(number_branch):
                if number_branch != 4:
                    employ = random.choice(strategy)
                else:
                    index = random.randrange(len(strategy))
                    employ = strategy.pop(index)
                level, shape = random.sample(transform[:6], employ[0]), random.sample(transform[6:], employ[1])
                img_transform = A.Compose([*level, *shape])
                random.shuffle(img_transform.transforms)
                if not os.path.exists(output_path): os.makedirs(output_path)
                if mask_i:
                    transformed = img_transform(image=image, mask=mask)
                    transformed_image, transformed_mask = transformed['image'], transformed['mask']
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      i+1}.{
      
      file_s}", transformed_image)
                    cv2.imwrite(f"{
      
      out_mask}/{
      
      file_n}_{
      
      i+1}_mask.{
      
      file_s}", transformed_mask)
                else:
                    transformed = img_transform(image=image)
                    transformed_image = transformed['image']
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      i+1}.{
      
      file_s}", transformed_image)
                if not shield:
                    cv2.imwrite(f"{
      
      output_path}/{
      
      file_n}_{
      
      number_branch+1}.{
      
      file_s}", image)
                    if mask_i: cv2.imwrite(f"{
      
      out_mask}/{
      
      file_n}_{
      
      number_branch+1}_mask.{
      
      file_s}", mask)


def generate_datasets(train_type, dataset, seed, level, number_branch):

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)

    if train_type == "classification":
        print('Executing data augmentation for image classification...')
        data_path = f"./datasets/classification/{
      
      dataset}/baseline/training/"
        folder_path = f"./datasets/classification/{
      
      dataset}/"
        n = len([name for name in os.listdir(f"{
      
      folder_path}/baseline/training") if
                 os.path.isdir(os.path.join(f"{
      
      folder_path}/baseline/training", name))])

        for folder in ["medaugment"]:
            shutil.copytree(f"{
      
      folder_path}baseline", f"{
      
      folder_path}{
      
      folder}",
                            ignore=shutil.ignore_patterns("training"))
            training_folder_path = f"{
      
      folder_path}{
      
      folder}/training"
            os.makedirs(training_folder_path)
            for i in range(n):
                os.makedirs(f"{
      
      training_folder_path}/n{
      
      i}")

        for i in range(n):
            name = f"n{
      
      i}"
            med_augment(data_path, name, level, number_branch)
    else:
        print('Executing data augmentation for image segmentation...')
        data_path = f"./datasets/segmentation/{
      
      dataset}/baseline/"
        folder_path = f"./datasets/segmentation/{
      
      dataset}/"

        for folder in ["medaugment"]:
            shutil.copytree(f"{
      
      folder_path}baseline", f"{
      
      folder_path}{
      
      folder}",
                            ignore=shutil.ignore_patterns("training", "training_mask"))
            os.makedirs(f"{
      
      folder_path}{
      
      folder}/training")
            os.makedirs(f"{
      
      folder_path}{
      
      folder}/training_mask")

        folder_list = ["training"]
        for i in range(len(folder_list)):
            name = folder_list[i]
            med_augment(data_path, name, level, number_branch, mask_i=True)


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
    group = parser.add_argument_group()
    group.add_argument('--dataset', required=True)
    group.add_argument('--train_type', choices=['classification', 'segmentation'], default='classification')
    group.add_argument('--level', help='Augmentation level', default=5, type=int, metavar='INT')
    group.add_argument('--number_branch', help='Number of branch', default=4, type=int, metavar='INT')
    group.add_argument('--seed', help='Seed', default=8, type=int, metavar='INT')
    args = parser.parse_args()
    generate_datasets(**vars(args))


if __name__ == '__main__':
    main()

总结

今天为大家介绍了一种名为 MedAugment 的即插即用的自动数据增强方法,可以将自动数据增强引入到医学图像分析领域,显著降低了数据增强对经验的依赖。通过设计增强空间和采样策略,MedAugment 解决了自然图像和医学图像之间的差异。

然而,MedAugment 和其他最先进的方法在平衡不同评估指标方面仍存在一些问题,例如敏感性等指标可能较低。因此,可以进一步研究如何平衡不同指标。例如,可以引入和评估多个超参数,在训练过程中通过超参数更新来平衡不同指标。另外,小目标尺寸的问题需要进一步研究。这导致如何强调待分割的目标成为一个问题。进一步的研究可以根据目标的大小应用不同类型和级别的数据增强。例如,与较大的目标相比,具有较小目标的图像在增强过程中更有可能被放大,同时放大系数也可以更大。

猜你喜欢

转载自blog.csdn.net/CVHub/article/details/131625515