tensorflow 批量读取多个csv文件

tensorflow 批量读取多个csv文件

#!/usr/bin/python
# -*- coding:utf-8 -*-
import tensorflow as tf
import os

def csvfile(fileist):
    file_queue=tf.train.string_input_producer(filelist)
    reader=tf.TextLineReader()
    key,value=reader.read(file_queue)
    records=[['None'],['None']]
    example,label=tf.decode_csv(value,record_defaults=records)
    example_batch,label_batch=tf.train.batch([example,label],batch_size=9,num_threads=1,capacity=9)
    return  example_batch,label_batch
    pass

if __name__ == '__main__':
    listname=os.listdir("../data")
    print(listname)
    file_name="F:/window7x64/software/Pycharm/pyproject27/pythonml/com/itheima/deeplearning/data/"
    filelist=[os.path.join(file_name,line) for line in listname]
    print(filelist)
    example_batch, label_batch=csvfile(filelist)

    with tf.Session() as sess:
        # 定义线程协调器
        coord=tf.train.Coordinator()
        # 开启读文件线程
        thread=tf.train.start_queue_runners(sess,coord=coord)

        print(sess.run([example_batch,label_batch]))

        # 回收子线程
        coord.request_stop()
        coord.join(thread)
    pass

猜你喜欢

转载自blog.csdn.net/u011243684/article/details/85224061