【笔记】YOLOv3训练自己的数据集(3)——小技巧和训练日志可视化

环境:Ubuntu18.04 GTX1080

4 小技巧

4.1 .weights文件保存时机

原版的.weights文件是当训练次数小于1000时,每隔100次保存一次,大于1000是每10000次保存一次,这对于检测类别数较少或者数据集太少的情况是很难受的,那么可以修改源代码,使得每隔1000次保存一次

gedit darknet/example/detector.c

把第138行的if(i%10000==0 || (i < 1000 && i%100 == 0))修改为if(i%1000==0),然后重新编译,就是每隔1000次迭代保存一次.weights文件了。
在这里插入图片描述
在这里插入图片描述

4.2 断点继续训练

如果想继续训练,就要保证backup下的yolov3.backup还在。

./darknet detector train cfg/voc.data cfg/yolov3-voc.cfg backup/yolov3.backup -gpu 0 | tee new_train_yolov3.log 

4.3 Nan

训练过程中nan的屏幕占比30%是正常的,如果太大,全是nan,那就是训练过程出了问题。
减少nan
①调高darknet/cfg/yolov3.cfg中的batch,注意得是32的整数倍。
②增加数据集规模,200张图片的训练集,太小了吧。实在不行,就做数据增强,可以把数据集增加到原来的十几倍。数据增强
更多问题,参考大佬笔记

5 训练日志可视化

前提是训练过程中保存了训练日志xxx.log

./darknet detector train cfg/voc.data cfg/yolov3-voc.cfg darknet53.conv.74 -gpu 0 | tee train_yolov3.log 

训练结束,运行下面的脚本。

# -*- coding: utf-8 -*-
# @Time    : 2018/12/30 16:26
# @Author  : lazerliu
# @File    : vis_yolov3_log.py
# @Func    :yolov3 训练日志可视化,把该脚本和日志文件放在同一目录下运行。

import pandas as pd
import matplotlib.pyplot as plt
import os

# ==================可能需要修改的地方=====================================#
g_log_path = "train_yolov3.log"  # 此处修改为你的训练日志文件名
# ==========================================================================#

def extract_log(log_file, new_log_file, key_word):
    '''
    :param log_file:日志文件
    :param new_log_file:挑选出可用信息的日志文件
    :param key_word:根据关键词提取日志信息
    :return:
    '''
    with open(log_file, "r") as f:
        with open(new_log_file, "w") as train_log:
            for line in f:
                # 去除多gpu的同步log
                if "Syncing" in line:
                    continue
                # 去除nan log
                if "nan" in line:
                    continue
                if key_word in line:
                    train_log.write(line)
    f.close()
    train_log.close()


def drawAvgLoss(loss_log_path):
    '''
    :param loss_log_path: 提取到的loss日志信息文件
    :return: 画loss曲线图
    '''
    line_cnt = 0
    for count, line in enumerate(open(loss_log_path, "rU")):
        line_cnt += 1
    result = pd.read_csv(loss_log_path, skiprows=[iter_num for iter_num in range(line_cnt) if ((iter_num < 500))],
                         error_bad_lines=False,
                         names=["loss", "avg", "rate", "seconds", "images"])
    result["avg"] = result["avg"].str.split(" ").str.get(1)
    result["avg"] = pd.to_numeric(result["avg"])

    fig = plt.figure(1, figsize=(6, 4))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(result["avg"].values, label="Avg Loss", color="#ff7043")
    ax.legend(loc="best")
    ax.set_title("Avg Loss Curve")
    ax.set_xlabel("Batches")
    ax.set_ylabel("Avg Loss")


def drawIOU(iou_log_path):
    '''
    :param iou_log_path: 提取到的iou日志信息文件
    :return: 画iou曲线图
    '''
    line_cnt = 0
    for count, line in enumerate(open(iou_log_path, "rU")):
        line_cnt += 1
    result = pd.read_csv(iou_log_path, skiprows=[x for x in range(line_cnt) if (x % 39 != 0 | (x < 5000))],
                         error_bad_lines=False,
                         names=["Region Avg IOU", "Class", "Obj", "No Obj", "Avg Recall", "count"])
    result["Region Avg IOU"] = result["Region Avg IOU"].str.split(": ").str.get(1)

    result["Region Avg IOU"] = pd.to_numeric(result["Region Avg IOU"])

    result_iou = result["Region Avg IOU"].values
    # 平滑iou曲线
    for i in range(len(result_iou) - 1):
        iou = result_iou[i]
        iou_next = result_iou[i + 1]
        if abs(iou - iou_next) > 0.2:
            result_iou[i] = (iou + iou_next) / 2

    fig = plt.figure(2, figsize=(6, 4))
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(result_iou, label="Region Avg IOU", color="#ff7043")
    ax.legend(loc="best")
    ax.set_title("Avg IOU Curve")
    ax.set_xlabel("Batches")
    ax.set_ylabel("Avg IOU")


if __name__ == "__main__":
    loss_log_path = "train_log_loss.txt"
    iou_log_path = "train_log_iou.txt"
    if os.path.exists(g_log_path) is False:
        exit(-1)
    if os.path.exists(loss_log_path) is False:
        extract_log(g_log_path, loss_log_path, "images")
    if os.path.exists(iou_log_path) is False:
        extract_log(g_log_path, iou_log_path, "IOU")
    drawAvgLoss(loss_log_path)
    drawIOU(iou_log_path)
    plt.show()

在这里插入图片描述

【完】

猜你喜欢

转载自blog.csdn.net/csdn_zhishui/article/details/85397380
今日推荐