1. 极大似然估计
GAN用到了极大似然估计(MLE),因此我们对MLE作简单介绍。
MLE的目标是从样本数据中估计出真实的数据分布情况,所用的方法是最大化样本数据在估计出的模型上的出现概率,也即选定使得样本数据出现的概率最大的模型,作为真实的数据分布。
将真实模型用参数θ theta θ 表示,则在模型θ theta θ 下,样本数据的出现概率(likelihood)是(1) ∏ i = 1 m p m o d e l ( x i ; θ ) prod_{i=1}^mp_{model}(x_i; theta) tag{1} i = 1 ∏ m p m o d e l ( x i ; θ ) ( 1 )
其中x i x_i x i 表示样本中的第i i i 个数据。
最大化(1)式的概率,求得满足条件的θ theta θ :θ ∗ = arg max θ ∏ i = 1 m p m o d e l ( x i ; θ ) = arg max θ ∑ i = 1 m log p m o d e l ( x i ; θ ) begin{aligned}theta^* & = argmax_thetaprod_{i=1}^mp_{model}(x_i; theta) \&= argmax_thetasum_{i=1}^mlog p_{model}(x_i; theta) \end{aligned} θ ∗ = arg θ max i = 1 ∏ m p m o d e l ( x i ; θ ) = arg θ max i = 1 ∑ m log p m o d e l ( x i ; θ )
还可以使用KL散度来代表MLE方法:θ ∗ = arg min θ D K L ( p d a t a ( x ) ∣ ∣ p m o d e l ( x ; θ ) = arg min θ { ∑ i = 1 m p d a t a ( x i ) log p d a t a ( x i ) − ∑ i = 1 m p d a t a ( x i ) log p m o d e l ( x i ; θ ) } = − arg min θ ∑ i = 1 m p d a t a ( x i ) log p m o d e l ( x i ; θ ) = arg max θ ∑ i = 1 m p d a t a ( x i ) log p m o d e l ( x i ; θ ) begin{aligned}theta^*&=argmin_theta D_{KL}(p_{data}(x) || p_{model}(x;theta)\& = argmin_thetaleft{ sum_{i=1}^mp_{data}(x_i)log p_{data}(x_i) - sum_{i=1}^mp_{data}(x_i)log p_{model}(x_i;theta) right}\& = -argmin_thetasum_{i=1}^mp_{data}(x_i)log p_{model}(x_i;theta) \& = argmax_thetasum_{i=1}^mp_{data}(x_i)log p_{model}(x_i;theta)end{aligned} θ ∗ = arg θ min D K L ( p d a t a ( x ) ∣ ∣ p m o d e l ( x ; θ ) = arg θ min { i = 1 ∑ m p d a t a ( x i ) log p d a t a ( x i ) − i = 1 ∑ m p d a t a ( x i ) log p m o d e l ( x i ; θ ) } = − arg θ min i = 1 ∑ m p d a t a ( x i ) log p m o d e l ( x i ; θ ) = arg θ max i = 1 ∑ m p d a t a ( x i ) log p m o d e l ( x i ; θ )
在实际上,我们无法得到数据的真实分布p d a t a p_{data} p d a t a ,但是可以从m m m 个数据的样本中近似得到一个估计p ^ d a t a hat{p}_{data} p ^ d a t a 。
为了便于理解KL散度,我们在下面对其进行简要介绍。
2. 相对熵,KL散度
两个概率分布P P P 和Q Q Q 的KL散度定义如下:D K L ( P ∣ ∣ Q ) = ∑ i P ( i ) log P ( i ) Q ( i ) D_{KL}(P||Q)=sum_iP(i)log{frac{P(i)}{Q(i)}} D K L ( P ∣ ∣ Q ) = i ∑ P ( i ) log Q ( i ) P ( i )
性质 :D K L ( P ∣ ∣ Q ) ≥ 0 D_{KL}(P||Q)ge0 D K L ( P ∣ ∣ Q ) ≥ 0
当且仅当P = Q P=Q P = Q 时,等号成立。(证明过程借用吉布斯不等式 :∑ i p i log p i ≥ ∑ i p i log q i sum_ip_ilog p_igesum_ip_ilog q_i ∑ i p i log p i ≥ ∑ i p i log q i ,证明吉布斯不等式会用到关系log x ≤ x − 1 log x le x - 1 log x ≤ x − 1 )
KL散度反映了两个分布P P P 和Q Q Q 的相似情况,KL散度越小,两个分布越相似。
KL散度是不对称的:D K L ( P ∣ ∣ Q ) ≠ D K L ( Q ∣ ∣ P ) D_{KL}(P||Q) quadneq D_{KL}(Q||P) D K L ( P ∣ ∣ Q ) ̸ = D K L ( Q ∣ ∣ P )
3. KL散度与交叉熵的关系
神经网络中常常使用交叉熵作为损失函数:L = − ∑ i y i log h i L = -sum_i y_ilog h_i L = − i ∑ y i log h i
其中y i y_i y i 是实际的标签值,h i h_i h i 是网络的输出值。
我们将y y y 和h h h 的KL散度展开,得到:D K L ( y ∣ ∣ h ) = ∑ i y i log y i h i = ∑ i y i log y i − ∑ i y i log h i = ∑ i y i log y i + L = C o n s t a n t + L begin{aligned}D_{KL}(y||h) & = sum_iy_ilog{frac{y_i}{h_i}}\& = sum_iy_ilog y_i - sum_iy_ilog h_i\& = sum_iy_ilog y_i + L\&= Constant + Lend{aligned} D K L ( y ∣ ∣ h ) = i ∑ y i log h i y i = i ∑ y i log y i − i ∑ y i log h i = i ∑ y i log y i + L = C o n s t a n t + L
因此,最小化KL散度,等价于最小化损失函数L L L 。也即交叉熵损失函数反应的是网络输出结果和样本实际标签结果的KL散度的大小,交叉熵越小,KL散度也越小,网络的输出结果越接近实际值 。
4. JS散度
对于两个分布P P P 和Q Q Q ,JS散度是:D J S ( P ∣ ∣ Q ) = 1 2 D K L ( P ∣ ∣ P + Q 2 ) + 1 2 D K L ( Q ∣ ∣ P + Q 2 ) D_{JS}(P||Q) = frac{1}{2}D_{KL}(P||frac{P+Q}{2}) + frac{1}{2}D_{KL}(Q||frac{P+Q}{2}) D J S ( P ∣ ∣ Q ) = 2 1 D K L ( P ∣ ∣ 2 P + Q ) + 2 1 D K L ( Q ∣ ∣ 2 P + Q )
JS散度是对称的,并且有界[ 0 , log 2 ] [0, log2] [ 0 , log 2 ] 。
5. GAN 框架
生成器 ,生成与训练集数据相同分布的样本;判别器 ,检查生成器生成的样本是真的还是假的。 The generator is trained to fool the discriminator.
判别器的损失函数
判别器的损失函数为:(2) J ( D ) ( θ ( D ) , θ ( G ) ) = − 1 2 E x ∼ p d a t a log D ( x ) − 1 2 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) ) J^{(D)}(theta^{(D)}, theta^{(G)})= -frac{1}{2}mathbb{E}_{xsim p_{data}}log D(x) - frac{1}{2}mathbb{E}_{zsim p_{model}}log (1-D(G(z)))tag{2} J ( D ) ( θ ( D ) , θ ( G ) ) = − 2 1 E x ∼ p d a t a log D ( x ) − 2 1 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) ) ( 2 )
上式其实就是一个交叉熵损失函数。GAN的判别器在训练的过程中,数据集包含两个部分,一部分是训练集的样本x x x ,对应的标签y = 1 y=1 y = 1 ,一部分是生成器生成的数据G ( z ) G(z) G ( z ) ,对应的标签y = 0 y=0 y = 0 ,因此判别器的训练集可以看做X = { x , G ( z ) } , Y = { 1 , 0 } X={x, G(z)}, Y={1, 0} X = { x , G ( z ) } , Y = { 1 , 0 } 。
训练集样本是X X X ,标签是Y Y Y ,网络输出是H H H ,则交叉熵损失函数为:(3) J = 1 m ∑ i = 1 m { − Y i log H i − ( 1 − Y i ) log ( 1 − H i ) } J = frac{1}{m} sum_{i=1}^m{-Y_ilog H_i - (1-Y_i)log(1-H_i)}tag{3} J = m 1 i = 1 ∑ m { − Y i log H i − ( 1 − Y i ) log ( 1 − H i ) } ( 3 )
与式(2)作比较,前一项的log H log H log H 等价于式(2)中的log D ( x ) log D(x) log D ( x ) ,后一项的log ( 1 − H i ) log(1-H_i) log ( 1 − H i ) 等价于式(2)中的log ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log ( 1 − D ( G ( z ) ) ) 。将x x x 看做包含了真实样本和生成器生成的数据G ( z ) G(z) G ( z ) 的新的训练集,则判别器的损失函数可以重新写作:(4) J ( D ) ( θ ( D ) , θ ( G ) ) = − 1 2 E x ∼ p d a t a log D ( x ) − 1 2 E x ∼ p m o d e l log ( 1 − D ( x ) ) = − 1 2 ∑ i p d a t a ( x i ) log D ( x i ) − 1 2 ∑ i p m o d e l ( x i ) log ( 1 − D ( x i ) ) begin{aligned}J^{(D)}(theta^{(D)}, theta^{(G)}) &= -frac{1}{2}mathbb{E}_{xsim p_{data}}log D(x) - frac{1}{2}mathbb{E}_{xsim p_{model}}log (1-D(x))\&= -frac{1}{2} sum_ip_{data}(x_i)log D(x_i) -frac{1}{2}sum_i p_{model}(x_i) log (1-D(x_i))end{aligned}tag{4} J ( D ) ( θ ( D ) , θ ( G ) ) = − 2 1 E x ∼ p d a t a log D ( x ) − 2 1 E x ∼ p m o d e l log ( 1 − D ( x ) ) = − 2 1 i ∑ p d a t a ( x i ) log D ( x i ) − 2 1 i ∑ p m o d e l ( x i ) log ( 1 − D ( x i ) ) ( 4 )
对上式关于D ( x ) D(x) D ( x ) 求导,并令导数为0,得到:D ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p m o d e l ( x ) D^*(x) = frac{p_{data}(x)}{p_{data}(x)+p_{model}(x)} D ∗ ( x ) = p d a t a ( x ) + p m o d e l ( x ) p d a t a ( x )
生成器的损失函数
令J ( G ) = − J ( D ) J^{(G)}=-J^{(D)} J ( G ) = − J ( D ) ,则J ( G ) ( θ ( D ) , θ ( G ) ) = 1 2 E x ∼ p d a t a log D ( x ) + 1 2 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) ) = C o n s t a n t + 1 2 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) ) begin{aligned}J^{(G)}(theta^{(D)}, theta^{(G)}) &= frac{1}{2}mathbb{E}_{xsim p_{data}}log D(x) + frac{1}{2}mathbb{E}_{zsim p_{model}}log (1-D(G(z)))\& = Constant + frac{1}{2}mathbb{E}_{zsim p_{model}}log (1-D(G(z)))end{aligned} J ( G ) ( θ ( D ) , θ ( G ) ) = 2 1 E x ∼ p d a t a log D ( x ) + 2 1 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) ) = C o n s t a n t + 2 1 E z ∼ p m o d e l log ( 1 − D ( G ( z ) ) )
生成器没有直接 接受任何的训练集数据,训练集数据的信息是通过判别器学习后传递 过来的。