TensorFlow之画图

# View more python learning tutorial on my Youtube and Youku channel!!!

# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
# Youku video tutorial: http://i.youku.com/pythontutorial

"""
Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
"""
from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    return outputs

# Make up some real data
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

##plt.scatter(x_data, y_data)
##plt.show()

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediction and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# important step
sess = tf.Session()
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
    init = tf.initialize_all_variables()
else:
    init = tf.global_variables_initializer()
sess.run(init)

# plot the real data
fig = plt.figure()#画出幕布,以便在上面画图
ax = fig.add_subplot(1,1,1)
#将幕布分为1行1列,然后从左往右从上到下在第1个子网格中画图
ax.scatter(x_data, y_data)#画出真实数据的点状图
plt.ion()
#如果不加这个,每画完一条线程序会暂停,加了后会一直画,如果一直画不是会很
#密密麻麻看不清吗,后面 有语句会移除当前线条防止出现该情况
plt.show()
#显示画好后的散点图,plt这个函数如果画完第一次整个程序就暂停了,
#如果想连续画图就要加上plt.ion()


for i in range(1000):
    # training
    sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
    if i % 50 == 0:
        # to visualize the result and improvement
        try:#尝试以下语句
            ax.lines.remove(lines[0])#抹除画出来的第一条线
        except Exception:
            #如果没有的话,忽略第一次的错误,因为这时还没画第一条线
            pass
        prediction_value = sess.run(prediction, feed_dict={xs: x_data})#预测值
        # plot the prediction
        lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
        #画出预测线,r-表红色,lw表粗度为5
        plt.pause(1)#每画完一条线暂停一秒

来源

构建图形,用散点图描述真实数据之间的关系。 (注意:plt.ion()用于连续显示。)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()#本次运行请注释,全局运行不要注释
plt.show()

散点图的结果为:

例子3 结果可视化

接下来,我们来显示预测数据。

每隔50次训练刷新一次图形,用红色、宽度为5的线来显示我们的预测数据和输入之间的关系,并暂停0.1s。

for i in range(1000):
    # training
    sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
    if i % 50 == 0:
        # to visualize the result and improvement
        try:
            ax.lines.remove(lines[0])
        except Exception:
            pass
        prediction_value = sess.run(prediction, feed_dict={xs: x_data})
        # plot the prediction
        lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
        plt.pause(0.1)

最后,机器学习的结果为:

例子3 结果可视化

猜你喜欢

转载自blog.csdn.net/weixin_40849273/article/details/81169508
今日推荐