机器学习 - 解决梯度消失的方法(BatchNorm, Relu, residual network)

1. Batch Normalization

1.1 简介

Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要性。本节是对论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》的导读。

机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

接下来一步一步的理解什么是BN。

**为什么深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢?**很多论文都是解决这个问题的,比如ReLU激活函数,再比如Residual Network,BN本质上也是解释并从某个不同的角度来解决这个问题的。

1.2 “Internal Covariate Shift”问题

从论文名字可以看出,BN是用来解决“Internal Covariate Shift”问题的,那么首先得理解什么是“Internal Covariate Shift”?

论文首先说明Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快;(为什么要说这些?因为BatchNorm是基于Mini-Batch SGD的);然后吐槽下SGD训练的缺点:超参数调起来很麻烦。(作者隐含意思是用BN就能解决很多SGD的缺点)

接着引入covariate shift的概念:如果ML系统实例集合<X,Y>中的输入值X的分布老是变,这不符合IID假设,网络模型很难稳定的学规律。对于深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,所以每个隐层都会面临covariate shift的问题,也就是在训练过程中,隐层的输入分布老是变来变去,这就是所谓的“Internal Covariate Shift”,Internal指的是深层网络的隐层,是发生在网络内部的事情,而不是covariate shift问题只发生在输入层。

然后提出了BatchNorm的基本思想:能不能让每个隐层节点的激活输入分布固定下来呢?这样就避免了“Internal Covariate Shift”问题了。

BN不是凭空拍脑袋拍出来的好点子,它是有启发来源的:之前的研究表明如果在图像处理中对输入图像进行白化(Whiten)操作的话——所谓白化,就是对输入数据分布变换到0均值,单位方差的正态分布——那么神经网络会较快收敛,那么BN作者就开始推论了:图像是深度神经网络的输入层,做白化能加快收敛,那么其实对于深度网络来说,其中某个隐层的神经元是下一层的输入,意思是其实深度神经网络的每一个隐层都是“输入层”,不过是相对下一层来说而已,那么能不能对每个隐层都做白化呢?这就是启发BN产生的原初想法,而BN也确实就是这么做的,可以理解为对深层神经网络每个隐层神经元的激活值做白化操作。

1.3 BatchNorm的本质思想

BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值( W x + b Wx+b Wx+b)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(死区)(对于Sigmoid函数来说,意味着激活输入值是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。

其实一句话就是:**对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。**因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。BN说到底就是这么个机制,方法很简单,道理很深刻。
图1
在这里插入图片描述

​​​
由图1知,64%的概率x落在[-1,1]的范围内,95%的概率x落在[-2,2]的范围内。这就把x值限制在了图2 sigmoid函数不饱和的地方。

但是很明显,看到这里,稍微了解神经网络的读者一般会提出一个疑问:如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?这意味着什么?我们知道,如果是多层的线性函数变换其实这个深层是没有意义的,因为多层线性网络跟一层线性网络是等价的。所以BN为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。当然,这是我的理解,论文作者并未明确这样说。但是很明显这里的scale和shift操作是会有争议的,因为按照论文作者论文里写的理想状态,就会又通过scale和shift操作把变换后的x调整回未变换的状态,那不是饶了一圈又绕回去原始的“Internal Covariate Shift”问题里去了吗,感觉论文作者并未能够清楚地解释scale和shift操作的理论原因。

1.4 BatchNorm的训练过程

上面是对BN的抽象分析和解释,具体在Mini-Batch SGD下做BN怎么做?其实论文里面这块写得很清楚也容易理解。

假设对于一个深层神经网络来说,其中的两层全连接如下:
在这里插入图片描述

其中,第t-1层有m个神经元, 输出为 y t − 1 y^{t-1} yt1, 进行计算 x t = W y t − 1 + b x^t = Wy^{t-1} +b xt=Wyt1+b之后得到第t层的n个神经元的输入。因为使用了小批量梯度下降,所以每次都是同时计算BATCHSIZE个 x t x^t xt。Batch Normalization就是要对每一个t层神经元的输入在这个BATCH上做正则化。

要对每个隐层神经元的激活值做BN,可以想象成每个隐层又加上了一层BN操作层,它位于激活值 W x + b Wx+b Wx+b获得之后,非线性函数变换之前,其图示如下:
在这里插入图片描述

具体操作流程:
在这里插入图片描述

1.5 BatchNorm的测试过程

BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,但是在测试的过程中,很明显输入就只有一个实例,看不到Mini-Batch其它实例,那么这时候怎么对输入做BN呢?因为很明显一个实例是没法算实例集合求出的均值和方差的。

既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在测试的时候直接用全局统计量即可。

有了均值和方差,每个隐层神经元也已经有对应训练好的Scaling参数和Shift参数,就可以在推导的时候对每个神经元的激活数据计算BN进行变换了。

1.6 BatchNorm的好处

① 极大提升了训练速度,收敛过程大大加快;(因为解决了梯度消失的问题)

② 增加分类效果,一种解释是这是类似于Dropout的一种防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果;

③ 另外调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等。

1.7 如何防止过拟合

Batch Normalization的主要作用是加快网络的训练速度。如果硬要说是防止过拟合,可以这样理解:

BN每次的mini-batch的数据都不一样,但是每次的mini-batch的数据都会对moving mean和moving variance产生作用,可以认为是引入了噪声,这就可以认为是进行了data augmentation,而data augmentation被认为是防止过拟合的一种方法。因此,可以认为用BN可以防止过拟合。

2. Relu 激活函数

Q: 为什么通常Relu比sigmoid和tanh强?
A:

  • 对于深层的网络而言,Sigmoid和tanh函数反向传播的过程中,饱和区域非常平缓,接近于0,容易出现梯度消失的问题,减缓收敛速度。Relu的gradient大多数情况下是常数,有助于解决深层网络的收敛问题。

  • ReLU会使一部分神经元的输出为0,这样就造成了网络的稀疏性(类似dropout),并且减少了参数的相互依存关系,缓解了过拟合问题的发生

  • 采用sigmoid等函数,算激活函数时(指数运算),计算量大 ;反向传播求误差梯度时,求导涉及除法,计算量也大。而采用Relu激活函数,整个过程的计算量节省很多

sigmoid / tanh比较常见于全连接层,relu常见于卷积层。

3. Residual Network

猜你喜欢

转载自blog.csdn.net/weixin_41332009/article/details/113834419