第五章 TensorFlow Object Detection API+双目视觉=实时物体定位+检测

本章节是第三章第四章的结合,旨在实现实时物体的定位+检测。

直接上代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 16 17:28:57 2018

@author: bjw
"""
import os
import cv2
import sys
#import time
#import argparse
#import multiprocessing
import numpy as np
import tensorflow as tf
#from matplotlib import pyplot as plt
import camera_configs

from IPython import get_ipython
get_ipython().run_line_magic('matplotlib', 'inline')
sys.path.append("..")
#cv2.namedWindow("left")
#cv2.namedWindow("right")
cv2.namedWindow("depth")
cv2.moveWindow("left", 0, 0)
cv2.moveWindow("right", 640, 0)
cv2.createTrackbar("num", "depth", 2, 10, lambda x: None)
cv2.createTrackbar("blockSize", "depth", 5, 255, lambda x: None)


# 添加点击事件,打印当前点的距离q
def callbackFunc(e, x, y, f, p):
    if e == cv2.EVENT_LBUTTONDOWN:        
        print (threeD[y][x])

cv2.setMouseCallback("depth", callbackFunc, None)

cap = cv2.VideoCapture(0)   

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

CWD_PATH = os.getcwd()
PATH_TO_CKPT = os.path.join(CWD_PATH,'ssd_mobilenet_v1_coco_2017_11_17','frozen_inference_graph.pb')

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join(CWD_PATH,'data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90
# Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,use_display_name=True)
category_index = label_map_util.create_category_index(categories)

#Load a frozen TF model 
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while True:
            ret1, frame = cap.read()
            #ret2, frame2 = camera2.read()
            #if not ret1 or not ret2:
            if ret1 != True:
                break
            cv2.resize(frame,(2560,360),interpolation=cv2.INTER_LINEAR)  
            dsize = (1280, 360)
            imagedst = cv2.resize(frame,dsize,interpolation=cv2.INTER_LINEAR)
    
            frame1 = imagedst[0:360,0:640]
            frame2 = imagedst[0:360,640:1280]

            # 根据更正map对图片进行重构
            img1_rectified = cv2.remap(frame1, camera_configs.left_map1, camera_configs.left_map2, cv2.INTER_LINEAR)
            img2_rectified = cv2.remap(frame2, camera_configs.right_map1, camera_configs.right_map2, cv2.INTER_LINEAR)

            # 将图片置为灰度图,为StereoBM作准备
            imgL = cv2.cvtColor(img1_rectified, cv2.COLOR_BGR2GRAY)
            imgR = cv2.cvtColor(img2_rectified, cv2.COLOR_BGR2GRAY)

            # 两个trackbar用来调节不同的参数查看效果
            num = cv2.getTrackbarPos("num", "depth")
            blockSize = cv2.getTrackbarPos("blockSize", "depth")
            if blockSize % 2 == 0:
                blockSize += 1
            if blockSize < 5:
                blockSize = 5

            # 根据Block Maching方法生成差异图(opencv里也提供了SGBM/Semi-Global Block Matching算法,有兴趣可以试试)
            stereo = cv2.StereoBM_create(numDisparities = 16*num, 
                                         blockSize = 31)
    
            disparity = stereo.compute(imgL, imgR)

            disp = cv2.normalize(disparity, disparity, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            # 将图片扩展至3d空间中,其z方向的值则为当前的距离
            threeD = cv2.reprojectImageTo3D(disparity.astype(np.float32)/16., camera_configs.Q)

            #cv2.imshow("left", img1_rectified)
            #cv2.imshow("right", img2_rectified)
            cv2.imshow("depth", disp)

            key = cv2.waitKey(1)
            if key == ord("q"):
                break
            elif key == ord("s"):
                cv2.imwrite(path_BM_left, imgL)
                cv2.imwrite(path_BM_right, imgR)
                cv2.imwrite(path_BM_depth, disp)
            #ret, image_np = cap.read()
            
            # 扩展维度,应为模型期待: [1, None, None, 3]
            image_np_expanded = np.expand_dims(img1_rectified, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # 每个框代表一个物体被侦测到
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

            #每个分值代表侦测到物体的可信度.  
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            # 执行侦测任务.  
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            
            #打印识别物体的个数
            final_score = np.squeeze(scores)   
            final_box = np.squeeze(boxes)
            count = 0
            for i in range(20):
                if scores is None or final_score[i] > 0.5:
                    count = count + 1
                    #print(final_box[i])
                    #print(finalbox[count])
                    Center = [((final_box[i][2]-final_box[i][0])/2+final_box[i][0])*360,
                              ((final_box[i][3]-final_box[i][1])/2+final_box[i][1])*640]
                    #print(Center)
                    y = int(Center[0])
                    x = int(Center[1])
                    depth = threeD[y][x]
                    #print(depth)
                    cv2.putText(img1_rectified,str(depth[2]) + "mm",(x,y),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0, 255, 255),2,cv2.FONT_HERSHEY_SIMPLEX)
        
            #print(count)

            # 检测结果的可视化
            vis_util.visualize_boxes_and_labels_on_image_array(
                img1_rectified,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                category_index,
                use_normalized_coordinates=True,
                line_thickness=8)
            cv2.imshow('object detection', cv2.resize(img1_rectified,(640,360)))
            if cv2.waitKey(25) & 0xFF ==ord('q'):
                cv2.destroyAllWindows()
                break
cap.release()
cv2.destroyAllWindows()

测试结果:

因为是晚上,视差图效果不太好 。

猜你喜欢

转载自blog.csdn.net/hunzhangzui9837/article/details/82860134