用自己的数据(kitti)训练测试faster rcnn(一)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014256231/article/details/79801665

Faster用的是PASCAL VOC数据集,所以将kitti数据改为PASCAL VOC的形式是最高效的办法。

1. 准备工作

  • 下载kitti数据集:
    数据集:Download left color images of object data set (12 GB)
    标注文件:Download training labels of object data set (5 MB)
  • 创建VOC2007文件夹,在文件夹下创建三个文件夹,命名如下:
    这里写图片描述
    JPEGImages文件夹用来存放所有数据集,Annotations文件夹用于放标注信息,在ImageSets文件夹内再创建三个文件夹,如下:
    这里写图片描述
  • 将下载的数据解压放在VOC2007下,图片放入JPEGImages里,标注信息解压后的label_2文件夹在VOC2007下,与Annotations等平行。

这就是所有的准备工作,接下来将kitti转换为pascal形式。

2. 转换数据集

2.1 kitti标注信息说明

kitti数据集中标注信息是存放在txt文本中的,包含如下信息:

Car 0.00 0 -1.67 642.24 178.50 680.14 208.68 1.38 1.49 3.32 2.41 1.66 34.98 -1.60
Car 0.00 0 -1.75 685.77 178.12 767.02 235.21 1.50 1.62 3.89 3.27 1.67 21.18 -1.60

这里写图片描述
根据上图,需要使用的只有类别‘Car’和物体外框的坐标‘387.63 181.54 423.81 203.12’,其余的字段都可以忽略。
我们分三步转换数据:

2.2 转换kitti类别

我忽略了’DontCare’和’Misc’类,也忽略了’Cyclist’类,因为图片中这一类过于小,标注信息也不准确,肉眼都难以分辨。我还合并’Person_sitting’和’Pedestrian’统一为’Pedestrian’。你可以根据自己需求修改。修改后可以看到txt文件里的内容已经变化。

#!/usr/bin/env python  
# -*- coding: UTF-8 -*-  

# modify_annotations_txt.py  
import glob  
import string  

txt_list = glob.glob('./label_2/*.txt') # 存储Labels文件夹所有txt文件路径  
def show_category(txt_list):  
    category_list= []  
    for item in txt_list:  
        try:  
            with open(item) as tdf:  
                for each_line in tdf:  
                    labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开  
                    category_list.append(labeldata[0]) # 只要第一个字段,即类别  
        except IOError as ioerr:  
            print('File error:'+str(ioerr))  
    print(set(category_list)) # 输出集合  

def merge(line):  
    each_line=''  
    for i in range(len(line)):  
        if i!= (len(line)-1):  
            each_line=each_line+line[i]+' '  
        else:  
            each_line=each_line+line[i] # 最后一条字段后面不加空格  
    each_line=each_line+'\n'  
    return (each_line)  

print('before modify categories are:\n')  
show_category(txt_list)  

for item in txt_list:  
    new_txt=[]  
    try:  
        with open(item, 'r') as r_tdf:  
            for each_line in r_tdf:  
                labeldata = each_line.strip().split(' ')

                '''if labeldata[0] in ['Truck','Van','Tram','Car']: # 合并汽车类  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'car')  
                if labeldata[0] in ['Person_sitting','Cyclist','Pedestrian']: # 合并行人类  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'pedestrian')'''
                #print type(labeldata[4])
                if labeldata[4] == '0.00':
                    labeldata[4] = labeldata[4].replace(labeldata[4],'1.00')
                if labeldata[5] == '0.00':
                    labeldata[5] = labeldata[5].replace(labeldata[5],'1.00')
                if labeldata[0] == 'Truck':  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'truck')
                if labeldata[0] == 'Van':  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'van') 
                if labeldata[0] == 'Tram':  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'tram')
                if labeldata[0] == 'Car':  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'car')
                #if labeldata[0] == 'Cyclist':  
                    #labeldata[0] = labeldata[0].replace(labeldata[0],'cyclist')
                if labeldata[0] in ['Person_sitting','Pedestrian']: # 合并行人类  
                    labeldata[0] = labeldata[0].replace(labeldata[0],'pedestrian')
                if labeldata[0] == 'Cyclist':  
                    continue
                if labeldata[0] == 'DontCare': # 忽略Dontcare类  
                    continue  
                if labeldata[0] == 'Misc': # 忽略Misc类  
                    continue  
                new_txt.append(merge(labeldata)) # 重新写入新的txt文件  
        with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去  
            for temp in new_txt:  
                w_tdf.write(temp)  
    except IOError as ioerr:  
        print('File error:'+str(ioerr))  

