OpenCV学习笔记(二):SVM+HOG实现的行人检测

因为一个项目的需求接触到OpenCV里的SVM和HOG特征算法,根据网上的教程一个博客,给自己准备了一个关于行人检测demo,里面也有一些代码也是参考网上的demo,这里大致记录下demo的代码和自己的遇到的一些小问题。
参考博客/文章:
HOG+SVM行人检测
目标检测的图像特征提取之(一)HOG特征
python+opencv3.4.0 实现HOG+SVM行人检测
软件环境:
Python:3.6.3
OpenCV:opencv-contrib-python 3.4.3.18
做这个demo的时候也遇到了一些小问题,大部分都是数据集的问题
训练出来的模型检测不出来图片中的人:因为最开始我使用的正样本的尺寸很大,居中裁剪后的128*64的图片里人已经不完整了
训练出来的模型检测一张图片要很久:正样本数据集里面有一些数据图片里面的人不完整或者没有人,目前做的demo检测一张图片的时间大概在200ms左右,其实在视屏检测中由于帧数低于30,画面看上去有点卡顿,看起来很不舒服,但是官方给出的训练好的检测器大概也是这个时间,所以我认为达到这个时间至少说明训练过程和训练数据集没有太大问题,只是好与不够好的区别。
训练用的数据集是我自己整理了后的数据集,把官方的一些数据集合了一下,增大了正负样本的数据量,以求训练处更精准的模型,但是事实证明好像影响不大…
数据集:https://pan.baidu.com/s/14tieF7WBMTr_J-BUzvYJqQ

训练代码:

import os
import myLogger
import numpy as np
import cv2 as cv
import random
import time


def load_data(pos_path, neg_path):
    pos_list = []
    neg_list = []

    pos = os.listdir(pos_path)
    neg = os.listdir(neg_path)

    for pos_p in pos:
        pos_list.append(os.path.join(pos_path, pos_p))
    for neg_p in neg:
        neg_list.append(os.path.join(neg_path, neg_p))

    return pos_list, neg_list


