Tfrecord file is a training sample storage format specially designed by tensorflow. Packing training samples into tfrecord format can speed up the file reading efficiency. So the first step in training the network is to package your own training set samples to generate tfrecord format. This article mainly introduces two tfrecord packaging methods. The main difference between the two methods is that the size of the generated tfrecord files is different.
Method 1: Use a common image processing library to read and decode the image, convert it into a binary file for storage, basically this method is found on the Internet.
write to tfrecord file
def data_to_tfrecord(images, labels, filename): # images are stored in a list of all image paths """ Save data into TFRecord """ # labels is the label corresponding to each image in images if os.path.isfile(filename): # filenames is the tfrecord file name print("%s exists" % filename) return print("Converting data into %s ..." % filename) writer = tf.python_io.TFRecordWriter(filename) for index, img_file in zip(labels, images): img1 = Image.open(img_file) # Read and decode images through the Images function in the PIL package img1 = e.g. asarray (img1, e.g. uint8) width, height, channel = img1.shape # Get the width, height and depth parameters of the image img_raw = img1.tobytes() # Convert image to binary sequence label = int(index) # The label corresponding to the image example = tf.train.Example( features=tf.train.Features( feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), # save the label 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), # save binary sequence 'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), # save the width of the image 'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), # save the height of the image 'img_channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[channel])) # save the depth of the image } ) ) writer.write(example.SerializeToString()) # Serialize To String writer.close()
read tfrecord file
import numpy as np import tensorflow as tf import tensorlayer as tl def read_and_decode(filename): """ Return tensor to read from TFRecord """ filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), # Read various information from the tfrecord file 'img_raw': tf.FixedLenFeature([], tf.string), 'img_width': tf.FixedLenFeature([], tf.int64), 'img_height': tf.FixedLenFeature([], tf.int64), 'img_channel': tf.FixedLenFeature([], tf.int64) } ) # You can do more image distortion here for training data width = tf.cast(features['img_width'], tf.int32) # 转型 height = tf.cast(features['img_height'], tf.int32) channel = tf.cast(features['img_channel'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) # Convert from binary to uint8 img = tf.reshape(img, [width, height, channel]) # Reshape the image, note that a sequence is stored in the tfrecord file, and there is no shape img = tf.image.resize_images(img, [32, 32]) # resize images to the same size # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label # Example to visualize data img, label = read_and_decode("train.tfrecord") img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=4, capacity=5000, min_after_dequeue=100, num_threads=1) print("img_batch : %s" % img_batch._shape) print("label_batch : %s" % label_batch._shape) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(3): # number of mini-batch (step) print("Step %d" % i) val, l = sess.run([img_batch, label_batch]) # exit() print(val.shape, l) tl.visualize.images2d(val, second=1, saveable=False, name='batch'+str(i), dtype=np.uint8, fig_idx=2020121) coord.request_stop() coord.join(threads) sess.close()
Method 2: Use tf.gfile.FastGFile to read the image information (it seems that there is no decoding), and convert it into binary file storage.
This method is the method I use to generate tfrecord files in the slim framework of tensorflow in github.
write to tfrecord file
def data_to_tfrecord(images, labels, filename): """ Save data into TFRecord """ if os.path.isfile(filename): print("%s exists" % filename) return print("Converting data into %s ..." % filename) writer = tf.python_io.TFRecordWriter(filename) for index, img_file in zip(labels, images): img1 = Image.open(img_file) img1 = e.g. asarray (img1, e.g. uint8) width, height, channel = img1.shape # img_raw = img1.tobytes() img_raw = tf.gfile.FastGFile(img_file, 'rb').read() # The difference from method 1 is the FastGFile function used label = int(index) example = tf.train.Example( features=tf.train.Features( feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'img_width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), 'img_height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), 'img_channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[channel])) } ) ) writer.write(example.SerializeToString()) # Serialize To String writer.close()
read tfrecord file
import numpy as np import tensorflow as tf import tensorlayer as tl def read_and_decode(filename): """ Return tensor to read from TFRecord """ filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), 'img_width': tf.FixedLenFeature([], tf.int64), 'img_height': tf.FixedLenFeature([], tf.int64), 'img_channel': tf.FixedLenFeature([], tf.int64) } ) # You can do more image distortion here for training data width = tf.cast(features['img_width'], tf.int32) height = tf.cast(features['img_height'], tf.int32) channel = tf.cast(features['img_channel'], tf.int32) # img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.image.decode_jpeg(features['img_raw']) # The difference from method 1 is that it needs to be decoded with decode_jpeg img = tf.reshape(img, [width, height, channel]) img = tf.image.resize_images(img, [32, 32]) # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label # Example to visualize data img, label = read_and_decode("train") img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=4, capacity=5000, min_after_dequeue=100, num_threads=1) print("img_batch : %s" % img_batch._shape) print("label_batch : %s" % label_batch._shape) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(3): # number of mini-batch (step) print("Step %d" % i) val, l = sess.run([img_batch, label_batch]) # exit() print(val.shape, l) tl.visualize.images2d(val, second=1, saveable=False, name='batch'+str(i), dtype=np.uint8, fig_idx=2020121) coord.request_stop() coord.join(threads) sess.close()
The difference between the two ways
Although the difference between the two methods is only one or two lines in the code, there is still a big difference in the generated tfrecord file. I use the same image sample set, about 200M, the tfrecord file generated by method 1 is about 900M, and the tfrecord file generated by method 2 is about 200M. Obviously there is a big difference in memory usage. According to my personal guess, the first scheme converts the image into a binary file after decoding, and the second scheme does not decode it but directly converts it into a binary file for storage, so image decoding is required when reading. This is just my personal guess. If there is a great god who understands, I hope to enlighten me.