tensorflow实战二

tensorflow实战二

概述

本篇博文利用mnist数据集,利用最简单的单层神经网络识别手写数字。

代码实现

本博文使用了最简单的单层神经网络结构,输入层有784个输入神经元,输出层有10个神经元,采用one_hot形式读入数据。
话不多说,直接上代码

#-*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
minist = input_data.read_data_sets("minist",one_hot=True)

#定义批次大小,数据量太大,采用随机梯度下降法进行批次训练,一个批次暂定50张训练数据
batch_size = 50
batch_n = minist.train.num_examples//batch_size

#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#创建一个简单神经网络
w1 = tf.Variable(tf.zeros([784,10]))
b1 = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,w1) + b1) #使用softmax作为激活函数

loss = tf.reduce_mean(tf.square(y-prediction)) #定义代价函数
train = tf.train.GradientDescentOptimizer(0.05).minimize(loss) #定义优化器

#计算准确率
correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for step in range(50):
    for _ in range(batch_n):
        batch_xs,batch_ys = minist.train.next_batch(batch_size) #每次取50张图片
        sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
    acc = sess.run(accuracy,feed_dict={x:minist.test.images,y:minist.test.labels})
print "Iter"+str(step)+":"+str(acc)

sess.close()

结果和小结

上述神经网络的结构是最简单的网络,最终得到了正确率为0.9158 。大家可以从参数初始化、批次调节、迭代次数、激活函数、增加隐藏层这几种方法进行改进。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_40504899/article/details/84985790