一,tensorflow中使用自己的数据集的流程

# -*- coding: utf-8 -*-

'''

准备tf训练集的流程:
    1,使用操作系统os中的方法,得到训练图片和标注(图片)的"文件名字列表",
      相当于在给训练文件集合(库)建立起来一套"索引(名字)"
      image_filename_lists
      label_filename_lists
      这两个列表结构的第一个维度必须一样,也就是:有多少条样本数据(行)就有对应的标签数据(行)
    2,调用tf.train.slice_input_producer()函数,依据名字(索引列表),建立一个队列
      filename_queue=tf.train.slice_input_producer([image_filename_lists,label_filename_lists])
      该函数返回image_filename_lists和label_filename_lists对应的队列:queue列表
      filename_queue[0]与image_filename对应;
      filename_queue[1]与image_filename对应
      使用队列的理由是:可以对这些名字列表中的元素进行随机化处理;
                      第二个理由是,可以使用多线程机制.
    3,我们可以以队列的filename_queue[i]为基础,定义tf数据读取的操作符:例如:
      image_contents = tf.read_file(input_queue[0]) #定义读取内容的操作符
      image = tf.image.decode_jpeg(image_contents, channels=1) #图像预处理操作符
      这一步就相当于:利用索引从数据库中获取对应的内容.执行数据库的底层读写操作.
    4,上面定义了对从库中读取内容后的预处理操作符,定义的是:从索引得到“单个”文件并进行处理的

      流程,但是还没有组成batch批量数据,接下来,把利用索引从库中读取的内容,
      并按照batch_size的要求"封装,打包",也就是把batch_size个预处理后的图片内容,
      放在一个张量中,也相当于:从总表中得到包含batch_size这么多"行"的"子表".
      这个运算符是由image_batch = tf.train.batch()函数完成的
      image_batch = tf.train.batch([image], batch_size=batch_size)
      这个操作的输入参数是上一步对文件内容进行预处理得到的结果的操作符:image
      函数返回的是:获得一组(batch_size个)数据的操作符.
    5,使用控制台的run函数,这是按照以上操作符描述的流程,读取数据,得到一组batch_size数据:
      picture_batch = sess.run(image_batch)
      
需要强调的是:以上的操作都是定义的在tf中将要执行的一些列操作符号,在定义操作符的时候,
            这些操作并没有被执行,他们相当于用语言,描绘一张tf使用的运行流程图
            直到最后的sess.run(op操作符),才执行该流程图中定义的一些列关联的运算.
'''
import tensorflow as tf
import numpy as np  
import os
import matplotlib.pyplot as plt
tf.reset_default_graph()
 
def show_image(image):
    #显示单张图片
    image = np.transpose(image,(2,0,1)) #需要改变一下维度
    plt.imshow(image[0],cmap='Greys_r')                                  
    plt.show()
 
def show_batch_image(batch_image):
    #显示一个batch的图片
    for i in range(batch_size):
        image = batch_image[i]
        show_image(image)
           
image_W = 16
image_H = 16

 
train_dir = "..\\test_data\\" #文件所在的路径
image_list = []
for file in os.listdir(train_dir):  
    image_list.append(os.path.join(train_dir, file))  
for i in range(len(image_list)):
    print(image_list[i])
    
sample_num = len(image_list)  #样本的总数量
epoch_num = 3    #用全部样本迭代的次数
batch_size = 5   #
batch_total = int(sample_num/batch_size)+1 #一个epoch有的batch数目
#给予列表,产生"文件名字"队列,这里只有一个元素image_list,那也得用[]括起来
input_queue = tf.train.slice_input_producer([image_list],shuffle = False)  #生成文件名队列,由于shuffle = True,随机产生一个文件名路径
#从文将名字队列的输出端口,读取文件名字并调用tf.read_file(队列端口),读取文件内容数据
image_contents = tf.read_file(input_queue[0])   #根据上面的文件名路径,读取样本(处于编码下)
#把读取的内容进行jpg解码
image = tf.image.decode_jpeg(image_contents, channels=1)    #进行解码,由于读取的为灰度图,channels=1
#crop_image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) #采用切割的方式改变大小
#对图片进行裁切
crop_image = tf.image.resize_images(image, [image_W,image_H],method=0)  #采用放缩的方式改变大小,method=0为双线性插值
#对图片规范化
standard_image = tf.image.per_image_standardization(crop_image)   # 标准化数据
#把从文件名队列,到文件内容,再到用文件内容构造成训练batch数据,相当于
#上面的操作定义了tf中文件读取的操作,下面定义了构造batch的操作.
image_batch = tf.train.batch([crop_image], batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False) #生成batch
 
with tf.Session() as sess:
    coord = tf.train.Coordinator() #创建线程管理器
    threads = tf.train.start_queue_runners(sess, coord) #启动线程
    try:
        for i in range(epoch_num):  # 每一轮迭代
            print ('epoch is %d'%(i+1))
            for j in range(batch_total): #batch遍历,在所有样本中选择batch_size个
                #print ('batch num is %d'%(j+1))
                #print ('input_queue:%s'%(sess.run(input_queue)))     
                #print ('image_contents:%s'%(sess.run(image_contents)))    
                #picture, crop_picture,standard_picture = sess.run([image,crop_image,standard_image])
                #show_image(picture)
                #show_image(crop_picture)
                #show_image(standard_picture)
                picture_batch = sess.run(image_batch) #让tf执行以上定义的操作
                print(picture_batch.shape) #打印从tf返回到py环境的结果,这里返回shape
                #(3, 64, 64, 1) 3是batch_size,也就是说,返回的结果包含batch_size张图片
                #每张图片是:64x64x1的图片
#                show_batch_image(picture_batch )
    except tf.errors.OutOfRangeError:  #如果读取到文件队列末尾会抛出此异常
        print("done! now lets kill all the threads……")
    finally:
        # 协调器coord发出所有线程终止信号
        coord.request_stop()
        print('all threads are asked to stop!')
    coord.join(threads) #把开启的线程加入主线程,等待

猜你喜欢

转载自blog.csdn.net/yanlizhong62/article/details/84248823