学习笔记:用tensorflow训练自己的图片

感谢这位作者,以下记录是来自于https://blog.csdn.net/qq_36631272/article/details/79173035的,我看到比较好,就转记录到自己的博客了,如果有侵权,立马删掉

在训练mnist数据的时候,根据书本上的内容都可以很好很快的编辑并跑出来,但是一旦换成自己的文件夹,就很头疼,毕竟mnist里面一个read_data解决你所有的输入问题,然而在现实中,该read_data是要自己编辑的,本文主要针对非ont_hot数据,如何利用tensorflow搭起网络并跑通自己的数据,话不多说,直接上代码。

python版本:2.7

tensorflow 版本:1.2.0

[html]  view plain  copy
  1. #!/usr/bin/env python2  
  2. # -*- coding: utf-8 -*-  
  3. """  
  4. Created on Thu Jan 25 11:28:55 2018  
  5.   
  6. @author:huangxd  
  7. """  
  8. """  
  9. vision:python3  
  10. author:huangxd  
  11. """  
  12. import os  
  13. import math    
  14. import numpy as np    
  15. import tensorflow as tf  
  16.   
  17. #生成图片路径和标签list  
  18. #train_dir='C:/Users/hxd/Desktop/tensorflow_study/Alexnet_dr'  
  19. zeroclass = []    
  20. label_zeroclass = []    
  21. oneclass = []    
  22. label_oneclass = []    
  23. twoclass = []    
  24. label_twoclass = []    
  25. threeclass = []    
  26. label_threeclass = []  
  27. fourclass = []  
  28. label_fourclass = []  
  29. fiveclass = []  
  30. label_fiveclass = []  
  31. #s1 获取路径下所有图片名和路径,存放到对应列表并贴标签  
  32. def get_files(file_dir,ratio):  
  33.     for file in os.listdir(file_dir+'/0'):    
  34.         zeroclass.append(file_dir +'/0'+'/'+ file)     
  35.         label_zeroclass.append(0)    
  36.     for file in os.listdir(file_dir+'/1'):    
  37.         oneclass.append(file_dir +'/1'+'/'+file)    
  38.         label_oneclass.append(1)    
  39.     for file in os.listdir(file_dir+'/2'):    
  40.         twoclass.append(file_dir +'/2'+'/'+ file)     
  41.         label_twoclass.append(2)    
  42.     for file in os.listdir(file_dir+'/3'):    
  43.         threeclass.append(file_dir +'/3'+'/'+file)    
  44.         label_threeclass.append(3)        
  45.     for file in os.listdir(file_dir+'/4'):    
  46.         fourclass.append(file_dir +'/4'+'/'+file)    
  47.         label_fourclass.append(4)        
  48.     for file in os.listdir(file_dir+'/5'):    
  49.         fiveclass.append(file_dir +'/5'+'/'+file)    
  50.         label_fiveclass.append(5)  
  51. #s2 对生成图片路径和标签list打乱处理(img和label)  
  52.     image_list=np.hstack((zeroclass, oneclass, twoclass, threeclass, fourclass, fiveclass))  
  53.     label_list=np.hstack((label_zeroclass, label_oneclass, label_twoclass, label_threeclass, label_fourclass, label_fiveclass))  
  54.     #shuffle打乱  
  55.     temp = np.array([image_list, label_list])  
  56.     temp = temp.transpose()  
  57.     np.random.shuffle(temp)  
  58.     #将所有的img和lab转换成list  
  59.     all_image_list=list(temp[:,0])  
  60.     all_label_list=list(temp[:,1])  
  61.     #将所得List分为2部分,一部分train,一部分val,ratio是验证集比例  
  62.     n_sample = len(all_label_list)    
  63.     n_val = int(math.ceil(n_sample*ratio))   #验证样本数    
  64.     n_train = n_sample - n_val   #训练样本数    
  65.     
  66.     tra_images = all_image_list[0:n_train]  
  67.     tra_labels = all_label_list[0:n_train]    
  68.     tra_labels = [int(float(i)) for i in tra_labels]    
  69.     val_images = all_image_list[n_train:]    
  70.     val_labels = all_label_list[n_train:]  
  71.     val_labels = [int(float(i)) for i in val_labels]      
  72.     return tra_images,tra_labels,val_images,val_labels  
  73. #生成batch  
  74. #s1:将上面的list传入get_batch(),转换类型,产生输入队列queue因为img和lab    
  75. #是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像    
  76. #   image_W, image_H, :设置好固定的图像高度和宽度    
  77. #   设置batch_size:每个batch要放多少张图片    
  78. #   capacity:一个队列最大多少  
  79.   
  80. def get_batch(image,label,image_W,image_H,batch_size,capacity):  
  81.     #转换类型  
  82.     image=tf.cast(image,tf.string)  
  83.     label=tf.cast(label,tf.int32)  
  84.     #入队  
  85.     input_queue=tf.train.slice_input_producer([image,label])  
  86.     label=input_queue[1]  
  87.     image_contents=tf.read_file(input_queue[0]) #读取图像  
  88.     #s2图像解码,且必须是同一类型  
  89.     image=tf.image.decode_jpeg(image_contents,channels=3)  
  90.     #s3预处理,主要包括旋转,缩放,裁剪,归一化  
  91.     image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)    
  92.     image = tf.image.per_image_standardization(image)  
  93.     #s4生成batch  
  94.   
  95.     image_batch, label_batch = tf.train.batch([image, label],    
  96.                                                 batch_sizebatch_size,    
  97.                                                 num_threads32,     
  98.                                                 capacity = capacity)  
  99.     #重新排列label,行数为[batch_size]    
  100.     label_batch = tf.reshape(label_batch, [batch_size])    
  101.     #image_batch = tf.cast(image_batch, tf.float32)    
  102.     return image_batch, label_batch  

