生成对抗网络(GAN)之 Basic Theory 学习笔记

  前言:最近学习了李宏毅生成对抗网络篇(2018年)的视频(视频地址:李宏毅对抗生成网络(GAN)国语教程(2018)),因为截止今天(3.23),2020版还未讲到生成对抗网络,因此选择18年。本次学习笔记主要为Basic Theory部分,主要讲解GAN的数学原理。
  GAN又称生成对抗网络,是由Ian Goodfellow等人在2014年提出的一种训练策略,论文地址https://arxiv.org/abs/1406.2661,也称为对抗学习(Adversarial Learning),其主要由两个部分组成,一个是生成器(或者说是采样器)(Generator),另一个是判别器(Discriminator)。生成器主要负责从低维度简单的分布(例如正态分布、均匀分布等)随机采样,并映射到一个复杂高维度空间中,判别器的目标则是根据生成器生成的分布和真实世界中采样的样本进行区分,尽可能的给真实世界的样本高分,给生成器生成的假数据低分。一般来说,判别器可以当做一个0-1分类器。


1、GAN基本实现原理

  GAN的生成原理如下图所示:

生成对抗网络(GAN)之 Basic Theory 学习笔记

  结合图,给出一个形式化的描述。给定一个低维度随机分布(例如Normal Distribution)作为Noise,从中随机采样一个样本 XX,生成器首先通过计算将其生成高维度空间中的样本,例如如果是图像生成任务,则生成器的输出 G(X)G(X) 表示一张图像。然后由判别器获取到生成器生成的图像 G(X)G(X) 以及从真实世界中随机采样的真实图像 II,判别器分别为其进行打分 D(G(X))D(G(X))D(I)D(I),目标则是使得 D(G(X))=0D(G(X)) = 0D(I)=1D(I) = 1。生成器与判别器是一组相互制衡的Object,判别器尽可能的来区分哪些样本是来自真实世界,哪些来自生成器制造的假数据;而生成器则是尽可能生成一些看起来很像真实世界的数据来迷惑判别器。

  一个比较直观的例子就是小偷与警察,这在许多博客中也有所提及。初始化时候,小偷的偷盗水平很差,警察的办案能力也不高,但始终有一个条件就是警察的水平总会比小偷高一点,因此警察总会抓住一些水平很低的小偷。小偷为了生存则不断提高自己的偷盗水平和反侦察能力,警察也发现许多小偷的水平太高了来提升自己的办案水平。如此进行下去。当然我们并不希望小偷的水平能够达到警察都分辨不出来的底部,但是对于GAN来讲,最终的结果则是我们更希望得到一个强大的生成器,因为它能够在判别器不断的指导下来生成非常真实的东西,以至于连判别器都没有办法区分它们。

  为了能够形象的描述,我们假设蓝色曲线表示的是生成器生成的复杂的高维度分布,绿色表示实际真实分布,红色则表示判别器的判别分布。判别器目标就是在生成器生成的分布部分给与低分,在真实分布部分给与高分,而随着不断地迭代,生成样本会在真实样本附近来回的震荡,其会经过所有可能会使得判别器给分非常高但并非是真是样本分布的部分,最终理想状态是右下角的图,两者分布完全吻合。
生成对抗网络(GAN)之 Basic Theory 学习笔记

2、GAN数学基础

  通过对GAN的理解,我们会产生疑问,生成器与判别器如何去衡量它们的好坏?首先我们从统计学角度分析。
  我们知道,生成器输入的是从低维度分布中随机采样的噪声,也就是先验概率分布,例如我们可以选择正态分布。而对于图像等这一类数据往往是位于高纬度空间,且真正有价值、可读的图像只是其中一小部分,如下图所示。绿色的部分就是生成器生成的高维度空间的分布PG(xz)=PG(x=G(z))P_G(x|z)=P_G(x=G(z)),最左侧浅蓝色则是潜在未知的真实分布Pdata(x)P_{data}(x),我们更希望能够让这两个分布距离Div(PG,Pdata)Div(P_G,P_{data})越小,即寻找一个最优的判别器以满足 G=arg minGDiv(PG,Pdata)G^*=argmin_{G}Div(P_G,P_{data}) 。但事实上,我们无法直接去计算两者。

