论文笔记之:Conditional Generative Adversarial Nets

论文笔记之: Conditional Generative Adversarial Nets

简介

Conditional Generative Adversarial Nets,简称CGAN,是一种生成对抗网络(GAN)的扩展。相对于传统的GAN,CGAN在输入噪声向量的基础上,额外输入了条件信息,使得生成的结果能够针对条件信息的不同而变化,具有更好的灵活性。CGAN最初由Mirza和Osalind等人于2014年提出,并且在实验中证明,相对于一般GAN,能够生成更为精细、多样化的结果。

实现原理

CGAN的实现基于一般的GAN,其核心思路是通过训练两个神经网络实现生成器和判别器的博弈:

  • 生成器负责产生假数据样本;
  • 判别器负责区分真实数据和假数据,并给出判断真实性的概率。

本文提出的CGAN在此基础上,通过将条件向量和噪声向量放入生成器与判别器,从而创造了一个特殊的生成对抗结构。简单来说,就是给一个GAN加入了一个条件变量。这个条件变量可以理解成是任意类型的辅助信息:比如图像中的标注、文本中的关键字,或者是其他结构化数据,这些信息将作为输入与噪声共同作为生成器的输入,从而生成对应的图片或文本。

代码实现

使用Keras框架可以快速实现一个基于MNIST数据集的CGAN。

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Embedding, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np


def build_generator(noise_dim, img_shape, num_classes):
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=noise_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization())

    model.add(Conv2D(img_shape[2], kernel_size=3, padding="same"))
    model.add(Activation("tanh"))

    noise = Input(shape=(noise_dim,))
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, noise_dim)(label))

    model_input = Concatenate()([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

def build_discriminator(img_shape, num_classes):
    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Dropout(rate=0.25))
    model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization())
    model.add(Activation("relu"))

    model.add(Flatten())
    img = Input(shape=img_shape)

    features = model(img)

    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))

    input_features = Concatenate()([features, label_embedding])

    validity = Dense(1, activation="sigmoid")(input_features)

    return Model([img, label], validity)

def build_cgan(generator, discriminator):

    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

    z = Input(shape=(noise_dim,))
    label = Input(shape=(1,))
    img = generator([z, label])

    discriminator.trainable = False
    validity = discriminator([img, label])
    cgan = Model([z, label], validity)
    cgan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return cgan

