DCGAN网络的结构:

tensorflow 2.0 学习 (十六)生成对抗网络 GAN网络与WGAN网络

 

 代码包括:

数据:

  1 import tensorflow as tf
  2 import multiprocessing
  3 
  4 
  5 def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
  6     @tf.function
  7     def _map_fn(img):
  8         img = tf.image.resize(img, [resize, resize])
  9         img = tf.clip_by_value(img, 0, 255)
 10         img = img / 127.5 - 1
 11 
 12         return img
 13 
 14     dataset = disk_image_batch_dataset(img_paths, batch_size, drop_remainder=drop_remainder,
 15                                        map_fn=_map_fn, shuffle=shuffle, repeat=repeat)
 16     img_shape = (resize, resize, 3)
 17     len_dataset = len(img_paths) // batch_size
 18 
 19     return dataset, img_shape, len_dataset
 20 
 21 
 22 def batch_dataset(dataset,
 23                   batch_size,
 24                   drop_remainder=True,
 25                   n_prefetch_batch=1,
 26                   filter_fn=None,
 27                   map_fn=None,
 28                   n_map_threads=None,
 29                   filter_after_map=False,
 30                   shuffle=True,
 31                   shuffle_buffer_size=None,
 32                   repeat=None):
 33     # set defaults
 34     if n_map_threads is None:
 35         n_map_threads = multiprocessing.cpu_count()
 36 
 37     if shuffle and shuffle_buffer_size is None:
 38         shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048
 39 
 40     # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
 41     if shuffle:
 42         dataset = dataset.shuffle(shuffle_buffer_size)
 43 
 44     if not filter_after_map:
 45         if filter_fn:
 46             dataset = dataset.filter(filter_fn)
 47 
 48         if map_fn:
 49             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
 50 
 51     else:  # [*] this is slower
 52         if map_fn:
 53             dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
 54 
 55         if filter_fn:
 56             dataset = dataset.filter(filter_fn)
 57 
 58     dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
 59     dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
 60 
 61     return dataset
 62 
 63 
 64 def memory_data_batch_dataset(memory_data,
 65                               batch_size,
 66                               drop_remainder=True,
 67                               n_prefetch_batch=1,
 68                               filter_fn=None,
 69                               map_fn=None,
 70                               n_map_threads=None,
 71                               filter_after_map=False,
 72                               shuffle=True,
 73                               shuffle_buffer_size=None,
 74                               repeat=None):
 75     """Batch dataset of memory data.
 76     Parameters
 77     ----------
 78     memory_data : nested structure of tensors/ndarrays/lists
 79     """
 80 
 81     dataset = tf.data.Dataset.from_tensor_slices(memory_data)
 82     dataset = batch_dataset(dataset, batch_size,
 83                             drop_remainder=drop_remainder,
 84                             n_prefetch_batch=n_prefetch_batch,
 85                             filter_fn=filter_fn,
 86                             map_fn=map_fn,
 87                             n_map_threads=n_map_threads,
 88                             filter_after_map=filter_after_map,
 89                             shuffle=shuffle,
 90                             shuffle_buffer_size=shuffle_buffer_size,
 91                             repeat=repeat)
 92 
 93     return dataset
 94 
 95 
 96 def disk_image_batch_dataset(img_paths,
 97                              batch_size,
 98                              labels=None,
 99                              drop_remainder=True,
100                              n_prefetch_batch=1,
101                              filter_fn=None,
102                              map_fn=None,
103                              n_map_threads=None,
104                              filter_after_map=False,
105                              shuffle=True,
106                              shuffle_buffer_size=None,
107                              repeat=None):
108     """Batch dataset of disk image for PNG and JPEG.
109     Parameters
110     ----------
111         img_paths : 1d-tensor/ndarray/list of str
112         labels : nested structure of tensors/ndarrays/lists
113     """
114 
115     if labels is None:
116         memory_data = img_paths
117 
118     else:
119         memory_data = (img_paths, labels)
120 
121     def parse_fn(path, *label):
122         img = tf.io.read_file(path)
123         img = tf.image.decode_png(img, 3)  # fix channels to 3
124         return (img,) + label
125 
126     if map_fn:  # fuse `map_fn` and `parse_fn`
127         def map_fn_(*args):
128             return map_fn(*parse_fn(*args))
129     else:
130         map_fn_ = parse_fn
131 
132     dataset = memory_data_batch_dataset(memory_data,
133                                         batch_size,
134                                         drop_remainder=drop_remainder,
135                                         n_prefetch_batch=n_prefetch_batch,
136                                         filter_fn=filter_fn,
137                                         map_fn=map_fn_,
138                                         n_map_threads=n_map_threads,
139                                         filter_after_map=filter_after_map,
140                                         shuffle=shuffle,
141                                         shuffle_buffer_size=shuffle_buffer_size,
142                                         repeat=repeat)
143 
144     return dataset

