TensorFlow实现MNIST逻辑回归 + TensorBoard可视化

一、代码

# coding=utf-8

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#数据集
mnist = input_data.read_data_sets("MNIST_DATA/", one_hot=True)# 读取MNIST,独热编码

#定义模型
x = tf.placeholder(tf.float32, [None, 784], name='X')# 输入x
y = tf.placeholder(tf.float32, [None, 10], name='Y')# 标签y

W = tf.Variable(tf.zeros([784, 10]), name='W')# 学习变量、权重
b = tf.Variable(tf.zeros([10]), name='b')# 偏置

with tf.name_scope("wx_b") as scope:
    y_hat = tf.nn.softmax(tf.matmul(x,W)+b)# 多元线性回归

w_h = tf.summary.histogram("weights", W)# 权重随时间变化
b_h = tf.summary.histogram("biases", b)# 偏置随时间变化

with tf.name_scope("cross-entropy") as scope:
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_hat))# 定义交叉熵和损失函数
    tf.summary.scalar('cross-entropy', loss)# 随时间变化的损失函数

with tf.name_scope("Train") as scope:
    optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)# 简单梯度下降算法优化器,学习速率0.01

#训练
batch_size = 100# 每一批的训练量
max_epochs = 100# 总迭代次数
correct_prediction = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))# 正确预测
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 准确率
init = tf.global_variables_initializer()# 初始化变量

merged_summary_op = tf.summary.merge_all()# 生成摘要

with tf.Session() as sess:
    sess.run(init)
    summary_writer = tf.summary.FileWriter('graphs', sess.graph)# 将摘要与图形写入graphs目录

    for epoch in range(max_epochs):
        loss_avg = 0
        num_of_batch = int(mnist.train.num_examples / batch_size)# 第几批
        for i in range(num_of_batch):
            batch_xs, batch_ys = mnist.train.next_batch(100)# 从数据集中取出100个输入和标签
            _, l, summary_str = sess.run([optimizer,loss,merged_summary_op], feed_dict={x:batch_xs, y:batch_ys})# 运行优化器、损失函数、摘要,并馈送数据
            loss_avg += 1
            summary_writer.add_summary(summary_str, epoch*num_of_batch + i)# 添加摘要数据
            loss_avg = loss_avg / num_of_batch
            print('Epoch {0}: Loss {1}'.format(epoch, loss_avg))
        print('Done')

        # 评估
        print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))

accuracy=0.9091

二、命令行运行TensorBoard可视化:

tensorboard --logdir=graphs

三、效果

  1. GRAPHS 模型图
    在这里插入图片描述
  2. HISTOGRAMS 直方图
    在这里插入图片描述
    在这里插入图片描述
  3. DISTRIBUTIONS 分布在这里插入图片描述
    在这里插入图片描述
  4. SCALARS 标量
    在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/lly1122334/article/details/87787658