Tensorflow神经网络框架(第三课 3-2MNIST数据集分类简单版本,手写数字识别)

3-2MNIST数据集分类简单版本,手写数字识别Last Checkpoint: 31 分钟前(unsaved changes) Logout
In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
In [2]:
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])#样本
y = tf.placeholder(tf.float32,[None,10])#标签
#创建一个简单的神经网络,后面优化可以用多个隐藏层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))    
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#定义二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict = {x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch)+ ",Testing Accuracy " + str(acc))
WARNING:tensorflow:From <ipython-input-2-7a5fbd2e6494>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Program Files\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Iter0,Testing Accuracy 0.8331
Iter1,Testing Accuracy 0.8699
Iter2,Testing Accuracy 0.8813
Iter3,Testing Accuracy 0.8877
Iter4,Testing Accuracy 0.8935
Iter5,Testing Accuracy 0.8972
Iter6,Testing Accuracy 0.9004
Iter7,Testing Accuracy 0.9015
Iter8,Testing Accuracy 0.9038
Iter9,Testing Accuracy 0.9053
Iter10,Testing Accuracy 0.906
Iter11,Testing Accuracy 0.9074
Iter12,Testing Accuracy 0.9088
Iter13,Testing Accuracy 0.909
Iter14,Testing Accuracy 0.9099
Iter15,Testing Accuracy 0.9112
Iter16,Testing Accuracy 0.9117
Iter17,Testing Accuracy 0.9123
Iter18,Testing Accuracy 0.9128
Iter19,Testing Accuracy 0.9134
Iter20,Testing Accuracy 0.9135
In [ ]:
#可以看到最后的预测正确率为:91.356,其实这个程序有许多地方可以进行优化:1增加隐含层,2迭代多次

猜你喜欢

转载自blog.csdn.net/u011473714/article/details/80805459