论文笔记之:Conditional Generative Adversarial Nets

论文笔记之: Conditional Generative Adversarial Nets

简介

Conditional Generative Adversarial Nets,简称CGAN,是一种生成对抗网络(GAN)的扩展。相对于传统的GAN,CGAN在输入噪声向量的基础上,额外输入了条件信息,使得生成的结果能够针对条件信息的不同而变化,具有更好的灵活性。CGAN最初由Mirza和Osalind等人于2014年提出,并且在实验中证明,相对于一般GAN,能够生成更为精细、多样化的结果。

实现原理

CGAN的实现基于一般的GAN,其核心思路是通过训练两个神经网络实现生成器和判别器的博弈:

  • 生成器负责产生假数据样本;
  • 判别器负责区分真实数据和假数据,并给出判断真实性的概率。

本文提出的CGAN在此基础上,通过将条件向量和噪声向量放入生成器与判别器,从而创造了一个特殊的生成对抗结构。简单来说,就是给一个GAN加入了一个条件变量。这个条件变量可以理解成是任意类型的辅助信息:比如图像中的标注、文本中的关键字,或者是其他结构化数据,这些信息将作为输入与噪声共同作为生成器的输入,从而生成对应的图片或文本。

代码实现

使用Keras框架可以快速实现一个基于MNIST数据集的CGAN。

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Embedding, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np


def build_generator(noise_dim, img_shape, num_classes):
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=noise_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(Conv2D(img_shape[2], kernel_size=3, padding="same"))
    model.add(Activation("tanh"))

    noise = Input(shape=(noise_dim,))
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, noise_dim)(label))

    model_input = Concatenate()([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

def build_discriminator(img_shape, num_classes):
    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Flatten())
    img = Input(shape=img_shape)

    features = model(img)

    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))

    input_features = Concatenate()([features, label_embedding])

    validity = Dense(1, activation="sigmoid")(input_features)

    return Model([img, label], validity)

def build_cgan(generator, discriminator):

    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

    z = Input(shape=(noise_dim,))
    label = Input(shape=(1,))
    img = generator([z, label])

    discriminator.trainable = False
    validity = discriminator([img, label])
    cgan = Model([z, label], validity)
    cgan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return cgan

def train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs, batch_size):

    (X_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

    X_train = np.expand_dims(X_train, axis=3)
    X_train = (X_train/127.5) - 1
    y_train_one = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Embed label to label embedding
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs, labels = X_train[idx], y_train[idx]
        labels_one = tf.keras.utils.to_categorical(labels, num_classes=num_classes)
        gen_labels = np.random.randint(0, num_classes, batch_size)
        gen_labels_one = tf.keras.utils.to_categorical(gen_labels, num_classes=num_classes)
        gen_noise = np.random.normal(0, 1, (batch_size, noise_dim))
        gen_imgs = generator.predict([gen_noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = cgan.train_on_batch([gen_noise, gen_labels], np.ones((batch_size, 1)))

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

if __name__ == '__main__':

    noise_dim = 100
    img_rows, img_cols, channels = 28, 28, 1
    num_classes = 10
    img_shape = (img_rows, img_cols, channels)

    generator = build_generator(noise_dim, img_shape, num_classes)
    discriminator = build_discriminator(img_shape, num_classes)
    cgan = build_cgan(generator, discriminator)

    train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs=30000, batch_size=32)

总结

CGAN是生成对抗网络GAN的一个扩展,可以使用条件变量来控制生成数据的特征。其基本原理与GAN相似,但对网络结构和训练方法上进行了改进。目前,CGAN已被成功应用于图片生成、风格迁移等领域,在生成多样性和可控性方面都有很好的表现。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:论文笔记之:Conditional Generative Adversarial Nets - Python技术站

(0)
上一篇 2023年3月28日
下一篇 2023年3月28日

相关文章

  • 好用到飞起的12个jupyterlab插件

    好用到飞起的12个JupyterLab插件 JupyterLab作为目前最流行的交互式计算环境之一,优秀的插件也是其受欢迎的重要原因之一。在这篇文章中,我们将介绍12个在JupyterLab中非常有用的插件,这些插件可以帮助你更加高效地开发和使用JupyterLab。 1. jupyterlab-git jupyterlab-git可以将JupyterLab…

    其他 2023年3月28日
    00
  • 搞懂redis协议resp

    当然,我可以为您提供有关“搞懂Redis协议RESP”的完整攻略,以下是详细说明: 什么是Redis协议RESP? Redis协议RESP(REdis Serialization Protocol)是种二进制协议,用于在Redis客户端和服务器间传输数据。RESP协议是一种简单的协议,它使用文本协议的形式来传输二进制数据。 RESP协议设计目标是简、快速和可…

    other 2023年5月7日
    00
  • androidcursor浅析

    androidcursor浅析 在Android开发中,常常需要对数据库进行操作。Android提供了一个SQLite数据库用于本地存储。如果要实现数据的增删改查,需要使用Android提供的SQLiteOpenHelper类,它封装了对SQLite数据库的操作,但是我们更多的时候会使用Cursor来获取数据库的内容。 什么是Cursor 打个比方,我们把它…

    其他 2023年3月29日
    00
  • js弹出窗口代码大全(详细整理)

    js弹出窗口代码大全(详细整理) JavaScript弹出窗口经常被用于在页面中显示重要信息或提供用户交互。本文将详细介绍JS弹出窗口的各种用法和代码示例。 alert弹窗 alert弹窗是JS中最常见的弹窗形式,它用于在页面中显示一段提示信息,用户需要点击确认按钮才能继续操作。 alert("这是一个alert弹窗!"); confir…

    其他 2023年3月28日
    00
  • java二叉树面试题详解

    Java二叉树面试题详解 简介 二叉树是一种非常重要的数据结构,常被用于算法设计与面试问答中。本文将详细探讨Java二叉树面试题相关知识以及解决方案。 常见问题 如何构建一个二叉树? 构建二叉树的方法有很多,但最基础的方法是通过节点类来实现。定义一个Node类来表示二叉树的节点,每个节点包括三个属性:value、left和right。其中,value表示节点…

    other 2023年6月27日
    00
  • css类选择器的使用方法详解

    CSS类选择器的使用方法详解 1. 什么是类选择器? CSS类选择器是一种用于选中具有相同类名的元素的选择器。它以.开头,后跟类名,可以选择多个元素并对其应用相同的样式。 2. 如何使用类选择器? 2.1 在HTML中定义类名 在HTML标签的class属性中定义类名,并为多个元素分配相同的类名。例如: <p class="highlight…

    other 2023年6月28日
    00
  • Ajax加载外部页面弹出层效果实现方法

    当通过Ajax请求获取HTML页面时,我们希望将其以弹出层的形式展示出来,而不是让其跳转到新页面。这种效果可以使用一下几个步骤实现: 步骤一:添加页面元素 首先需要在页面中添加一些HTML元素,包括弹出层和触发弹出层的按钮。 <!– 弹出层 –> <div id="modal"> <div class=&…

    other 2023年6月25日
    00
  • j-link固件烧录以及使用j-flash向arm硬件板下载固件程序

    j-link固件烧录以及使用j-flash向arm硬件板下载固件程序 本文主要介绍j-link固件烧录以及使用j-flash向arm硬件板下载固件程序的方法及相关操作流程。 j-link固件烧录 j-link是一款功能强大的调试器,已经成为了大部分arm开发人员的首选工具。在使用j-link时,可能会遇到固件版本过低或者需要更新固件的情况。下面介绍j-lin…

    其他 2023年3月28日
    00
合作推广
合作推广
分享本页
返回顶部