[tf]当summary 遇到 placeholder

程序里面一般这么写

tf.summary.scalar('dis_loss',loss)
...
tf.summary.scalar('gene_loss',loss)

但是我写了一个RL的程序,需要sess run三次,这样就不能使用sum_op = tf.summary.merge_all()来收集所有的summary了,因为不同的阶段,feed的不一样,而mearge_all的话,又需要feed之前全部的数据,等于为了执行一个summary操作需要把所有的计算图重新计算一遍,而且如果强行sess.run不feed的话就会出现错误什么什么placeholder feed ,但是没有。

解决方法一

sess.run的结果使用summary.value.add添加。

summary_writer = tf.summary.FileWriter(LOGDIR)
summary = tf.Summary()
...
diss_loss = sess.run(diss_loss_op)
...
gene_loss = sess.run(gene_loss_op)
...
summary.value.add(tag="diss_loss", simple_value=diss_loss)
summary.value.add(tag="gene_loss", simple_value=gene_loss)
# step代表横轴坐标
summary_writer.add_summary(summary, step)

解决方法二

直接把 tf.summary.scalar()的返回值写进去。

def train():

    test_val = tf.placeholder(tf.float32, name='tmp1')
    summary_1 = tf.summary.scalar('tmp1', test_val)

    test_val2 = tf.placeholder(tf.float32, name='tmp2')
    summary_2 = tf.summary.scalar('tmp2', test_val2)

    sess = tf.InteractiveSession()

    train_writer = tf.summary.FileWriter('be_polite', sess.graph)
    tf.global_variables_initializer().run()

    summary, val1 = sess.run([summary_1, test_val], feed_dict={'tmp1:0': 1.0})
    train_writer.add_summary(summary)
    print('Val1: %f' % val1)

    summary2, val2 = sess.run([summary_2, test_val2], feed_dict={'tmp2:0': 2.0})
    train_writer.add_summary(summary2)
    print('Val2: %f' % val2)

猜你喜欢

转载自blog.csdn.net/weixin_34128411/article/details/87337836