基于TensorFlow的SSD车辆检测-1

版权声明:本文为博主原创文章,转载请注明出处 https://blog.csdn.net/shuzfan/article/details/79034555

此系列博客是用来学习Tensorflow和Python的,由于是新手上车,如有错误之处希望大家不吝指出。

整个项目可以从百度云下载:链接:https://pan.baidu.com/s/1f2JPJpE7m5M2kSifMP0-Lw 密码:9p8v

一. 训练数据准备

在训练数据准备环节,主要包含下面三块内容:

  • 怎样解析用于车辆检测训练的KITTI数据集
  • 怎样进行数据扩张来增大训练数据的多样性
  • 怎样在训练阶段为模型供给batch训练数据

1. 读取KITTI数据集

首先到KITTI官网http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=2d下载车辆检测数据集。

具体地,只用下载下面3个压缩包:(需要提供邮箱以获得下载链接)

KITTI数据集采用了一张图片对应一个标注文件的形式,其中标注文件是TXT格式,内容为N行15列,每一列都使用空格隔开。这15列的内容是:

列号 名称 描述
1 类别 目标类别,共8类:’Car’, ‘Van’, ‘Truck’,’Pedestrian’, ‘Person_sitting’, ‘Cyclist’, ‘Tram’, ‘Misc’ 或者 ‘DontCare’
2 是否有截断 指目标是否超出图像边界,0: (non-truncated), 1: (truncated)
3 遮挡情况 0 = fully visible, 1 = partly occluded 2 = largely occluded, 3 = unknown
4 目标观测角度 范围[-pi..pi]
5-8 目标bbox 坐标从0开始,[left, top, right, bottom]
9-11 3D维度 3D object dimensions: height, width, length (in meters)
12-14 3D空间坐标 D object location x,y,z in camera coordinates (in meters)
15 Y轴旋转角 Rotation ry around Y-axis in camera coordinates [-pi..pi]
16 置信度得分 仅用于Test,浮点数,用于绘制p/r曲线

备注:‘DontCare’表示忽略的未标记区域,这可能是因为超出了激光扫描仪的工作范围。测试时,位于该部分区域的结果会自动被忽略。训练时可以同样将此部分忽略,防止在此区域不断地引起Hard Mining操作。

由于这里只进行车辆检测,因此标注信息中我们暂时只关注类别和BBox信息。此外,将’Car’, ‘Van’, ‘Truck’这3类合并为正样本目标,其余区域作为背景区域。

首先,我们需要批量的读取每一个标注文件:

# readKITTI.py 用于解析KITTI数据集

import os

# 获取指定后缀名的文件列表
def get_filelist(path,ext):
    # 获取某个文件夹下的所有文件
    filelist_temp  = os.listdir(path)
    filelist = []
    # 通过比较后缀,选中所有TXT标注文件
    for i in filelist_temp:
        if os.path.splitext(i)[1] == ext:
            filelist.append(os.path.splitext(i)[0])
    return filelist

# 解析标注文件并返回目标的bounding box信息,维度Nx4
def get_bbox(filename):
    bbox = []
    # 判断文件是否存在
    if os.path.exists(filename):
        with open(filename) as fi:
            label_data = fi.readlines()
        # 依次读取每一行标注信息
        for l in label_data:
            data = l.split()
            # 如果存在车辆目标则记录bounding box
            if data[0] in ['Van','Car','Truck']:
                bbox.append((float(data[4]),float(data[5]),
                    float(data[6]),float(data[7])))
    return bbox

# 批量获取标注文件的bounding box信息
def get_bboxlist(rootpath,imagelist):
    bboxlist = []
    for i in imagelist:
        bboxlist.append(get_bbox(rootpath + i +'.txt'))
    return bboxlist

通过调用上述函数,我们便可以读取KITTI数据集为我们需要的形式:

import readKITTI

IMAGE_DIR = './image/training/image_2/'
LABEL_DIR = './label/training/label_2/'

