MeanShift 目标跟踪

MeanShift算法,又称为均值漂移算法,采用基于颜色特征的核密度估计,寻找局部最优,使得跟踪过程中对目标旋转,小范围遮挡不敏感。

MeanShift 原理

MeanShift的本质是一个迭代的过程,在一组数据的密度分布中,使用无参密度估计寻找到局部极值(不需要事先知道样本数据的概率密度分布函数,完全依靠对样本点的计算)。

在d维空间中,任选一个点,然后以这个点为圆心,h为半径做一个高维球,因为有d维,d可能大于2,所以是高维球。落在这个球内的所有点和圆心都会产生一个向量,向量是以圆心为起点落在球内的点位终点。然后把这些向量都相加。相加的结果就是下图中黄色箭头表示的MeanShift向量:

然后,再以这个MeanShift 向量的终点为圆心,继续上述过程,又可以得到一个MeanShift 向量:

不断地重复这样的过程,可以得到一系列连续的MeanShift 向量,这些向量首尾相连,最终可以收敛到概率密度最大得地方(一个点):

从上述的过程可以看出,MeanShift 算法的过程就是:从起点开始,一步步到达样本特征点的密度中心。

MeanShift 跟踪步骤

1.获取待跟踪对象

获取初始目标框(RoI)位置信息(x,y,w,h),截取 RoI图像区域

# 初始化RoI位置信息 
track_window = (c,r,w,h) 
# 截取图片RoI
roi = img[r:r+h, c:c+w]

2.转换颜色空间

将BGR格式的RoI图像转换为HSV格式,对 HSV格式的图像进行滤波,去除低亮度和低饱和度的部分。

在 HSV 颜色空间中要比在 BGR 空间中更容易表示一个特定颜色。在 OpenCV 的 HSV 格式中,H(色度)的取值范围是 [0,179], S(饱和度)的取值范围 [0,255],V(亮度)的取值范围 [0,255]。

# 转换到HSV
hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
# 设定滤波的阀值
lower = np.array([0.,130.,32.])
upper = np.array([180.,255.,255.])
# 根据阀值构建掩模
mask = cv2.inRange(hsv,lower, upper)

3.获取色调统计直方图

# 获取色调直方图
roi_hist = cv2.calcHist([hsv_roi],[0],mask,[180],[0,180])
# 直方图归一化
cv2.normalize(roi_hist,roi_hist,0,180,cv2.NORM_MINMAX)

cv2.calcHist的原型为:

cv2.calcHist(images, channels, mask, histSize, ranges[, hist[, accumulate ]])  
  • images: 待统计的图像,必须用方括号括起来,

  • channels:用于计算直方图的通道,这里使用色度通道

  • mask:滤波掩模

  • histSize:表示这个直方图分成多少份(即多少个直方柱)

  • ranges:表示直方图中各个像素的值的范围

4.在新的一帧中寻找跟踪对象

# 读入目标图片
ret, frame = cap.read()
# 转换到HSV
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
# 获取目标图片的反向投影
dst = cv2.calcBackProject([hsv],[0],roi_hist,[0,180],1)
# 定义迭代终止条件
term_crit = ( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1 )
# 计算得到迭代次数和目标位置
ret, track_window = cv2.meanShift(dst, track_window, term_crit)

meanShift 函数原型

def meanShift(probImage, window, criteria)
  • probImage:输入反向投影直方图

  • window:需要移动的矩形(ROI)

  • criteria:对meanshift迭代过程进行控制的初始参量

其中,criteria参数如下:

  • type:判定迭代终止的条件类型:

    • COUNT:按最大迭代次数作为求解结束标志

    • EPS:按达到某个收敛的阈值作为求解结束标志

    • COUNT + EPS:两个条件达到一个就算结束

  • maxCount:具体的最大迭代的次数

  • epsilon:具体epsilon的收敛阈值

反向投影

反向投影图输出的是一张概率密度图,与输入图像大小相同,每一个像素值代表了输入图像上对应点属于目标对象的概率,像素点越亮,代表这个点属于目标物体的概率越大。

跟踪目标:

跟踪目标在下一帧中的反向投影:

MeanShift 跟踪器

import numpy as np
import cv2

