TensorFlow笔记之搭建神经网络

一、基本概念

  1. 张量:多维数组,用 “阶” 表示张量的维度。
    0阶张量称作标量,表示一个单独的数;
    1阶张量称作向量,表示一个一维数组
    2阶张量称作矩阵,表示一个二维数组
    3阶及以上称作张量,判断张量是几阶的,就通过张量右边的方括号的个数,几个就是几阶。如:t=[[[...]]]为3阶张量。
  2. 基于TensorFlow的NN:用张量表示数据,用计算图搭建神经网络,用会话执行计算图,优化线上的权重(参数),得到模型。
  3. 数据类型:TensorFlow的数据类型有tf.float32、tf.int32等
  4. 计算图:搭建神经网络的计算过程,是承载一个或多个计算节点的一张图,只单间网络,不运算。
  5. 会话(tf.Session()):执行计算图中的节点运算
  6. 前向传播:搭建模型的计算过程,让模型具有推导能力,可以针对一组输入给你相应的输出。
  7. 反向传播:训练模型参数,在所有参数上用梯度下降,使NN模型在训练数据上的损失函数最小
  8. 损失函数:计算得到的预期值 y 与已知答案 y_ 的差距。均方误差MSE是比较常用的方法之一。
  9. 反向传播训练方法:以减小损失函数为优化目标,有梯度下降、momentum优化器、adam优化器等优化方法。
  10. 学习率:决定每次参数更新的幅度。

二、神经网络搭建八股

  1. 导入模块,生成模拟数据集
  2. 现象传播:定义输入、参数和输出
  3. 反向传播:定义损失函数、反向传播方法
  4. 生成会话,训练模型

三、搭建实例

    随机产生32组生产出的零件的体积和重量,训练5000轮,每500轮输出一次损失函数。

#coding:utf-8
#0 导入模块,生成模拟数据集
import tensorflow as tf
import numpy as np
BATCH_SIZE = 8
seed = 23455

#给予seed产生随机数
rng = np.random.RandomState(seed)
#随机数返回32行2列的矩阵,表示32组 体积和重量 作为输入数据集
X = rng.rand(32,2)
#从X这个矩阵中,取出一行,判断如果和小于1,给Y赋值1,如果和不小于1,给Y赋值0
#作为输入数据集的标签(正确答案)
Y = [[int(x0 + x1 < 1)] for (x0, x1) in X ]
print "X:\n", X
print "Y:\n", Y

#1 定义神经网络的输入、参数和输出,定义前向传播过程
x = tf.placeholder(tf.float32, shape=(None, 2))
y_= tf.placeholder(tf.float32, shape=(None, 1))

W1 = tf.Variable(tf.random_normal([2,3], stddev=1, seed=1))
W2 = tf.Variable(tf.random_normal([3,1], stddev=1, seed=1))

a = tf.matmul(x, W1)
y = tf.matmul(a, W2)

#2 定义损失函数及反向传播方法
loss = tf.reduce_mean(tf.square(y-y_))
#train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
train_step = tf.train.MomentumOptimizer(0.001, 0.9).minimize(loss)
#train_step = tf.train.AdamOptimizer(0.001).minimize(loss)

#3 生成会话,训练STEPS轮
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # 输出目前(未经训练)的参数
    print "W1:\n", sess.run(W1)
    print "W2:\n", sess.run(W2)
    print "\n"

    # 训练模型
    STEPS = 5000
    for i in range(STEPS):
        start = (i*BATCH_SIZE) % 32
        end = start + BATCH_SIZE
        sess.run(train_step, feed_dict={x: X[start:end], y_:Y[start:end]})
        if i % 500 == 0 :
            total_loss = sess.run(loss, feed_dict={x: X, y_:Y})
            print ("After %d training step(s), loss on all data is %g" %(i, total_loss))

    # 输出训练后的参数取值
    print "\n"
    print "W1:\n", sess.run(W1)
    print "W2:\n", sess.run(W2)

"""
X:
[[0.83494319 0.11482951]
 [0.66899751 0.46594987]
 [0.60181666 0.58838408]
 [0.31836656 0.20502072]
 [0.87043944 0.02679395]
 [0.41539811 0.43938369]
 [0.68635684 0.24833404]
 [0.97315228 0.68541849]
 [0.03081617 0.89479913]
 [0.24665715 0.28584862]
 [0.31375667 0.47718349]
 [0.56689254 0.77079148]
 [0.7321604  0.35828963]
 [0.15724842 0.94294584]
 [0.34933722 0.84634483]
 [0.50304053 0.81299619]
 [0.23869886 0.9895604 ]
 [0.4636501  0.32531094]
 [0.36510487 0.97365522]
 [0.73350238 0.83833013]
 [0.61810158 0.12580353]
 [0.59274817 0.18779828]
 [0.87150299 0.34679501]
 [0.25883219 0.50002932]
 [0.75690948 0.83429824]
 [0.29316649 0.05646578]
 [0.10409134 0.88235166]
 [0.06727785 0.57784761]
 [0.38492705 0.48384792]
 [0.69234428 0.19687348]
 [0.42783492 0.73416985]
 [0.09696069 0.04883936]]
Y:
[[1], [0], [0], [1], [1], [1], [1], [0], [1], [1], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [1], [1], [0], [1], [0], [1], [1], [1], [1], [1], [0], [1]]
W1:
[[-0.8113182   1.4845988   0.06532937]
 [-2.4427042   0.0992484   0.5912243 ]]
W2:
[[-0.8113182 ]
 [ 1.4845988 ]
 [ 0.06532937]]


After 0 training step(s), loss on all data id 5.13118
After 500 training step(s), loss on all data id 0.384391
After 1000 training step(s), loss on all data is 0.383592
After 1500 training step(s), loss on all data is 0.383562
After 2000 training step(s), loss on all data is 0.383561
After 2500 training step(s), loss on all data is 0.383561
After 3000 training step(s), loss on all data is 0.383561
After 3500 training step(s), loss on all data is 0.383561
After 4000 training step(s), loss on all data is 0.383561
After 4500 training step(s), loss on all data is 0.383561


W1:
[[-0.6131169   0.83128434  0.07560416]
 [-2.2572947  -0.1448612   0.56774247]]
W2:
[[-0.10436084]
 [ 0.7734161 ]
 [-0.04417417]]

"""


 

猜你喜欢

转载自blog.csdn.net/qust1508060414/article/details/81172424