点云 数据增强(Data Augmentation):方法与python代码

数据集增强(Data Augmentation)是机器学习常用的数据预处理方法。例如,当手头的数据量太少时,可以人工生成一些有意义的数据用来训练,这种数据获取方法的突出优点是:成本低,效果好。另外,当用来分类的数据集有数据倾斜(skewed data)即某一类样本比另一类多很多时,可以这对样本较少的一类进行数据增强。

在图像领域,常用的数据增强方法有:旋转,镜像,缩放等。

而在激光点云中,常用的数据增强方法有:旋转,加噪声,降采样,不同程度的遮挡等。

这里暂时只考虑旋转和在每个点的坐标XYZ上加高斯噪声等。理论上也可以对回波强度加上噪声,但噪声的方差和均值很难把握,设置不对的话会起到相反的作用,因此这里先不考虑在回波强度上加噪声。事实上,当采用加噪声的方法进行数据增强时,必须仔细选择噪声的[方差]

具体代码如下:因为我这里主要考虑激光雷达采集到的路面交通对象,所以在旋转时只考虑了绕Z轴旋转。

# -*- coding: utf-8 -*-
#######################################
########## Data Augmentation ##########
#######################################

import numpy as np

###########
# 绕Z轴旋转 #
###########
# point: vector(1*3:x,y,z)
# rotation_angle: scaler 0~2*pi
def rotate_point (point, rotation_angle):
    point = np.array(point)
    cos_theta = np.cos(rotation_angle)
    sin_theta = np.sin(rotation_angle)
    rotation_matrix = np.array([[cos_theta, sin_theta, 0],
                                [-sin_theta, cos_theta, 0],
                                [0, 0, 1]])
    rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix)
    return rotated_point

# point = np.array([1,2,3])
# rotated_point = rotate_point(point, 0.1*np.pi)
# print rotated_point


###########
# 在XYZ上加高斯噪声 #
###########
def jitter_point(point, sigma=0.01, clip=0.05):
    assert(clip > 0)
    point = np.array(point)
    point = point.reshape(-1,3)
    Row, Col = point.shape
    jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip)
    jittered_point += point
    return jittered_point


# jittered_point = jitter_point(point)
# print jittered_point


###########
# Data Augmentation #
###########
def augment_data(point, rotation_angle, sigma, clip):
    return jitter_point(rotate_point(point, rotation_angle), sigma, clip)

point = np.array(point) 这一语句是将point转换为numpy数组,保证输入的List类型也能运行。

point = point.reshape(-1,3) 是将point变为行向量,考虑到输入有可能是列向量。

猜你喜欢

转载自blog.csdn.net/shaozhenghan/article/details/81265817
今日推荐