TFRecords数据的生成与读取

TFRecords 作为TensorFlow标准支持格式,将所有信息(包括图片信息)写入到一个tfrecords文件中,便于管理数据。TFRecords 是二进制文件,有特定的写入和读取方式

需要将下图中的image图片文件和anno.txt文件生成为tfrecords格式。


anno.txt中标注数据如下:

E:/anno/image/Drivingdecordr_01.jpg 1 648 512 793 549
E:/anno/image/Drivingdecordr_02.jpg 1 1330 446 1388 477
E:/anno/image/Drivingdecordr_03.jpg 1 26 509 95 543
E:/anno/image/Drivingdecordr_04.jpg 1 1430 437 1503 466
E:/anno/image/Drivingdecordr_05.jpg 1 1582 521 1714 567
E:/anno/image/Drivingdecordr_06.jpg 1 888 502 1053 547
E:/anno/image/Drivingdecordr_07.jpg 1 1419 453 1493 478
E:/anno/image/Drivingdecordr_08.jpg 1 31 495 102 530
E:/anno/image/Drivingdecordr_09.jpg 1 862 471 989 508
以下代码中trans2tfrecords()将标注文件转换为tfrecords格式数据,read_tfrecords()为训练模型读取tfrecords格式数据。
# -*- coding: utf-8 -*-

import tensorflow as tf
import cv2
import sys
import random

def get_data(data_dir):
    imagelist = open(data_dir, 'r')
    dataset = []
    for line in imagelist.readlines():
        info = line.strip().split(' ')
        data_example = dict()
        box = dict()
        data_example['filename'] = info[0]
        data_example['label'] = int(info[1])
        box['xmin'] = float(info[2])
        box['ymin'] = float(info[3])
        box['xmax'] = float(info[4])
        box['ymax'] = float(info[5])
        data_example['box'] = box
        dataset.append(data_example)
    return dataset

def _int64_feature(value):
    """Wrapper for insert int64 feature into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
    """Wrapper for insert float features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _bytes_feature(value):
    """Wrapper for insert bytes features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def trans2tfrecords(tf_filename, dataset):
	with tf.python_io.TFRecordWriter(tf_filename) as writer:
		for i, image_example in enumerate(dataset):
			sys.stdout.write('\r>> Converting image %d/%d\n' % (i + 1, len(dataset)))
            #sys.stdout.flush()
			filename = image_example['filename']
			image_data = extract(filename)
			class_label = image_example['label']
			box = image_example['box']
			roi = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
			#tf.train.Example协议内存块,包含字段Features,Features中包含Feature的字典
			example = tf.train.Example(features = tf.train.Features(feature={     
				'image': _bytes_feature(image_data),
				'label': _int64_feature(class_label),
				'roi': _float_feature(roi)
				}))
				
			writer.write(example.SerializeToString())
	
def extract(filename):
	image = cv2.imread(filename)
	image_data = image.tostring()   #将numpy类转化为string类
	return image_data
	
def read_tfrecords(tfrecord_file):
    filename_queue = tf.train.string_input_producer([tfrecord_file],shuffle=True) #生成队列,并随机打乱顺序
    reader = tf.TFRecordReader()    #文件读取器
    _, serialized_example = reader.read(filename_queue)
    image_features = tf.parse_single_example(              #解析器
            serialized_example,
            features = {
                    'image': tf.FixedLenFeature([], tf.string),
                    'label':tf.FixedLenFeature([], tf.int64),
                    'roi': tf.FixedLenFeature([4], tf.float32)
                    })
    image = tf.decode_raw(image_features['image'], tf.uint8)      #解码器
    image = tf.reshape(image, [128, 128, 3])
    image = (tf.cast(image, tf.float32)-127.5) / 128
    
    label = tf.cast(image_features['label'], tf.float32)
    roi = tf.cast(image_features['roi'],tf.float32)
    return image, label, roi
	
if __name__ == '__main__':
    data_dir = 'E:/anno/anno.txt'
    imagelist = open(data_dir, 'r')
	
    output_dir = 'E:/anno'
    tf_filename = '%s/train_data.tfrecord' %output_dir
	
    shuffling = 'True'
	
    dataset = get_data(data_dir)
    if shuffling:
        random.shuffle(dataset)
	
    trans2tfrecords(tf_filename, dataset)
    print('Finished converting dataset!')
	
    image, label, roi = read_tfrecords(tf_filename)	
    
	

猜你喜欢

转载自blog.csdn.net/ghy_111/article/details/79624402