利用TensorFlow Object Detection API的预训练模型训练自己的数据

利用TensorFlow Object Detection API的预训练模型训练自己的数据

1.前言介绍

  • pb文件为训练好的模型,可以直接拿来使用
  • ckpt文件就是预训练模型,用来训练自己的数据

2.前期准备

  • 准备一个保存收集图片的文件夹,包含Image和Annotations,分别用来保存图片和标注后的xml文件
  • 另外准备一个文件夹放训练有关的数据,里面包含三个下属文件data,export,model,分别用来存放训练可用的数据,生成的最终模型,训练产生的文件

2.1环境搭建

  • 配置Tensorflow环境,Windows或Ubuntu都可

2.2数据准备

  1. 根据自己训练需要收集所需要的图片

  2. 将所收集的图片进行排序后进行筛选然后再排序

    如果是处理自己采集的数据集,一定要先排序再筛选!!否则可能会遗漏掉一些本该筛选的图片在标注时增加自己的工作量

    我用的方法是按顺序对所有文件进行重命名

    import os
    i = 1
    for filename in os.listdir('D:/DataCollection/hand_data/Image/test/'):
    	newname = str(i) + '.jpg'
    	print(newname)
    	os.rename('D:/DataCollection/hand_data/Image/test/'+filename, 'D:/DataCollection/hand_data/Image/test/'+newname)
    	i += 1
    
  3. 对排序后的图片进行标注

    标注图片用的软件是labelImg,可以选择标注的图片位置(Image),以及生成的xml文件保存的位置即Annotations文件夹

    W是标注 D是下一张 A是上一张 空格保存

  4. 格式转换

    这里的生成的csv以及tfrecord文件都放在data文件夹下

    图片需转换成tensorflow可以识别的格式

    • 先由xml转为csv

      """
      将文件夹内所有XML文件的信息记录到CSV文件中
      """
      
      import os
      import glob
      import pandas as pd
      import xml.etree.ElementTree as ET
       
      os.chdir('E:/tensorflow/hand_data_new/hand_data/test')  #xml文件保存路径 使用时需改为自己的路径
      path = 'E:/tensorflow/hand_data_new/hand_data/test'
      
      
      def xml_to_csv(path):
          xml_list = []
          for xml_file in glob.glob(path + '/*.xml'):
              tree = ET.parse(xml_file)
              root = tree.getroot()
              print('test')
              for member in root.findall('object'):
                  value = (root.find('filename').text,
                           int(root.find('size')[0].text),
                           int(root.find('size')[1].text),
                           member[0].text,
                           int(member[4][0].text),
                           int(member[4][1].text),
                           int(member[4][2].text),
                           int(member[4][3].text)
                           )
                  xml_list.append(value)
          column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
          xml_df = pd.DataFrame(xml_list, columns=column_name)
          return xml_df
      
      
      def main():
          image_path = path
          xml_df = xml_to_csv(image_path)
          xml_df.to_csv('E:/tensorflow/hand_set/data/eval.csv', index=None)  #得到的csv文件保存路径
          print('Successfully converted xml to csv.')
      
      main()
      
    • 然后将csv文件转为tfrecord

      from __future__ import division
      from __future__ import print_function
      from __future__ import absolute_import
      
      import os
      import io
      import pandas as pd
      import tensorflow as tf
      
      from PIL import Image
      from object_detection.utils import dataset_util
      from collections import namedtuple, OrderedDict
      
      flags = tf.app.flags
      
      flags.DEFINE_string('csv_input', 'E:/tensorflow/hand_set/data/eval.csv', 'Path to the CSV input')#csv文件
      flags.DEFINE_string('output_path', 'E:/tensorflow/hand_set/data/eval.record', 'Path to output TFRecord')#TFRecord文件
      flags.DEFINE_string('image_dir', 'E:/tensorflow/hand_data_new/hand_data/Image/TEST', 'Path to images')#对应的图片位置
      
      FLAGS = flags.FLAGS
      
      # TO-DO replace this with label map
      #从1开始根据自己训练的类别数和标签来写
      def class_text_to_int(row_label):
          if row_label == 'DOWN':
              return 1
          elif row_label == 'FIVE':
              return 2
          else:
              None
      
      def split(df, group):
          data = namedtuple('data', ['filename', 'object'])
          gb = df.groupby(group)
          return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
      
      
      def create_tf_example(group, path):
          with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
      
              encoded_jpg = fid.read()
      
          encoded_jpg_io = io.BytesIO(encoded_jpg)
          image = Image.open(encoded_jpg_io)
          width, height = image.size
      
          filename = group.filename.encode('utf8')
          image_format = b'jpg'
          xmins = []
          xmaxs = []
          ymins = []
          ymaxs = []
          classes_text = []
          classes = []
      
          for index, row in group.object.iterrows():
              xmins.append(row['xmin'] / width)
              xmaxs.append(row['xmax'] / width)
              ymins.append(row['ymin'] / height)
              ymaxs.append(row['ymax'] / height)
              classes_text.append(row['class'].encode('utf8'))
              classes.append(class_text_to_int(row['class']))
      
          tf_example = tf.train.Example(features=tf.train.Features(feature={
              
              
      
              'image/height': dataset_util.int64_feature(height),
      
              'image/width': dataset_util.int64_feature(width),
      
              'image/filename': dataset_util.bytes_feature(filename),
      
              'image/source_id': dataset_util.bytes_feature(filename),
      
              'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      
              'image/format': dataset_util.bytes_feature(image_format),
      
              'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      
              'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      
              'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      
              'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      
              'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      
              'image/object/class/label': dataset_util.int64_list_feature(classes),
      
          }))
      
          return tf_example
      
      
      def main(_):
      
          writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
      
          path = os.path.join(FLAGS.image_dir)
      
          examples = pd.read_csv(FLAGS.csv_input)
      
          grouped = split(examples, 'filename')
      
          for group in grouped:
      
              tf_example = create_tf_example(group, path)
      
              writer.write(tf_example.SerializeToString())
      
      
          writer.close()
      
          output_path = os.path.join(os.getcwd(), FLAGS.output_path)
      
          print('Successfully created the TFRecords: {}'.format(output_path))
      
      
      
      if __name__ == '__main__':
      
          tf.app.run()
  5. 训练数据准备完以后还需要准备一个pbtxt文件

    例如hand.pbtxt,放在data文件夹里

    内容如下,根据自己的类别数而定

    item {
      id: 1
      name: 'DOWN'
    }
    
    item{
      id: 2
      name: 'FIVE'
    }
    

