Tensorflow复习笔记4:Tensorboard可视化工具

一、tensorflow代码部分

1. 定义需要监控的summary节点
  • 定义单个节点:
# 例子1:
b_conv = bias_variable([bias_shape])
# 会在 tensorboard 的 distributions栏 和 histograms栏 生成对应图表。
tf.summary.histogram('b_conv', b_conv)

# 例子2:
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 监控交叉熵
loss_scalar = tf.summary.scalar('loss', cross_entropy)
  • 定义全部节点的集合:
summary_merged = tf.summary.merge_all()
# 这句可以放在Session开始之前
2. 新建记录summary的FileWriter

TensorBoard会自动扫描给定目录下的所有文件夹。

train_writer = tf.summary.FileWriter(logs_path+'train', sess.graph)
validation_writer = tf.summary.FileWriter(logs_path+'validation')
# 这句可以放在Session刚开始之后
3. 运行summary节点,写入第2步打开的指定文件
  • 运行单个summary节点
(sum_accuracy_validation,
 sum_loss_validation,
 accuracy_currut_validation) = sess.run([accuracy_scalar, loss_scalar, accuracy], feed_dict={x:mnist.validation.images, y_: mnist.validation.labels, keep_prob: 1.0})
# write to summary
validation_writer.add_summary(sum_accuracy_validation, step)
validation_writer.add_summary(sum_loss_validation, step)
  • 运行全体summary节点
summary, accuracy_currut_train = sess.run([summary_merged, accuracy], feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
# write to summary
train_writer.add_summary(summary, step)
  • 运行 指定节点集合
    具体的trick是在定义smmary节点时加上collections参数,就能方便地统一管理summary节点了,如下:
tf.summary.scalar('learning_rate', p_lr, collections=['train'])
tf.summary.scalar('loss', t_loss, collections=['train', 'test'])
# 之后可以直接merge
s_training = tf.summary.merge_all('train')
参考文章:
https://stackoverflow.com/questions/41940299/tensorflow-how-to-use-tf-train-summarywriter-inside-supervisor-loop-0-12-0rc1

  或者另一个稍麻烦一些的做法是这篇文章https://www.cnblogs.com/lyc-seu/p/8647792.html的第9点提到的:用get_collection来获取定义的节点集合。

图形“同框”技巧
上图中的accuracy和loss图形中,训练集曲线和验证集曲线以不同颜色“同框”出现,特别便于对比分析。
同框需要满足以下两点:
要同框的曲线跟踪的必须是同一个节点,比如跟踪accuracy节点或loss节点;
各曲线的数据记录在不同的目录下,可以通过构造两个不同的文件写入器来达到;

完整代码:

# 定义交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 监控交叉熵
loss_scalar = tf.summary.scalar('loss', cross_entropy)

# 定义acc
correct_prediction = tf.equal(tf.argmax(y_output, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 监控acc
accuracy_scalar = tf.summary.scalar('accuracy', accuracy)

# tenrsorboard可视化
summary_merged = tf.summary.merge_all()
# log保存路径
logs_path = 'logs/'

# 启动会话
with tf.Session() as sess :
    train_writer = tf.summary.FileWriter(logs_path+'train', sess.graph)
    validation_writer = tf.summary.FileWriter(logs_path+'validation')
    sess.run(tf.global_variables_initializer())

    # 超参数
    EPOCH = 2
    SAMPLE_NUM = mnist.train.images.shape[0]
    # ITERATION = 500 +1
    BATCH_SIZE = 64

    for step in range(EPOCH) :
        for _ in range( int(SAMPLE_NUM/BATCH_SIZE)):
            batch = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

        # 计算当前训练集的准确率
        summary, accuracy_currut_train = sess.run([summary_merged, accuracy], feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
        # 写入summary
        train_writer.add_summary(summary, step)

        # 计算当前验证集的准确率
        vat_batch = mnist.validation.next_batch(BATCH_SIZE)
        (sum_accuracy_validation,
         sum_loss_validation,
         accuracy_currut_validation) = sess.run([accuracy_scalar, loss_scalar, accuracy], feed_dict={x:mnist.validation.images, y_: mnist.validation.labels, keep_prob: 1.0})
        # 写入summary
        validation_writer.add_summary(sum_accuracy_validation, step)
        validation_writer.add_summary(sum_loss_validation, step)

二、开启tensorboard

  1. 启动TensorBoard服务器:
tensorboard --logdir [your_dir]
# 或者指定端口开放:
tensorboard --logdir [your_dir] --port=6007
  1. 访问TensorBoard:
http://127.0.0.1:6006
或者 主机名:6006
# 主要是看启动TensorBoard服务器时,它给出的提示,它让咋访问就咋访问..

记录运行时的其他统计信息。比如运行时占的内存啥的。

本机尝试未成功,报错。

# 定义元数据
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
# 运行时 记录元数据
summary, accuracy_currut_train = sess.run([summary_merged, accuracy], feed_dict={x: batch[0], y_: batch[1]}, options=run_options, run_metadata=run_metadata)
# 写入元数据
train_writer.add_run_metadata(run_metadata, 'step%d' % step)

参考文章:

猜你喜欢

转载自blog.csdn.net/YiRanZhiLiPoSui/article/details/81142918