卷积神经网络当中, 卷积运算是尤其是计算敏感的, 尤其是在端上设备中, 对于性能的要求更为苛刻。对于卷积优化的方法也有很多种,本文便针对近年来最常见的优化方法Winograd做一个简单总结。

相关资料

winograd算法最早是1980年由Terry Winograd提出的,当时并没有引起太大的轰动。在CVPR'16会议上,Lavin等人[1]提出了利用winogrd加速卷积运算,于是winograd加速卷积优化在算法圈里火了一把。网上较多的实现版本为andravin实现的py版本[2]。目前cudnn中计算卷积就使用了该方法。

[1] "Fast Algorithms for Convolutional Neural Networks" Lavin and Gray, CVPR 2016.
[2] https://github.com/andravin/wincnn

算法

在winograd算法下,对于一维卷积,当输出为m,卷积核长为r,要对应的乘法数量:

\[u(F(m,r)) = m+r+1
\]

将一维卷积扩展到二维,如果输出维度为mxn,卷积核维度为rxs,则对应的乘法数量:

\[u(F(m * n,r * s)) = u(F(m,r)) * u(F(n,s)) = (m+r-1) * (n+s-1)
\]

对一个矩阵大小为4 * 4的输入,卷积核大小为3 * 3,对应的输出为2 * 2,正常计算的情况下,滑动窗口或者im2col的计算方法的乘法次数为2*2*3*3 = 36次,而当使用winograd时,对应的乘法次数为$ u(F(2*2,3*3)) = (2+3-1) * (2+3-1)=16 $,乘法次数明显减少。

假设对应的一维输入为[d0,d1,d2,d3],对应的卷积为[g0,g1,g2],对应的输出为[m0,m1,m2],那么:

\[F(2,3) = \begin{bmatrix}d0 & d1 & d2\\\\d1 & d2 &d3\end{bmatrix} \begin{bmatrix}g0\\\\g1\\\\g3\end{bmatrix} = \begin{bmatrix}m1+m2+m3\\\\m2-m3-m4\end{bmatrix}
\]

其中:

\[m1 = (d0-d1)g0
\]

\[m2 = 0.5(d1+d2)(g0+g1+g2)
\]

\[m3 = 0.5(d2-d1)(g0-g1+g2)
\]

\[m4 = (d1-d3)g2
\]

这种计算方式需要2+3-1=4次乘法,4次加法。写成矩阵乘法的形式即为:

\[Y = A^T \left[\left[Gg\right] \odot \left[B^Td\right]\right]
\]

其中$ \odot $表示 element-wise multiplication. 对于F(2,3),以上矩阵分别为:

\[B^{T}=\begin{bmatrix} 1 &0&-1 &0 \\ 0&1 &1 &0 \\ 0&-1 &1 &0 \\ 0& 1& 0& -1 \end{bmatrix}
\]

\[G=\begin{bmatrix} 1 & 0 & 0\\ 0.5& 0.5 &0.5 \\ 0.5& -0.5 &0.5 \\ 0& 0 &1 \end{bmatrix}
\]

\[A^{T}=\begin{bmatrix} 1 & 1 & 1 & 0\\ 0 & 1& -1 & -1 \end{bmatrix}
\]

\[g=\begin{bmatrix} g_{0} &g_{1} &g_{2} \end{bmatrix}^{T}
\]

\[d=\begin{bmatrix} d_{0} &d_{1} &d_{2}&d_{3} \end{bmatrix}^{T}
\]

扩展为二维的形式即为:

\[Y = A^T \left[\left[GgG^T\right] \odot \left[B^TdB\right]\right]A
\]

注意

  1. 以上描述的 Winograd 算法只展示了在二维的图像 (更确切的说是 tile) 上的过程, 具体在 ConvNet 的多个 channel 的情况, 直接逐个 channel 按照上述方法计算完然后相加即可;
  2. 按照 1. 的思路, 在计算多个 channel 的时候, 仍然有可减少计算次数的地方.
  3. 按照 2. 的思路, Winograd 在目前使用越来越多的 depthwise conv 中其优势不明显了.
  4. 在 tile 较大的时候, Winograd 方法不适用, 因为, 在做 inverse transform 的时候的计算开销抵消了 Winograd 带来的计算节省.
    Winograd 会产生误差