Source code for antspynet.architectures.create_cycle_gan_model

import tensorflow as tf

import tensorflow.keras.backend as K

from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import (Input, Concatenate, Dense, Activation,
                                     BatchNormalization, Reshape, Dropout,
                                     Flatten, LeakyReLU, Conv2D, Conv3D,
                                     UpSampling2D, UpSampling3D)
from tensorflow.keras.optimizers import Adam

from ..utilities import InstanceNormalization

import numpy as np
import os

import matplotlib.pyplot as plot

import ants

[docs]class CycleGanModel(object): """ Cycle GAN model Cycle generative adverserial network from the paper: https://arxiv.org/pdf/1703.10593 and ported from the Keras (python) implementation: https://github.com/eriklindernoren/Keras-GAN/blob/master/cyclegan/cyclegan.py Arguments --------- input_image_size : tuple Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). latent_dimension : integer Returns ------- Keras model A Keras model defining the network. """ def __init__(self, input_image_size, lambda_cycle_loss_weight=10.0, lambda_identity_loss_weight=1.0, number_of_filters_at_base_layer=(32, 64)): super(CycleGanModel, self).__init__() self.input_image_size = input_image_size self.number_of_channels = self.input_image_size[-1] self.discriminator_patch_size = None self.lambda_cycle_loss_weight = lambda_cycle_loss_weight self.lambda_identity_loss_weight = lambda_identity_loss_weight self.number_of_filters_at_base_layer = number_of_filters_at_base_layer self.dimensionality = None if len(self.input_image_size) == 3: self.dimensionality = 2 elif len(self.input_image_size) == 4: self.dimensionality = 3 else: raise ValueError("Incorrect size for input_image_size.") optimizer = Adam(lr=0.0002, beta_1=0.5) # Build discriminators for domains A and B self.discriminatorA = self.build_discriminator() self.discriminatorA.compile(loss='mse', optimizer=optimizer, metrics=['acc']) self.discriminatorA.trainable = False self.discriminatorB = self.build_discriminator() self.discriminatorB.compile(loss='mse', optimizer=optimizer, metrics=['acc']) self.discriminatorB.trainable = False # Build u-net like generators self.generatorAtoB = self.build_generator() self.generatorBtoA = self.build_generator() imageA = Input(shape=input_image_size) imageB = Input(shape=input_image_size) fake_imageA = self.generatorBtoA(imageB) fake_imageB = self.generatorAtoB(imageA) reconstructed_imageA = self.generatorBtoA(fake_imageB) reconstructed_imageB = self.generatorAtoB(fake_imageA) identity_imageA = self.generatorBtoA(imageA) identity_imageB = self.generatorAtoB(imageB) # Check images validityA = self.discriminatorA(fake_imageA) validityB = self.discriminatorB(fake_imageB) # Combined models self.combined_model = Model(inputs=[imageA, imageB], outputs=[validityA, validityB, reconstructed_imageA, reconstructed_imageB, identity_imageA, identity_imageB]) self.combined_model.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'], loss_weights=[1.0, 1.0, self.lambda_cycle_loss_weight, self.lambda_cycle_loss_weight, self.lambda_identity_loss_weight, self.lambda_identity_loss_weight], optimizer=optimizer) def build_generator(self): def build_encoding_layer(input, number_of_filters, kernel_size=4): encoder = input if self.dimensionality == 2: encoder = Conv2D(filters=number_of_filters, kernel_size=kernel_size, strides=2, padding='same')(encoder) else: encoder = Conv3D(filters=number_of_filters, kernel_size=kernel_size, strides=2, padding='same')(encoder) encoder = LeakyReLU(alpha=0.2)(encoder) encoder = InstanceNormalization()(encoder) return(encoder) def build_decoding_layer(input, skip_input, number_of_filters, kernel_size=4, dropout_rate=0.0): decoder = input if self.dimensionality == 2: decoder = UpSampling2D(size=2)(decoder) decoder = Conv2D(filters=number_of_filters, kernel_size=kernel_size, strides=1, padding='same', activation='relu')(decoder) else: decoder = UpSampling3D(size=2)(decoder) decoder = Conv3D(filters=number_of_filters, kernel_size=kernel_size, strides=1, padding='same', activation='relu')(decoder) if dropout_rate > 0.0: decoder = Dropout(dropout_rate=dropout_rate)(decoder) decoder = LeakyReLU(alpha=0.2)(decoder) decoder = Concatenate()([decoder, skip_input]) return(decoder) input = Input(shape=self.input_image_size) encoding_layers = list() encoding_layers.append(build_encoding_layer(input, int(self.number_of_filters_at_base_layer[0]))) encoding_layers.append(build_encoding_layer(encoding_layers[0], int(self.number_of_filters_at_base_layer[0] * 2))) encoding_layers.append(build_encoding_layer(encoding_layers[1], int(self.number_of_filters_at_base_layer[0] * 4))) encoding_layers.append(build_encoding_layer(encoding_layers[2], int(self.number_of_filters_at_base_layer[0] * 8))) decoding_layers = list() decoding_layers.append(build_decoding_layer(encoding_layers[3], encoding_layers[2], int(self.number_of_filters_at_base_layer[0] * 4))) decoding_layers.append(build_decoding_layer(decoding_layers[0], encoding_layers[1], int(self.number_of_filters_at_base_layer[0] * 2))) decoding_layers.append(build_decoding_layer(decoding_layers[1], encoding_layers[0], int(self.number_of_filters_at_base_layer[0]))) if self.dimensionality == 2: decoding_layers.append(UpSampling2D(size=2)(decoding_layers[-1])) decoding_layers[-1] = Conv2D(filters=self.number_of_channels, kernel_size=4, strides=1, padding='same', activation='tanh')(decoding_layers[-1]) else: decoding_layers.append(UpSampling3D(size=2)(decoding_layers[-1])) decoding_layers[-1] = Conv2D(filters=self.number_of_channels, kernel_size=4, strides=1, padding='same', activation='tanh')(decoding_layers[-1]) generator = Model(inputs=input, outputs=decoding_layers[-1]) return(generator) def build_discriminator(self): def build_layer(input, number_of_filters, kernel_size=4, normalization=True): layer = input if self.dimensionality == 2: layer = Conv2D(filters=number_of_filters, kernel_size=kernel_size, strides=2, padding='same')(layer) else: layer = Conv3D(filters=number_of_filters, kernel_size=kernel_size, strides=2, padding='same')(layer) layer = LeakyReLU(alpha=0.2)(layer) if normalization == True: layer = InstanceNormalization()(layer) return(layer) input = Input(shape=self.input_image_size) layers = list() layers.append(build_layer(input, int(self.number_of_filters_at_base_layer[1]))) layers.append(build_layer(layers[0], int(self.number_of_filters_at_base_layer[1] * 2))) layers.append(build_layer(layers[1], int(self.number_of_filters_at_base_layer[1] * 4))) layers.append(build_layer(layers[2], int(self.number_of_filters_at_base_layer[1] * 8))) validity = None if self.dimensionality == 2: validity = Conv2D(filters=1, kernel_size=4, strides=1, padding='same')(layers[3]) else: validity = Conv3D(filters=1, kernel_size=4, strides=1, padding='same')(layers[3]) if self.discriminator_patch_size is None: self.discriminator_patch_size = K.int_shape(validity)[1:] discriminator = Model(inputs=input, outputs=validity) return(discriminator) def train(self, X_trainA, X_trainB, number_of_epochs, batch_size=128, sample_interval=None, sample_file_prefix='sample'): valid = np.ones((batch_size, *self.discriminator_patch_size)) fake = np.zeros((batch_size, *self.discriminator_patch_size)) for epoch in range(number_of_epochs): indicesA = np.random.randint(0, X_trainA.shape[0] - 1, batch_size) imagesA = X_trainA[indicesA] indicesB = np.random.randint(0, X_trainB.shape[0] - 1, batch_size) imagesB = X_trainB[indicesB] # train discriminator fake_imagesA = self.generatorAtoB.predict(imagesA) fake_imagesB = self.generatorBtoA.predict(imagesB) dA_loss_real = self.discriminatorA.train_on_batch(imagesA, valid) dA_loss_fake = self.discriminatorA.train_on_batch(fake_imagesA, fake) dB_loss_real = self.discriminatorB.train_on_batch(imagesB, valid) dB_loss_fake = self.discriminatorB.train_on_batch(fake_imagesB, fake) d_loss = list() for i in range(len(dA_loss_real)): d_loss.append(0.25 * (dA_loss_real[i] + dA_loss_fake[i] + dB_loss_real[i] + dB_loss_fake[i])) # train generator g_loss = self.combined_model.train_on_batch([imagesA, imagesB], [valid, valid, imagesA, imagesB, imagesA, imagesB]) print("Epoch ", epoch, ": [Discriminator loss: ", d_loss[0], " acc: ", d_loss[1], "] ", "[Generator loss: ", g_loss[0], ", ", np.mean(g_loss[1:3]), ", ", np.mean(g_loss[3:5]), ", ", np.mean(g_loss[5:6]), "]") if self.dimensionality == 2: if sample_interval != None: if epoch % sample_interval == 0: # Do a 2x3 grid # # imageA | translated( imageA ) | reconstructed( imageA ) # imageB | translated( imageB ) | reconstructed( imageB ) indexA = np.random.randint(0, X_trainA.shape[0] - 1, 1) indexB = np.random.randint(0, X_trainB.shape[0] - 1, 1) imageA = X_trainA[indexA,:,:,:] imageB = X_trainB[indexB,:,:,:] X = list() X.append(imageA) X.append(self.generatorAtoB.predict(X[0])) X.append(self.generatorBtoA.predict(X[1])) X.append(imageB) X.append(self.generatorAtoB.predict(X[3])) X.append(self.generatorBtoA.predict(X[4])) plot_images = np.concatenate(X) plot_images = 0.5 * plot_images + 0.5 titles = ['Original', 'Translated', 'Reconstructed'] figure, axes = plot.subplots(2, 3) count = 0 for i in range(2): for j in range(3): axes[i, j].imshow(plot_images[count]) axes[i, j].set_title(titles[j]) axes[i, j].axis('off') count += 1 image_file_name = sample_file_prefix + "_iteration" + str(epoch) + ".jpg" dir_name = os.path.dirname(sample_file_prefix) if not os.path.exists(dir_name): os.mkdir(dir_name) figure.savefig(image_file_name) plot.close()