数据增强(Data Augmentation)

版权声明:偷我的我会气死的, 希望你去访问我的个人主页:crazyang.top https://blog.csdn.net/yzy_1996/article/details/82223241

在训练过程中,网络优化是一方面,数据集的优化又是另一方面。数据集会存在各类样本不均匀的情况,也就是各类样本的数量不一样,有的甚至差别很大。为了让模型具有更强的鲁棒性,采用Data Augmentation是一个不错的选择。

常用的方法

  • Color Jittering:对颜色的数据增强:图像亮度、饱和度、对比度变化(此处对色彩抖动的理解不知是否得当)
  • PCA Jittering:首先按照RGB三个颜色通道计算均值和标准差,再在整个训练集上计算协方差矩阵,进行特征分解,得到特征向量和特征值。参见论文
  • Random Scale:尺度变换
  • Random Crop:采用随机图像差值方式,对图像进行裁剪、缩放;包括Scale Jittering方法(VGG及ResNet模型使用)或者尺度和长宽比增强变换
  • Horizontal/Vertical Flip:水平/垂直翻转
  • Shift:平移变换
  • Rotation/Reflection:旋转/仿射变换
  • Noise:高斯噪声、模糊处理
  • Label shuffle:类别不平衡数据的增广,参见海康威视ILSVRC2016的report

ImageDataGenerator()函数

这是Keras提供的一个自动增强的函数

ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=0.,
    width_shift_range=0.,
    height_shift_range=0.,
    shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    cval=0.,
    horizontal_flip=False,
    vertical_flip=False,
    rescale=None,
    dim_ordering=K.image_dim_ordering())

参数解释,所有参数不一定都需要包括,可以不写。

featurewise_center:布尔值,使输入数据集去中心化(均值为0)
samplewise_center:布尔值,使输入数据的每个样本均值为0
featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化
samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差
zca_whitening:布尔值,对输入数据施加ZCA白化
rotation_range:整数,数据提升时图片随机转动的角度
width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
channel_shift_range:浮点数,随机通道偏移的幅度
fill_mode:;‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理
cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值
horizontal_flip:布尔值,进行随机水平翻转
vertical_flip:布尔值,进行随机竖直翻转
rescale: 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)
dim_ordering:‘tf’和‘th’之一,规定数据的维度顺序。‘tf’模式下数据的形状为samples, width, height, channels,‘th’下形状为(samples, channels, width, height).该参数的默认值是Keras配置文件~/.keras/keras.json的image_dim_ordering值,如果你从未设置过的话,就是'th'

有了这个函数,那怎么具体对一张图片或者批量化处理

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
datagen = ImageDataGenerator(
    rotation_range=0.2,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

img = load_img('lena.jpg')
x = img_to_array(img)
x = x.reshape((1,) + x.shape)
i = 0
for batch in datagen.flow(x,
    batch_size=1,
    save_to_dir='data/preview',   #保存在这个文件夹下
    save_prefix='lena',
    save_format='jpg'):
    i += 1
    if i > 20:  #生成20张图
        break

其中flow函数的参数解释

X:样本数据,秩应为4.在黑白图像的情况下channel轴的值为1,在彩色图像情况下值为3
batch_size:整数,默认32
shuffle:布尔值,是否随机打乱数据,默认为True
save_to_dir:None或字符串,该参数能让你将提升后的图片保存起来,用以可视化
save_prefix:字符串,保存提升后图片时使用的前缀, 仅当设置了save_to_dir时生效
save_format:”png”或”jpeg”之一,指定保存图片的数据格式,默认”jpeg”
yields:形如(x,y)的tuple,x是代表图像数据的numpy数组.y是代表标签的numpy数组.该迭代器无限循环.
seed: 整数,随机数种子

以上所有详情可以参考keras官方文档
如果你对每一步的效果不够清楚,请参考参数演示

PCA Jittering方法

PCA是主成分分析,
根据AlexNet论文当中的

import numpy as np
import os
from PIL import Image, ImageOps
import random
from scipy import misc
import imageio

def PCA_Jittering(path):
    img_list = os.listdir(path)
    img_num = len(img_list)

    for i in range(img_num):
        img_path = os.path.join(path, img_list[i])
        img = Image.open(img_path)    

        img = np.asanyarray(img, dtype = 'float32')

        img = img / 255.0
        img_size = img.size // 3    #转换为单通道
        img1 = img.reshape(img_size, 3)

        img1 = np.transpose(img1)   #转置
        img_cov = np.cov([img1[0], img1[1], img1[2]])    #协方差矩阵
        lamda, p = np.linalg.eig(img_cov)     #得到上述协方差矩阵的特征向量和特征值

        #p是协方差矩阵的特征向量
        p = np.transpose(p)    #转置回去

        #生成高斯随机数
        alpha1 = random.gauss(0,3)
        alpha2 = random.gauss(0,3)
        alpha3 = random.gauss(0,3)

        #lamda是协方差矩阵的特征值
        v = np.transpose((alpha1*lamda[0], alpha2*lamda[1], alpha3*lamda[2]))     #转置

        #得到主成分
        add_num = np.dot(p,v)

        #在原图像的基础上加上主成分
        img2 = np.array([img[:,:,0]+add_num[0], img[:,:,1]+add_num[1], img[:,:,2]+add_num[2]])

        #现在是BGR,要转成RBG再进行保存
        img2 = np.swapaxes(img2,0,2)
        img2 = np.swapaxes(img2,0,1)
        save_name = 'pre'+str(i)+'.png'
        save_path = os.path.join(path, save_name)
        misc.imsave(save_path,img2)

        #plt.imshow(img2)
        #plt.show()

PCA_Jittering('testpic')

猜你喜欢

转载自blog.csdn.net/yzy_1996/article/details/82223241