Tensorflow-visual training process

1 Introduction

This section will show the model training process in a graphical form and observe how the model is fitted into the final curve step by step.

2. Visualize the training process

2.1. Import necessary modules

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

2.2. Define Add Layer Function

def add_layer(inputs,in_size,out_size,activation_functional=None):
    Weights = tf.Variable(tf.random_normal([in_size,out_size]))
    biases = tf.Variable(tf.zeros([1,out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs,Weights) + biases
    if activation_functional is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_functional(Wx_plus_b)
    return outputs

2.3. Constructing data

x_data = np.linspace(-1,1,300)[:,np.newaxis]
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) -0.5+noise

plt.scatter(x_data,y_data)
plt.show()

2.4. Training

xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])

l1 = add_layer(xs,1,10,activation_functional=tf.nn.relu)
prediction = add_layer(l1,10,1,activation_functional=None)

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

2.5. Visualization

fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(1000):
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i%50 == 0:
        try:
            ax.lines.remove(line[0])
        except Exception:
            pass
        prediction_value = sess.run(prediction, feed_dict={xs:x_data})
        lines = ax.plot(x_data, prediction_value,'r-',lw=5)
        plt.pause(1)
Published 227 original articles · praised 633 · 30,000+ views

Guess you like

Origin blog.csdn.net/weixin_37763870/article/details/105519980