Tensorflow— saver_save

Code:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


#Load dataset
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

operation result:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

Code:

#100 photos per batch
batch_size = 100
# Calculate how many batches there are
n_batch = mnist.train.num_examples // batch_size

#Define two placeholders
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#Create a simple neural network with 784 neurons in the input layer and 10 neurons in the output layer
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#Secondary cost function
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# use gradient descent
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#Initialize variables
init = tf.global_variables_initializer()

#The result is stored in a boolean list
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax returns the position of the largest value in the one-dimensional tensor
# find the accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))


# define a saver
saver = tf.train.Saver()


with tf.Session() as sess:
    sess.run(init)
    for epoch in range(11):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
    # save the model
    saver.save(sess,'net/my_net.ckpt')

operation result:

Iter 0,Testing Accuracy 0.8237
Iter 1,Testing Accuracy 0.8937
Iter 2,Testing Accuracy 0.9018
Iter 3,Testing Accuracy 0.906
Iter 4,Testing Accuracy 0.9089
Iter 5,Testing Accuracy 0.9111
Iter 6,Testing Accuracy 0.9118
Iter 7,Testing Accuracy 0.9128
Iter 8,Testing Accuracy 0.9147
Iter 9,Testing Accuracy 0.916
Iter 10,Testing Accuracy 0.9168

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325832400&siteId=291194637