Tensorflow学习笔记:读取二进制文件、读写TFRecord文件

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/WilliamCode/article/details/85268850

#图像基本知识
    OpenCV已经学过

#图片操作目的:
    增加图片数据的统一性:大小与格式统一
    缩小图片数据量,防止增加开销

#图片操作:放大或缩小
    tf.image.resize_images(images,size)
        image:4-D数组[batch,length,width,depth] 或者 3-D数组[length,width,depth]的图片数据
        size:一维的int32张量[new_length,new_width]

#图片读取API
    tf.WholeFileReader()将文件全部内容作为值的读取器
        return :读取器实例
        read方法(file_queue)输出是一个文件名(key)和该文件内容(值value)
    
    tf.image.decode_jpeg(contents)将jpeg文件解码为uint8类型的张量
        return    uint8类型张量[height,length,channels]
    tf.image.decode_png(contests)将png文件解码为uint8或者uint16类型的张量

#图片批量读取示例:

def picreader(filelist):
    """
    args:list of picture file to read
    return: batch of read result
    """
    #1、构造文件阅读队列
    file_queue = tf.train.string_input_producer(filelist)


    #2、构造文件阅读器读取文件
    reader = tf.WholeFileReader()
    key , raw_data = reader.read(file_queue)

    print(raw_data)
    #3、构造文件解码器
    read_result = tf.image.decode_jpeg(raw_data)
    print(read_result)

    #4、统一图片大小
    read_result = tf.image.resize_images(read_result, [200,200])

    #固定样本形状,否则无法放入队列
    read_result.set_shape([200,200,3])
    print(read_result)

    #5、进行批处理
    read_result_batch = tf.train.batch([read_result], batch_size = 5, num_threads = 1, capacity = 7)

    #6、返回数据
    return read_result_batch

import os
import tensorflow as tf

if __name__ == "__main__":
    dir_file_list = os.listdir(".\\pic_data\\")
    dir_file_list = ["C:\\Users\\xie\\pic_data\\" + i for i in dir_file_list]
    print(dir_file_list)
    
    read_result = picreader(dir_file_list)

    with tf.Session() as sess:

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord = coord)
        for i in range(3):
            print(sess.run(read_result).shape)

        coord.request_stop()
        coord.join(threads)

#tf的切片操作
    a = [lable,imageData1,imageData2 ...]
    lable = tf.slice(a, [0], [1])
    image = tf.slice(a, [1], [size_of_data])


#读取二进制文件代码:
    def binreader(filelist):
    
    #1、构造文件列表
    file_queue = tf.train.string_input_producer(filelist)

    #2、构造阅读器进行阅读
    reader = tf.FixedLengthRecordReader(1024 * 3 + 1)
    key, value = reader.read(file_queue)

    #3、构造解码器进行解码
    decoded_data = tf.decode_raw(value, tf.uint8)

    #4、分割出图片和标签

    lable = tf.slice(decoded_data, [0], [1])
    image = tf.slice(decoded_data, [1], [3072])

     #5、改变image大小
    image = tf.reshape(image, [32, 32, 3])

    #6、批处理
    batch_lable, batch_image = tf.train.batch([lable,image], batch_size = 10, num_threads = 1, capacity = 100)

    print(batch_lable, batch_image)

    return batch_lable, batch_image

import tensorflow as tf 
import os

if __name__ == '__main__':
    
    #获取目标路径下文件列表
    dir_file_list = os.listdir('C:\\Users\\xie\\binary_data\\')
    
    #选择适当文件
    dir_file_list = ['C:\\Users\\xie\\binary_data\\' + i for i in dir_file_list \
                        if i[-3:] == 'bin']

    read_result = binreader(dir_file_list)

    with tf.Session() as sess:

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord = coord)

        print(sess.run(read_result))

        coord.request_stop()
        coord.join(threads)


#tf自带的文件格式TFRecords
    为了将特征值和标签值存储在一个标签内部,方便存取和移动
    用Example协议块(一个类似字典的格式)存储样本
#写入TFRecord文件
    1、建立TFrecord存储器
    tf.python_io.TFRecordWriter(path)
    return 文件写入器实例
    
    方法:write(string)写入一个字符串记录(一个example)
        close()
#代码示例:

def binreader(filelist):
    
    #1、构造文件列表
    file_queue = tf.train.string_input_producer(filelist)

    #2、构造阅读器进行阅读
    reader = tf.FixedLengthRecordReader(1024 * 3 + 1)
    key, value = reader.read(file_queue)

    #3、构造解码器进行解码
    decoded_data = tf.decode_raw(value, tf.uint8)

    #4、分割出图片和标签

    lable = tf.slice(decoded_data, [0], [1])
    image = tf.slice(decoded_data, [1], [3072])

     #5、改变image大小
    image = tf.reshape(image, [32, 32, 3])

    #6、批处理
    batch_lable, batch_image = tf.train.batch([lable,image], batch_size = 10, num_threads = 1, capacity = 100)

    print(batch_lable, batch_image)

    return batch_lable, batch_image


