tensorflow-训练检查点tf.train.Saver


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar('myvar',my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter('sum_vars',sess.graph)
    sess.run(init)
    #---0
    step,var,summary=sess.run([my_step,my_var,merged_summaries])
    writer.add_summary(summary,global_step=step)
    print step,var
    saver=tf.train.Saver()
    #1-49
    for i in xrange(1,50):
        sess.run(addop)
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var
        if i%5==0:
            saver.save(sess,'./myvar-model/myvar-model',global_step=i)
    saver.save(sess,'./myvar-model/myvar-model',global_step=49)

    writer.flush()
    writer.close()

38 0.0512373
39 0.04996785
40 0.048759546
41 0.04760808
42 0.04650955
43 0.045460388
44 0.04445735
45 0.04349747
46 0.042578023
47 0.041696515
48 0.040850647
49 0.04003831

保存数据流图的变量到二进制检查点文件。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:[email protected]
"""

import tensorflow as tf
import os
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32,trainable=False)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar('myvar',my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter('sum_vars',sess.graph)
    sess.run(init)

    saver=tf.train.Saver()

    #如果之前保存了检查点文件,则恢复模型后,继续
    init_step=0
    ckpt=tf.train.get_checkpoint_state(os.getcwd()+'/myvar-model')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step=int(ckpt.model_checkpoint_path.rsplit('-',1)[1])
        print "读取检查点文件..."
    for i in xrange(init_step,100):
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var,init_step
        if i%5==0 and i<=50:
            print "保存检查点文件"
            saver.save(sess,'./myvar-model/myvar-model',global_step=i)
        sess.run(addop)

    writer.flush()
    writer.close()

上面代码跑第一次时,检查点文件被保存,跑第二次开始,检查点文件将被读取,循环次数从step=50开始。

跑第二次时

读取检查点文件...
50 0.03925755 50
保存检查点文件
51 0.038506564 50
52 0.037783686 50
53 0.03708737 50
54 0.036416177 50
55 0.035768777 50
56 0.03514393 50
...
...
...
93 0.021334965 50
94 0.02111056 50
95 0.02089082 50
96 0.0206756 50
97 0.020464761 50
98 0.020258171 50
99 0.020055704 50

猜你喜欢

转载自blog.51cto.com/13959448/2326699