Tensorflow realizes object detection

content

  • Background of the project
  • Introduction to TensorFlow
  • Environment construction
  • Model selection
  • Api Instructions for Use
  • run route
  • summary

You drive more than that

Background of the project

When the product sees competing products, it can mark the function of the object, adhering to the usual, he has to have me, he does not have the style that I have more, I threw it over to a website, saying that this function is very simple, it will definitely be realized

image

At this time, the all-powerful Google played a role and discovered the Tensorflow machine learning framework in the vast sea of ​​data, which is currently very popular deep learning (artificial intelligence).

Introduction to Tensorflow

Encyclopedia introduction: TensorFlow is the second-generation artificial intelligence learning system developed by Google based on DistBelief. It can be used in many machine learning and deep learning fields such as speech recognition or image recognition.

image

Translated into vernacular: It is a deep learning and neural network framework, the underlying C++, controlled by Python, of course, also supports Go, Java and other languages

Environment construction

  • Unix/Linux (I use Mac)
  • Python3.6
  • protoc 3.5.1
  • tensorflow 1.7.0
1. Clone the file

git clone https://github.com/guandeng/tensorflow.git

The file directory format is as follows

└── tensorflow
    ├── Dockerfile
    ├── README.md
    ├── data
    │   ├── models
    │   ├── pbtxt
    │   └── tf_models
    ├── object_detection_api.py
    ├── server.py
    ├── sh
    │   ├── download_data.sh
    │   └── ods.sh
    ├── static
    ├── templates
    └── upload
  • data/models storage
  • data/pbtxt object identifier name
  • data/tf_models stores tensorflow/models data
2. Install the dependency library

pip3 install -r requirements.txt

3. Download the model

sh sh/download_data.sh

4. Add the environment variable PYTHONPATH

echo 'PYTHONPATH=$PYTHONPATH:pwd/data/tf_models/models/research'>> ~/.bashrc && source ~/.bashrc

5. Start the service

python3 server.py

If there is no error, it means that you have successfully built the environment. Is the use process very simple? The following describes the logic process of code calling

Model selection

I selected several models from Google for comparison

Model name Speed mAP[^1]
ssd_mobilenet_v1_coco 30 21
ssd_mobilenet_v2_coco 31 22
ssd_inception_v2_coco 42 24
faster_rcnn_inception_resnet_v2_atrous_coco 620 37
  • Speed ​​is the speed of recognizing objects, the smaller the value, the faster the recognition
  • mAP (average accuracy) is the product of accuracy and detection bounding box. The higher the value, the higher the recognition accuracy of the neural network and the greater the corresponding Speed.

For the convenience of testing, the author chooses lightweight (ssd_mobilenet) as the object recognition model for this time.

Import Python library

import numpy as np
import os
import tensorflow as tf
import json
import time
from PIL import Image
# 兼容Python2.7版本
try:
    import urllib.request as ulib
except Exception as e:
    import urllib as ulib
import re
from object_detection.utils import label_map_util

load model

MODEL_NAME = 'data/models/ssd_mobilenet_v2_coco_2018_03_29'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data/pbtxt','mscoco_label_map.pbtxt')  # CWH: Add object_detection path
# data下mscoco_label_map.pbtxt最大item.id
NUM_CLASSES = 90
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  # 加载模型
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

Load the tag map, the integer returned by the built-in function will be mapped to the pbtxt character tag

The format of mscoco_label_map.pbtxt is as follows

item {
  name: "/m/01g317"
  id: 1
  display_name: "person"
}
item {
  name: "/m/0199g"
  id: 2
  display_name: "bicycle"
}
# 加载标签
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(graph=detection_graph,config=config) as sess:
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # 物体坐标
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # 检测到物体的准确度
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
def get_objects(file_name, threshold=0.5):
    image = Image.open(file_name)
    # 判断文件是否是jpeg格式
    if not image.format=='JPEG':
        result['status'] = 0
        result['msg'] = file_name+ ' is ' + image.format + ' ods system allow jpeg or jpg'
        return result
    image_np = load_image_into_numpy_array(image)
    # 扩展维度
    image_np_expanded = np.expand_dims(image_np, axis=0)
    output = []
    # 获取运算结果
    (boxes, scores, classes, num) = sess.run(
        [detection_boxes, detection_scores, detection_classes, num_detections],
        feed_dict={image_tensor: image_np_expanded})
    # 去掉纬度为1的数组
    classes = np.squeeze(classes).astype(np.int32)
    scores = np.squeeze(scores)
    boxes = np.squeeze(boxes)
    for c in range(0, len(classes)):
        if scores[c] >= threshold:
            item = Object()
            item.class_name = category_index[classes[c]]['name'] # 物体名称
            item.score = float(scores[c]) # 准确率
            # 物体坐标轴百分比
            item.y1 = float(boxes[c][0])
            item.x1 = float(boxes[c][1])
            item.y2 = float(boxes[c][2])
            item.x2 = float(boxes[c][3])
            output.append(item)
    # 返回JSON格式
    outputJson = json.dumps([ob.__dict__ for ob in output])
    return outputJson

run route

Logic under server.py

def image():
    startTime = time.time()
    if request.method=='POST':
        image_file = request.files['file']
        base_path = os.path.abspath(os.path.dirname(__file__))
        upload_path = os.path.join(base_path,'static/upload/')
        # 保存上传图片文件
        file_name = upload_path + image_file.filename
        image_file.save(file_name)
        # 准确率过滤值
        threshold = request.form.get('threshold',0.5)
        # 调用Api服务
        objects = object_detection_api.get_objects(file_name, threshold)
        # 模板显示
        return render_template('index.html',json_data = objects,img=image_file.filename)

curl http://localhost:5000 | python -m json.tool

[
    {
        "y2": 0.9886252284049988,
        "class_name": "bed",
        "x2": 0.4297400414943695,
        "score": 0.9562674164772034,
        "y1": 0.5202791094779968,
        "x1": 0
    },
    {
        "y2": 0.9805927872657776,
        "class_name": "couch",
        "x2": 0.4395904541015625,
        "score": 0.6422878503799438,
        "y1": 0.5051193833351135,
        "x1": 0.00021047890186309814
    }
]

Access URL experience in browser

http://localhost:5000/upload

summary

  • Tensorflow uses GPU efficiency by orders of magnitude
  • You can try different models to compare speed and accuracy
  • This case also supports python2, in order to keep up with the times, it is recommended to use python3

Everyone must be curious, how to train the objects you need to detect, you can look forward to the next article

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324932964&siteId=291194637