该数据生成的是bool型,非one_hot编码,系统自带的mnist编码是one_hot编码,大家可以先去了解下这块东西

得到数据之后,接下来就是网络的搭建,我在这里将模型单独定义出来,方便后期的网络修正。

[html]  view plain  copy
  1. #!/usr/bin/env python2  
  2. # -*- coding: utf-8 -*-  
  3.   
  4. """  
  5. Spyder Editor  
  6. This is a temporary script file.  
  7. filename:DR_alexnet.py  
  8. creat time:2018年1月16日  
  9. author:huangxudong  
  10. """  
  11. import tensorflow as tf  
  12. import numpy as np  
  13. #define different layer function  
  14. def maxPoolLayer(x,kHeight,kWidth,strideX,strideY,name,padding="SAME"):  
  15.     return tf.nn.max_pool(x,ksize=[1,kHeight,kWidth,1],strides=[1,strideX,strideY,1],padding=padding,name=name)  
  16.   
  17. def dropout(x,keepPro,name=None):  
  18.     return tf.nn.dropout(x,keepPro,name)  
  19.   
  20. def LRN(x,R,alpha,beta,name=None,bias=1.0):                      #局部相应归一化  
  21.     return tf.nn.local_response_normalization(x,depth_radius=R,alpha=alpha,  
  22.                                               beta=beta,bias=bias,name=name)  
  23. def fcLayer(x,inputD,outputD,reluFlag,name):  
  24.     with tf.variable_scope(name) as scope:  
  25.         w=tf.get_variable("w",shape=[inputD,outputD])   #shape就是变量维度  
  26.         b=tf.get_variable("b",[outputD])  
  27.         out=tf.nn.xw_plus_b(x,w,b,name=scope.name)  
  28.         if reluFlag:  
  29.             return tf.nn.relu(out)  
  30.         else:  
  31.             return out  
  32. def convLayer(x,kHeight,kWidth,strideX,strideY,featureNum,name,padding="SAME",groups=1):  
  33.     """convolution"""  
  34.     channel=int(x.get_shape()[-1])       #x数组的最后一个数  
  35.     conv=lambda a,b: tf.nn.conv2d(a,b,strides=[1,strideY,strideX,1],padding=padding)   #匿名函数  
  36.     with tf.variable_scope(name) as scope:  
  37.         w=tf.get_variable("w",shape=[kHeight,kWidth,channel/groups,featureNum])  
  38.         b=tf.get_variable("b",shape=[featureNum])  
  39.         xNew=tf.split(value=x,num_or_size_splits=groups,axis=3)  
  40.         wNew=tf.split(value=w,num_or_size_splits=groups,axis=3)  
  41.         featureMap=[conv(t1,t2) for t1,t2 in zip(xNew,wNew)]  
  42.         mergeFeatureMap=tf.concat(axis=3,values=featureMap)  
  43.         out=tf.nn.bias_add(mergeFeatureMap,b)  
  44.  #       print(mergeFeatureMap.get_shape().as_list(),out.shape)  
  45.         return tf.nn.relu(out,name=scope.name)  #卷积激活一起完成,out大小和mergeFeatureMap一样,不需要reshape  
  46. class alexNet(object):  
  47.     """alexNet model"""  
  48.     def __init__(self,x,keepPro,classNum,skip,modelPath="bvlc_alexnet.npy"):  
  49.         self.X=x  
  50.         self.KEEPPRO=keepPro                 #表示类名  
  51.         self.CLASSNUM=classNum  
  52.         self.SKIP=skip  
  53.         self.MODELPATH=modelPath  
  54.         #build CNN  
  55.         self.buildCNN()  
  56.     def buildCNN(self):             #重点,模型搭建  2800*2100  
  57.         x1=tf.reshape(self.X,shape=[-1,512,512,3])  
  58. #        print(x1.shape)  
  59.         conv1=convLayer(x1,7,7,3,3,256,"conv1","VALID")    #169*169  
  60.         lrn1=LRN(conv1,2,2e-05,0.75,"norm1")  
  61.         pool1=maxPoolLayer(lrn1,3,3,2,2,"pool1","VALID")    #84*84  
  62.   
  63.         conv2=convLayer(pool1,3,3,1,1,512,"conv2","VALID")    #82*82  
  64.         lrn2=LRN(conv2,2,2e-05,0.75,"norm2")  
  65.         pool2=maxPoolLayer(lrn2,3,3,2,2,"pool2","VALID")       #40*40  
  66.       
  67.         conv3=convLayer(pool2,3,3,1,1,1024,"conv3","VALID")    #38*38      
  68.         conv4=convLayer(conv3,3,3,1,1,1024,"conv4","VALID")   #36*36  
  69.   
  70.         conv5=convLayer(conv4,3,3,2,2,512,"conv5","VALID")    #17*17  
  71.         pool5=maxPoolLayer(conv5,3,3,2,2,"pool5","VALID")     #8*8  
  72. #        print(pool5.shape)  
  73.         fcIn=tf.reshape(pool5,[-1,512*8*8])  
  74.         fc1=fcLayer(fcIn,512*8*8,4096,True,"fc6")  
  75.         dropout1=dropout(fc1,self.KEEPPRO)  
  76.   
  77.         fc2=fcLayer(dropout1,4096,4096,True,"fc7")  
  78.         dropout2=dropout(fc2,self.KEEPPRO)  
  79.   
  80.         self.fc3=fcLayer(dropout2,4096,self.CLASSNUM,True,"fc8")  
