Use the pre-trained model of TensorFlow Object Detection API to train your own data

Use the pre-trained model of TensorFlow Object Detection API to train your own data

1. Introduction

  • The pb file is a trained model that can be used directly
  • The ckpt file is the pre-trained model used to train your own data

2. Preliminary preparation

  • Prepare a folder for saving collected pictures, including Image and Annotations, which are used to save pictures and marked xml files respectively
  • In addition, prepare a folder to store training-related data, which contains three subordinate files data, export, and model, which are used to store the available data for training, the final model generated, and the files generated by training.

2.1 Environment Construction

  • Configure the Tensorflow environment, either Windows or Ubuntu

2.2 Data preparation

  1. Collect the required pictures according to your own training needs

  2. Sort the collected pictures, filter and then sort them

    If you are processing the data set collected by yourself, you must first sort and then filter! ! Otherwise, some pictures that should be screened may be missed and the workload will be increased when labeling

    The method I use is to rename all files in order

    import os
    i = 1
    for filename in os.listdir('D:/DataCollection/hand_data/Image/test/'):
    	newname = str(i) + '.jpg'
    	print(newname)
    	os.rename('D:/DataCollection/hand_data/Image/test/'+filename, 'D:/DataCollection/hand_data/Image/test/'+newname)
    	i += 1
    
  3. Annotate the sorted images

    The software used to label images is labelImg, you can choose the location of the labeled image (Image), and the location where the generated xml file is saved is the Annotations folder

    W is the label D is the next one A is the previous one to save

  4. format conversion

    The generated csv and tfrecord files here are placed in the data folder

    The image needs to be converted into a format that tensorflow can recognize

    • First convert from xml to csv

      """
      将文件夹内所有XML文件的信息记录到CSV文件中
      """
      
      import os
      import glob
      import pandas as pd
      import xml.etree.ElementTree as ET
       
      os.chdir('E:/tensorflow/hand_data_new/hand_data/test')  #xml文件保存路径 使用时需改为自己的路径
      path = 'E:/tensorflow/hand_data_new/hand_data/test'
      
      
      def xml_to_csv(path):
          xml_list = []
          for xml_file in glob.glob(path + '/*.xml'):
              tree = ET.parse(xml_file)
              root = tree.getroot()
              print('test')
              for member in root.findall('object'):
                  value = (root.find('filename').text,
                           int(root.find('size')[0].text),
                           int(root.find('size')[1].text),
                           member[0].text,
                           int(member[4][0].text),
                           int(member[4][1].text),
                           int(member[4][2].text),
                           int(member[4][3].text)
                           )
                  xml_list.append(value)
          column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
          xml_df = pd.DataFrame(xml_list, columns=column_name)
          return xml_df
      
      
      def main():
          image_path = path
          xml_df = xml_to_csv(image_path)
          xml_df.to_csv('E:/tensorflow/hand_set/data/eval.csv', index=None)  #得到的csv文件保存路径
          print('Successfully converted xml to csv.')
      
      main()
      
    • Then convert the csv file to tfrecord

      from __future__ import division
      from __future__ import print_function
      from __future__ import absolute_import
      
      import os
      import io
      import pandas as pd
      import tensorflow as tf
      
      from PIL import Image
      from object_detection.utils import dataset_util
      from collections import namedtuple, OrderedDict
      
      flags = tf.app.flags
      
      flags.DEFINE_string('csv_input', 'E:/tensorflow/hand_set/data/eval.csv', 'Path to the CSV input')#csv文件
      flags.DEFINE_string('output_path', 'E:/tensorflow/hand_set/data/eval.record', 'Path to output TFRecord')#TFRecord文件
      flags.DEFINE_string('image_dir', 'E:/tensorflow/hand_data_new/hand_data/Image/TEST', 'Path to images')#对应的图片位置
      
      FLAGS = flags.FLAGS
      
      # TO-DO replace this with label map
      #从1开始根据自己训练的类别数和标签来写
      def class_text_to_int(row_label):
          if row_label == 'DOWN':
              return 1
          elif row_label == 'FIVE':
              return 2
          else:
              None
      
      def split(df, group):
          data = namedtuple('data', ['filename', 'object'])
          gb = df.groupby(group)
          return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
      
      
      def create_tf_example(group, path):
          with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
      
              encoded_jpg = fid.read()
      
          encoded_jpg_io = io.BytesIO(encoded_jpg)
          image = Image.open(encoded_jpg_io)
          width, height = image.size
      
          filename = group.filename.encode('utf8')
          image_format = b'jpg'
          xmins = []
          xmaxs = []
          ymins = []
          ymaxs = []
          classes_text = []
          classes = []
      
          for index, row in group.object.iterrows():
              xmins.append(row['xmin'] / width)
              xmaxs.append(row['xmax'] / width)
              ymins.append(row['ymin'] / height)
              ymaxs.append(row['ymax'] / height)
              classes_text.append(row['class'].encode('utf8'))
              classes.append(class_text_to_int(row['class']))
      
          tf_example = tf.train.Example(features=tf.train.Features(feature={
              
              
      
              'image/height': dataset_util.int64_feature(height),
      
              'image/width': dataset_util.int64_feature(width),
      
              'image/filename': dataset_util.bytes_feature(filename),
      
              'image/source_id': dataset_util.bytes_feature(filename),
      
              'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      
              'image/format': dataset_util.bytes_feature(image_format),
      
              'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      
              'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      
              'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      
              'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      
              'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      
              'image/object/class/label': dataset_util.int64_list_feature(classes),
      
          }))
      
          return tf_example
      
      
      def main(_):
      
          writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
      
          path = os.path.join(FLAGS.image_dir)
      
          examples = pd.read_csv(FLAGS.csv_input)
      
          grouped = split(examples, 'filename')
      
          for group in grouped:
      
              tf_example = create_tf_example(group, path)
      
              writer.write(tf_example.SerializeToString())
      
      
          writer.close()
      
          output_path = os.path.join(os.getcwd(), FLAGS.output_path)
      
          print('Successfully created the TFRecords: {}'.format(output_path))
      
      
      
      if __name__ == '__main__':
      
          tf.app.run()
  5. After the training data is prepared, a pbtxt file needs to be prepared

    For example, hand.pbtxt, placed in the data folder

    The content is as follows, depending on the number of categories

    item {
      id: 1
      name: 'DOWN'
    }
    
    item{
      id: 2
      name: 'FIVE'
    }
    

