tensorflow training process, using an iterative data tf.data

tensorflow training process, reading data from the CSV (may be txt) processes to achieve

The main part of the code are the following:

  1. 创建dataset create_dataset
    1. dataset, the operation of the read line decode_line performed
  2. Create training and validation iterator create_iterator
  3. Creating Predictions iterator create_iterator_for_predict
  4. Data preprocessing
    1. Pretreatment characteristics of the tag reprocessing_feature_label
    2. Pretreatment characteristics (characteristics may be used alone for pretreatment, and without disturbing the correspondence between feature tag) reprocessing_feature
  5. Training process data iterative testing train_validation_test
  6. Iterative testing process forecast data predict_validation_test

# encoding: utf-8
"""
@author: mry
@contact: [email protected]
@time: 2020/03/02 23:50
@file: dataset_FromStringHandle.py
@desc: 
        训练集 使用 make_one_shot_iterator 
            并不是只对数据集循环一次 而是循环dataset指定的次数
            如果创建dataset时不指定循环次数,那么就可以无限循环
        验证集 使用 make_initializable_iterator 

        测试集 使用 make_one_shot_iterator 
好处: 验证集 不会打断训练集 
"""

import pandas as pd 
import numpy as np 
import tensorflow as tf 
import time


# dataset 中,对读取到的一行 进行的操作 
#   主要是指定每一列的数据类型  以及如何切分 features 和 labels
def decode_line(line,len=2):

    record_defaults = [[0.0] for i in range(len)]
    items = tf.decode_csv(line, record_defaults)
    features_0 = items[0]
    labels = items[1]
    return features_0,labels


# 创建dataset 
#   主要是指定 从哪里读取文件 指定批次,迭代次数, 是否打乱,是否对文件进行过滤等操作
def create_dataset(filename,batch_size=32,is_shuffle=False,n_repeats=0):
    # # # 只有第一个文件去掉第一行
    # dataset = tf.data.TextLineDataset(filename).skip(1)   #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"0"))
    # # 所有的文件都去除第一行
    dataset = tf.data.Dataset.from_tensor_slices(filename)
    dataset = dataset.flat_map(
        lambda f: (
            tf.data.TextLineDataset(f)
            .skip(1)       #.filter(lambda line: tf.not_equal(tf.substr(line,0,1),"#"))
            ))

    if n_repeats >0:
        dataset =dataset.repeat(n_repeats)
    else:
        dataset =dataset.repeat()
    # dataset = dataset.map(decode_line)
    dataset = dataset.map(lambda x:decode_line(x,len=2))

    if is_shuffle:
        dataset = dataset.shuffle(10000)
    dataset = dataset.batch(batch_size)
    return dataset


# 创建 训练与验证 迭代器
def create_iterator(training_filenames=None,validation_filenames=None,
    handle = tf.placeholder(tf.string, shape=[])):
    '''
        __doc__:
            创建训练集+验证集 的迭代器
        Args:
            training_filenames:
            validation_filenames:

        Returns:
            next_element: 下一个batch 
            training_iterator: 初始化为训练集迭代器
            validation_iterator: 初始化为验证集迭代器
    '''
    # 训练集的dataset
    training_dataset = create_dataset(training_filenames,batch_size=10,is_shuffle=False)    
    # 测试集的dataset
    validation_dataset = create_dataset(validation_filenames,batch_size=5,is_shuffle=False)
    training_iterator = training_dataset.make_one_shot_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()  

    # # 操作占位符,控制 train or validation
    # handle = tf.placeholder(tf.string, shape=[])
    # iterator 是一个由handle控制的迭代器,可以切换为 训练集迭代器 或 验证集迭代器
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_dataset.output_types, training_dataset.output_shapes)
    next_element = iterator.get_next()

    return next_element,training_iterator,validation_iterator

# 特征与标签的预处理过程
def reprocessing_feature_label(features_0,labels):
    features_0 = tf.reshape(labels,[-1,1])
    labels = tf.reshape(labels,[-1])
    return features_0,labels

# 特征的预处理过程 
def reprocessing_feature(features_0):

    features_0 = tf.reshape(features_0,[-1,1,1])
    return features_0


# 创建预测的迭代器
def create_iterator_for_predict(filenames=None):
    '''
    __doc__:
        创建 评价集 或 预测集 的迭代器
        预测集 如果没有结果标签列(labels),需要在csv中提前添加labels列 
    '''
    predict_dataset = create_dataset(filenames,batch_size=5,is_shuffle=False,n_repeats=1)
    predict_iterator = predict_dataset.make_one_shot_iterator()
    next_element = predict_iterator.get_next()
    return next_element


# 训练过程测试
def train_validation_test():
    training_filenames = ['./t1.csv','./t3.csv']
    validation_filenames = ['./t2.csv']

    sess = tf.Session()
    handle = tf.placeholder(tf.string, shape=[])
    next_element,training_iterator,validation_iterator = create_iterator(training_filenames,validation_filenames,handle)
    
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())   

    features_0, labels=next_element
    # features_0, labels = reprocessing_feature_label(features_0,labels)
    features_0 = reprocessing_feature(features_0)
    for _ in range(50):
        # Initialize an iterator over the training dataset.
        
        for _ in range(10):
            f_0, l=sess.run([features_0,labels],feed_dict={handle:training_handle})
            print(f_0)
            print(l)
        print('>'*20)
        # Initialize an iterator over the validation dataset.
        sess.run(validation_iterator.initializer)
        for _ in range(5):
            f_0, ne=sess.run([features_0,labels],feed_dict={handle:validation_handle})
            print(f_0)
            print(ne)
            # print(ne)
        print('>'*50)
        time.sleep(0.5)


# 预测过程测试
def predict_test():

    filenames = ['./t2.csv']
    g_xyk_predict=tf.Graph()
    with g_xyk_predict.as_default():

        next_element = create_iterator_for_predict(filenames)
        features_0,labels=next_element
        features_0,labels = reprocessing_feature_label(features_0,labels)
        # 真实情况下,该位置放模型
        # 例如 prediction = rnn(features_0)
        sess = tf.Session()

        # 真实情况下,该位置从check point 或者 pb 模型加载训练的模型
        
        result_list = []
        while True:
            try:
                l = sess.run(labels)
                print(l)
                print('>'*30)
                result_list+=l.tolist()
            except tf.errors.OutOfRangeError:
                break
    print(result_list)
    return result_list


if __name__ =='__main__':
    # 训练 与 验证 的过程
    train_validation_test()
    # 预测的过程
    predict_test()

Read data

./t1.csv ./t2.csv ./t3.csv

There are two features and labels (column name is what it does not matter)

[Here there are no pictures uploaded successfully replaced with a form of]

features label
0 0
1 1
2 2
3 3

operation result

[[[0.]]
 [[1.]]
 [[2.]]
 [[3.]]
 [[4.]]
 [[5.]]
 [[6.]]
 [[7.]]
 [[8.]]
 [[9.]]]
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[[[10.]]
 [[11.]]
 [[12.]]
 [[13.]]
 [[14.]]
 [[15.]]
 [[16.]]
 [[17.]]
....
[0. 1. 2. 3. 4.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[5. 6. 7. 8. 9.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[10. 11. 12. 13. 14.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[15. 16. 17. 18. 19.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[20. 21. 22. 23. 24.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[25. 26. 27. 28. 29.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[30. 31. 32. 33. 34.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[35. 36. 37. 38. 39.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
[40. 41. 42. 43. 44.]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
....

 

Published an original article · won praise 0 · Views 11

Guess you like

Origin blog.csdn.net/weixin_43854495/article/details/104815418