用tensorflow训练自己的数据_3、训练模型

训练模型的时候,维数一定要匹配,同时要了解你自己的数据的格式,和读取的类型,一个one_hot编码用的函数和非one_hot用的函数完全不一样,这也是我当时一直出现问题的原因。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 11:32:40 2018

@author: huangxudong
"""
import dr_alexnet
import tensorflow as tf
import read_data2

#定义网络超参数
learning_rate=0.01
train_iters=2000
batch_size=5
capacity=256
display_step=10
#读取数据
tra_list,tra_labels,val_list,val_labels=read_data2.get_files('/home/bigvision/Desktop/DR_model',0.2)
tra_list_batch,tra_label_batch=read_data2.get_batch(tra_list,tra_labels,512,512,batch_size,capacity)
val_list_batch,val_label_batch=read_data2.get_batch(val_list,val_labels,512,512,batch_size,capacity)

#定义网络参数
n_class=6       #标记维度
dropout=0.75
skip=[]
#输入占位符
x=tf.placeholder(tf.float32,[None,786432])  #2800*2100*3,512*512*3
y=tf.placeholder(tf.int32,[None])
#print(y.shape)
keep_prob=tf.placeholder(tf.float32)  #dropout


''''构建模型,定义损失函数和优化器'''''
pred=dr_alexnet.alexNet(x,dropout,n_class,skip)
#定义损失函数和优化器
cost=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=pred.fc3))
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
#评估函数,优化函数
correct_pred=tf.nn.in_top_k(pred.fc3,y,1)  #1表示列上去最大,0是行,这个地方如果是one_hot就是tf.argmax
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))    #改类型


'''训练模型'''
init=tf.global_variables_initializer()   #初始化所有变量

with tf.Session() as sess:
    sess.run(init)
    coord=tf.train.Coordinator()      
    threads= tf.train.start_queue_runners(coord=coord)    
    step=1
    #开始训练,达到最大训练次数
    while step*batch_size<train_iters:       
        batch_x,batch_y=tra_list_batch.eval(session=sess),tra_label_batch.eval(session=sess)
        batch_x=batch_x.reshape((batch_size,786432))
        batch_y=batch_y.T
        
        sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
        if step%display_step==2:            
            #计算损失值和准确度,输出
            loss,acc=sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})
            print("Iter"+str(step*batch_size)+",Minibatch Loss="+ "{:.6f}".format(loss)+", Training Acc"+ "{:.5f}".format(acc))
        step+=1
    print("Optimization Finished!")
    coord.request_stop()     
    coord.join(threads)            #多线程进行batch送入

feed_dict字典读取数据的时候不能是tensor类型,必须是list,numpy类型(还有一个忘了),所以在送入batch数据的时候加入了.eval(session.sess),当初这块也是磨了很久。希望以后不在犯错

本人新人,对大家有帮助的话就点赞哦

猜你喜欢

转载自blog.csdn.net/qq_36631272/article/details/79173280