2.3 Model preparation

Download Tensorflow model

Download address: https://github.com/tensorflow/models

download protoc

Download address: https://github.com/protocolbuffers/protobuf/releases

Download the pretrained model

Download address: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md

  1. Use protoc to compile TensorFlow Object Detection API and convert it to a py file
  2. Create a special folder to save the pre-training model, remember the download path, and then use the ckpt file inside
  3. Find models\research\object_detection\samples\configs under the downloaded Tensorflow model file, find the config file corresponding to the pre-training model you use, copy a copy and put it in the initially created model folder

3. Training process

3.1 Modify the configuration file (config file)

​ The following is just a list of the most important modifications, and other relevant training configurations can be adjusted according to the actual situation

  • Change to the number of categories trained by yourself

    num_classes: 4
    
  • According to the performance of your own machine, you can modify it or not

    batch_size: 24
    
  • Change to the pre-trained model path used

    fine_tune_checkpoint: "E:/tensorflow/pretrained_models/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
    
  • The path of the tfrecord file required for training, and the path of the test set for the test

      tf_record_input_reader {
        input_path: "E:/tensorflow/hand_set/data/train.record"
      }
    
  • Label mapping file, that is, the location of the pbtxt file, which is shared by training and testing

     label_map_path: "E:/tensorflow/hand_set/data/object_detection.pbtxt"
    

3.2 Start training

execute statement

python E:/tensorflow/models/research/object_detection/legacy/train.py --train_dir=E:/tensorflow/hand_set/model/model5 --pipeline_config_path=E:/tensorflow/hand_set/model/ssd_mobilenet_v2_quantized_300x300_coco.config --logtostderr
  • train.py is under the downloaded Tensorflow model folder
  • train_dir is the location where the data is saved during training, and it is placed under the initially created model folder. Because I have trained multiple models, I created multiple subfolders named model1 and so on under the model folder. For example, in this example I Save the model in model/model5
  • pipeline is the location of the config file, I put it under the model file

3.3 Save the model

The first three folders data are used to store data sets, and model is the data during training, mainly including the ckpt files corresponding to each checkpoint that can generate the model, as well as the information during the training process, and export is to save our exported model

execute statement

python E:/tensorflow/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path E:/tensorflow/hand_set/model/ssd_mobilenet_v2_coco.config  --trained_checkpoint_prefix E:/tensorflow/hand_set/model/model4/model.ckpt-5997  --output_directory E:/tensorflow/hand_set/export/model4/
  • export_inference_graph.py is under the downloaded Tensorflow model folder
  • The location of pipeline_config_path is the same as above
  • trained_checkpoint_prefix selects the best checkpoint to generate the model, generally choose the latest
  • output_directory model save path
  • The final generated pb file is the model we can use

3.4 Tensorboard to view the training effect in real time

Win+r, enter cmd, execute the statement, enter the absolute path of the saved model just trained, remember to enter the absolute path is not easy to make mistakes

tensorboard --logdir=E:\tensorflow\hand_set\model\model5

Then enter https://localhost:6006 in the browser to view the training effect

4. Test results

  • Find the object_detection_tutorial.ipynb file in tensorflow\models\research\object_detection under the Tensorflow model folder, and copy the code out

  • Modify the model to the address of our own trained model, that is, the address of the pb file

    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_FROZEN_GRAPH =
    'E:/tensorflow/hand_set/export/model4/frozen_inference_graph.pb'
  • The pbtxt file address is also changed to our own file address

    PATH_TO_LABELS = os.path.join('E:/tensorflow/hand_set/data', 'object_detection.pbtxt')
  • Set test image path

    PATH_TO_TEST_IMAGES_DIR = 'test_images'
    TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
  • It can also be changed to camera real-time test

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            while  True:
                ret, image = capture.read()
                if ret is True:
                    image_np_expanded = np.expand_dims(image, axis=0)
                    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                    scores = detection_graph.get_tensor_by_name('detection_scores:0')
                    classes = detection_graph.get_tensor_by_name('detection_classes:0')
                    num_detections = detection_graph.get_tensor_by_name('num_detections:0')
    
                    (boxes,scores,classes,num_detections)=sess.run([boxes, scores, classes, num_detections],
                                                                    feed_dict={
          
          image_tensor: image_np_expanded})
                    vis_util.visualize_boxes_and_labels_on_image_array(
                        image,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        category_index,
                        min_score_thresh=0.6, #置信度
                        use_normalized_coordinates=True,
                        line_thickness=4
                    )
                    c = cv.waitKey(5)
                    if c == 27:  # ESC
                        break
                    cv.imshow("Demo", image)
                else:
                    break
            cv.waitKey(0)
            cv.destoryAllWindows()
  • run code

Guess you like

Origin blog.csdn.net/Lianhaiyan_zero/article/details/107638516