Tensorflow官方文档学习理解 (四)-深入MNIST

Tensorflow基于一个高效的C++模块进行运算,与这个模块的连接叫做session。一般而言,使用Tensorflow程序的流程是先创建一个图,然后在session中加载它。这里,我们使用更加方便的InteractiveSession类。通过它,你可以更加灵活地构建你的代码。他能让你在运行图的时候,插入一些构建计算图的操作。这能给使用交互文本shell如IPython带来便利。如果没有InteractiveSession的话,你需要在开始session和加载图之前,构建整个计算图。

sess = tf.InteractiveSession()

和之前一样,我们需要先来创建计算图的输入图片和输出类别。

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

接下来再定义权重和偏置:

w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

Variable需要在session之前初始化,才能在session中使用,初始化需要初始值(本例当中全为0),并传入赋值给Variable,这个操作可以一次性完成。

sess.run(tf.global_variables_initializer())

之后再引入softmax激活函数:

y = tf.nn.softmax(tf.matmul(x, w) + b)

使用交叉熵最小化损失函数:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

使用梯度下降算法对其进行训练:

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

接下来对其进行迭代训练:

for i in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x:batch[0], y_:batch[1]})

之后验证模型准确率:

correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels}))

总代码如下:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.global_variables_initializer())
y = tf.nn.softmax(tf.matmul(x, w) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
for i in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x:batch[0], y_:batch[1]})
correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels}))

 

 

 

 

 

 

 

 

 

猜你喜欢

转载自blog.csdn.net/weixin_39059031/article/details/82850453
今日推荐