2 tensorflow中读取数据

关于环境

  • python 3.6.5
  • tensorflow 1.14.0
  • numpy 1.16.0

1.通过文件名读取数据的小demo

import tensorflow as tf
# print(tf.__version__)
images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
labels = [1, 2, 3, 4]

#生成tensor,图片以及对应的lable可直接用于后续的数据处理,生成文件队列,前端
[images, labels] = tf.train.slice_input_producer([images, labels],
                              num_epochs=2,#图片读取几次
                              shuffle=True)
#后端
with tf.Session() as sess:
    #进行赋值,但是还没执行,这时候run才执行赋值操作
    sess.run(tf.local_variables_initializer())
    #队列填充
    tf.train.start_queue_runners(sess=sess)

    for i in range(8):
        #获取文件队列的
        #可以通过文件读取的函数进行读取
        print(sess.run([images, labels]))

2.通过路径真的读取数据

import tensorflow as tf

filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']

#产生文件队列slice输出一个tensor,string输出的是一个文件队列
file_queue = tf.train.string_input_producer(filename,
                                            shuffle=True,
                                            num_epochs=2)
reader = tf.WholeFileReader()
#读取文件队列中的文件
key, value = reader.read(file_queue)

with tf.Session() as sess:
    #局部变量进行赋值
    sess.run(tf.local_variables_initializer())
    #定义文件队列填充的线程
    tf.train.start_queue_runners(sess=sess)
    for i in range(6):
        print(sess.run([key, value]))

3.通过上一节打包的函数读取数据

import urllib
import os
import sys
import tarfile
import glob
import pickle
import numpy as np
import cv2

def download_and_uncompress_tarball(tarball_url, dataset_dir):
  """Downloads the `tarball_url` and uncompresses it locally.
  Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = tarball_url.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)

  def _progress(count, block_size, total_size):
    sys.stdout.write('\r>> Downloading %s %.1f%%' % (
        filename, float(count * block_size) / float(total_size) * 100.0))
    sys.stdout.flush()
  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  print()
  statinfo = os.stat(filepath)
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)

classification = ['airplane',
                  'automobile',
                  'bird',
                  'cat',
                  'deer',
                  'dog',
                  'frog',
                  'horse',
                  'ship',
                  'truck']
#默认的图片解压缩形式
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_DIR = 'data'


# download_and_uncompress_tarball(DATA_URL, DATA_DIR)

folders = r'E:\zhuomian\tf_read_write\data_manager/data/cifar-10-batches-py'
#通过golb确定当前的图片位置
trfiles = glob.glob(folders + "/data_batch*")

data  = []
labels = []
for file in trfiles:
    dt = unpickle(file)
    print(dt)
    #解析出data和对应的lables
    data += list(dt[b"data"])
    labels += list(dt[b"labels"])
#labels相当于所有图片对应的类别

#将图片解析成第i个,3通道,32*32的图片
imgs = np.reshape(data, [-1, 3, 32, 32])

for i in range(imgs.shape[0]):
    #拿到数据
    im_data = imgs[i, ...]
    #转换维度,将通道放在最后边
    im_data = np.transpose(im_data, [1, 2, 0])
    #将RGB转为BGR 方便opencv读取
    im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
    #通过类别来命名文件名 label[i] 拿到的是类别Id
    f = "{}/{}".format(r"E:\zhuomian\tf_read_write\data_manager/data/image/train", classification[labels[i]])
    #如果文件不存在的话就创建新的文件
    if not os.path.exists(f):
        os.mkdir(f)
    #写入图片
    cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)

慢慢来吧,通过时间的和精力的积累

发布了17 篇原创文章 · 获赞 0 · 访问量 284

猜你喜欢

转载自blog.csdn.net/DropJing/article/details/104848029
今日推荐