TFrecord:write&read

参考了这位仁兄的博客

概述

在训练卷积神经网络时,将图片提前处理好并缓存在磁盘上,通过中间文件随机调用访问可以明显提高训练速度,并且可以减少重复处理图片的工作。

write

通过tf.train.Example Protocol Buffer
下面代码源于本人写的一个函数

def create_tfrecord(result, sess):
    """
    create tfrecord files for train,validation,test
    Args:
        result: the dictionary of images
        sess: the session

    """
    path = FLAGS.tfrecord_dir
    if not tf.gfile.Exists(path):
        tf.gfile.MakeDirs(path)
    tf_filename = os.path.join(path,'validation.tfrecord')

    jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding()

    writer = tf.python_io.TFRecordWriter(tf_filename)
    #print(len(result['validation']))   
    for index_val,file in enumerate(result['validation']):
        tf.logging.info("write the %d in validation"%index_val)
        name,_ = os.path.splitext(file)
        label= get_labels_array(name + '.txt')
        input_image_array = create_input_tensor(file, sess, jpeg_data_tensor, decoded_image_tensor)
        input_image_string = input_image_array.tostring()
        label_string = label.tostring()
        example = tf.train.Example(features = tf.train.Features(
                feature = {
                        'label': tf.train.Feature(bytes_list = tf.train.BytesList(value = [label_string])),
                        'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [input_image_string]))
                        }))
        writer.write(example.SerializeToString())
    writer.close()

read

读比较麻烦,还要建立线程什么的
注意函数使用了多线程

def read_tfrecord(file_name,batch):
    filename_queue = tf.train.string_input_producer([file_name],)
    reader = tf.TFRecordReader()
    _, serialize_example = reader.read(filename_queue)
    feature = tf.parse_single_example(serialize_example,
                                       features = {
                                               'label': tf.FixedLenFeature([], tf.string),
                                               'image': tf.FixedLenFeature([], tf.string),
                                               })
    labels = tf.decode_raw(feature['label'],tf.int64)
    labels = tf.reshape(labels, [26])
    images = tf.decode_raw(feature['image'],tf.float32)
    images = tf.reshape(images, [1080, 1440, 3])
    #coord = tf.train.Coordinator()
    #threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    #images = tf.squeeze(images)
    images = tf.image.convert_image_dtype(images,tf.int8)
    if batch > 1:
        images, labels = tf.train.shuffle_batch([images,labels],
                                                batch_size=batch,
                                                capacity=500,
                                                num_threads=2,
                                                min_after_dequeue=10)

    return images,labels

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    """
    result = create_image_lists(FLAGS.image_dir,FLAGS.test_dir,30)
    label = get_labels_path(result['testing'])
    """
    #label = get_labels_array(r'G:\GraduateStudy\Smoke Recognition\Newdata\Train\10830004.txt')
    #result = create_image_lists(FLAGS.image_dir, FLAGS.test_dir, 10)
    file_name = r'G:\GraduateStudy\Smoke Recognition\Newdata\Tfrecord\validation.tfrecord'
    image,label = read_tfrecord(file_name,8)
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        #create_tfrecord(result,sess)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord = coord)
        try:
            for i in range(2):
                img,labe = sess.run([image,label])
                #cv2.imwrite('image' + str(i) + '.jpg',img)             
                print(img.shape, labe.shape)
        except tf.errors.OutOfRangeError:
            print('Done reading')
        finally:
            coord.request_stop()

        coord.join(threads)

猜你喜欢

转载自blog.csdn.net/hunt_ing/article/details/80543141
今日推荐