神经网络优化算法二(正则化、滑动平均模型)

1、神经网络进一步优化——过拟合与正则化

过拟合,指的是当一个模型过为复杂后,它可以很好的“记忆”每一个训练数据中随机噪音的部分而忘了要去“学习”训练数据中通用的趋势。举一个极端的例子,如果一个模型中的参数比训练数据的总数还多,那么只要训练数据不冲突,这个模型完全可以记住所有训练数据的结果从而使得损失函数为0。

为了避免过拟合问题,一个非常常用的方法是正则化

正则化:就是在损失函数中给每个参数加上权重,引入模型复杂度指标,从而抑制模型的噪声,减少过拟合。使用正则化后,损失函数loss变为两项之和,假设用于刻画模型在训练数据上表现的损失函数为 J(θ),那么在优化时不是直接优化J(θ),

而是优化

其中 R(w) 刻画的是模型的复杂度,而 λ 表示模型复杂损失在总损失中的比例。一般来说模型复杂度只由权重 w 决定。常用的刻画模型复杂度的函数 R(w) 有两种。

1、一种是 L1 正则化,计算公式是

\large R ( w ) = \| w \| _ { 1 } = \sum _ { i } \left| w _ { i } \right|

2、另一种是 L2 正则化,计算公式是

 \large R ( w ) = \| w \| \frac { 2 } { 2 } = \sum _ { i } \left| w _ { i } ^ { 2 } \right|

 

无论是哪一种正则化方式,基本思想都是希望通过限制权重的大小,使得模型不能任意拟合训练数据中的随机噪音。但这两种正则化方式有很大的区别:

  •  L1 正则化会让参数变得更稀疏,而 L2 正则化不会。所谓参数更稀疏就是会有更多的参数变为0。
  •  L1 正则化不可导,L2 正则化可导。所以优化 L2 正则化损失函数更简洁,优化 L1 正则化损失函数更复杂。
w = tf.Variable(tf.random_normal([2,1],stddev = 1,seed = 1))
y = tf.matmul(x,w)
 
loss = tf.reduce_mean(tf.square(y_ - y)) +
        tf.contrib.layers.l2_regularizer(lambda)(w)

在上述代码中,loss 为定义的损失函数,它由两部分组成。第一部分是前面介绍的均方差函数,它刻画了模型在训练数据上的表现。第二部分就是 L2 正则化。 

weights = tf.constant([[1.0,-2.0],[-3.0,4.0]])
with tf.Session() as sess:
    #输出为 (|1|+|-2|+|-3|+|4|) * 0.5 = 5 其中 0.5 为正则化项的权重
    print sess.run(tf.contrib.layers.l1_regularizer(0.5)(weights))
    #输出为 (1^2 + (-2)^2 + (-3)^2 + (4)^2) /2 * 0.5 = 7.5
    print sess.run(tf.contrib.layers.l2_regularizer(0.5)(weights))

 以上代码显示了 L1 正则化和 L2 正则化的计算差别。但当神经网络的参数增多后,这样的方式首先会导致损失函数 loss 的定义很长,可读性差且容易出错。但更为主要的是,当网络结构复杂化之后定义网络结构的部分和计算损失函数的部分可能不在一个函数中,这样通过变量这种方式计算损失函数就不方便了。为了解决这个问题,可以利用TensorFlow中提供的集合,以下代码给出了通过集合计算一个 5 层神经网络带 L2 正则化的损失函数的计算方法。

import tensorflow as tf
#获取一层神经网络边上的权重,并将这个权重的 L2 正则化损失加入名称为 'losses' 的集合中
def get_weight(shape,lambda):
    #生成一个变量
    var = tf.Variable(tf.random_normal(shape),dtype = tf.float32)
    # add_to_collection 函数将这个新生成变量的 L2 正则化损失加入集合
    # 这个函数的第一个参数 'losses' 是集合的名字,第二个参数是要加入集合的内容
    tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(lambda)(var))
    return var
x = tf.placeholder(tf.float32,shape = (None,2))
y_ = tf.placeholder(tf.float32,shape = (None,1))
batch_size = 8
 
