tensorflow学习笔记——简单MNIST数据集分类

首先在这个网站下载MNIST数据集。下载后的数据如图所示:

代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据
mnist = input_data.read_data_sets("E:/mnist",one_hot=True)
#每个批次大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples//batch_size #整除

x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

网络部分:网络搭建比较灵活,不同的层数、节点数等都会对训练出的模型产生影响,这里只有一层隐藏层,节点为10。

#输入层到隐藏层
W = tf.Variable(tf.zeros([784,10]))#这里W初始化为0,可以更快收敛
b = tf.Variable(tf.zeros([10]))
Wx_plus_b_L1 = tf.matmul(x,W) + b

#隐藏层到输出层
W2 = tf.Variable(tf.random_normal([10,10]))#隐藏层不能初始化为0
b2 = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(Wx_plus_b_L1,W2)+b2)

训练并测试:

#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)#学习率为0.2

init = tf.global_variables_initializer()

#求准确率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#equl判断是否相等,argmax返回张量最大值的索引
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #将布尔型转换为浮点型

with tf.Session() as sess:
    sess.run(init)
    #迭代训练20次
    for epoch in range(20):
        for batch in range(n_batch):
            #训练集数据与标签
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + " Accuracy" + str(acc))

训练结果如下:

Iter 0 Accuracy0.8866
Iter 1 Accuracy0.9057
Iter 2 Accuracy0.91
Iter 3 Accuracy0.9127
Iter 4 Accuracy0.9156
Iter 5 Accuracy0.9187
Iter 6 Accuracy0.9186
Iter 7 Accuracy0.9201
Iter 8 Accuracy0.9227
Iter 9 Accuracy0.9225
Iter 10 Accuracy0.9243
Iter 11 Accuracy0.9244
Iter 12 Accuracy0.9254
Iter 13 Accuracy0.925
Iter 14 Accuracy0.9249
Iter 15 Accuracy0.9261
Iter 16 Accuracy0.9274
Iter 17 Accuracy0.9259
Iter 18 Accuracy0.9266
Iter 19 Accuracy0.9261

 之前写的准确率比较低,后面又写了一个版本,准确率有98%。大家想看也可以点这里

发布了33 篇原创文章 · 获赞 148 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_40692109/article/details/104094160