基于Tensorflow框架的tfrecord文件的生成与读取(附有详细的注释,易理解)

版权声明:本文为博主原创文章,未经博主允许不得转载 https://blog.csdn.net/qq_35240640/article/details/89452884

基于Tensorflow框架的tfrecord文件的生成与读取

一. 根据自己已有数据集,生成tfrecord文件,附有详细的注释:

我所用的数据集只有猫狗两类,下面是部分示例图:
在这里插入图片描述

#coding=utf-8
import cv2 
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

def writeTFRecord(imgDir, recordPath, imgH, imgW, encoderLabel, recordName):
    '''
         创建tfrecord文件
    param:
        imgDir: 预处理图片路径
        recordPath: 预保存tfrecord文件路径
        imgH: 为了方便训练,将原图resize为指定的高
        imgW: 指定的宽
        encoderLabel: 编码标签,将标签以数字的形式呈现
        recordName: 预保存tfrecord文件名,eg:test.tfrecord
    '''
    imgPathList = [os.path.join(imgDir, img) for img in os.listdir(imgDir)]  #将所有图片路径绝对路径放入list中
    np.random.shuffle(imgPathList)  #打乱imgList中的图片顺序
    
    if not os.path.exists(recordPath):
        os.makedirs(recordPath)
    writer = tf.python_io.TFRecordWriter(path=os.path.join(recordPath, recordName))  #建立TFRecord存储器
    
    for i, ip in enumerate(imgPathList):
        img = cv2.imread(ip, cv2.IMREAD_COLOR)
    #     if(i == 0):  #显示第0张图
    #         cv2.imshow(ip, img)
    #         cv2.waitKey(0)
        if(type(img).__name__ != 'NoneType'):
            imgName = ip.split('\\')[-1].split('.')[0]  #name eg:cat.0.jpg
            label = encoderLabel[imgName]  #根据标签编码字典,找到其对应的标签数字
            
            img = cv2.resize(img, (imgW, imgH))  #将图片统一调成指定大小
            imgByte = img.tobytes()  #将图片转化为字节类型
            
            example = tf.train.Example(features=tf.train.Features(feature={
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgByte])),
                'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                }))  #创建 Example 对象,并且将 Feature 一一对应填充进去
    #             x = tf.train.Example(features=tf.train.Features(feature={'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), 
    #                                             'll': tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))}))
    #             print(x)    #举了一个例子,用于了解上述方法的作用及输出结果
            
            writer.write(example.SerializeToString())  #将 example序列化成 string 类型,然后写入
    writer.close()        

'''创建tfrecord文件的参数设置部分'''            
encoderLabel = {'dog':0, 'cat':1}   #编码标签,将标签以数字的形式呈现
imgH = 80   #图片数据集的高宽
imgW = 100
imgDir = r'F:\kaggle\val'   #图片存放路径,我的图片命名形式:F:\kaggle\val\dog.12335.jpg,前面是路径
recordPath = r'E:\Learning\tensorflowPractice\dataProcess\tfrecord\catDog' #准备保存生成的tfrecord文件的路径
recordName = 'val.tfrecord'    #tfrecord文件名
'''传入参数,开始创建'''
writeTFRecord(imgDir, recordPath, imgH, imgW, encoderLabel, recordName)
生成tfrecord文件:

在这里插入图片描述

二. 读取生成的tfrecord文件

#coding=utf-8
import cv2 
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

def readTFRecord(tfrecordFile, numClass, imgW, imgH, imgChannel):
    '''
         读取tfrecord文件
    param:
        tfrecordFile: tfrecord绝对文件名
        numClass: 标签种类数
        imgW、imgH: 定义的图片宽高
        imgChannel: 图片通道数
    return:
        (img, label)
    '''
    fileNameQueue = tf.train.string_input_producer([tfrecordFile])  #创建文件名队列(文件名list,epochs:传入多少批次,shuffle:每一epoch内是否打乱)
    reader  = tf.TFRecordReader()
    _, serializedExample = reader.read(fileNameQueue)
    features = tf.parse_single_example(serializedExample, 
                                       features={
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                           'labels': tf.FixedLenFeature([], tf.int64)
                                           })
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.cast(img, tf.float32)
    img = tf. reshape(img, [imgH, imgW, imgChannel])
     
    img = img / 255
    img = tf.subtract(img, 0.5) #减去0.5
    img = tf.multiply(img, 2.5)  #乘以2.5
     
    label = features['labels']
    label_one_hot = slim.one_hot_encoding(labels=label, num_classes=numClass)
    return img, label_one_hot  
      
#参数设置         
recordPath = r'E:\Learning\tensorflowPractice\dataProcess\tfrecord\catDog' #准备保存生成的tfrecord文件的路径 
tfrecordFile = os.path.join(recordPath, 'test.tfrecord')  
numClass = 2  #标签种类
imgChannel = 3  #图片通道数
imgH = 80   #图片数据集的高宽
imgW = 100
 
'''读取tfrecord文件,并返回(图像,标签)'''
img, label = readTFRecord(tfrecordFile, numClass, imgW, imgH, imgChannel)
 
 #下面是显示读取的图片的简单示例
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=10, capacity=2000,
                                                min_after_dequeue=1000)
   
init = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init)
    threads = tf.train.start_queue_runners(sess=sess)  # 使用start_queue_runners之后,才会开始填充队列
    im, l = sess.run([img_batch, label_batch])
    cv2.imshow('img', im[3])
    cv2.waitKey(0)

猜你喜欢

转载自blog.csdn.net/qq_35240640/article/details/89452884