【手把手教学】只需三步搭建自己的目标检测器(基于SSD算法)

先看效果

在这里插入图片描述
(舍友乱入哈哈哈)
在这里插入图片描述
在这里插入图片描述

第一步——安装依赖库


库名称 版本(我的版本)
tensorflow 1.14.0
opencv 3.4.2
numpy 1.16.3
matplotlib 3.0.3

安装教程:

在终端输入

pip install 库名称==版本号 --user

如:

pip install tensorflow==1.14.0 --user

在这里插入图片描述

注意事项:

版本号不一定需要严格按照我的来,但是如果出现了报错AttributeError: module ‘xxxx’ has no attribute ‘xxxxx’,很有可能就是版本不一致的问题;

第二步——下载源码并解压

源码地址:

https://github.com/balancap/SSD-Tensorflow

点击这里下载并解压:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

解压模型文件:

打开checkpoints文件夹,如图:
在这里插入图片描述
将这个压缩包解压:
在这里插入图片描述

第三步——复制调用模型程序

测试图片:

这一步需要创建一个demo_test.py文件,并将下面的代码复制到这个文件里:

# demo_test.py
from notebooks import visualization
from preprocessing import ssd_vgg_preprocessing
from nets import ssd_vgg_300, ssd_common, np_methods
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import os
import math
import random
import numpy as np
import tensorflow as tf
import cv2


slim = tf.contrib.slim


gpu_options = tf.GPUOptions(allow_growth=True)

config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)

isess = tf.InteractiveSession(config=config)


net_shape = (300, 300)

data_format = 'NHWC'

img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))


image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(

    img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)

image_4d = tf.expand_dims(image_pre, 0)


reuse = True if 'ssd_net' in locals() else None

ssd_net = ssd_vgg_300.SSDNet()

with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):

    predictions, localisations, _, _ = ssd_net.net(
        image_4d, is_training=False, reuse=reuse)


ckpt_filename = './checkpoints/ssd_300_vgg.ckpt'


isess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

saver.restore(isess, ckpt_filename)


ssd_anchors = ssd_net.anchors(net_shape)


def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):

    rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
                                                              feed_dict={img_input: img})

    rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
        rpredictions, rlocalisations, ssd_anchors,
        select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
    rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)

    rclasses, rscores, rbboxes = np_methods.bboxes_sort(
        rclasses, rscores, rbboxes, top_k=400)

    rclasses, rscores, rbboxes = np_methods.bboxes_nms(
        rclasses, rscores, rbboxes, nms_threshold=nms_threshold)

    rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)

    return rclasses, rscores, rbboxes


image_path = './cat.jpg' # 图片路径

img = mpimg.imread(image_path)

rclasses, rscores, rbboxes = process_image(img)  # 这里传入图片

# labeled_img = visualization.bboxes_draw_on_img(
#     img, rclasses, rscores, rbboxes, visualization.colors_plasma)  # 返回标注图片

visualization.plt_bboxes(img, rclasses, rscores, rbboxes)  # 展示(plt)标注图片

如图:

在这里插入图片描述

最后在这里修改需要测试的图片的路径,就可以啦:
在这里插入图片描述

测试视频:

我们创建一个detector.py程序:

# detector.pyimport cv2
from demo_test import process_image
from notebooks import visualization


class Detertor(object):

    def __init__(self, camera_index=0):

        self.camera_index = camera_index

    def Catch_Video(self, window_name='Detertor'):

        cv2.namedWindow(window_name)

        cap = cv2.VideoCapture(self.camera_index)

        while cap.isOpened():

            catch, frame = cap.read()  # 读取每一帧图片

            if not catch:

                raise Exception('Check if the camera if on.')

                break

            rclasses, rscores, rbboxes = process_image(frame)  # 这里传入图片

            labeled_img = visualization.bboxes_draw_on_img(
                frame, rclasses, rscores, rbboxes, visualization.colors_plasma)

            cv2.imshow(window_name, labeled_img)

            c = cv2.waitKey(10)
            if c & 0xFF == ord('q'):
                # 按q退出
                break

            if cv2.getWindowProperty(window_name, cv2.WND_PROP_AUTOSIZE) < 1:
                # 点x退出
                break
                                
        # 释放摄像头

        cap.release()

        cv2.destroyAllWindows()

if __name__ == "__main__":    

    detect = Detertor()    
    detect.Catch_Video()

大功告成

看一下效果:

图片测试效果:

在这里插入图片描述

视频测试效果:

在这里插入图片描述
这样就大功告成啦~

原理和论文及源码解析请看我的另外两篇博客:
SSD目标检测算法详解 (一)论文讲解
SSD目标检测算法详解 (二)代码详解

如果对你有帮助的话,记得点赞关注哦
在这里插入图片描述

发布了58 篇原创文章 · 获赞 117 · 访问量 6812

猜你喜欢

转载自blog.csdn.net/weixin_44936889/article/details/103751527