Tensorflow入门——读取csv文件入门

一、Tensorflow读取.csv代码报错

机器学习编程中,数据的输入部分一直是困扰博主的一个问题,博主在前期的学习中一直使用的是mnist手写数字项目,这个项目自带数据输入部分的代码。说句题外话,mnist这个项目很适合新手进行深度神经网络和卷积神经网络的编程入门,因为你不用考虑如何输入数据,你只需要改写核心的训练代码。

但是数据的输入和预处理部分是机器学习爱好者迟早要面对的部分,如果你参加一次机器学习大赛,你就会发现,读取.csv文件的方法很有必要掌握

什么是.csv文件

简单看一下概念:逗号分隔值(Comma-Separated Values,CSV,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。

这个我们不纠结,我直接上官方文档给的代码(根据我创建的.csv文件稍作了调整):

import tensorflow as tf

filename_queue = tf.train.string_input_producer(["file3.csv", "file4.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1]]
col1, col2, col3, col4 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.concat(0, [col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  #for i in range(1200):
    # Retrieve a single instance:
    #example, label = sess.run([features])
  print(sess.run([features]))

  coord.request_stop()
  coord.join(threads)

我创建的.csv文件是4列整数,所以上述代码做了微调,请看我的.csv文件内容:
file3.csv

1,2,3,4
2,3,43,5
4,5,56,6
3,4,5,4

file4.csv

1,2,3,4
2,3,4,45
2,3,3,4
3,45,5,3

两个文件都是每列有4个值,由于.csv文件的概念提到了逗号分隔值,我就在每行使用逗号进行值的分隔。实际上,我尝试用空格进行分隔,但代码运行报错了。

运行结果如下

/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv1.py
Traceback (most recent call last):
  File "/home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv1.py", line 13, in <module>
    features = tf.concat(0, [col1, col2, col3, col4])
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1187, in concat
    tensor_shape.scalar())
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/framework/tensor_shape.py", line 844, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (4,) and () are incompatible

Process finished with exit code 1

看了眼表,离超市关门还有半小时,如果解决不了,按照我不解决问题不罢休的脾气,水就只能在围合贩卖机买了


二、正确的代码

博主成功买到了超市的水

tf.concat这个函数原型在网上找到了两个版本:
错误的版本(可能是因为过时了)

tf.concat(concat_dim, values, name='concat')

正确的版本

tf.concat(values,concat_dim,name='concat')

concat_dim表示你在哪个维度上进行连接,从0开始计数,0表示第一个维度,1表示第二个维度…

官方文档给的代码是前者,concat_dim取值为0,0作为concat的第一个参数,但是这样是错的!

将代码进行修改,得到如下代码:

import tensorflow as tf

filename_queue = tf.train.string_input_producer(["file3.csv", "file4.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

record_defaults = [[1], [1], [1], [1]]
col1, col2, col3, col4 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.concat(0, [[col1], [col2], [col3], [col4]])
#features = tf.stack([col1, col2, col3, col4]) #把上一行注释掉用这一行也可以


init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    print(sess.run([features]))
#   for i in range(9):
#       print(sess.run([features]))

    coord.request_stop()
    coord.join(threads)

运行结果如下

/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py
2018-07-18 00:29:40.698139: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
[array([1, 2, 3, 4], dtype=int32)]

Process finished with exit code 0

可以看到file3.csv的第一行被打印了出来,如果想打印所有数据可以用最后几行的两行注释替换print...那一行,结果如下:

/home/umbrella/.pyenv/versions/2.7.15/bin/python /home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py
2018-07-18 00:33:10.401338: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
[array([1, 2, 3, 4], dtype=int32)]
[array([ 2,  3, 43,  5], dtype=int32)]
[array([ 4,  5, 56,  6], dtype=int32)]
[array([3, 4, 5, 4], dtype=int32)]
[array([1, 2, 3, 4], dtype=int32)]
[array([ 2,  3,  4, 45], dtype=int32)]
[array([2, 3, 3, 4], dtype=int32)]
[array([ 3, 45,  5,  3], dtype=int32)]
2018-07-18 00:33:10.419660: W tensorflow/core/kernels/queue_base.cc:277] _0_input_producer: Skipping cancelled enqueue attempt with queue not closed
Traceback (most recent call last):
  File "/home/umbrella/桌面/study_readdata1/day01/readcsv/readcsv2.py", line 26, in <module>
    print(sess.run([features]))
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/home/umbrella/.pyenv/versions/2.7.15/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: <exception str() failed>

Process finished with exit code 1

可以看到,文件中一共八行内容都被打印出来了,但最后有点报错,有时运行只会打印出file3.csv的内容,尚未解决(可参考TensorFlow中的队列 : https://blog.csdn.net/huachao1001/article/details/78083125 )。

发布了36 篇原创文章 · 获赞 41 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/umbrellalalalala/article/details/81091175