tensorflow数据清洗

import tensorflow as tf
import numpy as np
import random
import os
import math


from matplotlib import pyplot as plt

def get_files(file_dir):
    """
    创建数据文件名列表

    :param file_dir:
    :return:image_list 所有图像文件名的列表,label_list 所有对应标贴的列表
    """
    #step1.获取图片,并贴上标贴
    #新建五个列表,存储文件夹下的文件名
    daisy=[]
    label_daisy=[]
    dandelion=[]
    label_dandelion = []
    roses=[]
    label_roses = []
    sunflowers=[]
    label_sunflowers = []
    tulips=[]
    label_tulips = []
    for file in os.listdir(file_dir+"/daisy"):
        daisy.append(file_dir+"/daisy"+"/"+file)
        label_daisy.append(0)

    for file in os.listdir(file_dir+"/dandelion"):
        dandelion.append(file_dir+"/dandelion"+"/"+file)
        label_dandelion.append(1)
    for file in os.listdir(file_dir+"/roses"):
        roses.append(file_dir+"/roses"+"/"+file)
        label_roses.append(2)
    for file in os.listdir(file_dir+"/sunflowers"):
        sunflowers.append(file_dir+"/sunflowers"+"/"+file)
        label_sunflowers.append(3)
    for file in os.listdir(file_dir+"/tulips"):
        tulips.append(file_dir+"/tulips"+"/"+file)
        label_tulips.append(4)

    #step2:对生成的图片路径和标签List做打乱处理
    #把所有图片跟标贴合并到一个列表list(img和lab)
    images_list=np.hstack([daisy,dandelion,roses,sunflowers,tulips])
    labels_list=np.hstack([label_daisy,label_dandelion,label_roses,label_sunflowers,label_tulips])

    #利用shuffle打乱顺序
    temp=np.array([images_list,labels_list]).transpose()
    np.random.shuffle(temp)
    # 从打乱的temp中再取出list(img和lab)
    image_list=list(temp[:,0])
    label_list=list(temp[:,1])
    label_list_new=[int(i) for i in label_list]

    # 将所得List分为两部分,一部分用来训练tra,一部分用来测试val
    # 测试样本数, ratio是测试集的比例
    ratio=0.3
    n_sample = len(label_list)
    n_val = int(math.ceil(n_sample * ratio))
    n_train = n_sample - n_val  # 训练样本数
    tra_images = image_list[0:n_train]
    tra_labels = label_list_new[0:n_train]
    #tra_labels = [int(float(i)) for i in tra_labels]  # 转换成int数据类型
    val_images = image_list[n_train:-1]
    val_labels = label_list_new[n_train:-1]
    #val_labels = [int(float(i)) for i in val_labels]  # 转换成int数据类型
    return tra_images, tra_labels, val_images, val_labels

    #return image_list,label_list_new

def get_batch(image, label, image_W, image_H,channel, batch_size, capacity):
    #step1:将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue
    #类型转换
    image=tf.cast(image,tf.string)
    label=tf.cast(label,tf.int32)
    #生成输入队列
    input_queue=tf.train.slice_input_producer([image,label])

扫描二维码关注公众号,回复: 9472710 查看本文章

    label=input_queue[1]
    image_contents=tf.read_file(input_queue[0])
    #print(image_contents)
    #step2:将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等
    images_value=tf.image.decode_jpeg(image_contents)
    #print(images_value)
    #step3:数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮
    #image=tf.image.resize_image_with_crop_or_pad(images_value,image_W,image_H)
    #image=tf.image.resize_images(images_value,size=[200,200])
    image = tf.image.resize_images(images_value,size=[image_W,image_H])
    #image.set_shape(shape=[200, 200, 3])
    image.set_shape(shape=[image_W, image_H, channel])
    #print(image)
    # 对resize后的图片进行标准化处理
    image=tf.image.per_image_standardization(image)
    #step4:生成batch
    image_batch,label_batch=tf.train.batch([image,label],batch_size=batch_size,num_threads=1,capacity=capacity)
    # 重新排列label,行数为[batch_size]
    #print(label_batch)
    label_batch = tf.reshape(label_batch, [batch_size])
    #print(label_batch)
    image_batch = tf.cast(image_batch, tf.float32)
    return image_batch,label_batch


if __name__=="__main__":
    BATCH_SIZE = 2
    CAPACITY = 256
    IMG_W = 208
    IMG_H = 208
    # 读取文件所在路径
    mypath = "/home/sunxiaoming/PycharmProjects/data/flower_photos"
    image_list,label_list=get_files(mypath)
    print(len(image_list))
    print(len(label_list))
    image_batch,label_batch=get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
    print(image_batch)
    with tf.Session() as sess:
        # 开启线程
        # 线程协调元
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        i=0
        while i<2:
            image,lable = sess.run([image_batch, label_batch])
            #image_array=np.array(image[i,:,:,:])
            for j in range(2):
                plt.imshow(image[j, :, :, :])
                plt.show()

            i+=1

        # 回收线程
        coord.request_stop()
        coord.join(threads)

    #with tf.Session() as sess:
        # 开启线程
        # 线程协调元
        #coord = tf.train.Coordinator()
        #threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        #i=0
        #while not coord.should_stop() and i < 2:


            #lable, image = sess.run([image_batch,label_batch])
            #print(type(image))
            #"""
                        #for j in np.arange(BATCH_SIZE):
               # print('label: %d' % lable[j])

                #plt.imshow(image[j, :, :, :])
                #plt.show()
            #i += 1

            #"""


        # 回收线程
        #coord.request_stop()
        #coord.join(threads)

发布了109 篇原创文章 · 获赞 22 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_42233538/article/details/89290058