#定义了每一层网络节点中的个数
layer_dimension = [2,10,10,10,1]
#神经网络的层数
n_layers = len(layer_dimension)
 
#这个变量维护前向传播时最深层的节点,开始的时候是输入层
cur_layer = x
#当前层的节点个数
in_dimension = layer_dimension[0]
 
#通过 for 循环来生成 5 层全连接神经网络
for i in range(1,n_layers):
    out_dimension = layer_dimension[i] #下一层节点个数
    #生成当前层中权重的变量,并将这个变量的 L2 正则化损失加入计算图上的集合
    weight = get_weight([in_dimension,out_dimension],0.001)
    bias = tf.Variable(tf.constant(0.1,shape = [out_dimension]))
    #使用relu 激活函数
    cur_layer = tf.nn.relu(tf.matmul(cur_layer,weight) + bias)
    #进入下一层之前将下一层的节点个数更新为当前节点个数
    in_dimension = layer_dimension[i]
 
#定义神经网络前向传播的同时已经将所有的 L2 正则化损失加入了图上的集合
#这里只需要计算刻画模型在数据上表现的损失函数
mse_loss = tf.reduce_mean(tf.square(y_ - cur_layer))
 
#将均方差损失函数加入集合
tf.add_to_collection('losses',mse_loss)
 
# get_collection 返回一个列表,这个列表是所有的这个集合中的元素。
# 在这个样例中,这些元素就是损失函数的不同部分,将它们加起来就可以得到最终的损失函数
loss = tf.add_n(tf.get_collection('losses'))

第一步:

#coding:utf-8
#导入模块,生成数据集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE=30 	#一次喂入神经网络的30组数据
seed=2
#基于seed产生随机数
rdm=np.random.RandomState(seed)
#随机数返回300行2列的矩阵,表示300组坐标点(x0,x1)作为输入数据集
X=rdm.randn(300,2)
#从X这个300行2列的矩阵中取出1行,判断如果2个坐标的平方和小于2,给Y赋值为1,其中赋值为0
#作为数据集的正确答案(标签)
Y_=[int(x0*x0+x1*x1<2) for (x0,x1) in X ]
#遍历Y_中的每个元素,1赋值为red,2赋值为blue,这样可视化显示时人可以直观的区分
Y_c=[['red' if y else 'blue'] for y in Y_]
#对数据集X和标签Y进行shape整理,第一元素为-1表示,随第二个参数计算得到,第二个元素表示多少列,把X整理为n行2列,把Y整理为n行1列
X=np.vstack(X).reshape(-1,2)
Y=np.vstack(Y_).reshape(-1,1)
print X
print Y
print Y_c
#用plt.scatter画出数据集X和各行中的第0列元素和第1列元素的点即各行的(x0,x1),用各行Y_c对应的值表示颜色(c是color的缩写)
plt.scatter(X[:,0],X[:,1],c=np.squeeze(Y_c))
plt.show() 

运行结果:

