论文笔记之:Conditional Generative Adversarial Nets

yizhihongxing

论文笔记之: Conditional Generative Adversarial Nets

简介

Conditional Generative Adversarial Nets,简称CGAN,是一种生成对抗网络(GAN)的扩展。相对于传统的GAN,CGAN在输入噪声向量的基础上,额外输入了条件信息,使得生成的结果能够针对条件信息的不同而变化,具有更好的灵活性。CGAN最初由Mirza和Osalind等人于2014年提出,并且在实验中证明,相对于一般GAN,能够生成更为精细、多样化的结果。

实现原理

CGAN的实现基于一般的GAN,其核心思路是通过训练两个神经网络实现生成器和判别器的博弈:

  • 生成器负责产生假数据样本;
  • 判别器负责区分真实数据和假数据,并给出判断真实性的概率。

本文提出的CGAN在此基础上,通过将条件向量和噪声向量放入生成器与判别器,从而创造了一个特殊的生成对抗结构。简单来说,就是给一个GAN加入了一个条件变量。这个条件变量可以理解成是任意类型的辅助信息:比如图像中的标注、文本中的关键字,或者是其他结构化数据,这些信息将作为输入与噪声共同作为生成器的输入,从而生成对应的图片或文本。

代码实现

使用Keras框架可以快速实现一个基于MNIST数据集的CGAN。

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Embedding, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np


def build_generator(noise_dim, img_shape, num_classes):
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=noise_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(Conv2D(img_shape[2], kernel_size=3, padding="same"))
    model.add(Activation("tanh"))

    noise = Input(shape=(noise_dim,))
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, noise_dim)(label))

    model_input = Concatenate()([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

def build_discriminator(img_shape, num_classes):
    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Flatten())
    img = Input(shape=img_shape)

    features = model(img)

    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))

    input_features = Concatenate()([features, label_embedding])

    validity = Dense(1, activation="sigmoid")(input_features)

    return Model([img, label], validity)

def build_cgan(generator, discriminator):

    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

    z = Input(shape=(noise_dim,))
    label = Input(shape=(1,))
    img = generator([z, label])

    discriminator.trainable = False
    validity = discriminator([img, label])
    cgan = Model([z, label], validity)
    cgan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return cgan

