目标检测--RFBNet训练自己制作数据集出现loss=nan问题的解决方法

之前用RFBNet进行目标检测,采用的数据集是VOC2007和VOC2012。最近用在自己的数据集进行训练,由于我的数据集格式跟VOC格式不一样,根据网上的经验,我就开始将自己的数据集制作成VOC格式的方便训练。但自己的数据集和标准的数据集质量真心不能比,有很多问题,花费了好多时间在数据处理上。。。
我遇到的问题主要是数据集的问题,而RFBNet是基于SSD的,所以SSD的如果出现这个问题大概率是一样的,当然其他目标检测网络也可以参考。

训练自己数据集

loss=nan问题
在制作完自己数据集后,训练RFBNet的时候,出现loss_l=nan的情况。
在这里插入图片描述
同时还出现RuntimeWarning:

Code/RFBNet-pytorch0.4.0/utils/box_utils.py:84: RuntimeWarning: invalid value encountered in true_divide
  return area_i / (area_a[:, np.newaxis] + area_b - area_i)

说明在代码的utils/box_utils.py的match_iou方法返回值中除法遇到了除数为0的情况。
在这里插入图片描述
解决方法
在网上查了一番,发现大家也都遇到类似的情况。
出现nan情况一般有以下集中可能:

  1. 数据问题,例如目标检测中可能出现bbox位置为(0,0),可能xmin>xmax
  2. 网络结构问题
  3. 训练问题等

在逼乎上某一大佬根据他自己经验给出一回答,将batch_size 调成1,shuffle调成False,查看到底哪些数据存在问题。受他启发,我也进行检查,将train_RFB.pybatch_size 调成1,shuffle调成False,ngpu、num_workers调成1,发现网络只是在某一iteration的时候出现loss=nan,所以基本确定是数据集的问题。

在这里插入图片描述
确定了是数据集的问题,还得确定是数据集的哪里出现了问题。这个过程实在是有点大海捞针的感觉,我试过将xml中bndbox值转成int,因为VOC中坐标值都是Int,object的name中特殊的字符映射到正常的字符等等。。。但这些都没能解决问题。

就在今天早上,读到一篇[微信推送]帮了我大忙。里面讲到bndbox的iou计算问题,其中有xmax-xmin,ymax-ymin。我就猜想会不会是这个问题。于是我先在VOC2007中检测是否存xmin>xmax的情况,结果发现没有!
然后我又在我的数据集中查找,结果发现,我去,1万条数据中有2000多条存在xmin>xmax!!!

检测xmin>xmax程序

import os
import xml.etree.ElementTree as ET 
xml_dir = './Annotations'
def compare_min_max(xml_dir):
    xmls = os.listdir(xml_dir)
    xmls.sort()
    flag = 0
    count = 0
    for xml in xmls:
        xml_path = os.path.join(xml_dir, xml)
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for elem in root.findall('object'):
            xmin = elem.find('bndbox').find('xmin').text
            ymin = elem.find('bndbox').find('ymin').text
            xmax = elem.find('bndbox').find('xmax').text
            ymax = elem.find('bndbox').find('ymax').text
            if int(ymin) > int(ymax) or int(xmin) > int(xmax):
                print('min > max in file:',xml_path)
                flag = 1
        if flag == 1:
            count += 1
            flag = 0
    print('{} files that min > max'.format(count))
    print('finish comparision...')

if __name__ == '__main__':
    compare_min_max(xml_dir)

在发现自己数据集存在这个问题后,我就重新制作了一遍数据集,这次数据集没有出现xmin(ymin)>xmax(ymax)的情况了。

制作数据集程序

##将数据集中img和rext中的信息转成VOC annotation的 xml格式
import os
import cv2
import io
import pandas as pd

img_path = r'./JPEGImages'
rect_path = r'./rect'
xml_path = r'./Annotations'

#read images ,get image's w,h c and name
def read_image(filename):
    img = cv2.imread(filename)
    h, w, c = img.shape
    basename = os.path.basename(filename)
    return h, w, c, basename

#读取rect的txt中第一个空行之前的内容
def file_reader(filename):
    with open(filename) as f:
        for line in f:
            if line and line != '\n':
                yield line
            else:
                break
#获取目标字符类别以及bbox
def get_object_bbox(filename):
    bbox = []
    data = io.StringIO(''.join(file_reader(filename)))
    dataframe = pd.read_csv(data, skiprows=2, header=None)
    for row in dataframe.iterrows():
        if isinstance(row[1][2], str):
            row[1][2] = row[1][2].strip() #delete space in string
        r = [row[1][2],row[1][5], row[1][6], row[1][7], row[1][8]]
        bbox.append(r)
    return bbox #shape(n,5) n number of bndboxes, each bndbox has the form[object, xmin, ymin,xmax,ymax]

