tensorflow实战之数据加载

在进行深度学习开发之前,我们都必须面对的是数据加载问题。如何加载我们自己的数据,是我们不得不面对的一个问题,本篇以数据加载作为我们tensorflow实战的开始,教你手把手实现自己的模型训练。


目录

一、tensorflow常见的数据集格式

二、内存数据

2.1、数据集说明

2.2、生成样本数据

三、TFRecord数据

四、Dataset数据集

4.1、生成Dataset对象

4.2、在Session中使用Dataset数据集

五、总结


 

一、tensorflow常见的数据集格式

  • 内存数据:该类数据集是通过直接读取数据,并通过注入机制来进行数据的加载,通常来说如果数据较多时会造成数据加载过慢,而且非常消耗内存。所以一般建议少量数据集时使用该方法;
  • TFRecord数据:这种数据是通过队列管道的方式来加载数据,通常我们会将数据先制作成TFRecord格式然后在进行加载,这种数据加载模式非常适合有大量训练的数据。所以一般建议如果数据集较多时可以考虑这在方法;
  • Dataset数据:特别强调这是下x1.4版本后的新特性,也是官方比较推荐的一种加载方法,他通过性能更高的输入管道进行加载数据,在后面部分我会着重介绍他的方法。这里,一般建议使用这种方法加载数据(tfrecord看起来太麻烦了,不知道大家有没有这种感受...);
  • tf.keras接口数据:只支持keras语法的数据,这里不详细说明了。

看到这里大家会有一定的想法了,对于前两种方法,由于其局限性和一定的阅读困难,最主要是tfrecord写起来太麻烦了。

二、内存数据

在读取图片的过程中,如果图像数据集较小,则可以直接全部读取,如果数据集较多,则可能消耗大量的内存。这时候可以考虑边读边取,但如果频繁的进行读取操作可能会影响性能。这里可以采用队列的方式进行,即使用两个线程并发:一个线程用于取数据进行训练;一个线程用于读数据到内存中。

2.1、数据集说明

下载链接:https://download.pytorch.org/tutorial/hymenoptera_data.zip

这是取自于imageNet的非常小的子集。其训练集和验证集的数目见下表:

类别 训练集 验证集
蜜蜂 121 83
蚂蚁 124

7

数据集的结构如下:

2.2、生成样本数据

2.2.1、加载文件路径和标签

读取文件夹下的图片路径和对应的标签,并存储到lfilenames 和labelsnames的list中,并对其进行shuffle操作。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
import numpy as np
from sklearn.utils import shuffle
import matplotlib.pyplot as plt

def load_sample(sample_dir,shuffleflag = True):
    '''递归读取文件。只支持一级。返回文件名、数值标签、数值对应的标签名'''
    print ('loading sample  dataset..')
    lfilenames = []
    labelsnames = []
    for (dirpath, dirnames, filenames) in os.walk(sample_dir):#递归遍历文件夹
        for filename in filenames:                            #遍历所有文件名
            #print(dirnames)
            filename_path = os.sep.join([dirpath, filename])
            lfilenames.append(filename_path)               #添加文件名
            labelsnames.append( dirpath.split('\\')[-1] )#添加文件名对应的标签

    lab= list(sorted(set(labelsnames)))  #生成标签名称列表
    labdict=dict( zip( lab  ,list(range(len(lab)))  )) #生成字典

    labels = [labdict[i] for i in labelsnames]
    if shuffleflag == True:
        return shuffle(np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)
    else:
        return (np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)

2.2.1、队列中批次读取数据

读取批次数据的具体步骤:

  1. 使用tf.train.slice_input_producer函数生成队列;
  2. 加载数据并进行预处理;
  3. 使用tf.train.batch将预处理后的数据变成批次数据。
def get_batches(image,label,input_w,input_h,channels,batch_size):

    queue = tf.train.slice_input_producer([image,label])  #使用tf.train.slice_input_producer实现一个输入的队列
    label = queue[1]                                        #从输入队列里读取标签

    image_c = tf.read_file(queue[0])                        #从输入队列里读取image路径

    image = tf.image.decode_bmp(image_c,channels)           #按照路径读取图片

    image = tf.image.resize_image_with_crop_or_pad(image,input_w,input_h) #修改图片大小


    image = tf.image.per_image_standardization(image) #图像标准化处理,(x - mean) / adjusted_stddev

    image_batch,label_batch = tf.train.batch([image,label],#调用tf.train.batch函数生成批次数据
               batch_size = batch_size,
               num_threads = 64)

    images_batch = tf.cast(image_batch,tf.float32)   #将数据类型转换为float32

    labels_batch = tf.reshape(label_batch,[batch_size])#修改标签的形状shape
    return images_batch,labels_batch