imagelist = readKITTI.get_filelist(IMAGE_DIR,'.png')
bboxlist  = readKITTI.get_bboxlist(LABEL_DIR,imagelist)

2. 数据扩张

在深度学习模型的训练过程中,数据扩张(Data Augmentation)通常都会被使用。其中,随机缩放、所及裁剪、随机翻转应当是使用最广泛的且行之有效的手段。(至于对比度调整、颜色调整、PCA这些东西,还真不好说。)

对于目标检测而言,相当重要的一点是:对图像进行调整的同时,也要保证目标bounding box的有效性与正确性。

缩放

为了后续模型训练的时候可以使用Batch,通常我们会将输入图像固定到统一尺寸,因此图像resize并调整颜色统一是必不可少的。

# imAugment.py 提供一些用于数据扩张的函数

import cv2

# 将图像按照指定尺寸进行缩放,同时处理boundingbox以及颜色信息
def imresize(in_img,in_bbox,out_w,out_h,is_color = True):
    # 判断是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 获取图像宽度与高度
    height, width = in_img.shape[:2]
    out_img = cv2.resize(in_img,(out_w, out_h))
    # 调整图像颜色
    if is_color == True and in_img.ndim == 2 :
        out_img = cv2.cvtColor(out_img, cv2.COLOR_GRAY2BGR)
    elif is_color == False and in_img.ndim == 3 :
        out_img = cv2.cvtColor(out_img, cv2.COLOR_BGR2GRAY)
    # 调整bounding box
    s_h = out_h / height
    s_w = out_w / width
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((i[0]*s_w, i[1]*s_h, i[2]*s_w, i[3]*s_h))
    return out_img, out_bbox

水平翻转

对于车辆检测,垂直翻转没有必要,我们这里只进行水平翻转,并对应的翻转bounding box。

# imAugment.py 提供一些用于数据扩张的函数

# 将图像进行水平翻转,同时处理boundingbox
def immirror(in_img,in_bbox):
    # 判断是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 图像水平翻转
    out_img = cv2.flip(in_img,1)
    # 获取图像宽度
    width = out_img.shape[1]
    # 重新调整目标在翻转后图像上的位置
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((width - i[0], i[1], width-i[2], i[3]))
    return out_img, out_bbox

随机裁剪

随机裁剪其实有很多约束和注意事项,主要有下面几点:

  • 需要指定最小裁剪块的大小。否则如果裁剪块过小,则不适用于训练。
  • 过小的图像不应当再被裁剪
  • 由于我们无法准确的形容一个被裁剪掉一块的目标是否还是一个有效的可被识别的目标,因此我们的裁剪区域应当包含所有目标的bounding box。
# imAugment.py 提供一些用于数据扩张的函数

import random
# 将图像进行随机crop,同时处理boundingbox, min_wh为crop块的最小宽高
def imcrop(in_img,in_bbox,min_hw):
    # 判断是否是字符串
    if isinstance(in_img,str):
        in_img = cv2.imread(in_img)
    # 获取图像宽度与高度
    height, width = in_img.shape[:2]
    # 如果图像过小,则放弃crop
    if height <= min_hw and width <= min_hw:
        return in_img, in_bbox
    # 为了防止有效目标被crop截断,crop范围应包含所有目标
    # 下面寻找包含所有目标的最小矩形
    min_x1, min_y1, min_x2, min_y2 = width-1, height-1, 0, 0
    for i in in_bbox:
        min_x1 = min(min_x1,int(i[0]))
        min_y1 = min(min_y1,int(i[1]))
        min_x2 = max(min_x2,int(i[2]))
        min_y2 = max(min_y2,int(i[3]))

    # 根据最小包围框,再随机生成一个矩形框,并防止框超出图像范围
    rand_x1, rand_y1, rand_x2, rand_y2 = 0, 0, width, height
    if min_x1 <= 1:
        rand_x1 = 0
    else:
        rand_x1 = random.randint(0,min(min_x1,max(width - min_hw,1)))
    if min_y1 <= 1:
        rand_y1 = 0
    else:
        rand_y1 = random.randint(0,min(min_y1,max(height - min_hw,1)))
    if min_x2 >= width or rand_x1 + min_hw >= width:
        rand_x2 = width
    else:
        rand_x2 = random.randint(max(rand_x1+min_hw,min_x2),width)
    if min_y2 >= height or rand_y1 + min_hw >= height:
        rand_y2 = height
    else:
        rand_y2 = random.randint(max(rand_y1+min_hw,min_y2),height)

    # crop图像
    out_img = in_img[rand_y1:rand_y2-1,rand_x1:rand_x2-1]
    # 处理bounding box
    out_bbox = []
    for i in in_bbox:
        out_bbox.append((i[0]-rand_x1,i[1]-rand_y1,i[2]-rand_x1,i[3]-rand_y1))
    return out_img, out_bbox

