原文链接: https://zhuanlan.zhihu.com/p/31575074

在梳理CNN经典模型的过程中,我理解到其实经典模型演进中的很多创新点都与改善模型计算复杂度紧密相关,因此今天就让我们对卷积神经网络的复杂度分析简单总结一下下。

本文主要关注的是针对模型本身的复杂度分析(其实并不是很复杂啦~)。如果想要进一步评估模型在计算平台上的理论计算性能,则需要了解 Roofline Model 的相关理论,欢迎阅读本文的进阶版: Roofline Model与深度学习模型的性能分析


1. 时间复杂度

即模型的运算次数,可用 卷积神经网络的复杂度分析 衡量。

1.1 单个卷积层的时间复杂度

卷积神经网络的复杂度分析

  • 卷积神经网络的复杂度分析 每个卷积核输出特征图 卷积神经网络的复杂度分析 的边长
  • 卷积神经网络的复杂度分析 每个卷积核 卷积神经网络的复杂度分析 的边长
  • 卷积神经网络的复杂度分析 每个卷积核的通道数,也即输入通道数,也即上一层的输出通道数。
  • 卷积神经网络的复杂度分析 本卷积层具有的卷积核个数,也即输出通道数。
  • 可见,每个卷积层的时间复杂度由输出特征图面积 卷积神经网络的复杂度分析 、卷积核面积 卷积神经网络的复杂度分析 、输入 卷积神经网络的复杂度分析 和输出通道数 卷积神经网络的复杂度分析 完全决定。
  • 其中,输出特征图尺寸本身又由输入矩阵尺寸 卷积神经网络的复杂度分析 、卷积核尺寸 卷积神经网络的复杂度分析卷积神经网络的复杂度分析卷积神经网络的复杂度分析 这四个参数所决定,表示如下:

卷积神经网络的复杂度分析

  • 注1:为了简化表达式中的变量个数,这里统一假设输入和卷积核的形状都是正方形。
  • 注2:严格来讲每层应该还包含 1 个 卷积神经网络的复杂度分析 参数,这里为了简洁就省略了。

1.2 卷积神经网络整体的时间复杂度

卷积神经网络的复杂度分析

  • 卷积神经网络的复杂度分析 神经网络所具有的卷积层数,也即网络的深度
  • 卷积神经网络的复杂度分析 神经网络第 卷积神经网络的复杂度分析 个卷积层
  • 卷积神经网络的复杂度分析 神经网络第 卷积神经网络的复杂度分析 个卷积层的输出通道数 卷积神经网络的复杂度分析 ,也即该层的卷积核个数。
  • 对于第 卷积神经网络的复杂度分析 个卷积层而言,其输入通道数 卷积神经网络的复杂度分析 就是第 卷积神经网络的复杂度分析 个卷积层的输出通道数。
  • 可见,CNN整体的时间复杂度并不神秘,只是所有卷积层的时间复杂度累加而已。
  • 简而言之,层内连乘,层间累加。

示例:用 Numpy 手动简单实现二维卷积

假设 Stride = 1, Padding = 0, img 和 kernel 都是 np.ndarray.

def conv2d(img, kernel):
    height, width, in_channels = img.shape
    kernel_height, kernel_width, in_channels, out_channels = kernel.shape
    out_height = height - kernel_height + 1
    out_width = width - kernel_width + 1
    feature_maps = np.zeros(shape=(out_height, out_width, out_channels))
    for oc in range(out_channels):              # Iterate out_channels (# of kernels)
        for h in range(out_height):             # Iterate out_height
            for w in range(out_width):          # Iterate out_width
                for ic in range(in_channels):   # Iterate in_channels
                    patch = img[h: h + kernel_height, w: w + kernel_width, ic]
                    feature_maps[h, w, oc] += np.sum(patch * kernel[:, :, ic, oc])

    return feature_maps

2. 空间复杂度

空间复杂度包括模型的参数数量(模型本身的体积)和每层输出的特征图大小(会影响模型运行时的内存占用情况)。

卷积神经网络的复杂度分析

  • 可见,网络的参数量只与卷积核的尺寸 卷积神经网络的复杂度分析 、通道数 卷积神经网络的复杂度分析 、网络的深度 卷积神经网络的复杂度分析 相关。而与输入数据的大小无关
  • 当我们需要裁剪模型时,由于卷积核的尺寸通常已经很小,而网络的深度又与模型的能力紧密相关,不宜过多削减,因此模型裁剪通常最先下手的地方就是通道数。


3. 复杂度对模型的影响

  • 时间复杂度决定了模型的训练/预测时间。如果复杂度过高,则会导致模型训练和预测耗费大量时间,既无法快速的验证想法和改善模型,也无法做到快速的预测。
  • 空间复杂度决定了模型的参数数量。由于维度诅咒的限制,模型的参数越多,训练模型所需的数据量就越大,而现实生活中的数据集通常不会太大,这会导致模型的训练更容易过拟合。


4. Inception 系列模型是如何优化复杂度的

通过五个小例子说明模型的演进过程中是如何优化复杂度的。

4.1 卷积神经网络的复杂度分析 中的 卷积神经网络的复杂度分析 卷积降维

