Tensorflow调用目标检测模型并显示

Tensorflow调用目标检测模型并显示

import numpy as np
import tensorflow as tf
import cv2 as cv

# Read the graph.
with tf.gfile.FastGFile('exported_model/frozen_inference_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Session() as sess:
    # Restore session
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    # Read and preprocess an image.
    img = cv.imread('mouth_dataset/test/1_107.jpg')
    inp = cv.resize(img, (300, 300))
    inp = inp[:, :, [2, 1, 0]]  # BGR2RGB

    # Run the model
    out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
                    sess.graph.get_tensor_by_name('detection_scores:0'),
                    sess.graph.get_tensor_by_name('detection_boxes:0'),
                    sess.graph.get_tensor_by_name('detection_classes:0')],
                   feed_dict={
    
    'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})

    # Visualize detected bounding boxes.
    num_detections = int(out[0][0])
    for i in range(num_detections):
        classId = int(out[3][0][i])
        score = float(out[1][0][i])
        bbox = [float(v) for v in out[2][0][i]]
        if score > 0.3:
            x = bbox[1] * img.shape[1]
            y = bbox[0] * img.shape[0]
            right = bbox[3] * img.shape[1]
            bottom = bbox[2] * img.shape[0]
            cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (255, 255, 0), thickness=1)

cv.imshow('TensorFlow', img)
cv.waitKey()
  • 修改第6行exported_model/frozen_inference_graph.pb换成自己训练好的模型;
  • 修改第17行img = cv.imread(‘mouth_dataset/test/1_107.jpg’),换成自己需要检测的图片;

猜你喜欢

转载自blog.csdn.net/gaoqing_dream163/article/details/112675907