2.3模型准备

下载Tensorflow模型

下载地址:https://github.com/tensorflow/models

下载protoc

下载地址:https://github.com/protocolbuffers/protobuf/releases

下载预训练模型

下载地址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

  1. 利用protoc编译TensorFlow Object Detection API,转换为py文件
  2. 建立一个专门的文件夹来保存预训练模型,记住下载路径,之后会用到里面的ckpt文件
  3. 在下载的Tensorflow模型的文件下找到models\research\object_detection\samples\configs,在里面找到自己所用的预训练模型对应的config文件,拷贝一份放在最初建立的model文件夹下

3.训练过程

3.1修改配置文件(config文件)

​ 以下仅仅是列出最主要的修改,其他有关训练配置可根据实际情况再做调整

  • 改成自己训练的类别数量

    num_classes: 4
    
  • 根据自己的机器性能适当修改也可以不改

    batch_size: 24
    
  • 改成所用的预训练模型路径

    fine_tune_checkpoint: "E:/tensorflow/pretrained_models/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
    
  • 训练所需的tfrecord文件路径,测试的则改为测试集的路径

      tf_record_input_reader {
        input_path: "E:/tensorflow/hand_set/data/train.record"
      }
    
  • 标签映射文件,即pbtxt文件位置,训练与测试共用一个

     label_map_path: "E:/tensorflow/hand_set/data/object_detection.pbtxt"
    

3.2开始训练

执行语句

python E:/tensorflow/models/research/object_detection/legacy/train.py --train_dir=E:/tensorflow/hand_set/model/model5 --pipeline_config_path=E:/tensorflow/hand_set/model/ssd_mobilenet_v2_quantized_300x300_coco.config --logtostderr
  • train.py在下载的Tensorflow模型文件夹下
  • train_dir是训练时的数据保存位置,放在最初建立的model文件夹下,因为我训练了多个模型,因此我在model文件夹下建立了多个子文件夹命名为model1等等,例如此例子我将该模型保存在model/model5中
  • pipeline是config文件的位置,我放在model文件下

3.3保存模型

最初建立的三个文件夹data是用来存放数据集的,而model是训练时的数据,主要包括各个检查点对应的能够生成模型的ckpt文件,以及训练过程中的信息,而export就是保存我们导出的模型

执行语句

python E:/tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path E:/tensorflow/hand_set/model/ssd_mobilenet_v2_coco.config  --trained_checkpoint_prefix E:/tensorflow/hand_set/model/model4/model.ckpt-5997  --output_directory E:/tensorflow/hand_set/export/model4/
  • export_inference_graph.py在下载的Tensorflow模型文件夹下
  • pipeline_config_path位置同上
  • trained_checkpoint_prefix选择效果最好的检查点来生成模型,一般选择最新的
  • output_directory模型保存路径
  • 最后生成的pb文件就是我们可以用的模型

3.4Tensorboard实时查看训练效果

win+r,输入cmd,执行语句,输入刚刚训练保存模型的绝对路径,记得输入绝对路径不容易出错

tensorboard --logdir=E:\tensorflow\hand_set\model\model5

然后在浏览器里输入https://localhost:6006 就可以查看训练效果了

4.测试结果

  • 在Tensorflow模型文件夹下tensorflow\models\research\object_detection 找到object_detection_tutorial.ipynb文件,将代码复制出来

  • 将模型修改为我们自己训练的模型地址,即pb文件的地址

    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_FROZEN_GRAPH =
    'E:/tensorflow/hand_set/export/model4/frozen_inference_graph.pb'
  • pbtxt文件地址也改为我们自己的文件地址

    PATH_TO_LABELS = os.path.join('E:/tensorflow/hand_set/data', 'object_detection.pbtxt')
  • 设置测试图片路径

    PATH_TO_TEST_IMAGES_DIR = 'test_images'
    TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  • 也可以改为摄像头实时测试

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            while  True:
                ret, image = capture.read()
                if ret is True:
                    image_np_expanded = np.expand_dims(image, axis=0)
                    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                    scores = detection_graph.get_tensor_by_name('detection_scores:0')
                    classes = detection_graph.get_tensor_by_name('detection_classes:0')
                    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    
                    (boxes,scores,classes,num_detections)=sess.run([boxes, scores, classes, num_detections],
                                                                    feed_dict={
          
          image_tensor: image_np_expanded})
                    vis_util.visualize_boxes_and_labels_on_image_array(
                        image,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        category_index,
                        min_score_thresh=0.6, #置信度
                        use_normalized_coordinates=True,
                        line_thickness=4
                    )
                    c = cv.waitKey(5)
                    if c == 27:  # ESC
                        break
                    cv.imshow("Demo", image)
                else:
                    break
            cv.waitKey(0)
            cv.destoryAllWindows()
  • 运行代码

猜你喜欢

转载自blog.csdn.net/Lianhaiyan_zero/article/details/107638516