使用 tfrecord 制作自己的数据集 (附上源代码)

相信很多刚入手深度学习的人,最早接触的程序就是Mnist 手写数字的识别。Mnist 数据集都已被事先整理好,我们只有拿来用即可。但是如何制作自己的数据集,相信很多刚入门的人还是会一团雾水。作为刚入门不就的小白,我也花了很长时间才完整的制作了自己的数据集。制作自己的数据集,大概可以分为这么几步:

Step1.首先要去收集自己的数据吧,可以是自己拍的图片,也可以是那种网上爬虫爬下来的图片。

Step2.建议最好将趴下来的图片重新命名,再用去训练,这样图片数据看起来比较整齐。特别是对有强迫症的同学来说,这是很重要的,总感觉名字不统一会觉得怪怪的。命名可以采用 name1,name2,name3.......这种形式。具体如何命名,我在之前的博客中也有详细介绍过,有需要的同学可以参考看下  点击打开链接    当然不改名字的话,也没什么影响,只是读取图片时需要采用不同的方法就好。

Step3. 接下来就是读取图片,在读取图片时也有些需要注意的细节。我在另一片博客中给出了详细的介绍  点击打开链接 

并制作成tfrecord形式,具体代码如下

import tensorflow as tf  
import numpy as np  
import os  
import cv2
from skimage import transform 
import skimage.io as io  

#%%
#def rename(file_dir,name):
#    '''将网上爬下来的图片重命名(更好的观看)'''
#    i=0
#    for file in os.listdir(file_dir):  #获取该路径文件下的所有图片
#        src = os.path.join(os.path.abspath(file_dir), file) #目标文件夹+图片的名称
#        dst = os.path.join(os.path.abspath(file_dir),  name+str(i) + '.jpg')#目标文件夹+新的图片的名称
#        os.rename(src, dst)
#        i=i+1        
#rename(file_dir+'/roses','rose')
#rename(file_dir+'/sunflowers','sunflower')

'''要将图片的路径完整的保存下来''' 
def get_files(file_dir): 
    roses=[]
    label_roses=[]
    sunflowers=[]
    label_sunflowers=[]

    for file in os.listdir(file_dir+'/roses'):  #获取该路径文件下的所有图片
        roses.append(file_dir+'/roses' +'/'+file)  #将图片存入一个列表中
        label_roses.append(0) # 将roses的标签设为0
     
    for file in os.listdir(file_dir+'/sunflowers'):
       sunflowers.append(file_dir+'/sunflowers' +'/'+file)
       label_sunflowers.append(1)     # 将sunflower的标签设为1 
    print('There are %d roses \n There are %d sunflowers' %(len(roses), len(sunflowers)))  
    
#把cat和dog合起来组成一个list(img和lab)
    image_list = np.hstack((roses, sunflowers))
    label_list = np.hstack((label_roses, label_sunflowers))

    #利用shuffle打乱顺序
    temp = np.array([image_list, label_list]) #转换成2维矩阵
    temp = temp.transpose() #转置
    np.random.shuffle(temp) #按行随机打乱顺序


    #从打乱的temp中再取出list(img和lab)
    image_list = list(temp[:, 0])  #取出第0列数据,即图片路径
    label_list = list(temp[:, 1]) #取出第0列数据,即图片路径
    label_list = [int(i) for i in label_list] #转换成int数据类型
    return image_list, label_list  
  
    
#%%    
def int64_feature(value):  
  """Wrapper for inserting int64 features into Example proto."""  
  if not isinstance(value, list):   #标签的转化形式
    value = [value]  
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))  
  
def bytes_feature(value):   #图片的转换格式
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  

  
def convert_to_tfrecord(images, labels, save_dir, name):  
    '''''convert all images and labels to one tfrecord file. 
    Args: 
        images: list of image directories, string type 
        labels: list of labels, int type 
        save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 
        name: the name of tfrecord file, string type, e.g.: 'train' 
    Return: 
        no return 
    Note: 
        converting needs some time, be patient... 
    '''  
    filename = os.path.join(save_dir, name + '.tfrecords')
    n_samples = len(labels)
    
    if np.shape(images)[0] != n_samples:
        raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], n_samples))
    
    # wait some time here, transforming need some time based on the size of your data.
    writer = tf.python_io.TFRecordWriter(filename)
    print('\nTransform start......')
    for i in np.arange(0, n_samples):
        try:  
            '''因为cv2读出的图片保存形式是BGR,要转换成RGB形式'''
            image = cv2.imread(images[i])    
            image = cv2.resize(image, (208, 208))    
            b,g,r = cv2.split(image)    
            rgb_image = cv2.merge([r,g,b])  
#            image = io.imread(images[i]) # type(image) must be array!  #这边是两种读取图像的方法  
#            image =transform.resize(image, (208, 208))
#            img = image * 255 
#            img = img.astype(np.uint8)   
              
            image_raw =  rgb_image.tostring()
            
            label = int(labels[i])
            example = tf.train.Example(features=tf.train.Features(feature={
                            'label':int64_feature(label),
                            'image_raw': bytes_feature(image_raw)}))
            writer.write(example.SerializeToString())
        except IOError as e:
            print('Could not read:', images[i])
            print('error: %s' %e)
            print('Skip it!\n')
    writer.close()
    print('Transform done!')


Step4. 生成tfrecord后,接着便是tfrecord的读取了  (只是在训练过程中才调用这个函数,主要的作用就是讲将tfrecord模式解码)

def read_and_decode(tfrecords_file, batch_size):  
    '''''read and decode tfrecord file, generate (image, label) batches 
    Args: 
        tfrecords_file: the directory of tfrecord file 
        batch_size: number of images in each batch 
    Returns: 
        image: 4D tensor - [batch_size, width, height, channel] 
        label: 1D tensor - [batch_size] 
    '''  
    # make an input queue from the tfrecord file  
    filename_queue = tf.train.string_input_producer([tfrecords_file])  
      
    reader = tf.TFRecordReader()  
    _, serialized_example = reader.read(filename_queue)  
    img_features = tf.parse_single_example(  
                                        serialized_example,  
                                        features={  
                                               'label': tf.FixedLenFeature([], tf.int64),  
                                               'image_raw': tf.FixedLenFeature([], tf.string),  
                                               })  
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)  
      
    ##########################################################  
    # you can put data augmentation here, I didn't use it  
    ##########################################################  
    # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.  
      
    image = tf.reshape(image, [208, 208,3])  
    label = tf.cast(img_features['label'], tf.float32)      
    image = tf.image.per_image_standardization(image)  
    image_batch, label_batch = tf.train.batch([image, label],  
                                                batch_size= batch_size,  
                                                num_threads= 64,   
                                                capacity = 2000)  #线程的个数及最大存储量
    return image_batch, tf.reshape(label_batch, [batch_size])  



猜你喜欢

转载自blog.csdn.net/qq_24193303/article/details/80000893
今日推荐