关于环境
- python 3.6.5
- tensorflow 1.14.0
- numpy 1.16.0
1.通过文件名读取数据的小demo
import tensorflow as tf
# print(tf.__version__)
images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
labels = [1, 2, 3, 4]
#生成tensor,图片以及对应的lable可直接用于后续的数据处理,生成文件队列,前端
[images, labels] = tf.train.slice_input_producer([images, labels],
num_epochs=2,#图片读取几次
shuffle=True)
#后端
with tf.Session() as sess:
#进行赋值,但是还没执行,这时候run才执行赋值操作
sess.run(tf.local_variables_initializer())
#队列填充
tf.train.start_queue_runners(sess=sess)
for i in range(8):
#获取文件队列的
#可以通过文件读取的函数进行读取
print(sess.run([images, labels]))
2.通过路径真的读取数据
import tensorflow as tf
filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']
#产生文件队列slice输出一个tensor,string输出的是一个文件队列
file_queue = tf.train.string_input_producer(filename,
shuffle=True,
num_epochs=2)
reader = tf.WholeFileReader()
#读取文件队列中的文件
key, value = reader.read(file_queue)
with tf.Session() as sess:
#局部变量进行赋值
sess.run(tf.local_variables_initializer())
#定义文件队列填充的线程
tf.train.start_queue_runners(sess=sess)
for i in range(6):
print(sess.run([key, value]))
3.通过上一节打包的函数读取数据
import urllib
import os
import sys
import tarfile
import glob
import pickle
import numpy as np
import cv2
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
classification = ['airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']
#默认的图片解压缩形式
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_DIR = 'data'
# download_and_uncompress_tarball(DATA_URL, DATA_DIR)
folders = r'E:\zhuomian\tf_read_write\data_manager/data/cifar-10-batches-py'
#通过golb确定当前的图片位置
trfiles = glob.glob(folders + "/data_batch*")
data = []
labels = []
for file in trfiles:
dt = unpickle(file)
print(dt)
#解析出data和对应的lables
data += list(dt[b"data"])
labels += list(dt[b"labels"])
#labels相当于所有图片对应的类别
#将图片解析成第i个,3通道,32*32的图片
imgs = np.reshape(data, [-1, 3, 32, 32])
for i in range(imgs.shape[0]):
#拿到数据
im_data = imgs[i, ...]
#转换维度,将通道放在最后边
im_data = np.transpose(im_data, [1, 2, 0])
#将RGB转为BGR 方便opencv读取
im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
#通过类别来命名文件名 label[i] 拿到的是类别Id
f = "{}/{}".format(r"E:\zhuomian\tf_read_write\data_manager/data/image/train", classification[labels[i]])
#如果文件不存在的话就创建新的文件
if not os.path.exists(f):
os.mkdir(f)
#写入图片
cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)
慢慢来吧,通过时间的和精力的积累