def write_xml(h, w, c, bbox, basename):
    front, extend = os.path.splitext(basename)
    front += '.xml'
    full_path = os.path.join(xml_path, front)
    with open(full_path,'w') as f:
        f.write('<annotation>\n')
        f.write('    <folder>OHWME</folder>\n')
        f.write('    <filename>' + str(basename) + '</filename>\n')
        f.write('    <source>\n')
        f.write('        <database>MyDataBase</database>\n')
        f.write('        <annotation>PASCAL VOC2007</annotation>\n')
        f.write('        <image>f</image>\n')
        f.write('    </source>\n')
        f.write('    <size>\n')
        f.write('        <width>' + str(w) + '</width>\n')
        f.write('        <height>' + str(h) + '</height>\n')
        f.write('        <depth>' + str(c) + '</depth>\n')
        f.write('    </size>\n')
        f.write('    <segmented>0</segmented>\n')
        for b in bbox:
            object = b[0]
            if object == '/':
                object = r'\backslash'
            if object == '.':
                object = r'\dot'
            xmin = b[1]
            ymin = b[2]
            xmax = b[3]
            ymax = b[4]
            f.write('    <object>\n')
            f.write('        <name>' + str(object) + '</name>\n')
            f.write('        <pose>Unspecified</pose>\n')
            f.write('        <truncated>0</truncated>\n')
            f.write('        <difficult>0</difficult>\n')
            f.write('        <bndbox>\n')
            ##avoid xmin,ymin > xmax,ymax
            if int(xmin) > int(xmax): 
                xmax, xmin = xmin, xmax
            if int(ymin) > int(ymax):
                ymax, ymin = ymin, ymax
            #avoid (0,0) which would probaly result in nan
            if int(xmin) < 1:
                f.write('            <xmin>' + str(int(xmin + 1)) + '</xmin>\n')
            else:
                f.write('            <xmin>' + str(int(xmin)) + '</xmin>\n')
            if int(ymin) < 1:
                f.write('            <ymin>' + str(int(ymin + 1)) + '</ymin>\n')
            else:
                f.write('            <ymin>' + str(int(ymin)) + '</ymin>\n')
            if int(xmax < 1):
                f.write('            <xmax>' + str(int(xmax + 1)) + '</xmax>\n')
            else:
                f.write('            <xmax>' + str(int(xmax)) + '</xmax>\n')
            if int(ymax < 1):
                f.write('            <ymax>' + str(int(ymax + 1)) + '</ymax>\n')
            else:
                f.write('            <ymax>' + str(int(ymax)) + '</ymax>\n')
            f.write('        </bndbox>\n')
            f.write('    </object>\n')
        f.write('</annotation>')

if __name__ == '__main__':
    img_names = os.listdir(img_path)
    rect_names = os.listdir(rect_path)
    img_names.sort()
    rect_names.sort()

    for img_name, rect_name in zip(img_names, rect_names):
        full_image_path = os.path.join(img_path, img_name)
        full_rect_path = os.path.join(rect_path, rect_name)
        h, w, c, basename = read_image(full_image_path)
        bbox = get_object_bbox(full_rect_path)
        print('writing {}\\{}.xml'.format(xml_path,os.path.splitext(basename)[0]))
        write_xml(h, w, c, bbox, basename)

训练结果
重新制作数据集后,训练过程如下:
在这里插入图片描述
在这里插入图片描述
从图中可以看出loss_l L已经正常并且开始下降,说明数据集格式正确了,结果由于网络还在训练,所以还没有test结果,但至少说明开始正确训练了,至于mAP能有多少还得调参哈哈哈(请叫我调参侠)。

总结

目标检测中若遇到loss为nan的情况,
首先,检查数据集格式问题。如bbox的xmin,ymin是否大于xmax,ymax,或者坐标是否存在为0的情况。
其次,检查网络结构是否存在问题。
还有,训练的方法是否有问题。

希望我的经验能帮助遇到类似问题的朋友,少掉点头发,少走点弯路。

参考

https://www.zhihu.com/question/49346370
https://mp.weixin.qq.com/s/TMRDhDrf5rRRFIdGGL8Uhg

猜你喜欢

转载自blog.csdn.net/weixin_40313940/article/details/105915575