TF0007、神经网络分类MNIST数据集

minst数据集下载地址

链接:https://pan.baidu.com/s/1ka0L6MHfeFWiqGOeJjm3Tw 
提取码:nv0g 

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import warnings
warnings.filterwarnings('ignore')
# 载入数据集,我已经把数据集下载到本地
mnist = input_data.read_data_sets(train_dir=r'C:\Users\zx\Desktop\MLCodes\MyAIRoad\MNIST_data',one_hot=True)
Extracting C:\Users\zx\Desktop\MLCodes\MyAIRoad\MNIST_data\train-images-idx3-ubyte.gz
Extracting C:\Users\zx\Desktop\MLCodes\MyAIRoad\MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:\Users\zx\Desktop\MLCodes\MyAIRoad\MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:\Users\zx\Desktop\MLCodes\MyAIRoad\MNIST_data\t10k-labels-idx1-ubyte.gz
# 批次大小,即每次放多少条数据到网络进行训练
batch_size = 64
# 计算训练集里一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
#用来输入训练数据
x = tf.placeholder(tf.float32,[None,784])
#用来输入训练标签
y = tf.placeholder(tf.float32,[None,10])

# 创建的神经网络:784-10-10 输入层784,隐藏层10,输出层10
#1.输入层到隐藏层
W1 = tf.Variable(tf.truncated_normal([784,10]
                                    , stddev=0.1
                                   ))
b1 = tf.Variable(tf.zeros([10])+0.1)
#激活函数使用softmax
l1 = tf.nn.softmax(tf.matmul(x,W1)+b1)

#2.隐藏层到输出层
W2 = tf.Variable(tf.truncated_normal([10,10]
                                     ,stddev=0.1))

b2 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(l1,W2)+b2)


# 二次代价函数
loss = tf.losses.mean_squared_error(y, prediction)
#设置学习率
lr = 0.3
# 使用梯度下降法
train = tf.train.GradientDescentOptimizer(lr).minimize(loss)

# 结果存放在一个布尔型列表中 tf.argmax(y,1) 取y第1个维度中最大值
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 求准确率,cast 把布尔值转换成0 1
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())
    # 周期epoch:所有数据训练一次,就是一个周期
    for epoch in range(200):
        for batch in range(n_batch):
            # 获取一个批次的数据和标签
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
        # 每训练一个周期做一次测试
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        #每训练20次打印一次
        if epoch%20==0:
            print("epoch " + str(epoch+1) + ",Testing Accuracy " + str(acc))
epoch 1,Testing Accuracy 0.1867
epoch 21,Testing Accuracy 0.4955
epoch 41,Testing Accuracy 0.624
epoch 61,Testing Accuracy 0.7113
epoch 81,Testing Accuracy 0.7516
epoch 101,Testing Accuracy 0.781
epoch 121,Testing Accuracy 0.8003
epoch 141,Testing Accuracy 0.8072
epoch 161,Testing Accuracy 0.8139
epoch 181,Testing Accuracy 0.8182
发布了23 篇原创文章 · 获赞 1 · 访问量 3350

猜你喜欢

转载自blog.csdn.net/ABCDABCD321123/article/details/104577699