TensorFlow-GPU线性回归可视化代码,以及问题总结

通过TensorBoard将TensorFlow模型的训练过程进行可视化的展示出来,将训练的损失值随迭代次数的变化情况,以及神经网络的内部结构展示出来,以此更好的了解神经网络。

一、 建立图

   通过添加一个标量数据和一个直方图数据到log文件里,然后通过TensorBoard显示出来,第一步加到summary,第二步写入文件。

 将模型的生成值加入到直方图数据中(直方图名字为z),将损失函数加入到标量数据中(标量名字叫做loss_function)。

下面的代码就是在启动session之后创建一个summary_writer,在迭代中将summary的值运行出来,并且保存在文件里面

   代码如下:

# -*- coding: utf-8 -*-
# !/usr/bin/env python
# @Time    : 2019/5/16 9:47
# @Author  : xhh
# @Desc    :  线性回归的TensorBoard
# @File    : tensor_tensorBoard.py
# @Software: PyCharm

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

plotdata = {"batchsize":[], "loss":[]}

def moving_average(a, w=10):
    if len(a)<w:
        return a[:]
    return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]

#  模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.3  # 加入了噪声

# 图形展示
plt.plot(train_X,train_Y,'ro',label="original data") # label数据标签
plt.legend()
plt.show()

tf.reset_default_graph()  # 重置会话

# 创建模型
# 占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")

# 前向结构
z = tf.multiply(X, W) +b
tf.summary.histogram('z',z)  #将预测值以直方图显示

# 反向优化
cost = tf.reduce_mean(tf.square(Y-z))
tf.summary.scalar('loss_function', cost)  #将损失以标量显示
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 初始化变量
init = tf.global_variables_initializer()
# 参数设置
training_epochs = 20
display_step = 2
saver = tf.train.Saver() # 模型生成,并保存
savedir = "log/"

# 启动session
with tf.Session() as sess:
    sess.run(init)


    merged_summary_op = tf.summary.merge_all()  # 合并所以summary
    # 创建summary_writer,用于写文件
    summary_writer = tf.summary.FileWriter('log/mnist_with_summaries', sess.graph)

    for epoch in range(training_epochs):
        for (x, y) in zip(train_X,train_Y):
            sess.run(optimizer, feed_dict={X:x, Y:y})

            # 生成summary
            summary_str = sess.run(merged_summary_op, feed_dict={X:x, Y:y})
            summary_writer.add_summary(summary_str, epoch) # 将summary写入文件

        # 显示训练中的详细信息
        if epoch % display_step == 0:
            loss = sess.run(cost, feed_dict={X:train_X, Y:train_Y})
            print("Epoch:",epoch+1,"cost=", loss,"W=",sess.run(W),"b=",sess.run(b))
            if not (loss=="NA"):
                plotdata["batchsize"].append(epoch)
                plotdata["loss"].append(loss)

        print("finished!")
        saver.save(sess, savedir+"linermodel.cpkt")
        print("cost=",sess.run(cost, feed_dict={X:train_X, Y:train_Y}),"W=", sess.run(W),"b=",sess.run(b))

        # 图形显示
        plt.plot(train_X, train_Y, 'ro', label='Original data')
        plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
        plt.legend()
        plt.show()

        plotdata["avgloss"] = moving_average(plotdata["loss"])
        plt.figure(1)
        plt.subplot(211)
        plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
        plt.xlabel('Minibatch number')
        plt.ylabel('Loss')
        plt.title('Minibatch run vs. Training loss')
        plt.show()

        print("x=0.2, z=",sess.run(z, feed_dict={X: 0.2}))

最终的运行结果,下面贴出在建模过程中拟合线性模型的变化,以及它的损失值的变化:

下图三个折线图就是在拟合模型时损失函数的变化,分别是在第2,8,19次的变化。散点图就是最终拟合出来的模型。

下图就是在运行之后多出的可视化文件:

进入该文件夹,输入cmd,激活你带有TensorFlow-gpu版本的python环境

然后输入:

地址需要变化一下,后面的端口改成8080,我的默认的6006访问不了。

tensorboard --logdir F:\code\tensor_test\log\mnist_with_summaries --port=8080

最终运行后的结果如下:

然后打开谷歌浏览器(最好是谷歌),输入http://localhost:8080访问就可以看到了:

损失值随迭代次数的变化情况:

神经网络内部结构:

单击SCALARS,会看到之前创建的loss_function,点开后可以看到损失值随迭代次数的变化情况。如上图。

二、可能会出现的问题

(1)  在cmd时可能会出现这样的问题:

上面的是因为没有对应的TensorFlow-gpu的python所以访问不了。

(2)TensorFlow-GPU,python环境的问题

我的是下图的Python==3.6.2, tensorflow-gpu==1.13.1

python环境和tensorflow-gpu版本不兼容,所以出现下面的问题,最好将python环境换成了python=3.6.7,并且安装TensorFlow-gpu=1.13.1,之后运行成功了

(3)访问问题

在用127.0.0.1:6006,访问时出现拒绝访问,如下:

是因为本机的默认ip地址为localhost所以访问不了,换成localhost可以访问。

在或着可能出现6006端口访问不了,这时就需要在后面给上指定IP=8080,然后访问就OK了。

大家可以关注我和我小伙伴的公众号~~~这里有我和我的小伙伴不定时的更新一些python技术资料哦!!大家也可以留言,讨论一下技术问题,希望大家多多支持,关注一下啦,谢谢大家啦~~

       

  

猜你喜欢

转载自blog.csdn.net/weixin_39121325/article/details/90257354