生成对抗网络(GAN)之 Basic Theory 学习笔记

2.1、最大似然估计与KL散度

  通过分析,GAN可以被认为是衡量生成的分布与真实分布的距离,我们希望这个距离尽可能的小。在统计学中,我们是已知先验分布 zN(μ,σ)zsim N(mu,sigma),从中随机采样一组样本 x1,x2,x3,...,xmx_1,x_2,x_3,...,x_m,我们可以根据这些采样来对高维度分布进行最大似然估计,即有一组参数
θ=arg maxθi=1mPG(xi,θ)=arg maxθi=1mlogPG(xi,θ)theta^*=argmax_{theta}prod_{i=1}^{m}P_G(x_i,theta)=argmax_{theta}sum_{i=1}^{m}logP_G(x_i,theta)

而上式可以再加上一个无关项 i=1mlogPdata(xi,θ)-sum_{i=1}^{m}logP_{data}(x_i,theta),也就是说:

θ=arg maxθi=1mlogPG(xi,θ)i=1mlogPdata(xi,θ)theta^*=argmax_{theta}sum_{i=1}^{m}logP_G(x_i,theta) - sum_{i=1}^{m}logP_{data}(x_i,theta)

θ=arg minθKL(PdataPG)theta^*=argmin_{theta}KL(P_{data}||P_G)

因此说,如果能够找到一个参数 θtheta^* 使得生成器生成的样本与真实样本的分布KL散度最小,这组参数就是我们所学习的目标。但是事实上KL散度是不对称的,并不能直接作为GAN学习的目标。

2.2、GAN目标函数与JS散度

  在上面我们提到GAN的目标是尽可能的让判别器分辨出真假数据,生成器则是尽可能欺骗判别器,在Ian Goodfellow的论文中,给出了比较清晰的训练目标函数,如下所示:

minGmaxDV(D,G)=ExPdata[logD(x)]+ExPz[1logD(G(z))]min_{G} max_{D} V(D,G)=mathbb{E}_{xsim P_{data}}[logD(x)] + mathbb{E}_{xsim P_{z}}[1-logD(G(z))]

这是一个被称为min max游戏的任务,也符合GAN的训练机制:先固定生成器 GG,寻找当前最优的 DGD_{G}^* 能够使得 maxDV(G,D)=V(G,DG)max_{D}V(G,D)=V(G,D_{G}^*),其次固定判别器 DD,寻找能够使得 G=arg minGDiv(PG,Pdata)=V(G,DG)G^*=argmin_{G}Div(P_G,P_{data}) = V(G^*,D_{G}^*) 。而事实上,这个min max游戏本质是最小化 PGP_{G}PdataP_{data} 的JS散度,推导如下图所示:

首先我们固定生成器,求最大化的 V(G,DG)V(G,D_{G}^*),此时 V(G,D)V(G,D) 可以看做是只与 DD 有关的一元函数,我们用积分来描述期望:

V(G,D)=xPdata(x)logD(x)dx+zPG(z)log(1D(G(z)))dzV(G,D) = int_{x}P_{data}(x)logD(x) dx+ int_{z}P_{G}(z)log(1-D(G(z)))dz

