版权声明:本文为博主原创文章,未经博主允许不得转载 https://blog.csdn.net/qq_35240640/article/details/89452884
基于Tensorflow框架的tfrecord文件的生成与读取
一. 根据自己已有数据集,生成tfrecord文件,附有详细的注释:
我所用的数据集只有猫狗两类,下面是部分示例图:
#coding=utf-8
import cv2
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
def writeTFRecord(imgDir, recordPath, imgH, imgW, encoderLabel, recordName):
'''
创建tfrecord文件
param:
imgDir: 预处理图片路径
recordPath: 预保存tfrecord文件路径
imgH: 为了方便训练,将原图resize为指定的高
imgW: 指定的宽
encoderLabel: 编码标签,将标签以数字的形式呈现
recordName: 预保存tfrecord文件名,eg:test.tfrecord
'''
imgPathList = [os.path.join(imgDir, img) for img in os.listdir(imgDir)] #将所有图片路径绝对路径放入list中
np.random.shuffle(imgPathList) #打乱imgList中的图片顺序
if not os.path.exists(recordPath):
os.makedirs(recordPath)
writer = tf.python_io.TFRecordWriter(path=os.path.join(recordPath, recordName)) #建立TFRecord存储器
for i, ip in enumerate(imgPathList):
img = cv2.imread(ip, cv2.IMREAD_COLOR)
# if(i == 0): #显示第0张图
# cv2.imshow(ip, img)
# cv2.waitKey(0)
if(type(img).__name__ != 'NoneType'):
imgName = ip.split('\\')[-1].split('.')[0] #name eg:cat.0.jpg
label = encoderLabel[imgName] #根据标签编码字典,找到其对应的标签数字
img = cv2.resize(img, (imgW, imgH)) #将图片统一调成指定大小
imgByte = img.tobytes() #将图片转化为字节类型
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgByte])),
'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
})) #创建 Example 对象,并且将 Feature 一一对应填充进去
# x = tf.train.Example(features=tf.train.Features(feature={'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
# 'll': tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))}))
# print(x) #举了一个例子,用于了解上述方法的作用及输出结果
writer.write(example.SerializeToString()) #将 example序列化成 string 类型,然后写入
writer.close()
'''创建tfrecord文件的参数设置部分'''
encoderLabel = {'dog':0, 'cat':1} #编码标签,将标签以数字的形式呈现
imgH = 80 #图片数据集的高宽
imgW = 100
imgDir = r'F:\kaggle\val' #图片存放路径,我的图片命名形式:F:\kaggle\val\dog.12335.jpg,前面是路径
recordPath = r'E:\Learning\tensorflowPractice\dataProcess\tfrecord\catDog' #准备保存生成的tfrecord文件的路径
recordName = 'val.tfrecord' #tfrecord文件名
'''传入参数,开始创建'''
writeTFRecord(imgDir, recordPath, imgH, imgW, encoderLabel, recordName)
生成tfrecord文件:
二. 读取生成的tfrecord文件
#coding=utf-8
import cv2
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
def readTFRecord(tfrecordFile, numClass, imgW, imgH, imgChannel):
'''
读取tfrecord文件
param:
tfrecordFile: tfrecord绝对文件名
numClass: 标签种类数
imgW、imgH: 定义的图片宽高
imgChannel: 图片通道数
return:
(img, label)
'''
fileNameQueue = tf.train.string_input_producer([tfrecordFile]) #创建文件名队列(文件名list,epochs:传入多少批次,shuffle:每一epoch内是否打乱)
reader = tf.TFRecordReader()
_, serializedExample = reader.read(fileNameQueue)
features = tf.parse_single_example(serializedExample,
features={
'img_raw': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.int64)
})
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.cast(img, tf.float32)
img = tf. reshape(img, [imgH, imgW, imgChannel])
img = img / 255
img = tf.subtract(img, 0.5) #减去0.5
img = tf.multiply(img, 2.5) #乘以2.5
label = features['labels']
label_one_hot = slim.one_hot_encoding(labels=label, num_classes=numClass)
return img, label_one_hot
#参数设置
recordPath = r'E:\Learning\tensorflowPractice\dataProcess\tfrecord\catDog' #准备保存生成的tfrecord文件的路径
tfrecordFile = os.path.join(recordPath, 'test.tfrecord')
numClass = 2 #标签种类
imgChannel = 3 #图片通道数
imgH = 80 #图片数据集的高宽
imgW = 100
'''读取tfrecord文件,并返回(图像,标签)'''
img, label = readTFRecord(tfrecordFile, numClass, imgW, imgH, imgChannel)
#下面是显示读取的图片的简单示例
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=10, capacity=2000,
min_after_dequeue=1000)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess) # 使用start_queue_runners之后,才会开始填充队列
im, l = sess.run([img_batch, label_batch])
cv2.imshow('img', im[3])
cv2.waitKey(0)