def write_to_tfrecords(batch_lable, batch_image):
    """
    将文件的特征值和目标值存入tfrecords文件中
    :param batch_lable 10个标签
    :param batch_image 10个特征值
    """
    #1、构造TFRecorder存储器
    writer = tf.python_io.TFRecordWriter("./binary_data/TFRecord/cifar.tfrecords")

    #2、循环将所有样本写入文件,每张图片都要构造example协议
    for i in range(10):
        #取出第i个样本的特征值和标签值
        image = batch_image[i].eval().tostring()
        lable = batch_lable[i].eval()[0]

        #构造Example
        example = tf.train.Example(features = tf.train.Features(feature = {
            "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
            "lable":tf.train.Feature(int64_list = tf.train.Int64List(value = [lable]))}))
        writer.write(example.SerializeToString())
        print("%d times of store is over"%i)

    writer.close()

import tensorflow as tf 
import os

if __name__ == '__main__':
    
    #获取目标路径下文件列表
    dir_file_list = os.listdir('C:\\Users\\xie\\binary_data\\')
    
    #选择适当文件
    dir_file_list = ['C:\\Users\\xie\\binary_data\\' + i for i in dir_file_list \
                        if i[-3:] == 'bin']

    read_result = binreader(dir_file_list)

    with tf.Session() as sess:

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord = coord)

        sess.run(read_result)
        print("start to store")
        write_to_tfrecords(read_result[0],read_result[1])
        print("end of store")
        coord.request_stop()
        coord.join(threads)


#读取tfrecords文件(多了一步解析)
    tf.parse_single_example(serialized,    #读出的内容
                features = None,#dict字典数据,键位读取的名字,值为FixedLenFeature
        return:一个键值对组成的字典,键为读取的名字
#示例代码:

def binreader(filelist):
    
    #1、构造文件列表
    file_queue = tf.train.string_input_producer(filelist)

    #2、构造阅读器进行阅读
    reader = tf.FixedLengthRecordReader(1024 * 3 + 1)
    key, value = reader.read(file_queue)

    #3、构造解码器进行解码
    decoded_data = tf.decode_raw(value, tf.uint8)

    #4、分割出图片和标签

    lable = tf.slice(decoded_data, [0], [1])
    image = tf.slice(decoded_data, [1], [3072])

     #5、改变image大小
    image = tf.reshape(image, [32, 32, 3])

    #6、批处理
    batch_lable, batch_image = tf.train.batch([lable,image], batch_size = 10, num_threads = 1, capacity = 100)

    print(batch_lable, batch_image)

    return batch_lable, batch_image


def write_to_tfrecords(batch_lable, batch_image):
    """
    将文件的特征值和目标值存入tfrecords文件中
    :param batch_lable 10个标签
    :param batch_image 10个特征值
    """
    #1、构造TFRecorder存储器
    writer = tf.python_io.TFRecordWriter("./binary_data/TFRecord/cifar.tfrecords")

    #2、循环将所有样本写入文件,每张图片都要构造example协议
    for i in range(10):
        #取出第i个样本的特征值和标签值
        image = batch_image[i].eval().tostring()
        lable = batch_lable[i].eval()[0]

        #构造Example
        example = tf.train.Example(features = tf.train.Features(feature = {
            "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
            "lable":tf.train.Feature(int64_list = tf.train.Int64List(value = [lable]))}))
        writer.write(example.SerializeToString())
        print("%d times of store is over"%i)

    writer.close()

def read_from_tfrecords():
    #1、构造文件队列
    file_queue = tf.train.string_input_producer(['C:\\Users\\xie\\binary_data\\TFRecord\\cifar.tfrecords'])
    
    #2、构造文件阅读器
    reader = tf.TFRecordReader()

    #3、读取队列,value也是一个样本的序列化值
    key, value = reader.read(file_queue)

    features = tf.parse_single_example(value, 
                    features = {'image':tf.FixedLenFeature([],tf.string),
                                'lable':tf.FixedLenFeature([],tf.int64)})
    

    #4、解码:当且仅当取出string类型时,需要解码
    image = tf.decode_raw(features["image"], tf.uint8)
    lable = features["lable"]

    
    image_reshape = tf.reshape(image, [32, 32, 3])
    
    
    image_batch, lable_batch = tf.train.batch([image_reshape, lable], batch_size = 10, num_threads = 1, capacity = 10)

    return image_batch, lable_batch

import tensorflow as tf 
import os

if __name__ == '__main__':
    
    #获取目标路径下文件列表
    dir_file_list = os.listdir('C:\\Users\\xie\\binary_data\\')
    
    #选择适当文件
    dir_file_list = ['C:\\Users\\xie\\binary_data\\' + i for i in dir_file_list \
                        if i[-3:] == 'bin']

    image_batch, lable_batch = read_from_tfrecords()

    with tf.Session() as sess:

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord = coord)
        print(sess.run([image_batch, lable_batch]))
        
        
        coord.request_stop()
        coord.join(threads)

猜你喜欢

转载自blog.csdn.net/WilliamCode/article/details/85268850