图像增强库Albumentations使用总结

摘要

albumentations包是一种针对数据增强专门写的API,里面基本包含大量的数据增强手段,其特点:

1、Albumentations支持所有常见的计算机视觉任务,如分类、语义分割、实例分割、目标检测和姿态估计。

2、该库提供了一个简单统一的API,用于处理所有数据类型:图像(rbg图像、灰度图像、多光谱图像)、分割掩码、边界框和关键点。

3、该库包含70多种不同的增强功能,可以从现有数据中生成新的训练样本。

4、Albumentations快。我们对每个新版本进行基准测试,以确保增强功能提供最大的速度。

5、它与流行的深度学习框架(如PyTorch和TensorFlow)一起工作。顺便说一下,Albumentations是PyTorch生态系统的一部分。

6、由专家写的。作者既有生产计算机视觉系统的工作经验,也有参与竞争性机器学习的经验。许多核心团队成员是Kaggle Masters和Grandmasters。

7、该库广泛应用于工业、深度学习研究、机器学习竞赛和开源项目。

Albumentations 的 pip 安装

pip install albumentations

image.png

问题

在这个部分有个小问题,我的本机已经安装完opencv-python,然后我们再去安装albumentations的时候,出现了一个问题,就是我们的opencv-python阻止albumentations的安装,报错如下:# Could not install packages due to anEnvironmentError: [WinError 5] 拒绝访问,是因为在安装albumentations的时候还要安装opencv-python-headless,这个库和opencv冲突。

解决方式

对于这个问题的解决方式是我们在新的虚拟环境中,先安装albumentations,安装albumentations的时候,我们会伴随安装上一个opencv-python-headless,这个完全可以替代opencv-python的功能,所以我们就不用安装opencv了。

基准测试结果

测试使用ImageNet验证集的前2000张图像在Intel Xeon Gold 6140 CPU运行基准测试的结果。所有输出都被转换为带有np的连续NumPy数组。uint8数据类型。表格显示了在单个核上每秒可以处理的图像数量;越高越好。

image.png

Python and library versions: Python 3.8.6 (default, Oct 13 2020, 20:37:26) [GCC 8.3.0], numpy 1.19.2, pillow-simd 7.0.0.post3, opencv-python 4.4.0.44, scikit-image 0.17.2, scipy 1.5.2.

Spatial-level transforms(空间层次转换)

空间级转换将同时改变输入图像和附加目标,如掩模、边界框和关键点。下表显示了每个转换支持哪些附加目标。

Spatial-level transforms(空间层次转换) 空间级转换将同时改变输入图像和附加目标,如掩模、边界框和关键点。下表显示了每个转换支持哪些附加目标。

image.png

支持的列表

  • Blur
  • CLAHE
  • ChannelDropout
  • ChannelShuffle
  • ColorJitter
  • Downscale
  • Emboss
  • Equalize
  • FDA
  • FancyPCA
  • FromFloat
  • GaussNoise
  • GaussianBlur
  • GlassBlur
  • HistogramMatching
  • HueSaturationValue
  • ISONoise
  • ImageCompression
  • InvertImg
  • MedianBlur
  • MotionBlur
  • MultiplicativeNoise
  • Normalize
  • Posterize
  • RGBShift
  • RandomBrightnessContrast
  • RandomFog
  • RandomGamma
  • RandomRain
  • RandomShadow
  • RandomSnow
  • RandomSunFlare
  • RandomToneCurve
  • Sharpen
  • Solarize
  • Superpixels
  • ToFloat
  • ToGray
  • ToSepia

简单的使用案例

import albumentations as A
import cv2
 
import matplotlib.pyplot as plt
 
# Declare an augmentation pipeline
transform = A.Compose([
    A.RandomCrop(width=512, height=512),
    A.HorizontalFlip(p=0.8),
    A.RandomBrightnessContrast(p=0.5),
])
 
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
# Augment an image
transformed = transform(image=image)
transformed_image = transformed["image"]
plt.imshow(transformed_image)
plt.show()
复制代码

原始图像:

image.png

运行结果:

image.png

详细使用案例

