今天给大家分享分享循环神经网络(以LSTM为研究对象)的内部计算逻辑,本次博客从keras源码,并结合一位博主的博客对其进行详细剖析。博客:https://www.cnblogs.com/wangduo/p/6773601.html?utm_source=itdadao&utm_medium=referral,这是一篇非常经典且详细的博客,大家一定要抽时间去过一遍,并仔细思考。探讨之前,假设各位看官已经有RNN的一丢丢基础,线性代数的一丢丢基础和常见深度学习的一丢丢基础。

OK,开始吧。

 

一问带你看懂循环神经网络小黑匣内部结构——LSTM

以上是引自上述博客的图片,表达的是lstm结构的框架,但是其实结构会更加复杂一些,咱们稍后作详解,但是会基于这个结构图做调整。

1、ont-hot编码

公 [0 0 0 0 1 0 0 0 0 0]
主 [0 0 0 1 0 0 0 0 0 0]
很 [0 0 1 0 0 0 0 0 0 0]
漂 [0 1 0 0 0 0 0 0 0 0] 
亮 [1 0 0 0 0 0 0 0 0 0]

咱们假设有一句话“公主很漂亮”,经过one-hot编码后形成shape=(5, 10)的张量(假设语料库总共有10个字, 所以是(5 , 10)),这个一句话在lstm过程中,是这样的:

一问带你看懂循环神经网络小黑匣内部结构——LSTM

最初的循环神经网络要做的事就是,通过公预测主,通过主预测很,通过很预测漂,通过漂预测亮,然后通过上一个步骤预测下一个步骤的过程,我们把他称为“时间片”操作,“公主很漂亮”就分成了5个时间片,通常称为“time_step”。

具体的过程是:输入x1=“主”,经过LSTM,h1就会得出“很”,这个h1就是“短时记忆”,c1就会得出一个状态(张量),这个状态c1就是“长时记忆”;接下来h1会跟x2结合(这不是简单加法,咋们后续谈这个“结合”),参与计算该时间片的操作,c1也会参与到本次操作的计算中来,经过LSTM,h2得出“漂”,c2得出新的状态;如此循环!

总结出来就是:本次输入结合上次输出的“短时记忆”  和  上次输出的“长时记忆”  经过  LSTM单元,得出 下一次的“短时记忆”以及下一次的“长时记忆”。这就是循环神经网络要做的事。

 

好,咱们结合源码,咱们重新画这张图:

以下是keras LSTMCell的源码,有兴趣的移步一下

