tensorflow csv 读取数据

import tensorflow as tf


def read_file():

    # 定义文件位置
    file_list= ["./data.csv"]

    # 创建队列
    file_queue = tf.train.string_input_producer(file_list)

    # 创建阅读器
    reader = tf.TextLineReader()

    # 读取数据  key 文件的名字 value 值 每次读取的时候默认只读取一行数据
    key, value = reader.read(file_queue)


    """
    record_defaults  
        1. 指定有几列
        2. 指定数据的类型
        3. 指定默认值
    """
    record_defaults = [["null"], ["null"], ["null"], ["null"], ["null"]]
    name1, name2, name3, name4, name5 = tf.decode_csv(value, record_defaults=record_defaults)

    # 通过tf.train.batch来批处理数据
    name1_batch, name2_batch, name3_batch, name4_batch, name5_batch = tf.train.batch(
                                                [name1, name2, name3, name4, name5],
                                                batch_size=3,   # 每次读取几个
                                                num_threads=1,  # 开启几个线程
                                                capacity=3)     # 队列的大小

    return name1_batch, name2_batch, name3_batch, name4_batch, name5_batch


if __name__ == "__main__":

    name1_batch, name2_batch, name3_batch, name4_batch, name5_batch = read_file()

    with tf.Session() as sess:
        # 线程协调员
        coord = tf.train.Coordinator()

        # 启动工作线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        print(sess.run(name1_batch))
        print("\n")
        print(sess.run(name2_batch))
        print("\n")
        print(sess.run(name3_batch))
        print("\n")
        print(sess.run(name4_batch))
        print("\n")
        print(sess.run(name5_batch))

        # 关闭线程
        coord.request_stop()

        # 回收线程
        coord.join(threads)



猜你喜欢

转载自blog.csdn.net/weixin_40639095/article/details/88988527