def load_train_samples(pos, neg):
    samples = []
    labels = []
    neg_list_for_corrct = []
    x_width = 64
    y_height = 128
    random.seed(1)
    for i in pos:
        img_matrix = cv.imread(i, cv.COLOR_BGR2GRAY)
        if img_matrix is not None:
            #取样本图片中间(128*64)部分
            if img_matrix.shape[0] > y_height and img_matrix.shape[1] > x_width:
                img_new = img_matrix[
                          (img_matrix.shape[0] - y_height) // 2:(img_matrix.shape[0] - y_height) // 2 + y_height,
                          (img_matrix.shape[1] - x_width) // 2:(img_matrix.shape[1] - x_width) // 2 + x_width]
                # img_new = cv.cvtColor(img_new, cv.COLOR_BGR2GRAY)
                samples.append(img_new)
                labels.append(1.)

    for i in neg:
        img_matrix = cv.imread(i, cv.COLOR_BGR2GRAY)
        if img_matrix is not None:
            neg_list_for_corrct.append(img_matrix)
            #随机获取10份(128*64)的图片:扩大训练数据集10倍
            for j in range(10):
                x = int(random.random() * (img_matrix.shape[1] - x_width))
                y = int(random.random() * (img_matrix.shape[0] - y_height))
                img_new = img_matrix[y:y + y_height, x:x + x_width]
                # img_new = cv.cvtColor(img_new,cv.COLOR_BGR2GRAY)
                samples.append(img_new)
                labels.append(-1.)

    labels_len = len(labels)
    labels = np.int32(labels)
    labels = np.resize(labels, (labels_len,))

    return samples, neg_list_for_corrct, labels


def extract_hog(samples):
    train = []
    colors_channel = 3
    # hog = cv.HOGDescriptor((64, 128), (16, 16), (8, 8), (8, 8), 9)
    hog = cv.HOGDescriptor()  #默认参数为训练128*64的数据集的参数
    for img in samples:
        if img.shape == (128, 64, colors_channel):
            descriptors = hog.compute(img)
            train.append(descriptors)
    train = np.float32(train)
    train = np.resize(train, (len(samples), 3780, 1))

    return train

def get_svm_detector(svm):
    '''
    导出可以用于cv2.HOGDescriptor()的SVM检测器,实质上是训练好的SVM的支持向量和rho参数组成的列表
    :param svm: 训练好的SVM分类器
    :return: SVM的支持向量和rho参数组成的列表,可用作cv2.HOGDescriptor()的SVM检测器
    '''
    sv = svm.getSupportVectors()
    rho, _, _ = svm.getDecisionFunction(0)
    sv = np.transpose(sv)
    return np.append(sv, [[-rho]], 0)


def train_SVM(train, labels, Logger):
    svm = cv.ml.SVM_create()
    svm.setCoef0(0.0)
    svm.setDegree(3)
    criteria = (cv.TERM_CRITERIA_MAX_ITER + cv.TERM_CRITERIA_EPS, 1000, 1e-3)
    svm.setTermCriteria(criteria)
    svm.setGamma(0)
    svm.setKernel(cv.ml.SVM_LINEAR)
    svm.setNu(0.5)
    svm.setP(0.1)  # for EPSILON_SVR, epsilon in loss function?
    svm.setC(0.01)  # From paper, soft classifier
    svm.setType(cv.ml.SVM_EPS_SVR)  # C_SVC # EPSILON_SVR # may be also NU_SVR # do regression task

    Logger.writeLog('Starting training svm...', level='info')
    time_start = time.time()
    svm.train(train, cv.ml.ROW_SAMPLE, labels)
    time_use = time.time() - time_start
    Logger.writeLog('Training done,use time:{}'.format(time_use), level='info')

    # return SVM检测器(直接用于HOG计算),SVM分类器
    return get_svm_detector(svm), svm


def train_correction(hog, samples, labels, neg_list, svm, Logger, train_correction_count):
    Logger.writeLog('{}.th correction is training...'.format(train_correction_count), level='info')
    start_time = time.time()
    count = 0
    labels = labels.tolist()
    for img in neg_list:
        rects, _ = hog.detectMultiScale(img, winStride=(4, 4), padding=(8, 8), scale=1.05)
        for (x, y, w, h) in rects:
            img_wrong = img[y:y + h, x:x + w]
            samples.append(cv.resize(img_wrong, (64, 128)))
            labels.append(-1)
            count += 1
    train = extract_hog(samples)

    labels_len = len(labels)
    labels = np.int32(labels)
    labels = np.resize(labels, (labels_len,))
    svm.train(train, cv.ml.ROW_SAMPLE, labels)

    Logger.writeLog('{}.th correction train is done,use time:{},wrong samples:{}'.format(train_correction_count,
                                                                                         int(time.time() - start_time),
                                                                                         count),
                    level='info')
    return get_svm_detector(svm), svm, samples, labels, count


if __name__ == '__main__':

    pos_path = r'F:/OpenCV/SVM_HOG_DATA/Train/pos_set'
    neg_path = r'F:/OpenCV/SVM_HOG_DATA/Train/neg_set'

    Logger = myLogger.LogHelper()
    Logger.writeLog("Program runing...", level='info')
    Program_start = time.time()

    #获取正负样本图片路径
    pos, neg = load_data(pos_path,neg_path)
    Logger.writeLog('load pos samples:' + str(len(pos)), level='info')
    Logger.writeLog('load neg samples:' + str(len(neg)), level='info')
    #加载正负样本数据(全样本数据,负样本图片库,全样本对应标签)
    samples, neg_list_for_corrct, labels = load_train_samples(pos, neg)

    #计算正负样本的HOG特征
    train = extract_hog(samples)
    Logger.writeLog('Size of feature vectors of samples:{}'.format(train.shape), level='info')
    Logger.writeLog('Size of labels of samples:{}'.format(labels.shape), level='info')

    #训练
    SVM_detector, svm = train_SVM(train, labels, Logger)
    train_correction_count = 0  #修正训练计数

    #用训练好的检测器检测负样本图片库,将错误识别加入全样本数据集重复训练
    while True:
        hog = cv.HOGDescriptor()
        hog.setSVMDetector(SVM_detector)
        train_correction_count += 1
        #参数:hog对象,所有样本,标签,原始负样本,svm分类器,日志对象,训练计数
        #返回:svm检测器,cvm分类器,所有样本,对应标签,错误样本数
        SVM_detector, svm, samples, labels, wrong_count = train_correction(hog,
                                                                           samples,
                                                                           labels,
                                                                           neg_list_for_corrct,
                                                                           svm,
                                                                           Logger,
                                                                           train_correction_count)

        # if wrong_count <= 100 or train_correction_count >= 10:
        if wrong_count <= 100:
            break

    hog.save('myHogDector.bin')
    Logger.writeLog('Program is done...,use time:{}'.format(time.time() - Program_start),
                    level='info')

测试代码:

import cv2
import time
import numpy as np
import os

def test_svm(hog):
    test_list = []
    test = os.listdir(r'F:/OpenCV/SVM_HOG_DATA/Test')
    for i in test:
        test_list.append(os.path.join(r'F:/OpenCV/SVM_HOG_DATA/Test', i))
    i = 0
    for f in test_list:
        i += 1
        print(i)
        img = cv2.imread(f,cv2.COLOR_BGR2GRAY)
        rects, _ = hog.detectMultiScale(img, winStride=(4, 4), padding=(8, 8), scale=1.05)
        for (x,y,w,h) in rects:
            cv2.rectangle(img,(x,y),(x+w,y+h),(0,0,255),2)
        cv2.imshow('{}'.format(i),img)
        if i>=10:
            break
    cv2.waitKey()

def test_svm_vidio(hog):
    cap = cv2.VideoCapture(0)
    while True:
        img = cap.read()[1]
        time_1 = time.time()
        rects, wei = hog.detectMultiScale(img, winStride=(4, 4), padding=(8, 8), scale=1.05)
        print(time.time() - time_1)
        for (x, y, w, h) in rects:
            cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)
        cv2.imshow('a', img)
        if cv2.waitKey(33) == 27:
            break
    cv2.destroyAllWindows()