def train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs, batch_size):

    (X_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

    X_train = np.expand_dims(X_train, axis=3)
    X_train = (X_train/127.5) - 1
    y_train_one = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Embed label to label embedding
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs, labels = X_train[idx], y_train[idx]
        labels_one = tf.keras.utils.to_categorical(labels, num_classes=num_classes)
        gen_labels = np.random.randint(0, num_classes, batch_size)
        gen_labels_one = tf.keras.utils.to_categorical(gen_labels, num_classes=num_classes)
        gen_noise = np.random.normal(0, 1, (batch_size, noise_dim))
        gen_imgs = generator.predict([gen_noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = cgan.train_on_batch([gen_noise, gen_labels], np.ones((batch_size, 1)))

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

if __name__ == '__main__':

    noise_dim = 100
    img_rows, img_cols, channels = 28, 28, 1
    num_classes = 10
    img_shape = (img_rows, img_cols, channels)

    generator = build_generator(noise_dim, img_shape, num_classes)
    discriminator = build_discriminator(img_shape, num_classes)
    cgan = build_cgan(generator, discriminator)

    train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs=30000, batch_size=32)

总结

CGAN是生成对抗网络GAN的一个扩展,可以使用条件变量来控制生成数据的特征。其基本原理与GAN相似,但对网络结构和训练方法上进行了改进。目前,CGAN已被成功应用于图片生成、风格迁移等领域,在生成多样性和可控性方面都有很好的表现。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:论文笔记之:Conditional Generative Adversarial Nets - Python技术站

(0)
上一篇 2023年3月28日
下一篇 2023年3月28日

相关文章

  • python利用后缀表达式实现计算器功能

    Python利用后缀表达式实现计算器功能攻略 后缀表达式(也称为逆波兰表达式)是一种将运算符放在操作数之后的表示方法。利用后缀表达式可以实现计算器功能,以下是详细的攻略。 步骤一:将中缀表达式转换为后缀表达式 创建一个空栈和一个空列表,用于存储运算符和后缀表达式。 从左到右遍历中缀表达式的每个字符。 如果遇到操作数(数字),将其添加到后缀表达式列表中。 如果…

    other 2023年8月5日
    00
  • proptypes使用

    当然,我很乐意为您提供有关“PropTypes使用”的完整攻略。以下是详细的步骤和两个示例: 1. 什么是PropTypes? PropTypes是React中的一个库,用于检查组件的属性是否符合预期。它可以帮助开发人员在开发过程中捕获错误,并提高代码的可维护性。 2. PropTypes使用 以下是使用PropTypes的步骤: 2.1 引入PropTyp…

    other 2023年5月6日
    00
  • Java三大特性之封装详解

    Java三大特性之封装详解 在Java中,封装是面向对象编程的三大特性之一。封装是指将数据和方法包装在一个单元中,通过访问修饰符来控制对数据的访问。封装的目的是隐藏内部实现细节,提供对外部的安全访问接口。 封装的优点 封装具有以下几个优点: 数据隐藏:封装可以将数据隐藏在类的内部,只暴露必要的接口给外部使用。这样可以防止外部直接访问和修改数据,保证数据的安全…

    other 2023年8月8日
    00
  • 如何实现bean初始化摧毁方法的注入

    实现bean初始化摧毁方法的注入,需要通过Spring的IOC容器实现。Spring提供了两种方式来实现bean的初始化和销毁方法的注入:使用注解和使用XML配置文件。 一、使用注解的方式: 使用注解@PostConstruct来指定bean初始化方法,使用@PreDestroy来指定bean销毁方法。 @Component public class MyB…

    other 2023年6月20日
    00
  • Android开发实现ListView点击展开收起效果示例

    Android开发实现ListView点击展开收起效果示例攻略 在Android开发中,实现ListView点击展开收起效果是一个常见的需求。下面将详细介绍如何实现这一效果,并提供两个示例说明。 步骤一:准备工作 首先,在XML布局文件中定义ListView和需要展开收起的子项布局。例如: <ListView android:id=\"@+i…

    other 2023年8月26日
    00
  • Java类加载初始化的过程及顺序

    下面我将详细讲解Java类加载初始化的过程及顺序。 Java类加载初始化的过程 Java的类加载过程一般分为三个部分:类加载、链接和初始化。其中类的加载是指将类的.class文件读入内存,并将其转化成方法区中的运行时数据结构;链接是将类的常量池中的符号引用转化成直接引用的过程,然后进行内存地址的检验,最后完成方法表的预备工作;初始化则是对类的静态变量进行初始…

    other 2023年6月20日
    00
  • Java线程中的常见方法(start方法和run方法)

    Java线程中的常见方法包括start()方法和run()方法,它们是Java多线程进行并发编程的基础。 start()方法 start()方法是启动线程的方法,它会在新的线程中执行run()方法。在调用start()方法后,JVM会自动调用run()方法,因此我们不应该直接调用run()方法。当线程启动后,start()方法就会返回,该方法不会等待线程执行…

    other 2023年6月27日
    00
  • extundelete教程(完整版)

    以下是详细讲解“extundelete教程(完整版)”的标准Markdown格式文本: extundelete教程(完整版) extundelete是一款用于恢复已删除文件的工具,适用于ext3和ext4文件系统。本攻略将介绍如何使用extundelete来恢复已删除的文件,包括安装、使用和示例说明等内容。 安装extundelete 在Ubuntu和Deb…

    other 2023年5月10日
    00
合作推广
合作推广
分享本页
返回顶部