Source code for antspynet.utilities.brain_extraction

import numpy as np
import tensorflow as tf

import ants

[docs]def brain_extraction(image, modality="t1v0", antsxnet_cache_directory=None, verbose=False): """ Perform brain extraction using U-net and ANTs-based training data. "NoBrainer" is also possible where brain extraction uses U-net and FreeSurfer training data ported from the https://github.com/neuronets/nobrainer-models Arguments --------- image : ANTsImage input image (or list of images for multi-modal scenarios). modality : string Modality image type. Options include: * "t1": T1-weighted MRI---ANTs-trained. Update from "t1v0". * "t1v0": T1-weighted MRI---ANTs-trained. * "t1nobrainer": T1-weighted MRI---FreeSurfer-trained: h/t Satra Ghosh and Jakub Kaczmarzyk. * "t1combined": Brian's combination of "t1" and "t1nobrainer". One can also specify "t1combined[X]" where X is the morphological radius. X = 12 by default. * "flair": FLAIR MRI. * "t2": T2 MRI. * "bold": 3-D BOLD MRI. * "fa": Fractional anisotropy. * "t1t2infant": Combined T1-w/T2-w infant MRI h/t Martin Styner. * "t1infant": T1-w infant MRI h/t Martin Styner. * "t2infant": T2-w infant MRI h/t Martin Styner. antsxnet_cache_directory : string Destination directory for storing the downloaded template and model weights. Since these can be resused, if is None, these data will be downloaded to a ~/.keras/ANTsXNet/. verbose : boolean Print progress to the screen. Returns ------- ANTs probability brain mask image. Example ------- >>> probability_brain_mask = brain_extraction(brain_image, modality="t1") """ from ..architectures import create_unet_model_3d from ..utilities import get_pretrained_network from ..utilities import get_antsxnet_data from ..architectures import create_nobrainer_unet_model_3d from ..utilities import decode_unet classes = ("background", "brain") number_of_classification_labels = len(classes) channel_size = 1 if isinstance(image, list): channel_size = len(image) if antsxnet_cache_directory == None: antsxnet_cache_directory = "ANTsXNet" input_images = list() if channel_size == 1: input_images.append(image) else: input_images = image if input_images[0].dimension != 3: raise ValueError( "Image dimension must be 3." ) if "t1combined" in modality: # Need to change with voxel resolution morphological_radius = 12 if '[' in modality and ']' in modality: morphological_radius = int(modality.split("[")[1].split("]")[0]) brain_extraction_t1 = brain_extraction(image, modality="t1", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_mask = ants.iMath_get_largest_component( ants.threshold_image(brain_extraction_t1, 0.5, 10000)) brain_mask = ants.morphology(brain_mask,"close",morphological_radius).iMath_fill_holes() brain_extraction_t1nobrainer = brain_extraction(image * ants.iMath_MD(brain_mask, radius=morphological_radius), modality = "t1nobrainer", antsxnet_cache_directory=antsxnet_cache_directory, verbose=verbose) brain_extraction_combined = ants.iMath_fill_holes( ants.iMath_get_largest_component(brain_extraction_t1nobrainer * brain_mask)) brain_extraction_combined = brain_extraction_combined + ants.iMath_ME(brain_mask, morphological_radius) + brain_mask return(brain_extraction_combined) if modality != "t1nobrainer": ##################### # # ANTs-based # ##################### weights_file_name_prefix = None if modality == "t1v0": weights_file_name_prefix = "brainExtraction" elif modality == "t1": weights_file_name_prefix = "brainExtractionT1v1" elif modality == "t2": weights_file_name_prefix = "brainExtractionT2" elif modality == "flair": weights_file_name_prefix = "brainExtractionFLAIR" elif modality == "bold": weights_file_name_prefix = "brainExtractionBOLD" elif modality == "fa": weights_file_name_prefix = "brainExtractionFA" elif modality == "t1t2infant": weights_file_name_prefix = "brainExtractionInfantT1T2" elif modality == "t1infant": weights_file_name_prefix = "brainExtractionInfantT1" elif modality == "t2infant": weights_file_name_prefix = "brainExtractionInfantT2" else: raise ValueError("Unknown modality type.") if verbose == True: print("Brain extraction: retrieving model weights.") weights_file_name = get_pretrained_network(weights_file_name_prefix, antsxnet_cache_directory=antsxnet_cache_directory) if verbose == True: print("Brain extraction: retrieving template.") reorient_template_file_name_path = get_antsxnet_data("S_template3", antsxnet_cache_directory=antsxnet_cache_directory) reorient_template = ants.image_read(reorient_template_file_name_path) resampled_image_size = reorient_template.shape number_of_filters = (8, 16, 32, 64) mode = "classification" if modality == "t1": number_of_filters = (16, 32, 64, 128) number_of_classification_labels = 1 mode = "sigmoid" unet_model = create_unet_model_3d((*resampled_image_size, channel_size), number_of_outputs=number_of_classification_labels, mode=mode, number_of_filters=number_of_filters, dropout_rate=0.0, convolution_kernel_size=3, deconvolution_kernel_size=2, weight_decay=1e-5) unet_model.load_weights(weights_file_name) if verbose == True: print("Brain extraction: normalizing image to the template.") center_of_mass_template = ants.get_center_of_mass(reorient_template) center_of_mass_image = ants.get_center_of_mass(input_images[0]) translation = np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template) xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", center=np.asarray(center_of_mass_template), translation=translation) batchX = np.zeros((1, *resampled_image_size, channel_size)) for i in range(len(input_images)): warped_image = ants.apply_ants_transform_to_image(xfrm, input_images[i], reorient_template) warped_array = warped_image.numpy() batchX[0,:,:,:,i] = (warped_array - warped_array.mean()) / warped_array.std() if verbose == True: print("Brain extraction: prediction and decoding.") predicted_data = unet_model.predict(batchX, verbose=verbose) probability_images_array = decode_unet(predicted_data, reorient_template) if verbose == True: print("Brain extraction: renormalize probability mask to native space.") xfrm_inv = xfrm.invert() probability_image = xfrm_inv.apply_to_image(probability_images_array[0][number_of_classification_labels-1], input_images[0]) return(probability_image) else: ##################### # # NoBrainer # ##################### if verbose == True: print("NoBrainer: generating network.") model = create_nobrainer_unet_model_3d((None, None, None, 1)) weights_file_name = get_pretrained_network("brainExtractionNoBrainer", antsxnet_cache_directory=antsxnet_cache_directory) model.load_weights(weights_file_name) if verbose == True: print("NoBrainer: preprocessing (intensity truncation and resampling).") image_array = image.numpy() image_robust_range = np.quantile(image_array[np.where(image_array != 0)], (0.02, 0.98)) threshold_value = 0.10 * (image_robust_range[1] - image_robust_range[0]) + image_robust_range[0] thresholded_mask = ants.threshold_image(image, -10000, threshold_value, 0, 1) thresholded_image = image * thresholded_mask image_resampled = ants.resample_image(thresholded_image, (256, 256, 256), use_voxels=True) image_array = np.expand_dims(image_resampled.numpy(), axis=0) image_array = np.expand_dims(image_array, axis=-1) if verbose == True: print("NoBrainer: predicting mask.") brain_mask_array = np.squeeze(model.predict(image_array, verbose=verbose)) brain_mask_resampled = ants.copy_image_info(image_resampled, ants.from_numpy(brain_mask_array)) brain_mask_image = ants.resample_image(brain_mask_resampled, image.shape, use_voxels=True, interp_type=1) spacing = ants.get_spacing(image) spacing_product = spacing[0] * spacing[1] * spacing[2] minimum_brain_volume = round(649933.7/spacing_product) brain_mask_labeled = ants.label_clusters(brain_mask_image, minimum_brain_volume) return(brain_mask_labeled)