python学习笔记3

# 成功运行了!
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 创建一个神经网络
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)

# 二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

# 初始化变量
initial = tf.global_variables_initializer()

# 结果存放在一个布尔型列表中    #argmax返回一位张量中最大的值所在位置
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

# 求准确率 把true变成1.0 把fault变成0.0
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(initial)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})

        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        print("Iter" + str(epoch) + ",Test Accuracy" + str(acc))
# 二次代价函数:
Iter0,Test Accuracy0.8304
Iter1,Test Accuracy0.8704
Iter2,Test Accuracy0.8813
Iter3,Test Accuracy0.8883
Iter4,Test Accuracy0.895
Iter5,Test Accuracy0.8968
Iter6,Test Accuracy0.8992
Iter7,Test Accuracy0.9022
Iter8,Test Accuracy0.9037
Iter9,Test Accuracy0.9054
Iter10,Test Accuracy0.9068
Iter11,Test Accuracy0.9071
Iter12,Test Accuracy0.9074
Iter13,Test Accuracy0.909
Iter14,Test Accuracy0.9093
Iter15,Test Accuracy0.9107
Iter16,Test Accuracy0.9122
Iter17,Test Accuracy0.9128
Iter18,Test Accuracy0.9127
Iter19,Test Accuracy0.9143
Iter20,Test Accuracy0.9136
# 交叉熵代价函数 效果更好一点
Iter0,Test Accuracy0.8242
Iter1,Test Accuracy0.8831
Iter2,Test Accuracy0.8993
Iter3,Test Accuracy0.9048
Iter4,Test Accuracy0.9088
Iter5,Test Accuracy0.9093
Iter6,Test Accuracy0.9122
Iter7,Test Accuracy0.9131
Iter8,Test Accuracy0.9146
Iter9,Test Accuracy0.9164
Iter10,Test Accuracy0.9172
Iter11,Test Accuracy0.918
Iter12,Test Accuracy0.9201
Iter13,Test Accuracy0.9192
Iter14,Test Accuracy0.9196
Iter15,Test Accuracy0.921
Iter16,Test Accuracy0.9211
Iter17,Test Accuracy0.92
Iter18,Test Accuracy0.9209
Iter19,Test Accuracy0.9214
Iter20,Test Accuracy0.922

猜你喜欢

转载自blog.csdn.net/qq_39683287/article/details/80469986