模型的保存和读取

saver = tf.train.Saver(max_to_keep=3) #模型保存三个

with tf.Session() as sess:
	saver_path = saver.save(sess, "save/model.ckpt")
	...
	saver.restore(sess, "save/model.ckpt")
	print (sess.run(v1))
if do_train == 1:
	...
	#  Save Net
	if epoch % save_step == 0:
		saver.save(sess, "save/nets/cnn_mnist_basic.ckpt-" + str(epoch))

print ("OPYIMIZATION FINISHED")

if do_train == 0:
	epoch = training_epochs -1 
	saver.restore(sess, "save/nets/cnn_mnist_basic.ckpt" + str(epoch))

	test_acc = sess.run(accr, feed_dict={x: testing, y: testlabel, keepratio:1})
	print ("TEST ACCURACY: %.3f" % (test_acc))

猜你喜欢

转载自blog.csdn.net/Miracle_520/article/details/85284784