引入包,数据约定等

import numpy as np
import matplotlib.pyplot as plt
import input_data  #读取数据的一个工具文件,不影响理解
import tensorflow as tf


# 获取数据
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images

X = mnist.train.images[:, :]
batch_size = 64

#用来返回真实数据
def iterate_minibatch(x, batch_size, shuffle=True):
    indices = np.arange(x.shape[0])
    if shuffle:
        np.random.shuffle(indices)
    for i in range(0, x.shape[0]-1000, batch_size):
        temp = x[indices[i:i + batch_size], :]
        temp = np.array(temp) * 2 - 1
        yield np.reshape(temp, (-1, 28, 28, 1))

GAN对象结构

class GAN(object):
    def __init__(self):
        #初始函数,在这里对初始化模型
    def netG(self, z):
        #生成器模型
    def netD(self, x, reuse=False):
        #判别器模型

生成器函数

对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
包装过程概括为:全连接->reshape->反卷积
包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧

   #对随机值z(维度为1,100),进行包装,伪造,产生伪造数据。
    #包装过程概括为:全连接->reshape->反卷积
    #包装过程中使用了batch_normalization,Leaky ReLU,dropout,tanh等技巧
    def netG(self,z,alpha=0.01):
        with tf.variable_scope('generator') as scope:
            layer1 = tf.layers.dense(z, 4 * 4 * 512)  # 这是一个全连接层,输出 (n,4*4*512)
            layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
            # batch normalization
            layer1 = tf.layers.batch_normalization(layer1, training=True)  # 做BN标准化处理
            # Leaky ReLU
            layer1 = tf.maximum(alpha * layer1, layer1)
            # dropout
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 4 x 4 x 512 to 7 x 7 x 256
            layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 256 to 14 x 14 x 128
            layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 14 x 14 x 128 to 28 x 28 x 1
            logits = tf.layers.conv2d_transpose(layer3, 1, 3, strides=2, padding='same')
            # MNIST原始数据集的像素范围在0-1,这里的生成图片范围为(-1,1)
            # 因此在训练时,记住要把MNIST像素范围进行resize
            outputs = tf.tanh(logits)

            return outputs

判别器函数

通过深度卷积+全连接的形式,判别器将输入分类为真数据,还是假数据。

    def netD(self, x, reuse=False,alpha=0.01):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            layer1 = tf.layers.conv2d(x, 128, 3, strides=2, padding='same')
            layer1 = tf.maximum(alpha * layer1, layer1)
            layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

            # 14 x 14 x 128 to 7 x 7 x 256
            layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
            layer2 = tf.layers.batch_normalization(layer2, training=True)
            layer2 = tf.maximum(alpha * layer2, layer2)
            layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

            # 7 x 7 x 256 to 4 x 4 x 512
            layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
            layer3 = tf.layers.batch_normalization(layer3, training=True)
            layer3 = tf.maximum(alpha * layer3, layer3)
            layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

            # 4 x 4 x 512 to 4*4*512 x 1
            flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
            f = tf.layers.dense(flatten, 1)
            return f

初始化函数

有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力

    # 有一个前置训练,将真实数据喂给判别器,训练判别器的鉴别能力
    def __init__(self):
        self.z = tf.placeholder(tf.float32, shape=[batch_size, 100], name='z')  # 随机输入值
        self.x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='real_x')  # 图片值

        self.fake_x = self.netG(self.z)  # 将随机输入,包装为伪造图片值

        self.pre_logits = self.netD(self.x, reuse=False)  # 判别器预训练时,判别器对真实数据的判别情况-未sigmoid处理
        self.real_logits = self.netD(self.x, reuse=True)  # 判别器对真实数据的判别情况-未sigmoid处理
        self.fake_logits = self.netD(self.fake_x, reuse=True)  # 判别器对伪造数据的判别情况-未sigmoid处理

        # 预训练时判别器,判别器将真实数据判定为真的得分情况。
        self.loss_pre_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.pre_logits,
                                                                                 labels=tf.ones_like(self.pre_logits)))
        # 训练时,判别器将真实数据判定为真,将伪造数据判定为假的得分情况。
        self.loss_D = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_logits,
                                                                             labels=tf.ones_like(self.real_logits))) + 
                      tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.zeros_like(self.fake_logits)))
        # 训练时,生成器伪造的数据,被判定为真实数据的得分情况。
        self.loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_logits,
                                                                             labels=tf.ones_like(self.fake_logits)))

        # 获取生成器和判定器对应的变量地址,用于更新变量
        t_vars = tf.trainable_variables()
        self.g_vars = [var for var in t_vars if var.name.startswith("generator")]
        self.d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

