零基础学习tensorflow----模型的保存与载入(一)

废活不多说,直接上代码,代码是博主一个一个敲得,每一行都加了注释。
如果你想真正学透一门学问,必须要求亲手实践它,正所谓,好记性不如烂笔头嘛。
光看视频模型是没用的,敲代码才是王道。
下面放一个最简单的tensorflow里的模型载入和保存。
ps.敲的代码最好都运行一遍

import tensorflow as tf 
from tensorflow.examples.tutprials.mnist import input_data

#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#定义批次的大小--64个
batch_size = 64
#计算有多少个批次
n_batch = mnist.train.num_example//batch_size

#定义两个placeholder
#给模型输入起名为x-input
x = tf.placeholder(tf.float32,[None,784],name='x-input')
#给模型输入起名为y-input
y = tf.placeholder(tf.float32,[None,10],name='y-input')

#建立一个简单的神经网络,输入层有784个神经元,输出层10个神经元
W = tf.Variable(tf.truncated_normal([784,10],stddev=0.1))
b = tf.Variable(tf.zero([10])+0.1)
#给模型的输出起名为output
prediction = tf.nn.softmax(tf.matmul(x,W)+b,name='output')
#交叉熵代价函数
loss = tf.lossrd.softmax_cross_entropy(y,prediction)
#使用Adam优化器,给优化器operation起名为train
train_step = tf.train.AdamOptimizer(0.001).minimize(loss,name='train')

#初始化变量
init = tf.global_variable_initializer()

#求准确率
#tf.argmax(y,1)中的1表示取y中第一个维度中的最大值所在的位置
#tf.equal表示比较两个值是否相等,相等返回ture,不相等返回False
#最后correct——prediction是一个布尔型的列表
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#tf.cast表示数据格式转换,把布尔类型转化为float类型,True变成1.0,False变成0.0
#tf.reduce_mean求平均值
#最后accuracy为准确率
#给准确率tensor起名为accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name='accuracy')


#定义Saver用于保存模型
saver = tf.train.Session() 
with tf.Session() as sess:
	#变量初始化
	sess.run(init)
	#运行11个周期
	for batch in range(n_batch):
		#获取一个批次的数据和标签
		batch_xs,batch_ys = mnist.train.next_batch(batch_size)
		#喂到模型中做训练
		sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys})
	#每个周期计算一次测试集准确率
	acc=sess.run(accuracy,feed_dic={x:mnist.test.images,y:mnist.test.labels})
	#打印信息
	print("Iter "+str(epoch)+",Testing accuracy"+str(acc))
	#保存模型
	saver.save(sess,'models/my_model.ckpt')
发布了76 篇原创文章 · 获赞 14 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_36444039/article/details/102238292
今日推荐