对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的。 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(discriminator)之间博弈的过程。整个网络训练的过程中,
两个模块的分工
- 判断网络,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值,用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假)
- 生成网络,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个图像。
两个模块的训练目的
- 判别网络的目的:就是能判别出来属于的一张图它是来自真实样本集还是假样本集。假如输入的是真样本,网络输出就接近1,输入的是假样本,网络输出接近0,那么很完美,达到了很好判别的目的。
- 生成网络的目的:生成网络是造样本的,它的目的就是使得自己造样本的能力尽可能强,强到判别网络没法判断是真样本还是假样本。
GAN的训练
需要注意的是生成模型与对抗模型可以说是完全独立的两个模型,好比就是完全独立的两个神经网络模型,他们之间没有什么联系。
那么训练这样的两个模型的大方法就是:单独交替迭代训练。因为是2个网络,不好一起训练,所以才去交替迭代训练,我们一一来看。
首先我们先随机产生一个生成网络模型(当然可能不是最好的生成网络),那么给一堆随机数组,就会得到一堆假的样本集(因为不是最终的生成模型,那么现在生成网络可能就处于劣势,导致生成的样本很糟糕,可能很容易就被判别网络判别出来了说这货是假冒的),但是先不管这个,假设我们现在有了这样的假样本集,真样本集一直都有,现在我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这里我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0.
对于生成网络,回想下我们的目标,是生成尽可能逼真的样本。那么原始的生成网络生成的样本你怎么知道它真不真呢?就是送到判别网络中,所以在训练生成网络的时候,我们需要联合判别网络一起才能达到训练的目的。就是如果我们单单只用生成网络,那么想想我们怎么去训练?误差来源在哪里?细想一下没有,但是如果我们把刚才的判别网络串接在生成网络的后面,这样我们就知道真假了,也就有了误差了。所以对于生成网络的训练其实是对生成-判别网络串接的训练,就像图中显示的那样。好了那么现在来分析一下样本,原始的噪声数组Z我们有,也就是生成了假样本我们有,此时很关键的一点来了,我们要把这些假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本。这样才能起到迷惑判别器的目的,也才能使得生成的假样本逐渐逼近为正样本。
下面是代码部分,这里,我们利用训练的两个数据集分别是
- mnist
- Celeba
来生成手写数字以及人脸
首先是数据集的下载
1 import math 2 import os 3 import hashlib 4 from urllib.request import urlretrieve 5 import zipfile 6 import gzip 7 import shutil 8 9 data_dir = './data' 10 11 def download_extract(database_name, data_path): 12 """ 13 Download and extract database 14 :param database_name: Database name 15 """ 16 DATASET_CELEBA_NAME = 'celeba' 17 DATASET_MNIST_NAME = 'mnist' 18 19 if database_name == DATASET_CELEBA_NAME: 20 url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip' 21 hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb' 22 extract_path = os.path.join(data_path, 'img_align_celeba') 23 save_path = os.path.join(data_path, 'celeba.zip') 24 extract_fn = _unzip 25 elif database_name == DATASET_MNIST_NAME: 26 url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' 27 hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' 28 extract_path = os.path.join(data_path, 'mnist') 29 save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz') 30 extract_fn = _ungzip 31 32 if os.path.exists(extract_path): 33 print('Found {} Data'.format(database_name)) 34 return 35 36 if not os.path.exists(data_path): 37 os.makedirs(data_path) 38 39 if not os.path.exists(save_path): 40 with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar: 41 urlretrieve( 42 url, 43 save_path, 44 pbar.hook) 45 46 assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, 47 '{} file is corrupted. Remove the file and try again.'.format(save_path) 48 49 os.makedirs(extract_path) 50 try: 51 extract_fn(save_path, extract_path, database_name, data_path) 52 except Exception as err: 53 shutil.rmtree(extract_path) # Remove extraction folder if there is an error 54 raise err 55 56 # Remove compressed data 57 os.remove(save_path) 58 59 # download mnist 60 download_extract('mnist', data_dir) 61 # download celeba 62 download_extract('celeba', data_dir
我们先看看我们的mnist还有celeba数据集是什么样子
1 # the number of images 2 show_n_images =16 3 4 %matplotlib inline 5 import os 6 from glob import glob 7 from matplotlib import pyplot 8 9 def get_batch(image_files, width, height, mode): 10 data_batch = np.array( 11 [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) 12 13 # Make sure the images are in 4 dimensions 14 if len(data_batch.shape) < 4: 15 data_batch = data_batch.reshape(data_batch.shape + (1,)) 16 17 return data_batch 18 19 def images_square_grid(images, mode): 26 # Get maximum size for square grid of images 27 save_size = math.floor(np.sqrt(images.shape[0])) 28 29 # Scale to 0-255 30 images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) 31 32 # Put images in a square arrangement 33 images_in_square = np.reshape( 34 images[:save_size*save_size], 35 (save_size, save_size, images.shape[1], images.shape[2], images.shape[3])) 36 if mode == 'L': 37 images_in_square = np.squeeze(images_in_square, 4) 38 39 # Combine images to grid image 40 new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size)) 41 for col_i, col_images in enumerate(images_in_square): 42 for image_i, image in enumerate(col_images): 43 im = Image.fromarray(image, mode) 44 new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2])) 45 46 return new_im 47 48 mnist_images = get_batch(glob(os.path.join(data_dir, 'mnist/*.jpg'))[:show_n_images], 28, 28, 'L') 49 pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray')
mninst:
1 show_n_images = 9 2 3 mnist_images = get_batch(glob(os.path.join(data_dir, 'img_align_celeba/*.jpg'))[:show_n_images], 28, 28, 'RGB') 4 pyplot.imshow(images_square_grid(mnist_images, 'RGB'))
celeba
现在我们开始搭建网络
这里我建议用GPU来训练,tensorflow的版本最好是1.1.0
1 from distutils.version import LooseVersion 2 import warnings 3 import tensorflow as tf 4 5 # Check TensorFlow Version 6 assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer. You are using {}'.format(tf.__version__) 7 print('TensorFlow Version: {}'.format(tf.__version__)) 8 9 # Check for a GPU 10 if not tf.test.gpu_device_name(): 11 warnings.warn('No GPU found. Please use a GPU to train your neural network.') 12 else: 13 print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
接着我们要做的是构建输入
1 def model_inputs(image_width, image_height, image_channels, z_dim): 2 ## Real imag 3 inputs_real = tf.placeholder(tf.float32,(None, image_width,image_height,image_channels), name = 'input_real') 4 5 ## input z 6 7 inputs_z = tf.placeholder(tf.float32,(None, z_dim), name='input_z') 8 9 ## Learning rate 10 learning_rate = tf.placeholder(tf.float32, name = 'lr') 11 12 return inputs_real, inputs_z, learning_rate
构建Discriminator
1 def discriminator(images, reuse=False): 2 """ 3 Create the discriminator network 4 :param images: Tensor of input image(s) 5 :param reuse: Boolean if the weights should be reused 6 :return: Tuple of (tensor output of the discriminator, tensor logits of the discriminator) 7 """ 8 # TODO: Implement Function 9 10 ## scope here 11 12 with tf.variable_scope('discriminator', reuse=reuse): 13 14 alpha = 0.2 ### leak relu coeff 15 16 # drop out probability 17 keep_prob = 0.8 18 19 # input layer 28 * 28 * color channel 20 x1 = tf.layers.conv2d(images, 128, 5, strides=2, padding='same', 21 kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2)) 22 ## No batch norm here 23 ## leak relu here / alpha = 0.2 24 relu1 = tf.maximum(alpha * x1, x1) 25 # applied drop out here 26 drop1 = tf.nn.dropout(relu1, keep_prob= keep_prob) 27 # 14 * 14 * 128 28 29 # Layer 2 30 x2 = tf.layers.conv2d(drop1, 256, 5, strides=2, padding='same', 31 kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2)) 32 ## employ batch norm here 33 bn2 = tf.layers.batch_normalization(x2, training=True) 34 ## leak relu 35 relu2 = tf.maximum(alpha * bn2, bn2) 36 drop2 = tf.nn.dropout(relu2, keep_prob=keep_prob) 37 38 # 7 * 7 * 256 39 40 # Layer3 41 x3 = tf.layers.conv2d(drop2, 512, 5, strides=2, padding='same', 42 kernel_initializer= tf.contrib.layers.xavier_initializer(seed=2)) 43 bn3 = tf.layers.batch_normalization(x3, training=True) 44 relu3 = tf.maximum(alpha * bn3, bn3) 45 drop3 = tf.nn.dropout(relu3, keep_prob=keep_prob) 46 # 4 * 4 * 512 47 48 # Output 49 # Flatten 50 flatten = tf.reshape(relu3, (-1, 4 * 4 * 512)) 51 logits = tf.layers.dense(flatten,1) 52 # activation 53 out = tf.nn.sigmoid(logits) 54 55 return out, logits
接着是 Generator
1 def generator(z, out_channel_dim, is_train=True): 2 """ 3 Create the generator network 4 :param z: Input z 5 :param out_channel_dim: The number of channels in the output image 6 :param is_train: Boolean if generator is being used for training 7 :return: The tensor output of the generator 8 """ 9 # TODO: Implement Function 10 11 with tf.variable_scope('generator', reuse = not is_train): 12 # First Fully connect layer 13 x0 = tf.layers.dense(z, 4 * 4 * 512) 14 # Reshape 15 x0 = tf.reshape(x0,(-1,4,4,512)) 16 # Use the batch norm 17 bn0 = tf.layers.batch_normalization(x0, training= is_train) 18 # Leak relu 19 relu0 = tf.nn.relu(bn0) 20 # 4 * 4 * 512 21 22 # Conv transpose here 23 x1 = tf.layers.conv2d_transpose(relu0, 256, 4, strides=1, padding='valid') 24 bn1 = tf.layers.batch_normalization(x1, training=is_train) 25 relu1 = tf.nn.relu(bn1) 26 # 7 * 7 * 256 27 28 x2 = tf.layers.conv2d_transpose(relu1, 128, 3, strides=2, padding='same') 29 bn2 = tf.layers.batch_normalization(x2, training=is_train) 30 relu2 = tf.nn.relu(bn2) 31 # 14 * 14 * 128 32 33 # Last cov 34 logits = tf.layers.conv2d_transpose(relu2, out_channel_dim, 3, strides=2, padding='same') 35 ## without batch norm here 36 out = tf.tanh(logits) 37 38 39 return out
然后我们来定义loss,这里,加入了smoother
1 def model_loss(input_real, input_z, out_channel_dim): 2 """ 3 Get the loss for the discriminator and generator 4 :param input_real: Images from the real dataset 5 :param input_z: Z input 6 :param out_channel_dim: The number of channels in the output image 7 :return: A tuple of (discriminator loss, generator loss) 8 """ 9 # TODO: Implement Function 10 11 12 g_model = generator(input_z, out_channel_dim, is_train=True) 13 14 d_model_real, d_logits_real = discriminator(input_real, reuse = False) 15 16 d_model_fake, d_logits_fake = discriminator(g_model, reuse= True) 17 18 ## add smooth here 19 20 smooth = 0.1 21 d_loss_real = tf.reduce_mean( 22 tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 23 labels=tf.ones_like(d_model_real) * (1 - smooth))) 24 25 d_loss_fake = tf.reduce_mean( 26 tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_model_fake))) 27 28 g_loss = tf.reduce_mean( 29 tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 30 labels= tf.ones_like(d_model_fake))) 31 32 d_loss = d_loss_real + d_loss_fake 33 34 35 36 return d_loss, g_loss
接着我们需要定义网络优化的过程,这里我们需要用到batch_normlisation, 不懂的话去搜下文档
1 def model_opt(d_loss, g_loss, learning_rate, beta1): 2 """ 3 Get optimization operations 4 :param d_loss: Discriminator loss Tensor 5 :param g_loss: Generator loss Tensor 6 :param learning_rate: Learning Rate Placeholder 7 :param beta1: The exponential decay rate for the 1st moment in the optimizer 8 :return: A tuple of (discriminator training operation, generator training operation) 9 """ 10 11 t_vars = tf.trainable_variables() 12 d_vars = [var for var in t_vars if var.name.startswith('discriminator')] 13 g_vars = [var for var in t_vars if var.name.startswith('generator')] 14 15 16 17 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 18 19 with tf.control_dependencies(update_ops): 20 d_train_opt = tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(d_loss,var_list = d_vars) 21 g_train_opt = tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(g_loss,var_list = g_vars) 22 23 return d_train_opt, g_train_opt
现在,我们网络的模块,损失函数,以及优化的过程都定义好了,现在我们就要开始训练我们的网络了,我们的训练过程定义如下。
1 def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape, data_image_mode): 2 """ 3 Train the GAN 4 :param epoch_count: Number of epochs 5 :param batch_size: Batch Size 6 :param z_dim: Z dimension 7 :param learning_rate: Learning Rate 8 :param beta1: The exponential decay rate for the 1st moment in the optimizer 9 :param get_batches: Function to get batches 10 :param data_shape: Shape of the data 11 :param data_image_mode: The image mode to use for images ("RGB" or "L") 12 """ 13 losses = [] 14 samples = [] 15 16 input_real, input_z, lr = model_inputs(data_shape[1], data_shape[2], data_shape[3], z_dim) 17 18 d_loss, g_loss = model_loss(input_real,input_z,data_shape[-1]) 19 20 d_opt, g_opt = model_opt(d_loss, g_loss, learning_rate, beta1) 21 22 steps = 0 23 24 with tf.Session() as sess: 25 sess.run(tf.global_variables_initializer()) 26 for epoch_i in range(epoch_count): 27 for batch_images in get_batches(batch_size): 28 # TODO: Train Model 29 steps += 1 30 31 # Reshape the image and pass to Discriminator 32 batch_images = batch_images.reshape(batch_size, 33 data_shape[1], 34 data_shape[2], 35 data_shape[3]) 36 # Rescale the data to -1 and 1 37 batch_images = batch_images * 2 38 39 # Sample the noise 40 batch_z = np.random.uniform(-1,1,size = (batch_size, z_dim)) 41 42 43 ## Run optimizer 44 _ = sess.run(d_opt, feed_dict = {input_real:batch_images, 45 input_z:batch_z, 46 lr:learning_rate 47 }) 48 _ = sess.run(g_opt, feed_dict = {input_real:batch_images, 49 input_z:batch_z, 50 lr:learning_rate}) 51 52 if steps % 10 == 0: 53 54 train_loss_d = d_loss.eval({input_real:batch_images, input_z:batch_z}) 55 train_loss_g = g_loss.eval({input_real:batch_images, input_z:batch_z}) 56 57 losses.append((train_loss_d,train_loss_g)) 58 59 print("Epoch {}/{}...".format(epoch_i+1, epochs), 60 "Discriminator Loss: {:.4f}...".format(train_loss_d), 61 "Generator Loss: {:.4f}".format(train_loss_g)) 62 63 if steps % 100 == 0: 64 65 show_generator_output(sess, 25, input_z, data_shape[-1], data_image_mode)
开始训练,超参数的设置
对于MNIST
1 batch_size = 64 2 z_dim = 100 3 learning_rate = 0.001 4 beta1 = 0.5 5 epochs = 2 6 7 mnist_dataset = helper.Dataset('mnist', glob(os.path.join(data_dir, 'mnist/*.jpg'))) 8 with tf.Graph().as_default(): 9 train(epochs, batch_size, z_dim, learning_rate, beta1, mnist_dataset.get_batches, 10 mnist_dataset.shape, mnist_dataset.image_mode)
训练效果如下
开始的时候,网络的参数很差,我们生成的手写数字的效果自然就不好
随着训练的进行,轮廓逐渐清晰,效果如下,到最后:
我们看到数字的轮廓基本是清晰可以辨认的,当然,这只是两个epoch的结果,如果有足够的时间经过更长时间的训练,效果会更好。
我们同样展示下对celeba人脸数据集的训练结果
1 batch_size = 32 2 z_dim = 100 3 learning_rate = 0.001 4 beta1 = 0.4 5 epochs = 1 6 7 celeba_dataset = helper.Dataset('celeba', glob(os.path.join(data_dir, 'img_align_celeba/*.jpg'))) 8 with tf.Graph().as_default(): 9 train(epochs, batch_size, z_dim, learning_rate, beta1, celeba_dataset.get_batches, 10 celeba_dataset.shape, celeba_dataset.image_mode)
训练开始:
经过一个epoch之后:
人脸的轮廓基本清晰了。
这里我们就是用了DCGAN最简单的方式来实现,原理过程说的不是很详细,同时,可能这个参数设置也不是很合理,训练的也不够成分,但是我想可以帮大家快速掌握实现一个简单的DCGAN的方法了。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:利用tensorflow训练简单的生成对抗网络GAN - Python技术站