生成式对抗网络GAN-入门篇
近年来,GAN逐渐成为了最热门的神经网络框架。在阅读了若干篇关于GAN的文章后,希望用尽量通俗易懂的语言对GAN的基本工作原理,最优解推导,训练方法,存在的问题以及应用做一个简单的总结并与大家交流,可作为GAN的一个入门参考。
1 写在GAN之前
生成模型和判别模型是机器学习领域的两个重要内容。判别模型是通过寻找不同类别数据之间的分界面来实现数据的分类;而生成模型则通过是估计数据的分布来生成新的数据,当然也可以用于分类。生成模型的构建方法可以大致分为两类,即显示地估计样本分布和隐式地估计样本分布[1],也人有将其称为从人理解数据的角度建模和从机器理解数据的角度建模[2]。
样本分布的显示估计一般包括分布假设和参数学习,首先根据对数据的分析人为地假设数据服从某种分布形式,然后通过极大似然估计等方法选择令样本数据似然度最大的分布参数,即选择最可能产生所有训练样本的分布。如图1所示,通过显示地估计数据样本的分布,可以将样本分布显示地表达成人类可以理解的分布形式。
样本分布的隐式估计则不需要假设数据的分布形式,如图2所示,直接根据所采集的样本数据训练生成模型,机器将根据自己对于数据的理解建立生成模型并产生新的数据。这样得到的生成模型无法对于样本的分布进行显示地表达,对于人类来说不具有可解释性,但是它所产生的样本确实人类可以理解的。在GAN提出之前,这种方法一般需要使用马尔科夫链进行模型训练,效率较低。
图1:样本分布的显示估计。(图片来源:Goodfellow. (2016)[1])
图2:样本分布的隐式估计。(图片来源:Goodfellow. (2016)[1])
2 什么是GAN
2.1 GAN的基本思想
生成对抗网络GAN(Generative Adversarial networks)是受了博弈论的启发,目的是希望通过对抗学习,有效地估计样本数据的真实分布
举例来说,GAN中判别模型和生成模型之间的博弈可以看成是警察和假币制造者之间的博弈。生成模型就好比是一个假币制作者,希望通过制造高仿的假币从中获利。而判别模型就好比是警察,能够专业地判断是纸币的真假。警察的判别能力和假币制造者的造假能力都在双方的博弈中不断提高,而当假币制造者制造的假币能够跟真的一模一样时,警察也就无法再分别真假了。
2.2 GAN的计算框架
如图3,判别模型
图3: GAN的计算框架.(图片来源:WANG Kun-Feng et al.(2017)[2])
2.3 GAN的目标函数
2.3.1 判别模型的目标函数
在GAN中,判别模型
其中,
2.3.2 最小最大博弈问题
生成模型则要尽可能地逼近真实的数据分布,令产生的伪样本更接近真样本以骗过判别器,因此可以假设生成模型和判别模型之间为零和博弈,令两者的目标函数正好相反,形式如下:
因此,也可以得到GAN的整体目标函数,将GAN问题表示生成如下的极小-极大优化问题:
2.3.3 非饱和博弈
但是,在实际应用中采用最小最大问题博弈问题得到的目标函数的效果并不够好。在最小最大博弈问题中,判别模型要最大化目标函数,生成模型要最小化目标函数,但是当判别函数以较高的置信度拒绝了由生成模型产生的样本时,生成模型的更新梯度将很小,对于生成模型的训练存在梯度消失的情况,具体分析如下:
因为在最小最大博弈中,生成模型的目标函数为:
在更新G时第一项为常数,当判别函数以较高的置信度否定了由生成模型产生的样本即
此时,当判别函数以较高的置信度否定了由生成模型产生的样本即
非饱和博弈可以保证当生成模型
2.4 GAN的最优解
可证明GAN具有唯一的全局最优解
-
首先证明,当生成模型
G 固定时,最优的判别模型为:D∗G(x)=pdata(x)pdata(x)+pg(x) 当
G 固定时,判别模型D 的目标函数J(D) 可写成:max J(D)≡Ex∼pdata[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]=∫xpdata(x)logD(x)dx+∫zpz(z)log(1−D(G(z)))dz=∫xpdata(x)logD(x)+pg(x)log(1−D(x)) dx 对于任意的非零实数
a,b 且实数y∈[0,1] ,表达式alog(y)+blog(1−y) 可在aa+b 处取得最大值。因此,类比于
J(D) 的表达式,当G 固定时,D∗G(x)=pdata(x)pdata(x)+pg(x) 可令目标函数J(D) 取得最大值。 -
再证明,当且仅当
pg=pdata 时,目标函数可取得全局最优值当
D 固定时,可得GAN的整体目标函数如下:minG C(G)≡minGJ(D∗,G) =Ex∼pdata[logD∗(x)]+Ez∼pz(z)[log(1−D∗(G(z)))] =Ex∼pdata[logD∗(x)]+Ex∼pg(x)[log(1−D∗(x)) =Ex∼pdata[logpdata(x)pdata(x)+pg(x)]+Ex∼pg(x)log(pg(x)pdata(x)+pg(x)) =Ex∼pdata[logpdata(x)pdata(x)+pg(x)]+Ex∼pg(x)log(pg(x)pdata(x)+pg(x))−Ex∼pdata[log12]−Ex∼pg[log12]+log[14] =Ex∼pdata[logpdata(x)pdata(x)+pg(x)2]+Ex∼pg(x)log(pg(x)pdata(x)+pg(x)2)−log4 =KL(pdata||pdata+pg2)+KL(pg||pdata+pg2)−log4 =2JSD(pdata||pg)−log4 ≥−log4
其中,
虽然,在理论分析中,GAN存在最优解,但是在具体的实现中往往无法达到。
3 GAN的训练
3.1 训练方法
在GAN的学习中,我们需要训练判别模型
3.2 训练过程
通过判别模型和生成模型的交替更新,判别模型区分真伪样本的能力不断提高,同时生成模型的分布
越来越靠近真实的数据分布
观察图4(b),可以了解判别模型
观察图4(b)(c)(d),可以了解生成模型的更新。判别模型会指导生成模型的更新,令生成模型的分布向着更可能被判别模型判别为真实数据的方向更新,直至生成模型能完全地拟合真实分布。
图4: GAN的训练过程.
3.3 DCGAN的框架
DCGAN[4]是如今用于构建GAN网络的常见且有效的框架,DCGAN的主要特点在于:
- 采用了全卷积网络的结构,去除了池化层和最后的全连接层。其中,生成模型的网络结构如图5所示,首先将100维的随机向量通过全连接映射成8096维的向量,然后重构成
4∗4∗1024 的特征图,随后通过一连串的转置卷积或小数步长卷积运算,最终得到64∗64∗3 的RGB输出结果。 - 使用批量归一化算法:深度学习的模型在训练过程中每一层的参数都在发生着变化,而前面层训练参数的更新将导致后面层输入数据分布的变化,因此每一层的网络都要去适应变化的数据分布,这将影响网络的训练速度。批量归一化算法通过对深层网络中每一层的输入都进行归一化,减少内部协变量转移,大大地加快了深度神经网络的训练,而且还把归一化的步骤也作为模型训练架构的一部分来实现。批量归一化的效果还包括可以采用较高的学习率,对于网络的初始情况不用太在意,在一定情况况下也可以起到正则化的作用,并减轻了对Dropout的需求。
图5: 采用DCGAN框架的生成网络.(图片来源:Radford et al.(2015)[4])
4 GAN的效果与问题
4.1 GAN的应用
- GAN最直接的应用就是在生成任务上。如图6所示,Twitter利用GAN由低分辨率图像生成高分辨率的图像[5];如图 7所示,根据草图完成图像[6]。
- GAN在处理多模态输出问题(对于单一输入有多种可行的输出)上也非常有效,如图8所示,将GAN用于视频预测[7]。
- 除此之外,GAN在半监督学习和强化学习中也都有应用。
图6:由低分辨率图像生成高分辨率的图像.(图片来源:Ledig et al. (2016))
图7:图与图之间的转化.(图片来源:Isola et al. (2016))
图8:视频预测.(图片来源:Lotter et al. (2015))
4.2 不收敛
可能存在不收敛的情况是GAN训练的一大问题。常见的机器学习问题可以表示为单个目标函数的优化,利用优化算法可以沿着令目标函数下降的方向更新迭代,直至找到最优解。而GAN是一个博弈问题,需要寻找博弈问题的均衡点,其涉及两个目标函数的优化,在优化过程中,虽然我们可以分别地令单个目标函数下降,但却也可能同时使得另一个目标函数上升。因此,在GAN的优化过程中,有时我们可以到达博弈的均衡点,有时却可能绕着均衡点不断打转甚至发散。
4.2.1 模式奔溃
模式奔溃是指习得的生成模型
如图9,目标的真实数据分布为一个二维的混合高斯分布,但是在训练过程中,生成模型仅仅能产生其中一部分的数据。这是由于生成模型
图9: 模式奔溃.(图片来源:Metz et al. (2016))
图10:文本-图像转化任务重的模式奔溃.(图片来源:Reed et al. (2016a))
4.3 生成模型的评估困难
生成模型的性能无法得到科学定量的评估。人类可以较容易地判断生成模型所产生照片是否正确是否真实,但是机器还无法做到。
4.3.1 间接采样似然度方法ISL
Bengio等[10]提出过一种生成模型质量的评估方法,并将称其为间接采样似然度方法(Indirect Sampling Likelihood)。ISL方法的主要过程包括如下:
- 利用待评估的生成模型产生样本集合
S ; - 根据样本集合
S ,训练并估计生成模型的分布P ; - 计算测试数据(真实数据)在分布
P 下的似然度。
测试数据集的似然度可以在一定程度上反映生成模型分布与真实数据分布的相似度,因此可用测试数据似然度来衡量生成模型的质量。在GAN的首篇文章中[3],Goodfellow等也采取了类似的评价方法,首先利用生成模型
参考文献
[1] Goodfellow I. NIPS 2016 tutorial: Generative adversarial networks[J]. arXiv preprint arXiv:1701.00160, 2016.
[2] 王坤峰, 苟超, 段艳杰, 等. 生成式对抗网络 GAN 的研究进展与展望[J]. 自动化学报, 2017, 43(3): 321-332.
[3] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.
[4] Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
[5] Ledig C, Theis L, Huszár F, et al. Photo-realistic single image super-resolution using a generative adversarial network[J]. arXiv preprint arXiv:1609.04802, 2016.
[6] Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[J]. arXiv preprint arXiv:1611.07004, 2016.
[7] Lotter W, Kreiman G, Cox D. Unsupervised learning of visual structure using predictive generative networks[J]. arXiv preprint arXiv:1511.06380, 2015.
[8] Metz L, Poole B, Pfau D, et al. Unrolled generative adversarial networks[J]. arXiv preprint arXiv:1611.02163, 2016.
[9] Reed S, van den Oord A, Kalchbrenner N, et al. Generating interpretable images with controllable structure[J]. 2016.
[10] Breuleux O, Bengio Y, Vincent P. Unlearning for better mixing[J]. Universite de Montreal/DIRO, 2010.
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:生成式对抗网络GAN-入门篇 - Python技术站