Source code for antspynet.architectures.create_simple_classification_with_spatial_transformer_network_model


from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dense, Activation, Flatten,
                          Conv2D, MaxPooling2D,
                          Conv3D, MaxPooling3D)

from ..utilities import (SpatialTransformer2D, SpatialTransformer3D)

import numpy as np
import tensorflow as tf

[docs]def create_simple_classification_with_spatial_transformer_network_model_2d(input_image_size, resampled_size=(30, 30), number_of_classification_labels=10): """ 2-D implementation of the spatial transformer network. Creates a keras model of the spatial transformer network: https://arxiv.org/abs/1506.02025 based on the following python Keras model: https://github.com/oarriaga/STN.keras/blob/master/src/models/STN.py @param inputImageSize Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). The batch size (i.e., number of training images) is not specified a priori. @param resampledSize resampled size of the transformed input images. @param numberOfClassificationLabels Number of classes. Arguments --------- input_image_size : tuple of length 3 Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). resampled_size : tuple of length 2 Resampled size of the transformed input images. number_of_classification_labels : integer Number of units in the final dense layer. Returns ------- Keras model A 2-D Keras model defining the network. Example ------- >>> model = create_simple_classification_with_spatial_transformer_network_model_2d((128, 128, 1)) >>> model.summary() """ def get_initial_weights_2d(output_size): b = np.zeros((2, 3), dtype='float32') b[0, 0] = 1 b[1, 1] = 1 W = np.zeros((output_size, 6), dtype='float32') weights = [W, b.flatten()] return(weights) inputs = Input(shape = input_image_size) localization = inputs localization = MaxPooling2D(pool_size=(2, 2))(localization) localization = Conv2D(filters=20, kernel_size=(5, 5))(localization) localization = MaxPooling2D(pool_size=(2, 2))(localization) localization = Conv2D(filters=20, kernel_size=(5, 5))(localization) localization = Flatten()(localization) localization = Dense(units=50)(localization) localization = Activation('relu')(localization) weights = get_initial_weights_2d(output_size=50) localization = Dense(6, kernel_initializer = tf.constant_initializer(weights[0]), bias_initializer = tf.constant_initializer(weights[1]))(localization) outputs = SpatialTransformer2D(resampled_size=resampled_size, transform_type="affine", interpolator_type="linear")([inputs, localization]) outputs = Conv2D(filters=32, kernel_size=(3, 3), padding='same')(outputs) outputs = Activation('relu')(outputs) outputs = MaxPooling2D(pool_size=(2, 2))(outputs) outputs = Conv2D(filters=32, kernel_size=(3, 3))(outputs) outputs = Activation('relu')(outputs) outputs = MaxPooling2D(pool_size=(2, 2))(outputs) outputs = Flatten()(outputs) outputs = Dense(units=256)(outputs) outputs = Activation('relu')(outputs) outputs = Dense(units=number_of_classification_labels)(outputs) outputs = Activation('softmax')(outputs) stnModel = Model(inputs=inputs, outputs=outputs) return(stnModel)
[docs]def create_simple_classification_with_spatial_transformer_network_model_3d(input_image_size, resampled_size=(30, 30, 30), number_of_classification_labels=10): """ 3-D implementation of the spatial transformer network. Creates a keras model of the spatial transformer network: https://arxiv.org/abs/1506.02025 based on the following python Keras model: https://github.com/oarriaga/STN.keras/blob/master/src/models/STN.py @param inputImageSize Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). The batch size (i.e., number of training images) is not specified a priori. @param resampledSize resampled size of the transformed input images. @param numberOfClassificationLabels Number of classes. Arguments --------- input_image_size : tuple of length 4 Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). resampled_size : tuple of length 3 Resampled size of the transformed input images. number_of_classification_labels : integer Number of units in the final dense layer. Returns ------- Keras model A 3-D Keras model defining the network. Example ------- >>> model = create_simple_classification_with_spatial_transformer_network_model_3d((128, 128, 128, 1)) >>> model.summary() """ def get_initial_weights_3d(output_size): b = np.zeros((3, 4), dtype='float32') b[0, 0] = 1 b[1, 1] = 1 b[2, 2] = 1 W = np.zeros((output_size, 12), dtype='float32') weights = [W, b.flatten()] return(weights) inputs = Input(shape = input_image_size) localization = inputs localization = MaxPooling3D(pool_size=(2, 2, 2))(localization) localization = Conv3D(filters=20, kernel_size=(5, 5, 5))(localization) localization = MaxPooling3D(pool_size=(2, 2, 2))(localization) localization = Conv3D(filters=20, kernel_size=(5, 5, 5))(localization) localization = Flatten()(localization) localization = Dense(units=50)(localization) localization = Activation('relu')(localization) weights = get_initial_weights_3d(output_size=50) localization = Dense(6, kernel_initializer = tf.constant_initializer(weights[0]), bias_initializer = tf.constant_initializer(weights[1]))(localization) outputs = SpatialTransformer3D(resampled_size=resampled_size, transform_type="affine", interpolator_type="linear")([inputs, localization]) outputs = Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same')(outputs) outputs = Activation('relu')(outputs) outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs) outputs = Conv3D(filters=32, kernel_size=(3, 3, 3))(outputs) outputs = Activation('relu')(outputs) outputs = MaxPooling3D(pool_size=(2, 2, 2))(outputs) outputs = Flatten()(outputs) outputs = Dense(units=256)(outputs) outputs = Activation('relu')(outputs) outputs = Dense(units=number_of_classification_labels)(outputs) outputs = Activation('softmax')(outputs) stnModel = Model(inputs=inputs, outputs=outputs) return(stnModel)