AutoGAN: Neural Architecture Search for Generative Adversarial Networks

生成对抗网络(GAN)可以用于生成、风格迁移、数据增强、超分辨率等任务。今天介绍一篇 ICCV 2019 的 paper: “AutoGAN: Neural Architecture Search for Generative Adversarial Networks”。这篇文章第一次把 NAS 和 GAN 结合,想要用神经网络结构搜索(NAS)的方法搜一个GAN 的网络结构。

作者是来自 Texas A&M University 和 MIT-IBM Watson AI Lab 的 Xinyu Gong, Shiyu Chang, Yifan Jiang, Zhangyang Wang,代码开源在这里

作者指出,将 NAS 和 GAN 结合,会遇到下面的问题:
问题1: GAN 由生成器 G 和判别器 D 组成。那么应该先固定住其中一个的结构,搜另一个吗?还是说这两个应该一起同时搜呢?

如果先固定一个搜另一个,可能会导致两者之间的不平衡;而如果两个一起搜,GAN 的训练不稳定,可能会遇到 Mode collapse 之类的问题。

作者的解决方法是: 只搜 G ,但是 D 并不是一成不变的。当 G 变的越来越深的时候, D 也会堆叠一些预先确定的 block 来变深。

问题2: 没有一个好的评价指标来给搜索过程提供反馈。
GAN 常用的指标是 Inception score (IS) 和 FID score,由于 FID score 计算比较慢,作者就选择用 Inception score 作为 RL 的 award。

搜索空间

下面介绍 AutoGAN 的搜索空间。作者采用了 Multi-Level Architecture Search (MLAS) 的策略,就是生成器由不同的 cell 组成,在搜索的时候一个 cell 对应一个 RNN controller 。搜索空间如下图所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络
第 s 个 cell 的搜索空间就可以由一个形状为 (s+5) 的元组 (skip1,...,skips,C,N,U,SC)(skip_1, ..., skip_s, C, N, U, SC) 来表示。其中 skipiskip_i 表示当前的 cell 和前面坐标为 i1i-1 的 cell 之间的 skip connection。C 表示卷积 block 的类型,有**函数放在 conv 前/后两种;N 表示 normalization 的类型,有 BN / IN / 不加 normalization 三种;U 表示上采样操作的类型,有双线性差值、最近邻差值、stride 为 2 的反卷积三种;SC 表示 cell 内部要不要加 shortcut 连接。

这个搜索空间其实还算是比较简单的,一个 cell 里面就有一次上采样,再过两个 conv,中间加一些 skip-connection。

搜索策略

下面介绍 AutoGAN 的搜索策略:作者用的是 RL + RNN controller 的方法,训练的时候一共要更新两组参数:一个是 RNN controller 的参数 θtheta,一个是 GAN 的生成器和判别器的参数 ωomega

由于训练 GAN 的时候是不稳定的,如果模型已经 collapsed 的话,就没有必要继续训练了。作者根据经验归纳出一个结论:如果训练 loss 的方差变得比较小,那么很可能就是发生了 mode collapse。作者提出了一种 dynamic-resetting 的策略:用一个滑动窗口来存储生成器和判别器的 training loss,如果方差小于一个阈值,当前 GAN 的训练就会终止,生成器和判别器的参数会重新初始化。不过 RNN controller 的参数还是会保留的。

整个训练过程如下图所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络
可以看出,训练过程分为两步:

  1. 固定住 θtheta,只更新 GAN 的参数 ωomega。从 RNN controller 得到出一堆候选的结构,用 hinge adversarial loss 来训练 GAN。同时会根据 training loss 的计算来提前终止已经发生 mode collapse 的模型。

  2. 固定住 GAN 的参数 ωomega,只更新 RNN 的参数 θtheta。作者用的是一个 LSTM,首先采样出 K 个 child models,然后计算对应的 inception score 作为 reward,随后用强化学习的方式更新 LSTM 的权重。

Experiments

作者用的数据集是 CIFAR-10 (分辨率 32 x 32) 和 STL-10 (分辨率 48 x 48)。作者在 CIFAR-10 上搜到的生成器结构如下图所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络
可以看出,这个搜出来的结构倾向于把**函数加在卷积的前面、使用双线性插值而不是反卷积、不使用 normalization 运算、加很多 skipping connections 运算。

和其他方法的对比如下表所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络
作者指出,本文方法的搜索空间和 SN-GAN 比较像,因此和 SN-GAN 对比能看出来确实比人工设计的方法好。而有些方法用到了本文的搜索空间中没有的运算,例如 WGAN-GP 用到了 Wasserstein loss ,和他们比其实不太公平。而 SN-GAN 是基于 ResNet Block 的,并且在判别器部分移除了 BN 。这么看的话,本文搜出来的结构主要是多了 cell 之间的 skip connection。

作者为了证明搜出来的结构没有过拟合 CIFAR 这个小数据集,还提供了在 STL-10 上面的结果。作者保留在 CIFAR 上得到的结构不变,在 STL-10 上面重新训练,得到的结果如下表所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络
这个结果在 FIN 指标上比 Improving MMD GAN 要好。

AutoGAN 在 CIFAR 数据集上生成得到的图片效果如下图所示:
ICCV 2019| Auto GAN 论文解读,神经网络结构搜索 + 生成对抗网络

总结

作者指出,AutoGAN 还有很大的提升空间,都是后面可以继续做的点:

  1. 搜索空间还可以更大。比如说加入一些 attention 的候选 block,还可以加入 Wasserstein loss 等等;
  2. 可以换到更大分辨率的数据集,如 ImageNet 。不过在低分辨率的 CIFAR 10 上搜索要 43 个小时,换到大图上肯定要尝试更高效的搜索算法;
  3. 本文其实没有怎么搜判别器 D,可以研究怎么得到更好的判别器 D;
  4. 如何和 Conditional GAN 结合,在训练的时候引入标签信息。

综合来看,这个 NAS + GAN 的坑还有很多可以填的地方。本文的贡献主要在于第一次把 NAS + GAN 这种东西搞 work。

如果有什么理解不到位的地方,欢迎在评论区指正。