权重梯度的计算

现在,我们终于来到了BPTT算法的最后一步:计算每个权重的梯度。

首先,我们计算误差函数E对权重矩阵W的梯度∂EW

循环神经网络的训练(2)

上图展示了我们到目前为止,在前两步中已经计算得到的量,包括每个时刻t 循环层的输出值st,以及误差项δt

回忆一下我们在文章零基础入门深度学习(3) - 神经网络和反向传播算法介绍的全连接网络的权重梯度计算算法:只要知道了任意一个时刻的误差项δt,以及上一个时刻循环层的输出值st−1,就可以按照下面的公式求出权重矩阵在t时刻的梯度∇WtE

 

WtE=⎡⎣⎢⎢⎢⎢⎢⎢δt1st−11δt2st−11..δtnst−11δt1st−12δt2st−12δtnst−12.........δt1st−1nδt2st−1nδtnst−1n⎤⎦⎥⎥⎥⎥⎥⎥(式5)

 

式5中,δti表示t时刻误差项向量的第i个分量;st−1i表示t-1时刻循环层第i个神经元的输出值。

我们下面可以简单推导一下式5

我们知道:

 

nett=⎡⎣⎢⎢⎢⎢⎢nett1nett2..nettn⎤⎦⎥⎥⎥⎥⎥==Uxt+Wst−1Uxt+⎡⎣⎢⎢⎢⎢w11w21..wn1w12w22wn2.........w1nw2nwnn⎤⎦⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢st−11st−12..st−1n⎤⎦⎥⎥⎥⎥⎥Uxt+⎡⎣⎢⎢⎢⎢⎢w11st−11+w12st−12...w1nst−1nw21st−11+w22st−12...w2nst−1n..wn1st−11+wn2st−12...wnnst−1n⎤⎦⎥⎥⎥⎥⎥(44)(45)(46)

 

因为对W求导与Uxt无关,我们不再考虑。现在,我们考虑对权重项wji求导。通过观察上式我们可以看到wji只与nettj有关,所以:

 

Ewji==∂Enettjnettjwjiδtjst−1i(47)(48)

 

按照上面的规律就可以生成式5里面的矩阵。

我们已经求得了权重矩阵W在t时刻的梯度∇WtE,最终的梯度∇WE是各个时刻的梯度之和

 

WE==∑i=1tWiE⎡⎣⎢⎢⎢⎢⎢⎢δt1st−11δt2st−11..δtnst−11δt1st−12δt2st−12δtnst−12.........δt1st−1nδt2st−1nδtnst−1n⎤⎦⎥⎥⎥⎥⎥⎥+...+⎡⎣⎢⎢⎢⎢⎢⎢δ11s01δ12s01..δ1ns01δ11s02δ12s02δ1ns02.........δ11s0nδ12s0nδ1ns0n⎤⎦⎥⎥⎥⎥⎥⎥(式6)(49)(50)

 

式6就是计算循环层权重矩阵W的梯度的公式。

----------数学公式超高能预警----------

前面已经介绍了∇WE的计算方法,看上去还是比较直观的。然而,读者也许会困惑,为什么最终的梯度是各个时刻的梯度之和呢?我们前面只是直接用了这个结论,实际上这里面是有道理的,只是这个数学推导比较绕脑子。感兴趣的同学可以仔细阅读接下来这一段,它用到了矩阵对矩阵求导、张量与向量相乘运算的一些法则。

我们还是从这个式子开始:

 

nett=Uxt+Wf(nett−1)

 

因为Uxt与W完全无关,我们把它看做常量。现在,考虑第一个式子加号右边的部分,因为W和f(nett−1)都是W的函数,因此我们要用到大学里面都学过的导数乘法运算:

 

(uv)′=uv+uv

 

因此,上面第一个式子写成:

 

∂nettW=∂WWf(nett−1)+Wf(nett−1)∂W

 

我们最终需要计算的是∇WE

 

WE===∂EWE∂nett∂nettWδTtWWf(nett−1)+δTtWf(nett−1)∂W(式7)(51)(52)(53)

 

我们先计算式7加号左边的部分。∂WW矩阵对矩阵求导,其结果是一个四维张量(tensor),如下所示:

 

