读取tensorflow object detection的tfrecord

数据太多,找不到原图片和label了。只有手头的tfrecord,记录一下读取过程。

# -*- coding: utf-8 -*-

import cv2
import os
import tensorflow as tf
import numpy as np

flags = tf.app.flags
flags.DEFINE_string('tfrecord_path', 'F:/F_Datasets/Car/plate_bbox/ReadTFrecord/plate_green.record', 'path to tfrecord file')
flags.DEFINE_integer('resize_height', 800, 'resize height of image')
flags.DEFINE_integer('resize_width', 800, 'resize width of image')
FLAG = flags.FLAGS
slim = tf.contrib.slim


def print_data(image, resized_image, label, height, width,xmins, ymins, xmaxs, ymaxs ):
    with tf.Session() as sess:
        #不加这三个初始化,DatasetDataProvider那里设置num_epoch=1会报错
        sess.run(tf.tables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        i = -1
        error_cnt = 0
        while True:
            i+=1
            try:
                #print_image,print_height, print_width,print_label= sess.run([image,height,width,label])
                print_image,print_label, print_height, print_width,print_xmins,print_ymins,print_xmaxs,print_ymaxs  = sess.run([image,label,height, width,xmins, ymins, xmaxs, ymaxs ])
                print("______________________image({})___________________".format(i))
                #print("resized_image shape is: ", print_resized_image.shape)
                print("image shape is: ", print_image.shape)
                print("image label is: ", print_label)
                print("image height is ", print_height)
                print("image width is: ", print_width)
                print("image xmins is: ",print_xmins)
                print("image ymins is: ",print_ymins)
                print("image xmaxs is: ",print_xmaxs)
                print("image ymaxs is: ",print_ymaxs)

                # img_np = np.array(print_image)
                # img_cvt = cv2.cvtColor(img_np,cv2.COLOR_RGB2BGR)
                # cv2.imshow("a",img_cvt)
                # cv2.waitKey()
                
                error_cnt = 0

            except Exception as e:
                error_cnt+=1
                if(error_cnt%499==0):
                    print("***********************************")
                    print("***********************************")
                    print(str(i)+" image couldn't read")
                    print(e)
                    print("***********************************")
                    print("***********************************")
                if(error_cnt>1000):
                    break


        coord.request_stop()
        coord.join(threads)

def reshape_same_size(image, output_height, output_width):
    """Resize images by fixed sides.
    
    Args:
        image: A 3-D image `Tensor`.
        output_height: The height of the image after preprocessing.
        output_width: The width of the image after preprocessing.

    Returns:
        resized_image: A 3-D tensor containing the resized image.
    """
    output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
    output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)

    image = tf.expand_dims(image, 0)
    resized_image = tf.image.resize_nearest_neighbor(
        image, [output_height, output_width], align_corners=False)
    resized_image = tf.squeeze(resized_image)
    return resized_image


def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800):
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string,),
        'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
        #'image/object/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
        'image/object/class/label': tf.VarLenFeature(tf.int64),
        'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
        'image/width': tf.FixedLenFeature([], tf.int64, default_value=0),
        'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
        # 'image/object/bbox/xmin': tf.VarLenFeature([], tf.float32, default_value=0),
        # 'image/object/bbox/xmax': tf.VarLenFeature([], tf.float32, default_value=0),
        # 'image/object/bbox/ymin': tf.VarLenFeature([], tf.float32, default_value=0),
        # 'image/object/bbox/ymax': tf.VarLenFeature([], tf.float32, default_value=0),

    }

    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
        'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
        'width': slim.tfexample_decoder.Tensor('image/width', shape=[]),
        'label': slim.tfexample_decoder.Tensor('image/object/class/label'),
        'xmins': slim.tfexample_decoder.Tensor('image/object/bbox/xmin'),
        'ymins': slim.tfexample_decoder.Tensor('image/object/bbox/ymin'),
        'xmaxs': slim.tfexample_decoder.Tensor('image/object/bbox/xmax'),
        'ymaxs': slim.tfexample_decoder.Tensor('image/object/bbox/ymax'),

        #注意这里不要设置shape = [],不然报错
        # 'label': slim.tfexample_decoder.Tensor('image/object/class/label,shape=[]')
        # 'xmins': slim.tfexample_decoder.Tensor('image/object/bbox/xmin',shape=[]),
        # 'ymins': slim.tfexample_decoder.Tensor('image/object/bbox/ymin',shape=[]),
        # 'xmaxs': slim.tfexample_decoder.Tensor('image/object/bbox/xmax',shape=[]),
        # 'ymaxs': slim.tfexample_decoder.Tensor('image/object/bbox/ymax',shape=[]),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    labels_to_names = None
    items_to_descriptions = {
        'image': 'Test items_to_descriptions image',
        'label': 'Test items_to_descriptions label'}

    dataset = slim.dataset.Dataset(
        data_sources=tfrecord_path,
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=num_samples,
        items_to_descriptions=None,
        num_classes=num_classes,
    )


    provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
                                                              num_readers=1,
                                                              shuffle=False,
                                                              common_queue_capacity=256,
                                                              common_queue_min=128,
                                                              seed=None,
                                                              num_epochs=1)
    image, label, height, width, xmins, ymins, xmaxs, ymaxs = provider.get(['image','label','height', 'width','xmins','ymins','xmaxs','ymaxs'])
    resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))
    return resized_image, label, image, height, width, xmins, ymins, xmaxs, ymaxs




def main():
    resized_image, label, image, height, width,xmins, ymins, xmaxs, ymaxs = read_tfrecord(tfrecord_path=FLAG.tfrecord_path,
                                                               resize_height=FLAG.resize_height,
                                                               resize_width=FLAG.resize_width)
    #resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width)
    #resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
    print_data(image, resized_image, label, height, width,xmins, ymins, xmaxs, ymaxs )
  


if __name__ == '__main__':
    main()
发布了161 篇原创文章 · 获赞 71 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/yxpandjay/article/details/101279626