第三章 利用TensorFlow Object Detection API实现摄像头实时物体检测

通过第二章节,已经在Ubuntu16.04上实现了利用Google的TensorFlow Object Detection API对图片上物体的检测。这一部分在此基础上修改代码实现捕捉摄像头视频流,并对视频流实时物体检测。

1、安装opencv-python

安装opencv直接在终端执行:

sudo pip install opencv-python

查看opencv版本,执行:

pkg-config --modversion opencv

下图表安装成功 ,我的是3.3.1

下面写一个测试代码实现摄像头实时捕捉: 

import cv2
import numpy as np

cap = cv2.VideoCapture(0)
while(1):
    # get a frame
    ret, frame = cap.read()
    # show a frame
    cv2.imshow("capture", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
cap.release()
cv2.destroyAllWindows() 

 同样在jupyter上点击run all,或者按Shift+Eenter,就会显示出视频流

2、修改代码实现实时物体检测 

代码根据research/object_detection路径的object_detection_tutorial.ipynb修改而来,其中红色字体的步骤的代码需要修改。

(1)导入各种包,代码不变

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
    raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')

(2)增加导入cv包,以及获取摄像头设备号

import cv2
cap = cv2.VideoCapture(0)

 (3)从utils模块引入label_map_util和visualization_utils

这两个包很关键,label_map_util用于后面获取图像标签和类别,visualization_utils用于可视化。代码不变

from utils import label_map_util
from utils import visualization_utils as vis_util

(4)获取预训练模型

模型下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

里面包含很多预训练好的模型,各个模型的speed、mAP都给了参考,其中数据集为COCO(90个类别)

我选择了速度较快的ssd_mobilenet_v1_coco ,下载后解压到research/object_detection路径。

然后添加模型路径:

CWD_PATH = os.getcwd()
PATH_TO_CKPT = os.path.join(CWD_PATH,'ssd_mobilenet_v1_coco_2017_11_17','frozen_inference_graph.pb')

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join(CWD_PATH,'data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90

(5)加载模型,代码不变

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

(6)加载lable map,代码不变

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

 (7)最后核心代码

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while True:
            ret, image_np = cap.read()
            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object was detected.
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            # Actual detection.
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            # Visualization of the results of a detection.
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),category_index,
                use_normalized_coordinates=True,
                line_thickness=8)

            cv2.imshow('object detection', cv2.resize(image_np, (800,600)))
            if cv2.waitKey(25) & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                break
cap.release()
cv2.destroyAllWindows()

3、测试

附完整代码:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
    raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
import cv2
cap = cv2.VideoCapture(0)
from utils import label_map_util
from utils import visualization_utils as vis_util
CWD_PATH = os.getcwd()
PATH_TO_CKPT = os.path.join(CWD_PATH,'ssd_mobilenet_v1_coco_2017_11_17','frozen_inference_graph.pb')

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join(CWD_PATH,'data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while True:
            ret, image_np = cap.read()
            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object was detected.
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            # Actual detection.
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            # Visualization of the results of a detection.
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),category_index,
                use_normalized_coordinates=True,
                line_thickness=8)

            cv2.imshow('object detection', cv2.resize(image_np, (800,600)))
            if cv2.waitKey(25) & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                break
cap.release()
cv2.destroyAllWindows()

 结果图:

猜你喜欢

转载自blog.csdn.net/hunzhangzui9837/article/details/82857743