TensorFlow数据集Dataset的使用

利用数据集读取数据有三个基本步骤:

  1. 定义数据及的构造方法,如tf.data.TFRecordDataset(input_files)
  2. 定义遍历器,如one_shot_iterator,initializable_iterator
  3. 使用get_next()获取tensor

例:

import tensorflow as tf

def parser(record):
    features = tf.parse_single_example(
        record,
        features={
            'feat1':tf.FixedLenFeature([],tf.int64),
            'feat2':tf.FixedLenFeature([],tf.int64)
        }
    )

    return features['feat1'],features['feat2']

#数据集可以是一个tensor,或者文本文件
#若是tensor,则使用tf.data.from_tensor_slices(input_data)
#若是文本文件,则使用tf.data.TextLineDataset(input_files)
input_files = ['file1','file2']
dataset = tf.data.TFRecordDataset(input_files)
#由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式
#这里使用映射函数对每个数据进行解析
dataset = dataset.map(parser)

#通过一个迭代器获取数据
iterator = dataset.make_one_shot_iterator()
feat1,feat2 = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run([feat1,feat2]))

若需要动态输入数据,可以使用make_initializable_iterator()


input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
#由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式
#这里使用映射函数对每个数据进行解析
dataset = dataset.map(parser)

#通过一个迭代器获取数据
iterator = dataset.make_initializable_iterator()
feat1,feat2 = iterator.get_next()

with tf.Session() as sess:
    #注意要先对迭代器初始化
    sess.run(iterator.initializer,feed_dict={input_files:['file1','file2']})
    #由于不知道数据集大小,这里使用while循环,当全部数据访问完毕时,则抛出错误
    while True:
        try:
            sess.run([feat1,feat2])
        except:
            break

一些高层操作:

dataset.map(func)
dataset.shuffle(buffer_size)
dataset.batch(batch_size)
dataset.repeat(N)

具体用法见实例:

import tensorflow as tf

train_files = tf.train.match_filenames_once('tfrecords/train_file-*')
test_file = tf.train.match_filenames_once('tfrecords/test_file-*')

def preprocess_for_train(image,height,width,bbox):
    pass

def build_net(input):pass

def calc_loss(logit,label):pass

def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image':tf.FixedLenFeature([],tf.string),
            'label':tf.FixedLenFeature([],tf.int64),
            'height':tf.FixedLenFeature([],tf.int64),
            'width':tf.FixedLenFeature([],tf.int64),
            'channels':tf.FixedLenFeature([],tf.int64)
        }
    )

    image_data = tf.decode_raw(features['image'],tf.uint8)
    image_data.set_shape([features['height'],features['width'],features['channels']])
    label = features['label']

    return image_data,label

if __name__ == '__main__':
    image_size = 299
    batch_size = 100
    shuffle_buffer = 10000

    #读取数据集
    dataset = tf.data.TFRecordDataset(train_files)
    #将tfrecord转化成image,label的格式
    dataset = dataset.map(parse)
    #对image进行预处理
    dataset = dataset.map(
        lambda image,label:(preprocess_for_train(image,image_size,image_size,None),label)
    )
    #将数据集的顺序打乱,shuffle_buffer指定了队列中最少的元素个数
    dataset = dataset.shuffle(shuffle_buffer)
    #指定每次从迭代器中读出的数据个数,默认为1
    dataset = dataset.batch(batch_size)
    num_epoches = 10
    #将数据集中的数据重复num_epoches次,由于之前使用了shuffle,因此每个副本的顺序都不一定相同
    dataset = dataset.repeat(num_epoches)

    iterator = dataset.make_initializable_iterator()
    image_batch,label_batch = iterator.get_next()

    learning_rate = 0.01
    #构建网络,得到结果
    logit = build_net(image_batch)
    #结算损失
    loss = calc_loss(logit,label_batch)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    test_dataset = tf.data.TFRecordDataset(test_file)
    test_dataset = test_dataset.map(parse).map(
        lambda image,label:(tf.image.resize_images(image,[image_size,image_size]),label)
    )
    test_dataset = test_dataset.batch(batch_size)

    test_iterator = test_dataset.make_initializable_iterator()
    test_image_batch,test_label_batch = test_dataset.get_next()

    test_logit = build_net(test_image_batch)
    prediction = tf.argmax(test_logit,1)

    with tf.Session() as sess:
        #使用filename_match_once函数要初始化local_variables
        #使用迭代器要初始化iterator.initializer
        #global_variables一般都会初始化
        sess.run([tf.global_variables_initializer(),tf.local_variables_initializer(),iterator.initializer])

        while True:
            try:
                sess.run(train_step)
            except:
                break

        sess.run(test_iterator.initializer)
        test_results = []
        test_labels = []
        while True:
            try:
                pred,label = sess.run([prediction,test_label_batch])
                test_results.extend(pred)
                test_labels.extend(label)
            except:
                break

        correct = [float(y==y_) for (y,y_) in zip(test_results,test_labels)]
        acc = sum(correct)/len(correct)

        print(acc)

猜你喜欢

转载自blog.csdn.net/a13602955218/article/details/80766292