定义模型
简单起见,我们考虑一个无偏差项的循环神经网络 ,且**函数为恒等映射(ϕ ( x ) = x phi(x)=x ϕ ( x ) = x )。设时间步 t t t 的输入为单样本 x t ∈ R d boldsymbol{x}_t in mathbb{R}^d x t ∈ R d ,标签为 y t y_t y t ,那么隐藏状态 h t ∈ R h boldsymbol{h}_t in mathbb{R}^h h t ∈ R h 的计算表达式为
h t = W h x x t + W h h h t − 1 , boldsymbol{h}_t = boldsymbol{W}_{hx} boldsymbol{x}_t + boldsymbol{W}_{hh} boldsymbol{h}_{t-1}, h t = W h x x t + W h h h t − 1 ,
其中W h x ∈ R h × d boldsymbol{W}_{hx} in mathbb{R}^{h times d} W h x ∈ R h × d 和W h h ∈ R h × h boldsymbol{W}_{hh} in mathbb{R}^{h times h} W h h ∈ R h × h 是隐藏层权重参数。设输出层权重参数W q h ∈ R q × h boldsymbol{W}_{qh} in mathbb{R}^{q times h} W q h ∈ R q × h ,时间步t t t 的输出层变量o t ∈ R q boldsymbol{o}_t in mathbb{R}^q o t ∈ R q 计算为
o t = W q h h t . boldsymbol{o}_t = boldsymbol{W}_{qh} boldsymbol{h}_{t}. o t = W q h h t .
设时间步t t t 的损失为ℓ ( o t , y t ) ell(boldsymbol{o}_t, y_t) ℓ ( o t , y t ) 。时间步数为T T T 的损失函数L L L 定义为
L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = frac{1}{T} sum_{t=1}^T ell (boldsymbol{o}_t, y_t). L = T 1 t = 1 ∑ T ℓ ( o t , y t ) .
我们将L L L 称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。
模型计算图
为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图所示。例如,时间步3的隐藏状态h 3 boldsymbol{h}_3 h 3 的计算依赖模型参数W h x boldsymbol{W}_{hx} W h x 、W h h boldsymbol{W}_{hh} W h h 、上一时间步隐藏状态h 2 boldsymbol{h}_2 h 2 以及当前时间步输入x 3 boldsymbol{x}_3 x 3 。
方法
刚刚提到,图6.3中的模型的参数是 W h x boldsymbol{W}_{hx} W h x , W h h boldsymbol{W}_{hh} W h h 和 W q h boldsymbol{W}_{qh} W q h 。与一般的反向传播类似,训练模型通常需要模型参数的梯度∂ L ∂ W h x frac{partial L}{partial boldsymbol{W}_{hx}} ∂ W h x ∂ L 、∂ L ∂ W h h frac{partial L}{partial boldsymbol{W}_{hh}} ∂ W h h ∂ L 和∂ L ∂ W q h frac{partial L}{partial boldsymbol{W}_{qh}} ∂ W q h ∂ L 。 根据上图的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。
首先,目标函数有关各时间步输出层变量的梯度∂ L ∂ o t ∈ R q frac{partial L}{partial boldsymbol{o}_t} in mathbb{R}^q ∂ o t ∂ L ∈ R q 很容易计算:
∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . frac{partial L}{partial boldsymbol{o}_t} = frac{partial ell (boldsymbol{o}_t, y_t)}{T cdot partial boldsymbol{o}_t}. ∂ o t ∂ L = T ⋅ ∂ o t ∂ ℓ ( o t , y t ) .
下面,我们可以计算目标函数有关模型参数W q h boldsymbol{W}_{qh} W q h 的梯度∂ L ∂ W q h ∈ R q × h frac{partial L}{partial boldsymbol{W}_{qh}} in mathbb{R}^{q times h} ∂ W q h ∂ L ∈ R q × h 。根据计算图,L L L 通过o 1 , … , o T boldsymbol{o}_1, ldots, boldsymbol{o}_T o 1 , … , o T 依赖W q h boldsymbol{W}_{qh} W q h 。依据链式法则,
∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . frac{partial L}{partial boldsymbol{W}{qh}} = sum_{t=1}^T text{prod}left(frac{partial L}{partial boldsymbol{o}_t}, frac{partial boldsymbol{o}_t}{partial boldsymbol{W}_{qh}}right) = sum_{t=1}^T frac{partial L}{partial boldsymbol{o}_t} boldsymbol{h}_t^top. ∂ W q h ∂ L = t = 1 ∑ T prod ( ∂ o t ∂ L , ∂ W q h ∂ o t ) = t = 1 ∑ T ∂ o t ∂ L h t ⊤ .
其次,我们注意到隐藏状态之间也存在依赖关系。 在计算图中,L L L 只通过o T boldsymbol{o}_T o T 依赖最终时间步T T T 的隐藏状态h T boldsymbol{h}_T h T 。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度∂ L ∂ h T ∈ R h frac{partial L}{partial boldsymbol{h}_T} in mathbb{R}^h ∂ h T ∂ L ∈ R h 。依据链式法则,我们得到
∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . frac{partial L}{partial boldsymbol{h}_T} = text{prod}left(frac{partial L}{partial boldsymbol{o}_T}, frac{partial boldsymbol{o}_T}{partial boldsymbol{h}_T} right) = boldsymbol{W}_{qh}^top frac{partial L}{partial boldsymbol{o}_T}. ∂ h T ∂ L = prod ( ∂ o T ∂ L , ∂ h T ∂ o T ) = W q h ⊤ ∂ o T ∂ L .
接下来对于时间步t < T t < T t < T , 在计算图中,L L L 通过h t + 1 boldsymbol{h}_{t+1} h t + 1 和o t boldsymbol{o}_t o t 依赖h t boldsymbol{h}_t h t 。依据链式法则, 目标函数有关时间步t < T t < T t < T 的隐藏状态的梯度∂ L ∂ h T ∈ R h frac{partial L}{partial boldsymbol{h}_T} in mathbb{R}^h ∂ h T ∂ L ∈ R h 需要按照时间步从大到小依次计算: ∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t frac{partial L}{partial boldsymbol{h}_t} = text{prod} (frac{partial L}{partial boldsymbol{h}_{t+1}}, frac{partial boldsymbol{h}_{t+1}}{partial boldsymbol{h}_t}) + text{prod} (frac{partial L}{partial boldsymbol{o}_t}, frac{partial boldsymbol{o}_t}{partial boldsymbol{h}_t} ) = boldsymbol{W}_{hh}^top frac{partial L}{partial boldsymbol{h}_{t+1}} + boldsymbol{W}_{qh}^top frac{partial L}{partial boldsymbol{o}_t} ∂ h t ∂ L = prod ( ∂ h t + 1 ∂ L , ∂ h t ∂ h t + 1 ) + prod ( ∂ o t ∂ L , ∂ h t ∂ o t ) = W h h ⊤ ∂ h t + 1 ∂ L + W q h ⊤ ∂ o t ∂ L
将上面的递归公式展开,对任意时间步1 ≤ t ≤ T 1 leq t leq T 1 ≤ t ≤ T ,我们可以得到目标函数有关隐藏状态梯度的通项公式
∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . frac{partial L}{partial boldsymbol{h}_t} = sum_{i=t}^T {left(boldsymbol{W}_{hh}^topright)}^{T-i} boldsymbol{W}_{qh}^top frac{partial L}{partial boldsymbol{o}_{T+t-i}}. ∂ h t ∂ L = i = t ∑ T ( W h h ⊤ ) T − i W q h ⊤ ∂ o T + t − i ∂ L .
由上式中的指数项可见,当时间步数 T T T 较大或者时间步 t t t 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含∂ L ∂ h T frac{partial L}{partial boldsymbol{h}_T} ∂ h T ∂ L 项的梯度,例如隐藏层中模型参数的梯度∂ L ∂ W h x ∈ R h × d frac{partial L}{partial boldsymbol{W}_{hx}} in mathbb{R}^{h times d} ∂ W h x ∂ L ∈ R h × d 和 ∂ L ∂ W h h ∈ R h × h frac{partial L } {partial boldsymbol{W}_{hh}} in mathbb{R}^{h times h} ∂ W h h ∂ L ∈ R h × h 。 在计算图中,L L L 通过h 1 , … , h T boldsymbol{h}_1, ldots, boldsymbol{h}_T h 1 , … , h T 依赖这些模型参数。 依据链式法则,我们有
∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ frac{partial L}{partial boldsymbol{W}_{hx}} = sum_{t=1}^T text{prod}left(frac{partial L}{partial boldsymbol{h}_t}, frac{partial boldsymbol{h}_t}{partial boldsymbol{W}_{hx}}right) = sum_{t=1}^T frac{partial L}{partial boldsymbol{h}_t} boldsymbol{x}t^top ∂ W h x ∂ L = t = 1 ∑ T prod ( ∂ h t ∂ L , ∂ W h x ∂ h t ) = t = 1 ∑ T ∂ h t ∂ L x t ⊤ ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ frac{partial L}{partial boldsymbol{W}_{hh}} = sum_{t=1}^T text{prod}left(frac{partial L}{partial boldsymbol{h}_t}, frac{partial boldsymbol{h}_t}{partial boldsymbol{W}_{hh}}right) = sum_{t=1}^T frac{partial L}{partial boldsymbol{h}_t} boldsymbol{h}_{t-1}^top ∂ W h h ∂ L = t = 1 ∑ T prod ( ∂ h t ∂ L , ∂ W h h ∂ h t ) = t = 1 ∑ T ∂ h t ∂ L h t − 1 ⊤
每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度∂ L ∂ h t frac{partial L}{partial boldsymbol{h}_t} ∂ h t ∂ L 被计算和存储,之后的模型参数梯度∂ L ∂ W h x frac{partial L}{partial boldsymbol{W}_{hx}} ∂ W h x ∂ L 和∂ L ∂ W h h frac{partial L}{partial boldsymbol{W}_{hh}} ∂ W h h ∂ L 的计算可以直接读取∂ L ∂ h t frac{partial L}{partial boldsymbol{h}_t} ∂ h t ∂ L 的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。 举例来说,参数梯度∂ L ∂ W h h frac{partial L}{partial boldsymbol{W}_{hh}} ∂ W h h ∂ L 的计算需要依赖隐藏状态在时间步t = 0 , … , T − 1 t = 0, ldots, T-1 t = 0 , … , T − 1 的当前值h t boldsymbol{h}_t h t (h 0 boldsymbol{h}_0 h 0 是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。
参考资料
反向传播