转自:https://blog.csdn.net/leviopku/article/details/81292192
Generative adversarial network
据有关媒体统计:CVPR2018的论文里,有三分之一的论文与GAN有关
由此可见,GAN在视觉领域的未来多年内,将是一片沃土(CVer们是时候入门GAN了)。而发现这片矿源的就是GAN之父,Goodfellow大神。
文末有基于keras的GAN代码,有助于理解GAN的原理
生成对抗网络GAN,是当今的一大热门研究方向。在2014年,被Goodfellow大神提出来,当时的G神还只是蒙特利尔大学的博士生而已。
GAN之父的主页:
http://www.iangoodfellow.com/
GAN的论文首次出现在NIPS2014上,论文地址如下:
https://arxiv.org/pdf/1406.2661.pdf
入坑GAN,首先需要理由,GAN能做什么,为什么要学GAN。
GAN的初衷就是生成不存在于真实世界的数据,类似于使得 AI具有创造力或者想象力。应用场景如下:
AI作家,AI画家等需要创造力的AI体;
将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有所谓的“想象力”,能脑补情节;
进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。
以上的场景都可以找到相应的paper。而且GAN的用处也远不止此,期待我们继续挖掘,是发论文的好方向哦
GAN的原理介绍
这里介绍的是原生的GAN算法,虽然有一些不足,但提供了一种生成对抗性的新思路。放心,我这篇博文不会堆一大堆公式,只会提供一种理解思路。
理解GAN的两大护法G和D
G是generator,生成器: 负责凭空捏造数据出来
D是discriminator,判别器: 负责判断数据是不是真数据
这样可以简单的看作是两个网络的博弈过程。在最原始的GAN论文里面,G和D都是两个多层感知机网络。首先,注意一点,GAN操作的数据不一定非得是图像数据,不过为了更方便解释,我在这里用图像数据为例解释以下GAN:
稍微解释以下上图,z是随机噪声(就是随机生成的一些数,也就是GAN生成图像的源头)。D通过真图和假图的数据(相当于天然label),进行一个二分类神经网络训练(想各位必再熟悉不过了)。G根据一串随机数就可以捏造一个“假图像”出来,用这些假图去欺骗D,D负责辨别这是真图还是假图,会给出一个score。比如,G生成了一张图,在D这里得分很高,那证明G是很成功的;如果D能有效区分真假图,则G的效果还不太好,需要调整参数。GAN就是这么一个博弈的过程。
那么,GAN是怎么训练呢?
根据GAN的训练算法,我画一张图:
GAN的训练在同一轮梯度反传的过程中可以细分为2步,先训练D在训练G;注意不是等所有的D训练好以后,才开始训练G,因为D的训练也需要上一轮梯度反传中G的输出值作为输入。
当训练D的时候,上一轮G产生的图片,和真实图片,直接拼接在一起,作为x。然后根据,按顺序摆放0和1,假图对应0,真图对应1。然后就可以通过,x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反传了。(我在图片上举的例子是batch = 1,len(y)=2*batch,训练时通常可以取较大的batch)
当训练G的时候, 需要把G和D当作一个整体,我在这里取名叫做’D_on_G’。这个整体(下面简称DG系统)的输出仍然是score。输入一组随机向量,就可以在G生成一张图,通过D对生成的这张图进行打分,这就是DG系统的前向过程。score=1就是DG系统需要优化的目标,score和y=1之间的差异可以组成损失函数,然后可以反向传播梯度。注意,这里的D的参数是不可训练的。这样就能保证G的训练是符合D的打分标准的。这就好比:如果你参加考试,你别指望能改变老师的评分标准
需要注意的是,整个GAN的整个过程都是无监督的(后面会有监督性GAN比如cGAN),怎么理解这里的无监督呢?
这里,给的真图是没有经过人工标注的,你只知道这是真实的图片,比如全是人脸,而系统里的D并不知道来的图片是什么玩意儿,它只需要分辨真假。G也不知道自己生成的是什么玩意儿,反正就是学真图片的样子骗D。
正由于GAN的无监督,在生成过程中,G就会按照自己的意思天马行空生成一些“诡异”的图片,可怕的是D还能给一个很高的分数。比如,生成人脸极度扭曲的图片。这就是无监督目的性不强所导致的,所以在同年的NIPS大会上,有一篇论文conditional GAN就加入了监督性进去,将可控性增强,表现效果也好很多。
from future import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np
class GAN():
def init(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if name == ‘main’:
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:生成对抗网络——GAN(一) - Python技术站