tensorflow保存读取TFRecords文件

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Chang_Shuang/article/details/80229759

Convert data to features

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

shuffle_data = True # shuffle the addresses before saving
cat_dog_train_path = ‘train/*.jpg’

read addresses and labels from the ‘train’ folder

addrs = glob.glob(cat_dog_train_path)
labels = [0 if ‘cat’ in addr else 1 for addr in addrs] # 0 = Cat, 1 = Dog

to shuffle data

if shuffle_data:
c = list(zip(addrs, labels))
shuffle(c)
addrs, labels = zip(*c)

Divide the hata into 60% train, 20% validation, and 20% test

train_addrs = addrs[0:int(0.6 * len(addrs))]
train_labels = labels[0:int(0.6 * len(labels))]

val_addrs = addrs[int(0.6 * len(addrs)):int(0.8 * len(addrs))]
val_labels = labels[int(0.6 * len(addrs)):int(0.8 * len(addrs))]

test_addrs = addrs[int(0.8 * len(addrs)):]
test_labels = labels[int(0.8 * len(labels)):]

Write data into a TFRecords file

train_filename = ‘train.tfrecords’ # address to save the TFRecords file

open the TFRecords file

writer = tf.python_io.TFRecordWriter(train_filename)

for i in range(len(train_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print(‘Train data: {}/{}’.format(i, len(train_addrs)))
sys.stdout.flush()

# Load the image
img = load_image(train_addrs[i])

label = train_labels[i]

# Create a feature
feature = {'train/label': _int64_feature(label),
           'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}

# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))

# Serialize to string and write on the file
writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()

“”“

Write validation and test data into a TFRecords file

open the TFRecords file

val_filename = ‘val.tfrecords’ # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)

for i in range(len(val_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print(‘Val data: {}/{}’.format(i, len(val_addrs)))
sys.stdout.flush()

# Load the image
img = load_image(val_addrs[i])

label = val_labels[i]

# Create a feature
feature = {'val/label': _int64_feature(label),
           'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}

# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))

# Serialize to string and write on the file
writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()

open the TFRecords file

test_filename = ‘test.tfrecords’ # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(test_filename)

for i in range(len(test_addrs)):
# print how many images are saved every 1000 images
if not i % 1000:
print(‘Test data: {}/{}’.format(i, len(test_addrs)))
sys.stdout.flush()

# Load the image
img = load_image(test_addrs[i])

label = test_labels[i]

# Create a feature
feature = {'test/label': _int64_feature(label),
           'test/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}

# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))

# Serialize to string and write on the file
writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()
“”“

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

data_path = ‘train.tfrecords’ # address to save the hdf5 file

with tf.Session() as sess:
feature = {‘train/image’: tf.FixedLenFeature([], tf.string),
‘train/label’: tf.FixedLenFeature([], tf.int64)}

# Create a list of filenames and pass it to a queue
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)

# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)

# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/image'], tf.float32)

# Cast label data into int32
label = tf.cast(features['train/label'], tf.int32)

# Reshape image data into the original shape
image = tf.reshape(image, [224, 224, 3])

# Any preprocessing here ...

# Creates batches by randomly shuffling tensors
images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1,
                                        min_after_dequeue=10)
# Initialize all global and local variables
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)

# Create a coordinator and run all QueueRunner objects
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

for batch_index in range(5):
    img, lbl = sess.run([images, labels])

    img = img.astype(np.uint8)

    for j in range(6):
        plt.subplot(2, 3, j + 1)
        plt.imshow(img[j, ...])
        plt.title('cat' if lbl[j] == 0 else 'dog')

    plt.show()

# Stop the threads
coord.request_stop()

# Wait for threads to stop
coord.join(threads)
sess.close()

猜你喜欢

转载自blog.csdn.net/Chang_Shuang/article/details/80229759