基于tensorflow的MNIST探索(基于图像版本的实现与探索)——如何读取较大数据集进行训练(基于TFRecord)(二)

版权声明:转载需注明出处 https://blog.csdn.net/holmes_MX/article/details/81942537

注意:使用TFRecord是,存在两个实际的问题:

      1) 制作数据时,先制作train_data_label_list(内容为[ [imagePath, label], [imagePath, label], [], [], ,,, ]),然后对次数据进行打乱,否则即使采用tf.train.shuffle_batch给出的结果仍不是打乱的

        猜测原因:是由于每一类的数据过大,如果按照每一类这样的顺序读取图像写入TFRecord文件中,即使后继采用tf.train.shuffle_batch仍然由于每一类数据过大,shuffle到的仍是同一类,因此在读取数据时,即将数据打乱。

      2) 在训练时,由于采用TFRecord文件格式进行读取数据,需要配置TensorFlow中的tf.train.string_input_producer([filename_train], shuffle,num_epochs)来使用,在实际中如果指定num_epochs的数值,当数据从队列中读取完毕时,会抛出异常,此时处理异常程序会中止,而且在实践中,如果num_epochs指定数值,会导致训练不稳定,因此在实际训练或者测试时,num_epochs不指定数值,这样就会一直迭代下去,因此需要通过额外的参数来控制结束。

0. 写作目的

好记性不如烂笔头。

1. 前言

对于较小的数据处理时,我们可以完全读入内存,但是针对较大的数据集,如果一次性读入内存,将会出现一下两点问题:

1) 内存是否够用。

一些较大的数据集,如ImageNet数据集,一次性完全读入内存,需要较大的内存,即使内存足够,我们也不会使用该方法,原因如下第二点。

2) 程序的交互性差

如果我们将较大的数据集一次性读入时,首先在程序运行之前,我们第一步就要读入数据,如果在读入数据后,后面的代码出现了bug,修改bug后,仍需要较大的时间读入,浪费时间。(虽然对于Jupyter notebook或类似的环境修改一下代码,可以避免重新读入,但是读入的过程的仍在浪费时间,而且即使数据全都读入内存,但在一段时间内我们使用的数据仍只是一部分,因此对于内存也造成了一种资源浪费。)

因此,一种较好的方式就是,即用即读。在深度框架中,不同的深度学习框架都有自己的序列化格式,如caffe的LMDB文件,Tensorflow的TFrecord文件等。

在之前使用Caffe的LMDB文件时,我就有一种想法,以图像分类为例,可不可以先保存图像的路径,以及对应的标签,然后在使用时,进行即用即读。本文继上一篇博客---基于tensorflow的MNIST探索(基于图像版本的实现与探索)——如何读取较大数据集进行训练(一)之后,采用TFrecord格式进行训练。

2. 基于TFRecord格式的较大数据读入方式

2.1 制作TFRecord数据集

数据集的制作是前提,在学习前人制作的基础上[1],加入了自己的制作方式。具体思路为:首先得到train和test的图像绝对路径以及对应的标签,即train_data_label_list 和test_data_label_list,然后制作TFRecord数据集。具体代码如下:

import os
import numpy as np
import tensorflow as tf
from PIL import Image

os.chdir('../')
current_dir = os.getcwd()
data_dir = current_dir + '/mnist_image'
train_data_dir = data_dir + '/train'
test_data_dir = data_dir + '/test'

TF_dir = current_dir + '/TFdata'


def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def _float_feature(value):
    return tf.train.Feature(float_list = tf.train.FloatList(value = [value]))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))



def createList(train = True):
    if train:
        train_data_label_list = []
        for className in os.listdir(train_data_dir):
            for imageName in os.listdir(train_data_dir + '/' + str(className)):
                tempImageName = train_data_dir + '/' + str(className) + '/' + imageName
                train_data_label_list.append( [tempImageName, str(className)] )
        return train_data_label_list
    else:
        test_data_label_list = []
        for className in os.listdir(test_data_dir):
            for imageName in os.listdir(test_data_dir + '/' + str(className) ):
                tempImageName = test_data_dir + '/' + str(className) + '/' + str(imageName)
                test_data_label_list.append( [tempImageName, str(className)] )
        return test_data_label_list

def createTF(root, train = True):
    if train:
        current_number = 0
        train_TF_dir = TF_dir + '/TFtrain/'
        if not os.path.exists( train_TF_dir ):
            os.makedirs( train_TF_dir )
        train_TF_dir = train_TF_dir + 'train.tfrecords'
        if os.path.exists(train_TF_dir):
            os.remove(train_TF_dir)

        train_TFWriter = tf.python_io.TFRecordWriter( train_TF_dir )
        train_data_label_list = createList(train=True)

        np.random.shuffle(train_data_label_list)
        for item in train_data_label_list:
            current_number += 1
            img = Image.open( item[0] )
            img = img.convert("RGB")
            img = np.array( img )
            height = img.shape[0]
            width = img.shape[1]
            img_raw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature=
                                                                  {'label': _int64_feature(int(item[1])),
                                                                   'image': _bytes_feature(img_raw),
                                                                   'id': _int64_feature(int(current_number)),
                                                                   'height': _int64_feature(height),
                                                                   'width': _int64_feature(width)
                                                                   }))
            train_TFWriter.write(example.SerializeToString())
            if current_number % 1000 == 0:
                print("%d is done" % (current_number))
        print("train data is Done! train data number is: {}".format(current_number))
        train_TFWriter.close()
    else:
        current_number = 0
        test_TF_dir = TF_dir + '/TFtest/'
        if not os.path.exists(test_TF_dir):
            os.makedirs(test_TF_dir)
        test_TF_dir = test_TF_dir + 'test.tfrecords'
        if os.path.exists(test_TF_dir):
            os.remove(test_TF_dir)

        test_TFWriter = tf.python_io.TFRecordWriter(test_TF_dir)
        test_data_label_list = createList(train=False)
        np.random.shuffle(test_data_label_list)
        for item in test_data_label_list:
            current_number += 1
            img = Image.open(item[0])
            img = img.convert("RGB")
            img = np.array(img)
            height = img.shape[0]
            width = img.shape[1]
            img_raw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature=
                                                                  {'label': _int64_feature(int(item[1])),
                                                                   'image': _bytes_feature(img_raw),
                                                                   'id': _int64_feature(int(current_number)),
                                                                   'height': _int64_feature(height),
                                                                   'width': _int64_feature(width)
                                                                   }))
            test_TFWriter.write(example.SerializeToString())
            if current_number % 1000 == 0:
                print("%d is done" % (current_number))
        print("test data is Done! test data number is: {}".format(current_number))
        test_TFWriter.close()

        
def main(_):
    createTF(train_data_dir)
    createTF(test_data_dir, False)
     
if __name__ == "__main__":
    tf.app.run()

2.2 然后进行训练

训练的网络与我的博客——基于tensorflow的MNIST探索(基于图像版本的实现与探索)——如何读取较大数据集进行训练(一)中的网络一样,参数也相同(tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9))。

这里直接给出训练的结果:

后面训练再也升不上来了(此问题为tf.losses.softmax_cross_entropy函数参数选择出现问题,应该为logist和y_true)。

 

 

[Reference]

[1] http://blog.csdn.net/wiinter_fdd/article/details/72835939

猜你喜欢

转载自blog.csdn.net/holmes_MX/article/details/81942537
今日推荐