目标检测训练数据的一般包括图像和对应的标注xml文件,这里以四边形标注目标,如下:
转换为tfrecord文件
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def read_xml_gtbox_and_label(xml_path):
"""
:param xml_path: the path of voc xml
:return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
and has [xmin, ymin, xmax, ymax, label] in a per row
"""
tree = ET.parse(xml_path)
root = tree.getroot()
img_width = None
img_height = None
box_list = []
for child_of_root in root:
# if child_of_root.tag == 'filename':
# assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
# + FLAGS.img_format, 'xml_name and img_name cannot match'
if child_of_root.tag == 'size':
for child_item in child_of_root:
if child_item.tag == 'width':
img_width = int(child_item.text)
if child_item.tag == 'height':
img_height = int(child_item.text)
if child_of_root.tag == 'object':
label = None
for child_item in child_of_root:
if child_item.tag == 'name':
category = child_item.text.encode("utf-8") #如果xml文件中目标的类别是中文,那么就需要对child_item.text进行‘utf-8’编码转换为str格式(child_item.text是Unicode格式)
#category = child_item.text #如果xml文件中目标的类别是英文
label = NAME_LABEL_MAP[category]
if child_item.tag == 'bndbox':
tmp_box = []
for node in child_item:
tmp_box.append(int(node.text)) # [x1, y1. x2, y2, x3, y3, x4, y4]
assert label is not None, 'label is none, error'
tmp_box.append(label) # [x1, y1. x2, y2, x3, y3, x4, y4, label]
box_list.append(tmp_box)
gtbox_label = np.array(box_list, dtype=np.int32) # [x1, y1. x2, y2, x3, y3, x4, y4, label]
return img_height, img_width, gtbox_label
def convert_pascal_to_tfrecord():
'''
每一张样本图片可以看做是一个example,每个Example中包含features
features里包含feature(这里没s)的字典,feature分为FloatList,或ByteList,或Int64List 的格式
例如该例子中,首先利用tf.train.Features函数来创建每一个样本的features
features中包括样本的名称(img_name)、高度(img_height)等字典信息,这些字典信息要利用tf.train.Feature函数创建
例如样本名称是二进制的格式,因此在开头创建了_bytes_feature函数,其中调用tf.train.Feature函数,并设置为bytes_list
高度是int64形式,因此创建了_int64_feature函数,其中调用tf.train.Feature函数,并设置为Int64List
最后再利用tf.train.Example函数,将上述的features赋给Example
'''
xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
image_path = FLAGS.VOC_dir + FLAGS.image_dir
save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
#os.mkdir(FLAGS.save_dir)
#writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) #定义了tfrecords文件压缩类型:无,ZLIB,GZIP三种
#writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options) #建立TFRecord存储器,path是TFRecords文件的路径
writer = tf.python_io.TFRecordWriter(path=save_path) #可以用该行代码代替前两个
for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
# to avoid path error in different development platform
xml = xml.replace('\\', '/')
img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
img_path = image_path + '/' + img_name
if not os.path.exists(img_path):
print('{} is not exist!'.format(img_path))
continue
img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
# img = np.array(Image.open(img_path))
img = cv2.imread(img_path)
feature = tf.train.Features(feature={
# maybe do not need encode() in linux
# 'img_name': _bytes_feature(img_name.encode()),
'img_name': _bytes_feature(img_name),
'img_height': _int64_feature(img_height),
'img_width': _int64_feature(img_width),
'img': _bytes_feature(img.tostring()), #
'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
'num_objects': _int64_feature(gtbox_label.shape[0])
})
example = tf.train.Example(features=feature)
writer.write(example.SerializeToString()) #把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串
view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml')))
print('\nConversion is complete!')
检查tfrecord文件是否有问题
import os
import tensorflow as tf
import sys
stdi, stdo, stde = sys.stdin, sys.stdout, sys.stderr
reload(sys)
sys.setdefaultencoding('utf-8')
sys.stdin, sys.stdout, sys.stderr = stdi, stdo, stde
def read_single_example_and_decode(filename_queue):
#如果你在上面转换的代码中采用了前面两行,那么相应的就采用下面这两行
#tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
#reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器
#否则采用:
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
#解析协议块,返回的值是字典
features = tf.parse_single_example(
serialized=serialized_example,
features={
'img_name': tf.FixedLenFeature([], tf.string),
'img_height': tf.FixedLenFeature([], tf.int64),
'img_width': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string),
'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
'num_objects': tf.FixedLenFeature([], tf.int64)
}
)
img_name = features['img_name']
img_height = tf.cast(features['img_height'], tf.int32) #将数据类型int64 转换为int32
img_width = tf.cast(features['img_width'], tf.int32) #将数据类型int64 转换为int32
img = tf.decode_raw(features['img'], tf.uint8) ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的
img = tf.reshape(img, shape=[img_height, img_width, 3]) ##转换图片的形状,此处需要用动态形状进行转换
gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9])
num_objects = tf.cast(features['num_objects'], tf.int32)
return img_name, img, gtboxes_and_label, num_objects
directory = os.path.join('/home/yantianwang/rdfpn/data/tfrecord', 'hangtian_ship_train.tfrecord')
if not os.path.exists(directory):
print('不存在')
filename_tensorlist = tf.train.match_filenames_once(directory) # 获取文件列表
filename_queue = tf.train.string_input_producer(filename_tensorlist)# 创建文件输入队列
img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue) #解析数据
img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = tf.train.batch(
[img_name, img, gtboxes_and_label, num_objects],
batch_size = 1,
capacity=100,
num_threads=16,
dynamic_pad=True)
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
for step in range(10000):
print(step,sess.run(img_name_batch))
coord.request_stop()
coord.join(threads)
如果测试的代码运行无误,那么就说明tfrecord文件没有问题。如果出现问题:
PaddingFIFOQueue '_2_batch/padding_fifo_queue' is closed and has insufficient elements (requested 1, current size 0)
那么和那程度上说明的你准备的数据有问题,需要检查一下样本和相应的xml文件有无问题,比如xml文件中记录的图像长宽与图像不一致、目标的标注超过了图像的范围等等....
读取tfrecord文件生成batch
import tensorflow as tf
import os
from data.io import image_preprocess
def read_single_example_and_decode(filename_queue):
#tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
#reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
#解析协议块,返回的值是字典
features = tf.parse_single_example(
serialized=serialized_example,
features={
'img_name': tf.FixedLenFeature([], tf.string),
'img_height': tf.FixedLenFeature([], tf.int64),
'img_width': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string),
'gtboxes_and_label': tf.FixedLenFeature([], tf.string),
'num_objects': tf.FixedLenFeature([], tf.int64)
}
)
img_name = features['img_name']
img_height = tf.cast(features['img_height'], tf.int32) #将数据类型int64 转换为int32
img_width = tf.cast(features['img_width'], tf.int32) #将数据类型int64 转换为int32
img = tf.decode_raw(features['img'], tf.uint8) ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的
img = tf.reshape(img, shape=[img_height, img_width, 3]) ##转换图片的形状,此处需要用动态形状进行转换
gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32)
gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9])
num_objects = tf.cast(features['num_objects'], tf.int32)
return img_name, img, gtboxes_and_label, num_objects
def read_and_prepocess_single_img(filename_queue, shortside_len, is_training):
img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue)
# img = tf.image.per_image_standardization(img)
img = tf.cast(img, tf.float32)
img = img - tf.constant([103.939, 116.779, 123.68])
if is_training:
img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
target_shortside_len=shortside_len)
img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img, gtboxes_and_label=gtboxes_and_label)
else:
img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label,
target_shortside_len=shortside_len)
return img_name, img, gtboxes_and_label, num_objects
def next_batch(dataset_name, batch_size, shortside_len, is_training):
if dataset_name not in ['ship', 'spacenet', 'pascal', 'coco','hangtian_ship']: #增加自己的数据库名称
raise ValueError('dataSet name must be in pascal or coco')
if is_training:
pattern = os.path.join('../data/tfrecord', dataset_name + '_train*')
else:
pattern = os.path.join('../data/tfrecord', dataset_name + '_test*')
print('tfrecord path is -->', os.path.abspath(pattern))
filename_tensorlist = tf.train.match_filenames_once(pattern)
filename_queue = tf.train.string_input_producer(filename_tensorlist)
img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue, shortside_len,
is_training=is_training)
img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \
tf.train.batch(
[img_name, img, gtboxes_and_label, num_obs],
batch_size=batch_size,
capacity=100,
num_threads=16,
dynamic_pad=True)
return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch
该部分代码包括了对数据的处理