GAN:

 1 import tensorflow as tf
 2 from tensorflow.keras import layers, Model
 3 
 4 
 5 class Generator(Model):
 6     # 生成器网络类
 7     def __init__(self):
 8         super(Generator, self).__init__()
 9         filter = 64
10         # 转置卷积层1,输出channel 为filter*8,核大小4,步长1,不使用padding,不使用偏置
11         self.conv1 = layers.Conv2DTranspose(filter*8, 4,1, 'valid', use_bias=False)
12         self.bn1 = layers.BatchNormalization()
13         # 转置卷积层2
14         self.conv2 = layers.Conv2DTranspose(filter * 4, 4, 2, 'same', use_bias=False)
15         self.bn2 = layers.BatchNormalization()
16         # 转置卷积层3
17         self.conv3 = layers.Conv2DTranspose(filter * 2, 4, 2, 'same', use_bias=False)
18         self.bn3 = layers.BatchNormalization()
19         # 转置卷积层4
20         self.conv4 = layers.Conv2DTranspose(filter * 1, 4, 2, 'same', use_bias=False)
21         self.bn4 = layers.BatchNormalization()
22         # 转置卷积层5
23         self.conv5 = layers.Conv2DTranspose(3, 4, 2, 'same', use_bias=False)
24 
25     def call(self, inputs, training=None):
26         x = inputs  # [z, 100]
27         # Reshape 乘4D 张量,方便后续转置卷积运算:(b, 1, 1, 100)
28         x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
29         x = tf.nn.relu(x)  # 激活函数
30         # 转置卷积-BN-激活函数:(b, 4, 4, 512)
31         x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
32         # 转置卷积-BN-激活函数:(b, 8, 8, 256)
33         x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
34         # 转置卷积-BN-激活函数:(b, 16, 16, 128)
35         x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
36         # 转置卷积-BN-激活函数:(b, 32, 32, 64)
37         x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
38         # 转置卷积-激活函数:(b, 64, 64, 3)
39         x = self.conv5(x)
40         x = tf.tanh(x)  # 输出x 范围-1~1,与预处理一致
41 
42         return x
43 
44 
45 class Discriminator(Model):
46     # 判别器类
47     def __init__(self):
48         super(Discriminator, self).__init__()
49         filter = 64
50         # 卷积层1
51         self.conv1 = layers.Conv2D(filter, 4, 2, 'valid', use_bias=False)
52         self.bn1 = layers.BatchNormalization()
53         # 卷积层2
54         self.conv2 = layers.Conv2D(filter*2, 4, 2, 'valid', use_bias=False)
55         self.bn2 = layers.BatchNormalization()
56         # 卷积层3
57         self.conv3 = layers.Conv2D(filter * 4, 4, 2, 'valid', use_bias=False)
58         self.bn3 = layers.BatchNormalization()
59         # 卷积层4
60         self.conv4 = layers.Conv2D(filter * 8, 3, 1, 'valid', use_bias=False)
61         self.bn4 = layers.BatchNormalization()
62         # 卷积层5
63         self.conv5 = layers.Conv2D(filter * 16, 3, 1, 'valid', use_bias=False)
64         self.bn5 = layers.BatchNormalization()
65         # 全局池化层
66         self.pool = layers.GlobalAveragePooling2D()
67         # 特征打平层
68         self.flatten = layers.Flatten()
69         # 2 分类全连接层
70         self.fc = layers.Dense(1)
71 
72     def call(self, inputs, training=None):
73         # 卷积-BN-激活函数:(4, 31, 31, 64)
74         x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training) )
75         # 卷积-BN-激活函数:(4, 14, 14, 128)
76         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
77         # 卷积-BN-激活函数:(4, 6, 6, 256)
78         x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
79         # 卷积-BN-激活函数:(4, 4, 4, 512)
80         x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))
81         # 卷积-BN-激活函数:(4, 2, 2, 1024)
82         x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
83         # 卷积-BN-激活函数:(4, 1024)
84         x = self.pool(x)
85         # 打平
86         x = self.flatten(x)
87         # 输出,[b, 1024] => [b, 1]
88         logits = self.fc(x)
89 
90         return logits

