利用Transfer Learning Toolkit训练自己的数据集(ubuntu系统)(二:训练模型)

上一篇讲到了环境搭建,这一篇我将继续分析训练模型的过程

首先用的数据集是kitti格式的,刚开始我也是一脸蒙蔽,因为之前用的oc和coco等数据集格式,然后为就看了一下kitti里面的东西
在这里插入图片描述
一行有15个数据,然后我就去百度了一下kitti数据的参数解释,以及看了官网给的解释
在这里插入图片描述
具体如下

第1个字符串:代表物体类别
第2个数:代表物体是否被截断
从0(非截断)到1(截断)浮动,其中truncated指离开图像边界的对象
第3个数:代表物体是否被遮挡
整数0,1,2,3表示被遮挡的程度
0:完全可见  1:小部分遮挡  2:大部分遮挡 3:完全遮挡(unknown)
第4个数:alpha,物体的观察角度,范围:-pi~pi
第5~8这4个数:物体的2维边界框
xmin,ymin,xmax,ymax
第9~11这3个数:3维物体的尺寸
高、宽、长(单位:米)
第12~14这3个数:3维物体的位置
 x,y,z(在照相机坐标系下,单位:米)
第15个数:3维物体的空间方向:rotation_y
在照相机坐标系下,物体的全局方向角(物体前进方向与相机坐标系x轴的夹角),范围:-pi~pi

然后我就寻思着是不是可以直接通过voc数据集转化过去呢?答案肯定是可以的
因为voc可以转到coco数据集,博主之前也弄过yolo4
博主从网上找了一些口罩的数据集(里面包含了label的标签),然后为通过如下的代码

import xml.etree.ElementTree as ET
import os

base_xml_dir = "./label1/"
xml_list = os.listdir(base_xml_dir)
kitti_saved_dir = "./label_2/"


def convert_annotation(file_name):
    in_file = open(base_xml_dir + file_name)
    tree = ET.parse(in_file)
    root = tree.getroot()

    with open(kitti_saved_dir + file_name[:-4] + '.txt', 'w') as f:
        for obj in root.iter('object'):
            cls = obj.find('name').text
            xmlbox = obj.find('bndbox')
            """
                第5~8这4个数:物体的2维边界框
                xmin,ymin,xmax,ymax
            """
            xmin, ymin, xmax, ymax = xmlbox.find('xmin').text, xmlbox.find('ymin').text, \
                                     xmlbox.find('xmax').text, xmlbox.find('ymax').text
            f.write(cls + " " + '0.0' + " " + '0' + " " + '1.0' + " " + str(xmin) + '.0' + " "
                    + str(ymin) + '.0' + " " + str(xmax) + '.0' + " " + str(ymax) + '.0' + " " +
                    str((int(str(ymax)) - int(str(ymin)))/int(1000) )+ " " + str((int(str(xmax)) - int(str(xmin)) )/int(1000))+ " " + '0.1' + " " + '1.0' + " " + '0.0' + " " + '1.0' + " " + '0.0' + '\n')


for i in xml_list:
    convert_annotation(i)

代码其实很简单,把我们voc数据集得不到的东西全部用来填入,但这里需要注意,填入的是整形还是浮点型(官网给的除了第一个是字符窗,第三个是整形其余到是浮点型)

至于voc数据集怎么制作,可以去看我的tensorflow objection api 的文章
至此数据集准备完毕,接下来我们需要看他的文件夹是怎么弄的
进入tlt-experiments,然后新建data文件夹

cd /workspace/tlt-experiments
sudo mkdir data

然后进入data文件,新建一个testing和training文件夹

在这里插入图片描述
在这里插入图片描述
文件夹名字可以自己改,后期我们可以自定义
image_2放的是图片
label_2放的是txt文件

接着我们进入jupyter(就算刚才打开的浏览器)

在这里插入图片描述
第三行的key需要我们自己输入,就是文章一里面提到的
在这里插入图片描述
然后直接跳到下图的位置
在这里插入图片描述
接着找到下图的文件夹,打开ssd_tfrecords_kitti_trainval.txt
在这里插入图片描述