2.2.3、在Session中使用数据集

通过在静态图Session中启动一个带有协调器的队列线程来获取数据,具体如下:

(image,label),labelsnames = load_sample(data_dir)   #载入文件名称与标签
batch_size = 16
image_batches,label_batches = get_batches(image,label,28,28,1,batch_size)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)  #初始化

    coord = tf.train.Coordinator()          #开启列队
    threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    try:
        for step in np.arange(10):
            if coord.should_stop():
                break
            images,label = sess.run([image_batches,label_batches]) #注入数据

    except tf.errors.OutOfRangeError:
        print("Done!!!")
    finally:
        coord.request_stop()

    coord.join(threads)                             #关闭列队

完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/load_imagedata.py

执行上述完整代码后的结果如下:

三、TFRecord数据

TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

3.1、生成TFRecord数据集

这里我们使用第二节中的数据集进行演示,同时使用load_sample函数来加载数据集的路径和标签数据,在此可以参考上一节。在使用TFRecord数据集之前我们需要将我们的数据制作成TFRecord的格式方便后续的训练。具体实现流程如下:

  1. 按照load_sample读取的数据进行读取图片;
  2. 将读取的图片和标签进行打包组合在一起;
  3. 使用TFRecordWriter对象的write方法将图片和标签写入文件中。
def makeTFRec(filenames,labels): #定义函数生成TFRecord
    writer= tf.python_io.TFRecordWriter("mydata.tfrecords") #通过tf.python_io.TFRecordWriter 写入到TFRecords文件
    for i in tqdm( range(0,len(labels) ) ):
        img=Image.open(filenames[i])
        img = img.resize((256, 256))
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
                #存放图片的标签label
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
                #存放具体的图片
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            })) #example对象对label和image数据进行封装

        writer.write(example.SerializeToString())  #序列化为字符串
    writer.close()  #数据集制作完成

3.2、在队列中批量读取数据

通常在训练集中我们需要对数据进行乱序操作,并按指定的方法组合,而在测试集中,只需一次加载,不需要乱序和批次组合。具体实现如下:

def read_and_decode(filenames,flag = 'train',batch_size = 3):
    #根据文件名生成一个队列
    if flag == 'train':
        filename_queue = tf.train.string_input_producer(filenames)#默认已经是shuffle并且循环读取
    else:
        filename_queue = tf.train.string_input_producer(filenames,num_epochs = 1,shuffle = False)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example, #取出包含image和label的feature对象
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    #tf.decode_raw可以将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [256,256,3])
    #
    label = tf.cast(features['label'], tf.int32)

    if flag == 'train':
        image = tf.cast(image, tf.float32) * (1. / 255) - 0.5     #归一化
        img_batch, label_batch = tf.train.batch([image, label],   #还可以使用tf.train.shuffle_batch进行乱序批次
                                                batch_size=batch_size, capacity=20)

        return img_batch, label_batch

    return image, label

3.3、在Session中使用数据集

TFRecordfilenames = ["mydata.tfrecords"]
image, label =read_and_decode(TFRecordfilenames,flag='test')  #以测试的方式打开数据集

#开始一个会话读取数据
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())   #初始化本地变量,没有这句会报错
    #启动多线程
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)

    try:
        while True:
            example, examplelab = sess.run([image,label])#在会话中取出image和label
    except tf.errors.OutOfRangeError:
        print('Done Test -- epoch limit reached')
    finally:
        coord.request_stop()
        coord.join(threads)
        print("stop()")

完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/load_imagedata.py

执行上述完整代码后的结果如下:

参考链接:

TensorFlow高效读取数据的方法

Tensorflow(一) TFRecord生成与读取

四、Dataset数据集

Dataset数据集是由tf.data.Dataset接口实现的,通过该接口使得tensorflow能够更加方便、快速的处理数据集。Dataset对象能直接对其上的数据进行相关乱序、迭代等操作(越来越和pytorch有点像了)。

Dataset对象的创建:

  • tf.data.Dataset.from_tensors:根据内存对象创建,且只有一个元素;
  • tf.data.Dataset.from_tensors_slices:根据内存对象创建,对象可以是list、dict、set、Numpy等;
  • tf.data.Dataset.from_generator:根据迭代器生成对象。

这几种方法比较类似,一般用的比较多的是第二种方法,建议着重掌握这种方法。

Dataset对象支持的操作:

1、dataset.shuffle(buffer_size,seed=None,reshuffle_each_iteration=None):将数据内部的元素顺序打乱

  • buffer_size:随机打乱元素排序的大小,一般越大越混乱。
  • seed:随机种子,一般不用管。
  • reshuffle_each_iteration:是否每次迭代都乱序。

2、dataset.repeat(count=None):生成重复的数据,count代表重复的次数

3、dataset.map(map_func,num_parallel_cell=None):通过map_func来将数据集中的每一个元素进行转换处理

  • map_func:处理函数
  • num_parallel_cell:并行处理的线程个数

4、dataset.batch(batch_size,drop_remainder):将数据集的元素按照批次进行组合

  • batch_size:批次大小。
  • drop_remainder:是否忽略批次组合后剩余的数据

5、dataset.prefetch(buffer_size):设置从数据集中取数据时的最大缓冲区。一般推荐将buffer_size设置为tf.data.experimental.AUTOTUNE,代表系统自动调节大小

一般来讲,处理数据的合理步骤为:创建Dataset对象->乱序数据集(shuffle)->重复数据集(repeat)->数据预处理(map)->设定批次(batch)->设定缓存(prefetch)

在训练过程中有时会出现某次数据不足的情况,一般造成这种情况的原因有:数据总数不能被batch_size整除,而在训练过程中的剩余数据也会进入训练。解决的方法主要有:第一种方法是对数据进行repeat操作,在进行batch设置;第二种方法是将batch函数中的drop_remainder参数设置为True,这样在训练过程中就会丢弃剩余数据,从而避免批次数据不足的情况

Dataset数据集的操作步骤:

(1)创建Dataset对象;

(2)对Dataset对象进行变换操作;

(3)创建Dataset迭代器;

(4)在会话Session中取数据。

4.1、生成Dataset对象

这里我们还是使用之前的数据集进行测试,同时使用load_sample函数来加载数据集的路径和标签数据,在此可以参考第二节。这里我们使用了上述的一些函数来进行创建,具体实现如下:

def _norm_image(image,size,ch=1,flattenflag = False):    #定义函数,实现归一化,并且拍平
    image_decoded = image/255.0
    if flattenflag==True:
        image_decoded = tf.reshape(image_decoded, [size[0]*size[1]*ch])
    return image_decoded

def dataset(directory,size,batchsize,random_rotated=False):#定义函数,创建数据集
    """ parse  dataset."""
    (filenames,labels),_ =load_sample(directory,shuffleflag=False) #载入文件名称与标签
    def _parseone(filename, label):                         #解析一个图片文件
        """ Reading and handle  image"""
        image_string = tf.read_file(filename)         #读取整个文件
        image_decoded = tf.image.decode_image(image_string)
        image_decoded = tf.image.resize(image_decoded, size)  #变化尺寸
        image_decoded = _norm_image(image_decoded,size)#归一化
        image_decoded = tf.cast(image_decoded,dtype=tf.float32)
        label = tf.cast(  tf.reshape(label, []) ,dtype=tf.int32  )#将label 转为张量
        return image_decoded, label

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))#生成Dataset对象
    dataset = dataset.map(_parseone)   #有图片内容的数据集

    dataset = dataset.batch(batchsize) #批次划分数据集

    return dataset

4.2、在Session中使用Dataset数据集

def getone(dataset):
    iterator = dataset.make_one_shot_iterator()			#生成一个迭代器
    one_element = iterator.get_next()					#从iterator里取出一个元素
    return one_element

sample_dir=r"./hymenoptera_data/train"
size = [96,96]
batchsize = 10
tdataset = dataset(sample_dir,size,batchsize)
print(tdataset.output_types)  #打印数据集的输出信息
print(tdataset.output_shapes)

one_element1 = getone(tdataset)				#从tdataset里取出一个元素

with tf.Session() as sess:	# 建立会话(session)
    sess.run(tf.global_variables_initializer())  #初始化

    try:
        for step in np.arange(1):
            image,label = sess.run(one_element1)

    except tf.errors.OutOfRangeError:           #捕获异常
        print("Done!!!")

完整代码链接:https://github.com/kingqiuol/learning_tensorflow/blob/master/data/load_imagedata.py

执行上述完整代码后的结果如下:

五、总结

这里我们对tensorflow数据的加载方式有了一定的了解,同时也建议大家在以后的模型搭建过程中尽量使用Dataset这种数据加载方式,在之后的章节中我将继续讲解数据加载的进阶,讲解在实际工程中的使用方法,哈哈哈,有兴趣的话可以用继续看看我的下一篇文章。

发布了69 篇原创文章 · 获赞 26 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/wxplol/article/details/104172736