训练:

  1 import os
  2 import glob
  3 import numpy as np
  4 
  5 import tensorflow as tf
  6 from tensorflow import keras
  7 
  8 from GAN import Generator, Discriminator
  9 from Dataset import make_anime_dataset
 10 
 11 from PIL import Image
 12 import scipy.misc
 13 import matplotlib.pyplot as plt
 14 
 15 
 16 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
 17     # 计算判别器的误差函数
 18     # 采样生成图片
 19     fake_image = generator(batch_z, is_training)
 20     # 判定生成图片
 21     d_fake_logits = discriminator(fake_image, is_training)
 22     # 判定真实图片
 23     d_real_logits = discriminator(batch_x, is_training)
 24     # 真实图片与1 之间的误差
 25     d_loss_real = celoss_ones(d_real_logits)
 26     # 生成图片与0 之间的误差
 27     d_loss_fake = celoss_zeros(d_fake_logits)
 28     # 合并误差
 29     loss = d_loss_fake + d_loss_real
 30 
 31     return loss
 32 
 33 
 34 def celoss_ones(logits):
 35     # 计算属于与标签为1 的交叉熵
 36     y = tf.ones_like(logits)
 37     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
 38 
 39     return tf.reduce_mean(loss)
 40 
 41 
 42 def celoss_zeros(logits):
 43     # 计算属于与便签为0 的交叉熵
 44     y = tf.zeros_like(logits)
 45     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
 46 
 47     return tf.reduce_mean(loss)
 48 
 49 
 50 def g_loss_fn(generator, discriminator, batch_z, is_training):
 51     # 采样生成图片
 52     fake_image = generator(batch_z, is_training)
 53     # 在训练生成网络时,需要迫使生成图片判定为真
 54     d_fake_logits = discriminator(fake_image, is_training)
 55     # 计算生成图片与1 之间的误差
 56     loss = celoss_ones(d_fake_logits)
 57 
 58     return loss
 59 
 60 
 61 def save_result(val_out, val_block_size, image_path, color_mode):
 62     def preprocess(img):
 63         img = ((img + 1.0) * 127.5).astype(np.uint8)
 64         # img = img.astype(np.uint8)
 65         return img
 66 
 67     preprocesed = preprocess(val_out)
 68     final_image = np.array([])
 69     single_row = np.array([])
 70 
 71     for b in range(val_out.shape[0]):
 72         # concat image into a row
 73         if single_row.size == 0:
 74             single_row = preprocesed[b, :, :, :]
 75         else:
 76             single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
 77 
 78         # concat image row to final_image
 79         if (b + 1) % val_block_size == 0:
 80             if final_image.size == 0:
 81                 final_image = single_row
 82             else:
 83                 final_image = np.concatenate((final_image, single_row), axis=0)
 84 
 85             # reset single row
 86             single_row = np.array([])
 87 
 88     if final_image.shape[2] == 1:
 89         final_image = np.squeeze(final_image, axis=2)
 90     im = Image.fromarray(final_image)
 91     im.save('exam11_final_image.png')
 92     # Image.save(final_image)
 93     # Image(final_image).save(image_path)
 94 
 95 
 96 d_losses, g_losses = [], []
 97 
 98 
 99 def draw():