(tf1.5) zhangkf@Ubuntu2:~/tf/tf4$ python opt4.py 
[[ -4.16757847e-01  -5.62668272e-02]
 [ -2.13619610e+00   1.64027081e+00]
 [ -1.79343559e+00  -8.41747366e-01]
 [  5.02881417e-01  -1.24528809e+00]
 [ -1.05795222e+00  -9.09007615e-01]
 [  5.51454045e-01   2.29220801e+00]
 [  4.15393930e-02  -1.11792545e+00]
 [  5.39058321e-01  -5.96159700e-01]
 [ -1.91304965e-02   1.17500122e+00]
 [ -7.47870949e-01   9.02525097e-03]
 [ -8.78107893e-01  -1.56434170e-01]
 [  2.56570452e-01  -9.88779049e-01]
 [ -3.38821966e-01  -2.36184031e-01]
 [ -6.37655012e-01  -1.18761229e+00]
 [ -1.42121723e+00  -1.53495196e-01]
 [ -2.69056960e-01   2.23136679e+00]
 [ -2.43476758e+00   1.12726505e-01]
 [  3.70444537e-01   1.35963386e+00]
 [  5.01857207e-01  -8.44213704e-01]
 [  9.76147160e-06   5.42352572e-01]
 [ -3.13508197e-01   7.71011738e-01]
 [ -1.86809065e+00   1.73118467e+00]
 [  1.46767801e+00  -3.35677339e-01]
 [  6.11340780e-01   4.79705919e-02]
 [ -8.29135289e-01   8.77102184e-02]
 [  1.00036589e+00  -3.81092518e-01]
 [ -3.75669423e-01  -7.44707629e-02]
 [  4.33496330e-01   1.27837923e+00]
 [ -6.34679305e-01   5.08396243e-01]
 [  2.16116006e-01  -1.85861239e+00]
 [ -4.19316482e-01  -1.32328898e-01]
 [ -3.95702397e-02   3.26003433e-01]
 [ -2.04032305e+00   4.62555231e-02]
 [ -6.77675577e-01  -1.43943903e+00]
 [  5.24296430e-01   7.35279576e-01]
 [ -6.53250268e-01   8.42456282e-01]
 [ -3.81516482e-01   6.64890091e-02]
 [ -1.09873895e+00   1.58448706e+00]
 [ -2.65944946e+00  -9.14526229e-02]
 [  6.95119605e-01  -2.03346655e+00]
 [ -1.89469265e-01  -7.72186654e-02]
 [  8.24703005e-01   1.24821292e+00]
 [ -4.03892269e-01  -1.38451867e+00]
 [  1.36723542e+00   1.21788563e+00]
 [ -4.62005348e-01   3.50888494e-01]
 [  3.81866234e-01   5.66275441e-01]
 [  2.04207979e-01   1.40669624e+00]
 [ -1.73795950e+00   1.04082395e+00]
 [  3.80471970e-01  -2.17135269e-01]
 [  1.17353150e+00  -2.34360319e+00]
 [  1.16152149e+00   3.86078048e-01]
 [ -1.13313327e+00   4.33092555e-01]
 [ -3.04086439e-01   2.58529487e+00]
 [  1.83533272e+00   4.40689872e-01]
 [ -7.19253841e-01  -5.83414595e-01]
 [ -3.25049628e-01  -5.60234506e-01]
 [ -9.02246068e-01  -5.90972275e-01]
 [ -2.76179492e-01  -5.16883894e-01]
 [ -6.98589950e-01  -9.28891925e-01]
 [  2.55043824e+00  -1.47317325e+00]
 [ -1.02141473e+00   4.32395701e-01]
 [ -3.23580070e-01   4.23824708e-01]
 [  7.99179995e-01   1.26261366e+00]
 [  7.51964849e-01  -9.93760983e-01]
 [  1.10914328e+00  -1.76491773e+00]
 [ -1.14421297e-01  -4.98174194e-01]
 [ -1.06079904e+00   5.91666521e-01]
 [ -1.83256574e-01   1.01985473e+00]
 [ -1.48246548e+00   8.46311892e-01]
 [  4.97940148e-01   1.26504175e-01]
 [ -1.41881055e+00  -2.51774118e-01]
 [ -1.54667461e+00  -2.08265194e+00]
 [  3.27974540e+00   9.70861320e-01]
 [  1.79259285e+00  -4.29013319e-01]
 [  6.96197980e-01   6.97416272e-01]
 [  6.01515814e-01   3.65949071e-03]
 [ -2.28247558e-01  -2.06961226e+00]
 [  6.10144086e-01   4.23496900e-01]
 [  1.11788673e+00  -2.74242089e-01]
 [  1.74181219e+00  -4.47500876e-01]
 [ -1.25542722e+00   9.38163671e-01]
 [ -4.68346260e-01  -1.25472031e+00]
 [  1.24823646e-01   7.56502143e-01]
 [  2.41439629e-01   4.97425649e-01]
 [  4.10869262e+00   8.21120877e-01]
 [  1.53176032e+00  -1.98584577e+00]
 [  3.65053516e-01   7.74082033e-01]
 [ -3.64479092e-01  -8.75979478e-01]
 [  3.96520159e-01  -3.14617436e-01]
 [ -5.93755583e-01   1.14950057e+00]
 [  1.33556617e+00   3.02629336e-01]
 [ -4.54227855e-01   5.14370717e-01]
 [  8.29458431e-01   6.30621967e-01]
 [ -1.45336435e+00  -3.38017777e-01]
 [  3.59133332e-01   6.22220414e-01]
 [  9.60781945e-01   7.58370347e-01]
 [ -1.13431848e+00  -7.07420888e-01]
 [ -1.22142917e+00   1.80447664e+00]
 [  1.80409807e-01   5.53164274e-01]
 [  1.03302907e+00  -3.29002435e-01]
 [ -1.15100294e+00  -4.26522471e-01]
 [ -1.48147191e-01   1.50143692e+00]
 [  8.69598198e-01  -1.08709057e+00]
 [  6.64221413e-01   7.34884668e-01]
 [ -1.06136574e+00  -1.08516824e-01]
 [ -1.85040397e+00   3.30488064e-01]
 [ -3.15693210e-01  -1.35000210e+00]
 [ -6.98170998e-01   2.39951198e-01]
 [ -5.52949440e-01   2.99526813e-01]
 [  5.52663696e-01  -8.40443012e-01]
 [ -3.12270670e-01   2.14467809e+00]
 [  1.21105582e-01  -8.46828752e-01]
 [  6.04624490e-02  -1.33858888e+00]
 [  1.13274608e+00   3.70304843e-01]
 [  1.08580640e+00   9.02179395e-01]
 [  3.90296450e-01   9.75509412e-01]
 [  1.91573647e-01  -6.62209012e-01]
 [ -1.02351498e+00  -4.48174823e-01]
 [ -2.50545813e+00   1.82599446e+00]
 [ -1.71406741e+00  -7.66395640e-02]
 [ -1.31756727e+00  -2.02559359e+00]
 [ -8.22453750e-02  -3.04666585e-01]
 [ -1.59724130e-01   5.48946560e-01]
 [ -6.18375485e-01   3.78794466e-01]
 [  5.13251444e-01  -3.34844125e-01]
 [ -2.83519516e-01   5.38424263e-01]
 [  5.72509465e-02   1.59088487e-01]
 [ -2.37440268e+00   5.85199353e-02]
 [  3.76545911e-01  -1.35479764e-01]
 [  3.35908395e-01   1.90437591e+00]
 [  8.53644334e-02   6.65334278e-01]
 [ -8.49995503e-01  -8.52341797e-01]
 [ -4.79985112e-01  -1.01964910e+00]
 [ -7.60113841e-03  -9.33830661e-01]
 [ -1.74996844e-01  -1.43714343e+00]
 [ -1.65220029e+00  -6.75661789e-01]
 [ -1.06706712e+00  -6.52931145e-01]
 [ -6.12094750e-01  -3.51262461e-01]
 [  1.04547799e+00   1.36901602e+00]
 [  7.25353259e-01  -3.59474459e-01]
 [  1.49695179e+00  -1.53111111e+00]
 [ -2.02336394e+00   2.67972576e-01]
 [ -2.20644541e-03  -1.39291883e-01]
 [  3.25654693e-02  -1.64056022e+00]
 [ -1.15669917e+00   1.23403468e+00]
 [  1.02818490e+00  -7.21879726e-01]
 [  1.93315697e+00  -1.07079633e+00]
 [ -5.71381608e-01   2.92432067e-01]
 [ -1.19499989e+00  -4.87930544e-01]
 [ -1.73071165e-01  -3.95346401e-01]
 [  8.70840765e-01   5.92806797e-01]
 [ -1.09929731e+00  -6.81530644e-01]
 [  1.80066685e-01  -6.69310440e-02]
 [ -7.87749540e-01   4.24753672e-01]
 [  8.19885117e-01  -6.31118683e-01]
 [  7.89059649e-01  -1.62167380e+00]
 [ -1.61049926e+00   4.99939764e-01]
 [ -8.34515207e-01  -9.96959687e-01]
 [ -2.63388077e-01  -6.77360492e-01]
 [  3.27067038e-01  -1.45535944e+00]
 [ -3.71519124e-01   3.16096597e+00]
 [  1.09951013e-01  -1.91352322e+00]
 [  5.99820429e-01   5.49384465e-01]
 [  1.38378103e+00   1.48349243e-01]
 [ -6.53541444e-01   1.40883398e+00]
 [  7.12061227e-01  -1.80071604e+00]
 [  7.47598942e-01  -2.32897001e-01]
 [  1.11064528e+00  -3.73338813e-01]
 [  7.86146070e-01   1.94168696e-01]
 [  5.86204098e-01  -2.03872918e-02]
 [ -4.14408598e-01   6.73134124e-02]
 [  6.31798924e-01   4.17592731e-01]
 [  1.61517627e+00   4.25606211e-01]
 [  6.35363758e-01   2.10222927e+00]
 [  6.61264168e-02   5.35558351e-01]
 [ -6.03140792e-01   4.19576292e-02]
 [  1.64191464e+00   3.11697707e-01]
 [  1.45116990e+00  -1.06492788e+00]
 [ -1.40084545e+00   3.07525527e-01]
 [ -1.36963867e+00   2.67033724e+00]
 [  1.24845030e+00  -1.24572655e+00]
 [ -1.67168774e-01  -5.76610930e-01]
 [  4.16021749e-01  -5.78472626e-02]
 [  9.31887358e-01   1.46833213e+00]
 [ -2.21320943e-01  -1.17315562e+00]
 [  5.62669078e-01  -1.64515057e-01]
 [  1.14485538e+00  -1.52117687e-01]
 [  8.29789046e-01   3.36065952e-01]
 [ -1.89044051e-01  -4.49328601e-01]
 [  7.13524448e-01   2.52973487e+00]
 [  8.37615794e-01  -1.31682403e-01]
 [  7.07592866e-01   1.14053878e-01]
 [ -1.28089518e+00   3.09846277e-01]
 [  1.54829069e+00  -3.15828043e-01]
 [ -1.12590378e+00   4.88496666e-01]
 [  1.83094666e+00   9.40175993e-01]
 [  1.01871705e+00   2.30237829e+00]
 [  1.62109298e+00   7.12683273e-01]
 [ -2.08703629e-01   1.37617991e-01]
 [ -1.03352168e-01   8.48350567e-01]
 [ -8.83125561e-01   1.54538683e+00]
 [  1.45840073e-01  -4.00106056e-01]
 [  8.15206041e-01  -2.07492237e+00]
 [ -8.34437391e-01  -6.57718447e-01]
 [  8.20564332e-01  -4.89157001e-01]
 [  1.42496703e+00  -4.46857897e-01]
 [  5.21109431e-01  -7.08194380e-01]
 [  1.15553059e+00  -2.54530459e-01]
 [  5.18924924e-01  -4.92994911e-01]
 [ -1.08654815e+00  -2.30917497e-01]
 [  1.09801004e+00  -1.01787805e+00]
 [ -1.52939136e+00  -3.07987737e-01]
 [  7.80754356e-01  -1.05583964e+00]
 [ -5.43883381e-01   1.84301739e-01]
 [ -3.30675843e-01   2.87208202e-01]
 [  1.18952814e+00   2.12015479e-02]
 [ -6.54096803e-02   7.66115904e-01]
 [ -6.16350846e-02  -9.52897152e-01]
 [ -1.01446306e+00  -1.11526396e+00]
 [  1.91260068e+00  -4.52632031e-02]
 [  5.76909718e-01   7.17805695e-01]
 [ -9.38998998e-01   6.28775807e-01]
 [ -5.64493432e-01  -2.08780746e+00]
 [ -2.15050132e-01  -1.07502856e+00]
 [ -3.37972149e-01   3.43212732e-01]
 [  2.28253964e+00  -4.95778848e-01]
 [ -1.63962832e-01   3.71622161e-01]
 [  1.86521520e-01  -1.58429224e-01]
 [ -1.08292956e+00  -9.56625520e-01]
 [ -1.83376735e-01  -1.15980690e+00]
 [ -6.57768362e-01  -1.25144841e+00]
 [  1.12448286e+00  -1.49783981e+00]
 [  1.90201722e+00  -5.80383038e-01]
 [ -1.05491567e+00  -1.18275720e+00]
 [  7.79480054e-01   1.02659795e+00]
 [ -8.48666001e-01   3.31539648e-01]
 [ -1.49591353e-01  -2.42440600e-01]
 [  1.51197175e-01   7.65069481e-01]
 [ -1.91663052e+00  -2.22734129e+00]
 [  2.06689897e-01  -7.08763560e-02]
 [  6.84759969e-01  -1.70753905e+00]
 [ -9.86569665e-01   1.54353634e+00]
 [ -1.31027053e+00   3.63433972e-01]
 [ -7.94872445e-01  -4.05286267e-01]
 [ -1.37775793e+00   1.18604868e+00]
 [ -1.90382114e+00  -1.19814038e+00]
 [ -9.10065643e-01   1.17645419e+00]
 [  2.99210670e-01   6.79267178e-01]
 [ -1.76606800e-02   2.36040923e-01]
 [  4.94035871e-01   1.54627765e+00]
 [  2.46857508e-01  -1.46877580e+00]
 [  1.14709994e+00   9.55569845e-02]
 [ -1.10743873e+00  -1.76286141e-01]
 [ -9.82755667e-01   2.08668273e+00]
 [ -3.44623671e-01  -2.00207923e+00]
 [  3.03234433e-01  -8.29874845e-01]
 [  1.28876941e+00   1.34925462e-01]
 [ -1.77860064e+00  -5.00791490e-01]
 [ -1.08816157e+00  -7.57855553e-01]
 [ -6.43744900e-01  -2.00878453e+00]
 [  1.96262894e-01  -8.75896370e-01]
 [ -8.93609209e-01   7.51902355e-01]
 [  1.89693224e+00  -6.29079151e-01]
 [  1.81208553e+00  -2.05626574e+00]
 [  5.62704887e-01  -5.82070757e-01]
 [ -7.40029749e-02  -9.86496364e-01]
 [ -5.94722499e-01  -3.14811843e-01]
 [ -3.46940532e-01   4.11443516e-01]
 [  2.32639090e+00  -6.34053128e-01]
 [ -1.54409962e-01  -1.74928880e+00]
 [ -2.51957930e+00   1.39116243e+00]
 [ -1.32934644e+00  -7.45596414e-01]
 [  2.12608498e-02   9.10917515e-01]
 [  3.15276082e-01   1.86620821e+00]
 [ -1.82497623e-01  -1.82826634e+00]
 [  1.38955717e-01   1.19450165e-01]
 [ -8.18899200e-01  -3.32639265e-01]
 [ -5.86387955e-01   1.73451634e+00]
 [ -6.12751558e-01  -1.39344202e+00]
 [  2.79433757e-01  -1.82223127e+00]
 [  4.27017458e-01   4.06987749e-01]
 [ -8.44308241e-01  -5.59820113e-01]
 [ -6.00520405e-01   1.61487324e+00]
 [  3.94953220e-01  -1.20381347e+00]
 [ -1.24747243e+00  -7.75462496e-02]
 [ -1.33397514e-02  -7.68323250e-01]
 [  2.91234010e-01  -1.97330948e-01]
 [  1.07682965e+00   4.37410232e-01]
 [ -9.31978663e-02   1.35631416e-01]
 [ -8.82708822e-01   8.84744194e-01]
 [  3.83204463e-01  -4.16994149e-01]
 [  1.17796550e-01  -5.36685309e-01]
 [  2.48718458e+00  -4.51361054e-01]
 [  5.18836127e-01   3.64448005e-01]
 [ -7.98348729e-01   5.65779713e-03]
 [ -3.20934708e-01   2.49513550e-01]
 [  2.56308392e-01   7.67625083e-01]
 [  7.83020087e-01  -4.07063047e-01]
 [ -5.24891667e-01  -5.89808683e-01]
 [ -8.62531086e-01  -1.74287290e+00]]
