用RANSAC算法实现干扰严重的直线拟合(续)求点线距离

用RANSAC算法实现干扰严重的直线拟合~_扫地僧1234的博客-CSDN博客直线拟合无非是最小二乘,但是如果有一些干扰项可能效果就不好了,可以采用随机抽样一致性算法来找局内项,从而排队局外项的干扰,拟合出理想的直线。https://blog.csdn.net/ogebgvictor/article/details/125359704

 接上一篇“用RANSAC算法实现干扰严重的直线拟合”, 上一篇的末尾没有说是如何求点线距离的。今天补充一下。首先说明,上一篇的代码里面实际上是先求点A到直线的垂线与该直线的交点B,然后求AB两点的距离,就是直线距离。这个方法其实是绕远了,完全没必要这么算。

如下图,求C点到直线的距离,CE与直线垂直,E就是C到直线的垂线的交点,CE的长度就是C点到直线的距离。但是我们实际上并不需要去求出这个交点,现在由C点做一个垂直于x轴的线与直线相交于D(其实就是直接把C点的x坐标带到直线方程中,求到的y就是D点的y坐标,而D点的x坐标跟C的x坐标是一样的),就会发现CE=CD乘cos(θ),而θ就是直线的倾角。

 直线方程有两种形式:

形式1.y=ax+b,a=tan(θ),当然我们不需要先由a求θ=arctan(θ),再求cos(θ) 。而是直接由下式计算出

原因很简单啦,如下图cos(θ)可不就是1除以斜边

 形式2.公式如下,斜率就是vy/vx啦

cos(θ)由下式求出,跟上面其实是一样的。

 

接下来上代码,先给出第一版代码,这里叫get_distance_line_and_p_batch1是区别于上一篇文章里用的get_distance_line_and_p_batch,并且get_distance_line_and_p_batch还返回了交点的坐标,所以这里为了保持一致,多返回了两个0。

def get_distance_line_and_p_batch1(vx, vy, p0, p1_list):
    cos_theta = vx / math.sqrt(vx ** 2 + vy ** 2)  # 求cos(θ)

    p1_list = p1_list.tolist()

    # y = (x - x0) * vy / vx + y0就是直线方程啦,把点集p1_list的每一个点的x坐标值代入,就求得了其垂线(垂直于x轴)在直线上的点的y坐标了
    p2_y_list = [(x - p0[0]) * vy / vx + p0[1] for x, y in p1_list]

    # 然后就是按照图示的方法由y坐标的差值乘以余弦值得到点到直线的距离了
    distance_list = [abs((p1[1] - y2) * cos_theta) for p1, y2 in zip(p1_list, p2_y_list)]
    distance_array = np.array(distance_list)


    return distance_array, 0, 0

并且调用的地方要改一下函数名,这里就先不上代码了。 

 耗时0.02秒。不对啊,之前那个繁琐的方法只耗了0.009秒,现在改的简单了,怎么还耗时更多了?(接下来会尝试几种不同的情况,目的是为了分享一下python里面到底怎么充分利用numpy来加快运行速度,嫌啰嗦的可以直接滑过~~)

 可能会怀疑上面那个p1_list.tolist()浪费时间了,每次都tolist了一下,其实它花不了多少时间。

我们可以给它加个耗时检测

会发现它耗时极短,不过这里总共花了0.03秒,又多了一点,是因为print是耗时的~~ 

 接下来再分享一个可能经常犯错的地方,比如像下面这样,我不用p1_list.tolist(),p1_list本身是一个numpy的ndarray,它也是可以遍历的麻,那我直接遍历它不就得了,所以我把p1_list.tolist()注释掉了,上代码。

def get_distance_line_and_p_batch1(vx, vy, p0, p1_list):
    cos_theta = vx / math.sqrt(vx ** 2 + vy ** 2)  # 求cos(θ)

    # p1_list = p1_list.tolist()

    # y = (x - x0) * vy / vx + y0就是直线方程啦,把点集p1_list的每一个点的x坐标值代入,就求得了其垂线(垂直于x轴)在直线上的点的y坐标了
    p2_y_list = [(x - p0[0]) * vy / vx + p0[1] for x, y in p1_list]

    # 然后就是按照图示的方法由y坐标的差值乘以余弦值得到点到直线的距离了
    distance_list = [abs((p1[1] - y2) * cos_theta) for p1, y2 in zip(p1_list, p2_y_list)]
    distance_array = np.array(distance_list)


    return distance_array, 0, 0

