在TensorFlow上利用mnist数据集训练神经网络模型

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 在tensorflow的log日志等级如下:
# - 0:显示所有日志(默认等级)
# - 1:显示info、warning和error日志
# - 2:显示warning和error信息
# - 3:显示error日志信息
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'

# MNIST_data是个.zip压缩文件 保存在跟py文件同样路径下
mnist_set = input_data.read_data_sets('./MNIST_data', one_hot=True)


batch_size = 32  # batch_size 批大小,根据自己的gup内存大小设置
batch_num = mnist_set.train.num_examples // batch_size  # 每批样本数量 // 在python语法中表示整除 向下取整

#  定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y_data = tf.placeholder(tf.float32,[None,10])

# build a simple network
weight = tf.Variable(initial_value=tf.truncated_normal([784,10]))
bias = tf.Variable(initial_value=tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,weight) + bias)

# train stage
# loss
loss = tf.reduce_mean(tf.square(y_data - prediction))
# optimize
train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#  test stage
# tf.argmax 返回一维中最大值的位置
# 比如
# a = [[2,3,2,0]
#     [3,5,2,9]
#     [2,7,6,2]]
# tf.argmax(a,0) 0 表示按列处理
#  结果[1,2,2,1]
# tf.argmax(a,1) 1 表示按行处理
#  结果[1,3,1]
correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(y_data,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1,51):
        for batch in range(batch_num):
            batch_x, batch_y = mnist_set.train.next_batch(batch_size)
            sess.run(train_op,feed_dict={x:batch_x,y_data:batch_y})
        acc = sess.run(accuracy,feed_dict={x:mnist_set.test.images,y_data:mnist_set.test.labels})
        print('第%d代:测试正确率为:%s' % (epoch, acc))


结果如下

Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz
第1代:测试正确率为:0.2348
第2代:测试正确率为:0.3211
第3代:测试正确率为:0.4589
第4代:测试正确率为:0.5411
第5代:测试正确率为:0.583
第6代:测试正确率为:0.6128
第7代:测试正确率为:0.6417
第8代:测试正确率为:0.6665
第9代:测试正确率为:0.6897
第10代:测试正确率为:0.7075
第11代:测试正确率为:0.7209
第12代:测试正确率为:0.7311
第13代:测试正确率为:0.7406
第14代:测试正确率为:0.7484
第15代:测试正确率为:0.7543
第16代:测试正确率为:0.7597
第17代:测试正确率为:0.7635
第18代:测试正确率为:0.7679
第19代:测试正确率为:0.7717
第20代:测试正确率为:0.7748
第21代:测试正确率为:0.7787
第22代:测试正确率为:0.7807
第23代:测试正确率为:0.7855
第24代:测试正确率为:0.7919
第25代:测试正确率为:0.8064
第26代:测试正确率为:0.8317
第27代:测试正确率为:0.8476
第28代:测试正确率为:0.8571
第29代:测试正确率为:0.8641
第30代:测试正确率为:0.8676
第31代:测试正确率为:0.8712
第32代:测试正确率为:0.8734
第33代:测试正确率为:0.8757
第34代:测试正确率为:0.8775
第35代:测试正确率为:0.8799
第36代:测试正确率为:0.8819
第37代:测试正确率为:0.8828
第38代:测试正确率为:0.8842
第39代:测试正确率为:0.8861
第40代:测试正确率为:0.8866
第41代:测试正确率为:0.8883
第42代:测试正确率为:0.889
第43代:测试正确率为:0.8893
第44代:测试正确率为:0.8899
第45代:测试正确率为:0.8918
第46代:测试正确率为:0.8919
第47代:测试正确率为:0.8919
第48代:测试正确率为:0.8934
第49代:测试正确率为:0.8934
第50代:测试正确率为:0.8943

猜你喜欢

转载自blog.csdn.net/ruguowoshiyu/article/details/81974312