[[1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]]
[['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue']]

图片显示:

解释:

plt.scatter():利用指定颜色实现点(x,y)的可视化
plt.scatter(x坐标,y坐标,c="颜色")其中c是color的缩写。

 第二步:完整代码

#coding:utf-8
#0导入模块 ,生成模拟数据集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE = 30 
seed = 2 
#基于seed产生随机数
rdm = np.random.RandomState(seed)
#随机数返回300行2列的矩阵,表示300组坐标点(x0,x1)作为输入数据集
X = rdm.randn(300,2)
#从X这个300行2列的矩阵中取出一行,判断如果两个坐标的平方和小于2,给Y赋值1,其余赋值0
#作为输入数据集的标签(正确答案)
Y_ = [int(x0*x0 + x1*x1 <2) for (x0,x1) in X]
#遍历Y中的每个元素,1赋值'red'其余赋值'blue',这样可视化显示时人可以直观区分
Y_c = [['red' if y else 'blue'] for y in Y_]
#对数据集X和标签Y进行shape整理,第一个元素为-1表示,随第二个参数计算得到,第二个元素表示多少列,把X整理为n行2列,把Y整理为n行1列
X = np.vstack(X).reshape(-1,2)
Y_ = np.vstack(Y_).reshape(-1,1)
print X
print Y_
print Y_c
#用plt.scatter画出数据集X各行中第0列元素和第1列元素的点即各行的(x0,x1),用各行Y_c对应的值表示颜色(c是color的缩写) 
plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c)) 
plt.show()


