TensorFlow读取目录下的图片,目录参数必须是list,然后转成tf.string
path = tf.convert_to_tensor(path, dtype=tf.string)
通过tf.train.string_input_producer生成文件队列,shuffle表示是否打乱图片的顺序, num_epochs表示图片被加载几次,一般在训练的过程有epoch的概念,就是训练集被计算几轮。
file_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=2)
接下来使用tf.WholeFileReader,生成image_reader, 该image_reader调用read函数读取文件队列的图片内容,返回key,value,其中value保存了所有的训练图像数据
image_reader = tf.WholeFileReader()
key, image = image_reader.read(file_queue)
接下来解码图像数据,tf.image.decode_jpeg,根据图片格式有decode_jpeg、decode_bmp等
image = tf.image.decode_jpeg(image)
最后就是启动实例,创建线程,获取图像数据
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord = coord)
try:
while not coord.should_stop():
plt.figure
plt.imshow(image.eval())
plt.show()
except tf.errors.OutOfRangeError:
print ('done')
finally:
coord.request_stop()
coord.join(threads)
图片显示使用matplotlib.pyplot模块,安装步骤
sudo apt-get install python-tk
sudo pip2 install -i https://pypi.tuna.tsinghua.edu.cn/simple matplotlib
最后显示图像如下:
最后贴上调试后的完整代码:
#-*- encoding:utf-8 -*-
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
print (tf.__version__)
def readimg(file_path):
image_raw = tf.gfile.FastGFile(file_path).read()
img = tf.image.decode_jpeg(image_raw) #Tensor
with tf.Session() as sess:
print (type(image_raw))
print (type(img))
print (sess.run(img))
plt.figure(1)
plt.imshow(sess.run(img))
plt.show()
def file_name(file_dir):
for root, dirs, files in os.walk(file_dir):
print (root)
print (dirs)
print (files)
def file_name2(file_dir):
L = []
for root, dirs, files in os.walk(file_dir):
for file in files:
if os.path.splitext(file)[1] == '.jpg':
L.append(os.path.join(root, file))
print root
return L
def readimg2(path):
nlen = len(path)
path = tf.convert_to_tensor(path, dtype=tf.string)
file_queue = tf.train.string_input_producer(path, shuffle=True, num_epochs=2)
image_reader = tf.WholeFileReader()
key, image = image_reader.read(file_queue)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord = coord)
try:
while not coord.should_stop():
plt.figure
plt.imshow(image.eval())
plt.show()
except tf.errors.OutOfRangeError:
print ('done')
finally:
coord.request_stop()
coord.join(threads)
def load_img(path_queue):
reader = tf.WholeFileReader()
key, value = reader.read(path_queue)
img = tf.image.decode_jpeg(value, channels=3)
img = tf.reshape(img,shape=(224,224,3))
return img
path = file_name2('/home/jyf/jyf/python/loadimage')
#readimg(path[0])
print path[0]
readimg2(path)