=x[Pdata(x)logD(x)+PG(x)log(1D(x)]dx=int_{x}[P_{data}(x)logD(x) + P_{G}(x)log(1-D(x)]dx

这里使用一次积分换元,将zz的积分换为对xx的积分。因为此时 arg maxDV(G,D)=arg maxDPdata(x)logD(x)+PG(x)log(1D(x))argmax_{D}V(G,D) = argmax_{D}P_{data}(x)logD(x) + P_{G}(x)log(1-D(x))

  令无关变量Pdata(x)=aP_{data}(x)=aPG(x)=bP_{G}(x)=b,则 f(D)=a×logD+b×log(1D)f(D)=atimes logD+btimes log(1-D),求导后得到极值点为 D=aa+bD=frac{a}{a+b},即

DG=arg maxDV(G,D)=Pdata(x)Pdata(x)+PG(x)D_{G}^*=argmax_{D}V(G,D) = frac{P_{data}(x)}{P_{data}(x) + P_{G}(x)}

代入到 V(G,D)V(G,D) 后,通过简单的变换,可以转换为JS散度的形式,如下图所示:
生成对抗网络(GAN)之 Basic Theory 学习笔记

所以说,min max游戏本质上是最小化JS散度,即min(2log2+2×JSD(PdataPG))min(-2log2+2times JSD(P_{data}||P_{G}))

  用图示来描述这个过程:假设有三个不同的生成器 G1,G2,G3G_1,G_2,G_3,其对应判别器生成的曲线为蓝色线条,而线条上的点到底边轴的距离即为 V(G,D)V(G,D),通过上面的公式推导,结合图我们很快理解,首先是找到蓝色线条的最大值点,这也每个生成器都对应一个最大值点;其次从所有生成器中寻找一个最小的最大值点。图中对应的就是 G3G_3

生成对抗网络(GAN)之 Basic Theory 学习笔记

2.3、GAN算法

  GAN算法如图所示:
生成对抗网络(GAN)之 Basic Theory 学习笔记
首先分别从真实数据和噪声中随机采样一组{x1,x2,...,xm}{x^{1},x^{2},...,x^{m}}{z1,z2,...,zm}{z^{1},z^{2},...,z^{m}},其中 mm 为batch_size。先固定生成器训练判别器。得到每个生成的数据 {x~1,x~2,...,x~m}{widetilde{x}^{1},widetilde{x}^{2},...,widetilde{x}^{m}},其中 x~i=G(zi)widetilde{x}^{i}=G(z^{i})。根据生成数据和真实数据的采样,训练判别器,目标最大化 V~widetilde{V},可采用梯度上升更新参数。其次固定判别器,训练生成器,最小化V~widetilde{V},梯度下降法更新参数。

  需要值得注意的是,虽然训练判别器和生成器使用的目标函数是一样,但一个是最大化,一个是最小化。另外在训练生成器时,可以简化目标函数为第二项 V=1mi=1mlog(1D(x~i))V=frac{1}{m}sum_{i=1}^{m}log(1-D(widetilde{x}^{i})),因为第一项相对于GG是一个常数项。使用这个作为目标函数的GAN被命名为MMGAN。还有,原文使用了另一个目标函数来优化生成器,如下图:
生成对抗网络(GAN)之 Basic Theory 学习笔记

也就是NSGAN,目标函数为 V=1mi=1mlog(D(x~i))V=-frac{1}{m}sum_{i=1}^{m}log(D(widetilde{x}^{i})) 其相比于MMGAN,其能够保持梯度的方向是不变,但梯度值会比较大,方便计算。

3、特别说明

  GAN在训练过程中有几点需要注意:
(1)GAN拟合速度比较慢,因为对于高维度空间的分布,GAN经常会生成一些肉眼无法理解的内容,而需要迭代非常多次才能达到比较稳定的范围内;
(2)判别器需要保证尽可能的或接近收敛,而生成器不能训练太强。可以想象,如果生成器训练的很强,而判别器还没有达到一个较好的结果,此时判别器就无法判别出谁是真实数据,谁是假数据,此时可能导致训练终止了但生成器生成的样本还是很糟糕。通常情况下,在每一次对抗过程中,判别器尽可能训练多次到达接近收敛,生成器训练1-3次。
(3)在实际对抗训练中,有一种策略可以实现简单的对抗模式。我们假设GAN用于噪声分类,即判断给定一个一组样本中,哪些是positive,哪些是negative。可以先使用生成器生成一组序列,表示对应每个样本是噪声的概率,换句话说就是生成一组它认为是噪声negative的样本集合。然后使用判别器去判别这些集合是不是噪声。判别器会从真实数据(positive)采样一部分数据,但给它们标记为negative,而生成器生成的数据标记为positive,这样如果判别器无法判别出究竟是哪个是positive,哪个是negative时(也就是判别器性能下降了),生成器就能够比较准确的找出噪声。
(4)GAN还有诸多变形,本文只是GAN的基本模型。例如生成器可以使用卷积神经网络来提取图像特征,使用RNN来提取文本类特征等。判别器因为本质上是一个二分类器,在深度学习中,只要保证有一定深度,也可以使用CNN、RNN等编码器。另外,也有将自编码器结合到GAN中,例如VAEGAN。