tensorflow学习2:将训练数据转为tfrecord

目标检测训练数据的一般包括图像和对应的标注xml文件,这里以四边形标注目标,如下:

转换为tfrecord文件

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def read_xml_gtbox_and_label(xml_path):

    """
    :param xml_path: the path of voc xml
    :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
           and has [xmin, ymin, xmax, ymax, label] in a per row
    """

    tree = ET.parse(xml_path)
    root = tree.getroot()
    img_width = None
    img_height = None
    box_list = []
    for child_of_root in root:
        # if child_of_root.tag == 'filename':
        #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
        #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'

        if child_of_root.tag == 'size':
            for child_item in child_of_root:
                if child_item.tag == 'width':
                    img_width = int(child_item.text)
                if child_item.tag == 'height':
                    img_height = int(child_item.text)

        if child_of_root.tag == 'object':
            label = None
            for child_item in child_of_root:
                if child_item.tag == 'name':
                    category = child_item.text.encode("utf-8")    #如果xml文件中目标的类别是中文,那么就需要对child_item.text进行‘utf-8’编码转换为str格式(child_item.text是Unicode格式)
                    #category = child_item.text                    #如果xml文件中目标的类别是英文
                    label = NAME_LABEL_MAP[category]
                if child_item.tag == 'bndbox':
                    tmp_box = []
                    for node in child_item:
                        tmp_box.append(int(node.text))  # [x1, y1. x2, y2, x3, y3, x4, y4]
                    assert label is not None, 'label is none, error'
                    tmp_box.append(label)  # [x1, y1. x2, y2, x3, y3, x4, y4, label]
                    box_list.append(tmp_box)

    gtbox_label = np.array(box_list, dtype=np.int32)  # [x1, y1. x2, y2, x3, y3, x4, y4, label]

    return img_height, img_width, gtbox_label


def convert_pascal_to_tfrecord():
    '''
    每一张样本图片可以看做是一个example,每个Example中包含features
    features里包含feature(这里没s)的字典,feature分为FloatList,或ByteList,或Int64List 的格式  
    例如该例子中,首先利用tf.train.Features函数来创建每一个样本的features
    features中包括样本的名称(img_name)、高度(img_height)等字典信息,这些字典信息要利用tf.train.Feature函数创建
    例如样本名称是二进制的格式,因此在开头创建了_bytes_feature函数,其中调用tf.train.Feature函数,并设置为bytes_list
    高度是int64形式,因此创建了_int64_feature函数,其中调用tf.train.Feature函数,并设置为Int64List 
    最后再利用tf.train.Example函数,将上述的features赋给Example
    '''
    xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
    image_path = FLAGS.VOC_dir + FLAGS.image_dir
    save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
    #os.mkdir(FLAGS.save_dir)

    #writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)   #定义了tfrecords文件压缩类型:无,ZLIB,GZIP三种
    #writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)  #建立TFRecord存储器,path是TFRecords文件的路径
    writer = tf.python_io.TFRecordWriter(path=save_path)  #可以用该行代码代替前两个
    
    for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
        # to avoid path error in different development platform
        xml = xml.replace('\\', '/')

        img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
        img_path = image_path + '/' + img_name

        if not os.path.exists(img_path):
            print('{} is not exist!'.format(img_path))
            continue

        img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)

        # img = np.array(Image.open(img_path))
        img = cv2.imread(img_path)

        feature = tf.train.Features(feature={
            # maybe do not need encode() in linux
            # 'img_name': _bytes_feature(img_name.encode()),
            'img_name': _bytes_feature(img_name),
            'img_height': _int64_feature(img_height),
            'img_width': _int64_feature(img_width),
            'img': _bytes_feature(img.tostring()),    #
            'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
            'num_objects': _int64_feature(gtbox_label.shape[0])
        })

        example = tf.train.Example(features=feature)

        writer.write(example.SerializeToString()) #把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串
        
        view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))

    print('\nConversion is complete!')

检查tfrecord文件是否有问题

import os
import tensorflow as tf

import sys  
stdi, stdo, stde = sys.stdin, sys.stdout, sys.stderr
reload(sys)
sys.setdefaultencoding('utf-8')
sys.stdin, sys.stdout, sys.stderr = stdi, stdo, stde