#定义神经网络的输入、参数和输出,定义前向传播过程 
def get_weight(shape, regularizer):
	w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
	tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
	return w

def get_bias(shape):  
    b = tf.Variable(tf.constant(0.01, shape=shape)) 
    return b
	
x = tf.placeholder(tf.float32, shape=(None, 2))
y_ = tf.placeholder(tf.float32, shape=(None, 1))

w1 = get_weight([2,11], 0.01)	
b1 = get_bias([11])
y1 = tf.nn.relu(tf.matmul(x, w1)+b1)

w2 = get_weight([11,1], 0.01)
b2 = get_bias([1])
y = tf.matmul(y1, w2)+b2 


#定义损失函数
loss_mse = tf.reduce_mean(tf.square(y-y_))
loss_total = loss_mse + tf.add_n(tf.get_collection('losses'))


#定义反向传播方法:不含正则化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_mse)

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
	sess.run(init_op)
	STEPS = 40000
	for i in range(STEPS):
		start = (i*BATCH_SIZE) % 300
		end = start + BATCH_SIZE
		sess.run(train_step, feed_dict={x:X[start:end], y_:Y_[start:end]})
		if i % 2000 == 0:
			loss_mse_v = sess.run(loss_mse, feed_dict={x:X, y_:Y_})
			print("After %d steps, loss is: %f" %(i, loss_mse_v))
    #xx在-3到3之间以步长为0.01,yy在-3到3之间以步长0.01,生成二维网格坐标点
	xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
	#将xx , yy拉直,并合并成一个2列的矩阵,得到一个网格坐标点的集合
	grid = np.c_[xx.ravel(), yy.ravel()]
	#将网格坐标点喂入神经网络 ,probs为输出
	probs = sess.run(y, feed_dict={x:grid})
	#probs的shape调整成xx的样子
	probs = probs.reshape(xx.shape)
	print "w1:\n",sess.run(w1)
	print "b1:\n",sess.run(b1)
	print "w2:\n",sess.run(w2)	
	print "b2:\n",sess.run(b2)

plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
plt.contour(xx, yy, probs, levels=[.5])
plt.show()



#定义反向传播方法:包含正则化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_total)

with tf.Session() as sess:
	init_op = tf.global_variables_initializer()
	sess.run(init_op)
	STEPS = 40000
	for i in range(STEPS):
		start = (i*BATCH_SIZE) % 300
		end = start + BATCH_SIZE
		sess.run(train_step, feed_dict={x: X[start:end], y_:Y_[start:end]})
		if i % 2000 == 0:
			loss_v = sess.run(loss_total, feed_dict={x:X,y_:Y_})
			print("After %d steps, loss is: %f" %(i, loss_v))

	xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
	grid = np.c_[xx.ravel(), yy.ravel()]
	probs = sess.run(y, feed_dict={x:grid})
	probs = probs.reshape(xx.shape)
	print "w1:\n",sess.run(w1)
	print "b1:\n",sess.run(b1)
	print "w2:\n",sess.run(w2)
	print "b2:\n",sess.run(b2)

plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c)) 
plt.contour(xx, yy, probs, levels=[.5])
plt.show()

运行结果1:无正则化的

 

运行结果2:有正则化的 

注意:

tf.add_to_collection(name, value)  用来把一个value放入名称是‘name’的集合,组成一个列表;