也是可以运行的哦,这次变的更慢了!(这次是没有print的),上次那个还0.02秒,现在变0.03秒了,所以要注意,不要随便去遍历numpy的数组,numpy的数组最好是用批量操作。或者要么就转成list再遍历。总体的运行时间是:numpy批操作<list遍历<numpy遍历

 既然说到numpy批操作是最快的,那么接下来就改为直接用numpy的批操作了,上代码,遍历的代码都注释掉了,放在这里只是为了说明它们的含义,替换它们的只有一行,只有一行。就不多说了,这一行代码干的事情跟之前是一样的。提一句,p1_list[:, 0] - p0[0]这个其实是用到了广播机制,这是实现批操作的常用技俩~~

def get_distance_line_and_p_batch1(vx, vy, p0, p1_list):
    cos_theta = vx / math.sqrt(vx ** 2 + vy ** 2)  # 求cos(θ)

    # p1_list = p1_list.tolist()
    # # y = (x - x0) * vy / vx + y0就是直线方程啦,把点集p1_list的每一个点的x坐标值代入,就求得了其垂线(垂直于x轴)在直线上的点的y坐标了
    # p2_y_list = [(x - p0[0]) * vy / vx + p0[1] for x, y in p1_list]
    # 
    # # 然后就是按照图示的方法由y坐标的差值乘以余弦值得到点到直线的距离了
    # distance_list = [abs((p1[1] - y2) * cos_theta) for p1, y2 in zip(p1_list, p2_y_list)]
    # distance_array = np.array(distance_list)

    distance_array = abs((p1_list[:, 0] - p0[0]) * vy / vx + p0[1] - p1_list[:, 1]) * cos_theta

    return distance_array, 0, 0

 上效果图,这0秒。。。不知道今天是咋了,可能时间太短统计的不准了,昨天我试的还是0.009秒,总之跟之前那个先求点到直线的垂线交点,再求点和交点的长度的办法相比,速度起码是原来的10倍!

 最后上一个完整代码

import cv2 as cv
import numpy as np
import math
import time


def get_distance_line_and_p_batch1(vx, vy, p0, p1_list):
    cos_theta = vx / math.sqrt(vx ** 2 + vy ** 2)  # 求cos(θ)

    # p1_list = p1_list.tolist()
    # # y = (x - x0) * vy / vx + y0就是直线方程啦,把点集p1_list的每一个点的x坐标值代入,就求得了其垂线(垂直于x轴)在直线上的点的y坐标了
    # p2_y_list = [(x - p0[0]) * vy / vx + p0[1] for x, y in p1_list]
    #
    # # 然后就是按照图示的方法由y坐标的差值乘以余弦值得到点到直线的距离了
    # distance_list = [abs((p1[1] - y2) * cos_theta) for p1, y2 in zip(p1_list, p2_y_list)]
    # distance_array = np.array(distance_list)

    distance_array = abs((p1_list[:, 0] - p0[0]) * vy / vx + p0[1] - p1_list[:, 1]) * cos_theta

    return distance_array, 0, 0


def get_distance_line_and_p_batch(vx, vy, p0, p1_list):
    """
    求点到直线距离
    :param vx:
    :param vy:
    :param p0: 线上点
    :param p1: 线外点
    :return:
    """
    # print("vx = {}, vy = {}, p0 = {}, p1 = {}".format(vx, vy, p0, p1))
    x0, y0 = p0

    B0 = [vy * x0 - vx * y0] * len(p1_list)
    B1 = [vx * p1[0] + vy * p1[1] for p1 in p1_list]

    A = np.array([[vy, -vx], [vx, vy]], dtype=np.float32)
    B = np.array([B0, B1], dtype=np.float32)
    ret, X = cv.solve(A, B, flags=cv.DECOMP_LU)
    x2 = X[0, :]
    y2 = X[1, :]

    x1 = np.array([p1[0] for p1 in p1_list], dtype=np.float32)
    y1 = np.array([p1[1] for p1 in p1_list], dtype=np.float32)

    return ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5, x2, y2


