搭建一个简单的神经网络demo

一、准备工作

1)运行环境已经安装了python3.5以上版本

2)已经安装好了tensorflow

3)已经安装好了numpy

4)已经安装好了matplotlib

二、开始搭建神经网络

1)创建一个层

#inputs:
#in_size:行
#out_size:列
#activation_fuction:
def add_later(inputs,in_size,out_size,activation_fuction=None):

#系数
    with tf.name_scope('layer'):
        with tf.name_scope('Weights'):
            Weights = tf.Variable(tf.random_normal([in_size,out_size]),name="W")

        #偏置
        with tf.name_scope('biases'):
            biases = tf.Variable(tf.zeros([1,out_size])+0.1,name='B')
        #行列式的乘法
        with tf.name_scope('Wx_plus_b'):
            Wx_plus_b = tf.matmul(inputs,Weights)+biases

        #如果是None 不用激活函数
        if activation_fuction is None:
            outputs= Wx_plus_b
        else:
            outputs = activation_fuction(Wx_plus_b)#如果不是None,用激活函数
        return outputs

2)创建训练数据

#-1到1,300个点,竖起来
x_data = np.linspace(-1,1,300)[:,np.newaxis]
#产生一个与x_data大小形状相等的噪点
noise = np.random.normal(0,0.05,x_data.shape)
#生产出y,使得看起来比较乱
#y=x^2-0.5+c
y_data = np.square(x_data)-0.5+noise

3)设置全局变量

#相当于一个全局变量
with tf.name_scope('inputs'):
    xs=tf.placeholder(tf.float32,[None,1],name='x_input')
    ys=tf.placeholder(tf.float32,[None,1],name='y_input')

4)搭建神经网络

#第一层神经网络
l1=add_later(xs,1,10,activation_fuction=tf.nn.relu)
#第二层神经网络
prediction = add_later(l1,10,1,activation_fuction=None)
with tf.name_scope('loss'):
    loss = tf.reduce_mean(#求平均
        tf.reduce_sum(#求和
            tf.square(ys - prediction),#求差值
                      reduction_indices=[1]))
with tf.name_scope('train_step'):
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

5)激活神经网络

init = tf.initialize_all_variables()
sess=tf.Session()
writer = tf.summary.FileWriter("logs/",sess.graph)
sess.run(init)

6)显示效果

#生成图片框
fig = plt.figure()
#图片狂的大小,默认1,1,1代表在整个图片显示
ax= fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()

7)训练

for i in range(1000):
    #训练
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i% 10==0:
        #print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
        try:
            ax.lines.remove(lines[0])
        except Exception:
            pass

        prediction_value = sess.run(prediction,feed_dict={xs:x_data})

        lines = ax.plot(x_data,prediction_value,'r-',lw=5)
        plt.pause(0.1)
        print(i)

三、效果展示

四、查看神经网络图模型

1)切换到logs的上级目录,在命令行输入:tensorboard --logdir='logs/' 如下图:

2)在浏览器地址栏输入http://localhost:6006 可看到下图:

猜你喜欢

转载自blog.csdn.net/weixin_39527812/article/details/83032612