kitti_config {
    
    
  root_directory_path: "/workspace/tlt-experiments/data/training"
  image_dir_name: "image_2"    ##如果和我一样可以不用改,如果你上面自定义文件夹名字,这边需要改
  label_dir_name: "label_2"   ##
  image_extension: ".png"      #自己图片的格式
  partition_mode: "random"
  num_partitions: 2        
  val_split: 14     
  num_shards: 10
}
image_directory_path: "/workspace/tlt-experiments/data/training"

然后执行下方的三步,如果出错,赵找报错的原因,一般不会报错在这里插入图片描述
然后执行这里的四步,这里的第三步可能下载有点慢
在这里插入图片描述
然后我们打开下方文件夹的!cat $SPECS_DIR/ssd_train_resnet18_kitti.txt,改最重要的东西
在这里插入图片描述

random_seed: 42
ssd_config {
    
    
  aspect_ratios_global: "[1.0, 2.0, 0.5, 3.0, 1.0/3.0]"
  scales: "[0.05, 0.1, 0.25, 0.4, 0.55, 0.7, 0.85]"
  two_boxes_for_ar1: true
  clip_boxes: false
  loss_loc_weight: 0.8
  focal_loss_alpha: 0.25
  focal_loss_gamma: 2.0
  variances: "[0.1, 0.1, 0.2, 0.2]"
  arch: "resnet"    ## 网络的类型,如果是mobilenet_v2,需要自己改
  nlayers: 18
  freeze_bn: false
}
training_config {
    
    
  batch_size_per_gpu: 24    ## 建议改小,不然训练会报错
  num_epochs: 80
  learning_rate {
    
    
  soft_start_annealing_schedule {
    
    
    min_learning_rate: 5e-5
    max_learning_rate: 2e-2
    soft_start: 0.15
    annealing: 0.5
    }
  }
  regularizer {
    
    
    type: L1
    weight: 3e-06
  }
}
eval_config {
    
    
  validation_period_during_training: 10
  average_precision_mode: SAMPLE
  batch_size: 32       ## 建议改小,不然训练会报错
  matching_iou_threshold: 0.5
}
nms_config {
    
    
  confidence_threshold: 0.01
  clustering_iou_threshold: 0.6
  top_k: 200
}
augmentation_config {
    
    
  preprocessing {
    
    
    output_image_width: 1248
    output_image_height: 384
    output_image_channel: 3
    crop_right: 1248
    crop_bottom: 384
    min_bbox_width: 1.0
    min_bbox_height: 1.0
  }
  spatial_augmentation {
    
    
    hflip_probability: 0.5
    vflip_probability: 0.0
    zoom_min: 0.7
    zoom_max: 1.8
    translate_max_x: 8.0
    translate_max_y: 8.0
  }
  color_augmentation {
    
    
    hue_rotation_max: 25.0
    saturation_shift_max: 0.20000000298
    contrast_scale_max: 0.10000000149
    contrast_center: 0.5
  }
}
dataset_config {
    
    
  data_sources: {
    
    
    tfrecords_path: "/workspace/tlt-experiments/data/tfrecords/kitti_trainval/kitti_trainval*"
    image_directory_path: "/workspace/tlt-experiments/data/training"
  }
  image_extension: "png"      #3 图片格式


## 改为自己的标签名字
  target_class_mapping {
    
    
      key: "have_mask"
      value: "have_mask"
  }
  target_class_mapping {
    
    
      key: "no_mask"
      value: "no_mask"
  }
  
validation_fold: 0
}


改完记得保存

然后运行下面的2步
在这里插入图片描述
然后你改了网络,那么下图的文件路径也要改
在这里插入图片描述

如果到这里都没有问题,那么下面的基本都是ok的

在这里插入图片描述

教程至此基本结束了,感谢支持

猜你喜欢

转载自blog.csdn.net/weixin_44868057/article/details/107213001
今日推荐