版权声明:如使用此博客内容,请经过作者同意,谢谢 https://blog.csdn.net/qq_40994943/article/details/85173929
import os
import tensorflow as tf
from PIL import Image
#保存数据到tfrecords文件
def convert2tfr(path, name):
classes = 3 # 类别数目
writer = tf.python_io.TFRecordWriter(name + '.tfrecords') # 要生成的文件
for index in range(classes):
class_path = path + str(index) + '/'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一个图片的地址
img = Image.open(img_path)
img = img.convert("RGB") # 转换成RGB格式
img = img.resize((32, 32))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
#tensorflow运行读取图片batch函数
def read_and_decode(filename, batch_size):
filename_queue = tf.train.string_input_producer([filename]) # create a queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [32, 32, 3]) # reshape image to 32*32*3
# 3通道转为1通道
img = tf.image.rgb_to_grayscale(img) # 图像灰度化 32*32*1
# img = tf.reshape(img, [32,32]) #reshape image to 32*32
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # throw img tensor
label = tf.cast(features['label'], tf.int64) # throw label tensor
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=batch_size,
capacity=1000 + batch_size * 3, min_after_dequeue=1000)
label_batch = tf.one_hot(label_batch, depth=81)
return img_batch, label_batch
#解析tfrecords文件变成图片源
def tfr2bmp(filename, dir, num): # tfr文件名,解析后存放的目录,图片数量
if not os.path.exists('read_img'):
os.mkdir('read_img')
os.mkdir('read_img/' + dir)
filename_queue = tf.train.string_input_producer([filename]) # 读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [32, 32, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: # 开始一个会话
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(num):
example, l = sess.run([image, label]) # 在会话中取出image和label
img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
img.save('read_img/' + dir + '/' + str(i) + '_Label_' + str(l) + '.bmp') # 存下图片
print(example, l)
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
# 将图片转换为TFR格式
#convert2tfr('G:/tfrecord_test/', 'tfrecord_test')
#convert2tfr('test_img/', 'tai_test1')
# 读取TFR格式数据
# read_and_decode('tai_test.tfrecords',batch_size)
# read_and_decode('tai_train.tfrecords',batch_size)
# 提取TFR格式数据并保存
tfr2bmp(filename='tfrecord_test.tfrecords', dir='test_img', num=15)
#tfr2bmp(filename='tai_train.tfrecords', dir='train_img', num=48961)
refer to 博客