hog = cv2.HOGDescriptor()
hog.load('myHogDector.bin')
#官方自带的检测器
# hog.setSVMDetector(cv2.HOGDescriptor_getDefaultPeopleDetector())

#两种测试方式:1.测试数据集  2.视频测试
test_svm(hog)
# test_svm_vidio(hog)

日志代码:

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 20 10:19:31 2018

@author: Python
"""

import logging
import sys

class LogHelper:
    
    def __init__(self, name='LogHelper', 
                 setLevel=logging.DEBUG):
       self.logger = logging.getLogger(name)
       self.formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
       
       self.file_handler = logging.FileHandler(name+".log")
       self.file_handler.setFormatter(self.formatter)
       
       self.consle_handler = logging.StreamHandler(sys.stdout)
       self.consle_handler.setFormatter(self.formatter)
       
       self.logger.setLevel(setLevel)
       
       self.logger.addHandler(self.file_handler)
       self.logger.addHandler(self.consle_handler)
    
    def writeLog(self, info, level='debug'):
        if level == "critial":
            self.logger.critical(info)
        elif level == "error":
            self.logger.error(info)
        elif level == "warning":
            self.logger.warning(info)
        elif level == "info":
            self.logger.info(info)
        else:
            self.logger.debug(info)
            
    def removeLog(self):
        self.logger.removeHandler(self.file_handler)
        self.logger.removeHandler(self.consle_handler)
        
if __name__ == "__main__":
    logger = LogHelper()
    logger.writeLog("helloworld", level='error')
    logger.removeLog()

猜你喜欢

转载自blog.csdn.net/qq_36272641/article/details/85686027