tensorflow实现手写数字识别(MLP)

from  __future__ import print_function#即使是在python2版本也要像在Python3中使用print函数
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/",one_hot=True)#onehot对标签的标注,非onehot是1,2,3.onehot就是只有一个1其余全是0
import tensorflow as tf
#超参数(学习率,batch的大小,训练的轮数,多少轮展示一下loss)
learning_rate = 0.1
num_step = 500
batch_size = 128
display_step =100

#网络参数(有多少层网络,每层有多少个神经元,整个网络的输入是多少维度的,输出是多少维度的)
n_hidden_1 = 256
n_hidden_2 = 256
num_input = 784#(28*28)
num_class = 10

#图的输入
X = tf.placeholder("float",[None,num_input])
Y = tf.placeholder("float",[None,num_class])

#网络的权重和偏向,如果是两个隐层的话需要定义三个权重,包括输出层
weights={
    'h1':tf.Variable(tf.random_normal([num_input,n_hidden_1])),
    'h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_hidden_2,num_class]))
}

biase = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([num_class]))
}
#定义网络结构
def neural_net(x):
    layer_1 = tf.add(tf.matmul(x,weights['h1']),biase['b1'])
    layer_2 = tf.add(tf.matmul(layer_1,weights['h2']),biase['b2'])
    out_layer = tf.add(tf.matmul(layer_2,weights['out']),biase['out'])
    return out_layer

#模型输出处理
logits = neural_net(X)
prediction = tf.nn.softmax(logits)

#定义损失和优化器
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
train_op = optimizer.minimize(loss_op)

#评估模型准确率
correct_pred = tf.equal(tf.argmax(prediction,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

#初始化变量
init = tf.global_variables_initializer()
#开始训练
with tf.Session() as sess:
    sess.run(init)
    for step in range(1,num_step+1):
        batch_x,batch_y = mnist.train.next_batch(batch_size)
        if step % display_step == 0 or step == 1:
            loss,acc = sess.run([loss_op,accuracy],feed_dict={X:batch_x,Y:batch_y})
            print("step:{},loss:{},acc:{}".format(step,loss,acc))
    print("优化完成!")
    #训练完模型后,开始测试
    print("testing Accuracy:",sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels}))





猜你喜欢

转载自blog.csdn.net/pwtd_huran/article/details/80258418