VerticalFlip 围绕X轴垂直翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.VerticalFlip(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Blur后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

Blur模糊输入图像

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Blur(blur_limit=15,always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Blur后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

HorizontalFlip 围绕y轴水平翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.HorizontalFlip(always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('HorizontalFlip后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

Flip水平,垂直或水平和垂直翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Flip(always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Flip后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果有一定的随机性,如下图:

image.png

Transpose, 通过交换行和列来转置输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Transpose(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Transpose后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

RandomCrop 随机裁剪

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomCrop(512, 512,always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomCrop后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

RandomGamma 随机灰度系数

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomGamma(gamma_limit=(20, 20), eps=None, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomGamma后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

RandomRotate90 将输入随机旋转90度,N次

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomRotate90(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomRotate90后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

ShiftScaleRotate 随机平移,缩放和旋转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('ShiftScaleRotate后的图像')
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

CenterCrop 裁剪图像的中心部分

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.CenterCrop(256, 256, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("CenterCrop后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

GridDistortion网格失真

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GridDistortion(num_steps=10, distort_limit=0.3,border_mode=4, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GridDistortion后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

ElasticTransform 弹性变换

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ElasticTransform(alpha=5, sigma=50, alpha_affine=50, interpolation=1, border_mode=4,always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("ElasticTransform后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

RandomGridShuffle把图像切成网格单元随机排列

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomGridShuffle(grid=(3, 3), always_apply=False, p=1) (image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RandomGridShuffle后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

HueSaturationValue随机更改图像的颜色,饱和度和值

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("HueSaturationValue后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

PadIfNeeded 填充图像

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.PadIfNeeded(min_height=2048, min_width=2048, border_mode=4, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("PadIfNeeded后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

RGBShift,对图像RGB的每个通道随机移动值

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RGBShift(r_shift_limit=10, g_shift_limit=20, b_shift_limit=20, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RGBShift后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

GaussianBlur 使用随机核大小的高斯滤波器对图像进行模糊处理

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GaussianBlur(blur_limit=11, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GaussianBlur后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

CLAHE自适应直方图均衡

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("CLAHE后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

ChannelShuffle随机重新排列输入RGB图像的通道

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ChannelShuffle(always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("ChannelShuffle后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

InvertImg反色

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.InvertImg(always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("InvertImg后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

Cutout 随机擦除

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Cutout(num_holes=20, max_h_size=20, max_w_size=20, fill_value=0, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("Cutout后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

image.png

RandomFog随机雾化

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomFog(fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RandomFog后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

GridDropout网格擦除

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GridDropout(ratio=0.5, unit_size_min=None, unit_size_max=None, holes_number_x=None, holes_number_y=None,
                            shift_x=0, shift_y=0, always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GridDropout后的图像")
plt.imshow(transformed_image)
plt.show()
复制代码

运行结果:

image.png

组合变换(Compose)

变换不仅可以单独使用,还可以将这些组合起来,这就需要用到 Compose 类,该类继承自 BaseCompose。Compose 类含有以下参数:

  • transforms:转换类的数组,list类型
  • bbox_params:用于 bounding boxes 转换的参数,BboxPoarams 类型
  • keypoint_params:用于 keypoints 转换的参数, KeypointParams 类型
  • additional_targets:key新target 名字,value 为旧 target 名字的 dict,如 {'image2': 'image'},dict 类型
  • p:使用这些变换的概率,默认值为 1.0
image3 = Compose([
        # 对比度受限直方图均衡
            #(Contrast Limited Adaptive Histogram Equalization)
        CLAHE(),
        # 随机旋转 90°
        RandomRotate90(),
        # 转置
        Transpose(),
        # 随机仿射变换
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
        # 模糊
        Blur(blur_limit=3),
        # 光学畸变
        OpticalDistortion(),
        # 网格畸变
        GridDistortion(),
        # 随机改变图片的 HUE、饱和度和值
        HueSaturationValue()
    ], p=1.0)(image=image)['image']
复制代码

随机选择(OneOf)

它同Compose一样,都是做组合的,都有概率。区别就在于:Compose组合下的变换是要挨着顺序做的,而OneOf组合里面的变换是系统自动选择其中一个来做,而这里的概率参数p是指选定后的变换被做的概率。例:

image4 = Compose([
        RandomRotate90(),
        # 翻转
        Flip(),
        Transpose(),
        OneOf([
            # 高斯噪点
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            # 模糊相关操作
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([
            # 畸变相关操作
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        OneOf([
            # 锐化、浮雕等操作
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),            
        ], p=0.3),
        HueSaturationValue(p=0.3),
    ], p=1.0)(image=image)['image']
复制代码

在程序中的使用

def get_transform(phase: str):
    if phase == 'train':
        return Compose([
            A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
            A.Flip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.HueSaturationValue(p=0.5),
            A.OneOf([
                A.RandomBrightnessContrast(p=0.5),
                A.RandomGamma(p=0.5),
            ], p=0.5),
            A.OneOf([
                A.Blur(p=0.1),
                A.GaussianBlur(p=0.1),
                A.MotionBlur(p=0.1),
            ], p=0.1),
            A.OneOf([
                A.GaussNoise(p=0.1),
                A.ISONoise(p=0.1),
                A.GridDropout(ratio=0.5, p=0.2),
                A.CoarseDropout(max_holes=16, min_holes=8, max_height=16, max_width=16, min_height=8, min_width=8, p=0.2)
            ], p=0.2),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    else:
        return Compose([
            A.Resize(height=CFG.img_size, width=CFG.img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
复制代码

猜你喜欢

转载自juejin.im/post/7080060840858091557