下面给出效果图:(最上面的是原图,下面依次是水平翻转,缩放和随机裁剪)

这里写图片描述

3. Batch生成

训练阶段我们需要生成一个个batch用于训练,一般需要的参数设置包括:batchsize、训练图片的大小、颜色、是否shuffle数据、是否随机crop等。基于此,下面给出一个供给batch的代码:

# genBatch.py 用于训练阶段供给训练数据

# coding=utf-8
import random
import readKITTI
import imAugment
import cv2


class genBatch:
    image_dir, label_dir = [], []
    image_list, bbox_list = [], []
    initOK = False

    # 初始化读取数据
    def initdata(self, imagedir, labeldir):
        self.image_dir, self.label_dir = imagedir, labeldir
        self.image_list = readKITTI.get_filelist(imagedir,'.png')
        self.bbox_list  = readKITTI.get_bboxlist(labeldir,self.image_list)
        # 如果数据不为空且图片和label数量相匹配
        if len(self.image_list) > 0 and len(self.image_list) == len(self.bbox_list):
           self.initOK = True
        else:
            print("The amount of images is %d, while the amount of"
                    "corresponding label is %d"%(len(self.image_list),len(self.bbox_list)))
            self.initOK = False
        return self.initOK

    readPos = 0

    # 生成一个新的batch
    def genbatch(self,batchsize,newh,neww,iscolor=True,isshuffle=False,
                mirrorratio=0.0, cropratio=0.0):
        if self.initOK == False:
            print("The initdata() function must be successfully called first.")
            return []
        batch_data, batch_bbox = [], []
        for i in range(batchsize):
            # 当数据遍历一遍时
            if self.readPos >= len(self.image_list)-1:
                self.readPos = 0
                if isshuffle == True:
                    # 指定同一随机种子,保证图片和label采用同样的乱序
                    r_seed = random.random()
                    random.seed(r_seed)
                    random.shuffle(self.image_list)
                    random.seed(r_seed)
                    random.shuffle(self.bbox_list)
            img = cv2.imread(self.image_dir + self.image_list[self.readPos] + '.png')
            bbox = self.bbox_list[self.readPos]
            self.readPos += 1

            # 按照指定概率进行crop,切记裁剪应当发生在resize之前
            if cropratio > 0 and random.random() < cropratio:
                img, bbox = imAugment.imcrop(img,bbox,min(neww,newh))

            # 调整图像大小及颜色
            img, bbox = imAugment.imresize(img,bbox,neww,newh,iscolor)

            # 按照指定概率进行随机镜像
            if mirrorratio > 0 and random.random() < mirrorratio:
                img, bbox = imAugment.immirror(img,bbox)

            batch_data.append(img)
            batch_bbox.append(bbox)
        return batch_data, batch_bbox

猜你喜欢

转载自blog.csdn.net/shuzfan/article/details/79034555