Mnist1--Simple

# http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html     #极客学院
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)       #下载很慢,会出错。可以提前到网上下好,放到相关文件夹下

import tensorflow as tf
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))       #[1,10]

y  = tf.nn.softmax(tf.matmul(x,W) + b)      #softmax-- 使标准化 ,y--预测
y_ = tf.placeholder("float", [None,10])    #y_ :真实值

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))    #交叉熵

optimization = tf.train.GradientDescentOptimizer(0.01)   #更新权值
train_step   = optimization.minimize(cross_entropy)

#模型评估
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))      #返回最大的那个数值所在的下标
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))     #布尔值转换成浮点数,然后取平均值

#初始化变量
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    ###batch_xs: (100,784), batch_ys: (100,10), <class 'numpy.ndarray'>
    ## #mnist.test.images:(10000,784), mnist.test.labels:(10000,10) , <class 'numpy.ndarray'>
    # batch_x,batch_y = mnist.test.next_batch(100)    ##batch_x:(100,784) , batch_y:(100,10)
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    Z = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels})
    print(Z)



猜你喜欢

转载自blog.csdn.net/qq_34638161/article/details/81037650