WW===⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w11∂Ww21∂W..∂wn1∂Ww12∂Ww22∂Wwn2∂W.........∂w1nWw2nWwnnW⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w11∂w11∂w11∂w21..∂w11∂wn1∂w11∂w12∂w11∂w22∂w11∂wn2.........∂w11∂1nw11∂2nw11∂nn⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w12∂w11∂w12∂w21..∂w12∂wn1∂w12∂w12∂w12∂w22∂w12∂wn2.........∂w12∂1nw12∂2nw12∂nn⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢10..0000.........000⎤⎦⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢00..0100.........000⎤⎦⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥(54)(55)(56)

 

接下来,我们知道st−1=f(nett−1),它是一个列向量。我们让上面的四维张量与这个向量相乘,得到了一个三维张量,再左乘行向量δTt,最终得到一个矩阵:

 

δTtWWf(nett−1)======δTtWWst−1δTt⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢10..0000.........000⎤⎦⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢00..0100.........000⎤⎦⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢st−11st−12..st−1n⎤⎦⎥⎥⎥⎥⎥δTt⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢st−110..0⎤⎦⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢st−120..0⎤⎦⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥[δt1δt2...δtn]⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢st−110..0⎤⎦⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢st−120..0⎤⎦⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢δt1st−11δt2st−11..δtnst−11δt1st−12δt2st−12δtnst−12.........δt1st−1nδt2st−1nδtnst−1n⎤⎦⎥⎥⎥⎥⎥⎥∇WtE(57)(58)(59)(60)(61)(62)

 

接下来,我们计算式7加号右边的部分:

 

δTtWf(nett−1)∂W====δTtWf(nett−1)∂nett−1∂nett−1∂WδTtWf′(nett−1)∂nett−1∂WδTt∂nett∂nett−1∂nett−1∂WδTt−1∂nett−1∂W(63)(64)(65)(66)

 

于是,我们得到了如下递推公式:

 

WE======∂EWE∂nett∂nettWWtE+δTt−1∂nett−1∂WWtE+∇Wt−1E+δTt−2∂nett−2∂WWtE+∇Wt−1E+...+∇W1Ek=1tWkE(67)(68)(69)(70)(71)(72)

 

这样,我们就证明了:最终的梯度∇WE是各个时刻的梯度之和。

----------数学公式超高能预警解除----------

同权重矩阵W类似,我们可以得到权重矩阵U的计算方法。

 

UtE=⎡⎣⎢⎢⎢⎢⎢⎢δt1xt1δt2xt1..δtnxt1δt1xt2δt2xt2δtnxt2.........δt1xtmδt2xtmδtnxtm⎤⎦⎥⎥⎥⎥⎥⎥(式8)

 

式8是误差函数在t时刻对权重矩阵U的梯度。和权重矩阵W一样,最终的梯度也是各个时刻的梯度之和:

 

UE=∑i=1tUiE

 

具体的证明这里就不再赘述了,感兴趣的读者可以练习推导一下。

RNN的梯度爆炸和消失问题

不幸的是,实践中前面介绍的几种RNNs并不能很好的处理较长的序列。一个主要的原因是,RNN在训练中很容易发生梯度爆炸梯度消失,这导致训练时梯度不能在较长序列中一直传递下去,从而使RNN无法捕捉到长距离的影响。

为什么RNN会产生梯度爆炸和消失问题呢?我们接下来将详细分析一下原因。我们根据式3可得:

 

δTk=∥δTk∥⩽⩽δTti=kt−1Wdiag[f′(neti)]∥δTt∥∏i=kt−1∥W∥∥diag[f′(neti)]∥∥δTt∥(βWβf)tk(73)(74)(75)

 

上式的β定义为矩阵的模的上界。因为上式是一个指数函数,如果t-k很大的话(也就是向前看很远的时候),会导致对应的误差项的值增长或缩小的非常快,这样就会导致相应的梯度爆炸梯度消失问题(取决于β大于1还是小于1)。

通常来说,梯度爆炸更容易处理一些。因为梯度爆炸的时候,我们的程序会收到NaN错误。我们也可以设置一个梯度阈值,当梯度超过这个阈值的时候可以直接截取。

梯度消失更难检测,而且也更难处理一些。总的来说,我们有三种方法应对梯度消失问题:

  1. 合理的初始化权重值。初始化权重,使每个神经元尽可能不要取极大或极小值,以躲开梯度消失的区域。
  2. 使用relu代替sigmoid和tanh作为**函数。原理请参考上一篇文章零基础入门深度学习(4) - 卷积神经网络**函数一节。
  3. 使用其他结构的RNNs,比如长短时记忆网络(LTSM)和Gated Recurrent Unit(GRU),这是最流行的做法。我们将在以后的文章中介绍这两种网络。