开始训练

gan = DCGAN()
#预训练时的梯度优化函数
d_pre_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_pre_D, var_list=gan.d_vars)
#判别器的梯度优化函数
d_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_D, var_list=gan.d_vars)
#预训练时的梯度优化函数
g_optim = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.4).minimize(gan.loss_G, var_list=gan.g_vars)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    #对判别器的预训练,训练了两个epoch
    for i in range(2):
        print('判别器初始训练,第' + str(i) + '次包')
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            loss_pre_D, _ = sess.run([gan.pre_logits, d_pre_optim],
                                     feed_dict={
                                         gan.x: x_batch
                                     })
    #训练5个epoch
    for epoch in range(5):
        print('对抗' + str(epoch) + '次包')
        avg_loss = 0
        count = 0
        for x_batch in iterate_minibatch(X, batch_size=batch_size):
            z_batch = np.random.uniform(-1, 1, size=(batch_size, 100))  # 随机起点值

            loss_D, _ = sess.run([gan.loss_D, d_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     gan.x: x_batch
                                 })

            loss_G, _ = sess.run([gan.loss_G, g_optim],
                                 feed_dict={
                                     gan.z: z_batch,
                                     # gan.x: np.zeros(z_batch.shape)
                                 })

            avg_loss += loss_D
            count += 1

        # 显示预测情况
        if True:
            avg_loss /= count
            z = np.random.normal(size=(batch_size, 100))
            excerpt = np.random.randint(100, size=batch_size)
            needTest = np.reshape(X[excerpt, :], (-1, 28, 28, 1))
            fake_x, real_logits, fake_logits = sess.run([gan.fake_x, gan.real_logits, gan.fake_logits],
                                                        feed_dict={gan.z: z, gan.x: needTest})
            # accuracy = (np.sum(real_logits > 0.5) + np.sum(fake_logits < 0.5)) / (2 * batch_size)
            print('real_logits')
            print(len(real_logits))
            print('fake_logits')
            print(len(fake_logits))
            print('ndiscriminator loss at epoch %d: %f' % (epoch, avg_loss))
            # print('ndiscriminator accuracy at epoch %d: %f' % (epoch, accuracy))
            print('----')
            print()

            # curr_img = np.reshape(trainimg[i, :], (28, 28))  # 28 by 28 matrix
            curr_img = np.reshape(fake_x[0], (28, 28))
            plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img2 = np.reshape(fake_x[10], (28, 28))
            plt.matshow(curr_img2, cmap=plt.get_cmap('gray'))
            plt.show()
            curr_img3 = np.reshape(fake_x[20], (28, 28))
            plt.matshow(curr_img3, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img4 = np.reshape(fake_x[30], (28, 28))
            plt.matshow(curr_img4, cmap=plt.get_cmap('gray'))
            plt.show()

            curr_img5 = np.reshape(fake_x[40], (28, 28))
            plt.matshow(curr_img5, cmap=plt.get_cmap('gray'))
            plt.show()
            # plt.figure(figsize=(28, 28))

            # plt.title("" + str(i) + "th Training Data "
            #           + "Label is " + str(curr_label))
            # print("" + str(i) + "th Training Data "
            #       + "Label is " + str(curr_label))

            # plt.scatter(X[:, 0], X[:, 1])
            # plt.scatter(fake_x[:, 0], fake_x[:, 1])
            # plt.show()

结果

GAN生成式对抗网络(三)——mnist数据生成

下载链接