class MeanShiftTracer:

    def __init__(self, id):
        # Stop criteria for the iterative search algorithm.
        self._term_crit = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1)
        self._roi_hist = None
        self.predict_count = 0
        self.frame = None
        self.frame_begin_id = id
        self.frame_end_id = id
        self.roi_xywh = None

    def _log_last_correct(self, frame, frame_id, xywh):
        x, y, w, h = xywh
        self.correct_box = (x, y, w, h)
        self.correct_img = frame[y:y + h, x:x + w]
        self.correct_id = frame_id

    def correct(self, frame, frame_id, xywh):
        self._log_last_correct(frame,frame_id, xywh)
        self._refresh_roi(frame, frame_id, xywh)
        self.predict_count = 0

    def predict(self, frame, frame_id):
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        dst = cv2.calcBackProject([hsv], [0], self._roi_hist, [0, 180], 1) 
        ret, track_window = cv2.meanShift(dst, self.roi_xywh, self._term_crit)

        self._refresh_roi(frame, frame_id, track_window)
        self.predict_count += 1
        return track_window

    def _refresh_roi(self, frame, frame_id, xywh):
        x, y, w, h = xywh
        roi = frame[y:y + h, x:x + w]
        hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
        mask = cv2.inRange(hsv_roi, np.array((0., 60., 32.)), np.array((180., 255., 255.)))
        roi_hist = cv2.calcHist([hsv_roi], [0], mask, [180], [0, 180])
        cv2.normalize(roi_hist, roi_hist, 0, 180, cv2.NORM_MINMAX)

        self.roi_xywh = (x, y, w, h)
        self._roi_hist = roi_hist
        self.frame = frame
        self.frame_end_id = frame_id

    def get_roi_info(self):
        return {'correct_box': self.correct_box,
                'correct_img': self.correct_img,
                'correct_id': self.correct_id,
                'beginId': self.frame_begin_id,
                'endId': self.frame_end_id}

跟踪管理器

import numpy as np
import cv2

class TracerManager:

    def __init__(self, image_shape, trace_tool, trace_margin, max_predict):
        """
        :param image_shape: (height,width)
        :param trace_tool: MeanShiftTracer
        :param trace_margin: (0,0,30,50)(px)(left,top,right,bottom)
        :param max_predict: 3 (times)
        """

        self._tracers = []
        self._trace_tool = trace_tool
        self._max_predict = max_predict
        self._image_shape = image_shape
        self.trace_margin = trace_margin

    def _calc_iou(self, A, B):
        """
        :param A: [x1, y1, x2, y2]
        :param B: [x1, y1, x2, y2]
        :return: IoU
        """

        IoU = 0
        iw = min(A[2], B[2]) - max(A[0], B[0])
        if iw > 0:
            ih = min(A[3], B[3]) - max(A[1], B[1])
            if ih > 0:
                A_area = (A[2] - A[0]) * (A[3] - A[1])
                B_area = (B[2] - B[0]) * (B[3] - B[1])
                uAB = float(A_area + B_area - iw * ih)
                IoU = iw * ih / uAB

        return IoU

    def box_in_margin(self, box):
        in_bottom = (self._image_shape[0] - (box[1] + box[3])) < self.trace_margin[3]
        in_right = (self._image_shape[1] - (box[0] + box[2])) < self.trace_margin[2]
        return in_bottom or in_right

    def _get_box_tracer_iou(self, A, B):
        a = (A[0], A[1], A[0] + A[2], A[1] + A[3])
        b = (B[0], B[1], B[0] + B[2], B[1] + B[3])
        return self._calc_iou(a, b)

    def _check_over_trace(self):
        remove_tracer = []
        trace_info = []

        for t in self._tracers:
            if t.predict_count > self._max_predict:
                remove_tracer.append(t)
                if t.frame_end_id != t.frame_begin_id:
                    trace_info.append(t.get_roi_info())

        for t in remove_tracer:
            self._tracers.remove(t)

        return trace_info

    def _get_tracer(self, box):
        tracer = None
        maxIoU = 0

        for t in self._tracers:
            iou = self._get_box_tracer_iou(box, t.roi_xywh)
            if iou > maxIoU:
                tracer = t
                maxIoU = iou

        return tracer

    def update_tracer(self, frame, frame_id, boxes):
        trace_info = self._check_over_trace()

        for box in boxes:
            if self.box_in_margin(box):
                continue

            tracer = self._get_tracer(box)

            if tracer is not None:
                tracer.correct(frame, frame_id, box)
            else:
                tracer = self._trace_tool(frame_id)
                tracer.correct(frame, frame_id, box)
                self._tracers.append(tracer)

        return trace_info

    def trace(self, frame, frame_id):
        track_windows = []

        for t in self._tracers:
            window = t.predict(frame, frame_id)
            track_windows.append(window)

        return track_windows