卷积神经网络的复杂度分析(图像被压缩的惨不忍睹...)

  • InceptionV1 借鉴了 Network in Network 的思想,在一个 Inception Module 中构造了四个并行的不同尺寸的卷积/池化模块(上图左),有效的提升了网络的宽度。但是这么做也造成了网络的时间和空间复杂度的激增。对策就是添加 1 x 1 卷积(上图右红色模块)将输入通道数先降到一个较低的值,再进行真正的卷积。
  • 以 InceptionV1 论文中的 (3b) 模块为例,输入尺寸为 卷积神经网络的复杂度分析卷积神经网络的复杂度分析 卷积核 卷积神经网络的复杂度分析 个, 卷积神经网络的复杂度分析 卷积核 卷积神经网络的复杂度分析 个, 卷积神经网络的复杂度分析 卷积核 卷积神经网络的复杂度分析 个,卷积核一律采用 Same Padding 确保输出不改变尺寸。
  • 卷积神经网络的复杂度分析 卷积分支上加入 卷积神经网络的复杂度分析卷积神经网络的复杂度分析 卷积前后的时间复杂度对比如下式:

卷积神经网络的复杂度分析

  • 同理,在 卷积神经网络的复杂度分析 卷积分支上加入 卷积神经网络的复杂度分析卷积神经网络的复杂度分析 卷积前后的时间复杂度对比如下式:

卷积神经网络的复杂度分析

  • 可见,使用 卷积神经网络的复杂度分析 卷积降维可以降低时间复杂度3倍以上。该层完整的运算量可以在论文中查到,为 300 M,即 卷积神经网络的复杂度分析
  • 另外在空间复杂度上,虽然降维引入了三组 卷积神经网络的复杂度分析 卷积核的参数,但新增参数量仅占整体的 5%,影响并不大。

4.2 卷积神经网络的复杂度分析 中使用 卷积神经网络的复杂度分析 代替 卷积神经网络的复杂度分析

  • 全连接层可以视为一种特殊的卷积层,其卷积核尺寸 卷积神经网络的复杂度分析 与输入矩阵尺寸 卷积神经网络的复杂度分析 一模一样。每个卷积核的输出特征图是一个标量点,即 卷积神经网络的复杂度分析 。复杂度分析如下:

卷积神经网络的复杂度分析

  • 可见,与真正的卷积层不同,全连接层的空间复杂度与输入数据的尺寸密切相关。因此如果输入图像尺寸越大,模型的体积也就会越大,这显然是不可接受的。例如早期的VGG系列模型,其 90% 的参数都耗费在全连接层上。
  • InceptionV1 中使用的全局最大池化 GAP 改善了这个问题。由于每个卷积核输出的特征图在经过全局最大池化后都会直接精炼成一个标量点,因此全连接层的复杂度不再与输入图像尺寸有关,运算量和参数数量都得以大规模削减。复杂度分析如下:

卷积神经网络的复杂度分析

4.3 卷积神经网络的复杂度分析 中使用两个 卷积神经网络的复杂度分析 卷积级联替代 卷积神经网络的复杂度分析 卷积分支

卷积神经网络的复杂度分析
感受野不变

  • 根据上面提到的二维卷积输入输出尺寸关系公式,可知:对于同一个输入尺寸,单个 卷积神经网络的复杂度分析 卷积的输出与两个 卷积神经网络的复杂度分析 卷积级联输出的尺寸完全一样,即感受野相同。
  • 同样根据上面提到的复杂度分析公式,可知:这种替换能够非常有效的降低时间和空间复杂度。我们可以把辛辛苦苦省出来的这些复杂度用来提升模型的深度和宽度,使得我们的模型能够在复杂度不变的前提下,具有更大的容量,爽爽的。
  • 同样以 InceptionV1 里的 (3b) 模块为例,替换前后的 卷积神经网络的复杂度分析 卷积分支复杂度如下:

卷积神经网络的复杂度分析

4.4 卷积神经网络的复杂度分析 中使用 卷积神经网络的复杂度分析卷积神经网络的复杂度分析 卷积级联替代 卷积神经网络的复杂度分析 卷积

卷积神经网络的复杂度分析

  • InceptionV3 中提出了卷积的 Factorization,在确保感受野不变的前提下进一步简化。
  • 复杂度的改善同理可得,不再赘述。

4.5 卷积神经网络的复杂度分析 中使用 卷积神经网络的复杂度分析

卷积神经网络的复杂度分析

  • 我们之前讨论的都是标准卷积运算,每个卷积核都对输入的所有通道进行卷积。
  • Xception 模型挑战了这个思维定势,它让每个卷积核只负责输入的某一个通道,这就是所谓的 Depth-wise Separable Convolution。
  • 从输入通道的视角看,标准卷积中每个输入通道都会被所有卷积核蹂躏一遍,而 Xception 中每个输入通道只会被对应的一个卷积核扫描,降低了模型的冗余度。
  • 标准卷积与可分离卷积的时间复杂度对比:可以看到本质上是把连乘转化成为相加。

卷积神经网络的复杂度分析


5. 总结

通过上面的推导和经典模型的案例分析,我们可以清楚的看到其实很多创新点都是围绕模型复杂度的优化展开的,其基本逻辑就是乘变加。模型的优化换来了更少的运算次数和更少的参数数量,一方面促使我们能够构建更轻更快的模型(例如MobileNet),一方面促使我们能够构建更深更宽的网络(例如Xception),提升模型的容量,打败各种大怪兽,欧耶~