print('\nafter modify categories are:\n')  
show_category(txt_list)

2.3 转换标注信息格式:txt到xml

对原始txt文件进行上述处理后,接下来需要将标注文件从txt转化为xml,并去掉标注信息中用不上的部分,只留下3类,还有把坐标值从float型转化为int型,最后所有生成的xml文件要存放在Annotations文件夹中。

#!/usr/bin/env python  
# -*- coding: UTF-8 -*-  
# txt_to_xml.py  
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML  
from xml.dom.minidom import Document  
import cv2  
import os  

def generate_xml(name,split_lines,img_size,class_ind):  
    doc = Document()  # 创建DOM文档对象  

    annotation = doc.createElement('annotation')  
    doc.appendChild(annotation)  

    title = doc.createElement('folder')  
    title_text = doc.createTextNode('VOC2007')#这里修改了文件夹名  
    title.appendChild(title_text)  
    annotation.appendChild(title)  

    img_name=name+'.jpg'#要用jpg格式  

    title = doc.createElement('filename')  
    title_text = doc.createTextNode(img_name)  
    title.appendChild(title_text)  
    annotation.appendChild(title)  

    source = doc.createElement('source')  
    annotation.appendChild(source)  

    title = doc.createElement('database')  
    title_text = doc.createTextNode('The VOC2007 Database')#修改为VOC  
    title.appendChild(title_text)  
    source.appendChild(title)  

    title = doc.createElement('annotation')  
    title_text = doc.createTextNode('PASCAL VOC2007')#修改为VOC  
    title.appendChild(title_text)  
    source.appendChild(title)  

    size = doc.createElement('size')  
    annotation.appendChild(size)  

    title = doc.createElement('width')  
    title_text = doc.createTextNode(str(img_size[1]))  
    title.appendChild(title_text)  
    size.appendChild(title)  

    title = doc.createElement('height')  
    title_text = doc.createTextNode(str(img_size[0]))  
    title.appendChild(title_text)  
    size.appendChild(title)  

    title = doc.createElement('depth')  
    title_text = doc.createTextNode(str(img_size[2]))  
    title.appendChild(title_text)  
    size.appendChild(title)  

    for split_line in split_lines:  
        line=split_line.strip().split()  
        if line[0] in class_ind:  
            object = doc.createElement('object')  
            annotation.appendChild(object)  

            title = doc.createElement('name')  
            title_text = doc.createTextNode(line[0])  
            title.appendChild(title_text)  
            object.appendChild(title)  

            title = doc.createElement('difficult')  
            title_text = doc.createTextNode('0')  
            title.appendChild(title_text)  
            object.appendChild(title)  

            bndbox = doc.createElement('bndbox')  
            object.appendChild(bndbox)  
            title = doc.createElement('xmin')  
            title_text = doc.createTextNode(str(int(float(line[4]))))  
            title.appendChild(title_text)  
            bndbox.appendChild(title)  
            title = doc.createElement('ymin')  
            title_text = doc.createTextNode(str(int(float(line[5]))))  
            title.appendChild(title_text)  
            bndbox.appendChild(title)  
            title = doc.createElement('xmax')  
            title_text = doc.createTextNode(str(int(float(line[6]))))  
            title.appendChild(title_text)  
            bndbox.appendChild(title)  
            title = doc.createElement('ymax')  
            title_text = doc.createTextNode(str(int(float(line[7]))))  
            title.appendChild(title_text)  
            bndbox.appendChild(title)  

    # 将DOM对象doc写入文件  
    f = open('Annotations/'+name+'.xml','w')  
    f.write(doc.toprettyxml(indent = ''))  
    f.close()  

