干货 | 这可能全网最好的BatchNorm详解

文章来自:公众号【机器学习炼丹术】。求关注~

其实关于BN层,我在之前的文章“梯度爆炸”那一篇中已经涉及到了,但是鉴于面试经历中多次问道这个,这里再做一个更加全面的讲解。

Internal Covariate Shift(ICS)

Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。

这里做一个简单的数学定义,对于全链接网络而言,第i层的数学表达可以体现为:
\(Z^i=W^i\times input^i+b^i\)
\(input^{i+1}=g^i(Z^i)\)

第一个公式就是一个简单的线性变换;

第二个公式是表示一个激活函数的过程。

【怎么理解ICS问题】
我们知道,随着梯度下降的进行,每一层的参数\(W^i,b^i\)都会不断地更新,这意味着\(Z^i\)的分布也不断地改变,从而\(input^{i+1}\)的分布发生了改变。这意味着,除了第一层的输入数据不改变,之后所有层的输入数据的分布都会随着模型参数的更新发生改变,而每一层就要不停的去适应这种数据分布的变化,这个过程就是Internal Covariate Shift。

BN解决的问题

【ICS带来的收敛速度慢】
因为每一层的参数不断发生变化,从而每一层的计算结果的分布发生变化,后层网络不断地适应这种分布变化,这个时候会让整个网络的学习速度过慢。

【梯度饱和问题】
因为神经网络中经常会采用sigmoid,tanh这样的饱和激活函数(saturated actication function),因此模型训练有陷入梯度饱和区的风险。解决这样的梯度饱和问题有两个思路:第一种就是更为非饱和性激活函数,例如线性整流函数ReLU可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也就是Normalization的思路。

Batch Normalization

batchNormalization就像是名字一样,对一个batch的数据进行normalization。

现在假设一个batch有3个数据,每个数据有两个特征:(1,2),(2,3),(0,1)

如果做一个简单的normalization,那么就是计算均值和方差,把数据减去均值除以标准差,变成0均值1方差的标准形式。

对于第一个特征来说:
\(\mu=\frac{1}{3}(1+2+0)=1\)
\(\sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67\)

【通用公式】
\(\mu=\frac{1}{m}\sum_{i=1}^m{Z}\)
\(\sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu)\)
\(\hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}\)

其中m表示一个batch的数量。

\(\epsilon\)是一个极小数,防止分母为0。

目前为止,我们做到了让每个特征的分布均值为0,方差为1。这样分布都一样,一定不会有ICS问题

如同上面提到的,Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。每一层的分布都相同,所有任务的数据分布都相同,模型学啥呢

【0均值1方差数据的弊端】

数据表达能力的缺失;

通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。(线性区域和饱和区域都不理想,最好是非线性区域)

为了解决这个问题,BN层引入了两个可学习的参数\(\gamma\)\(\beta\),这样,经过BN层normalization的数据其实是服从\(\beta\)均值,\(\gamma^2\)方差的数据。

所以对于某一层的网络来说,我们现在变成这样的流程:

\(Z=W\times input^i+b\)

\(\hat{Z}=\gamma \times \frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta\)

\(input^{i+1}=g(\hat{Z})\)

(上面公式中,省略了\(i\),总的来说是表示第i层的网络层产生第i+1层输入数据的过程)

测试阶段的BN

我们知道BN在每一层计算的\(\mu\)\(\sigma^2\) 都是基于当前batch中的训练数据,但是这就带来了一个问题:我们在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,这样的\(\sigma^2\)\(\mu\)要怎么计算呢?

利用训练集训练好模型之后,其实每一层的BN层都保留下了每一个batch算出来的\(\mu\)\(\sigma^2\).然后呢利用整体的训练集来估计测试集的\(\mu_{test}\)\(\sigma_{test}^2\)
\(\mu_{test}=E(\mu_{train})\)
\(\sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2)\)
然后再对测试机进行BN层:

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zwgdjw.html