tensorlfow 可视化

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_24548569/article/details/82085973

为了方便TensorFlow程序的理解、调试与优化,TensorFlow发布了可视化工具TensorBoard,我们可以使用TensorBoard来展示TensorFlow的图像,绘制图像生成的定量指标图以及附加数据。

使用tf.summary记录要记录的标量,使用tf.merge_all_summaries合并所有的summary,避免逐个操作每个记录,当使用session运行之后生成数据,就可以使用tf.train.Summarywriter保存记录。

下面直接给出一个简单的例子:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data")

with tf.name_scope("input"):
    X = tf.placeholder(tf.float32, [None, 784], name="x-input")
    Y = tf.placeholder(tf.int64, [None], name="y-input")
with tf.name_scope("matmul"):
    W = tf.Variable(tf.zeros([784, 10]), name="weight")
    b = tf.Variable(tf.zeros([10]), name="bias")
    Y_hat = tf.matmul(X, W) + b
with tf.name_scope("loss"):
    cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=Y, logits=Y_hat)
    # 添加记录点,记录loss的变化
    tf.summary.scalar("loss", cross_entropy)

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.name_scope("accuracy"):
    correct_prediction = tf.equal(tf.argmax(Y_hat, 1), Y)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # 添加记录点,记录accuracy的变化
    tf.summary.scalar("accuracy", accuracy)

# 合并所有summary
merged = tf.summary.merge_all()

init = tf.global_variables_initializer()
with tf.Session() as sess:

    writer = tf.summary.FileWriter("log", sess.graph)
    sess.run(init)

    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={X: batch_xs, Y: batch_ys})
        if i % 100 == 0:
            summary, acc = sess.run([merged, accuracy], feed_dict={X: mnist.test.images, Y: mnist.test.labels})
            print("iter %d accuracy: %.3f" % (i, acc))
            # 保存记录
            writer.add_summary(summary, i)
    writer.close()

记录数据保存在log文件夹,我们使用TensorBoard可视化这些数据。
在代码目录中运行命令:

tensorboard --logdir log

根据输出的提示信息在浏览器中输入相应的地址,页面展示了loss和accuracy的变化曲线:
loss曲线

accuracy

选择GRAPHS查看该任务的数据流图。

注意,如果本文使用的MNIST数据集下载不了,使用该链接下载:
链接: https://pan.baidu.com/s/1tDAZJvfjlfzAaUo694pUrA 密码: wez6
然后解压缩到代码目录。

猜你喜欢

转载自blog.csdn.net/qq_24548569/article/details/82085973