f.get_collection(key, scope=None) 用来获取一个名称是‘key’的集合中的所有元素,
返回的是一个列表,列表的顺序是按照变量放入集合中的先后;   scope参数可选,表示的是
名称空间(名称域),如果指定,就返回名称域中所有放入‘key’的变量的列表,不指定则返回所有变量。

2、神经网络进一步优化 —— 滑动平均模型

另一个可以使模型在测试数据上更健壮的方法 -- 滑动平均模型。在采用随机梯度下降算法训练神经网络时,使用滑动平均模型在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。

在TensorFlow中提供了 tf.train.ExponentialMovingAverage 来实现滑动平均模型。在初始化时,需要提供一个衰减率(decay)。这个衰减率将用于控制模型更新的速度。滑动平均对每一个变量会维护一个影子变量,这个影子变量的初始值就是相应变量的初始值,而每次更新变量时,影子变量的值会更新成:
(decay 为衰减率)

从公式可以看到,decay 决定了模型更新的速度,decay 越大则模型越趋于稳定。在实际应用中,decay 一般取非常接近 1 的数,比如 0.99 或 0.999。 为了使得模型在训练前期可以更新的更快,滑动平均还提供了 num_updates 参数来动态设置 decay  的大小。如果在滑动平均初始化时提供了 num_updates 参数,那么每次使用的衰减率将是

