Source code for antspynet.utilities.custom_normalization_layers

import tensorflow as tf

from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import backend as K


[docs]class InstanceNormalization(Layer): """ Instance normalization layer. Normalize the activations of the previous layer at each step, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. Taken from https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/layers/normalization/instancenormalization.py Arguments --------- axis: integer Integer specifying which axis should be normalized, typically the feature axis. For example, after a Conv2D layer with `channels_first`, set axis = 1. Setting `axis=-1L` will normalize all values in each instance of the batch. Axis 0 is the batch dimension for tensorflow backend so we throw an error if `axis = 0`. epsilon: float Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g., `nn.relu`), this can be disabled since the scaling will be done by the next layer. beta_initializer : string Initializer for the beta weight. gamma_initializer : string Initializer for the gamma weight. beta_regularizer : string Optional regularizer for the beta weight. gamma_regularizer : string Optional regularizer for the gamma weight. beta_constraint : string Optional constraint for the beta weight. gamma_constraint : string Optional constraint for the gamma weight. Returns ------- Keras layer """ def __init__(self, axis=None, epsilon=1e-3, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs): super(InstanceNormalization, self).__init__(**kwargs) self.supports_masking = True self.axis = axis if self.axis == 0: raise ValueError('Axis cannot be zero') self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) def build(self, input_shape): dimensionality = len(input_shape) if (self.axis is not None) and (dimensionality == 2): raise ValueError('Cannot specify axis for rank 1 tensor') self.input_spec = InputSpec(ndim=dimensionality) if self.axis is None: shape = (1,) else: shape = (input_shape[self.axis],) if self.scale: self.gamma = self.add_weight(shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) else: self.gamma = None if self.center: self.beta = self.add_weight(shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) else: self.beta = None self.built = True def call(self, inputs, training=None): input_shape = K.int_shape(inputs) reduction_axes = list(range(0, len(input_shape))) if self.axis is not None: del reduction_axes[self.axis] del reduction_axes[0] mean = K.mean(inputs, reduction_axes, keepdims=True) stddev = K.std(inputs, reduction_axes, keepdims=True) normed = (inputs - mean) / (stddev + self.epsilon) broadcast_shape = [1] * len(input_shape) if self.axis is not None: broadcast_shape[self.axis] = input_shape[self.axis] if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) normed = normed * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) normed = normed + broadcast_beta return normed def get_config(self): config = { 'axis': self.axis, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint) } base_config = super(InstanceNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items()))