100     plt.figure()
101     plt.plot(d_losses, 'b', label='generator')
102     plt.plot(g_losses, 'r', label='discriminator')
103     plt.xlabel('Epoch')
104     plt.ylabel('ACC')
105     plt.legend()
106     plt.savefig('exam11.1_train_test_VAE.png')
107     plt.show()
108 
109 
110 def main():
111     batch_size = 64
112     learning_rate = 0.0002
113     z_dim = 100
114     is_training = True
115     epochs = 300
116 
117     img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg')
118     print('images num:', len(img_path))
119     # 构建数据集对象,返回数据集Dataset 类和图片大小
120     dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)  # (64, 64, 64, 3) (64, 64, 3)
121     sample = next(iter(dataset))  # 采样  (64, 64, 64, 3)
122     print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())  # (64, 64, 64, 3) 1.0 -1.0
123     dataset = dataset.repeat(100)  # 重复循环
124     db_iter = iter(dataset)
125 
126     generator = Generator()  # 创建生成器
127     generator.build(input_shape=(4, z_dim))
128     discriminator = Discriminator()  # 创建判别器
129     discriminator.build(input_shape=(4, 64, 64, 3))
130     # 分别为生成器和判别器创建优化器
131     g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
132     d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
133 
134     # generator.load_weights('exam11.1_generator.ckpt')
135     # discriminator.load_weights('exam11.1_discriminator.ckpt')
136     # print('Loaded chpt!!')
137 
138     for epoch in range(epochs):  # 训练epochs 次
139         # 1. 训练判别器
140         for _ in range(5):
141             # 采样隐藏向量
142             batch_z = tf.random.normal([batch_size, z_dim])
143             batch_x = next(db_iter)  # 采样真实图片
144             # 判别器前向计算
145             with tf.GradientTape() as tape:
146                 d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
147                 grads = tape.gradient(d_loss, discriminator.trainable_variables)
148                 d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
149 
150         # 2. 训练生成器
151         # 采样隐藏向量
152         batch_z = tf.random.normal([batch_size, z_dim])
153         batch_x = next(db_iter)  # 采样真实图片
154         # 生成器前向计算
155         with tf.GradientTape() as tape:
156             g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
157         grads = tape.gradient(g_loss, generator.trainable_variables)
158         g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
159 
160         if epoch % 100 == 0:
161             print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))  # 可视化
162             z = tf.random.normal([100, z_dim])
163             fake_image = generator(z, training=False)
164             img_path = os.path.join('gan_images', 'gan-%d.png' % epoch)
165             save_result(fake_image.numpy(), 10, img_path, color_mode='P')
166 
167             d_losses.append(float(d_loss))
168             g_losses.append(float(g_loss))
169 
170             if epoch % 10000 == 1:
171                 # print(d_losses)
172                 # print(g_losses)
173                 generator.save_weights('exam11.1_generator.ckpt')
174                 discriminator.save_weights('exam11.1_discriminator.ckpt')
175 
176 
177 if __name__ == '__main__':
178     main()
179     draw()

没有结果,代码没有报错,个人认为还是受机器的限制;

WGAN-GP:

 1 import tensorflow as tf
 2 from tensorflow.keras import layers, Model
 3 
 4 
 5 class Generator(Model):
 6     def __init__(self):
 7         super(Generator, self).__init__()
 8         # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
 9         self.fc = layers.Dense(3*3*512)
