基于python的数据集扩充增强

前言

  数据增强技术在深度学习中得到了广泛的应用,它能够有效地扩充训练数据集的大小,提高模型的泛化能力,同时也能够有效地防止过拟合现象的发生。在本篇中,将讲解一种基于 Python 和 OpenCV 库实现的数据增强方法,并提供一个示例代码。
  首先,需要安装 OpenCV 库及其 Python 接口。OpenCV(Open Source Computer Vision Library)是一个开源计算机视觉库,提供了计算机视觉和机器视觉中常用的算法和工具。在 Python 中使用 OpenCV 库可以方便地对图像进行处理、分析和识别。

pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

  接下来,通过 Python 实现一个数据增强函数 augment_data(img),它将对输入的图像进行多种处理,包括水平翻转、缩放、旋转、添加高斯噪音和调整对比度和亮度等。
在这里插入图片描述

  下面是具体的代码实现:

代码如下:

import cv2
import numpy as np
import os
import glob

# 数据增强函数
def augment_data(img):
    rows,cols,_ = img.shape

    # 水平翻转图像
    if np.random.random() > 0.5:
        img = cv2.flip(img, 1)
        img_name = os.path.splitext(save_path)[0] + "_flip.png"
        cv2.imwrite(img_name, img)
        print("Saved augmented image:", img_name)

    # 随机缩放图像
    scale = np.random.uniform(0.9, 1.1)
    M = cv2.getRotationMatrix2D((cols/2, rows/2), 0, scale)
    img_transformed = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_transform.png"
    cv2.imwrite(img_name, img_transformed)
    print("Saved augmented image:", img_name)

    # 随机旋转图像
    angle = np.random.randint(-10, 10)
    M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, 1)
    img_rotated = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_rotated.png"
    cv2.imwrite(img_name, img_rotated)
    print("Saved augmented image:", img_name)

    # 添加高斯噪音
    mean = 0
    std = np.random.uniform(5, 15)
    noise = np.zeros(img.shape, np.float32)
    cv2.randn(noise, mean, std)
    noise = np.uint8(noise)
    img_noisy = cv2.add(img, noise)
    img_name = os.path.splitext(save_path)[0] + "_noisy.png"
    cv2.imwrite(img_name, img_noisy)
    print("Saved augmented image:", img_name)

    # 随机调整对比度和亮度
    alpha = np.random.uniform(0.8, 1.2)
    beta = np.random.randint(-10, 10)
    img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    img_name = os.path.splitext(save_path)[0] + "_contrast.png"
    cv2.imwrite(img_name, img_contrast)
    print("Saved augmented image:", img_name)

    return img


# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = "data"
save_dir = "result"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.png")):

    img = cv2.imread(img_path)

    # 获取保存增强后的图片文件名
    img_name = os.path.basename(img_path)
    save_path = os.path.join(save_dir, img_name)

    # 数据增强
    augment_data(img)

    # 保存原始图片
    cv2.imwrite(save_path, img)
    print("Saved original image:", save_path)

讲解

首先,创建一个名为 result 的文件夹用于存储增强后的图像。然后,利用 glob 库对 data 文件夹中的所有图像进行遍历,并通过 cv2.imread 函数读取图像文件。获取保存增强后的图片的路径和文件名,这可以通过 os.path.basename 函数和 os.path.join 函数实现。接下来,调用 augment_data 函数对读取的图像进行数据增强,并保存增强后的图片。通过 OpenCV 库提供的 cv2.imwrite 函数,可以将增强后的每个图像保存为不同的文件,文件名中加入 _flip、_transform、_rotated、_noisy 和 _contrast 等后缀表示不同的增强方式。

综上,讲解了通过 Python 和 OpenCV 库实现了一个简单的数据增强方法。通过对图像进行水平翻转、缩放、旋转、添加噪音和调整对比度和亮度等操作,扩充了训练数据集的大小,用于提高了模型的泛化能力和鲁棒性。

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43788282/article/details/131168198
今日推荐