python opencv automatically tracking a moving target (DaSiamRPN)

1, the code has several bug, cv2.drawContours function return value is to be noted

 

2, the program calls the DaSiamRPN tracking network, tracking results are good, fps is also high

 

3, as follows

'''
读取视频、检测前景目标,调用DiameseRPN进行跟踪
'''
import cv2
import torch
import numpy as np
from os.path import realpath, dirname, join
from net import SiamRPNvot
from run_SiamRPN import SiamRPN_init, SiamRPN_track
from utils import get_axis_aligned_bbox, cxy_wh_2_rect



# load net
net = SiamRPNvot()
net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model')))
net.eval().cuda()


def videoTrack():
    video_path = " "
    cap = cv2.VideoCapture(0)
    ret,frame = cap.read()
    fgbg = cv2.createBackgroundSubtractorMOG2()

    startTrack = False
    restartTrack = False
    stopTrack = False
    isTracking = False

    while(ret):
        #前景检测
        fgmask = fgbg.apply(frame)
        element = cv2.getStructuringElement(cv2.MORPH_RECT,(3,3))

        #前景处理
        fgmask = cv2.erode(fgmask,element)
        masked = cv2.bitwise_and(frame,frame,mask=fgmask)

        #轮廓查找
        img,contours,hierarchy  = cv2.findContours(fgmask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(frame,contours,-1,(0,0,255),2)

        if(startTrack):
            #找到最大轮廓
            if(len(contours)>0):
                maxContour = contours[0]
                for contour in contours:
                    if contour.size>maxContour.size:
                        maxContour = contour
                x, y, w, h = cv2.boundingRect(maxContour)
                if(w*h>400):
                    # cv2.rectangle(frame,(x,y),(x+w,y+h),(0,0,255),2)
                    target_pos, target_sz = np.array([x, y]), np.array([w, h])
                    state = SiamRPN_init(frame, target_pos, target_sz, net)
                    isTracking = True
                    startTrack = False

        if(isTracking):
            state = SiamRPN_track(state, frame)
            res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
            res = [int(l) for l in res]
            cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3)


        #重新跟踪
        if(restartTrack):
            if(isTracking):
                isTracking = False
            else:
                startTrack = True
                restartTrack = False

        if(stopTrack):
            isTracking = False
            stopTrack = False


        cv2.imshow("track", frame)
        # cv2.imshow("fgmask", fgmask)
        # cv2.imshow("masked", masked)

        key = cv2.waitKey(10)
        if(key == 83): #S键开始跟踪
            print("------------开始跟踪-----------------")
            startTrack = True
        elif(key == 82): #R键重新跟踪
            print("------------重新跟踪-----------------")
            restartTrack = True
        elif(key == 80): #P键停止跟踪
            print("------------停止跟踪-----------------")
            stopTrack = True

        ret,frame = cap.read()










if __name__ == '__main__':
    videoTrack()

 

Released seven original articles · won praise 4 · Views 4303

Guess you like

Origin blog.csdn.net/fan_nlnl/article/details/85245992