上面便是网络的搭建,搭好之后还需要将模型加载出来:

[html]  view plain  copy
  1. def loadModel(self,sess):  
  2.     """load model"""  
  3.     wDict=np.load(self.MODELPATH,encoding="bytes").item()  
  4.     for name in wDict:  
  5.         if name not in self.SKIP:  
  6.             with tf.variable_scope(name, reuse = True):  
  7.                 for p in wDict[name]:  
  8.                     if len(p.shape) == 1:  
  9.                         #bias  
  10.                         sess.run(tf.get_variable('b', trainable = False).assign(p))  
  11.                     else:  
  12.                         #weights  
  13.                         sess.run(tf.get_variable('w', trainable = False).assign(p))  

训练模型的时候,维数一定要匹配,同时要了解你自己的数据的格式,和读取的类型,一个one_hot编码用的函数和非one_hot用的函数完全不一样,这也是我当时一直出现问题的原因。

扫描二维码关注公众号,回复: 1444518 查看本文章

[html]  view plain  copy
  1. #!/usr/bin/env python2  
  2. # -*- coding: utf-8 -*-  
  3. """  
  4. Created on Thu Jan 25 11:32:40 2018  
  5.   
  6. @author: huangxudong  
  7. """  
  8. import dr_alexnet  
  9. import tensorflow as tf  
  10. import read_data2  
  11.   
  12. #定义网络超参数  
  13. learning_rate=0.01  
  14. train_iters=2000  
  15. batch_size=5  
  16. capacity=256  
  17. display_step=10  
  18. #读取数据  
  19. tra_list,tra_labels,val_list,val_labels=read_data2.get_files('/home/bigvision/Desktop/DR_model',0.2)  
  20. tra_list_batch,tra_label_batch=read_data2.get_batch(tra_list,tra_labels,512,512,batch_size,capacity)  
  21. val_list_batch,val_label_batch=read_data2.get_batch(val_list,val_labels,512,512,batch_size,capacity)  
  22.   
  23. #定义网络参数  
  24. n_class=6       #标记维度  
  25. dropout=0.75  
  26. skip=[]  
  27. #输入占位符  
  28. x=tf.placeholder(tf.float32,[None,786432])  #2800*2100*3,512*512*3  
  29. y=tf.placeholder(tf.int32,[None])  
  30. #print(y.shape)  
  31. keep_prob=tf.placeholder(tf.float32)  #dropout  
  32.   
  33.   
  34. ''''构建模型,定义损失函数和优化器'''''  
  35. pred=dr_alexnet.alexNet(x,dropout,n_class,skip)  
  36. #定义损失函数和优化器  
  37. cost=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=pred.fc3))  
  38. optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)  
  39. #评估函数,优化函数  
  40. correct_pred=tf.nn.in_top_k(pred.fc3,y,1)  #1表示列上去最大,0是行,这个地方如果是one_hot就是tf.argmax  
  41. accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))    #改类型  
  42.   
  43.   
  44. '''训练模型'''  
  45. init=tf.global_variables_initializer()   #初始化所有变量  
  46.   
  47. with tf.Session() as sess:  
  48.     sess.run(init)  
  49.     coord=tf.train.Coordinator()        
  50.     threadstf.train.start_queue_runners(coord=coord)      
  51.     step=1  
  52.     #开始训练,达到最大训练次数  
  53.     while step*batch_size<train_iters:         
  54.         batch_x,batch_y=tra_list_batch.eval(session=sess),tra_label_batch.eval(session=sess)  
  55.         batch_x=batch_x.reshape((batch_size,786432))  
  56.         batch_y=batch_y.T  
  57.           
  58.         sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})  
  59.         if step%display_step==2:              
  60.             #计算损失值和准确度,输出  
  61.             loss,acc=sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})  
  62.             print("Iter"+str(step*batch_size)+",Minibatch Loss="+ "{:.6f}".format(loss)+", Training Acc"+ "{:.5f}".format(acc))  
  63.         step+=1  
  64.     print("Optimization Finished!")  
  65.     coord.request_stop()       
  66.     coord.join(threads)            #多线程进行batch送入  

feed_dict字典读取数据的时候不能是tensor类型,必须是list,numpy类型(还有一个忘了),所以在送入batch数据的时候加入了.eval(session.sess),当初这块也是磨了很久。希望以后不在犯错

猜你喜欢

转载自blog.csdn.net/zr940326/article/details/80556301