TFRecords是TensorFlow专用的数据文件格式。其中包含tf.train.Example协议内存块(protocol buffer),这是包含特征值和数据内容的一种数据格式。通过tf.python.io.TFRecordWriter类,可以获取相应的数据并将其填入到Example协议内存块中,最终生成TFRecords文件。
任何Feature中包含着FloatList,或者ByteList,或者Int64List这三种数据格式中的几种,TFRecords通过包含着二进制文件的数据文件,将特征和标签进行保存以便于TensorFlow读取
将图片和对应标签写入TFRecords文件的代码:
def int64_feature(value): #[]输入为list
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) #方括号表示输入为列表 转化为二进制形式
def convert_to_tfrecord(images_list,labels_list,save_dir,name):
'''
images_list: 图片文件的路径列表
labels_list: 标签的列表
save_dir: 用来保存tfrecords文件的路径
name: tfrecords文件的名字
'''
filename=os.path.join(save_dir,name+'.tfrecords')
n_samples=len(labels_list)
writer=tf.python_io.TFRecordWriter(filename) #实例化并传入保存文件路径 写入到文件中
print('\nTransform start......')
for i in np.arange(0,n_samples):
try:
image=io.imread(images_list[i])
image_raw=image.tostring()
label=int(labels_list[i])
example=tf.train.Example(features=tf.train.Features(feature={ #协议内存块
'label':int64_feature(label),
'image_raw':bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
except IOError as e:
print('Could not read:',images_list[i])
writer.close()
print('Transform done!')
读取相应TFRecords文件:
tf.train.shuffle_batch函数的参数解释
'''
Args:
tensors: 入队列表向量或字典向量The list or dictionary of tensors to enqueue.
batch_size: 每次入队出队的数量The new batch size pulled from the queue.
capacity: 队列中最大的元素数量An integer. The maximum number of elements in the queue.
min_after_dequeue: 在一次出队以后对列中最小元素数量Minimum number elements in the queue after a dequeue, used to ensure a level of mixing of elements.
num_threads: 向量列表入队的线程数The number of threads enqueuing tensor_list.
seed: 队列中shuffle的种子Seed for the random shuffling within the queue.
enqueue_many: 向量列表中的每个向量是否是单个实例Whether each tensor in tensor_list is a single example.
shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for tensor_list.
allow_smaller_final_batch: (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue.
shared_name: (Optional) If set, this queue will be shared under the given name across multiple sessions.
name: (Optional) A name for the operations.
'''
def read_and_decode(tfrecords_file,batch_size):
# 返回输出队列,QueueRunner加入到当前图中的QUEUE_RUNNER收集器
filename_queue=tf.train.string_input_producer([tfrecords_file])
reader=tf.TFRecordReader() #实例化读取器
_,serialized_example=reader.read(filename_queue) #返回队列当中的下一个键值对tensor
# 输入标量字符串张量,输出字典映射向量tensor和稀疏向量值
img_features=tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([],
tf.int64),
'image_raw':tf.FixedLenFeature([],
tf.string),
})
image=tf.decode_raw(img_features['image_raw'],tf.uint8) #解析字符向量tensor为实数,需要有相同长度
image=tf.reshape(image,[227,227,3])
label=tf.cast(img_features['label'],tf.int32)
#从TFRecords中读取数据,保证内容和标签同步
image_batch,label_batch=tf.train.shuffle_batch([image,label],
batch_size=batch_size,
min_after_dequeue=100,
num_threads=64,
capacity=200)
return image_batch,tf.reshape(label_batch,[batch_size])