mnist手写字体识别

1、获取mnist数据集,得到正确的数据格式

mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

2、定义网络大小:图片的大小是28*28,784个像素点,输入神经元为784个,输出0~9个数,输出神经元为10个

n_input =784
n_layer1 = 10
examples_to_show = 10 #显示的测试图像个数
x_data = tf.placeholder(tf.float32,[None,n_input])
y_data = tf.placeholder(tf.float32,[None,n_layer1])
 
3、添加层函数
inputsize:输入神经元的个数weights:权重biases:偏置值activation_function:激活函数
输出:该层的进行激活后的神经元,下一个层的输入
def addlayer(inputsize,weights,biases,activation_function=None):
    output = tf.add(tf.matmul(x_data,weights),biases)
    if activation_function == None:
        return tf.nn.sigmoid(output)
    else:
        return tf.nn.softmax(output)
 
4、构建网络
#预测输出
添加隐藏层
y_pre=addlayer(x_data,layer1_weights,layer1_biases,activation_function=tf.nn.softmax)
y_true=y_data
#反向
cross_entropy=-tf.reduce_sum(y_true * tf.log(y_pre))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
     batch_xs,batch_ys =mnist.train.next_batch(100)
     sess.run(train_step,feed_dict={x_data:batch_xs , y_data:batch_ys})
     if (i%50==0):
        print ("loss : ",sess.run(cross_entropy,feed_dict={x_data:batch_xs , y_data:batch_ys}))
        # 这个越大越好
        print ("prediction acc : ",compute_acc(mnist.test.images[:100], mnist.test.labels[:100]))
5、测试集计算模型预测准确率
def compute_acc(x_input_test,y_true_test):
    y_pre_test = sess.run(y_pre,feed_dict={x_data:x_input_test, y_data:y_true_test})
    correct_prediction=tf.equal(tf.argmax(y_pre_test,1),tf.argmax(y_true_test,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={x_data:x_input_test , y_data:y_true_test})
    return result
6、结果

猜你喜欢

转载自www.cnblogs.com/wyx501/p/10440299.html