yolov3 map 计算

1.代码 作者原版

git clone https://github.com/pjreddie/darknet.git

 本文的代码包:

链接: https://pan.baidu.com/s/1O6YCyJLKXH9MrqnLAYmlug 提取码: dj4g

2.生成测试结果

2.1 修改data文件   

我这里把所有数据都训练了,为了统计训练的map,所以valid 的路径也放的是train路径

train.txt 内容

2.2 测试命令及结果文件

命令

# 对应调整 data cfg weights 文件

./darknet detector valid kitti_data/kitti.data cfg/yolov3-kitti.cfg backupp/yolov3-kitti_2000.weights

我这个电脑测试7481张文件共计用时624s,每张平均需要0.083411309s.

结果文件

存放在 results 中。每个类别自成一个txt文件。每个类别文件(唯一)包含所有测试图片。有几个类别就有几txt。

以 Car.txt 举例

每行含义:

图片命名,分数,框的位置(左上右下)

2.3 结果文件--->mAP 计算需要的形式

结果文件形式如上图;

mAP需要的形式:

每个图片都形成一个标签文件,保存在 detection-results 文件夹下。

其中一个标签文件格式展示:

转换代码:

参考修改自 https://blog.csdn.net/qq_32761549/article/details/90054023#4__54

# detection_transfer.py
import os


def creat_mapping_dic(result_txt, threshold=0.0):  # 设置一个阈值,用来删掉置信度低的预测框信息

    mapping_dic = {}  # 创建一个字典,用来存放信息
    txt = open(result_txt, 'r').readlines()  # 按行读取TXT文件

    for info in txt:  # 提取每一行
        info = info.split()  # 将每一行(每个预测框)的信息切分开

        photo_name = info[0]  # 图片名称
        probably = float(info[1])  # 当前预测框的置信度
        if probably < threshold:
            continue
        else:
            xmin = int(float(info[2]))
            ymin = int(float(info[3]))
            xmax = int(float(info[4]))
            ymax = int(float(info[5]))

            position = [xmin, ymin, xmax, ymax,probably]

            if photo_name not in mapping_dic:  # mapping_dic的每个元素的key值为图片名称,value为一个二维list,其中存放当前图片的若干个预测框的位置
                mapping_dic[photo_name] = []

            mapping_dic[photo_name].append(position)

    return mapping_dic


def creat_result_txt(raw_txt_path, target_path, threshold=0.0):  # raw_txt_path为yolo按类输出的TXT的路径 target_path 为转换后的TXT存放路径

    all_files = os.listdir(raw_txt_path)  # 获取所以的原始txt

    for each_file in all_files:  # 遍历所有的原始txt文件,each_file为一个文件名,例如‘car.txt’

        each_file_path = os.path.join(raw_txt_path, each_file)  # 获取当前txt的路径
        map_dic = creat_mapping_dic(each_file_path, threshold=threshold)  # 对当前txt生成map_dic

        for each_map in map_dic:  # 遍历当前存放信息的字典
            target_txt = each_map + '.txt'  # 生成目标txt文件名
            target_txt_path = os.path.join(target_path, target_txt)  # 生成目标txt路径

            if target_txt not in os.listdir(target_path):
                txt_write = open(target_txt_path, 'w')  # 如果目标路径下没有这个目标txt文件,则创建它,即模式设置为“覆盖”
            else:
                txt_write = open(target_txt_path, 'a')  # 如果目标路径下有这个目标txt文件,则将模式设置为“追加”

            class_name = each_file[:-4]  # 获取当前原始txt的类名
            # txt_write.write(class_name)  # 对目标txt写入类名
            # txt_write.write('\n')  # 换行

            for info in map_dic[each_map]:  # 遍历某张图片的所有预测框信息
                txt_write.write(class_name)  # 对目标txt写入类名
                txt_write.write(' '+str(info[4])+' '+str(info[0]) + ' ' + str(info[1]) +
                                ' ' + str(info[2]) + ' ' + str(info[3]) + ' ')  # 写入预测框信息
                txt_write.write('\n')  # 换行


creat_result_txt('/home/studieren/lunwen/darknet/results',
                 '/home/studieren/lunwen/darknet/result_23',
                 threshold=0.1)
# 第一个文件路径,结果文件。   输入路径
# 第二个文件路径,生成单独标签的保存路径。   输出路径
# 阈值,筛选阈值以上的保存。

2.4 gt 标签 转换成 mAP 计算需要的形式

gt标签的形式:

label_02

源kitti labels 经过合并的标签。只有 car pedestrain and cyclist。

mAP需要的形式:

将文件保存在 ground-truth 文件夹下

转换代码:

# gt_transfer.py
import glob
import os
import string

txt_list = glob.glob('/home/studieren/lunwen/darknet/label_2/*.txt')

for item in txt_list:
    new_txt = []
    try:
        with open(item, 'r') as r_tdf:
            for each_line in r_tdf:
                label = each_line.strip().split(' ')
                info1=int(float(label[4]))
                info2=int(float(label[5]))
                info3=int(float(label[6]))
                info4=int(float(label[7]))
                simple_label=str(label[0])+' '+str(info1)+' '+str(info2)+' '+str(info3)+' '+str(info4)+'\n'

                new_txt.append(simple_label)  # 重新写入新的txt文件

            with open(item, 'w+') as w_tdf:  # w+是打开原文件将内容删除,另写新内容进去
                for temp in new_txt:
                    w_tdf.write(temp)

    except IOError as ioerr:
        print('File error:' + str(ioerr))

2.5 mAP计算

代码链接:https://github.com/Cartucho/mAP

参考自:https://blog.csdn.net/weixin_44791964/article/details/104695264#mAP_96

使用方法:

1.将生成好的detection-results, 放入input 中 ,注意文件名需要 是 detection-results。 

2.将生成好的ground-truth, 放入input 中 ,注意文件名需要 是 ground-truth

3.在 mAP-master 文件夹下,运行 python main.py

生成的结果:

看出来模型效果不是很好,因为只迭代了2000步。

测试了一下8800步的。

3. 步骤

1. 在darknet文件夹下,使用命令
./darknet detector valid kitti_data/kitti.data cfg/yolov3-kitti.cfg backupp/yolov3-kitti_2000.weights
生成结果默认保存在results 文件中
2.更改输入地址和输出地址
运行detection_transfer.py,并将结果保存在 input/detection-results 文件夹下。
3.更改gt_label地址
运行gt_transfer.py,并将结果保存在 input/ground-truth 文件夹下。
4.查看结果文件,保存在outputs文件夹下。

可以配合 darknet yolov3 训练 kitti数据集 https://blog.csdn.net/qq_40297851/article/details/104937740  食用。

还在码具体训练过程,包括log文件生成loss的图。不定期更新也可能草稿箱吃灰了。

猜你喜欢

转载自blog.csdn.net/qq_40297851/article/details/106999045