神经网络过拟合问题

在训练数据不够多,网络结构很复杂,或者overtraining时,可能会产生过拟合问题。

一般我们会将整个数据集分为训练集training data、validation data,testing data。这个validation data是什么?它其实就是用来避免过拟合的,在训练过程中,我们通常用它来确定一些超参数(比如根据validation data上的accuracy来确定提前终止的epoch大小、根据validation data确定learning rate等等)。那为啥不直接在testing data上做这些呢?因为如果在testing data做这些,那么随着训练的进行,我们的网络实际上就是在一点一点地overfitting我们的testing data,导致最后得到的testing accuracy没有任何参考意义。因此,training data的作用是计算梯度更新权重,validation data如上所述,testing data则给出一个accuracy以判断网络的好坏。

那么过拟合的直观解释为,随着训练过程的进行,模型复杂度增加,在training data上的error渐渐减小,但是在验证集上的error却反而渐渐增大——因为训练出来的网络过拟合了训练集,对训练集外的数据效果不好。也就是如过产生了过拟合问题,那么用training data得到的准确率同testing data得到的准确率相差非常大。

                                                                

那么为了防止过拟合问题,可用的方法有:得到更大的数据集,正则化方法,在网络层dropout一下。下面主要对dropout和正则化方法做讨论。

1.dropout

                                    

dropout的实质就是随机的让每层的一些神经元不工作以减少模型的复杂度。

它为什么有助于防止过拟合呢?可以简单地这样解释,运用了dropout的训练过程,相当于训练了很多个只有半数隐层单元的神经网络(后面简称为“半数网络”),每一个这样的半数网络,都可以给出一个分类结果,这些结果有的是正确的,有的是错误的。随着训练的进行,大部分半数网络都可以给出正确的分类结果,那么少数的错误分类结果就不会对最终结果造成大的影响。

2.正则化

L2正则化就是在代价函数后面再加上一个正则化项:

C0代表原始的代价函数,后面那一项就是L2正则化项,它是这样来的:所有参数w的平方的和,除以训练集的样本大小n。λ就是正则项系数,权衡正则项与C0项的比重。另外还有一个系数1/2,1/2经常会看到,主要是为了后面求导的结果方便,后面那一项求导会产生一个2,与1/2相乘刚好凑整。

L2正则化项是怎么避免overfitting的呢?我们推导一下看看,先求导:

可以发现L2正则化项对b的更新没有影响,但是对于w的更新有影响:

在不使用L2正则化时,求导结果中w前系数为1,现在w前面系数为 1−ηλ/n ,因为η、λ、n都是正的,所以 1−ηλ/n小于1,它的效果是减小w,这也就是权重衰减(weight decay)的由来当然考虑到后面的导数项,w最终的值可能增大也可能减小

另外,需要提一下,对于基于mini-batch的随机梯度下降,w和b更新的公式跟上面给出的有点不同:

对比上面w的更新公式,可以发现后面那一项变了,变成所有导数加和,乘以η再除以m,m是一个mini-batch中样本的个数。

到目前为止,我们只是解释了L2正则化项有让w“变小”的效果,但是还没解释为什么w“变小”可以防止overfitting?一个所谓“显而易见”的解释就是:更小的权值w,从某种意义上说,表示网络的复杂度更低,对数据的拟合刚刚好(这个法则也叫做奥卡姆剃刀),而在实际应用中,也验证了这一点,L2正则化的效果往往好于未经正则化的效果。当然,对于很多人(包括我)来说,这个解释似乎不那么显而易见,所以这里添加一个稍微数学一点的解释(引自知乎):

过拟合的时候,拟合函数的系数往往非常大,为什么?如下图所示,过拟合,就是拟合函数需要顾忌每一个点,最终形成的拟合函数波动很大。在某些很小的区间里,函数值的变化很剧烈。这就意味着函数在某些小区间里的导数值(绝对值)非常大,由于自变量值可大可小,所以只有系数足够大,才能保证导数值很大。

而正则化是通过约束参数的范数使其不要太大,所以可以在一定程度上减少过拟合情况。

 L1正则化:

在原始的代价函数后面加上一个L1正则化项,即所有权重w的绝对值的和,乘以λ/n(这里不像L2正则化项那样,需要再乘以1/2,具体原因上面已经说过。)

同样先计算导数:

上式中sgn(w)表示w的符号。那么权重w的更新规则为:

比原始的更新规则多出了η * λ * sgn(w)/n这一项。当w为正时,更新后的w变小。当w为负时,更新后的w变大——因此它的效果就是让w往0靠,使网络中的权重尽可能为0,也就相当于减小了网络复杂度,防止过拟合。

另外,上面没有提到一个问题,当w为0时怎么办?当w等于0时,|W|是不可导的,所以我们只能按照原始的未经正则化的方法去更新w,这就相当于去掉η*λ*sgn(w)/n这一项,所以我们可以规定sgn(0)=0,这样就把w=0的情况也统一进来了。(在编程的时候,令sgn(0)=0,sgn(w>0)=1,sgn(w<0)=-1)

下一篇将讨论优化器(optimizer)问题。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

'''
##简单版本
##只有两层网络,输入层和输出层,loss加入正则化项
'''

mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

batch_size = 100
n_batch = mnist.train.num_examples // batch_size

x_train = tf.placeholder(tf.float32,[None,784])
y_train = tf.placeholder(tf.float32,[None,10])

w = tf.Variable(tf.zeros([784,10]))
bias = tf.Variable(tf.zeros([1,10]))
y = tf.nn.softmax(tf.matmul(x_train,w) + bias)

loss_mse = tf.reduce_mean(tf.square(y - y_train))   #MSE代价函数
#loss_cr = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_train,logits=y)) #交叉熵代价函数
tf.add_to_collection("losses",loss_mse)
#获取整个模型的损失函数,tf.get_collection("losses")返回集合中定义的损失
#将整个集合中的损失相加得到整个模型的损失函数
loss = tf.add_n(tf.get_collection("losses"))

#train = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
#train = tf.train.AdadeltaOptimizer(1e-3).minimize(loss)
train = tf.train.MomentumOptimizer(0.1,0.9).minimize(loss)

init = tf.global_variables_initializer()

correct = tf.equal(tf.argmax(y_train,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    loss_mse_list = []
    loss_list = []
    for epoch in range(31):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x_train:batch_xs,y_train:batch_ys}) 
            loss_mse_list.append(sess.run(loss_mse,feed_dict={x_train:batch_xs,y_train:batch_ys}))
            loss_list.append(sess.run(loss,feed_dict={x_train:batch_xs,y_train:batch_ys}))
            
        acc = sess.run(accuracy,feed_dict={x_train:mnist.test.images,y_train:mnist.test.labels})
        print('iteration ', str(epoch),' accuracy: ',acc)
    plt.plot(loss_mse_list)
    plt.plot(loss_list)
    plt.legend(('loss_mse','loss'))

猜你喜欢

转载自www.cnblogs.com/54hys/p/10199359.html