def train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs, batch_size):

    (X_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

    X_train = np.expand_dims(X_train, axis=3)
    X_train = (X_train/127.5) - 1
    y_train_one = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Embed label to label embedding
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs, labels = X_train[idx], y_train[idx]
        labels_one = tf.keras.utils.to_categorical(labels, num_classes=num_classes)
        gen_labels = np.random.randint(0, num_classes, batch_size)
        gen_labels_one = tf.keras.utils.to_categorical(gen_labels, num_classes=num_classes)
        gen_noise = np.random.normal(0, 1, (batch_size, noise_dim))
        gen_imgs = generator.predict([gen_noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = cgan.train_on_batch([gen_noise, gen_labels], np.ones((batch_size, 1)))

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

if __name__ == '__main__':

    noise_dim = 100
    img_rows, img_cols, channels = 28, 28, 1
    num_classes = 10
    img_shape = (img_rows, img_cols, channels)

    generator = build_generator(noise_dim, img_shape, num_classes)
    discriminator = build_discriminator(img_shape, num_classes)
    cgan = build_cgan(generator, discriminator)

    train(generator, discriminator, cgan, noise_dim, img_shape, num_classes, epochs=30000, batch_size=32)

总结

CGAN是生成对抗网络GAN的一个扩展,可以使用条件变量来控制生成数据的特征。其基本原理与GAN相似,但对网络结构和训练方法上进行了改进。目前,CGAN已被成功应用于图片生成、风格迁移等领域,在生成多样性和可控性方面都有很好的表现。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:论文笔记之:Conditional Generative Adversarial Nets - Python技术站

(0)
上一篇 2023年3月28日
下一篇 2023年3月28日

相关文章

  • select2中文帮助文档动态设置选中值

    以下是关于select2中文帮助文档动态设置选中值的完整攻略: select2简介 select2是一个基于jQuery的下拉框插件,它支持搜索、多选、远程数据加载等功能。select2可以在浏览器和Node.js环境中使用。 动态设置选中值 以下是如何使用select2动态设置选中值的步骤: 获取select2对象 设置选中值 触发change事件 示例1…

    other 2023年5月6日
    00
  • Python学习之面向对象编程详解

    Python学习之面向对象编程详解攻略 1. 理解面向对象编程的概念 在初学Python时,我们经常听到“面向对象编程”,但很少有人真正理解它的含义。面向对象编程(OOP)是一种编程方法,它将程序中的数据和方法组合成对象,并通过对象之间的交互来实现程序的功能。 OOP具有下面三个主要特性: 封装:将对象的状态和行为封装在一个单独的单元内,从而隔离了内部细节并…

    other 2023年6月27日
    00
  • EXCEL坐标轴怎么自定义设置?

    EXCEL中的坐标轴可以自定义设置,包括调整坐标轴刻度、坐标轴标签、坐标轴位置等。下面,我们将提供详细的攻略指导。 一、自定义设置坐标轴 1.1 调整坐标轴刻度 首先,右键单击图表中的坐标轴,选择格式化坐标轴选项。在弹出的格式化轴选项中,可以调整刻度尺寸、主刻度和次刻度之间的间距等。 示例1:调整坐标轴主刻度和次刻度之间的间距 在图表中选择一个坐标轴,右键单…

    other 2023年6月25日
    00
  • 【matlab】膨胀

    【matlab】膨胀 什么是膨胀? 膨胀是图像处理中的一种形态学运算,用于扩大和增强图像中物体的大小。它可以消除小的空洞(孔洞)或缝隙,并连接或分离物体。在数字图像处理中,常常使用膨胀与腐蚀(Erosion)共同构成对图像进行形态学滤波的操作。 膨胀的作用 对于二值图像,膨胀的作用主要有两种: 消除小的空洞(孔洞)或缝隙。在二值图像处理中,通常将物体标记为“…

    其他 2023年3月28日
    00
  • Java创建型设计模式之抽象工厂模式(Abstract Factory)

    Java创建型设计模式之抽象工厂模式(Abstract Factory) 抽象工厂模式是一种创建型设计模式,它提供了一种创建一系列相关或相互依赖对象的接口,而无需指定具体实现类。抽象工厂模式通过将对象的创建委托给工厂类来实现,从而实现了客户端与具体实现类的解耦。 结构 抽象工厂模式由以下几个关键组件组成: 抽象工厂(Abstract Factory):定义了…

    other 2023年10月15日
    00
  • Android实现ListView左右滑动删除和编辑

    Android实现ListView左右滑动删除和编辑攻略 在Android中实现ListView左右滑动删除和编辑功能可以通过以下步骤完成: 步骤1:添加依赖库 首先,在项目的build.gradle文件中添加以下依赖库: dependencies { implementation ‘com.android.support:recyclerview-v7:2…

    other 2023年9月6日
    00
  • w3wp.exe占用cpu过高的解决方法

    w3wp.exe占用CPU过高的解决方法 问题描述 在使用IIS部署Web应用程序的过程中,经常会遇到w3wp.exe进程占用CPU过高的问题。当进程占用率过高时,服务器的性能会下降,导致用户访问体验不佳。 解决方法 以下是几个可以尝试的解决方法: 1. 调整应用程序池的性能选项 进入IIS管理器,在左侧窗口中选择“应用程序池”,然后在右侧窗口中选择要修改的…

    other 2023年6月25日
    00
  • 明日之后重启灯塔奇遇任务通关步骤 重启灯塔任务攻略

    明日之后重启灯塔奇遇任务通关步骤 重启灯塔任务攻略 任务起点 需要注意的是,重启灯塔任务需要完成“触类旁通”任务,也就是在云端集市中购买“废墟痕迹”,交给黑店的一个NPC后开放重启灯塔任务。 任务前置要求 为了完成重启灯塔任务,你需要: 在游戏内达到等级25级以上 拥有足够的装备 拥有一定数量的药品 任务步骤 1. 与NPC对话 首先,前往尼斯湖附近,与那里…

    other 2023年6月27日
    00
合作推广
合作推广
分享本页
返回顶部