MNIST数据集分类简单版本

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.compat.v1.disable_eager_execution()
# 载入数据集,one_hot = True 采用独热编码,即 1-> 0100000000 ,5-> 0000010000
mnist = input_data.read_data_sets("mnist_data",one_hot = True)

# 批次大小,每次训练放入64个数据,批次大小通常为16,32,64
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 创建一个神经网络784-10
# 定义两个placeholder
x = tf.compat.v1.placeholder(tf.float32,[None,784])
y = tf.compat.v1.placeholder(tf.float32,[None,10])

W = tf.Variable(tf.random.normal([784,10]))
b = tf.Variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(x,W) + b)

# 定义二次代价函数
#loss = tf.reduce_mean(tf.square(y - prediction))
loss = tf.compat.v1.losses.mean_squared_error(y,prediction)
# 训练
train = tf.compat.v1.train.GradientDescentOptimizer(0.5).minimize(loss)

# correct_prediction 得到一个布尔型的结果 tf.argmax(y,1)是返回每一行最大索引值,1换成0就是每一列最大索引值
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 准确率计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    # 所有数据训练一次是一个周期
    for epoch in range(21):
        for batch in range(n_batch):
            #获取一个批次的数据或标签
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)            
            sess.run(train,feed_dict = {
    
    x:batch_xs,y:batch_ys})
        acc = sess.run(accuracy,feed_dict = {
    
    x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch) + ".Testing Accuracy" + str(acc))            

运行结果
Extracting mnist_data\train-images-idx3-ubyte.gz
Extracting mnist_data\train-labels-idx1-ubyte.gz
Extracting mnist_data\t10k-images-idx3-ubyte.gz
Extracting mnist_data\t10k-labels-idx1-ubyte.gz
Iter0.Testing Accuracy0.2995
Iter1.Testing Accuracy0.4193
Iter2.Testing Accuracy0.5015
Iter3.Testing Accuracy0.5526
Iter4.Testing Accuracy0.5839
Iter5.Testing Accuracy0.6046
Iter6.Testing Accuracy0.6198
Iter7.Testing Accuracy0.6359
Iter8.Testing Accuracy0.653
Iter9.Testing Accuracy0.678
Iter10.Testing Accuracy0.6983
Iter11.Testing Accuracy0.7145
Iter12.Testing Accuracy0.725
Iter13.Testing Accuracy0.7313
Iter14.Testing Accuracy0.7397
Iter15.Testing Accuracy0.7452
Iter16.Testing Accuracy0.7503
Iter17.Testing Accuracy0.7551
Iter18.Testing Accuracy0.7605
Iter19.Testing Accuracy0.7651
Iter20.Testing Accuracy0.7684

16行把正态分布改成tf.turncated_normal(([784,10]),stddev = 0.1)) 学习效果明显有提升,初始化方式对训练有影响,一般采用truncated_normal这种方式,偏置一般设为0或者0.1,如果用上面的函数把标准差改为0.1同样学习效果提升比较快,所以,和初始化数据的方差也有关系。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.compat.v1.disable_eager_execution()
# 载入数据集,one_hot = True 采用独热编码,即 1-> 0100000000 ,5-> 0000010000
mnist = input_data.read_data_sets("mnist_data",one_hot = True)

# 批次大小,每次训练放入64个数据,批次大小通常为16,32,64
batch_size = 64
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 创建一个神经网络784-10
# 定义两个placeholder
x = tf.compat.v1.placeholder(tf.float32,[None,784])
y = tf.compat.v1.placeholder(tf.float32,[None,10])

W = tf.Variable(tf.random.truncated_normal(([784,10]),stddev = 0.1))
b = tf.Variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(x,W) + b)

# 定义二次代价函数
#loss = tf.reduce_mean(tf.square(y - prediction))
loss = tf.compat.v1.losses.mean_squared_error(y,prediction)
# 训练
train = tf.compat.v1.train.GradientDescentOptimizer(0.5).minimize(loss)

# correct_prediction 得到一个布尔型的结果 tf.argmax(y,1)是返回每一行最大索引值,1换成0就是每一列最大索引值
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 准确率计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    # 所有数据训练一次是一个周期
    for epoch in range(21):
        for batch in range(n_batch):
            #获取一个批次的数据或标签
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)            
            sess.run(train,feed_dict = {
    
    x:batch_xs,y:batch_ys})
        acc = sess.run(accuracy,feed_dict = {
    
    x:mnist.test.images,y:mnist.test.labels})
        print("Iter" + str(epoch) + ".Testing Accuracy" + str(acc))            

运行结果
Extracting mnist_data\train-images-idx3-ubyte.gz
Extracting mnist_data\train-labels-idx1-ubyte.gz
Extracting mnist_data\t10k-images-idx3-ubyte.gz
Extracting mnist_data\t10k-labels-idx1-ubyte.gz
Iter0.Testing Accuracy0.8815
Iter1.Testing Accuracy0.8982
Iter2.Testing Accuracy0.9051
Iter3.Testing Accuracy0.9101
Iter4.Testing Accuracy0.9132
Iter5.Testing Accuracy0.9149
Iter6.Testing Accuracy0.9166
Iter7.Testing Accuracy0.9185
Iter8.Testing Accuracy0.9193
Iter9.Testing Accuracy0.9193
Iter10.Testing Accuracy0.9207
Iter11.Testing Accuracy0.9213
Iter12.Testing Accuracy0.9223
Iter13.Testing Accuracy0.922
Iter14.Testing Accuracy0.9227
Iter15.Testing Accuracy0.9229
Iter16.Testing Accuracy0.9229
Iter17.Testing Accuracy0.9236
Iter18.Testing Accuracy0.9241
Iter19.Testing Accuracy0.9241
Iter20.Testing Accuracy0.9244

猜你喜欢

转载自blog.csdn.net/weixin_44823313/article/details/112510896
今日推荐