def fitline_ransac(pts, distance, confidence, correct_rate, fit_use_num=2):
    """
    拟合直线
    :param pts:
    :param distance:
    :param confidence: 置信度
    :param correct_rate: 估计局内项所占比率
    :param fit_use_num: 一次拟合所用点数
    :return:
    """
    iter_num = int(math.log(1 - confidence, 1 - correct_rate ** fit_use_num))
    pts = np.float32(pts)
    pts_list = pts.tolist()

    print("iter_num = {}".format(iter_num))

    inliers = None
    inliers_num = 0

    for i in range(iter_num):
        tmp_pts = pts[np.random.choice(pts.shape[0], fit_use_num)]
        tmp_vx, tmp_vy, tmp_x, tmp_y = cv.fitLine(tmp_pts, cv.DIST_L2, 0, 0.01, 0.01)
        tmp_vx = tmp_vx[0]
        tmp_vy = tmp_vy[0]
        tmp_x = tmp_x[0]
        tmp_y = tmp_y[0]

        # tmp_distance, _, _ = get_distance_line_and_p_batch(tmp_vx, tmp_vy, (tmp_x, tmp_y), pts)
        tmp_distance, _, _ = get_distance_line_and_p_batch1(tmp_vx, tmp_vy, (tmp_x, tmp_y), pts)

        tmp_inliers = np.uint8(tmp_distance <= distance)  # 小于等于阈值的就是局内项
        tmp_inliers_num = np.sum(tmp_inliers)

        if inliers_num < tmp_inliers_num:  # 如果当前这次抽样的局内项更大,那就采用当前这一次
            inliers_num = tmp_inliers_num
            inliers = tmp_inliers

    # 之前两个点的拟合只是为了排除局外项,再用所有局内项拟合一次,结果更精确
    tmp_pts = pts[inliers > 0]
    vx, vy, x, y = cv.fitLine(tmp_pts, cv.DIST_L2, 0, 0.01, 0.01)

    print("total_num = {}, inliers_num = {}".format(pts.shape[0], inliers_num))

    return vx, vy, x, y


src = None
xb = None
yb = None
vx = None
vy = None
draw_mode = 0


def fitline_callback(event, x, y, flags, param):
    global xb, yb, src, draw_mode, vx, vy

    if xb == x or yb == y:
        return

    # 模式0:画直线,模式1:橡皮擦,点一下擦个圆,模式2:直线拟合
    if event == cv.EVENT_MBUTTONDOWN:
        draw_mode += 1
        draw_mode %= 3
        return

    if draw_mode == 0:
        if event == cv.EVENT_LBUTTONDOWN:
            xb = x
            yb = y
        elif event == cv.EVENT_LBUTTONUP:
            cv.line(src, (xb, yb), (x, y), (255, 0, 0), 1, cv.LINE_AA)
            cv.imshow('src', src)
    elif draw_mode == 1:
        if event == cv.EVENT_LBUTTONDOWN:
            cv.circle(src, (x, y), 5, (0, 0, 0), -1, cv.LINE_AA)
            cv.imshow('src', src)
    elif draw_mode == 2:
        if event == cv.EVENT_LBUTTONDOWN:
            src_test = np.copy(src)
            gray = cv.cvtColor(src_test, cv.COLOR_BGR2GRAY)

            # 这边就是获得图上像素值不为0的点集坐标,为啥是::-1?因为行号是y坐标,列数是x坐标,得倒过来才行啦。
            # x,y是分开的,所以得用dstack把它们拼起来,这个跟torch.cat差不多哦,不明白的话可以看我的上一篇博客~~
            pts = np.dstack(np.where(gray > 0)[::-1]).reshape(-1, 2)

            start_time = time.time()
            vx, vy, x, y = cv.fitLine(pts, cv.DIST_L2, 0, 0.01, 0.01)
            end_time = time.time()
            print("耗时{}秒".format((end_time - start_time)))

            lefty = int((-x * vy / vx) + y)  # 这个就是x=0时,y是多少
            righty = int(((src.shape[1] - x) * vy / vx) + y)  # 这个是x等于图像宽度时,y是多少,参考我上面那个直线方程的图就明白了
            cv.line(src_test, (src.shape[1] - 1, righty), (0, lefty), (0, 255, 0), 1, cv.LINE_AA)
            cv.imshow('src_test', src_test)

            src_test = np.copy(src)

            start_time = time.time()
            vx, vy, x, y = fitline_ransac(pts, 2, 0.98, 0.6, 2)
            end_time = time.time()
            print("耗时{}秒".format((end_time - start_time)))

            lefty = int((-x * vy / vx) + y)
            righty = int(((src.shape[1] - x) * vy / vx) + y)
            cv.line(src_test, (src.shape[1] - 1, righty), (0, lefty), (0, 255, 0), 1, cv.LINE_AA)
            cv.imshow('src_test1', src_test)


def fitline_test():
    global src
    src = np.zeros((500, 800, 3), dtype=np.uint8)
    cv.namedWindow('src', cv.WINDOW_AUTOSIZE)
    cv.imshow('src', src)
    cv.setMouseCallback('src', fitline_callback, None)
    cv.waitKey(0)
    cv.destroyAllWindows()


if __name__ == '__main__':
    fitline_test()

时间不早了,求点到直线的垂线交点的就先不详述了,反正也不应该用它来求点到直线的距离啦,如果感兴趣的话,下次再水一期:求点到直线的垂线交点~~

猜你喜欢

转载自blog.csdn.net/ogebgvictor/article/details/125611005