def read_single_example_and_decode(filename_queue):
    
    #如果你在上面转换的代码中采用了前面两行,那么相应的就采用下面这两行
    #tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    #reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器
    #否则采用:
    reader = tf.TFRecordReader()
    
    _, serialized_example = reader.read(filename_queue)  #返回文件名和文件

    #解析协议块,返回的值是字典
    features = tf.parse_single_example(
        serialized=serialized_example,
        features={
            'img_name': tf.FixedLenFeature([], tf.string),
            'img_height': tf.FixedLenFeature([], tf.int64),
            'img_width': tf.FixedLenFeature([], tf.int64),
            'img': tf.FixedLenFeature([], tf.string),
            'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
            'num_objects': tf.FixedLenFeature([], tf.int64)   
        }
    )
    img_name = features['img_name']
    img_height = tf.cast(features['img_height'], tf.int32)  #将数据类型int64 转换为int32
    img_width = tf.cast(features['img_width'], tf.int32)   #将数据类型int64 转换为int32
    img = tf.decode_raw(features['img'], tf.uint8)  ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的

    img = tf.reshape(img, shape=[img_height, img_width, 3])   ##转换图片的形状,此处需要用动态形状进行转换

    gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
    gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9])

    num_objects = tf.cast(features['num_objects'], tf.int32)
    return img_name, img, gtboxes_and_label, num_objects


 
directory = os.path.join('/home/yantianwang/rdfpn/data/tfrecord', 'hangtian_ship_train.tfrecord') 
if not os.path.exists(directory):
    print('不存在')
filename_tensorlist = tf.train.match_filenames_once(directory)   # 获取文件列表
filename_queue = tf.train.string_input_producer(filename_tensorlist)# 创建文件输入队列
img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue) #解析数据
img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = tf.train.batch(
                       [img_name, img, gtboxes_and_label, num_objects],
                       batch_size = 1,
                       capacity=100,
                       num_threads=16,
                       dynamic_pad=True) 
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
 
with tf.Session() as sess:
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    
    for step in range(10000):
        print(step,sess.run(img_name_batch))
      
    coord.request_stop()
    coord.join(threads)

如果测试的代码运行无误,那么就说明tfrecord文件没有问题。如果出现问题:

PaddingFIFOQueue '_2_batch/padding_fifo_queue' is closed and has insufficient elements (requested 1, current size 0)

 那么和那程度上说明的你准备的数据有问题,需要检查一下样本和相应的xml文件有无问题,比如xml文件中记录的图像长宽与图像不一致、目标的标注超过了图像的范围等等....

读取tfrecord文件生成batch

import tensorflow as tf
import os
from data.io import image_preprocess


def read_single_example_and_decode(filename_queue):

    #tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)

    #reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器

    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)  #返回文件名和文件

    #解析协议块,返回的值是字典
    features = tf.parse_single_example(
        serialized=serialized_example,
        features={
            'img_name': tf.FixedLenFeature([], tf.string),
            'img_height': tf.FixedLenFeature([], tf.int64),
            'img_width': tf.FixedLenFeature([], tf.int64),
            'img': tf.FixedLenFeature([], tf.string),
            'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
            'num_objects': tf.FixedLenFeature([], tf.int64)   
        }
    )
    img_name = features['img_name']
    img_height = tf.cast(features['img_height'], tf.int32)  #将数据类型int64 转换为int32
    img_width = tf.cast(features['img_width'], tf.int32)   #将数据类型int64 转换为int32
    img = tf.decode_raw(features['img'], tf.uint8)  ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的

    img = tf.reshape(img, shape=[img_height, img_width, 3])   ##转换图片的形状,此处需要用动态形状进行转换

    gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
    gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9])

    num_objects = tf.cast(features['num_objects'], tf.int32)
    return img_name, img, gtboxes_and_label, num_objects


def read_and_prepocess_single_img(filename_queue, shortside_len, is_training):

    img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue)
    # img = tf.image.per_image_standardization(img)
    img = tf.cast(img, tf.float32)
    img = img - tf.constant([103.939, 116.779, 123.68])
    if is_training:
        img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
                                                                    target_shortside_len=shortside_len)
        img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img, gtboxes_and_label=gtboxes_and_label)

    else:
        img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
                                                                    target_shortside_len=shortside_len)

    return img_name, img, gtboxes_and_label, num_objects


def next_batch(dataset_name, batch_size, shortside_len, is_training):
    if dataset_name not in ['ship', 'spacenet', 'pascal', 'coco','hangtian_ship']: #增加自己的数据库名称
        raise ValueError('dataSet name must be in pascal or coco')

    if is_training:
        pattern = os.path.join('../data/tfrecord', dataset_name + '_train*')
    else:
        pattern = os.path.join('../data/tfrecord', dataset_name + '_test*')

    print('tfrecord path is -->', os.path.abspath(pattern))
    filename_tensorlist = tf.train.match_filenames_once(pattern)

    filename_queue = tf.train.string_input_producer(filename_tensorlist)

    img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue, shortside_len,
                                                                              is_training=is_training)
    img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \
        tf.train.batch(
                       [img_name, img, gtboxes_and_label, num_obs],
                       batch_size=batch_size,
                       capacity=100,
                       num_threads=16,
                       dynamic_pad=True)
    return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch

该部分代码包括了对数据的处理

猜你喜欢

转载自blog.csdn.net/Mr_health/article/details/88252652
今日推荐