tensorflow开发基本步骤

Tensorflow开发的基本步骤:

  • 定义Tensorflow输入节点
  1. 通过占位符定义:
    X = tf.placeholder("float")

    2.通过字典类型定义:

inputdict = {
    'x': tf.placeholder("float"),
    'y': tf.placeholder("float")
}

  3. 直接定义输入节点:

train_x = np.float32(np.linspace(-1,1,100))
  • 定义“学习参数”的变量
  • 定义“运算”
  • 优化函数,优化目标
  • 初始化所有变量
  • 迭代更新参数到最优解
  • 测试模型
  • 使用模型

2、模型保存与载入

  • 模型保存:
saver = tf.train.Saver()  #生成saver
saverdir = "log/"
with tf.Session() as sess:
    sess.run(init)
    print("Finished")
    saver.save(sess,saverdir+"linermodel.cpkt")
  • 模型载入:
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    saver.restore(sess2,saverdir+"linermodel.cpkt")
    print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))

检查点(Checkpoint):Tensorflow训练模型时难免会出现中断的情况,希望能够将辛苦得到的中间参数保留下来,在训练中保存模型,习惯上称之为保存检查点。

 saver = tf.train.Saver(max_to_keep=1)  #生成saver
 saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))

猜你喜欢

转载自www.cnblogs.com/wyx501/p/10541524.html