车辆监测与跟踪

检测与跟踪以1:1的比例交替进行。

import cv2
import numpy as np
import os.path
import Tracer

class car_detector:

    def __init__(self, cascade_file):
        if not os.path.isfile(cascade_file):
            raise RuntimeError("%s: not found" % cascade_file)
        self._cascade = cv2.CascadeClassifier(cascade_file)

    def _detect_cars(self, image):
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        gray = cv2.equalizeHist(gray)
        cars = self._cascade.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=15, minSize=(60, 60))
        return cars

    def _show_trace_object(self, infos):
        for info in infos:
            title = "%d - %d from frame: %d" % (info['beginId'], info['endId'], info['correct_id'])
            cv2.imshow(title, info['correct_img']) 
            cv2.waitKey(1)

    def _get_area_invalid_mark(self, img_shape, margin):
        area = np.zeros(img_shape,np.uint8)
        h, w = img_shape[:2]
        disable_bg_color = (0, 0, 80)
        disable_fg_color = (0, 0, 255)

        cv2.rectangle(area, (0, h-margin[3]), (w, h), disable_bg_color, -1)
        cv2.putText(area, "Invalid Region", (w-220, h-20), cv2.FONT_HERSHEY_SIMPLEX, 1, disable_fg_color, 2)
        return area

    def _show_trace_state(self, image, id, tracer, state, boxes, mark):
        image = cv2.addWeighted(mark, 0.5, image, 1, 0)
        title = 'frame : %s [%s]' % (state, id)
        colors = {'detect': (0, 255, 0), 'trace': (255, 255, 0), 'invalid': (150, 150, 150), 'title_bg': (0, 0, 0)}

        for (x, y, w, h) in boxes:
            if tracer.box_in_margin((x, y, w, h)):
                cv2.rectangle(image, (x, y), (x + w, y + h),colors['invalid'], 2)
            else:
                cv2.rectangle(image, (x, y), (x + w, y + h), colors[state], 2)

        cv2.rectangle(image, (10, 20), (250, 50), colors['title_bg'], -1)
        cv2.putText(image, title, (30, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[state],2)
        cv2.imshow("result", image)
        cv2.waitKey(1)

    def trace_detect_video(self, video_path, trace_rate = 1):
        cap = cv2.VideoCapture(video_path)
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 

        start_frame = 0
        invalid_margin = (0, 0, 0, 100)
        mark = self._get_area_invalid_mark((h, w, 3), invalid_margin)
        tracer = Tracer.TracerManager((h, w), Tracer.MeanShiftTracer, invalid_margin, trace_rate + 5)
        warm = False
        while True:
            ret, image = cap.read()
            start_frame += 1
            if not ret: return
            result = image.copy()

            if not warm or start_frame % (trace_rate + 1) == 0:
                warm = True
                cars = self._detect_cars(image)
                self._show_trace_state(result, start_frame, tracer, 'detect', cars, mark)

                trace_obj = tracer.update_tracer(image, start_frame, cars)
                self._show_trace_object(trace_obj)
            else:
                cars = tracer.trace(image, start_frame)
                self._show_trace_state(result, start_frame, tracer, 'trace', cars, mark)


if __name__ == "__main__":

    car_cascade_lbp_21 = './train/cascade_lbp_21/cascade.xml'
    video_path = "./test.mp4"

    detect = car_detector(car_cascade_lbp_21)
    detect.trace_detect_video(video_path)

MeanShift 算法的优缺点

优点:

  • 算法计算量不大,在目标区域已知的情况下完全可以做到实时跟踪;

  • 采用核函数直方图模型,对边缘遮挡、目标旋转、变形和背景运动不敏感。

缺点:

  • 跟踪过程中由于窗口宽度大小保持不变,框出的区域不会随着目标的扩大(或缩小)而扩大(或缩小);

  • 当目标速度较快时,跟踪效果不好;

  • 直方图特征在目标颜色特征描述方面略显匮乏,缺少空间信息;

猜你喜欢

转载自blog.csdn.net/lk3030/article/details/84108765
今日推荐