用tensorflow训练自己的图片——1、读取数据

很多同学(针对新手)在训练mnist数据的时候,根据书本上的内容都可以很好很快的编辑并跑出来,但是一旦换成自己的文件夹,就很头疼,毕竟mnist里面一个read_data解决你所有的输入问题,然而在现实中,该read_data是要自己编辑的,本文主要针对非ont_hot数据,如何利用tensorflow搭起网络并跑通自己的数据,话不多说,直接上代码。

python版本:2.7

tensorflow 版本:1.1.0

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 11:28:55 2018

@author:huangxd
"""
"""
vision:python3
author:huangxd
"""
import os
import math  
import numpy as np  
import tensorflow as tf

#生成图片路径和标签list
#train_dir='C:/Users/hxd/Desktop/tensorflow_study/Alexnet_dr'
zeroclass = []  
label_zeroclass = []  
oneclass = []  
label_oneclass = []  
twoclass = []  
label_twoclass = []  
threeclass = []  
label_threeclass = []
fourclass = []
label_fourclass = []
fiveclass = []
label_fiveclass = []
#s1 获取路径下所有图片名和路径,存放到对应列表并贴标签
def get_files(file_dir,ratio):
    for file in os.listdir(file_dir+'/0'):  
        zeroclass.append(file_dir +'/0'+'/'+ file)   
        label_zeroclass.append(0)  
    for file in os.listdir(file_dir+'/1'):  
        oneclass.append(file_dir +'/1'+'/'+file)  
        label_oneclass.append(1)  
    for file in os.listdir(file_dir+'/2'):  
        twoclass.append(file_dir +'/2'+'/'+ file)   
        label_twoclass.append(2)  
    for file in os.listdir(file_dir+'/3'):  
        threeclass.append(file_dir +'/3'+'/'+file)  
        label_threeclass.append(3)      
    for file in os.listdir(file_dir+'/4'):  
        fourclass.append(file_dir +'/4'+'/'+file)  
        label_fourclass.append(4)      
    for file in os.listdir(file_dir+'/5'):  
        fiveclass.append(file_dir +'/5'+'/'+file)  
        label_fiveclass.append(5)
#s2 对生成图片路径和标签list打乱处理(img和label)
    image_list=np.hstack((zeroclass, oneclass, twoclass, threeclass, fourclass, fiveclass))
    label_list=np.hstack((label_zeroclass, label_oneclass, label_twoclass, label_threeclass, label_fourclass, label_fiveclass))
    #shuffle打乱
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)
    #将所有的img和lab转换成list
    all_image_list=list(temp[:,0])
    all_label_list=list(temp[:,1])
    #将所得List分为2部分,一部分train,一部分val,ratio是验证集比例
    n_sample = len(all_label_list)  
    n_val = int(math.ceil(n_sample*ratio))   #验证样本数  
    n_train = n_sample - n_val   #训练样本数  
  
    tra_images = all_image_list[0:n_train]
    tra_labels = all_label_list[0:n_train]  
    tra_labels = [int(float(i)) for i in tra_labels]  
    val_images = all_image_list[n_train:]  
    val_labels = all_label_list[n_train:]
    val_labels = [int(float(i)) for i in val_labels]    
    return tra_images,tra_labels,val_images,val_labels
#生成batch
#s1:将上面的list传入get_batch(),转换类型,产生输入队列queue因为img和lab  
#是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像  
#   image_W, image_H, :设置好固定的图像高度和宽度  
#   设置batch_size:每个batch要放多少张图片  
#   capacity:一个队列最大多少

def get_batch(image,label,image_W,image_H,batch_size,capacity):
    #转换类型
    image=tf.cast(image,tf.string)
    label=tf.cast(label,tf.int32)
    #入队
    input_queue=tf.train.slice_input_producer([image,label])
    label=input_queue[1]
    image_contents=tf.read_file(input_queue[0]) #读取图像
    #s2图像解码,且必须是同一类型
    image=tf.image.decode_jpeg(image_contents,channels=3)
    #s3预处理,主要包括旋转,缩放,裁剪,归一化
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)  
    image = tf.image.per_image_standardization(image)
    #s4生成batch

    image_batch, label_batch = tf.train.batch([image, label],  
                                                batch_size= batch_size,  
                                                num_threads= 32,   
                                                capacity = capacity)
    #重新排列label,行数为[batch_size]  
    label_batch = tf.reshape(label_batch, [batch_size])  
    #image_batch = tf.cast(image_batch, tf.float32)  
    return image_batch, label_batch
该数据生成的是bool型,非one_hot编码,系统自带的mnist编码是one_hot编码,大家可以先去了解下这块东西

猜你喜欢

转载自blog.csdn.net/qq_36631272/article/details/79173035