tensorflow中多线程批量读取csv文件

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/XUEER88888888888888/article/details/86665015
import tensorflow as tf
import os


def csvread(filelist):
    """
    读取csv文件
    :param filelist: 文件路f径+名字的列表
    :return: 读取的内容
    """
    # 1.构造文件的队列
    file_queue = tf.train.string_input_producer(filelist)

    # 2.构造csv阅读器读取队列数据(按一行)
    reader = tf.TextLineReader()
    key,value = reader.read(file_queue)

    print("key的值:",key)
    print("value的值:",value)

    #3.对每行内容解码
    #
    records = [["None"],["None"]]
    example,lable = tf.decode_csv(value ,record_defaults=records)
    #print(example)
    #print(lable)

    #4.想要读取多个数据,就需要批处理
    example_batch,lable_batch = tf.train.batch([example,lable],batch_size=9,num_threads=1,capacity=9)

    #print(example_batch)
    return   example_batch,lable_batch


#批处理大小,跟队列,数据的数量没有影响,只决定 这批次处理多少数据


if __name__ == "__main__":
    # 1.找到文件,放入列表  路径+名字  ->列表当中
    file_name = os.listdir("./data/csvdata/")

    filelist = [os.path.join("./data/csvdata/",file) for file in file_name ]
    example,lable = csvread(filelist)

   #开启会话运行结果
    with tf.Session() as sess:
       #定义一个线程协调器
       coord = tf.train.Coordinator()

       #开启读文件的线程
       threads = tf.train.start_queue_runners(sess,coord=coord)

       #打印读取的内容
       print(sess.run([example,lable]))

       #回收子线程
       coord.request_stop()
       coord.join(threads)









猜你喜欢

转载自blog.csdn.net/XUEER88888888888888/article/details/86665015