《TensorFlow深度学习应用实践》学习笔记之TFRecords文件的创建和读取

TFRecords是TensorFlow专用的数据文件格式。其中包含tf.train.Example协议内存块(protocol buffer),这是包含特征值和数据内容的一种数据格式。通过tf.python.io.TFRecordWriter类,可以获取相应的数据并将其填入到Example协议内存块中,最终生成TFRecords文件。

任何Feature中包含着FloatList,或者ByteList,或者Int64List这三种数据格式中的几种,TFRecords通过包含着二进制文件的数据文件,将特征和标签进行保存以便于TensorFlow读取

将图片和对应标签写入TFRecords文件的代码:
 

def int64_feature(value):                                                 #[]输入为list
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  #方括号表示输入为列表 转化为二进制形式
 
def convert_to_tfrecord(images_list,labels_list,save_dir,name):
    '''
    images_list: 图片文件的路径列表
    labels_list: 标签的列表
    save_dir: 用来保存tfrecords文件的路径
    name: tfrecords文件的名字
    '''
    filename=os.path.join(save_dir,name+'.tfrecords')
    n_samples=len(labels_list)
    writer=tf.python_io.TFRecordWriter(filename)  #实例化并传入保存文件路径 写入到文件中
    print('\nTransform start......')
    for i in np.arange(0,n_samples):
        try:
            image=io.imread(images_list[i])
            image_raw=image.tostring()
            label=int(labels_list[i])
            example=tf.train.Example(features=tf.train.Features(feature={    #协议内存块
                'label':int64_feature(label),
                'image_raw':bytes_feature(image_raw),
            }))
            writer.write(example.SerializeToString())
        except IOError as e:
            print('Could not read:',images_list[i])
    writer.close()
    print('Transform done!')

读取相应TFRecords文件:
tf.train.shuffle_batch函数的参数解释

'''
Args:
tensors: 入队列表向量或字典向量The list or dictionary of tensors to enqueue.
batch_size: 每次入队出队的数量The new batch size pulled from the queue.
capacity: 队列中最大的元素数量An integer. The maximum number of elements in the queue.
min_after_dequeue: 在一次出队以后对列中最小元素数量Minimum number elements in the queue after a dequeue, used to ensure a level of mixing of elements.
num_threads: 向量列表入队的线程数The number of threads enqueuing tensor_list.
seed: 队列中shuffle的种子Seed for the random shuffling within the queue.
enqueue_many: 向量列表中的每个向量是否是单个实例Whether each tensor in tensor_list is a single example.
shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for tensor_list.
allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue.
shared_name: (Optional) If set, this queue will be shared under the given name across multiple sessions.
name: (Optional) A name for the operations.
'''
def read_and_decode(tfrecords_file,batch_size):
    # 返回输出队列,QueueRunner加入到当前图中的QUEUE_RUNNER收集器
    filename_queue=tf.train.string_input_producer([tfrecords_file])
 
    reader=tf.TFRecordReader()        #实例化读取器
    _,serialized_example=reader.read(filename_queue) #返回队列当中的下一个键值对tensor
 
    # 输入标量字符串张量,输出字典映射向量tensor和稀疏向量值
    img_features=tf.parse_single_example(serialized_example,
                                         features={
                                             'label':tf.FixedLenFeature([],
                                                                        tf.int64),
                                             'image_raw':tf.FixedLenFeature([],
                                                                            tf.string),
                                         })
    image=tf.decode_raw(img_features['image_raw'],tf.uint8) #解析字符向量tensor为实数,需要有相同长度
    image=tf.reshape(image,[227,227,3])
    label=tf.cast(img_features['label'],tf.int32)
 
    #从TFRecords中读取数据,保证内容和标签同步
    image_batch,label_batch=tf.train.shuffle_batch([image,label],
                                                   batch_size=batch_size,
                                                   min_after_dequeue=100,
                                                   num_threads=64,
                                                   capacity=200)
    return image_batch,tf.reshape(label_batch,[batch_size])

猜你喜欢

转载自blog.csdn.net/shiheyingzhe/article/details/82312891
今日推荐