生成式对抗网络GAN-入门篇

近年来,GAN逐渐成为了最热门的神经网络框架。在阅读了若干篇关于GAN的文章后,希望用尽量通俗易懂的语言对GAN的基本工作原理,最优解推导,训练方法,存在的问题以及应用做一个简单的总结并与大家交流,可作为GAN的一个入门参考。

1 写在GAN之前

生成模型和判别模型是机器学习领域的两个重要内容。判别模型是通过寻找不同类别数据之间的分界面来实现数据的分类;而生成模型则通过是估计数据的分布来生成新的数据,当然也可以用于分类。生成模型的构建方法可以大致分为两类,即显示地估计样本分布和隐式地估计样本分布[1],也人有将其称为从人理解数据的角度建模和从机器理解数据的角度建模[2]。

样本分布的显示估计一般包括分布假设和参数学习,首先根据对数据的分析人为地假设数据服从某种分布形式,然后通过极大似然估计等方法选择令样本数据似然度最大的分布参数,即选择最可能产生所有训练样本的分布。如图1所示,通过显示地估计数据样本的分布,可以将样本分布显示地表达成人类可以理解的分布形式。

样本分布的隐式估计则不需要假设数据的分布形式,如图2所示,直接根据所采集的样本数据训练生成模型,机器将根据自己对于数据的理解建立生成模型并产生新的数据。这样得到的生成模型无法对于样本的分布进行显示地表达,对于人类来说不具有可解释性,但是它所产生的样本确实人类可以理解的。在GAN提出之前,这种方法一般需要使用马尔科夫链进行模型训练,效率较低。


生成式对抗网络GAN-入门篇
图1:样本分布的显示估计。(图片来源:Goodfellow. (2016)[1])


生成式对抗网络GAN-入门篇
图2:样本分布的隐式估计。(图片来源:Goodfellow. (2016)[1])

2 什么是GAN

2.1 GAN的基本思想

生成对抗网络GAN(Generative Adversarial networks)是受了博弈论的启发,目的是希望通过对抗学习,有效地估计样本数据的真实分布pdata。GAN由两个互相博弈的生成模型G和判别模型D组成,其中,生成模型G需要尽可能地估计样本的真实分布,并据此产生新的样本,而判别模型G需要尽可能地区分出样本是来自真实分布还是生成模型。在博弈的过程中,两个模型通过交替地学习,相互博弈,不断地提高自己的生成和判别能力,当判别模型的能力提高到一定程度,且无法正确区分样本的来源时,可以认为生成模型已习得了样本的真实分布,此时也达到了博弈的一个均衡点。

举例来说,GAN中判别模型和生成模型之间的博弈可以看成是警察和假币制造者之间的博弈。生成模型就好比是一个假币制作者,希望通过制造高仿的假币从中获利。而判别模型就好比是警察,能够专业地判断是纸币的真假。警察的判别能力和假币制造者的造假能力都在双方的博弈中不断提高,而当假币制造者制造的假币能够跟真的一模一样时,警察也就无法再分别真假了。

2.2 GAN的计算框架

如图3,判别模型D同时接受真实样本x(真样本,标注为1)和生成模型产生的样本G(z)(假样本,标注为0),判别模型要尽可能地判别出样本的真伪;生成模型G的输入为随机变量 z ,表示的是从随机变量到近似样本空间pg的一种映射,尽可能地产生符合真实数据分布的样本G(z);根据判别模型的判别结果和样本的真实情况,对于判别模型和生成模型进行迭代更新。


生成式对抗网络GAN-入门篇
图3: GAN的计算框架.(图片来源:WANG Kun-Feng et al.(2017)[2])

2.3 GAN的目标函数

2.3.1 判别模型的目标函数

在GAN中,判别模型D是一个二分类器,要尽可能地将真样本(源于真实数据)判为1,将伪样本(由生成模型产生)判为0,因此,可得目标函数如下:

