TensorFlow目标检测——通过鼠标点击图片即输出label

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/luoxueqian/article/details/82705521

AI最常见是应用是在图像识别上,即输入一张图像,输出该图像对应的类别。

Google开源了一些TensorFlow Object Detection API,下面以ssd_mobilenet_v1_coco_2017_11_17模型为例,介绍如何在TensorFlow中进行目标检测。

目标检测的代码结构如下:

目标检测的数据集是COCO数据集,包含90种可检测的目标,存储在mscoco_label_map.pbtxt文件中

打开如下:

首先打开模型及label文件

# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'   #使用模型
MODEL_FILE = MODEL_NAME + '.tar.gz'
#DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'     模型下载地址

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90    #识别类型个数

#opener = urllib.request.URLopener()
#opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if 'frozen_inference_graph.pb' in file_name:
    tar_file.extract(file, os.getcwd())      #.pb文件为训练出的模型

#load graph
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)   #PATH_TO_LABELS文件中包含90种标签
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)   #获取90种可识别的种类
category_index = label_map_util.create_category_index(categories)   #包含90个字典1:{'id':1, 'name':'dog'}

定义函数load_image_into_numpy_array,将图片转换成矩阵

#将图片转换成矩阵
def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size    #获取输入图片的大小,即长宽
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)    #图片转换成矩阵

定义函数verification_box_label,用来分辨鼠标所点击的坐标处属于哪种类型

扫描二维码关注公众号,回复: 3569735 查看本文章
#分别输入鼠标坐标,图片中box的坐标(有多少目标会生成多少个box)及box所对应的标签
#此函数只考虑box不相交的情况
#mouse_position[0]为鼠标点击处x坐标,mouse_position[1]为鼠标点击处y坐标
def verification_box_label(mouse_position, boxes_position, class_label):
    boxes_num = len(boxes_position)    #计算box的个数
    #print(boxes_position)
    #print(boxes_position[0][0])
    flage = 1
    class_null = "no result"
    for i in range(boxes_num):    #用循环计算坐标点属于哪个box
        if (mouse_position[0] >= boxes_position[i][0] and mouse_position[0] <= boxes_position[i][1] and mouse_position[1] >= boxes_position[i][2] and mouse_position[1] <= boxes_position[i][3]):
            flage = 0
            class_out = class_label[i]
    if flage:
        return class_null
    else:
        return class_out

文件mouse_location.py,实现点击鼠标,输出鼠标所在处的坐标

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

#im = Image.open("result.jpg")
#plt.imshow(im)
#pos=plt.ginput(1)
#print(pos)
#x_position = pos[0][0]
#y_position = pos[0][1]
#print(pos[0][0])
#print(pos[0][1])
#box_position = np.zeros((2))

def mouse_event(event, x, y, flags, param):
    global box_position
    # 通过event判断具体是什么事件,这里是左键按下
    font = cv2.FONT_HERSHEY_SIMPLEX  # 定义字体
    if event == cv2.EVENT_LBUTTONDOWN:
        print((x, y))
        box_position = np.array([x, y])

def box(img):
    cv2.namedWindow("img")          #构建窗口
    cv2.setMouseCallback("img", mouse_event)   #回调绑定窗口
    cv2.imshow("img",img)        #显示图像
    #if cv2.EVENT_RBUTTONDOWN:
        #cv2.waitKey(1000)
    cv2.waitKey(0)  # 等待鼠标操作
    #print(box_position)
    cv2.destroyAllWindows()         #关闭窗口

实现目标识别

PATH_TO_TEST_IMAGES_DIR = 'test_images'
img_count = 0
ls = os.listdir(PATH_TO_TEST_IMAGES_DIR)
for i_num in ls:
    if os.path.isfile(os.path.join(PATH_TO_TEST_IMAGES_DIR, i_num)):
        img_count += 1  # 统计文件夹中图片的张数

#测试图片路径
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, img_count+1) ]

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

#run
with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.方框
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.百分数
    # Score is shown on the result image, together with the class label.识别出来的类别
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    for image_path in TEST_IMAGE_PATHS:
      image = Image.open(image_path)
      # the array based representation of the image will be used later in order to prepare the
      # result image with boxes and labels on it.调用函数,将图片转换成矩阵
      image_np = load_image_into_numpy_array(image)    #矩阵形式的图片
      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
      image_np_expanded = np.expand_dims(image_np, axis=0)
      # Actual detection.
      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
      # Visualization of the results of a detection.
      #box_location = visualization_utils.box_size
      image_box, box_location, class_nn = visualization_utils.visualize_boxes_and_labels_on_image_array(
          image_np,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8)
      #print(class_nn)   #分别表示left, right, top, bottom
      (r, g, b)=cv2.split(image_box)
      image_box=cv2.merge([b,g,r])
      mouse_location.box(image_box)
      x_y_position = mouse_location.box_position   #表示鼠标的横纵坐标
      #调用函数识别鼠标点击坐标所属类别
      mouse_position_label = verification_box_label(x_y_position, box_location, class_nn)
      print(mouse_position_label)
      #cv2.imshow('capture.jpg', image_box)
      #cv2.waitKey(0)    #等待鼠标操作
      #cv2.imwrite('result.jpg', image_box)
cv2.destroyAllWindows()

检测结果:

点击猫所在的方框,输出结果为鼠标坐标及label

ssd模型运行速度快,但对于一张图上多种目标,可能会出现漏识别的现象

下一篇介绍ImageNet图像识别模型

猜你喜欢

转载自blog.csdn.net/luoxueqian/article/details/82705521