import tensorflow as tf
 
#定义一个变量用于计算滑动平均,这个变量的初始值为 0 
v1 = tf.Variable(0,dtype = tf.float32)
#这里 step 变量模拟神经网络中迭代的轮数,可以用于动态控制衰减率
step = tf.Variable(0,trainable = False)
 
#定义一个滑动平均类。初始化时给定了衰减率 0.99 和控制衰减率的变量 step
ema = tf.train.ExponentialMovingAverage(0.99)(step)
#定义一个更新变量滑动平均的操作,这里需要给定一个列表,每次执行这个操作时,这个列表中的变量都会更新
maintain_averages_op = ema.apply([v1])
 
with tf.Session() as sess:
    #初始化所有变量
    tf.initialize_all_variables().run()
    #通过 ema.average(v1) 获取滑动平均之后的变量的取值。在初始化之后变量 v1 的值和 v1的滑动平均都是 0
    print sess.run([v1,ema.average(v1)])
    
    #更新 v1 的值到 5
    sess.run(tf.assign(v1,5))
    #更新 v1 的滑动平均值,衰减率为 min{0.99,(1+step)/(10+step) = 0.1} = 0.1
    sess.run(maintain_averages_op)
    print sess.run([v1,ema.average(v1)]) #输出 [5.0,4.5]
    
    #更新 step 的值为 10000
    sess.run(tf.assign(step,10000))
    #更新 v1 的值为 10
    sess.run(tf.assign(v1,10))
    #更新 v1 的滑动平均值,衰减率为 min{0.99,(1+step)/(10+step)} = 0.99
    #所以 v1 的滑动平均会被更新为 0.99*4.5 + 0.01*10 = 4.555
    sess.run(maintain_averages_op)
    print sess.run([v1,ema.average(v1)]) #输出 [10.0,4.5549998]
    
    #再次更新滑动平均,得到的新滑动平均值为 0.99*4.555 + 0.01*10 = 4.60945
    sess.run(maintain_averages_op)
    print sess.run([v1,ema.average(v1)]) #输出 [10.0,4.6094499]

通过上述代码可知,滑动平均模型是一个使得训练在基于后期时趋于稳定的一个模型。

部分参考了作者:https://blog.csdn.net/qq_32023541/article/details/79607000

猜你喜欢

转载自blog.csdn.net/abc13526222160/article/details/84959271