Source code for antspynet.architectures.create_resunet_model


import tensorflow as tf

import tensorflow.keras.backend as K

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dropout, BatchNormalization, Add,
                                     ThresholdedReLU, Concatenate, Dense,
                                     Conv2D, Conv2DTranspose,
                                     MaxPooling2D, UpSampling2D,
                                     Conv3D, Conv3DTranspose,
                                     MaxPooling3D, UpSampling3D)

from tensorflow.keras import initializers
from tensorflow.keras import regularizers

[docs]def create_resunet_model_2d(input_image_size, number_of_outputs=1, number_of_filters_at_base_layer=32, bottle_neck_block_depth_schedule=(3, 4), convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), dropout_rate=0.0, weight_decay=0.0, mode='classification' ): """ 2-D implementation of the Resnet + U-net deep learning architecture. Creates a keras model of the U-net + ResNet deep learning architecture for image segmentation and regression with the paper available here: https://arxiv.org/abs/1608.04117 This particular implementation was ported from the following python implementation: https://github.com/veugene/fcn_maker/ Arguments --------- input_image_size : tuple of length 3 Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). The batch size (i.e., number of training images) is not specified a priori. number_of_outputs : integer Meaning depends on the mode. For 'classification' this is the number of segmentation labels. For 'regression' this is the number of outputs. number_of_filters_at_base_layer : integer Number of filters at the beginning and end of the 'U'. Doubles at each descending/ascending layer. bottle_neck_block_depth_schedule : tuple Tuple that provides the encoding layer schedule for the number of bottleneck blocks per long skip connection. convolution_kernel_size : tuple of length 2 2-d vector defining the kernel size during the encoding path deconvolution_kernel_size : tuple of length 2 2-d vector defining the kernel size during the decoding dropout_rate : scalar Float between 0 and 1 to use between dense layers. weight_decay : scalar Weighting parameter for L2 regularization of the kernel weights of the convolution layers. Default = 0.0. mode : string 'classification' or 'regression'. Default = 'classification'. Returns ------- Keras model A 2-D Keras model defining the network. Example ------- >>> model = create_resunet_model_2d((128, 128, 1)) >>> model.summary() """ def simple_block_2d(input, number_of_filters, downsample=False, upsample=False, convolution_kernel_size=(3, 3), deconvolution_kernel_size=(2, 2), weight_decay=0.0, dropout_rate=0.0): number_of_output_filters = number_of_filters output = BatchNormalization()(input) output = ThresholdedReLU(theta = 0)(output) if downsample: output = MaxPooling2D(pool_size=(2, 2))(output) output = Conv2D(filters=number_of_filters, kernel_size=convolution_kernel_size, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(output) if upsample: output = Conv2DTranspose(filters=number_of_filters, kernel_size=deconvolution_kernel_size, padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = UpSampling2D(size=(2, 2))(output) if dropout_rate > 0.0: output=Dropout(rate=dropout_rate)(output) # Modify the input so that it has the same size as the output if downsample: input = Conv2D(filters=number_of_output_filters, kernel_size=(1, 1), strides=(2, 2), padding='same')(input) elif upsample: input = Conv2DTranspose(filters=number_of_output_filters, kernel_size=(1, 1), padding='same')(input) input = UpSampling2D(size=(2, 2))(input) elif number_of_filters != number_of_output_filters: input = Conv2D(filters=number_of_output_filters, kernel_size=(1, 1), padding='same')(input) output = skip_connection(input, output) return(output) def bottle_neck_block_2d(input, number_of_filters, downsample=False, upsample=False, deconvolution_kernel_size=(2, 2), weight_decay=0.0, dropout_rate=0.0): output = input number_of_output_filters = number_of_filters if downsample: output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) output = Conv2D(filters=number_of_filters, kernel_size=(1, 1), strides=(2, 2), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) output = Conv2D(filters=number_of_filters, kernel_size=(1, 1), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) if upsample: output = Conv2DTranspose(filters=number_of_filters, kernel_size=deconvolution_kernel_size, padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = UpSampling2D(size=(2, 2))(output) output = Conv2D(filters=(number_of_filters * 4), kernel_size=(1, 1), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) number_of_output_filters = number_of_filters * 4 if dropout_rate > 0.0: output=Dropout(rate=dropout_rate)(output) # Modify the input so that it has the same size as the output if downsample: input = Conv2D(filters=number_of_output_filters, kernel_size=(1, 1), strides=(2, 2), padding='same')(input) elif upsample: input = Conv2DTranspose(filters=number_of_output_filters, kernel_size=(1, 1), padding='same')(input) input = UpSampling2D(size=(2, 2))(input) elif number_of_filters != number_of_output_filters: input = Conv2D(filters=number_of_output_filters, kernel_size=(1, 1), padding='valid')(input) output = skip_connection(input, output) return(output) def skip_connection(source, target, merge_mode='sum'): layer_list = [source, target] output = None if merge_mode == 'sum': output = Add()(layer_list) else: channel_axis = 0 if K.image_data_format() == 'channels_last': channel_axis = -1 output = Concatenate(axis=channel_axis)(layer_list) return(output) inputs = Input(shape = input_image_size) encoding_layers_with_long_skip_connections = [] encoding_layer_count = 1 # Preprocessing layer model = Conv2D(filters=number_of_filters_at_base_layer, kernel_size=convolution_kernel_size, activation='relu', padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(inputs) encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 # Encoding initialization path model = simple_block_2d(model, number_of_filters_at_base_layer, downsample=True, convolution_kernel_size=convolution_kernel_size, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 # Encoding main path number_of_bottle_neck_layers = len(bottle_neck_block_depth_schedule) for i in range(number_of_bottle_neck_layers): number_of_filters = number_of_filters_at_base_layer * 2**i for j in range(bottle_neck_block_depth_schedule[i]): do_downsample = False if j == 0: do_downsample = True else: do_downsample = False model = bottle_neck_block_2d(model, number_of_filters=number_of_filters, downsample=do_downsample, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) if j == (bottle_neck_block_depth_schedule[i] - 1): encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 encoding_layer_count -= 1 # Transition path number_of_filters = number_of_filters_at_base_layer * 2**number_of_bottle_neck_layers model = bottle_neck_block_2d(model, number_of_filters=number_of_filters, downsample=True, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) model = bottle_neck_block_2d(model, number_of_filters=number_of_filters, upsample=True, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) # Decoding main path number_of_bottle_neck_layers = len(bottle_neck_block_depth_schedule) for i in range(number_of_bottle_neck_layers): number_of_filters = (number_of_filters_at_base_layer * 2**(number_of_bottle_neck_layers - i - 1)) for j in range(bottle_neck_block_depth_schedule[number_of_bottle_neck_layers - i - 1]): do_upsample = False if j == bottle_neck_block_depth_schedule[number_of_bottle_neck_layers - i - 1] - 1: do_upsample = True else: do_upsample = False model = bottle_neck_block_2d(model, number_of_filters=number_of_filters, upsample=do_upsample, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) if j == 0: model = Conv2D(filters=(number_of_filters * 4), kernel_size=(1, 1), padding='same')(model) model = skip_connection(encoding_layers_with_long_skip_connections[encoding_layer_count - 1], model) encoding_layer_count -= 1 # Decoding initialization path model = simple_block_2d(model, number_of_filters_at_base_layer, upsample=True, convolution_kernel_size=convolution_kernel_size, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) # Postprocessing layer model = Conv2D(filters=number_of_filters_at_base_layer, kernel_size=convolution_kernel_size, activation='relu', padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(model) encoding_layer_count -= 1 model = skip_connection(encoding_layers_with_long_skip_connections[encoding_layer_count - 1], model) model = BatchNormalization()(model) model = ThresholdedReLU(theta = 0)(model) convActivation = '' if mode == 'classification': convActivation = 'softmax' elif mode == 'regression': convActivation = 'linear' else: raise ValueError('mode must be either `classification` or `regression`.') outputs = Conv2D(filters=number_of_outputs, kernel_size=(1, 1), activation = convActivation, kernel_regularizer=regularizers.l2(weight_decay))(model) resunet_model = Model(inputs=inputs, outputs=outputs) return resunet_model
[docs]def create_resunet_model_3d(input_image_size, number_of_outputs=1, number_of_filters_at_base_layer=32, bottle_neck_block_depth_schedule=(3, 4), convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), dropout_rate=0.0, weight_decay=0.0, mode='classification' ): """ 3-D implementation of the Resnet + U-net deep learning architecture. Creates a keras model of the U-net + ResNet deep learning architecture for image segmentation and regression with the paper available here: https://arxiv.org/abs/1608.04117 This particular implementation was ported from the following python implementation: https://github.com/veugene/fcn_maker/ Arguments --------- input_image_size : tuple of length 4 Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). The batch size (i.e., number of training images) is not specified a priori. number_of_outputs : integer Meaning depends on the mode. For 'classification' this is the number of segmentation labels. For 'regression' this is the number of outputs. number_of_filters_at_base_layer : integer Number of filters at the beginning and end of the 'U'. Doubles at each descending/ascending layer. bottle_neck_block_depth_schedule : tuple Tuple that provides the encoding layer schedule for the number of bottleneck blocks per long skip connection. convolution_kernel_size : tuple of length 3 3-d vector defining the kernel size during the encoding path deconvolution_kernel_size : tuple of length 3 3-d vector defining the kernel size during the decoding dropout_rate : scalar Float between 0 and 1 to use between dense layers. weight_decay : scalar Weighting parameter for L2 regularization of the kernel weights of the convolution layers. Default = 0.0. mode : string 'classification' or 'regression'. Default = 'classification'. Returns ------- Keras model A 3-D Keras model defining the network. Example ------- >>> model = create_resunet_model_3d((128, 128, 128, 1)) >>> model.summary() """ def simple_block_3d(input, number_of_filters, downsample=False, upsample=False, convolution_kernel_size=(3, 3, 3), deconvolution_kernel_size=(2, 2, 2), weight_decay=0.0, dropout_rate=0.0): number_of_output_filters = number_of_filters output = BatchNormalization()(input) output = ThresholdedReLU(theta = 0)(output) if downsample: output = MaxPooling3D(pool_size=(2, 2, 2))(output) output = Conv3D(filters=number_of_filters, kernel_size=convolution_kernel_size, padding='same', kernel_regularizer=regularizers.l2(weight_decay))(output) if upsample: output = Conv3DTranspose(filters=number_of_filters, kernel_size=deconvolution_kernel_size, padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = UpSampling3D(size=(2, 2, 2))(output) if dropout_rate > 0.0: output=Dropout(rate=dropout_rate)(output) # Modify the input so that it has the same size as the output if downsample: input = Conv3D(filters=number_of_output_filters, kernel_size=(1, 1, 1), strides=(2, 2, 2), padding='same')(input) elif upsample: input = Conv3DTranspose(filters=number_of_output_filters, kernel_size=(1, 1, 1), padding='same')(input) input = UpSampling3D(size=(2, 2, 2))(input) elif number_of_filters != number_of_output_filters: input = Conv3D(filters=number_of_output_filters, kernel_size=(1, 1, 1), padding='same')(input) output = skip_connection(input, output) return(output) def bottle_neck_block_3d(input, number_of_filters, downsample=False, upsample=False, deconvolution_kernel_size=(2, 2, 2), weight_decay=0.0, dropout_rate=0.0): output = input number_of_output_filters = number_of_filters if downsample: output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) output = Conv3D(filters=number_of_filters, kernel_size=(1, 1, 1), strides=(2, 2, 2), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) output = Conv3D(filters=number_of_filters, kernel_size=(1, 1, 1), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = BatchNormalization()(output) output = ThresholdedReLU(theta = 0)(output) if upsample: output = Conv3DTranspose(filters=number_of_filters, kernel_size=deconvolution_kernel_size, padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) output = UpSampling3D(size=(2, 2, 2))(output) output = Conv3D(filters=(number_of_filters * 4), kernel_size=(1, 1, 1), kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(output) number_of_output_filters = number_of_filters * 4 if dropout_rate > 0.0: output=Dropout(rate=dropout_rate)(output) # Modify the input so that it has the same size as the output if downsample: input = Conv3D(filters=number_of_output_filters, kernel_size=(1, 1, 1), strides=(2, 2, 2), padding='same')(input) elif upsample: input = Conv3DTranspose(filters=number_of_output_filters, kernel_size=(1, 1, 1), padding='same')(input) input = UpSampling3D(size=(2, 2, 2))(input) elif number_of_filters != number_of_output_filters: input = Conv3D(filters=number_of_output_filters, kernel_size=(1, 1, 1), padding='valid')(input) output = skip_connection(input, output) return(output) def skip_connection(source, target, merge_mode='sum'): layer_list = [source, target] output = None if merge_mode == 'sum': output = Add()(layer_list) else: channel_axis = 0 if K.image_data_format() == 'channels_last': channel_axis = -1 output = Concatenate(axis=channel_axis)(layer_list) return(output) inputs = Input(shape = input_image_size) encoding_layers_with_long_skip_connections = [] encoding_layer_count = 1 # Preprocessing layer model = Conv3D(filters=number_of_filters_at_base_layer, kernel_size=convolution_kernel_size, activation='relu', padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(inputs) encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 # Encoding initialization path model = simple_block_3d(model, number_of_filters_at_base_layer, downsample=True, convolution_kernel_size=convolution_kernel_size, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 # Encoding main path number_of_bottle_neck_layers = len(bottle_neck_block_depth_schedule) for i in range(number_of_bottle_neck_layers): number_of_filters = number_of_filters_at_base_layer * 2**i for j in range(bottle_neck_block_depth_schedule[i]): do_downsample = False if j == 0: do_downsample = True else: do_downsample = False model = bottle_neck_block_3d(model, number_of_filters=number_of_filters, downsample=do_downsample, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) if j == (bottle_neck_block_depth_schedule[i] - 1): encoding_layers_with_long_skip_connections.append(model) encoding_layer_count += 1 encoding_layer_count -= 1 # Transition path number_of_filters = number_of_filters_at_base_layer * 2**number_of_bottle_neck_layers model = bottle_neck_block_3d(model, number_of_filters=number_of_filters, downsample=True, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) model = bottle_neck_block_3d(model, number_of_filters=number_of_filters, upsample=True, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) # Decoding main path number_of_bottle_neck_layers = len(bottle_neck_block_depth_schedule) for i in range(number_of_bottle_neck_layers): number_of_filters = (number_of_filters_at_base_layer * 2**(number_of_bottle_neck_layers - i - 1)) for j in range(bottle_neck_block_depth_schedule[number_of_bottle_neck_layers - i - 1]): do_upsample = False if j == bottle_neck_block_depth_schedule[number_of_bottle_neck_layers - i - 1] - 1: do_upsample = True else: do_upsample = False model = bottle_neck_block_3d(model, number_of_filters=number_of_filters, upsample=do_upsample, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) if j == 0: model = Conv3D(filters=(number_of_filters * 4), kernel_size=(1, 1, 1), padding='same')(model) model = skip_connection(encoding_layers_with_long_skip_connections[encoding_layer_count - 1], model) encoding_layer_count -= 1 # Decoding initialization path model = simple_block_3d(model, number_of_filters_at_base_layer, upsample=True, convolution_kernel_size=convolution_kernel_size, deconvolution_kernel_size=deconvolution_kernel_size, weight_decay=weight_decay, dropout_rate=dropout_rate) # Postprocessing layer model = Conv3D(filters=number_of_filters_at_base_layer, kernel_size=convolution_kernel_size, activation='relu', padding='same', kernel_initializer=initializers.he_normal(), kernel_regularizer=regularizers.l2(weight_decay))(model) encoding_layer_count -= 1 model = skip_connection(encoding_layers_with_long_skip_connections[encoding_layer_count - 1], model) model = BatchNormalization()(model) model = ThresholdedReLU(theta = 0)(model) convActivation = '' if mode == 'classification': convActivation = 'softmax' elif mode == 'regression': convActivation = 'linear' else: raise ValueError('mode must be either `classification` or `regression`.') outputs = Conv3D(filters=number_of_outputs, kernel_size=(1, 1, 1), activation = convActivation, kernel_regularizer=regularizers.l2(weight_decay))(model) resunet_model = Model(inputs=inputs, outputs=outputs) return resunet_model