10         self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
11         self.bn1 = layers.BatchNormalization()
12 
13         self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
14         self.bn2 = layers.BatchNormalization()
15         self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
16 
17     def call(self, inputs, training=None):
18         # [z, 100] => [z, 3*3*512]
19         x = self.fc(inputs)
20         x = tf.reshape(x, [-1, 3, 3, 512])
21         x = tf.nn.leaky_relu(x)
22 
23         #
24         x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
25         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
26         x = self.conv3(x)
27         x = tf.tanh(x)
28 
29         return x
30 
31 
32 class Discriminator(Model):
33     def __init__(self):
34         super(Discriminator, self).__init__()
35 
36         # [b, 64, 64, 3] => [b, 1]
37         self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
38         self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
39         self.bn2 = layers.BatchNormalization()
40 
41         self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
42         self.bn3 = layers.BatchNormalization()
43 
44         # [b, h, w ,c] => [b, -1]
45         self.flatten = layers.Flatten()
46         self.fc = layers.Dense(1)
47 
48 
49     def call(self, inputs, training=None):
50         x = tf.nn.leaky_relu(self.conv1(inputs))
51         x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
52         x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
53 
54         # [b, h, w, c] => [b, -1]
55         x = self.flatten(x)
56 
57         # [b, -1] => [b, 1]
58         logits = self.fc(x)
59         return logits
60 
61 
62 def main():
63     d = Discriminator()
64     g = Generator()
65 
66     x = tf.random.normal([2, 64, 64, 3])
67     z = tf.random.normal([2, 100])
68 
69     prob = d(x)
70     print(prob)
71     x_hat = g(z)
72     print(x_hat.shape)
73 
74 
75 if __name__ == '__main__':
76     main()

训练代码:

  1 import os
  2 import glob
  3 import numpy as np
  4 
  5 import tensorflow as tf
  6 from tensorflow import keras
  7 
  8 from WGAN import Generator, Discriminator
  9 from Dataset import make_anime_dataset
 10 
 11 from PIL import Image
 12 import matplotlib.pyplot as plt
 13 
 14 
 15 def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
 16     # 计算D 的损失函数
 17     fake_image = generator(batch_z, is_training) # 假样本
 18     d_fake_logits = discriminator(fake_image, is_training) # 假样本的输出
 19     d_real_logits = discriminator(batch_x, is_training) # 真样本的输出
 20     # 计算梯度惩罚项
 21     gp = gradient_penalty(discriminator, batch_x, fake_image)
 22     # WGAN-GP D 损失函数的定义,这里并不是计算交叉熵,而是直接最大化正样本的输出
 23     # 最小化假样本的输出和梯度惩罚项
 24     loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + 10. * gp
 25 
 26     return loss, gp
 27 
 28 
 29 def celoss_ones(logits):
 30     # 计算属于与标签为1 的交叉熵
 31     y = tf.ones_like(logits)
 32     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
 33 
 34     return tf.reduce_mean(loss)
 35 
 36 
 37 def celoss_zeros(logits):
 38     # 计算属于与便签为0 的交叉熵
 39     y = tf.zeros_like(logits)
 40     loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
 41 
 42     return tf.reduce_mean(loss)
 43 
 44 
 45 def gradient_penalty(discriminator, batch_x, fake_image):
 46     # 梯度惩罚项计算函数
 47     batchsz = batch_x.shape[0]
 48 
 49     # 每个样本均随机采样t,用于插值
 50     t = tf.random.uniform([batchsz, 1, 1, 1])
 51     # 自动扩展为x 的形状,[b, 1, 1, 1] => [b, h, w, c]
 52     t = tf.broadcast_to(t, batch_x.shape)
 53 
 54     # 在真假图片之间做线性插值
 55     interplate = t * batch_x + (1 - t) * fake_image
 56     # 在梯度环境中计算D 对插值样本的梯度
 57     with tf.GradientTape() as tape:
 58         tape.watch([interplate])  # 加入梯度观察列表
 59         d_interplote_logits = discriminator(interplate)
 60     grads = tape.gradient(d_interplote_logits, interplate)
 61 
 62     # 计算每个样本的梯度的范数:[b, h, w, c] => [b, -1]
 63     grads = tf.reshape(grads, [grads.shape[0], -1])
 64     gp = tf.norm(grads, axis=1)  # [b]
 65     # 计算梯度惩罚项
 66     gp = tf.reduce_mean((gp - 1.) ** 2)
 67 
 68     return gp
 69 
 70 
 71 def g_loss_fn(generator, discriminator, batch_z, is_training):
 72     # 生成器的损失函数
 73     fake_image = generator(batch_z, is_training)
 74     d_fake_logits = discriminator(fake_image, is_training)
 75     # WGAN-GP G 损失函数,最大化假样本的输出值
 76     loss = - tf.reduce_mean(d_fake_logits)
 77 
 78     return loss
 79 
 80 
 81 def save_result(val_out, val_block_size, image_path, color_mode):
 82     def preprocess(img):
 83         img = ((img + 1.0) * 127.5).astype(np.uint8)
 84         # img = img.astype(np.uint8)
 85         return img
 86 
 87     preprocesed = preprocess(val_out)
 88     final_image = np.array([])
 89     single_row = np.array([])
 90 
 91     for b in range(val_out.shape[0]):
 92         # concat image into a row
 93         if single_row.size == 0:
 94             single_row = preprocesed[b, :, :, :]
 95         else:
 96             single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
 97 
 98         # concat image row to final_image
 99         if (b + 1) % val_block_size == 0:
100             if final_image.size == 0:
101                 final_image = single_row
102             else:
103                 final_image = np.concatenate((final_image, single_row), axis=0)
104 
105             # reset single row
106             single_row = np.array([])
107 
108     if final_image.shape[2] == 1:
109         final_image = np.squeeze(final_image, axis=2)
110     im = Image.fromarray(final_image)
111     im.save('exam11_WGAN_final_image.png')
112     # Image.save(final_image)
113     # Image(final_image).save(image_path)
114 
115 
116 d_losses, g_losses = [], []
117 
118 
119 def draw():
120     plt.figure()
121     plt.plot(d_losses, 'b', label='generator')
122     plt.plot(g_losses, 'r', label='discriminator')
123     plt.xlabel('Epoch')
124     plt.ylabel('ACC')
125     plt.legend()
126     plt.savefig('exam11.2_train_test_VAE.png')
127     plt.show()
128 
129 
130 def main():
131     batch_size = 512
132     learning_rate = 0.002
133     z_dim = 100
134     is_training = True
135     epochs = 300
136 
137     img_path = glob.glob(r'G:2020pythonfacesfaces*.jpg')
138     print('images num:', len(img_path))  # images num: 51223
139     # 构建数据集对象,返回数据集Dataset 类和图片大小
140     dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)  # (512, 64, 64, 3) (64, 64, 3)
141     sample = next(iter(dataset))  # 采样  (512, 64, 64, 3)
142     print(sample.shape, tf.reduce_max(sample).numpy(), tf.reduce_min(sample).numpy())  # (512, 64, 64, 3) 1.0 -1.0
143     dataset = dataset.repeat(100)  # 重复循环
144     db_iter = iter(dataset)
145 
146     generator = Generator()  # 创建生成器
147     generator.build(input_shape=(None, z_dim))
148     discriminator = Discriminator()  # 创建判别器
149     discriminator.build(input_shape=(None, 64, 64, 3))
150     # 分别为生成器和判别器创建优化器
151     g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
152     d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
153 
154     # generator.load_weights('exam11.1_generator.ckpt')
155     # discriminator.load_weights('exam11.1_discriminator.ckpt')
156     # print('Loaded chpt!!')
157 
158     for epoch in range(epochs):  # 训练epochs 次
159         # 采样隐藏向量
160         batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
161         batch_x = next(db_iter)
162 
163         # 判别器前向计算
164         with tf.GradientTape() as tape:
165             d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
166         grads = tape.gradient(d_loss, discriminator.trainable_variables)
167         d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
168 
169         with tf.GradientTape() as tape:
170             g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
171         grads = tape.gradient(g_loss, generator.trainable_variables)
172         g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
173 
174         if epoch % 100 == 0:
175             print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss), 'gp:', float(gp))
176             z = tf.random.uniform([100, z_dim])
177 
178             fake_image = generator(z, training=False)
179             img_path = os.path.join('images', 'wgan-%d.png'%epoch)
180             save_result(fake_image.numpy(), 10, img_path, color_mode='P')
181 
182         if epoch % 10000 == 1:
183             # print(d_losses)
184             # print(g_losses)
185             generator.save_weights('exam11.2_generator.ckpt')
186             discriminator.save_weights('exam11.2_discriminator.ckpt')
187 
188 
189 if __name__ == '__main__':
190     main()
191     draw()

同样没有结果,后面有条件再试一试;

这一部分对算法的要求高,要看懂他,得花时间看,

我没有去研究它,只是看代码去了。