max  J(D)(θD,θG)Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中,x为源于真实数据分布 pdata的真样本,z为源于先验分布 p(z)的随机噪声,G(z)为生成模型产生的伪样本,判别模型 D要尽可能地令 D(x)=1,D(G(x))=0.

2.3.2 最小最大博弈问题

生成模型则要尽可能地逼近真实的数据分布,令产生的伪样本更接近真样本以骗过判别器,因此可以假设生成模型和判别模型之间为零和博弈,令两者的目标函数正好相反,形式如下:

max  J(G),J(G)=J(D)

 min  J(G),J(G)=J(D)

因此,也可以得到GAN的整体目标函数,将GAN问题表示生成如下的极小-极大优化问题:


minθG maxθDJ(D,G)Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))]

2.3.3 非饱和博弈

但是,在实际应用中采用最小最大问题博弈问题得到的目标函数的效果并不够好。在最小最大博弈问题中,判别模型要最大化目标函数,生成模型要最小化目标函数,但是当判别函数以较高的置信度拒绝了由生成模型产生的样本时,生成模型的更新梯度将很小,对于生成模型的训练存在梯度消失的情况,具体分析如下:

因为在最小最大博弈中,生成模型的目标函数为:

min J(G)Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))]

在更新G时第一项为常数,当判别函数以较高的置信度否定了由生成模型产生的样本即 D(G(z)为较小值(接近0)时,log(1D(G(z)))关于 D的导数为1,因此 J(G)θG较小,容易产生梯度消失,则导致生成模型无法得到有效的学习。为了解决这种梯度消失的情况,可以定义新的生成模型的目标函数如下:

JG=12EzlogD(G(z))

此时,当判别函数以较高的置信度否定了由生成模型产生的样本即 D(G(z)为较小值(接近0)时, log(D(G(z)))关于D的导数也为较大值(接近 ),可以保证生成模型得到有效的学习。

非饱和博弈可以保证当生成模型 G在博弈中处于劣势时,仍能够得到较大梯度和充分的学习,弥补了最小最大博弈中的不足。

2.4 GAN的最优解

可证明GAN具有唯一的全局最优解 pg=pdata

  1. 首先证明,当生成模型 G固定时,最优的判别模型为:

    DG(x)=pdata(x)pdata(x)+pg(x)

    G固定时,判别模型 D的目标函数 J(D)可写成:


    max J(D)Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))]=xpdata(x)logD(x)dx+zpz(z)log(1D(G(z)))dz=xpdata(x)logD(x)+pg(x)log(1D(x)) dx

    对于任意的非零实数 a,b且实数 y[0,1],表达式 alog(y)+blog(1y)可在 aa+b处取得最大值。

    因此,类比于 J(D)的表达式,当 G固定时,DG(x)=pdata(x)pdata(x)+pg(x)可令目标函数 J(D)取得最大值。

  2. 再证明,当且仅当 pg=pdata时,目标函数可取得全局最优值

    D固定时,可得GAN的整体目标函数如下:

    minG C(G)minGJ(D,G) =Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))] =Expdata[logD(x)]+Expg(x)[log(1D(x)) =Expdata[logpdata(x)pdata(x)+pg(x)]+Expg(x)log(pg(x)pdata(x)+pg(x)) =Expdata[logpdata(x)pdata(x)+pg(x)]+Expg(x)log(pg(x)pdata(x)+pg(x))Expdata[log12]Expg[log12]+log[14]  =Expdata[logpdata(x)pdata(x)+pg(x)2]+Expg(x)log(pg(x)pdata(x)+pg(x)2)log4  =KL(pdata||pdata+pg2)+KL(pg||pdata+pg2)log4  =2JSD(pdata||pg)log4  log4

其中,KL( || ),JSD( || ) 分别表示 KL散度和 JSD散度,表示两个分布之间的差距;当且仅当pg=pdata时,JSD(pdata||pg)=0,目标函数取得全局最优值 log4

虽然,在理论分析中,GAN存在最优解,但是在具体的实现中往往无法达到。

3 GAN的训练

3.1 训练方法

在GAN的学习中,我们需要训练判别模型 D令目标函数最大化,训练生成模型 G令目标函数最小化,因此无法同时训练两个模型。可以采用交替的优化方法:先固定生成生成模型,更新判别模型,令目标函数最大化;再固定判别模型,更新生成模型,令目标函数最小化,直至收敛。在2014的文章[3]中,Goodfellow等推荐令判别模型D先更新k步,再让生成模型G更新1步;但在2016的文章[1]中,Goodfellow又认为两个模型交替更新且每次各更新一步的效果最好。

3.2 训练过程

通过判别模型和生成模型的交替更新,判别模型区分真伪样本的能力不断提高,同时生成模型的分布pg
越来越靠近真实的数据分布pdata,判别模型和生成模型的更新变化可以表示如图4所示。

观察图4(b),可以了解判别模型D的判别标准。根据第二部分的证明,固定G时最优的判别模型为 D(x)=pdata(x)pdata(x)+pg(x)。在图4(b)中,对于分布于左侧的样本 xpg0,因此比例 pdata(x)pdata(x)+pg(x)1,该侧的样本 x更可能是真实样本;对于分布于右侧的样本 xpdata0,因此比例 pdata(x)pdata(x)+pg(x)0,该侧的样本x更可能是由生成模型产生的样本。

观察图4(b)(c)(d),可以了解生成模型的更新。判别模型会指导生成模型的更新,令生成模型的分布向着更可能被判别模型判别为真实数据的方向更新,直至生成模型能完全地拟合真实分布。


生成式对抗网络GAN-入门篇
图4: GAN的训练过程. z为随机噪声,x表示样本(包括真实样本和生成模型产生的样本),黑色箭头表示映射 x=G(z),黑色虚线表示真实数据分布 pdata,绿色实线表示生成模型的分布 pg,蓝色虚线表示判别模型。(图片来源:Goodfellow et al. (2014b)[3])

3.3 DCGAN的框架

DCGAN[4]是如今用于构建GAN网络的常见且有效的框架,DCGAN的主要特点在于:

  • 采用了全卷积网络的结构,去除了池化层和最后的全连接层。其中,生成模型的网络结构如图5所示,首先将100维的随机向量通过全连接映射成8096维的向量,然后重构成441024的特征图,随后通过一连串的转置卷积或小数步长卷积运算,最终得到 64643的RGB输出结果。
  • 使用批量归一化算法:深度学习的模型在训练过程中每一层的参数都在发生着变化,而前面层训练参数的更新将导致后面层输入数据分布的变化,因此每一层的网络都要去适应变化的数据分布,这将影响网络的训练速度。批量归一化算法通过对深层网络中每一层的输入都进行归一化,减少内部协变量转移,大大地加快了深度神经网络的训练,而且还把归一化的步骤也作为模型训练架构的一部分来实现。批量归一化的效果还包括可以采用较高的学习率,对于网络的初始情况不用太在意,在一定情况况下也可以起到正则化的作用,并减轻了对Dropout的需求。


生成式对抗网络GAN-入门篇
图5: 采用DCGAN框架的生成网络.(图片来源:Radford et al.(2015)[4])

4 GAN的效果与问题

4.1 GAN的应用

  • GAN最直接的应用就是在生成任务上。如图6所示,Twitter利用GAN由低分辨率图像生成高分辨率的图像[5];如图 7所示,根据草图完成图像[6]。
  • GAN在处理多模态输出问题(对于单一输入有多种可行的输出)上也非常有效,如图8所示,将GAN用于视频预测[7]。
  • 除此之外,GAN在半监督学习和强化学习中也都有应用。


生成式对抗网络GAN-入门篇
图6:由低分辨率图像生成高分辨率的图像.(图片来源:Ledig et al. (2016))


生成式对抗网络GAN-入门篇
图7:图与图之间的转化.(图片来源:Isola et al. (2016))

生成式对抗网络GAN-入门篇

图8:视频预测.(图片来源:Lotter et al. (2015))

4.2 不收敛

可能存在不收敛的情况是GAN训练的一大问题。常见的机器学习问题可以表示为单个目标函数的优化,利用优化算法可以沿着令目标函数下降的方向更新迭代,直至找到最优解。而GAN是一个博弈问题,需要寻找博弈问题的均衡点,其涉及两个目标函数的优化,在优化过程中,虽然我们可以分别地令单个目标函数下降,但却也可能同时使得另一个目标函数上升。因此,在GAN的优化过程中,有时我们可以到达博弈的均衡点,有时却可能绕着均衡点不断打转甚至发散。

4.2.1 模式奔溃

模式奔溃是指习得的生成模型 G会将不同的输入 z都映射至一个固定的输出 x。GAN中经常出现部分模式奔溃的现象,生成的多幅图像具有相同的颜色、纹理或者物体,使得输出的多样性受到了的限制。

如图9,目标的真实数据分布为一个二维的混合高斯分布,但是在训练过程中,生成模型仅仅能产生其中一部分的数据。这是由于生成模型 G仅仅收敛到了判别模型 D认为更可能属于真实样本分布的部分区域,即此时的 G仅仅是一个局部最优解而不是全部最优解[8]。如图10,在这个由文本生成图片的任务中[9],GAN的输出多样性较差。当然,可以利用minibatch features和unrolled GAN等方法缓解模式崩溃情况。

生成式对抗网络GAN-入门篇

图9: 模式奔溃.(图片来源:Metz et al. (2016))

生成式对抗网络GAN-入门篇

图10:文本-图像转化任务重的模式奔溃.(图片来源:Reed et al. (2016a))

4.3 生成模型的评估困难

生成模型的性能无法得到科学定量的评估。人类可以较容易地判断生成模型所产生照片是否正确是否真实,但是机器还无法做到。

4.3.1 间接采样似然度方法ISL

Bengio等[10]提出过一种生成模型质量的评估方法,并将称其为间接采样似然度方法(Indirect Sampling Likelihood)。ISL方法的主要过程包括如下:

  • 利用待评估的生成模型产生样本集合 S
  • 根据样本集合 S,训练并估计生成模型的分布 P
  • 计算测试数据(真实数据)在分布 P下的似然度。

测试数据集的似然度可以在一定程度上反映生成模型分布与真实数据分布的相似度,因此可用测试数据似然度来衡量生成模型的质量。在GAN的首篇文章中[3],Goodfellow等也采取了类似的评价方法,首先利用生成模型 G产生大量样本,然后用高斯帕森窗(Gaussian Parzen window)方法来估计测试数据集在生成模型分布 pg下的对数似然度。但是,当生成模型 G产生的样本数目不够多,或者样本的多样性较差时无法较好地估计分布 pg,此时利用对数似然度进行模型评估则不够准确。Goodfellow也认为如今还没有十分有效的方法可以对生成模型进行准确的定量评估。

参考文献

[1] Goodfellow I. NIPS 2016 tutorial: Generative adversarial networks[J]. arXiv preprint arXiv:1701.00160, 2016.

[2] 王坤峰, 苟超, 段艳杰, 等. 生成式对抗网络 GAN 的研究进展与展望[J]. 自动化学报, 2017, 43(3): 321-332.

[3] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.

[4] Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.

[5] Ledig C, Theis L, Huszár F, et al. Photo-realistic single image super-resolution using a generative adversarial network[J]. arXiv preprint arXiv:1609.04802, 2016.

[6] Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[J]. arXiv preprint arXiv:1611.07004, 2016.

[7] Lotter W, Kreiman G, Cox D. Unsupervised learning of visual structure using predictive generative networks[J]. arXiv preprint arXiv:1511.06380, 2015.

[8] Metz L, Poole B, Pfau D, et al. Unrolled generative adversarial networks[J]. arXiv preprint arXiv:1611.02163, 2016.

[9] Reed S, van den Oord A, Kalchbrenner N, et al. Generating interpretable images with controllable structure[J]. 2016.

[10] Breuleux O, Bengio Y, Vincent P. Unlearning for better mixing[J]. Universite de Montreal/DIRO, 2010.