Tensorflow — Simple version of MNIST dataset classification

Code:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


#Load dataset
#current path
mnist = input_data.read_data_sets("MNISt_data", one_hot=True)

operation result:

Extracting MNISt_data/train-images-idx3-ubyte.gz
Extracting MNISt_data/train-labels-idx1-ubyte.gz
Extracting MNISt_data/t10k-images-idx3-ubyte.gz
Extracting MNISt_data/t10k-labels-idx1-ubyte.gz

Code:

# size of each batch
# put it in the form of a matrix
batch_size = 100
# Calculate how many batches there are
n_batch = mnist.train.num_examples // batch_size


#Define two placeholders
#28 x 28 = 784
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])


#Create a simple neural network
#input layer 784, no hidden layer, output layer 10 neurons
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([1, 10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)


#Secondary cost function
loss = tf.reduce_mean(tf.square(y - prediction))


# use gradient descent
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#Initialize variables
init = tf.global_variables_initializer()



#The result is stored in a boolean list
#tf.argmax(y, 1) is the same as tf.argmax(prediction, 1) and returns True, otherwise returns False
#argmax returns the position of the largest value in a one-dimensional tensor
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

# find the accuracy
#tf.cast(correct_prediction, tf.float32) convert boolean to float
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


with tf.Session() as sess:
    sess.run(init)
    #21 cycles in total
    for epoch in range(21):
        #Total n_batch batches
        for batch in range(n_batch):
            # get a batch
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
        
        #Accuracy after one cycle of training
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
        print("Iter" + str(epoch) + ", Testing Accuracy" + str(acc))

operation result:

Iter0, Testing Accuracy0.8331
Iter1, Testing Accuracy0.8715
Iter2, Testing Accuracy0.8811
Iter3, Testing Accuracy0.8885
Iter4, Testing Accuracy0.8938
Iter5, Testing Accuracy0.8967
Iter6, Testing Accuracy0.9005
Iter7, Testing Accuracy0.9022
Iter8, Testing Accuracy0.9043
Iter9, Testing Accuracy0.9048
Iter10, Testing Accuracy0.9062
Iter11, Testing Accuracy0.907
Iter12, Testing Accuracy0.908
Iter13, Testing Accuracy0.9088
Iter14, Testing Accuracy0.9099
Iter15, Testing Accuracy0.9113
Iter16, Testing Accuracy0.911
Iter17, Testing Accuracy0.9124
Iter18, Testing Accuracy0.9131
Iter19, Testing Accuracy0.914
Iter20, Testing Accuracy0.9136

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325829145&siteId=291194637