论文笔记之: Conditional Generative Adversarial Nets
简介
Conditional Generative Adversarial Nets,简称CGAN,是一种生成对抗网络(GAN)的扩展。相对于传统的GAN,CGAN在输入噪声向量的基础上,额外输入了条件信息,使得生成的结果能够针对条件信息的不同而变化,具有更好的灵活性。CGAN最初由Mirza和Osalind等人于2014年提出,并且在实验中证明,相对于一般GAN,能够生成更为精细、多样化的结果。
实现原理
CGAN的实现基于一般的GAN,其核心思路是通过训练两个神经网络实现生成器和判别器的博弈:
- 生成器负责产生假数据样本;
- 判别器负责区分真实数据和假数据,并给出判断真实性的概率。
本文提出的CGAN在此基础上,通过将条件向量和噪声向量放入生成器与判别器,从而创造了一个特殊的生成对抗结构。简单来说,就是给一个GAN加入了一个条件变量。这个条件变量可以理解成是任意类型的辅助信息:比如图像中的标注、文本中的关键字,或者是其他结构化数据,这些信息将作为输入与噪声共同作为生成器的输入,从而生成对应的图片或文本。
代码实现
使用Keras框架可以快速实现一个基于MNIST数据集的CGAN。
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Embedding, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np
def build_generator(noise_dim, img_shape, num_classes):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=noise_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization())
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization())
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization())
model.add(Conv2D(img_shape[2], kernel_size=3, padding="same"))
model.add(Activation("tanh"))
noise = Input(shape=(noise_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(num_classes, noise_dim)(label))
model_input = Concatenate()([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
def build_discriminator(img_shape, num_classes):
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
model.add(Activation("relu"))
model.add(Dropout(rate=0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Dropout(rate=0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Dropout(rate=0.25))
model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Flatten())
img = Input(shape=img_shape)
features = model(img)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
input_features = Concatenate()([features, label_embedding])
validity = Dense(1, activation="sigmoid")(input_features)
return Model([img, label], validity)
def build_cgan(generator, discriminator):
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
z = Input(shape=(noise_dim,))
label = Input(shape=(1,))
img = generator([z, label])
discriminator.trainable = False
validity = discriminator([img, label])
cgan = Model([z, label], validity)
cgan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return cgan
def train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs, batch_size):
(X_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, axis=3)
X_train = (X_train/127.5) - 1
y_train_one = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Embed label to label embedding
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs, labels = X_train[idx], y_train[idx]
labels_one = tf.keras.utils.to_categorical(labels, num_classes=num_classes)
gen_labels = np.random.randint(0, num_classes, batch_size)
gen_labels_one = tf.keras.utils.to_categorical(gen_labels, num_classes=num_classes)
gen_noise = np.random.normal(0, 1, (batch_size, noise_dim))
gen_imgs = generator.predict([gen_noise, gen_labels])
d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
g_loss = cgan.train_on_batch([gen_noise, gen_labels], np.ones((batch_size, 1)))
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
if __name__ == '__main__':
noise_dim = 100
img_rows, img_cols, channels = 28, 28, 1
num_classes = 10
img_shape = (img_rows, img_cols, channels)
generator = build_generator(noise_dim, img_shape, num_classes)
discriminator = build_discriminator(img_shape, num_classes)
cgan = build_cgan(generator, discriminator)
train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs=30000, batch_size=32)
总结
CGAN是生成对抗网络GAN的一个扩展,可以使用条件变量来控制生成数据的特征。其基本原理与GAN相似,但对网络结构和训练方法上进行了改进。目前,CGAN已被成功应用于图片生成、风格迁移等领域,在生成多样性和可控性方面都有很好的表现。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:论文笔记之:Conditional Generative Adversarial Nets - Python技术站