Source code for antspynet.utilities.denseunet_utilities


import tensorflow.keras.backend as K

from tensorflow.keras.layers import Layer, InputSpec

from tensorflow.keras import initializers

[docs]class Scale(Layer): """ Custom layer used in the Dense U-net class for normalization which learns a set of weights and biases for scaling the input data. Arguments --------- axis : integer Specifies which axis to normalize. momentum : scalar Value used for computation of the exponential average of the mean and standard deviation. """ def __init__(self, axis=-1, momentum=0.9, **kwargs): self.momentum = momentum self.axis = axis super(Scale, self).__init__(**kwargs) def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] output_shape = (int(input_shape[self.axis]),) gamma_initializer = initializers.Ones() beta_initializer = initializers.Zeros() self.gamma = K.variable(gamma_initializer(output_shape)) self.beta = K.variable(beta_initializer(output_shape)) self.trainable_weights = [self.gamma, self.beta] def call(self, inputs, mask=None): input_shape = self.input_spec[0].shape broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] output = (K.reshape(self.gamma, broadcast_shape) * inputs + K.reshape(self.beta, broadcast_shape)) return(output) def get_config(self): config = {"momentum": self.momentum, "axis": self.axis} base_config = super(Scale, self).get_config() return dict(list(base_config.items()) + list(config.items()))