if __name__ == '__main__':  
    class_ind=('van', 'tram', 'car', 'pedestrian', 'truck')#修改为了5类  
    cur_dir=os.getcwd()  
    labels_dir=os.path.join(cur_dir,'label_2')  
    for parent, dirnames, filenames in os.walk(labels_dir): # 分别得到根目录,子目录和根目录下文件     
        for file_name in filenames:  
            full_path=os.path.join(parent, file_name) # 获取文件全路径  
            #print full_path  
            f=open(full_path)  
            split_lines = f.readlines()  
            name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名  
            #print name  
            img_name=name+'.jpg'   
            img_path=os.path.join('/Users/claire/Faster-RCNN_TF/data/VOCdevkit2007/VOC2007/JPEGImages',img_name) # 路径需要自行修改              
            #print img_path  
            img_size=cv2.imread(img_path).shape  
            generate_xml(name,split_lines,img_size,class_ind)  
print('all txts has converted into xmls')  

可以看到,Annotations文件夹里已经都是转换后的标注信息了。
这里写图片描述

2.4生成训练验证集和测试集列表

这个工具是用Python3写的,所以执行的时候要进入文件目录后,执行:
python3 create_train_test_txt.py

# create_train_test_txt.py  
# encoding:utf-8  
import pdb  
import glob  
import os  
import random  
import math  

def get_sample_value(txt_name, category_name):  
    label_path = './label_2/'  
    txt_path = label_path + txt_name+'.txt'  
    try:  
        with open(txt_path) as r_tdf:  
            if category_name in r_tdf.read():  
                return ' 1'  
            else:  
                return '-1'  
    except IOError as ioerr:  
        print('File error:'+str(ioerr))  

txt_list_path = glob.glob('./label_2/*.txt')  
txt_list = []  

for item in txt_list_path:  
    temp1,temp2 = os.path.splitext(os.path.basename(item))  
    txt_list.append(temp1)  
txt_list.sort()  
print(txt_list, end = '\n\n')  

# 有博客建议train:val:test=8:1:1,先尝试用一下  
num_trainval = random.sample(txt_list, math.floor(len(txt_list)*9/10.0)) # 可修改百分比  
num_trainval.sort()  
print(num_trainval, end = '\n\n')  

num_train = random.sample(num_trainval,math.floor(len(num_trainval)*8/9.0)) # 可修改百分比  
num_train.sort()  
print(num_train, end = '\n\n')  

num_val = list(set(num_trainval).difference(set(num_train)))  
num_val.sort()  
print(num_val, end = '\n\n')  

num_test = list(set(txt_list).difference(set(num_trainval)))  
num_test.sort()  
print(num_test, end = '\n\n')  

pdb.set_trace()  

Main_path = './ImageSets/Main/'  
train_test_name = ['trainval','train','val','test']  
category_name = ['van', 'tram', 'car', 'pedestrian', 'truck']#修改类别  

# 循环写trainvl train val test  
for item_train_test_name in train_test_name:  
    list_name = 'num_'  
    list_name += item_train_test_name  
    train_test_txt_name = Main_path + item_train_test_name + '.txt'   
    try:  
        # 写单个文件  
        with open(train_test_txt_name, 'w') as w_tdf:  
            # 一行一行写  
            for item in eval(list_name):  
                w_tdf.write(item+'\n')  
        # 循环写Car Pedestrian Cyclist  
        for item_category_name in category_name:  
            category_txt_name = Main_path + item_category_name + '_' + item_train_test_name + '.txt'  
            with open(category_txt_name, 'w') as w_tdf:  
                # 一行一行写  
                for item in eval(list_name):  
                    w_tdf.write(item+' '+ get_sample_value(item, item_category_name)+'\n')  
    except IOError as ioerr:  
        print('File error:'+str(ioerr)) 

用于faster rcnn训练的Pascal VOC格式的数据集总共就是三大块:首先是JPEGImages文件夹,放入了所有png图片;然后是Annotations文件夹,上述步骤已经生成了相应的xml文件;最后就是imagesSets文件夹,里面有一个Main子文件夹,这个文件夹存放的是训练验证集,测试集的相关列表文件。运行完代码会自动生成txt文件,其中训练测试部分的比例可以自行修改。如下图所示:
这里写图片描述

到此,所有数据准备工作就完成了。接下来是代码的修改。

猜你喜欢

转载自blog.csdn.net/u014256231/article/details/79801665
今日推荐