对颜色、纹理等的转换效果比较好,对多样性高的、
多变的转换效果不好(如几何转换)

代码

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

import tensorflow as tf
import glob
from matplotlib import pyplot as plt
%matplotlib inline
AUTOTUNE = tf.data.experimental.AUTOTUNE
import os
os.listdir('../input/apple2orange/apple2orange')

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

imgs_A = glob.glob('../input/apple2orange/apple2orange/trainA/*.jpg')

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

imgs_B = glob.glob('../input/apple2orange/apple2orange/trainB/*.jpg')

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

test_A = glob.glob('../input/apple2orange/apple2orange/testA/*.jpg')
test_B = glob.glob('../input/apple2orange/apple2orange/testB/*.jpg')

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10

def read_jpg(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    return img
def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32)/127.5 - 1
    return input_image
def load_image(image_path):
    image = read_jpg(image_path)
    image = tf.image.resize(image, (256, 256))
    image = normalize(image)
    return image
train_a = tf.data.Dataset.from_tensor_slices(imgs_A)
train_b = tf.data.Dataset.from_tensor_slices(imgs_B)
test_a = tf.data.Dataset.from_tensor_slices(test_A)
test_b = tf.data.Dataset.from_tensor_slices(test_B)
BUFFER_SIZE = 200
train_a = train_a.map(load_image, 
                      num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
train_b = train_b.map(load_image, 
                      num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_a = test_a.map(load_image, 
                      num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_b = test_b.map(load_image, 
                      num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
data_train = tf.data.Dataset.zip((train_a, train_b))
data_test = tf.data.Dataset.zip((test_a, test_b))
plt.figure(figsize=(6, 3))
for img, musk in zip(train_a.take(1), train_b.take(1)):
    plt.subplot(1,2,1)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img[0]))
    plt.subplot(1,2,2)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(musk[0]))

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
实例归一化

!pip install tensorflow_addons
import tensorflow_addons as tfa
OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_batchnorm=True):
#    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                               use_bias=False))

    if apply_batchnorm:
        result.add(tfa.layers.InstanceNormalization())

        result.add(tf.keras.layers.LeakyReLU())

    return result
def upsample(filters, size, apply_dropout=False):
#    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        use_bias=False))

    result.add(tfa.layers.InstanceNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result
def Generator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])

    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

#    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)
generator_x = Generator()   # a——>o
generator_y = Generator()   # o——>a
#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
def Discriminator():
#    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')

    down1 = downsample(64, 4, False)(inp) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(
               512, 4, strides=1,use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(
               1, 4, strides=1)(zero_pad2)  # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)
discriminator_x = Discriminator()   # discriminator  a
discriminator_y = Discriminator()   # discriminator  o
#tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss
def generator_loss(disc_generated_output):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    return gan_loss
LAMBDA = 7
def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1
generator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
def generate_images(model, test_input):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()
@tf.function
def train_step(image_a, image_b):
    with tf.GradientTape(persistent=True) as tape:
        fake_b = generator_x(image_a, training=True)
        cycled_a = generator_y(fake_b, training=True)

        fake_a = generator_y(image_b, training=True)
        cycled_b = generator_x(fake_a, training=True)
        
        disc_real_a = discriminator_x(image_a, training=True)
        disc_real_b = discriminator_y(image_b, training=True)

        disc_fake_a = discriminator_x(fake_a, training=True)
        disc_fake_b = discriminator_y(fake_b, training=True)
        
        gen_x_loss = generator_loss(disc_fake_b)
        gen_y_loss = generator_loss(disc_fake_a)
    
        total_cycle_loss = (calc_cycle_loss(image_a, cycled_a) 
                               + calc_cycle_loss(image_b, cycled_b))
    
        # 总生成器损失 = 对抗性损失 + 循环损失。
        total_gen_x_loss = gen_x_loss + total_cycle_loss
        total_gen_y_loss = gen_y_loss + total_cycle_loss

        disc_x_loss = discriminator_loss(disc_real_a, disc_fake_a)
        disc_y_loss = discriminator_loss(disc_real_b, disc_fake_b)
  
    # 计算生成器和判别器损失。
    generator_x_gradients = tape.gradient(total_gen_x_loss, 
                                        generator_x.trainable_variables)
    generator_y_gradients = tape.gradient(total_gen_y_loss, 
                                        generator_y.trainable_variables)
  
    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
    
    # 将梯度应用于优化器。
    generator_x_optimizer.apply_gradients(zip(generator_x_gradients, 
                                              generator_x.trainable_variables))

    generator_y_optimizer.apply_gradients(zip(generator_y_gradients, 
                                              generator_y.trainable_variables))
  
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                  discriminator_x.trainable_variables))
  
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                  discriminator_y.trainable_variables))
def fit(train_ds, test_ds, epochs):
    for epoch in range(epochs+1):
        for img_a, img_b in train_ds:
            train_step(img_a, img_b)
        print ('.', end='')

        if epoch % 5 == 0:
            print()
            for test_a, test_b in test_ds.take(1):
                print("Epoch: ", epoch)
                generate_images(generator_x, test_a)
    generate_images(generator_x, test_a)
EPOCHS = 100
fit(data_train, data_test,  EPOCHS)

GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10
GAN生成对抗网络-CycleGAN原理与基本实现-图像转换-10