大家好,欢迎来到专栏《百战GAN》,我们在公众号已经输出了非常多的GAN相关的理论,这一次我们开设《百战GAN》专栏,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练等内容。

作者&编辑 | 言有三

本文资源与生成结果展示

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

本文篇幅:5000字

背景要求:会使用Python,Tensorflow或者Pytorch

附带资料:项目推荐,版本包括Pytorch+Tensorflow

同步平台:有三AI知识星球(一周内)

1 项目背景

生成对抗网络如今在计算机视觉的很多领域中都被广泛应用,需要每一个学习深度学习相关技术的算法人员掌握,我们公众号和知识星球讲述了非常多的理论知识,在这个《百战GAN》专栏中,我们会配合各类实战案例来帮助大家进行提升,本次项目开发需要以下环境:

(1) Linux系统或者windows系统,使用Linux效率更高。

(2) 安装好的Tensorflow,CPU或者GPU训练都可以。

2 原理简介

今天我们要实践的模型是DCGAN和CGAN,DCGAN是第一个全卷积GAN,麻雀虽小,五脏俱全,最适合新人实践。

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

DCGAN的生成器和判别器都采用了4层的网络结构。生成器网络结构如上图所示,输入为1×100的向量,然后经过一个全连接层学习,reshape为4×4×1024的张量,再经过4个上采样的反卷积网络层,生成64×64的图,各层的配置如下:

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

判别器输入64×64大小的图,经过4次卷积,分辨率降低为4×4的大小,每一个卷积层的配置如下:

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

DCGAN并不能控制生成图片的类别,条件GAN(CGAN)则使用了条件控制变量作为输入,是几乎后续所有性能强大的GAN的基础。网络结构如下,其中的y就是条件变量。

【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

对于生成器来说,输入包括z和y,两者会进行拼接后作为输入。对于判别器来说,输入包括了x和y,两者会进行拼接后作为输入,当然为了和z以及x进行拼接,y需要做一些维度变换,即reshape操作。

关于它们的理论更加详细的讲解,大家可以移步有三AI知识星球,或者自行阅读论文。

3 模型训练

接下来我们进行实践,选择tensorflow框架,下面详解具体的工程代码,主要包括:

(1) 生成器和判别器模型的定义。

(2) 损失和优化目标的定义。

3.1 DCGAN类定义

首先我们需要定义一个类,设计好输入输出,__init__函数如下:

# 模型定义

class DCGAN(object):

    def __init__(self, sess, input_height=108, input_width=108, crop=True,

         batch_size=64, sample_num = 64, output_height=64, output_width=64,

         y_dim=None, z_dim=100, gf_dim=64, df_dim=64,

         gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',

         max_to_keep=1,

         input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):

其中参数解释如下:sess表示TensorFlow session,batch_size即批处理大小;z_dim是噪声的维度,默认为100;y_dim是一个可选的条件变量,比如分类标签,用于CGAN;gf_dim是生成器第一个卷积层的通道数;df_dim是判别器第一个卷积层的通道数;gfc_dim是生成器全连接层维度;dfc_dim是判别器全连接层维度;c_dim是输入图像维度,灰度图为1,彩色图为3。

从上述代码可以看出,初始化函数__init__中配置了训练输入图尺寸,批处理大小,输出图尺寸,生成器的输入维度,以及生成器和判别的卷积层和全连接层的若干维度变量。