class LSTMCell(Layer):
    """Cell class for the LSTM layer.

    # Arguments
        units: Positive integer, dimensionality of the output space.
        activation: Activation function to use
            (see [activations](../activations.md)).
            Default: hyperbolic tangent (`tanh`).
            If you pass `None`, no activation is applied
            (ie. "linear" activation: `a(x) = x`).
        recurrent_activation: Activation function to use
            for the recurrent step
            (see [activations](../activations.md)).
            Default: hard sigmoid (`hard_sigmoid`).
            If you pass `None`, no activation is applied
            (ie. "linear" activation: `a(x) = x`).x
        use_bias: Boolean, whether the layer uses a bias vector.
        kernel_initializer: Initializer for the `kernel` weights matrix,
            used for the linear transformation of the inputs
            (see [initializers](../initializers.md)).
        recurrent_initializer: Initializer for the `recurrent_kernel`
            weights matrix,
            used for the linear transformation of the recurrent state
            (see [initializers](../initializers.md)).
        bias_initializer: Initializer for the bias vector
            (see [initializers](../initializers.md)).
        unit_forget_bias: Boolean.
            If True, add 1 to the bias of the forget gate at initialization.
            Setting it to true will also force `bias_initializer="zeros"`.
            This is recommended in [Jozefowicz et al.]
            (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
        kernel_regularizer: Regularizer function applied to
            the `kernel` weights matrix
            (see [regularizer](../regularizers.md)).
        recurrent_regularizer: Regularizer function applied to
            the `recurrent_kernel` weights matrix
            (see [regularizer](../regularizers.md)).
        bias_regularizer: Regularizer function applied to the bias vector
            (see [regularizer](../regularizers.md)).
        kernel_constraint: Constraint function applied to
            the `kernel` weights matrix
            (see [constraints](../constraints.md)).
        recurrent_constraint: Constraint function applied to
            the `recurrent_kernel` weights matrix
            (see [constraints](../constraints.md)).
        bias_constraint: Constraint function applied to the bias vector
            (see [constraints](../constraints.md)).
        dropout: Float between 0 and 1.
            Fraction of the units to drop for
            the linear transformation of the inputs.
        recurrent_dropout: Float between 0 and 1.
            Fraction of the units to drop for
            the linear transformation of the recurrent state.
        implementation: Implementation mode, either 1 or 2.
            Mode 1 will structure its operations as a larger number of
            smaller dot products and additions, whereas mode 2 will
            batch them into fewer, larger operations. These modes will
            have different performance profiles on different hardware and
            for different applications.
    """

    def __init__(self, units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 implementation=1,
                 **kwargs):
        super(LSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.unit_forget_bias = unit_forget_bias

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.implementation = implementation
        self.state_size = (self.units, self.units)
        self.output_size = self.units
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)

        if self.use_bias:
            if self.unit_forget_bias:
                def bias_initializer(_, *args, **kwargs):
                    return K.concatenate([
                        self.bias_initializer((self.units,), *args, **kwargs),
                        initializers.Ones()((self.units,), *args, **kwargs),
                        self.bias_initializer((self.units * 2,), *args, **kwargs),
                    ])
            else:
                bias_initializer = self.bias_initializer
            self.bias = self.add_weight(shape=(self.units * 4,),
                                        name='bias',
                                        initializer=bias_initializer,
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        self.kernel_i = self.kernel[:, :self.units]
        self.kernel_f = self.kernel[:, self.units: self.units * 2]
        self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
        self.kernel_o = self.kernel[:, self.units * 3:]

        self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
        self.recurrent_kernel_f = (
            self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
            self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

        if self.use_bias:
            self.bias_i = self.bias[:self.units]
            self.bias_f = self.bias[self.units: self.units * 2]
            self.bias_c = self.bias[self.units * 2: self.units * 3]
            self.bias_o = self.bias[self.units * 3:]
        else:
            self.bias_i = None
            self.bias_f = None
            self.bias_c = None
            self.bias_o = None
        self.built = True

    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return h, [h, c]

以下是对核心参数做解释

# 假设输入的句子shape=(row, col),或者你可以认为就是(5, 10)

self.units = units    # 这是神经元个数
self.activation = activations.get(activation)    # tanh**函数
self.recurrent_activation = activations.get(recurrent_activation)    # sigmoid**

# 初始化一个shape=(col, 4*units)的张量,给本级的输入做准备
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

# 初始化一个shape=(units, 4*units)的张量,给上一级的输入状态做准备
self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)


# 中间输入有4个,fico,所以把这个分成四份,对应到不同的输出上
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]

self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

# 上一级的输出h,shape=(1, units)
h_tm1 = states[0]  
# 上一级的状态(长时记忆),shape=(1, units)
c_tm1 = states[1]


# 这里其实就是上一级的输出h
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs

好,开始看LSTM计算逻辑

这里是一步在上述博客没提到的地方

# 对本次输入做矩阵乘法,假设input的shape=(m, n), units=128,则
# x_i的shape=(1, 128), x_f、x_c、x_o同理
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)

# 对上一级的输出新赋值
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1

# 重点1,上次输入做了矩阵乘法,跟本次输入做加法,**,三步
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))

h = o * self.activation(c)

假设本次时间为t2,即第三个时间片!(这里的*是指矩阵乘法)

对本级输入:

x2(shape=(1, n))  *  kernel(shape=(n, units)) = A(shape=(m, units))

对上级输入:

h1(shape=(1, units)) * recurrent_kernel(shape=(units, units)) = B(shape=(1, units))

对两次做加法:

A + B = C(shape=(1, units))

也就是说,上级输入和本级输入,在输入的时候就已经做了“黑匣”计算了,即矩阵乘法。

我们可以看到,重点1中的式子都有这么一项:

