推荐的几个开源实现

  1. znxlwm 使用InfoGAN的结构,卷积反卷积
  2. eriklindernoren 把mnist转成1维,label用了embedding
  3. wiseodd 直接从tensorflow代码转换过来的,数据集居然还用tf的数据集。。
  4. Yangyangii 转1维向量,全连接
  5. FangYang970206 提供了多标签作为条件的实现思路
  6. znxlwm 专门针对MNIST数据集的一个实现方法,转1维,比较接近原paper的实现方法

训练过程

  • 简述
    # z - 随机噪声
    # X - 输入数据
    # c - 输入的label

    # ===== 训练判别器D =====

    # 真数据输入到D中
    D_real = D(X, c) 
    # 真数据D的判断结果应尽可能接近1  
    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)   

    # 生成随机噪声
    z = torch.rand((batch_size, self.z_dim)) 
    # G生成的伪数据,这一步的c可以用已知的,也可以重新随机生成一些label,但总之这些c所生成的数据都是伪的
    G_sample = G(z, c)  
    # 伪数据输入到D中
    D_fake = D(G_sample , c)    
    # 伪数据D的判断结果应尽可能接近0
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)     
   
    # D的loss定义为上面两部分之和,即真数据要尽可能接近1,伪数据要尽可能接近0
    D_loss = D_loss_real + D_loss_fake 

    # 更新D的参数
    D_loss.backward()
    D_solver.step()

    # 在训练G之前把梯度清零,也可以不这么做
    reset_grad()
    
    # ===== 训练生成器G =====

    # 这里可以选择,有的实现是直接用上面的z
    z = Variable(torch.randn(mb_size, Z_dim))  
    # 这里可以选择用已知的c,或者重新采样
    c = 重新随机一些label  
    # 用G生成伪数据
    G_sample = G(z, c) 
    # 伪数据输入到D中              
    D_fake = D(G_sample, c)     
    # 此时计算的是G的Loss,伪数据D的判断结果应尽可能接近1,因为G要试图骗过D
    G_loss = nn.binary_cross_entropy(D_fake, ones_label)  
    
    # 更新G的参数
    G_loss.backward()
    G_solver.step()

一些坑

  1. 计算D和G的loss时最好分别用不同的随机噪声,否则有可能训练过程不会收敛,而且结果差
  2. 注意,训练的时候随机噪声的分布应该要保持和测试时的分布一致,不要一个用均匀分布,一个用正态分布

初步结果

哈哈哈看到终于训练出来像样的数字,还是有点小成就的
pytorch conditional GAN 调试笔记