x_ + K.dot(h_tm1_, self.recurrent_kernel_

其中:x_ = K.dot(inputs_, self.kernel_)

所以,这个图,应该是这样的:(注意,深红色的X是矩阵乘法,橙色的X是指对应位置的元素相乘)

一问带你看懂循环神经网络小黑匣内部结构——LSTM

与原博客有两点区别:

1、添加了输入部分的详细细节,这是解答“黑匣”运作的核心部分,这里是整条红线所有后续操作的数据起始点

2、添加C的位置,去掉预选C的位置(我认为原博客的预选C会造成误导。。。至少我被误导了。。。)

 

咱们慢慢过他的源码!

(1)首先,虽然我们输入(假设)是(m,n)=(5, 10),但是分成了5个时间片,所以每个输入的x,其shape是(1, 10)

所以x2=shape(1, n),我们先明确一下输入的shape,并且假设units=128

(2)处理本级输入!

input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

input_dim就是n

这里初始化了一个shape=(10, 128*4)的recurrent_kernel,为什么要四个?因为ifco四个输入分别对应一个(128, 128)的矩阵(这里也正好说明了,ifco四个步骤的“黑匣”是分开训练的)。

self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]

这里就是正解,不同的四个输入对应不同的四个维度张量,做反向传播时肯定是分开优化的!

到此:回顾一下,x2.shape=(1, 10), recurrent_kernel.shape=(10, 128),input_i等就是x2

x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)

根据矩阵乘法,x2跟kernel经过X后,输出

本级输入的输入贡献:input1.shape=(1, 128),即:(1, units)

(3)处理上级输入!

self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

到这里,其对kernel的处理方法跟上述是一致的,唯一的区别就是recurrent.shape=(128, 128),为什么?因为

i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))

因为h_tm1_i是上一级的输入,上一级的输入的shape=(1, 128),所以这里的kernel.shape=(128, 128),不然没法乘了是吧

K.dot(h_tm1_i, self.recurrent_kernel_i)

四个dot操作就是对上级输入做的矩阵乘法,只是程序猿把所有的后续操作都写在了一起,咱们读起来也是费劲- - 

到此,再回顾一下:上级输入h1.shape=(1, 128), recurrent_kernel.shape=(128 ,128),根据矩阵乘法,h1跟recurrent_kernel经过X后,输出

上级输出的输入的贡献:input2.shape=(1, 128),即:(1, units),也是(1, 128)

(4)input1跟input2经过一个 + 加法操作,变成红线,即:input.shape=(1, 128)

到此,对本级输入和上级输出的前期操作就完成了,这里是整个lstm的核心黑匣操作,这里的操作跟全连接层(Dense)、嵌入层(Embedding)的原理是一致的。

(5) Go on!喝口茶~~~

(6)计算后续四个中间输出,还是这段代码!

i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
h = o * self.activation(c)

i = sigmoid(input)    shape=(1, 128)

f = sigmoid(input)    shape=(1, 128)

c = f * c1(上级输出状态) + i * tanh(input)    shape=(1, 128)     注意这里用的是*,这是指对应元素相乘

o = sigmoid(input)             shape=(1, 128)

h = tanh(o)                   shape=(1, 128)

 

(7)结果!

return h, [h, c]

这也看出来了,返回的结果中包含了两个h,即:本级输出,c本级状态

第一个h对应途中上面的h,[h ,c]对应右边的h和c

 

(8)部分参数解释

在LSTM中,有这么两个参数:return_state和return_sequence,默认都是false

此时返回 h, shape=(None, 128), None是batch

如果  return_state=True, 则返回h(shape=(None, 128)), h(shape=(None, 128)), c(shape=(None, 128))

ru如果return_sequence=True, 则返回h(shape=(None,5, 128)), h(shape=(None, 128)), c(shape=(None, 128))

第一个h变了,他把每一层的h结果都返回了,否则是指返回最后一个时间片的h。

 

埋个伏笔:这种常规的LSTM结构,只能做两种预测,

一种是N:N的预测(return_squence=True),一种N:1的预测(return_squence=False),机器翻译的最大问题就是,多对多。

举个栗子:吃完饭=have dinner (2:3), 你在开玩笑吧=are you kidding(6:3)

咋办?