antspymm.mm

    1__all__ = ['version',
    2    'mm_read',
    3    'mm_read_to_3d',
    4    'image_write_with_thumbnail',
    5    'nrg_format_path',
    6    'highest_quality_repeat',
    7    'match_modalities',
    8    'mc_resample_image_to_target',
    9    'nrg_filelist_to_dataframe',
   10    'merge_timeseries_data',
   11    'timeseries_reg',
   12    'merge_dwi_data',
   13    'outlierness_by_modality',
   14    'bvec_reorientation',
   15    'get_dti',
   16    'dti_reg',
   17    'mc_reg',
   18    'get_data',
   19    'get_models',
   20    'get_valid_modalities',
   21    'dewarp_imageset',
   22    'super_res_mcimage',
   23    'segment_timeseries_by_meanvalue',
   24    'get_average_rsf',
   25    'get_average_dwi_b0',
   26    'dti_template',
   27    't1_based_dwi_brain_extraction',
   28    'mc_denoise',
   29    'tsnr',
   30    'dvars',
   31    'slice_snr',
   32    'impute_fa',
   33    'trim_dti_mask',
   34    'dipy_dti_recon',
   35    'concat_dewarp',
   36    'joint_dti_recon',
   37    'middle_slice_snr',
   38    'foreground_background_snr',
   39    'quantile_snr',
   40    'mask_snr',
   41    'dwi_deterministic_tracking',
   42    'dwi_closest_peak_tracking',
   43    'dwi_streamline_pairwise_connectivity',
   44    'dwi_streamline_connectivity',
   45    'hierarchical_modality_summary',
   46    'tra_initializer',
   47    'neuromelanin',
   48    'resting_state_fmri_networks',
   49    'write_bvals_bvecs',
   50    'crop_mcimage',
   51    'mm',
   52    'write_mm',
   53    'mm_nrg',
   54    'mm_csv',
   55    'collect_blind_qc_by_modality',
   56    'alffmap',
   57    'alff_image',
   58    'down2iso',
   59    'read_mm_csv',
   60    'assemble_modality_specific_dataframes',
   61    'bind_wide_mm_csvs',
   62    'merge_mm_dataframe',
   63    'augment_image',
   64    'boot_wmh',
   65    'threaded_bind_wide_mm_csvs',
   66    'get_names_from_data_frame',
   67    'average_mm_df',
   68    'quick_viz_mm_nrg',
   69    'blind_image_assessment',
   70    'average_blind_qc_by_modality',
   71    'best_mmm',
   72    'nrg_2_bids',
   73    'bids_2_nrg',
   74    'parse_nrg_filename',
   75    'novelty_detection_svm',
   76    'novelty_detection_ee',
   77    'novelty_detection_lof',
   78    'novelty_detection_loop',
   79    'novelty_detection_quantile',
   80    'generate_mm_dataframe',
   81    'aggregate_antspymm_results',
   82    'aggregate_antspymm_results_sdf',
   83    'study_dataframe_from_matched_dataframe',
   84    'merge_wides_to_study_dataframe',
   85    'filter_image_files',
   86    'docsamson',
   87    'enantiomorphic_filling_without_mask',
   88    'wmh',
   89    'remove_elements_from_numpy_array',
   90    'score_fmri_censoring',
   91    'remove_volumes_from_timeseries',
   92    'loop_timeseries_censoring',
   93    'clean_tmp_directory',
   94    'validate_nrg_file_format',
   95    'ants_to_nibabel_affine',
   96    'dict_to_dataframe']
   97
   98from pathlib import Path
   99from pathlib import PurePath
  100import os
  101import pandas as pd
  102import math
  103import os.path
  104from os import path
  105import pickle
  106import sys
  107import numpy as np
  108import random
  109import functools
  110from operator import mul
  111from scipy.sparse.linalg import svds
  112from scipy.stats import pearsonr
  113import re
  114import datetime as dt
  115from collections import Counter
  116import tempfile
  117import uuid
  118import warnings
  119
  120from dipy.core.histeq import histeq
  121import dipy.reconst.dti as dti
  122from dipy.core.gradients import (gradient_table, gradient_table_from_gradient_strength_bvecs)
  123from dipy.io.gradients import read_bvals_bvecs
  124from dipy.segment.mask import median_otsu
  125from dipy.reconst.dti import fractional_anisotropy, color_fa
  126import nibabel as nib
  127
  128import ants
  129import antspynet
  130import antspyt1w
  131import siq
  132import tensorflow as tf
  133
  134from multiprocessing import Pool
  135import glob as glob
  136
  137antspyt1w.set_global_scientific_computing_random_seed(
  138    antspyt1w.get_global_scientific_computing_random_seed( )
  139)
  140
  141DATA_PATH = os.path.expanduser('~/.antspymm/')
  142
  143def version( ):
  144    """
  145    report versions of this package and primary dependencies
  146
  147    Arguments
  148    ---------
  149    None
  150
  151    Returns
  152    -------
  153    a dictionary with package name and versions
  154
  155    Example
  156    -------
  157    >>> import antspymm
  158    >>> antspymm.version()
  159    """
  160    import pkg_resources
  161    return {
  162              'tensorflow': pkg_resources.get_distribution("tensorflow").version,
  163              'antspyx': pkg_resources.get_distribution("antspyx").version,
  164              'antspynet': pkg_resources.get_distribution("antspynet").version,
  165              'antspyt1w': pkg_resources.get_distribution("antspyt1w").version,
  166              'antspymm': pkg_resources.get_distribution("antspymm").version
  167              }
  168
  169def nrg_filename_to_subjectvisit(s, separator='-'):
  170    """
  171    Extracts a pattern from the input string.
  172    
  173    Parameters:
  174    - s: The input string from which to extract the pattern.
  175    - separator: The separator used in the string (default is '-').
  176    
  177    Returns:
  178    - A string in the format of 'PREFIX-Number-Date'
  179    """
  180    parts = os.path.basename(s).split(separator)
  181    # Assuming the pattern is always in the form of PREFIX-Number-Date-...
  182    # and PREFIX is always "PPMI", extract the first three parts
  183    extracted = separator.join(parts[:3])
  184    return extracted
  185
  186
  187def validate_nrg_file_format(path, separator):
  188    """
  189    is your path nrg-etic?
  190    Validates if a given path conforms to the NRG file format, taking into account known extensions
  191    and the expected directory structure.
  192
  193    :param path: The file path to validate.
  194    :param separator: The separator used in the filename and directory structure.
  195    :return: A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.
  196
  197    : example
  198
  199    ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz'
  200    ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'
  201
  202    validate_nrg_detailed(ntfngood, '-')
  203    print( validate_nrg_detailed(ntfn, '-') )
  204    print( validate_nrg_detailed(ntfn, '_') )
  205
  206    """
  207    import re    
  208
  209    def normalize_path(path):
  210        """
  211        Replace multiple repeated '/' with just a single '/'
  212        
  213        :param path: The file path to normalize.
  214        :return: The normalized file path with single '/'.
  215        """
  216        normalized_path = re.sub(r'/+', '/', path)
  217        return normalized_path
  218
  219    def strip_known_extension(filename, known_extensions):
  220        """
  221        Strips a known extension from the filename.
  222
  223        :param filename: The filename from which to strip the extension.
  224        :param known_extensions: A list of known extensions to strip from the filename.
  225        :return: The filename with the known extension stripped off, if found.
  226        """
  227        for ext in known_extensions:
  228            if filename.endswith(ext):
  229                # Strip the extension and return the modified filename
  230                return filename[:-len(ext)]
  231        # If no known extension is found, return the original filename
  232        return filename
  233
  234    import warnings
  235    if normalize_path( path ) != path:
  236        path = normalize_path( path )
  237        warnings.warn("Probably had multiple repeated slashes eg /// in the file path.  this might cause issues. clean up with re.sub(r'/+', '/', path)")
  238
  239    known_extensions = [".nii.gz", ".nii", ".mhd", ".nrrd", ".mha", ".json", ".bval", ".bvec"]
  240    known_extensions2 = [ext.lstrip('.') for ext in known_extensions]
  241    def get_extension(filename, known_extensions ):
  242        # List of known extensions in priority order
  243        for ext in known_extensions:
  244            if filename.endswith(ext):
  245                return ext.strip('.')
  246        return "Invalid extension"
  247    
  248    parts = path.split('/')
  249    if len(parts) < 7:  # Checking for minimum path structure
  250        return False, "Path structure is incomplete. Expected at least 7 components, found {}.".format(len(parts))
  251    
  252    # Extract directory components and filename
  253    directory_components = parts[1:-1]  # Exclude the root '/' and filename
  254    filename = parts[-1]
  255    filename_without_extension = strip_known_extension( filename, known_extensions )
  256    file_extension = get_extension( filename, known_extensions )
  257    
  258    # Validating file extension
  259    if file_extension not in known_extensions2:
  260        print( file_extension )
  261        return False, "Invalid file extension: {}. Expected 'nii.gz' or 'json'.".format(file_extension)
  262    
  263    # Splitting the filename to validate individual parts
  264    filename_parts = filename_without_extension.split(separator)
  265    if len(filename_parts) != 5:  # Expecting 5 parts based on the NRG format
  266        print( filename_parts )
  267        return False, "Filename does not have exactly 5 parts separated by '{}'. Found {} parts.".format(separator, len(filename_parts))
  268    
  269    # Reconstruct expected filename from directory components
  270    expected_filename_parts = directory_components[-5:]
  271    expected_filename = separator.join(expected_filename_parts)
  272    if filename_without_extension != expected_filename:
  273        print( filename_without_extension )
  274        print("--- vs expected ---")
  275        print( expected_filename )
  276        return False, "Filename structure does not match directory structure. Expected filename: {}.".format(expected_filename)
  277    
  278    # Validate directory structure against NRG format
  279    study_name, subject_id, session, modality = directory_components[-4:-1] + [directory_components[-1].split('/')[0]]
  280    if not all([study_name, subject_id, session, modality]):
  281        return False, "Directory structure does not follow NRG format. Ensure StudyName, SubjectID, Session (ses_x), and Modality are correctly specified."
  282    
  283    # If all checks pass
  284    return True, "The path conforms to the NRG format."
  285
  286
  287
  288def apply_transforms_mixed_interpolation(
  289    fixed,
  290    moving,
  291    transformlist,
  292    interpolator="linear",
  293    imagetype=0,
  294    whichtoinvert=None,
  295    mask=None,
  296    **kwargs
  297):
  298    """
  299    Apply ANTs transforms with mixed interpolation:
  300    - Linear interpolation inside `mask`
  301    - Nearest neighbor outside `mask`
  302
  303    Parameters
  304    ----------
  305    fixed : ANTsImage
  306        Fixed/reference image to define spatial domain.
  307
  308    moving : ANTsImage
  309        Moving image to be transformed.
  310
  311    transformlist : list of str
  312        List of filenames for transforms.
  313
  314    interpolator : str, optional
  315        Interpolator used inside the mask. Default is "linear".
  316
  317    imagetype : int
  318        Image type used by ANTs (0 = scalar, 1 = vector, etc.)
  319
  320    whichtoinvert : list of bool, optional
  321        List of booleans indicating which transforms to invert.
  322
  323    mask : ANTsImage
  324        Binary mask image indicating where to apply `interpolator` (e.g., "linear").
  325        Outside the mask, nearest neighbor is used.
  326
  327    kwargs : dict
  328        Additional arguments passed to `ants.apply_transforms`.
  329
  330    Returns
  331    -------
  332    ANTsImage
  333        Interpolated image using mixed interpolation, added across masked regions.
  334    """
  335    if mask is None:
  336        raise ValueError("A binary `mask` image must be provided.")
  337
  338    # Apply linear interpolation inside the mask
  339    interp_linear = ants.apply_transforms(
  340        fixed=fixed,
  341        moving=moving,
  342        transformlist=transformlist,
  343        interpolator=interpolator,
  344        imagetype=imagetype,
  345        whichtoinvert=whichtoinvert,
  346        **kwargs
  347    )
  348
  349    # Apply nearest-neighbor interpolation everywhere
  350    interp_nn = ants.apply_transforms(
  351        fixed=fixed,
  352        moving=moving,
  353        transformlist=transformlist,
  354        interpolator="nearestNeighbor",
  355        imagetype=imagetype,
  356        whichtoinvert=whichtoinvert,
  357        **kwargs
  358    )
  359
  360    # Combine: linear * mask + nn * (1 - mask)
  361    mixed_result = (interp_linear * mask) + (interp_nn * (1 - mask))
  362
  363    return mixed_result
  364
  365def get_antsimage_keys(dictionary):
  366    """
  367    Return the keys of the dictionary where the values are ANTsImages.
  368
  369    :param dictionary: A dictionary to inspect
  370    :return: A list of keys for which the values are ANTsImages
  371    """
  372    return [key for key, value in dictionary.items() if isinstance(value, ants.core.ants_image.ANTsImage)]
  373
  374def to_nibabel(img: "ants.core.ants_image.ANTsImage") -> nib.Nifti1Image:
  375    """
  376    Convert an ANTsPy image to a Nibabel Nifti1Image in-memory, using correct spatial affine.
  377
  378    Parameters:
  379        img (ants.ANTsImage): An image from ANTsPy.
  380
  381    Returns:
  382        nib.Nifti1Image: The corresponding Nibabel image with spatial orientation in RAS.
  383    """
  384    array_data = img.numpy()  # get voxel data as NumPy array
  385    affine = ants_to_nibabel_affine(img)
  386    return nib.Nifti1Image(array_data, affine)
  387
  388def ants_to_nibabel_affine(ants_img):
  389    """
  390    Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space).
  391    Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).
  392    
  393    Returns:
  394        4x4 np.ndarray affine matrix in RAS space.
  395    """
  396    spatial_dim = ants_img.dimension
  397    spacing = np.array(ants_img.spacing)
  398    origin = np.array(ants_img.origin)
  399    direction = np.array(ants_img.direction).reshape((spatial_dim, spatial_dim))
  400    # Compute rotation-scale matrix
  401    affine_linear = direction @ np.diag(spacing)
  402    # Build full 4x4 affine with identity in homogeneous bottom row
  403    affine = np.eye(4)
  404    affine[:spatial_dim, :spatial_dim] = affine_linear
  405    affine[:spatial_dim, 3] = origin
  406    affine[3, 3]=1
  407    # Convert LPS -> RAS by flipping x and y
  408    lps_to_ras = np.diag([-1, -1, 1, 1])
  409    affine = lps_to_ras @ affine
  410    return affine
  411
  412
  413def dict_to_dataframe(data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
  414    """
  415    Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.
  416
  417    :param data_dict: Dictionary to be converted.
  418    :param convert_lists: boolean
  419    :param convert_arrays: boolean
  420    :param convert_images: boolean
  421    :param verbose: boolean
  422    :return: DataFrame representation of the dictionary.
  423    """
  424    processed_data = {}
  425    list_length = None
  426    def mean_of_list(lst):
  427        if not lst:  # Check if the list is not empty
  428            return 0  # Return 0 or appropriate value for an empty list
  429        all_numeric = all(isinstance(item, (int, float)) for item in lst)
  430        if all_numeric:
  431            return sum(lst) / len(lst)
  432        return None
  433    
  434    for key, value in data_dict.items():
  435        # Check if value is a scalar
  436        if isinstance(value, (int, float, str, bool)):
  437            processed_data[key] = [value]
  438        # Check if value is a list of scalars
  439        elif isinstance(value, list) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_lists:
  440            meanvalue = mean_of_list( value )
  441            newkey = key+"_mean"
  442            if verbose:
  443                print( " Key " + key + " is list with mean " + str(meanvalue) + " to " + newkey )
  444            if newkey not in data_dict.keys() and convert_lists:
  445                processed_data[newkey] = meanvalue
  446        elif isinstance(value, np.ndarray) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_arrays:
  447            meanvalue = value.mean()
  448            newkey = key+"_mean"
  449            if verbose:
  450                print( " Key " + key + " is nparray with mean " + str(meanvalue) + " to " + newkey )
  451            if newkey not in data_dict.keys():
  452                processed_data[newkey] = meanvalue
  453        elif isinstance(value, ants.core.ants_image.ANTsImage ) and convert_images:
  454            meanvalue = value.mean()
  455            newkey = key+"_mean"
  456            if newkey not in data_dict.keys():
  457                if verbose:
  458                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " to " + newkey )
  459                processed_data[newkey] = meanvalue
  460            else:
  461                if verbose:
  462                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " but " + newkey + " already exists" )
  463
  464    return pd.DataFrame.from_dict(processed_data)
  465
  466
  467def clean_tmp_directory(age_hours=1., use_sudo=False, extensions=[ '.nii', '.nii.gz' ], log_file_path=None):
  468    """
  469    Clean the /tmp directory by removing files and directories older than a certain number of hours.
  470    Optionally uses sudo and can filter files by extensions.
  471
  472    :param age_hours: Age in hours to consider files and directories for deletion.
  473    :param use_sudo: Whether to use sudo for removal commands.
  474    :param extensions: List of file extensions to delete. If None, all files are considered.
  475    :param log_file_path: Path to the log file. If None, a default path will be used based on the OS.
  476
  477    # Usage
  478    # Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])
  479    """
  480    import os
  481    import platform
  482    import subprocess
  483    from datetime import datetime, timedelta
  484
  485    if not isinstance(age_hours, float):
  486        return
  487
  488    # Determine the tmp directory based on the operating system
  489    tmp_dir = '/tmp'
  490
  491    # Set the log file path
  492    if log_file_path is not None:
  493        log_file = log_file_path
  494
  495    current_time = datetime.now()
  496    for item in os.listdir(tmp_dir):
  497        try:
  498            item_path = os.path.join(tmp_dir, item)
  499            item_stat = os.stat(item_path)
  500
  501            # Calculate the age of the file/directory
  502            item_age = current_time - datetime.fromtimestamp(item_stat.st_mtime)
  503            if item_age > timedelta(hours=age_hours):
  504                # Check for file extensions if provided
  505                if extensions is None or any(item.endswith(ext) for ext in extensions):
  506                    # Construct the removal command
  507                    rm_command = ['sudo', 'rm', '-rf', item_path] if use_sudo else ['rm', '-rf', item_path]
  508                    subprocess.run(rm_command)
  509
  510                if log_file_path is not None:
  511                    with open(log_file, 'a') as log:
  512                        log.write(f"{datetime.now()}: Deleted {item_path}\n")
  513        except Exception as e:
  514            if log_file_path is not None:
  515                with open(log_file, 'a') as log:
  516                    log.write(f"{datetime.now()}: Error deleting {item_path}: {e}\n")
  517
  518
  519
  520def docsamson(locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
  521    """
  522    Processes image file names based on the specified imaging modality and other parameters.
  523
  524    The function selects file names from the provided dictionary `studycsv` based on the imaging modality.
  525    It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly.
  526    The function can optionally print verbose output during processing.
  527
  528    Parameters:
  529    locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'.
  530    studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names.
  531    outputdir (str): Base directory for output files.
  532    projid (str): Project identifier.
  533    sid (str): Subject identifier.
  534    dtid (str): Data acquisition time identifier.
  535    mysep (str): Separator used in file naming.
  536    t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'.
  537    verbose (bool, optional): If True, prints detailed information during execution.
  538
  539    Returns:
  540    dict: A dictionary with keys 'modality', 'outprefix', and 'images'.
  541        - 'modality' (str): The imaging modality used.
  542        - 'outprefix' (str): The prefix for output file paths.
  543        - 'images' (list): A list of processed image file names.
  544
  545    Notes:
  546    - The function is designed to work within a specific workflow and might require adaptation for general use.
  547
  548    Examples:
  549    >>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
  550    >>> print(result['modality'])
  551    'T1w'
  552    >>> print(result['outprefix'])
  553    '/path/to/output/directory/T1w/some_identifier'
  554    >>> print(result['images'])
  555    ['image1.nii', 'image2.nii']
  556    """
  557
  558    import os
  559    import re
  560
  561    myimgsInput = []
  562    myoutputPrefix = None
  563    imfns = ['filename', 'rsfid1', 'rsfid2', 'dtid1', 'dtid2', 'flairid']
  564    
  565    # Define image file names based on the modality
  566    if locmod == 'T1w':
  567        imfns=['filename']
  568    elif locmod == 'T2Flair':
  569        imfns=['flairid']
  570    elif locmod == 'perf':
  571        imfns=['perfid']
  572    elif locmod == 'pet3d':
  573        imfns=['pet3did']
  574    elif locmod == 'NM2DMT':
  575        imfns=[]
  576        for i in range(11):
  577            imfns.append('nmid' + str(i))
  578    elif locmod == 'rsfMRI':
  579        imfns=[]
  580        for i in range(4):
  581            imfns.append('rsfid' + str(i))
  582    elif locmod == 'DTI':
  583        imfns=[]
  584        for i in range(4):
  585            imfns.append('dtid' + str(i))
  586    else:
  587        raise ValueError("docsamson: no match of modality to filename id " + locmod )
  588
  589    # Process each file name
  590    for i in imfns:
  591        if verbose:
  592            print(i + " " + locmod)
  593        if i in studycsv.keys():
  594            fni = str(studycsv[i].iloc[0])
  595            if verbose:
  596                print(i + " " + fni + ' exists ' + str(os.path.exists(fni)))
  597            if os.path.exists(fni):
  598                myimgsInput.append(fni)
  599                temp = os.path.basename(fni)
  600                mysplit = temp.split(mysep)
  601                iid = re.sub(".nii.gz", "", mysplit[-1])
  602                iid = re.sub(".mha", "", iid)
  603                iid = re.sub(".nii", "", iid)
  604                iid2 = iid
  605                if locmod != 'T1w' and t1iid is not None:
  606                    iid2 = iid + "_" + t1iid
  607                else:
  608                    iid2 = t1iid
  609                myoutputPrefix = os.path.join(outputdir, projid, sid, dtid, locmod, iid, projid + mysep + sid + mysep + dtid + mysep + locmod + mysep + iid2)
  610    
  611    if verbose:
  612        print(locmod)
  613        print(myimgsInput)
  614        print(myoutputPrefix)
  615    
  616    return {
  617        'modality': locmod,
  618        'outprefix': myoutputPrefix,
  619        'images': myimgsInput
  620    }
  621
  622
  623def get_valid_modalities( long=False, asString=False, qc=False ):
  624    """
  625    return a list of valid modality identifiers used in NRG modality designation
  626    and that can be processed by this package.
  627
  628    long - return the long version
  629
  630    asString - concat list to string
  631    """
  632    if long:
  633        mymod = ["T1w", "NM2DMT", "rsfMRI", "rsfMRI_LR", "rsfMRI_RL", "rsfMRILR", "rsfMRIRL", "DTI", "DTI_LR","DTI_RL",  "DTILR","DTIRL","T2Flair", "dwi", "dwi_ap", "dwi_pa", "func", "func_ap", "func_pa", "perf", 'pet3d']
  634    elif qc:
  635        mymod = [ 'T1w', 'T2Flair', 'NM2DMT', 'DTI', 'DTIdwi','DTIb0', 'rsfMRI', "perf", 'pet3d' ]
  636    else:
  637        mymod = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI", "perf", 'pet3d' ]
  638    if not asString:
  639        return mymod
  640    else:
  641        mymodchar=""
  642        for x in mymod:
  643            mymodchar = mymodchar + " " + str(x)
  644        return mymodchar
  645
  646
  647def generate_mm_dataframe(
  648        projectID,
  649        subjectID,
  650        date,
  651        imageUniqueID,
  652        modality,
  653        source_image_directory,
  654        output_image_directory,
  655        t1_filename,
  656        flair_filename=[],
  657        rsf_filenames=[],
  658        dti_filenames=[],
  659        nm_filenames=[],
  660        perf_filename=[],
  661        pet3d_filename=[],
  662):
  663    """
  664    Generate a DataFrame for medical imaging data with extensive validation of input parameters.
  665
  666    This function creates a DataFrame containing information about medical imaging files,
  667    ensuring that filenames match expected patterns for their modalities and that all
  668    required images exist. It also validates the number of filenames provided for specific
  669    modalities like rsfMRI, DTI, and NM.
  670
  671    Parameters:
  672    - projectID (str): Project identifier.
  673    - subjectID (str): Subject identifier.
  674    - date (str): Date of the imaging study.
  675    - imageUniqueID (str): Unique image identifier.
  676    - modality (str): Modality of the imaging study.
  677    - source_image_directory (str): Directory of the source images.
  678    - output_image_directory (str): Directory for output images.
  679    - t1_filename (str): Filename of the T1-weighted image.
  680    - flair_filename (list): List of filenames for FLAIR images.
  681    - rsf_filenames (list): List of filenames for rsfMRI images.
  682    - dti_filenames (list): List of filenames for DTI images.
  683    - nm_filenames (list): List of filenames for NM images.
  684    - perf_filename (list): List of filenames for perfusion images.
  685    - pet3d_filename (list): List of filenames for pet3d images.
  686
  687    Returns:
  688    - pandas.DataFrame: A DataFrame containing the validated imaging study information.
  689
  690    Raises:
  691    - ValueError: If any validation checks fail or if the number of columns does not match the data.
  692    """
  693    def check_pd_construction(data, columns):
  694        # Check if the length of columns matches the length of data in each row
  695        if all(len(row) == len(columns) for row in data):
  696            return True
  697        else:
  698            return False
  699    from os.path import exists
  700    valid_modalities = get_valid_modalities()
  701    if not isinstance(t1_filename, str):
  702        raise ValueError("t1_filename is not a string")
  703    if not exists(t1_filename):
  704        raise ValueError("t1_filename does not exist")
  705    if modality not in valid_modalities:
  706        raise ValueError('modality ' + str(modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
  707    # if not exists( output_image_directory ):
  708    #    raise ValueError("output_image_directory does not exist")
  709    if not exists( source_image_directory ):
  710        raise ValueError("source_image_directory does not exist")
  711    if len( rsf_filenames ) > 2:
  712        raise ValueError("len( rsf_filenames ) > 2")
  713    if len( dti_filenames ) > 3:
  714        raise ValueError("len( dti_filenames ) > 3")
  715    if len( nm_filenames ) > 11:
  716        raise ValueError("len( nm_filenames ) > 11")
  717    if len( rsf_filenames ) < 2:
  718        for k in range(len(rsf_filenames),2):
  719            rsf_filenames.append(None)
  720    if len( dti_filenames ) < 3:
  721        for k in range(len(dti_filenames),3):
  722            dti_filenames.append(None)
  723    if len( nm_filenames ) < 10:
  724        for k in range(len(nm_filenames),10):
  725            nm_filenames.append(None)
  726    # check modality names
  727    if not "T1w" in t1_filename:
  728        raise ValueError("T1w is not in t1 filename " + t1_filename)
  729    if flair_filename is not None:
  730        if isinstance(flair_filename,list):
  731            if (len(flair_filename) == 0):
  732                flair_filename=None
  733            else:
  734                print("Take first entry from flair_filename list")
  735                flair_filename=flair_filename[0]
  736    if flair_filename is not None and not "lair" in flair_filename:
  737            raise ValueError("flair is not flair filename " + flair_filename)
  738    ## perfusion
  739    if perf_filename is not None:
  740        if isinstance(perf_filename,list):
  741            if (len(perf_filename) == 0):
  742                perf_filename=None
  743            else:
  744                print("Take first entry from perf_filename list")
  745                perf_filename=perf_filename[0]
  746    if perf_filename is not None and not "perf" in perf_filename:
  747            raise ValueError("perf_filename is not perf filename " + perf_filename)
  748
  749    if pet3d_filename is not None:
  750        if isinstance(pet3d_filename,list):
  751            if (len(pet3d_filename) == 0):
  752                pet3d_filename=None
  753            else:
  754                print("Take first entry from pet3d_filename list")
  755                pet3d_filename=pet3d_filename[0]
  756    if pet3d_filename is not None and not "pet" in pet3d_filename:
  757            raise ValueError("pet3d_filename is not pet filename " + pet3d_filename)
  758    
  759    for k in nm_filenames:
  760        if k is not None:
  761            if not "NM" in k:
  762                raise ValueError("NM is not flair filename " + k)
  763    for k in dti_filenames:
  764        if k is not None:
  765            if not "DTI" in k and not "dwi" in k:
  766                raise ValueError("DTI/DWI is not dti filename " + k)
  767    for k in rsf_filenames:
  768        if k is not None:
  769            if not "fMRI" in k and not "func" in k:
  770                raise ValueError("rsfMRI/func is not rsfmri filename " + k)
  771    if perf_filename is not None:
  772        if not "perf" in perf_filename:
  773                raise ValueError("perf_filename is not a valid perfusion (perf) filename " + k)
  774    allfns = [t1_filename] + [flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename] + [pet3d_filename]
  775    for k in allfns:
  776        if k is not None:
  777            if not isinstance(k, str):
  778                raise ValueError(str(k) + " is not a string")
  779            if not exists( k ):
  780                raise ValueError( "image " + k + " does not exist")
  781    coredata = [
  782        projectID,
  783        subjectID,
  784        date,
  785        imageUniqueID,
  786        modality,
  787        source_image_directory,
  788        output_image_directory,
  789        t1_filename,
  790        flair_filename, 
  791        perf_filename,
  792        pet3d_filename]
  793    mydata0 = coredata +  rsf_filenames + dti_filenames
  794    mydata = mydata0 + nm_filenames
  795    corecols = [
  796        'projectID',
  797        'subjectID',
  798        'date',
  799        'imageID',
  800        'modality',
  801        'sourcedir',
  802        'outputdir',
  803        'filename',
  804        'flairid',
  805        'perfid',
  806        'pet3did']
  807    mycols0 = corecols + [
  808        'rsfid1', 'rsfid2',
  809        'dtid1', 'dtid2','dtid3']
  810    nmext = [
  811        'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
  812        'nmid6', 'nmid7','nmid8', 'nmid9', 'nmid10' #, 'nmid11'
  813    ]
  814    mycols = mycols0 + nmext
  815    if not check_pd_construction( [mydata], mycols ) :
  816#        print( mydata )
  817#        print( len(mydata ))
  818#        print( mycols )
  819#        print( len(mycols ))
  820        raise ValueError( "Error in generate_mm_dataframe: len( mycols ) != len( mydata ) which indicates a bad input parameter to this function." )
  821    studycsv = pd.DataFrame([ mydata ], columns=mycols)
  822    return studycsv
  823
  824import pandas as pd
  825from os.path import exists
  826
  827def validate_filename(filename, valid_keywords, error_message):
  828    """
  829    Validate if the given filename contains any of the specified keywords.
  830
  831    Parameters:
  832    - filename (str): The filename to validate.
  833    - valid_keywords (list): A list of keywords to look for in the filename.
  834    - error_message (str): The error message to raise if validation fails.
  835
  836    Raises:
  837    - ValueError: If none of the keywords are found in the filename.
  838    """
  839    if filename is not None and not any(keyword in filename for keyword in valid_keywords):
  840        raise ValueError(error_message)
  841
  842def validate_modality(modality, valid_modalities):
  843    if modality not in valid_modalities:
  844        valid_modalities_str = ', '.join(valid_modalities)
  845        raise ValueError(f'Modality {modality} not a valid mm modality: {valid_modalities_str}')
  846
  847def extend_list_to_length(lst, target_length, fill_value=None):
  848    return lst + [fill_value] * (target_length - len(lst))
  849
  850def generate_mm_dataframe_gpt(
  851        projectID, subjectID, date, imageUniqueID, modality, 
  852        source_image_directory, output_image_directory, t1_filename, 
  853        flair_filename=[], rsf_filenames=[], dti_filenames=[], nm_filenames=[], perf_filename=[] ):
  854    """
  855    see help for generate_mm_dataframe - same as this
  856    """
  857    def check_pd_construction(data, columns):
  858        return all(len(row) == len(columns) for row in data)
  859
  860    flair_filename.sort()
  861    rsf_filenames.sort()
  862    dti_filenames.sort()
  863    nm_filenames.sort()
  864    perf_filename.sort()
  865
  866    valid_modalities = get_valid_modalities()  
  867
  868    if not isinstance(t1_filename, str):
  869        raise ValueError("t1_filename is not a string")
  870    if not exists(t1_filename):
  871        raise ValueError("t1_filename does not exist")
  872
  873    validate_modality(modality, valid_modalities)
  874
  875    if not exists(source_image_directory):
  876        raise ValueError("source_image_directory does not exist")
  877
  878    rsf_filenames = extend_list_to_length(rsf_filenames, 2)
  879    dti_filenames = extend_list_to_length(dti_filenames, 2)
  880    nm_filenames = extend_list_to_length(nm_filenames, 11)
  881
  882    validate_filename(t1_filename, ["T1w"], "T1w is not in t1 filename " + t1_filename)
  883
  884    if flair_filename:
  885        flair_filename = flair_filename[0] if isinstance(flair_filename, list) else flair_filename
  886        validate_filename(flair_filename, ["lair"], "flair is not in flair filename " + flair_filename)
  887
  888    if perf_filename:
  889        perf_filename = perf_filename[0] if isinstance(perf_filename, list) else perf_filename
  890        validate_filename(perf_filename, ["perf"], "perf_filename is not a valid perfusion (perf) filename")
  891
  892    for k in nm_filenames:
  893        if k: validate_filename(k, ["NM"], "NM is not in NM filename " + k)
  894
  895    for k in dti_filenames:
  896        if k: validate_filename(k, ["DTI","dwi"], "DTI or dwi is not in DTI filename " + k)
  897
  898    for k in rsf_filenames:
  899        if k: validate_filename(k, ["fMRI","func"], "rsfMRI or func is not in rsfMRI filename " + k)
  900
  901    allfns = [t1_filename, flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename]
  902    for k in allfns:
  903        if k and not exists(k):
  904            raise ValueError("image " + k + " does not exist")
  905
  906    coredata = [projectID, subjectID, date, imageUniqueID, modality,
  907                source_image_directory, output_image_directory, t1_filename, 
  908                flair_filename, perf_filename]
  909    mydata0 = coredata + rsf_filenames + dti_filenames
  910    mydata = mydata0 + nm_filenames
  911
  912    corecols = ['projectID', 'subjectID', 'date', 'imageID', 'modality',
  913                'sourcedir', 'outputdir', 'filename', 'flairid', 'perfid']
  914    mycols0 = corecols + ['rsfid1', 'rsfid2', 'dtid1', 'dtid2']
  915    nmext = ['nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
  916             'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'nmid11']
  917    mycols = mycols0 + nmext
  918
  919    if not check_pd_construction([mydata], mycols):
  920        print( mydata )
  921        print( mycols )
  922        raise ValueError("Error in generate_mm_dataframe: len(mycols) != len(mydata), indicating bad input parameters.")
  923
  924    studycsv = pd.DataFrame([mydata], columns=mycols)
  925    return studycsv
  926
  927
  928
  929def filter_columns_by_nan_percentage(df, max_nan_percentage=50.0):
  930    """
  931    Filter columns in a DataFrame based on a threshold for the percentage of NaN values.
  932
  933    Parameters
  934    ----------
  935    df : pandas.DataFrame
  936        The input DataFrame from which columns are to be filtered.
  937    max_nan_percentage : float, optional
  938        The maximum allowed percentage of NaN values in a column. Columns with a higher
  939        percentage of NaN values than this threshold will be removed from the DataFrame.
  940        The default is 50.0, which means columns with more than 50% NaN values will be removed.
  941
  942    Returns
  943    -------
  944    pandas.DataFrame
  945        A DataFrame with columns filtered based on the NaN values percentage criterion.
  946
  947    Examples
  948    --------
  949    >>> import pandas as pd
  950    >>> data = {'A': [1, 2, None, 4], 'B': [None, 2, 3, None], 'C': [1, 2, 3, 4]}
  951    >>> df = pd.DataFrame(data)
  952    >>> filtered_df = filter_columns_by_nan_percentage(df, 50.0)
  953    >>> print(filtered_df)
  954
  955    Notes
  956    -----
  957    The function calculates the percentage of NaN values in each column and filters out
  958    those columns where the percentage exceeds the `max_nan_percentage` threshold.
  959    """
  960    # Calculate the percentage of NaN values for each column
  961    nan_percentage = df.isnull().mean() * 100
  962
  963    # Filter columns where the percentage of NaN values is less than or equal to the threshold
  964    columns_to_keep = nan_percentage[nan_percentage <= max_nan_percentage].index
  965
  966    # Return the filtered DataFrame
  967    return df[columns_to_keep]
  968
  969
  970
  971def parse_nrg_filename( x, separator='-' ):
  972    """
  973    split a NRG filename into its named parts
  974    """
  975    temp = x.split( separator )
  976    if len(temp) != 5:
  977        raise ValueError(x + " not a valid NRG filename")
  978    return {
  979        'project':temp[0],
  980        'subjectID':temp[1],
  981        'date':temp[2],
  982        'modality':temp[3],
  983        'imageID':temp[4]
  984    }
  985
  986
  987
  988def nrg_2_bids( nrg_filename ):
  989    """
  990    Convert an NRG filename to BIDS path/filename.
  991
  992    Parameters:
  993    nrg_filename (str): The NRG filename to convert.
  994
  995    Returns:
  996    str: The BIDS path/filename.
  997    """
  998
  999    # Split the NRG filename into its components
 1000    nrg_dirname, nrg_basename = os.path.split(nrg_filename)
 1001    nrg_suffix = '.' + nrg_basename.split('.',1)[-1]
 1002    nrg_basename = nrg_basename.replace(nrg_suffix, '') # remove ext
 1003    nrg_parts = nrg_basename.split('-')
 1004    nrg_subject_id = nrg_parts[1]
 1005    nrg_modality = nrg_parts[3]
 1006    nrg_repeat= nrg_parts[4]
 1007
 1008    # Build the BIDS path/filename
 1009    bids_dirname = os.path.join(nrg_dirname, 'bids')
 1010    bids_subject = f'sub-{nrg_subject_id}'
 1011    bids_session = f'ses-{nrg_repeat}'
 1012
 1013    valid_modalities = get_valid_modalities()
 1014    if nrg_modality is not None:
 1015        if not nrg_modality in valid_modalities:
 1016            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
 1017
 1018    if nrg_modality == 'T1w' :
 1019        bids_modality_folder = 'anat'
 1020        bids_modality_filename = 'T1w'
 1021
 1022    if nrg_modality == 'T2Flair' :
 1023        bids_modality_folder = 'anat'
 1024        bids_modality_filename = 'flair'
 1025
 1026    if nrg_modality == 'NM2DMT' :
 1027        bids_modality_folder = 'anat'
 1028        bids_modality_filename = 'nm2dmt'
 1029
 1030    if nrg_modality == 'DTI' or nrg_modality == 'DTI_RL' or nrg_modality == 'DTI_LR' :
 1031        bids_modality_folder = 'dwi'
 1032        bids_modality_filename = 'dwi'
 1033
 1034    if nrg_modality == 'rsfMRI' or nrg_modality == 'rsfMRI_RL' or nrg_modality == 'rsfMRI_LR' :
 1035        bids_modality_folder = 'func'
 1036        bids_modality_filename = 'func'
 1037
 1038    if nrg_modality == 'perf'  :
 1039        bids_modality_folder = 'perf'
 1040        bids_modality_filename = 'perf'
 1041
 1042    bids_suffix = nrg_suffix[1:]
 1043    bids_filename = f'{bids_subject}_{bids_session}_{bids_modality_filename}.{bids_suffix}'
 1044
 1045    # Return bids filepath/filename
 1046    return os.path.join(bids_dirname, bids_subject, bids_session, bids_modality_folder, bids_filename)
 1047
 1048
 1049def bids_2_nrg( bids_filename, project_name, date, nrg_modality=None ):
 1050    """
 1051    Convert a BIDS filename to NRG path/filename.
 1052
 1053    Parameters:
 1054    bids_filename (str): The BIDS filename to convert
 1055    project_name (str) : Name of project (i.e. PPMI)
 1056    date (str) : Date of image acquisition
 1057
 1058
 1059    Returns:
 1060    str: The NRG path/filename.
 1061    """
 1062
 1063    bids_dirname, bids_basename = os.path.split(bids_filename)
 1064    bids_suffix = '.'+ bids_basename.split('.',1)[-1]
 1065    bids_basename = bids_basename.replace(bids_suffix, '') # remove ext
 1066    bids_parts = bids_basename.split('_')
 1067    nrg_subject_id = bids_parts[0].replace('sub-','')
 1068    nrg_image_id = bids_parts[1].replace('ses-', '')
 1069    bids_modality = bids_parts[2]
 1070    valid_modalities = get_valid_modalities()
 1071    if nrg_modality is not None:
 1072        if not nrg_modality in valid_modalities:
 1073            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality: " + get_valid_modalities(asString=True))
 1074
 1075    if bids_modality == 'anat' and nrg_modality is None :
 1076        nrg_modality = 'T1w'
 1077
 1078    if bids_modality == 'dwi' and nrg_modality is None  :
 1079        nrg_modality = 'DTI'
 1080
 1081    if bids_modality == 'func' and nrg_modality is None  :
 1082        nrg_modality = 'rsfMRI'
 1083
 1084    if bids_modality == 'perf' and nrg_modality is None  :
 1085        nrg_modality = 'perf'
 1086
 1087    nrg_suffix = bids_suffix[1:]
 1088    nrg_filename = f'{project_name}-{nrg_subject_id}-{date}-{nrg_modality}-{nrg_image_id}.{nrg_suffix}'
 1089
 1090    return os.path.join(project_name, nrg_subject_id, date, nrg_modality, nrg_image_id,nrg_filename)
 1091
 1092def collect_blind_qc_by_modality( modality_path, set_index_to_fn=True ):
 1093    """
 1094    Collects blind QC data from multiple CSV files with the same modality.
 1095
 1096    Args:
 1097
 1098    modality_path (str): The path to the folder containing the CSV files.
 1099
 1100    set_index_to_fn: boolean
 1101
 1102    Returns:
 1103    Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.
 1104    """
 1105    import glob as glob
 1106    fns = glob.glob( modality_path )
 1107    fns.sort()
 1108    jdf = pd.DataFrame()
 1109    for k in range(len(fns)):
 1110        temp=pd.read_csv(fns[k])
 1111        if not 'filename' in temp.keys():
 1112            temp['filename']=fns[k]
 1113        jdf=pd.concat( [jdf,temp], axis=0, ignore_index=False )
 1114    if set_index_to_fn:
 1115        jdf.reset_index(drop=True)
 1116        if "Unnamed: 0" in jdf.columns:
 1117            holder=jdf.pop( "Unnamed: 0" )
 1118        jdf.set_index('filename')
 1119    return jdf
 1120
 1121
 1122def outlierness_by_modality( qcdf, uid='filename', outlier_columns = ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi','reflection_err', 'EVR', 'msk_vol'], verbose=False ):
 1123    """
 1124    Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF.  LOF appears to be more conservative.  This function will impute missing columns with the mean.
 1125
 1126    Args:
 1127    - qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
 1128    - uid: (str) Unique identifier for a subject. Default is 'filename'.
 1129    - outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
 1130    - verbose: (bool) If True, prints information for each modality. Default is False.
 1131
 1132    Returns:
 1133    - qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column.  Higher values near 1 are more outlying.
 1134
 1135    Raises:
 1136    - ValueError: If uid is not present in the dataframe.
 1137
 1138    Example:
 1139    >>> df = pd.read_csv('data.csv')
 1140    >>> outlierness_by_modality(df)
 1141    """
 1142    from PyNomaly import loop
 1143    from sklearn.neighbors import LocalOutlierFactor
 1144    qcdfout = qcdf.copy()
 1145    pd.set_option('future.no_silent_downcasting', True)
 1146    qcdfout.replace([np.inf, -np.inf], np.nan, inplace=True)
 1147    if uid not in qcdfout.keys():
 1148        raise ValueError( str(uid) + " not in dataframe")
 1149    if 'ol_loop' not in qcdfout.keys():
 1150        qcdfout['ol_loop']=math.nan
 1151    if 'ol_lof' not in qcdfout.keys():
 1152        qcdfout['ol_lof']=math.nan
 1153    didit=False
 1154    for mod in get_valid_modalities( qc=True ):
 1155        didit=True
 1156        lof = LocalOutlierFactor()
 1157        locsel = qcdfout["modality"] == mod
 1158        rr = qcdfout[locsel][outlier_columns]
 1159        column_means = rr.mean()
 1160        rr.fillna(column_means, inplace=True)
 1161        if rr.shape[0] > 1:
 1162            if verbose:
 1163                print("calc: " + mod + " outlierness " )
 1164            myneigh = np.min( [24, int(np.round(rr.shape[0]*0.5)) ] )
 1165            temp = antspyt1w.loop_outlierness(rr.astype(float), standardize=True, extent=3, n_neighbors=myneigh, cluster_labels=None)
 1166            qcdfout.loc[locsel,'ol_loop']=temp.astype('float64')
 1167            yhat = lof.fit_predict(rr)
 1168            temp = lof.negative_outlier_factor_*(-1.0)
 1169            temp = temp - temp.min()
 1170            yhat[ yhat == 1] = 0
 1171            yhat[ yhat == -1] = 1 # these are outliers
 1172            qcdfout.loc[locsel,'ol_lof_decision']=yhat
 1173            qcdfout.loc[locsel,'ol_lof']=temp/temp.max()
 1174    if verbose:
 1175        print( didit )
 1176    return qcdfout
 1177
 1178
 1179def nrg_format_path( projectID, subjectID, date, modality, imageID, separator='-' ):
 1180    """
 1181    create the NRG path on disk given the project, subject id, date, modality and image id
 1182
 1183    Arguments
 1184    ---------
 1185
 1186    projectID : string for the project e.g. PPMI
 1187
 1188    subjectID : string uniquely identifying the subject e.g. 0001
 1189
 1190    date : string for the date usually 20550228 ie YYYYMMDD format
 1191
 1192    modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)
 1193
 1194    imageID : string uniquely identifying the specific image
 1195
 1196    separator : default to -
 1197
 1198    Returns
 1199    -------
 1200    the path where one would write the image on disk
 1201
 1202    """
 1203    thedirectory = os.path.join( str(projectID), str(subjectID), str(date), str(modality), str(imageID) )
 1204    thefilename = str(projectID) + separator + str(subjectID) + separator + str(date) + separator + str(modality) + separator + str(imageID)
 1205    return os.path.join( thedirectory, thefilename )
 1206
 1207
 1208def get_first_item_as_string(df, column_name):
 1209    """
 1210    Check if the first item in the specified column of the DataFrame is a string.
 1211    If it is not a string, attempt to convert it to an integer and then to a string.
 1212
 1213    Parameters:
 1214    df (pd.DataFrame): The DataFrame to operate on.
 1215    column_name (str): The name of the column to check.
 1216
 1217    Returns:
 1218    str: The first item in the specified column, guaranteed to be returned as a string.
 1219    """
 1220    if isinstance(df[column_name].iloc[0], str):
 1221        return df[column_name].iloc[0]
 1222    else:
 1223        try:
 1224            return str(int(df[column_name].iloc[0]))
 1225        except ValueError:
 1226            raise ValueError("The value cannot be converted to an integer.")
 1227
 1228
 1229def study_dataframe_from_matched_dataframe( matched_dataframe, rootdir, outputdir, verbose=False ):
 1230    """
 1231    converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv
 1232
 1233    matched_dataframe : output of antspymm.match_modalities
 1234
 1235    rootdir : location for the input data root folder (in e.g. NRG format)
 1236
 1237    outputdir : location for the output data
 1238
 1239    verbose : boolean
 1240    """
 1241    iext='.nii.gz'
 1242    from os.path import exists
 1243    musthavecols = ['projectID', 'subjectID','date','imageID','filename']
 1244    for k in range(len(musthavecols)):
 1245        if not musthavecols[k] in matched_dataframe.keys():
 1246            raise ValueError('matched_dataframe is missing column ' + musthavecols[k] + ' in study_dataframe_from_qc_dataframe' )
 1247    csvrow=matched_dataframe.dropna(axis=1)
 1248    pid=get_first_item_as_string( csvrow, 'projectID'  )
 1249    sid=get_first_item_as_string( csvrow, 'subjectID'  ) # str(csvrow['subjectID'].iloc[0] )
 1250    dt=get_first_item_as_string( csvrow, 'date'  )  # str(csvrow['date'].iloc[0])
 1251    iid=get_first_item_as_string( csvrow, 'imageID'  ) # str(csvrow['imageID'].iloc[0])
 1252    nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
 1253    if not exists( nrgt1fn ) and iid == '0':
 1254        iid='000'
 1255        nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
 1256    if not exists( nrgt1fn ):
 1257        raise ValueError("T1 " + nrgt1fn + " does not exist in study_dataframe_from_qc_dataframe")
 1258    flList=[]
 1259    dtList=[]
 1260    rsfList=[]
 1261    nmList=[]
 1262    perfList=[]
 1263    if 'flairfn' in csvrow.keys():
 1264        flid=get_first_item_as_string( csvrow, 'flairid' )
 1265        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
 1266        if not exists( nrgt2fn ) and flid == '0':
 1267            flid='000'
 1268            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
 1269        if verbose:
 1270            print("Trying " + nrgt2fn )
 1271        if exists( nrgt2fn ):
 1272            if verbose:
 1273                print("success" )
 1274            flList.append( nrgt2fn )
 1275    if 'perffn' in csvrow.keys():
 1276        flid=get_first_item_as_string( csvrow, 'perfid' )
 1277        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
 1278        if not exists( nrgt2fn ) and flid == '0':
 1279            flid='000'
 1280            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
 1281        if verbose:
 1282            print("Trying " + nrgt2fn )
 1283        if exists( nrgt2fn ):
 1284            if verbose:
 1285                print("success" )
 1286            perfList.append( nrgt2fn )
 1287    if 'dtfn1' in csvrow.keys():
 1288        dtid=get_first_item_as_string( csvrow, 'dtid1' )
 1289        dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
 1290        if len( dtfn1) == 0 :
 1291            dtid = '000'
 1292            dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
 1293        dtfn1=dtfn1[0]
 1294        if exists( dtfn1 ):
 1295            dtList.append( dtfn1 )
 1296    if 'dtfn2' in csvrow.keys():
 1297        dtid=get_first_item_as_string( csvrow, 'dtid2' )
 1298        dtfn2=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
 1299        if len( dtfn2) == 0 :
 1300            dtid = '000'
 1301            dtfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
 1302        dtfn2=dtfn2[0]
 1303        if exists( dtfn2 ):
 1304            dtList.append( dtfn2 )
 1305    if 'dtfn3' in csvrow.keys():
 1306        dtid=get_first_item_as_string( csvrow, 'dtid3' )
 1307        dtfn3=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
 1308        if len( dtfn3) == 0 :
 1309            dtid = '000'
 1310            dtfn3=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
 1311        dtfn3=dtfn3[0]
 1312        if exists( dtfn3 ):
 1313            dtList.append( dtfn3 )
 1314    if 'rsffn1' in csvrow.keys():
 1315        rsid=get_first_item_as_string( csvrow, 'rsfid1' )
 1316        rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
 1317        if len( rsfn1 ) == 0 :
 1318            rsid = '000'
 1319            rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
 1320        rsfn1=rsfn1[0]
 1321        if exists( rsfn1 ):
 1322            rsfList.append( rsfn1 )
 1323    if 'rsffn2' in csvrow.keys():
 1324        rsid=get_first_item_as_string( csvrow, 'rsfid2' )
 1325        rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))[0]
 1326        if len( rsfn2 ) == 0 :
 1327            rsid = '000'
 1328            rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))
 1329        rsfn2=rsfn2[0]
 1330        if exists( rsfn2 ):
 1331            rsfList.append( rsfn2 )
 1332    for j in range(11):
 1333        keyname="nmfn"+str(j)
 1334        keynameid="nmid"+str(j)
 1335        if keyname in csvrow.keys() and keynameid in csvrow.keys():
 1336            nmid=get_first_item_as_string( csvrow, keynameid )
 1337            nmsearchpath=os.path.join( rootdir, pid, sid, dt, 'NM2DMT', nmid, "*"+nmid+iext)
 1338            nmfn=glob.glob( nmsearchpath )
 1339            nmfn=nmfn[0]
 1340            if exists( nmfn ):
 1341                nmList.append( nmfn )
 1342    if verbose:
 1343        print("assembled the image lists mapping to ....")
 1344        print(nrgt1fn)
 1345        print("NM")
 1346        print(nmList)
 1347        print("FLAIR")
 1348        print(flList)
 1349        print("DTI")
 1350        print(dtList)
 1351        print("rsfMRI")
 1352        print(rsfList)
 1353        print("perf")
 1354        print(perfList)
 1355    studycsv = generate_mm_dataframe(
 1356        pid,
 1357        sid,
 1358        dt,
 1359        iid, # the T1 id
 1360        'T1w',
 1361        rootdir,
 1362        outputdir,
 1363        t1_filename=nrgt1fn,
 1364        flair_filename=flList,
 1365        dti_filenames=dtList,
 1366        rsf_filenames=rsfList,
 1367        nm_filenames=nmList,
 1368        perf_filename=perfList)
 1369    return studycsv.dropna(axis=1)
 1370
 1371def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
 1372    """
 1373    This function returns a subset of the input dataframe that retains only the rows
 1374    that correspond to the highest quality observation for each combination of ID and visit.
 1375
 1376    Parameters:
 1377    ----------
 1378    mxdfin: pandas.DataFrame
 1379        The input dataframe.
 1380    idvar: str
 1381        The name of the column that contains the ID variable.
 1382    visitvar: str
 1383        The name of the column that contains the visit variable.
 1384    qualityvar: str
 1385        The name of the column that contains the quality variable.
 1386
 1387    Returns:
 1388    -------
 1389    pandas.DataFrame
 1390        A subset of the input dataframe that retains only the rows that correspond
 1391        to the highest quality observation for each combination of ID and visit.
 1392    """
 1393    if visitvar not in mxdfin.columns:
 1394        raise ValueError("visitvar not in dataframe")
 1395    if idvar not in mxdfin.columns:
 1396        raise ValueError("idvar not in dataframe")
 1397    if qualityvar not in mxdfin.columns:
 1398        raise ValueError("qualityvar not in dataframe")
 1399
 1400    mxdfin[qualityvar] = mxdfin[qualityvar].astype(float)
 1401
 1402    vizzes = mxdfin[visitvar].unique()
 1403    uids = mxdfin[idvar].unique()
 1404    useit = np.zeros(mxdfin.shape[0], dtype=bool)
 1405
 1406    for u in uids:
 1407        losel = mxdfin[idvar] == u
 1408        vizzesloc = mxdfin[losel][visitvar].unique()
 1409
 1410        for v in vizzesloc:
 1411            losel = (mxdfin[idvar] == u) & (mxdfin[visitvar] == v)
 1412            mysnr = mxdfin.loc[losel, qualityvar]
 1413            myw = np.where(losel)[0]
 1414
 1415            if len(myw) > 1:
 1416                if any(~np.isnan(mysnr)):
 1417                    useit[myw[np.argmax(mysnr)]] = True
 1418                else:
 1419                    useit[myw] = True
 1420            else:
 1421                useit[myw] = True
 1422
 1423    return mxdfin[useit]
 1424
 1425
 1426def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False ):
 1427    """
 1428    Find the best multiple modality dataset at each time point
 1429
 1430    :param qc_dataframe: quality control data frame with
 1431    :param unique_identifier : the unique NRG filename for each image
 1432    :param outlier_column: outlierness score used to identify the best image (pair) at a given date
 1433    :param mysep (str, optional): the separator used in the image file names. Defaults to '-'.
 1434    :param verbose: boolean
 1435    :return: filtered matched modality data frame
 1436    """
 1437    import pandas as pd
 1438    import numpy as np
 1439    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
 1440    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
 1441    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
 1442    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
 1443    mmdf = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column )['filt']
 1444    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column )['filt']
 1445    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column )['filt']
 1446    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column )['filt']
 1447    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column )['filt']
 1448    mmdf['flairid'] = None
 1449    mmdf['flairfn'] = None
 1450    mmdf['flairloop'] = None
 1451    mmdf['flairlof'] = None
 1452    mmdf['dtid1'] = None
 1453    mmdf['dtfn1'] = None
 1454    mmdf['dtntimepoints1'] = 0
 1455    mmdf['dtloop1'] = math.nan
 1456    mmdf['dtlof1'] = math.nan
 1457    mmdf['dtid2'] = None
 1458    mmdf['dtfn2'] = None
 1459    mmdf['dtntimepoints2'] = 0
 1460    mmdf['dtloop2'] = math.nan
 1461    mmdf['dtlof2'] = math.nan
 1462    mmdf['rsfid1'] = None
 1463    mmdf['rsffn1'] = None
 1464    mmdf['rsfntimepoints1'] = 0
 1465    mmdf['rsfloop1'] = math.nan
 1466    mmdf['rsflof1'] = math.nan
 1467    mmdf['rsfid2'] = None
 1468    mmdf['rsffn2'] = None
 1469    mmdf['rsfntimepoints2'] = 0
 1470    mmdf['rsfloop2'] = math.nan
 1471    mmdf['rsflof2'] = math.nan
 1472    for k in range(1,11):
 1473        myid='nmid'+str(k)
 1474        mmdf[myid] = None
 1475        myid='nmfn'+str(k)
 1476        mmdf[myid] = None
 1477        myid='nmloop'+str(k)
 1478        mmdf[myid] = math.nan
 1479        myid='nmlof'+str(k)
 1480        mmdf[myid] = math.nan
 1481    if verbose:
 1482        print( mmdf.shape )
 1483    for k in range(mmdf.shape[0]):
 1484        if verbose:
 1485            if k % 100 == 0:
 1486                progger = str( k ) # np.round( k / mmdf.shape[0] * 100 ) )
 1487                print( progger, end ="...", flush=True)
 1488        if dtdf is not None:
 1489            locsel = (dtdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
 1490            if sum(locsel) == 1:
 1491                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = dtdf["imageID"][locsel].values[0]
 1492                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = dtdf[unique_identifier][locsel].values[0]
 1493                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = dtdf[outlier_column][locsel].values[0]
 1494                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(dtdf['ol_lof_decision'][locsel].values[0])
 1495                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
 1496            elif sum(locsel) > 1:
 1497                locdf = dtdf[locsel]
 1498                dedupe = locdf[["snr","cnr"]].duplicated()
 1499                locdf = locdf[~dedupe]
 1500                if locdf.shape[0] > 1:
 1501                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1502                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = locdf["imageID"].values[0]
 1503                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = locdf[unique_identifier].values[0]
 1504                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = locdf[outlier_column].values[0]
 1505                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(locdf['ol_lof_decision'][locsel].values[0])
 1506                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
 1507                if locdf.shape[0] > 1:
 1508                    mmdf.iloc[k, mmdf.columns.get_loc("dtid2")] = locdf["imageID"].values[1]
 1509                    mmdf.iloc[k, mmdf.columns.get_loc("dtfn2")] = locdf[unique_identifier].values[1]
 1510                    mmdf.iloc[k, mmdf.columns.get_loc("dtloop2")] = locdf[outlier_column].values[1]
 1511                    mmdf.iloc[k, mmdf.columns.get_loc("dtlof2")] = float(locdf['ol_lof_decision'][locsel].values[1])
 1512                    mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints2")] = float(dtdf['dimt'][locsel].values[1])
 1513        if rsdf is not None:
 1514            locsel = (rsdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
 1515            if sum(locsel) == 1:
 1516                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = rsdf["imageID"][locsel].values[0]
 1517                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = rsdf[unique_identifier][locsel].values[0]
 1518                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = rsdf[outlier_column][locsel].values[0]
 1519                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(rsdf['ol_lof_decision'].values[0])
 1520                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(rsdf['dimt'][locsel].values[0])
 1521            elif sum(locsel) > 1:
 1522                locdf = rsdf[locsel]
 1523                dedupe = locdf[["snr","cnr"]].duplicated()
 1524                locdf = locdf[~dedupe]
 1525                if locdf.shape[0] > 1:
 1526                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1527                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = locdf["imageID"].values[0]
 1528                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = locdf[unique_identifier].values[0]
 1529                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = locdf[outlier_column].values[0]
 1530                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(locdf['ol_lof_decision'].values[0])
 1531                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(locdf['dimt'][locsel].values[0])
 1532                if locdf.shape[0] > 1:
 1533                    mmdf.iloc[k, mmdf.columns.get_loc("rsfid2")] = locdf["imageID"].values[1]
 1534                    mmdf.iloc[k, mmdf.columns.get_loc("rsffn2")] = locdf[unique_identifier].values[1]
 1535                    mmdf.iloc[k, mmdf.columns.get_loc("rsfloop2")] = locdf[outlier_column].values[1]
 1536                    mmdf.iloc[k, mmdf.columns.get_loc("rsflof2")] = float(locdf['ol_lof_decision'].values[1])
 1537                    mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints2")] = float(locdf['dimt'][locsel].values[1])
 1538
 1539        if fldf is not None:
 1540            locsel = fldf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
 1541            if locsel.sum() == 1:
 1542                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = fldf['imageID'][locsel].values[0]
 1543                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = fldf[unique_identifier][locsel].values[0]
 1544                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = fldf[outlier_column][locsel].values[0]
 1545                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(fldf['ol_lof_decision'][locsel].values[0])
 1546            elif sum(locsel) > 1:
 1547                locdf = fldf[locsel]
 1548                dedupe = locdf[["snr","cnr"]].duplicated()
 1549                locdf = locdf[~dedupe]
 1550                if locdf.shape[0] > 1:
 1551                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1552                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = locdf["imageID"].values[0]
 1553                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = locdf[unique_identifier].values[0]
 1554                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = locdf[outlier_column].values[0]
 1555                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(locdf['ol_lof_decision'].values[0])
 1556
 1557        if nmdf is not None:
 1558            locsel = nmdf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
 1559            if locsel.sum() > 0:
 1560                locdf = nmdf[locsel]
 1561                for i in range(np.min( [10,locdf.shape[0]])):
 1562                    nmid = "nmid"+str(i+1)
 1563                    mmdf.loc[k,nmid] = locdf['imageID'].iloc[i]
 1564                    nmfn = "nmfn"+str(i+1)
 1565                    mmdf.loc[k,nmfn] = locdf['imageID'].iloc[i]
 1566                    nmloop = "nmloop"+str(i+1)
 1567                    mmdf.loc[k,nmloop] = locdf[outlier_column].iloc[i]
 1568                    nmloop = "nmlof"+str(i+1)
 1569                    mmdf.loc[k,nmloop] = float(locdf['ol_lof_decision'].iloc[i])
 1570
 1571    mmdf['rsf_total_timepoints']=mmdf['rsfntimepoints1']+mmdf['rsfntimepoints2']
 1572    mmdf['dt_total_timepoints']=mmdf['dtntimepoints1']+mmdf['dtntimepoints2']
 1573    return mmdf
 1574
 1575
 1576def add_repeat_column(df, groupby_column):
 1577    """
 1578    Adds a 'repeat' column to the DataFrame that counts occurrences of each unique value
 1579    in the specified 'groupby_column'. The count increments from 1 for each identical entry.
 1580    
 1581    Parameters:
 1582    - df: pandas DataFrame.
 1583    - groupby_column: The name of the column to group by and count repeats.
 1584    
 1585    Returns:
 1586    - Modified pandas DataFrame with an added 'repeat' column.
 1587    """
 1588    # Validate if the groupby_column exists in the DataFrame
 1589    if groupby_column not in df.columns:
 1590        raise ValueError(f"Column '{groupby_column}' does not exist in the DataFrame.")
 1591    
 1592    # Count the occurrences of each unique value in the specified column and increment from 1
 1593    df['repeat'] = df.groupby(groupby_column).cumcount() + 1
 1594    
 1595    return df
 1596
 1597def best_mmm( mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
 1598    """
 1599    Selects the best repeats per modality.
 1600
 1601    Args:
 1602    wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')
 1603
 1604    mysep (str, optional): the separator used in the image file names. Defaults to '-'.
 1605
 1606    outlier_name : column name for outlier score
 1607
 1608    verbose (bool, optional): default True
 1609
 1610    Returns:
 1611
 1612    list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.
 1613
 1614    """
 1615#    mmdf = mmdf.astype(str)
 1616    mmdf[outlier_column]=mmdf[outlier_column].astype(float)
 1617    msel = mmdf['modality'] == wmod
 1618    if wmod == 'rsfMRI':
 1619        msel1 = mmdf['modality'] == 'rsfMRI'
 1620        msel2 = mmdf['modality'] == 'rsfMRI_LR'
 1621        msel3 = mmdf['modality'] == 'rsfMRI_RL'
 1622        msel = msel1 | msel2
 1623        msel = msel | msel3
 1624    if wmod == 'DTI':
 1625        msel1 = mmdf['modality'] == 'DTI'
 1626        msel2 = mmdf['modality'] == 'DTI_LR'
 1627        msel3 = mmdf['modality'] == 'DTI_RL'
 1628        msel4 = mmdf['modality'] == 'DTIdwi'
 1629        msel5 = mmdf['modality'] == 'DTIb0'
 1630        msel = msel1 | msel2 | msel3 | msel4 | msel5
 1631    if sum(msel) == 0:
 1632        return {'raw': None, 'filt': None}
 1633    metasub = mmdf[msel].copy()
 1634
 1635    if verbose:
 1636        print(f"{wmod} {(metasub.shape[0])} pre")
 1637
 1638    metasub['subjectID']=None
 1639    metasub['date']=None
 1640    metasub['subjectIDdate']=None
 1641    metasub['imageID']=None
 1642    metasub['negol']=math.nan
 1643    for k in metasub.index:
 1644        temp = metasub.loc[k, 'filename'].split( mysep )
 1645        metasub.loc[k,'subjectID'] = str( temp[1] )
 1646        metasub.loc[k,'date'] = str( temp[2] )
 1647        metasub.loc[k,'subjectIDdate'] = str( temp[1] + mysep + temp[2] )
 1648        metasub.loc[k,'imageID'] = str( temp[4])
 1649
 1650
 1651    if 'ol_' in outlier_column:
 1652        metasub['negol'] = metasub[outlier_column].max() - metasub[outlier_column]
 1653    else:
 1654        metasub['negol'] = metasub[outlier_column]
 1655    if 'date' not in metasub.keys():
 1656        metasub['date']=None
 1657    metasubq = add_repeat_column( metasub, 'subjectIDdate' )
 1658    metasubq = highest_quality_repeat(metasubq, 'filename', 'date', 'negol')
 1659
 1660    if verbose:
 1661        print(f"{wmod} {metasubq.shape[0]} post")
 1662
 1663#    metasub = metasub.astype(str)
 1664#    metasubq = metasubq.astype(str)
 1665    metasub[outlier_column]=metasub[outlier_column].astype(float)
 1666    metasubq[outlier_column]=metasubq[outlier_column].astype(float)
 1667    return {'raw': metasub, 'filt': metasubq}
 1668
 1669def mm_read( x, standardize_intensity=False, modality='' ):
 1670    """
 1671    read an image from a filename - same as ants.image_read (for now)
 1672
 1673    standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one
 1674
 1675    modality : not used
 1676    """
 1677    if x is None:
 1678        raise ValueError( " None passed to function antspymm.mm_read." )
 1679    if not isinstance(x,str):
 1680        raise ValueError( " Non-string passed to function antspymm.mm_read." )
 1681    if not os.path.exists( x ):
 1682        raise ValueError( " file " + fni + " does not exist." )
 1683    img = ants.image_read( x, reorient=False )
 1684    if standardize_intensity:
 1685        img[img<0.0]=0.0
 1686        img=ants.iMath(img,'Normalize')
 1687    if modality == "T1w" and img.dimension == 4:
 1688        print("WARNING: input image is 4D - we attempt a hack fix that works in some odd cases of PPMI data - please check this image: " + x, flush=True )
 1689        i1=ants.slice_image(img,3,0)
 1690        i2=ants.slice_image(img,3,1)
 1691        kk=np.concatenate( [i1.numpy(),i2.numpy()], axis=2 )
 1692        kk=ants.from_numpy(kk)
 1693        img=ants.copy_image_info(i1,kk)
 1694    return img
 1695
 1696def mm_read_to_3d( x, slice=None, modality='' ):
 1697    """
 1698    read an image from a filename - and return as 3d or None if that is not possible
 1699    """
 1700    img = ants.image_read( x, reorient=False )
 1701    if img.dimension <= 3:
 1702        return img
 1703    elif img.dimension == 4:
 1704        nslices = img.shape[3]
 1705        if slice is None:
 1706            sl = np.round( nslices * 0.5 )
 1707        else:
 1708            sl = slice
 1709        if sl > nslices:
 1710            sl = nslices-1
 1711        return ants.slice_image( img, axis=3, idx=int(sl) )
 1712    elif img.dimension > 4:
 1713        return img
 1714    return None
 1715
 1716def timeseries_n3(x):
 1717    """
 1718    Perform N3 bias field correction on a time-series image dataset using ANTsPy library.
 1719
 1720    This function processes a multi-dimensional image dataset, where the last dimension
 1721    represents different time points. It applies N3 bias field correction to each time point 
 1722    individually to correct intensity non-uniformity.
 1723
 1724    Parameters:
 1725    x (ndarray): A multi-dimensional array where the last dimension represents time points. 
 1726                 Each 'slice' along this dimension is a separate image to be corrected.
 1727
 1728    Returns:
 1729    ndarray: A multi-dimensional array of the same shape as x, with N3 bias field correction 
 1730             applied to each time slice.
 1731
 1732    The function works as follows:
 1733    - Initializes an empty list `mimg` to store the corrected images.
 1734    - Determines the number of time points in the input image series.
 1735    - Iterates over each time point, extracting the image slice and applying N3 bias 
 1736      field correction.
 1737    - The corrected images are then appended to the `mimg` list.
 1738    - Finally, the list of corrected images is converted back into a multi-dimensional 
 1739      array and returned.
 1740
 1741    Example:
 1742    corrected_images = timeseries_n3(image_data)
 1743    """
 1744    mimg = []
 1745    n = len(x.shape) - 1
 1746    for kk in range(x.shape[n]):
 1747        temp = ants.slice_image(x, axis=n, idx=kk)
 1748        temp = ants.n3_bias_field_correction(temp, downsample_factor=2)
 1749        mimg.append(temp)
 1750    return ants.list_to_ndimage(x, mimg)
 1751
 1752def image_write_with_thumbnail( x,  fn, y=None, thumb=True ):
 1753    """
 1754    will write the image and (optionally) a png thumbnail with (optional) overlay/underlay
 1755    """
 1756    ants.image_write( x, fn )
 1757    if not thumb or x.components > 1:
 1758        return
 1759    thumb_fn=re.sub(".nii.gz","_3dthumb.png",fn)
 1760    if thumb and x.dimension == 3:
 1761        if y is None:
 1762            try:
 1763                ants.plot_ortho( x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1764            except:
 1765                pass
 1766        else:
 1767            try:
 1768                ants.plot_ortho( y, x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1769            except:
 1770                pass
 1771    if thumb and x.dimension == 4:
 1772        thumb_fn=re.sub(".nii.gz","_4dthumb.png",fn)
 1773        nslices = x.shape[3]
 1774        sl = np.round( nslices * 0.5 )
 1775        if sl > nslices:
 1776            sl = nslices-1
 1777        xview = ants.slice_image( x, axis=3, idx=int(sl) )
 1778        if y is None:
 1779            try:
 1780                ants.plot_ortho( xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1781            except:
 1782                pass
 1783        else:
 1784            if y.dimension == 3:
 1785                try:
 1786                    ants.plot_ortho(y, xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1787                except:
 1788                    pass
 1789    return
 1790
 1791def convert_np_in_dict(data_dict):
 1792    """
 1793    Convert values in the dictionary from nupmy float or int to regular float or int.
 1794
 1795    :param data_dict: A dictionary with values of various types.
 1796    :return: Dictionary with numpy values converted.
 1797    """
 1798    converted_dict = {}
 1799    for key, value in data_dict.items():
 1800        if isinstance(value, (np.float32, np.float64)):
 1801            converted_dict[key] = float(value)
 1802        elif isinstance(value, (np.int8,  np.uint8, np.int16,  np.uint16, np.int32,  np.uint32, np.int64,  np.uint64)):
 1803            converted_dict[key] = int(value)
 1804        else:
 1805            converted_dict[key] = value
 1806    return converted_dict
 1807
 1808def mc_resample_image_to_target( x , y, interp_type='linear' ):
 1809    """
 1810    multichannel version of resample_image_to_target
 1811    """
 1812    xx=ants.split_channels( x )
 1813    yy=ants.split_channels( y )[0]
 1814    newl=[]
 1815    for k in range(len(xx)):
 1816        newl.append(  ants.resample_image_to_target( xx[k], yy, interp_type=interp_type ) )
 1817    return ants.merge_channels( newl )
 1818
 1819def nrg_filelist_to_dataframe( filename_list, myseparator="-" ):
 1820    """
 1821    convert a list of files in nrg format to a dataframe
 1822
 1823    Arguments
 1824    ---------
 1825    filename_list : globbed list of files
 1826
 1827    myseparator : string separator between nrg parts
 1828
 1829    Returns
 1830    -------
 1831
 1832    df : pandas data frame
 1833
 1834    """
 1835    def getmtime(x):
 1836        x= dt.datetime.fromtimestamp(os.path.getmtime(x)).strftime("%Y-%m-%d %H:%M:%d")
 1837        return x
 1838    df=pd.DataFrame(columns=['filename','file_last_mod_t','else','sid','visitdate','modality','uid'])
 1839    df.set_index('filename')
 1840    df['filename'] = pd.Series([file for file in filename_list ])
 1841    # I applied a time modified file to df['file_last_mod_t'] by getmtime function
 1842    df['file_last_mod_t'] = df['filename'].apply(lambda x: getmtime(x))
 1843    for k in range(df.shape[0]):
 1844        locfn=df['filename'].iloc[k]
 1845        splitter=os.path.basename(locfn).split( myseparator )
 1846        df['sid'].iloc[k]=splitter[1]
 1847        df['visitdate'].iloc[k]=splitter[2]
 1848        df['modality'].iloc[k]=splitter[3]
 1849        temp = os.path.splitext(splitter[4])[0]
 1850        df['uid'].iloc[k]=os.path.splitext(temp)[0]
 1851    return df
 1852
 1853
 1854def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
 1855    """
 1856    merge time series data into space of reference_image
 1857
 1858    img_LR : image
 1859
 1860    img_RL : image
 1861
 1862    allow_resample : boolean
 1863
 1864    """
 1865    # concatenate the images into the reference space
 1866    mimg=[]
 1867    for kk in range( img_LR.shape[3] ):
 1868        temp = ants.slice_image( img_LR, axis=3, idx=kk )
 1869        mimg.append( temp )
 1870    for kk in range( img_RL.shape[3] ):
 1871        temp = ants.slice_image( img_RL, axis=3, idx=kk )
 1872        if kk == 0:
 1873            insamespace = ants.image_physical_space_consistency( temp, mimg[0] )
 1874        if allow_resample and not insamespace :
 1875            temp = ants.resample_image_to_target( temp, mimg[0] )
 1876        mimg.append( temp )
 1877    return ants.list_to_ndimage( img_LR, mimg )
 1878
 1879def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
 1880    """
 1881    Copy spatial metadata (origin, spacing, direction) from a 3D image to the
 1882    spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
 1883
 1884    Parameters
 1885    ----------
 1886    spatial_img : ants.ANTsImage
 1887        A 3D ANTsImage with the desired spatial metadata.
 1888    timeseries_img : ants.ANTsImage
 1889        A 4D ANTsImage to update.
 1890
 1891    Returns
 1892    -------
 1893    ants.ANTsImage
 1894        A 4D ANTsImage with updated spatial metadata.
 1895    """
 1896    if spatial_img.dimension != 3:
 1897        raise ValueError("spatial_img must be a 3D ANTsImage.")
 1898    if timeseries_img.dimension != 4:
 1899        raise ValueError("timeseries_img must be a 4D ANTsImage.")
 1900    # Get 3D metadata
 1901    spatial_origin = list(spatial_img.origin)
 1902    spatial_spacing = list(spatial_img.spacing)
 1903    spatial_direction = spatial_img.direction  # 3x3
 1904    # Get original 4D metadata
 1905    ts_spacing = list(timeseries_img.spacing)
 1906    ts_origin = list(timeseries_img.origin)
 1907    ts_direction = timeseries_img.direction  # 4x4
 1908    # Replace only the first 3 entries for origin and spacing
 1909    new_origin = spatial_origin + [ts_origin[3]]
 1910    new_spacing = spatial_spacing + [ts_spacing[3]]
 1911    # Replace top-left 3x3 block of direction matrix, preserve last row/column
 1912    new_direction = ts_direction.copy()
 1913    new_direction[:3, :3] = spatial_direction
 1914    # Create updated image
 1915    updated_img = ants.from_numpy(
 1916        timeseries_img.numpy(),
 1917        origin=new_origin,
 1918        spacing=new_spacing,
 1919        direction=new_direction
 1920    )
 1921    return updated_img
 1922
 1923def timeseries_transform(transform, image, reference, interpolation='linear'):
 1924    """
 1925    Apply a spatial transform to each 3D volume in a 4D time series image.
 1926
 1927    Parameters
 1928    ----------
 1929    transform : ants transform object
 1930        Path(s) to ANTs-compatible transform(s) to apply.
 1931    image : ants.ANTsImage
 1932        4D input image with shape (X, Y, Z, T).
 1933    reference : ants.ANTsImage
 1934        Reference image to match in space.
 1935    interpolation : str
 1936        Interpolation method: 'linear', 'nearestNeighbor', etc.
 1937
 1938    Returns
 1939    -------
 1940    ants.ANTsImage
 1941        4D transformed image.
 1942    """
 1943    if image.dimension != 4:
 1944        raise ValueError("Input image must be 4D (X, Y, Z, T).")
 1945    n_volumes = image.shape[3]
 1946    transformed_volumes = []
 1947    for t in range(n_volumes):
 1948        vol = ants.slice_image( image, 3, t )
 1949        transformed = ants.apply_ants_transform_to_image(
 1950            transform=transform,
 1951            image=vol,
 1952            reference=reference,
 1953            interpolation=interpolation
 1954        )
 1955        transformed_volumes.append(transformed.numpy())
 1956    # Stack along time axis and convert to ANTsImage
 1957    transformed_array = np.stack(transformed_volumes, axis=-1)
 1958    out_image = ants.from_numpy(transformed_array)
 1959    out_image = ants.copy_image_info(image, out_image)
 1960    out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
 1961    return out_image
 1962
 1963def timeseries_reg(
 1964    image,
 1965    avg_b0,
 1966    type_of_transform='antsRegistrationSyNRepro[r]',
 1967    total_sigma=1.0,
 1968    fdOffset=2.0,
 1969    trim = 0,
 1970    output_directory=None,
 1971    return_numpy_motion_parameters=False,
 1972    verbose=False, **kwargs
 1973):
 1974    """
 1975    Correct time-series data for motion.
 1976
 1977    Arguments
 1978    ---------
 1979    image: antsImage, usually ND where D=4.
 1980
 1981    avg_b0: Fixed image b0 image
 1982
 1983    type_of_transform : string
 1984            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 1985            See ants registration for details.
 1986
 1987    fdOffset: offset value to use in framewise displacement calculation
 1988
 1989    trim : integer - trim this many images off the front of the time series
 1990
 1991    output_directory : string
 1992            output will be placed in this directory plus a numeric extension.
 1993
 1994    return_numpy_motion_parameters : boolean
 1995
 1996    verbose: boolean
 1997
 1998    kwargs: keyword args
 1999            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2000
 2001    Returns
 2002    -------
 2003    dict containing follow key/value pairs:
 2004        `motion_corrected`: Moving image warped to space of fixed image.
 2005        `motion_parameters`: transforms for each image in the time series.
 2006        `FD`: Framewise displacement generalized for arbitrary transformations.
 2007
 2008    Notes
 2009    -----
 2010    Control extra arguments via kwargs. see ants.registration for details.
 2011
 2012    Example
 2013    -------
 2014    >>> import ants
 2015    """
 2016    idim = image.dimension
 2017    ishape = image.shape
 2018    nTimePoints = ishape[idim - 1]
 2019    FD = np.zeros(nTimePoints)
 2020    if type_of_transform is None:
 2021        return {
 2022            "motion_corrected": image,
 2023            "motion_parameters": None,
 2024            "FD": FD
 2025        }
 2026
 2027    remove_it=False
 2028    if output_directory is None:
 2029        remove_it=True
 2030        output_directory = tempfile.mkdtemp()
 2031    output_directory_w = output_directory + "/ts_reg/"
 2032    os.makedirs(output_directory_w,exist_ok=True)
 2033    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2034    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2035    if verbose:
 2036        print('bold motcorr with ' + type_of_transform)
 2037        print(output_directory_w)
 2038        print(ofnG)
 2039        print(ofnL)
 2040        print("remove_it " + str( remove_it ) )
 2041
 2042    # get a local deformation from slice to local avg space
 2043    motion_parameters = list()
 2044    motion_corrected = list()
 2045    mask = ants.get_mask( avg_b0 )
 2046    centerOfMass = mask.get_center_of_mass()
 2047    npts = pow(2, idim - 1)
 2048    pointOffsets = np.zeros((npts, idim - 1))
 2049    myrad = np.ones(idim - 1).astype(int).tolist()
 2050    mask1vals = np.zeros(int(mask.sum()))
 2051    mask1vals[round(len(mask1vals) / 2)] = 1
 2052    mask1 = ants.make_image(mask, mask1vals)
 2053    myoffsets = ants.get_neighborhood_in_mask(
 2054        mask1, mask1, radius=myrad, spatial_info=True
 2055    )["offsets"]
 2056    mycols = list("xy")
 2057    if idim - 1 == 3:
 2058        mycols = list("xyz")
 2059    useinds = list()
 2060    for k in range(myoffsets.shape[0]):
 2061        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2062            useinds.append(k)
 2063        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2064    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2065    if verbose:
 2066        print("Progress:")
 2067    counter = round( nTimePoints / 10 ) + 1
 2068    for k in range( nTimePoints):
 2069        if verbose and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
 2070            myperc = round( k / nTimePoints * 100)
 2071            print(myperc, end="%.", flush=True)
 2072        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2073        temp = ants.iMath(temp, "Normalize")
 2074        txprefix = ofnL+str(k % 2).zfill(4)+"_"
 2075        if temp.numpy().var() > 0:
 2076            myrig = ants.registration(
 2077                    avg_b0, temp,
 2078                    type_of_transform='antsRegistrationSyNRepro[r]',
 2079                    outprefix=txprefix
 2080                )
 2081            if type_of_transform == 'SyN':
 2082                myreg = ants.registration(
 2083                    avg_b0, temp,
 2084                    type_of_transform='SyNOnly',
 2085                    total_sigma=total_sigma,
 2086                    initial_transform=myrig['fwdtransforms'][0],
 2087                    outprefix=txprefix,
 2088                    **kwargs
 2089                )
 2090            else:
 2091                myreg = myrig
 2092            fdptsTxI = ants.apply_transforms_to_points(
 2093                idim - 1, fdpts, myrig["fwdtransforms"]
 2094            )
 2095            if k > 0 and motion_parameters[k - 1] != "NA":
 2096                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2097                    idim - 1, fdpts, motion_parameters[k - 1]
 2098                )
 2099            else:
 2100                fdptsTxIminus1 = fdptsTxI
 2101            # take the absolute value, then the mean across columns, then the sum
 2102            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2103            motion_parameters.append(myreg["fwdtransforms"])
 2104        else:
 2105            motion_parameters.append("NA")
 2106
 2107        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2108        if temp.numpy().var() > 0:
 2109            img1w = ants.apply_transforms( avg_b0,
 2110                temp,
 2111                motion_parameters[k] )
 2112            motion_corrected.append(img1w)
 2113        else:
 2114            motion_corrected.append(avg_b0)
 2115
 2116    motion_parameters = motion_parameters[trim:len(motion_parameters)]
 2117    if return_numpy_motion_parameters:
 2118        motion_parameters = read_ants_transforms_to_numpy( motion_parameters )
 2119
 2120    if remove_it:
 2121        import shutil
 2122        shutil.rmtree(output_directory, ignore_errors=True )
 2123
 2124    if verbose:
 2125        print("Done")
 2126    d4siz = list(avg_b0.shape)
 2127    d4siz.append( 2 )
 2128    spc = list(ants.get_spacing( avg_b0 ))
 2129    spc.append( ants.get_spacing(image)[3] )
 2130    mydir = ants.get_direction( avg_b0 )
 2131    mydir4d = ants.get_direction( image )
 2132    mydir4d[0:3,0:3]=mydir
 2133    myorg = list(ants.get_origin( avg_b0 ))
 2134    myorg.append( 0.0 )
 2135    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
 2136    return {
 2137        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected[trim:len(motion_corrected)]),
 2138        "motion_parameters": motion_parameters,
 2139        "FD": FD[trim:len(FD)]
 2140    }
 2141
 2142
 2143def merge_dwi_data( img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL ):
 2144    """
 2145    merge motion and distortion corrected data if possible
 2146
 2147    img_LRdwp : image
 2148
 2149    bval_LR : array
 2150
 2151    bvec_LR : array
 2152
 2153    img_RLdwp : image
 2154
 2155    bval_RL : array
 2156
 2157    bvec_RL : array
 2158
 2159    """
 2160    import warnings
 2161    insamespace = ants.image_physical_space_consistency( img_LRdwp, img_RLdwp )
 2162    if not insamespace :
 2163        warnings.warn('not insamespace ... corrected image pair should occupy the same physical space; returning only the 1st set and wont join these data.')
 2164        return img_LRdwp, bval_LR, bvec_LR
 2165    
 2166    bval_LR = np.concatenate([bval_LR,bval_RL])
 2167    bvec_LR = np.concatenate([bvec_LR,bvec_RL])
 2168    # concatenate the images
 2169    mimg=[]
 2170    for kk in range( img_LRdwp.shape[3] ):
 2171            mimg.append( ants.slice_image( img_LRdwp, axis=3, idx=kk ) )
 2172    for kk in range( img_RLdwp.shape[3] ):
 2173            mimg.append( ants.slice_image( img_RLdwp, axis=3, idx=kk ) )
 2174    img_LRdwp = ants.list_to_ndimage( img_LRdwp, mimg )
 2175    return img_LRdwp, bval_LR, bvec_LR
 2176
 2177def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
 2178    if motion_parameters is None:
 2179        return bvecs
 2180    n = len(motion_parameters)
 2181    if n < 1:
 2182        return bvecs
 2183    from scipy.linalg import inv, polar
 2184    from dipy.core.gradients import reorient_bvecs
 2185    dipymoco = np.zeros( [n,3,3] )
 2186    for myidx in range(n):
 2187        if myidx < bvecs.shape[0]:
 2188            dipymoco[myidx,:,:] = np.eye( 3 )
 2189            if motion_parameters[myidx] != 'NA':
 2190                temp = motion_parameters[myidx]
 2191                if len(temp) == 4 :
 2192                    temp1=temp[3] # FIXME should be composite of index 1 and 3
 2193                    temp2=temp[1] # FIXME should be composite of index 1 and 3
 2194                    txparam1 = ants.read_transform(temp1)
 2195                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
 2196                    txparam2 = ants.read_transform(temp2)
 2197                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
 2198                    Rinv = inv( np.dot( txparam2, txparam1 ) )
 2199                elif len(temp) == 2 :
 2200                    temp=temp[1] # FIXME should be composite of index 1 and 3
 2201                    txparam = ants.read_transform(temp)
 2202                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
 2203                    Rinv = inv( txparam )
 2204                elif len(temp) == 3 :
 2205                    temp1=temp[2] # FIXME should be composite of index 1 and 3
 2206                    temp2=temp[1] # FIXME should be composite of index 1 and 3
 2207                    txparam1 = ants.read_transform(temp1)
 2208                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
 2209                    txparam2 = ants.read_transform(temp2)
 2210                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
 2211                    Rinv = inv( np.dot( txparam2, txparam1 ) )
 2212                else:
 2213                    temp=temp[0]
 2214                    txparam = ants.read_transform(temp)
 2215                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
 2216                    Rinv = inv( txparam )
 2217                bvecs[myidx,:] = np.dot( Rinv, bvecs[myidx,:] )
 2218                if rebase is not None:
 2219                    # FIXME - should combine these operations
 2220                    bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
 2221    return bvecs
 2222
 2223
 2224def distortion_correct_bvecs(bvecs, def_grad, A_img, A_ref):
 2225    """
 2226    Vectorized computation of voxel-wise distortion corrected b-vectors.
 2227
 2228    Parameters
 2229    ----------
 2230    bvecs : ndarray (N, 3)
 2231    def_grad : ndarray (X, Y, Z, 3, 3) containing rotations derived from the deformation gradient
 2232    A_img : ndarray (3, 3) direction matrix of the fixed image (target undistorted space)
 2233    A_ref : ndarray (3, 3) direction matrix of the moving image (being corrected)
 2234
 2235    Returns
 2236    -------
 2237    bvecs_5d : ndarray (X, Y, Z, N, 3)
 2238    """
 2239    X, Y, Z = def_grad.shape[:3]
 2240    N = bvecs.shape[0]
 2241    # Combined rotation: R_voxel = A_ref.T @ A_img @ def_grad
 2242    A = A_ref.T @ A_img
 2243    R_voxel = np.einsum('ij,xyzjk->xyzik', A, def_grad)  # (X, Y, Z, 3, 3)
 2244    # Apply R_voxel.T @ bvecs
 2245    # First, reshape R_voxel: (X*Y*Z, 3, 3)
 2246    R_voxel_reshaped = R_voxel.reshape(-1, 3, 3)
 2247    # Rotate all bvecs for each voxel
 2248    # Output: (X*Y*Z, N, 3)
 2249    rotated = np.einsum('vij,nj->vni', R_voxel_reshaped, bvecs)
 2250    # Normalize
 2251    norms = np.linalg.norm(rotated, axis=2, keepdims=True)
 2252    rotated /= np.clip(norms, 1e-8, None)
 2253    # Reshape back to (X, Y, Z, N, 3)
 2254    bvecs_5d = rotated.reshape(X, Y, Z, N, 3)
 2255    return bvecs_5d    
 2256
 2257def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
 2258    """
 2259    extract DTI data from a dipy tensormodel
 2260
 2261    reference_image : antsImage defining physical space (3D)
 2262
 2263    tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon
 2264
 2265    upper_triangular: boolean otherwise use lower triangular coding
 2266
 2267    return_image : boolean return the ANTsImage form of DTI otherwise return an array
 2268
 2269    Returns
 2270    -------
 2271    either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form)
 2272        or a 5D NumPy array (dim=X.Y.Z.3.3)
 2273
 2274    Notes
 2275    -----
 2276    DiPy returns lower triangular form but ANTs expects upper triangular.
 2277        Here, we default to the ANTs standard but could generalize in the future 
 2278        because not much here depends on ANTs standards of tensor data.
 2279        ANTs xx,xy,xz,yy,yz,zz
 2280        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
 2281
 2282    """
 2283    # make the DTI - see 
 2284    # https://dipy.org/documentation/1.7.0/examples_built/07_reconstruction/reconst_dti/#sphx-glr-examples-built-07-reconstruction-reconst-dti-py
 2285    # By default, in DIPY, values are ordered as (Dxx, Dxy, Dyy, Dxz, Dyz, Dzz)
 2286    # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2287    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
 2288    import dipy.reconst.dti as dti
 2289    dtiut = dti.lower_triangular(tensormodel.quadratic_form)
 2290    it = np.ndindex( reference_image.shape )
 2291    yyind=2
 2292    xzind=3
 2293    if upper_triangular:
 2294        yyind=3
 2295        xzind=2
 2296        for i in it: # convert to upper triangular
 2297            dtiut[i] = dtiut[i][ reoind ] # do we care if this is doing extra work?
 2298    if return_image:
 2299        dtiAnts = ants.from_numpy(dtiut,has_components=True)
 2300        ants.copy_image_info( reference_image, dtiAnts )
 2301        return dtiAnts
 2302    # copy these data into a tensor 
 2303    dtinp = np.zeros(reference_image.shape + (3,3), dtype=float)  
 2304    dtix = np.zeros((3,3), dtype=float)  
 2305    it = np.ndindex( reference_image.shape )
 2306    for i in it:
 2307        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2308        dtix[0,0]=dtivec[0]
 2309        dtix[1,1]=dtivec[yyind] # 2 for LT
 2310        dtix[2,2]=dtivec[5] 
 2311        dtix[0,1]=dtix[1,0]=dtivec[1]
 2312        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
 2313        dtix[1,2]=dtix[2,1]=dtivec[4]
 2314        dtinp[i]=dtix
 2315    return dtinp
 2316
 2317def triangular_to_tensor( image, upper_triangular=True ):
 2318    """
 2319    convert triangular tensor image to a full tensor form (in numpy)
 2320
 2321    image : antsImage holding dti in either upper or lower triangular format 
 2322
 2323    upper_triangular: boolean
 2324
 2325    Note
 2326    --------
 2327    see get_dti function for more details
 2328    """
 2329    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
 2330    it = np.ndindex( image.shape )
 2331    yyind=2
 2332    xzind=3
 2333    if upper_triangular:
 2334        yyind=3
 2335        xzind=2
 2336    # copy these data into a tensor 
 2337    dtinp = np.zeros(image.shape + (3,3), dtype=float)
 2338    dtix = np.zeros((3,3), dtype=float)
 2339    it = np.ndindex( image.shape )
 2340    dtiut = image.numpy()
 2341    for i in it:
 2342        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2343        dtix[0,0]=dtivec[0]
 2344        dtix[1,1]=dtivec[yyind] # 2 for LT
 2345        dtix[2,2]=dtivec[5] 
 2346        dtix[0,1]=dtix[1,0]=dtivec[1]
 2347        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
 2348        dtix[1,2]=dtix[2,1]=dtivec[4]
 2349        dtinp[i]=dtix
 2350    return dtinp
 2351
 2352
 2353def dti_numpy_to_image( reference_image, tensorarray, upper_triangular=True):
 2354    """
 2355    convert numpy DTI data to antsImage
 2356
 2357    reference_image : antsImage defining physical space (3D)
 2358
 2359    tensorarray : numpy array X,Y,Z,3,3 shape
 2360
 2361    upper_triangular: boolean otherwise use lower triangular coding
 2362
 2363    Returns
 2364    -------
 2365    ANTsImage
 2366
 2367    Notes
 2368    -----
 2369    DiPy returns lower triangular form but ANTs expects upper triangular.
 2370        Here, we default to the ANTs standard but could generalize in the future 
 2371        because not much here depends on ANTs standards of tensor data.
 2372        ANTs xx,xy,xz,yy,yz,zz
 2373        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
 2374
 2375    """
 2376    dtiut = np.zeros(reference_image.shape + (6,), dtype=float)  
 2377    dtivec = np.zeros(6, dtype=float)  
 2378    it = np.ndindex( reference_image.shape )
 2379    yyind=2
 2380    xzind=3
 2381    if upper_triangular:
 2382        yyind=3
 2383        xzind=2
 2384    for i in it:
 2385        dtix = tensorarray[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2386        dtivec[0]=dtix[0,0]
 2387        dtivec[yyind]=dtix[1,1] # 2 for LT
 2388        dtivec[5]=dtix[2,2]
 2389        dtivec[1]=dtix[0,1]
 2390        dtivec[xzind]=dtix[2,0] # 3 for LT
 2391        dtivec[4]=dtix[1,2]
 2392        dtiut[i]=dtivec
 2393    dtiAnts = ants.from_numpy( dtiut, has_components=True )
 2394    ants.copy_image_info( reference_image, dtiAnts )
 2395    return dtiAnts
 2396
 2397
 2398def deformation_gradient_optimized(warp_image, to_rotation=False, to_inverse_rotation=False):
 2399    """
 2400    Compute the deformation gradient tensor from a displacement (warp) field image.
 2401
 2402    This function computes the **deformation gradient** `F = ∂φ/∂x` where `φ(x) = x + u(x)` is the mapping
 2403    induced by the displacement field `u(x)` stored in `warp_image`.
 2404
 2405    The returned tensor field has shape `(x, y, z, dim, dim)` (for 3D), where each matrix represents 
 2406    the **Jacobian** of the transformation at that voxel. The gradient is computed in the physical space 
 2407    of the image using spacing and direction metadata.
 2408
 2409    Optionally, the deformation gradient can be projected onto the space of pure rotations using the polar
 2410    decomposition (via SVD). This is useful for applications like reorientation of tensors (e.g., DTI).
 2411
 2412    Parameters
 2413    ----------
 2414    warp_image : ants.ANTsImage
 2415        A vector-valued ANTsImage encoding the warp/displacement field. It must have `dim` components
 2416        (e.g., shape `(x, y, z, 3)` for 3D) representing the displacements in each spatial direction.
 2417        
 2418    to_rotation : bool, optional
 2419        If True, the deformation gradient will be replaced with its **nearest rotation matrix**
 2420        using the polar decomposition (`F → R`, where `F = R U`).
 2421        
 2422    to_inverse_rotation : bool, optional
 2423        If True, the deformation gradient will be replaced with the **inverse of the rotation**
 2424        (`F → R.T`), which is often needed for transforming tensors **back** to their original frame.
 2425
 2426    Returns
 2427    -------
 2428    F : np.ndarray
 2429        A NumPy array of shape `(x, y, z, dim, dim)` (or `(dim1, dim2, ..., dim, dim)` in general),
 2430        representing the deformation gradient tensor field at each voxel.
 2431
 2432    Raises
 2433    ------
 2434    RuntimeError
 2435        If `warp_image` is not an `ants.ANTsImage`.
 2436
 2437    Notes
 2438    -----
 2439    - The function computes gradients in physical space using the spacing of the image and applies 
 2440      the image direction matrix (`tdir`) to properly orient gradients.
 2441    - The gradient is added to the identity matrix to yield the deformation gradient `F = I + ∂u/∂x`.
 2442    - The polar decomposition ensures `F` is replaced with the closest rotation matrix (orthogonal, det=1).
 2443    - This is a **vectorized pure NumPy implementation**, intended for performance and simplicity.
 2444
 2445    Examples
 2446    --------
 2447    >>> warp = ants.create_warp_image(reference_image, displacement_field)
 2448    >>> F = deformation_gradient_optimized(warp)
 2449    >>> R = deformation_gradient_optimized(warp, to_rotation=True)
 2450    >>> Rinv = deformation_gradient_optimized(warp, to_inverse_rotation=True)
 2451    """
 2452    if not ants.is_image(warp_image):
 2453        raise RuntimeError("antsimage is required")
 2454    dim = warp_image.dimension
 2455    tshp = warp_image.shape
 2456    tdir = warp_image.direction
 2457    spc = warp_image.spacing
 2458    warpnp = warp_image.numpy()
 2459    gradient_list = [np.gradient(warpnp[..., k], *spc, axis=range(dim)) for k in range(dim)]
 2460    # This correctly calculates J.T, where dg[..., i, j] = d(u_j)/d(x_i)
 2461    dg = np.stack([np.stack(grad_k, axis=-1) for grad_k in gradient_list], axis=-1)
 2462    dg = (tdir @ dg).swapaxes(-1, -2)
 2463    dg += np.eye(dim)
 2464    if to_rotation or to_inverse_rotation:
 2465        U, s, Vh = np.linalg.svd(dg)
 2466        Z = U @ Vh
 2467        dets = np.linalg.det(Z)
 2468        reflection_mask = dets < 0
 2469        Vh[reflection_mask, -1, :] *= -1
 2470        Z[reflection_mask] = U[reflection_mask] @ Vh[reflection_mask]
 2471        dg = Z
 2472        if to_inverse_rotation:
 2473            dg = np.transpose(dg, axes=(*range(dg.ndim - 2), dg.ndim - 1, dg.ndim - 2))
 2474    new_shape = tshp + (dim,dim)
 2475    return np.reshape(dg, new_shape)
 2476
 2477
 2478def transform_and_reorient_dti( fixed, moving_dti, composite_transform, verbose=False, **kwargs):
 2479    """
 2480    Applies a transformation to a DTI image using an ANTs composite transform,
 2481    including local tensor reorientation via the Finite Strain method.
 2482
 2483    This function expects:
 2484    - Input `moving_dti` to be a 6-component ANTsImage (upper triangular format).
 2485    - `composite_transform` to point to an ANTs-readable transform file,
 2486      which maps points from `fixed` space to `moving` space.
 2487
 2488    Args:
 2489        fixed (ants.ANTsImage): The reference space image (defines the output grid).
 2490        moving_dti (ants.ANTsImage): The input DTI (6-component), to be transformed.
 2491        composite_transform (str): File path to an ANTs transform
 2492                                   (e.g., from `ants.read_transform` or a written composite transform).
 2493        verbose (bool): Whether to print verbose output during execution.
 2494        **kwargs: Additional keyword arguments passed to `ants.apply_transforms`.
 2495
 2496    Returns:
 2497        ants.ANTsImage: The transformed and reoriented DTI image in the `fixed` space,
 2498                        in 6-component upper triangular format.
 2499    """
 2500    if moving_dti.dimension != 3:
 2501        raise ValueError('moving_dti must be 3-dimensional.')
 2502    if moving_dti.components != 6:
 2503        raise ValueError('moving_dti must have 6 components (upper triangular format).')
 2504
 2505    if verbose:
 2506        print("1. Spatially transforming DTI scalar components from moving to fixed space...")
 2507
 2508    # ants.apply_transforms resamples the *values* of each DTI component from 'moving_dti'
 2509    # onto the grid of 'fixed'.
 2510    # The output 'dtiw' will have the same spatial metadata (spacing, origin, direction) as 'fixed'.
 2511    # However, the tensor values contained within it are still oriented as they were in
 2512    # 'moving_dti's original image space, not 'fixed' image space, and certainly not yet reoriented
 2513    # by the local deformation.
 2514    dtsplit = moving_dti.split_channels()
 2515    dtiw_channels = []
 2516    for k in range(len(dtsplit)):
 2517        dtiw_channels.append( ants.apply_transforms( fixed, dtsplit[k], composite_transform, **kwargs ) )
 2518    dtiw = ants.merge_channels(dtiw_channels)
 2519    
 2520    if verbose:
 2521        print(f"   DTI scalar components resampled to fixed grid. Result shape: {dtiw.shape}")
 2522        print("2. Computing local rotation field from composite transform...")
 2523    
 2524    # Read the composite transform as an image (assumed to be a displacement field).
 2525    # The 'deformation_gradient_optimized' function is assumed to be 100% correct,
 2526    # meaning it returns the appropriate local rotation matrix field (R_moving_to_fixed)
 2527    # in (spatial_dims..., 3, 3 ) format when called with these flags.
 2528    wtx = ants.image_read(composite_transform)
 2529    R_moving_to_fixed_forward = deformation_gradient_optimized(
 2530        wtx,
 2531        to_rotation=False,      # This means the *deformation gradient* F=I+J is computed first.
 2532        to_inverse_rotation=True # This requests the inverse of the rotation part of F.
 2533    )
 2534
 2535    if verbose:
 2536        print(f"   Local reorientation matrices (R_moving_to_fixed_forward) computed. Shape: {R_moving_to_fixed_forward.shape}")
 2537        print("3. Converting 6-component DTI to full 3x3 tensors for vectorized reorientation...")
 2538    
 2539    # Convert `dtiw` (resampled, but still in moving-image-space orientation)
 2540    # from 6-components to full 3x3 tensor representation.
 2541    # dtiw2tensor_np will have shape (spatial_dims..., 3, 3).
 2542    dtiw2tensor_np = triangular_to_tensor(dtiw)
 2543    
 2544    if verbose:
 2545        print("4. Applying vectorized tensor reorientation (Finite Strain Method)...")
 2546    
 2547    # --- Vectorized Tensor Reorientation ---
 2548    # This replaces the entire `for i in it:` loop and its contents with efficient NumPy operations.
 2549
 2550    # Step 4.1: Rebase tensors from `moving_dti.direction` coordinate system to World Coordinates.
 2551    # D_world_moving_orient = moving_dti.direction @ D_moving_image_frame @ moving_dti.direction.T
 2552    # This transforms the tensor's components from being relative to `moving_dti`'s image axes
 2553    # (where they are currently defined) into absolute World (physical) coordinates.
 2554    D_world_moving_orient = np.einsum(
 2555        'ab, ...bc, cd -> ...ad',
 2556        moving_dti.direction,            # 3x3 matrix (moving_image_axes -> world_axes)
 2557        dtiw2tensor_np,                  # (spatial_dims..., 3, 3)
 2558        moving_dti.direction.T           # 3x3 matrix (world_axes -> moving_image_axes) - inverse of moving_dti.direction
 2559    )
 2560
 2561    # Step 4.2: Apply local rotation in World Coordinates (Finite Strain Reorientation).
 2562    # D_reoriented_world = R_moving_to_fixed_forward @ D_world_moving_orient @ (R_moving_to_fixed_forward).T
 2563    # This is the core reorientation step, transforming the tensor's orientation from
 2564    # the original `moving` space to the new `fixed` space, all within world coordinates.
 2565    D_world_fixed_orient = np.einsum(
 2566        '...ab, ...bc, ...cd -> ...ad',
 2567        R_moving_to_fixed_forward,      # (spatial_dims..., 3, 3) - local rotation
 2568        D_world_moving_orient,          # (spatial_dims..., 3, 3) - tensor in world space, moving_orient
 2569        np.swapaxes(R_moving_to_fixed_forward, -1, -2) # (spatial_dims..., 3, 3) - transpose of local rotation
 2570    )
 2571
 2572    # Step 4.3: Rebase reoriented tensors from World Coordinates to `fixed.direction` coordinate system.
 2573    # D_final_fixed_image_frame = (fixed.direction).T @ D_world_fixed_orient @ fixed.direction
 2574    # This transforms the tensor's components from absolute World (physical) coordinates
 2575    # back into `fixed.direction`'s image coordinate system.
 2576    final_dti_tensors_numpy = np.einsum(
 2577        'ba, ...bc, cd -> ...ad',
 2578        fixed.direction,                # Using `fixed.direction` here, but 'ba' indices specify to use its transpose.
 2579        D_world_fixed_orient,           # (spatial_dims..., 3, 3)
 2580        fixed.direction                 # 3x3 matrix (world_axes -> fixed_image_axes)
 2581    )
 2582
 2583    if verbose:
 2584        print("   Vectorized tensor reorientation complete.")
 2585
 2586    if verbose:
 2587        print("5. Converting reoriented full tensors back to 6-component ANTsImage...")
 2588    
 2589    # Convert the final (spatial_dims..., 3, 3) NumPy array of tensors back into a
 2590    # 6-component ANTsImage with the correct spatial metadata from `fixed`.
 2591    final_dti_image = dti_numpy_to_image(fixed, final_dti_tensors_numpy)
 2592    
 2593    if verbose:
 2594        print(f"Done. Final reoriented DTI image in fixed space generated. Shape: {final_dti_image.shape}")
 2595
 2596    return final_dti_image
 2597
 2598def dti_reg(
 2599    image,
 2600    avg_b0,
 2601    avg_dwi,
 2602    bvals=None,
 2603    bvecs=None,
 2604    b0_idx=None,
 2605    type_of_transform="antsRegistrationSyNRepro[r]",
 2606    total_sigma=3.0,
 2607    fdOffset=2.0,
 2608    mask_csf=False,
 2609    brain_mask_eroded=None,
 2610    output_directory=None,
 2611    verbose=False, **kwargs
 2612):
 2613    """
 2614    Correct time-series data for motion - with optional deformation.
 2615
 2616    Arguments
 2617    ---------
 2618        image: antsImage, usually ND where D=4.
 2619
 2620        avg_b0: Fixed image b0 image
 2621
 2622        avg_dwi: Fixed dwi same space as b0 image
 2623
 2624        bvals: bvalues (file or array)
 2625
 2626        bvecs: bvecs (file or array)
 2627
 2628        b0_idx: indices of b0
 2629
 2630        type_of_transform : string
 2631            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 2632            See ants registration for details.
 2633
 2634        fdOffset: offset value to use in framewise displacement calculation
 2635
 2636        mask_csf: boolean
 2637
 2638        brain_mask_eroded: optional mask that will trigger mixed interpolation
 2639
 2640        output_directory : string
 2641            output will be placed in this directory plus a numeric extension.
 2642
 2643        verbose: boolean
 2644
 2645        kwargs: keyword args
 2646            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2647
 2648    Returns
 2649    -------
 2650    dict containing follow key/value pairs:
 2651        `motion_corrected`: Moving image warped to space of fixed image.
 2652        `motion_parameters`: transforms for each image in the time series.
 2653        `FD`: Framewise displacement generalized for arbitrary transformations.
 2654
 2655    Notes
 2656    -----
 2657    Control extra arguments via kwargs. see ants.registration for details.
 2658
 2659    Example
 2660    -------
 2661    >>> import ants
 2662    """
 2663
 2664    idim = image.dimension
 2665    ishape = image.shape
 2666    nTimePoints = ishape[idim - 1]
 2667    FD = np.zeros(nTimePoints)
 2668    if bvals is not None and bvecs is not None:
 2669        if isinstance(bvecs, str):
 2670            bvals, bvecs = read_bvals_bvecs( bvals , bvecs  )
 2671        else: # assume we already read them
 2672            bvals = bvals.copy()
 2673            bvecs = bvecs.copy()
 2674    if type_of_transform is None:
 2675        return {
 2676            "motion_corrected": image,
 2677            "motion_parameters": None,
 2678            "FD": FD,
 2679            'bvals':bvals,
 2680            'bvecs':bvecs
 2681        }
 2682
 2683    from scipy.linalg import inv, polar
 2684    from dipy.core.gradients import reorient_bvecs
 2685
 2686    remove_it=False
 2687    if output_directory is None:
 2688        remove_it=True
 2689        output_directory = tempfile.mkdtemp()
 2690    output_directory_w = output_directory + "/dti_reg/"
 2691    os.makedirs(output_directory_w,exist_ok=True)
 2692    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2693    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2694    if verbose:
 2695        print(output_directory_w)
 2696        print(ofnG)
 2697        print(ofnL)
 2698        print("remove_it " + str( remove_it ) )
 2699
 2700    if b0_idx is None:
 2701        # b0_idx = segment_timeseries_by_meanvalue( image )['highermeans']
 2702        b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
 2703
 2704    # first get a local deformation from slice to local avg space
 2705    # then get a global deformation from avg to ref space
 2706    ab0, adw = get_average_dwi_b0( image )
 2707    # mask is used to roughly locate middle of brain
 2708    mask = ants.threshold_image( ants.iMath(adw,'Normalize'), 0.1, 1.0 )
 2709    if brain_mask_eroded is None:
 2710        brain_mask_eroded = mask * 0 + 1
 2711    motion_parameters = list()
 2712    motion_corrected = list()
 2713    centerOfMass = mask.get_center_of_mass()
 2714    npts = pow(2, idim - 1)
 2715    pointOffsets = np.zeros((npts, idim - 1))
 2716    myrad = np.ones(idim - 1).astype(int).tolist()
 2717    mask1vals = np.zeros(int(mask.sum()))
 2718    mask1vals[round(len(mask1vals) / 2)] = 1
 2719    mask1 = ants.make_image(mask, mask1vals)
 2720    myoffsets = ants.get_neighborhood_in_mask(
 2721        mask1, mask1, radius=myrad, spatial_info=True
 2722    )["offsets"]
 2723    mycols = list("xy")
 2724    if idim - 1 == 3:
 2725        mycols = list("xyz")
 2726    useinds = list()
 2727    for k in range(myoffsets.shape[0]):
 2728        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2729            useinds.append(k)
 2730        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2731    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2732
 2733
 2734    if verbose:
 2735        print("begin global distortion correction")
 2736    # initrig = tra_initializer(avg_b0, ab0, max_rotation=60, transform=['rigid'], verbose=verbose)
 2737    if mask_csf:
 2738        bcsf = ants.threshold_image( avg_b0,"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 2739    else:
 2740        bcsf = ab0 * 0 + 1
 2741
 2742    initrig = ants.registration( avg_b0, ab0,'antsRegistrationSyNRepro[r]',outprefix=ofnG)
 2743    deftx = ants.registration( avg_dwi, adw, 'SyNOnly',
 2744        syn_metric='CC', syn_sampling=2,
 2745        reg_iterations=[50,50,20],
 2746        multivariate_extras=[ [ "CC", avg_b0, ab0, 1, 2 ]],
 2747        initial_transform=initrig['fwdtransforms'][0],
 2748        outprefix=ofnG
 2749        )['fwdtransforms']
 2750    if verbose:
 2751        print("end global distortion correction")
 2752
 2753    if verbose:
 2754        print("Progress:")
 2755    counter = round( nTimePoints / 10 ) + 1
 2756    for k in range(nTimePoints):
 2757        if verbose and nTimePoints > 0 and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
 2758            myperc = round( k / nTimePoints * 100)
 2759            print(myperc, end="%.", flush=True)
 2760        if k in b0_idx:
 2761            fixed=ants.image_clone( ab0 )
 2762        else:
 2763            fixed=ants.image_clone( adw )
 2764        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2765        temp = ants.iMath(temp, "Normalize")
 2766        txprefix = ofnL+str(k).zfill(4)+"rig_"
 2767        txprefix2 = ofnL+str(k % 2).zfill(4)+"def_"
 2768        if temp.numpy().var() > 0:
 2769            myrig = ants.registration(
 2770                    fixed, temp,
 2771                    type_of_transform='antsRegistrationSyNRepro[r]',
 2772                    outprefix=txprefix,
 2773                    **kwargs
 2774                )
 2775            if type_of_transform == 'SyN':
 2776                myreg = ants.registration(
 2777                    fixed, temp,
 2778                    type_of_transform='SyNOnly',
 2779                    total_sigma=total_sigma, grad_step=0.1,
 2780                    initial_transform=myrig['fwdtransforms'][0],
 2781                    outprefix=txprefix2,
 2782                    **kwargs
 2783                )
 2784            else:
 2785                myreg = myrig
 2786            fdptsTxI = ants.apply_transforms_to_points(
 2787                idim - 1, fdpts, myrig["fwdtransforms"]
 2788            )
 2789            if k > 0 and motion_parameters[k - 1] != "NA":
 2790                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2791                    idim - 1, fdpts, motion_parameters[k - 1]
 2792                )
 2793            else:
 2794                fdptsTxIminus1 = fdptsTxI
 2795            # take the absolute value, then the mean across columns, then the sum
 2796            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2797            motion_parameters.append(myreg["fwdtransforms"])
 2798        else:
 2799            motion_parameters.append("NA")
 2800
 2801        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2802        if k in b0_idx:
 2803            fixed=ants.image_clone( ab0 )
 2804        else:
 2805            fixed=ants.image_clone( adw )
 2806        if temp.numpy().var() > 0:
 2807            motion_parameters[k]=deftx+motion_parameters[k]
 2808            img1w = apply_transforms_mixed_interpolation( avg_dwi,
 2809                ants.slice_image(image, axis=idim - 1, idx=k),
 2810                motion_parameters[k], mask=brain_mask_eroded )
 2811            motion_corrected.append(img1w)
 2812        else:
 2813            motion_corrected.append(fixed)
 2814
 2815    if verbose:
 2816        print("Reorient bvecs")
 2817    if bvecs is not None:
 2818        #    direction = target->GetDirection().GetTranspose() * img_mov->GetDirection().GetVnlMatrix();
 2819        rebase = np.dot( np.transpose( avg_b0.direction  ), ab0.direction )
 2820        bvecs = bvec_reorientation( motion_parameters, bvecs, rebase )
 2821
 2822    if remove_it:
 2823        import shutil
 2824        shutil.rmtree(output_directory, ignore_errors=True )
 2825
 2826    if verbose:
 2827        print("Done")
 2828    d4siz = list(avg_b0.shape)
 2829    d4siz.append( 2 )
 2830    spc = list(ants.get_spacing( avg_b0 ))
 2831    spc.append( 1.0 )
 2832    mydir = ants.get_direction( avg_b0 )
 2833    mydir4d = ants.get_direction( image )
 2834    mydir4d[0:3,0:3]=mydir
 2835    myorg = list(ants.get_origin( avg_b0 ))
 2836    myorg.append( 0.0 )
 2837    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
 2838    return {
 2839        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected),
 2840        "motion_parameters": motion_parameters,
 2841        "FD": FD,
 2842        'bvals':bvals,
 2843        'bvecs':bvecs
 2844    }
 2845
 2846
 2847def mc_reg(
 2848    image,
 2849    fixed=None,
 2850    type_of_transform="antsRegistrationSyNRepro[r]",
 2851    mask=None,
 2852    total_sigma=3.0,
 2853    fdOffset=2.0,
 2854    output_directory=None,
 2855    verbose=False, **kwargs
 2856):
 2857    """
 2858    Correct time-series data for motion - with deformation.
 2859
 2860    Arguments
 2861    ---------
 2862        image: antsImage, usually ND where D=4.
 2863
 2864        fixed: Fixed image to register all timepoints to.  If not provided,
 2865            mean image is used.
 2866
 2867        type_of_transform : string
 2868            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 2869            See ants registration for details.
 2870
 2871        fdOffset: offset value to use in framewise displacement calculation
 2872
 2873        output_directory : string
 2874            output will be named with this prefix plus a numeric extension.
 2875
 2876        verbose: boolean
 2877
 2878        kwargs: keyword args
 2879            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2880
 2881    Returns
 2882    -------
 2883    dict containing follow key/value pairs:
 2884        `motion_corrected`: Moving image warped to space of fixed image.
 2885        `motion_parameters`: transforms for each image in the time series.
 2886        `FD`: Framewise displacement generalized for arbitrary transformations.
 2887
 2888    Notes
 2889    -----
 2890    Control extra arguments via kwargs. see ants.registration for details.
 2891
 2892    Example
 2893    -------
 2894    >>> import ants
 2895    >>> fi = ants.image_read(ants.get_ants_data('ch2'))
 2896    >>> mytx = ants.motion_correction( fi )
 2897    """
 2898    remove_it=False
 2899    if output_directory is None:
 2900        remove_it=True
 2901        output_directory = tempfile.mkdtemp()
 2902    output_directory_w = output_directory + "/mc_reg/"
 2903    os.makedirs(output_directory_w,exist_ok=True)
 2904    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2905    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2906    if verbose:
 2907        print(output_directory_w)
 2908        print(ofnG)
 2909        print(ofnL)
 2910
 2911    idim = image.dimension
 2912    ishape = image.shape
 2913    nTimePoints = ishape[idim - 1]
 2914    if fixed is None:
 2915        fixed = ants.get_average_of_timeseries( image )
 2916    if mask is None:
 2917        mask = ants.get_mask(fixed)
 2918    FD = np.zeros(nTimePoints)
 2919    motion_parameters = list()
 2920    motion_corrected = list()
 2921    centerOfMass = mask.get_center_of_mass()
 2922    npts = pow(2, idim - 1)
 2923    pointOffsets = np.zeros((npts, idim - 1))
 2924    myrad = np.ones(idim - 1).astype(int).tolist()
 2925    mask1vals = np.zeros(int(mask.sum()))
 2926    mask1vals[round(len(mask1vals) / 2)] = 1
 2927    mask1 = ants.make_image(mask, mask1vals)
 2928    myoffsets = ants.get_neighborhood_in_mask(
 2929        mask1, mask1, radius=myrad, spatial_info=True
 2930    )["offsets"]
 2931    mycols = list("xy")
 2932    if idim - 1 == 3:
 2933        mycols = list("xyz")
 2934    useinds = list()
 2935    for k in range(myoffsets.shape[0]):
 2936        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2937            useinds.append(k)
 2938        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2939    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2940    if verbose:
 2941        print("Progress:")
 2942    counter = 0
 2943    for k in range(nTimePoints):
 2944        mycount = round(k / nTimePoints * 100)
 2945        if verbose and mycount == counter:
 2946            counter = counter + 10
 2947            print(mycount, end="%.", flush=True)
 2948        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2949        temp = ants.iMath(temp, "Normalize")
 2950        if temp.numpy().var() > 0:
 2951            myrig = ants.registration(
 2952                    fixed, temp,
 2953                    type_of_transform='antsRegistrationSyNRepro[r]',
 2954                    outprefix=ofnL+str(k).zfill(4)+"_",
 2955                    **kwargs
 2956                )
 2957            if type_of_transform == 'SyN':
 2958                myreg = ants.registration(
 2959                    fixed, temp,
 2960                    type_of_transform='SyNOnly',
 2961                    total_sigma=total_sigma,
 2962                    initial_transform=myrig['fwdtransforms'][0],
 2963                    outprefix=ofnL+str(k).zfill(4)+"_",
 2964                    **kwargs
 2965                )
 2966            else:
 2967                myreg = myrig
 2968            fdptsTxI = ants.apply_transforms_to_points(
 2969                idim - 1, fdpts, myreg["fwdtransforms"]
 2970            )
 2971            if k > 0 and motion_parameters[k - 1] != "NA":
 2972                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2973                    idim - 1, fdpts, motion_parameters[k - 1]
 2974                )
 2975            else:
 2976                fdptsTxIminus1 = fdptsTxI
 2977            # take the absolute value, then the mean across columns, then the sum
 2978            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2979            motion_parameters.append(myreg["fwdtransforms"])
 2980            img1w = ants.apply_transforms( fixed,
 2981                ants.slice_image(image, axis=idim - 1, idx=k),
 2982                myreg["fwdtransforms"] )
 2983            motion_corrected.append(img1w)
 2984        else:
 2985            motion_parameters.append("NA")
 2986            motion_corrected.append(temp)
 2987
 2988    if remove_it:
 2989        import shutil
 2990        shutil.rmtree(output_directory, ignore_errors=True )
 2991
 2992    if verbose:
 2993        print("Done")
 2994    return {
 2995        "motion_corrected": ants.list_to_ndimage(image, motion_corrected),
 2996        "motion_parameters": motion_parameters,
 2997        "FD": FD,
 2998    }
 2999
 3000def map_scalar_to_labels(dataframe, label_image_template):
 3001    """
 3002    Map scalar values from a DataFrame to associated integer image labels.
 3003
 3004    Parameters:
 3005    - dataframe (pd.DataFrame): A Pandas DataFrame containing a label column and scalar_value column.
 3006    - label_image_template (ants.ANTsImage): ANTs image with (at least some of) the same values as labels.
 3007
 3008    Returns:
 3009    - ants.ANTsImage: A label image with scalar values mapped to associated integer labels.
 3010    """
 3011
 3012    # Create an empty label image with the same geometry as the template
 3013    mapped_label_image = label_image_template.clone() * 0.0
 3014
 3015    # Loop through DataFrame and map scalar values to labels
 3016    for index, row in dataframe.iterrows():
 3017        label = int(row['label'])  # Assuming the DataFrame has a 'label' column
 3018        scalar_value = row['scalar_value']  # Replace with your column name
 3019        mapped_label_image[label_image_template == label] = scalar_value
 3020
 3021    return mapped_label_image
 3022
 3023
 3024def template_figure_with_overlay(scalar_label_df, prefix, outputfilename=None, template='cit168', xyz=None, mask_dilation=25, padding=12, verbose=True):
 3025    """
 3026    Process and visualize images with mapped scalar values.
 3027
 3028    Parameters:
 3029    - scalar_label_df (pd.DataFrame): A Pandas DataFrame containing scalar values and labels.
 3030    - prefix (str): The prefix for input image files.
 3031    - template (str, optional): Template for selecting image data (default is 'cit168').
 3032    - xyz (str, optional): The integer index of the slices to display.
 3033    - mask_dilation (int, optional): Dilation factor for creating a mask (default is 25).
 3034    - padding (int, optional): Padding value for the mapped images (default is 12).
 3035    - verbose (bool, optional): Enable verbose mode for printing (default is True).
 3036
 3037    Example Usage:
 3038    >>> scalar_label_df = pd.DataFrame({'label': [1, 2, 3], 'scalar_value': [0.5, 0.8, 1.2]})
 3039    >>> prefix = '../PPMI_template0_'
 3040    >>> process_and_visualize_images(scalar_label_df, prefix, template='cit168', xyz=None, mask_dilation=25, padding=12, verbose=True)
 3041    """
 3042
 3043    # Template image paths
 3044    template_paths = {
 3045        'cit168': 'cit168lab.nii.gz',
 3046        'bf': 'bf.nii.gz',
 3047        'cerebellum': 'cerebellum.nii.gz',
 3048        'mtl': 'mtl.nii.gz',
 3049        'ctx': 'dkt_cortex.nii.gz',
 3050        'jhuwm': 'JHU_wm.nii.gz'
 3051    }
 3052
 3053    if template not in template_paths:
 3054        print( "Valid options:")
 3055        print( template_paths )
 3056        raise ValueError(f"Template option '{template}' does not exist.")
 3057
 3058    template_image_path = template_paths[template]
 3059    template_image = ants.image_read(f'{prefix}{template_image_path}')
 3060
 3061    # Load image data
 3062    edgeimg = ants.image_read(f'{prefix}edge.nii.gz')
 3063    dktimg = ants.image_read(f'{prefix}dkt_parcellation.nii.gz')
 3064    segimg = ants.image_read(f'{prefix}tissue_segmentation.nii.gz')
 3065    ttl = ''
 3066
 3067    # Load and process the template image
 3068    ventricles = ants.threshold_image(dktimg, 4, 4) + ants.threshold_image(dktimg, 43, 43)
 3069    seggm = ants.mask_image(segimg, segimg, [2, 4], binarize=False)
 3070    edgeimg = edgeimg.clone()
 3071    edgeimg[edgeimg == 0] = ventricles[edgeimg == 0]
 3072    segwm = ants.threshold_image(segimg, 3, 4).morphology("open", 1)
 3073
 3074    # Define cropping mask
 3075    cmask = ants.threshold_image(template_image, 1, 1.e9).iMath("MD", mask_dilation)
 3076
 3077    mapped_image = map_scalar_to_labels(scalar_label_df, template_image)
 3078    tcrop = ants.crop_image(template_image, cmask)
 3079    toviz = ants.crop_image(mapped_image, cmask)
 3080    seggm = ants.crop_image(edgeimg, cmask)
 3081       
 3082    # Map scalar values to labels and visualize
 3083    toviz = ants.pad_image(toviz, pad_width=(padding, padding, padding))
 3084    seggm = ants.pad_image(seggm, pad_width=(padding, padding, padding))
 3085    tcrop = ants.pad_image(tcrop, pad_width=(padding, padding, padding))
 3086
 3087    if xyz is None:
 3088        if template == 'cit168':
 3089            xyz=[140, 89, 94]
 3090        elif template == 'bf':
 3091            xyz=[114,92,76]
 3092        elif template == 'cerebellum':
 3093            xyz=[169, 128, 137]
 3094        elif template == 'mtl':
 3095            xyz=[154, 112, 113]
 3096        elif template == 'ctx':
 3097            xyz=[233, 190, 174]
 3098        elif template == 'jhuwm':
 3099            xyz=[146, 133, 182]
 3100
 3101    if verbose:
 3102        print("plot xyz for " + template )
 3103        print( xyz )
 3104        
 3105    if outputfilename is None:
 3106        temp = ants.plot_ortho( seggm, overlay=toviz, crop=False,
 3107                        xyz=xyz, cbar_length=0.2, cbar_vertical=False,
 3108                        flat=True, xyz_lines=False, resample=False, orient_labels=False,
 3109                        title=ttl, titlefontsize=12, title_dy=-0.02, textfontcolor='red', 
 3110                        cbar=True, allow_xyz_change=False)
 3111    else:
 3112        temp = ants.plot_ortho( seggm, overlay=toviz, crop=False,
 3113                    xyz=xyz, cbar_length=0.2, cbar_vertical=False,
 3114                    flat=True, xyz_lines=False, resample=False, orient_labels=False,
 3115                    title=ttl, titlefontsize=12, title_dy=-0.02, textfontcolor='red', 
 3116                    cbar=True, allow_xyz_change=False, filename=outputfilename )
 3117    seggm = temp['image']
 3118    toviz = temp['overlay']
 3119    return { "underlay": seggm, 'overlay': toviz, 'seg': tcrop  }
 3120
 3121def get_data( name=None, force_download=False, version=26, target_extension='.csv' ):
 3122    """
 3123    Get ANTsPyMM data filename
 3124
 3125    The first time this is called, it will download data to ~/.antspymm.
 3126    After, it will just read data from disk.  The ~/.antspymm may need to
 3127    be periodically deleted in order to ensure data is current.
 3128
 3129    Arguments
 3130    ---------
 3131    name : string
 3132        name of data tag to retrieve
 3133        Options:
 3134            - 'all'
 3135
 3136    force_download: boolean
 3137
 3138    version: version of data to download (integer)
 3139
 3140    Returns
 3141    -------
 3142    string
 3143        filepath of selected data
 3144
 3145    Example
 3146    -------
 3147    >>> import antspymm
 3148    >>> antspymm.get_data()
 3149    """
 3150    os.makedirs(DATA_PATH, exist_ok=True)
 3151
 3152    def mv_subfolder_files(folder, verbose=False):
 3153        """
 3154        Move files from subfolders to the parent folder.
 3155
 3156        Parameters
 3157        ----------
 3158        folder : str
 3159            Path to the folder.
 3160        verbose : bool, optional
 3161            Print information about the files and folders being processed (default is False).
 3162
 3163        Returns
 3164        -------
 3165        None
 3166        """
 3167        import os
 3168        import shutil
 3169        for root, dirs, files in os.walk(folder):
 3170            if verbose:
 3171                print(f"Processing directory: {root}")
 3172                print(f"Subdirectories: {dirs}")
 3173                print(f"Files: {files}")
 3174            
 3175            for file in files:
 3176                if root != folder:
 3177                    if verbose:
 3178                        print(f"Moving file: {file} from {root} to {folder}")
 3179                    shutil.move(os.path.join(root, file), folder)
 3180            
 3181            for dir in dirs:
 3182                if root != folder:
 3183                    if verbose:
 3184                        print(f"Removing directory: {dir} from {root}")
 3185                    shutil.rmtree(os.path.join(root, dir))
 3186
 3187    def download_data( version ):
 3188        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
 3189        target_file_name = "16912366.zip"
 3190        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
 3191            cache_subdir=DATA_PATH, extract = True )
 3192        mv_subfolder_files( os.path.expanduser("~/.antspymm"), False )
 3193        os.remove( DATA_PATH + target_file_name )
 3194
 3195    if force_download:
 3196        download_data( version = version )
 3197
 3198
 3199    files = []
 3200    for fname in os.listdir(DATA_PATH):
 3201        if ( fname.endswith(target_extension) ) :
 3202            fname = os.path.join(DATA_PATH, fname)
 3203            files.append(fname)
 3204
 3205    if len( files ) == 0 :
 3206        download_data( version = version )
 3207        for fname in os.listdir(DATA_PATH):
 3208            if ( fname.endswith(target_extension) ) :
 3209                fname = os.path.join(DATA_PATH, fname)
 3210                files.append(fname)
 3211
 3212
 3213    if name == 'all':
 3214        return files
 3215
 3216    datapath = None
 3217
 3218    for fname in os.listdir(DATA_PATH):
 3219        mystem = (Path(fname).resolve().stem)
 3220        mystem = (Path(mystem).resolve().stem)
 3221        mystem = (Path(mystem).resolve().stem)
 3222        if ( name == mystem and fname.endswith(target_extension) ) :
 3223            datapath = os.path.join(DATA_PATH, fname)
 3224
 3225    return datapath
 3226
 3227
 3228def get_models( version=3, force_download=True ):
 3229    """
 3230    Get ANTsPyMM data models
 3231
 3232    force_download: boolean
 3233
 3234    Returns
 3235    -------
 3236    None
 3237
 3238    """
 3239    os.makedirs(DATA_PATH, exist_ok=True)
 3240
 3241    def download_data( version ):
 3242        url = "https://figshare.com/ndownloader/articles/21718412/versions/"+str(version)
 3243        target_file_name = "21718412.zip"
 3244        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
 3245            cache_subdir=DATA_PATH, extract = True )
 3246        os.remove( DATA_PATH + target_file_name )
 3247
 3248    if force_download:
 3249        download_data( version = version )
 3250    return
 3251
 3252
 3253
 3254def dewarp_imageset( image_list, initial_template=None,
 3255    iterations=None, padding=0, target_idx=[0], **kwargs ):
 3256    """
 3257    Dewarp a set of images
 3258
 3259    Makes simplifying heuristic decisions about how to transform an image set
 3260    into an unbiased reference space.  Will handle plenty of decisions
 3261    automatically so beware.  Computes an average shape space for the images
 3262    and transforms them to that space.
 3263
 3264    Arguments
 3265    ---------
 3266    image_list : list containing antsImages 2D, 3D or 4D
 3267
 3268    initial_template : optional
 3269
 3270    iterations : number of template building iterations
 3271
 3272    padding:  will pad the images by an integer amount to limit edge effects
 3273
 3274    target_idx : the target indices for the time series over which we should average;
 3275        a list of integer indices into the last axis of the input images.
 3276
 3277    kwargs : keyword args
 3278        arguments passed to ants registration - these must be set explicitly
 3279
 3280    Returns
 3281    -------
 3282    a dictionary with the mean image and the list of the transformed images as
 3283    well as motion correction parameters for each image in the input list
 3284
 3285    Example
 3286    -------
 3287    >>> import antspymm
 3288    """
 3289    outlist = []
 3290    avglist = []
 3291    if len(image_list[0].shape) > 3:
 3292        imagetype = 3
 3293        for k in range(len(image_list)):
 3294            for j in range(len(target_idx)):
 3295                avglist.append( ants.slice_image( image_list[k], axis=3, idx=target_idx[j] ) )
 3296    else:
 3297        imagetype = 0
 3298        avglist=image_list
 3299
 3300    pw=[]
 3301    for k in range(len(avglist[0].shape)):
 3302        pw.append( padding )
 3303    for k in range(len(avglist)):
 3304        avglist[k] = ants.pad_image( avglist[k], pad_width=pw  )
 3305
 3306    if initial_template is None:
 3307        initial_template = avglist[0] * 0
 3308        for k in range(len(avglist)):
 3309            initial_template = initial_template + avglist[k]/len(avglist)
 3310
 3311    if iterations is None:
 3312        iterations = 2
 3313
 3314    btp = ants.build_template(
 3315        initial_template=initial_template,
 3316        image_list=avglist,
 3317        gradient_step=0.5, blending_weight=0.8,
 3318        iterations=iterations, verbose=False, **kwargs )
 3319
 3320    # last - warp all images to this frame
 3321    mocoplist = []
 3322    mocofdlist = []
 3323    reglist = []
 3324    for k in range(len(image_list)):
 3325        if imagetype == 3:
 3326            moco0 = ants.motion_correction( image=image_list[k], fixed=btp, type_of_transform='antsRegistrationSyNRepro[r]' )
 3327            mocoplist.append( moco0['motion_parameters'] )
 3328            mocofdlist.append( moco0['FD'] )
 3329            locavg = ants.slice_image( moco0['motion_corrected'], axis=3, idx=0 ) * 0.0
 3330            for j in range(len(target_idx)):
 3331                locavg = locavg + ants.slice_image( moco0['motion_corrected'], axis=3, idx=target_idx[j] )
 3332            locavg = locavg * 1.0 / len(target_idx)
 3333        else:
 3334            locavg = image_list[k]
 3335        reg = ants.registration( btp, locavg, **kwargs )
 3336        reglist.append( reg )
 3337        if imagetype == 3:
 3338            myishape = image_list[k].shape
 3339            mytslength = myishape[ len(myishape) - 1 ]
 3340            mywarpedlist = []
 3341            for j in range(mytslength):
 3342                locimg = ants.slice_image( image_list[k], axis=3, idx = j )
 3343                mywarped = ants.apply_transforms( btp, locimg,
 3344                    reg['fwdtransforms'] + moco0['motion_parameters'][j], imagetype=0 )
 3345                mywarpedlist.append( mywarped )
 3346            mywarped = ants.list_to_ndimage( image_list[k], mywarpedlist )
 3347        else:
 3348            mywarped = ants.apply_transforms( btp, image_list[k], reg['fwdtransforms'], imagetype=imagetype )
 3349        outlist.append( mywarped )
 3350
 3351    return {
 3352        'dewarpedmean':btp,
 3353        'dewarped':outlist,
 3354        'deformable_registrations': reglist,
 3355        'FD': mocofdlist,
 3356        'motionparameters': mocoplist }
 3357
 3358
 3359def super_res_mcimage( image,
 3360    srmodel,
 3361    truncation=[0.0001,0.995],
 3362    poly_order='hist',
 3363    target_range=[0,1],
 3364    isotropic = False,
 3365    verbose=False ):
 3366    """
 3367    Super resolution on a timeseries or multi-channel image
 3368
 3369    Arguments
 3370    ---------
 3371    image : an antsImage
 3372
 3373    srmodel : a tensorflow fully convolutional model
 3374
 3375    truncation :  quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]
 3376
 3377    poly_order : if not None, will fit a global regression model to map
 3378        intensity back to original histogram space; if 'hist' will match
 3379        by histogram matching - ants.histogram_match_image
 3380
 3381    target_range : 2-element tuple
 3382        a tuple or array defining the (min, max) of the input image
 3383        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 3384        intensity. This range should match the mapping used in the training
 3385        of the network.
 3386
 3387    isotropic : boolean
 3388
 3389    verbose : boolean
 3390
 3391    Returns
 3392    -------
 3393    super resolution version of the image
 3394
 3395    Example
 3396    -------
 3397    >>> import antspymm
 3398    """
 3399    idim = image.dimension
 3400    ishape = image.shape
 3401    nTimePoints = ishape[idim - 1]
 3402    mcsr = list()
 3403    for k in range(nTimePoints):
 3404        if verbose and (( k % 5 ) == 0 ):
 3405            mycount = round(k / nTimePoints * 100)
 3406            print(mycount, end="%.", flush=True)
 3407        temp = ants.slice_image( image, axis=idim - 1, idx=k )
 3408        temp = ants.iMath( temp, "TruncateIntensity", truncation[0], truncation[1] )
 3409        mysr = antspynet.apply_super_resolution_model_to_image( temp, srmodel,
 3410            target_range = target_range )
 3411        if poly_order is not None:
 3412            bilin = ants.resample_image_to_target( temp, mysr )
 3413            if poly_order == 'hist':
 3414                mysr = ants.histogram_match_image( mysr, bilin )
 3415            else:
 3416                mysr = ants.regression_match_image( mysr, bilin, poly_order = poly_order )
 3417        if isotropic:
 3418            mysr = down2iso( mysr )
 3419        if k == 0:
 3420            upshape = list()
 3421            for j in range(len(ishape)-1):
 3422                upshape.append( mysr.shape[j] )
 3423            upshape.append( ishape[ idim-1 ] )
 3424            if verbose:
 3425                print("SR will be of voxel size:" + str(upshape) )
 3426        mcsr.append( mysr )
 3427
 3428    upshape = list()
 3429    for j in range(len(ishape)-1):
 3430        upshape.append( mysr.shape[j] )
 3431    upshape.append( ishape[ idim-1 ] )
 3432    if verbose:
 3433        print("SR will be of voxel size:" + str(upshape) )
 3434
 3435    imageup = ants.resample_image( image, upshape, use_voxels = True )
 3436    if verbose:
 3437        print("Done")
 3438
 3439    return ants.list_to_ndimage( imageup, mcsr )
 3440
 3441
 3442def segment_timeseries_by_bvalue(bvals):
 3443    """
 3444    Segments a time series based on a threshold applied to b-values.
 3445    
 3446    This function categorizes indices of the given b-values array into two groups:
 3447    one for indices where b-values are above a near-zero threshold, and another
 3448    where b-values are at or below this threshold. The threshold is set to 1e-12.
 3449    
 3450    Parameters:
 3451    - bvals (numpy.ndarray): An array of b-values.
 3452
 3453    Returns:
 3454    - dict: A dictionary with two keys, 'largerbvals' and 'lowbvals', each containing
 3455      the indices of bvals where the b-values are above and at/below the threshold, respectively.
 3456    """
 3457    # Define the threshold
 3458    threshold = 1e-12
 3459    def find_min_value(data):
 3460        if isinstance(data, list):
 3461            return min(data)
 3462        elif isinstance(data, np.ndarray):
 3463            return np.min(data)
 3464        else:
 3465            raise TypeError("Input must be either a list or a numpy array")
 3466
 3467    # Get indices where b-values are greater than the threshold
 3468    lowermeans = list(np.where(bvals > threshold)[0])
 3469    
 3470    # Get indices where b-values are less than or equal to the threshold
 3471    highermeans = list(np.where(bvals <= threshold)[0])
 3472    
 3473    if len(highermeans) == 0:
 3474        minval = find_min_value( bvals )
 3475        lowermeans = list(np.where(bvals > minval )[0])
 3476        highermeans = list(np.where(bvals <= minval)[0])
 3477
 3478    return {
 3479        'largerbvals': lowermeans,
 3480        'lowbvals': highermeans
 3481    }
 3482
 3483def segment_timeseries_by_meanvalue( image, quantile = 0.995 ):
 3484    """
 3485    Identify indices of a time series where we assume there is a different mean
 3486    intensity over the volumes.  The indices of volumes with higher and lower
 3487    intensities is returned.  Can be used to automatically identify B0 volumes
 3488    in DWI timeseries.
 3489
 3490    Arguments
 3491    ---------
 3492    image : an antsImage holding B0 and DWI
 3493
 3494    quantile : a quantile for splitting the indices of the volume - should be greater than 0.5
 3495
 3496    Returns
 3497    -------
 3498    dictionary holding the two sets of indices
 3499
 3500    Example
 3501    -------
 3502    >>> import antspymm
 3503    """
 3504    ishape = image.shape
 3505    lastdim = len(ishape)-1
 3506    meanvalues = list()
 3507    for x in range(ishape[lastdim]):
 3508        meanvalues.append(  ants.slice_image( image, axis=lastdim, idx=x ).mean() )
 3509    myhiq = np.quantile( meanvalues, quantile )
 3510    myloq = np.quantile( meanvalues, 1.0 - quantile )
 3511    lowerindices = list()
 3512    higherindices = list()
 3513    for x in range(len(meanvalues)):
 3514        hiabs = abs( meanvalues[x] - myhiq )
 3515        loabs = abs( meanvalues[x] - myloq )
 3516        if hiabs < loabs:
 3517            higherindices.append(x)
 3518        else:
 3519            lowerindices.append(x)
 3520
 3521    return {
 3522    'lowermeans':lowerindices,
 3523    'highermeans':higherindices }
 3524
 3525
 3526def get_average_rsf( x, min_t=10, max_t=35 ):
 3527    """
 3528    automatically generates the average bold image with quick registration
 3529
 3530    returns:
 3531        avg_bold
 3532    """
 3533    output_directory = tempfile.mkdtemp()
 3534    ofn = output_directory + "/w"
 3535    bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3536    oavg = ants.slice_image( x, axis=3, idx=0 )
 3537    if x.shape[3] <= min_t:
 3538        min_t=0
 3539    if x.shape[3] <= max_t:
 3540        max_t=x.shape[3]
 3541    for myidx in range(min_t,max_t):
 3542        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3543        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3544    bavg = ants.iMath( bavg, 'Normalize' )
 3545    oavg = ants.image_clone( bavg )
 3546    bavg = oavg * 0.0
 3547    for myidx in range(min_t,max_t):
 3548        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3549        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3550    import shutil
 3551    shutil.rmtree(output_directory, ignore_errors=True )
 3552    bavg = ants.iMath( bavg, 'Normalize' )
 3553    return bavg
 3554    # return ants.n4_bias_field_correction(bavg, mask=ants.get_mask( bavg ) )
 3555
 3556
 3557def get_average_dwi_b0( x, fixed_b0=None, fixed_dwi=None, fast=False ):
 3558    """
 3559    automatically generates the average b0 and dwi and outputs both;
 3560    maps dwi to b0 space at end.
 3561
 3562    x : input image
 3563
 3564    fixed_b0 : alernative reference space
 3565
 3566    fixed_dwi : alernative reference space
 3567
 3568    fast : boolean
 3569
 3570    returns:
 3571        avg_b0, avg_dwi
 3572    """
 3573    output_directory = tempfile.mkdtemp()
 3574    ofn = output_directory + "/w"
 3575    temp = segment_timeseries_by_meanvalue( x )
 3576    b0_idx = temp['highermeans']
 3577    non_b0_idx = temp['lowermeans']
 3578    if ( fixed_b0 is None and fixed_dwi is None ) or fast:
 3579        xavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3580        bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3581        fixed_b0_use = ants.slice_image( x, axis=3, idx=b0_idx[0] )
 3582        fixed_dwi_use = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
 3583    else:
 3584        temp_b0 = ants.slice_image( x, axis=3, idx=b0_idx[0] )
 3585        temp_dwi = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
 3586        xavg = fixed_b0 * 0.0
 3587        bavg = fixed_b0 * 0.0
 3588        tempreg = ants.registration( fixed_b0, temp_b0,'antsRegistrationSyNRepro[r]')
 3589        fixed_b0_use = tempreg['warpedmovout']
 3590        fixed_dwi_use = ants.apply_transforms( fixed_b0, temp_dwi, tempreg['fwdtransforms'] )
 3591    for myidx in range(x.shape[3]):
 3592        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3593        if not fast:
 3594            if not myidx in b0_idx:
 3595                xavg = xavg + ants.registration(fixed_dwi_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3596            else:
 3597                bavg = bavg + ants.registration(fixed_b0_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3598        else:
 3599            if not myidx in b0_idx:
 3600                xavg = xavg + b0
 3601            else:
 3602                bavg = bavg + b0
 3603    bavg = ants.iMath( bavg, 'Normalize' )
 3604    xavg = ants.iMath( xavg, 'Normalize' )
 3605    import shutil
 3606    shutil.rmtree(output_directory, ignore_errors=True )
 3607    avgb0=ants.n4_bias_field_correction(bavg)
 3608    avgdwi=ants.n4_bias_field_correction(xavg)
 3609    avgdwi=ants.registration( avgb0, avgdwi, 'antsRegistrationSyNRepro[r]' )['warpedmovout']
 3610    return avgb0, avgdwi
 3611
 3612def dti_template(
 3613    b_image_list=None,
 3614    w_image_list=None,
 3615    iterations=5,
 3616    gradient_step=0.5,
 3617    mask_csf=False,
 3618    average_both=True,
 3619    verbose=False
 3620):
 3621    """
 3622    two channel version of build_template
 3623
 3624    returns:
 3625        avg_b0, avg_dwi
 3626    """
 3627    output_directory = tempfile.mkdtemp()
 3628    mydeftx = tempfile.NamedTemporaryFile(delete=False,dir=output_directory).name
 3629    tmp = tempfile.NamedTemporaryFile(delete=False,dir=output_directory,suffix=".nii.gz")
 3630    wavgfn = tmp.name
 3631    tmp2 = tempfile.NamedTemporaryFile(delete=False,dir=output_directory)
 3632    comptx = tmp2.name
 3633    weights = np.repeat(1.0 / len(b_image_list), len(b_image_list))
 3634    weights = [x / sum(weights) for x in weights]
 3635    w_initial_template = w_image_list[0]
 3636    b_initial_template = b_image_list[0]
 3637    b_initial_template = ants.iMath(b_initial_template,"Normalize")
 3638    w_initial_template = ants.iMath(w_initial_template,"Normalize")
 3639    if mask_csf:
 3640        bcsf0 = ants.threshold_image( b_image_list[0],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 3641        bcsf1 = ants.threshold_image( b_image_list[1],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 3642    else:
 3643        bcsf0 = b_image_list[0] * 0 + 1
 3644        bcsf1 = b_image_list[1] * 0 + 1
 3645    bavg = b_initial_template.clone() * bcsf0
 3646    wavg = w_initial_template.clone() * bcsf0
 3647    bcsf = [ bcsf0, bcsf1 ]
 3648    for i in range(iterations):
 3649        for k in range(len(w_image_list)):
 3650            fimg=wavg
 3651            mimg=w_image_list[k] * bcsf[k]
 3652            fimg2=bavg
 3653            mimg2=b_image_list[k] * bcsf[k]
 3654            w1 = ants.registration(
 3655                fimg, mimg, type_of_transform='antsRegistrationSyNQuickRepro[s]',
 3656                    multivariate_extras= [ [ "CC", fimg2, mimg2, 1, 2 ]],
 3657                    outprefix=mydeftx,
 3658                    verbose=0 )
 3659            txname = ants.apply_transforms(wavg, wavg,
 3660                w1["fwdtransforms"], compose=comptx )
 3661            if k == 0:
 3662                txavg = ants.image_read(txname) * weights[k]
 3663                wavgnew = ants.apply_transforms( wavg,
 3664                    w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3665                bavgnew = ants.apply_transforms( wavg,
 3666                    b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3667            else:
 3668                txavg = txavg + ants.image_read(txname) * weights[k]
 3669                if i >= (iterations-2) and average_both:
 3670                    wavgnew = wavgnew+ants.apply_transforms( wavg,
 3671                        w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3672                    bavgnew = bavgnew+ants.apply_transforms( wavg,
 3673                        b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3674        if verbose:
 3675            print("iteration:",str(i),str(txavg.abs().mean()))
 3676        wscl = (-1.0) * gradient_step
 3677        txavg = txavg * wscl
 3678        ants.image_write( txavg, wavgfn )
 3679        wavg = ants.apply_transforms(wavg, wavgnew, wavgfn).iMath("Normalize")
 3680        bavg = ants.apply_transforms(bavg, bavgnew, wavgfn).iMath("Normalize")
 3681    import shutil
 3682    shutil.rmtree( output_directory, ignore_errors=True )
 3683    if verbose:
 3684        print("done")
 3685    return bavg, wavg
 3686
 3687def read_ants_transforms_to_numpy(transform_files ):
 3688    """
 3689    Read a list of ANTs transform files and convert them to a NumPy array.
 3690    The function filters out any files that are not .mat and will only  use
 3691    the first .mat in each entry of the list.
 3692
 3693    :param transform_files: List of lists of file paths to ANTs transform files.  
 3694    :return: NumPy array of the transforms.
 3695    """
 3696    extension = '.mat'
 3697    # Filter the list of lists
 3698    filtered_lists = [[string for string in sublist if string.endswith(extension)] 
 3699                    for sublist in transform_files]
 3700    transforms = []
 3701    for file in filtered_lists:
 3702        transform = ants.read_transform(file[0])
 3703        np_transform = np.array(ants.get_ants_transform_parameters(transform)[0:9])
 3704        transforms.append(np_transform)
 3705    return np.array(transforms)
 3706
 3707def t1_based_dwi_brain_extraction(
 3708    t1w_head,
 3709    t1w,
 3710    dwi,
 3711    b0_idx = None,
 3712    transform='antsRegistrationSyNRepro[r]',
 3713    deform=None,
 3714    verbose=False
 3715):
 3716    """
 3717    Map a t1-based brain extraction to b0 and return a mask and average b0
 3718
 3719    Arguments
 3720    ---------
 3721    t1w_head : an antsImage of the hole head
 3722
 3723    t1w : an antsImage probably but not necessarily T1-weighted
 3724
 3725    dwi : an antsImage holding B0 and DWI
 3726
 3727    b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess
 3728
 3729    transform : string Rigid or other ants.registration tx type
 3730
 3731    deform : follow up transform with deformation
 3732
 3733    Returns
 3734    -------
 3735    dictionary holding the avg_b0 and its mask
 3736
 3737    Example
 3738    -------
 3739    >>> import antspymm
 3740    """
 3741    t1w_use = ants.iMath( t1w, "Normalize" )
 3742    t1bxt = ants.threshold_image( t1w_use, 0.05, 1 ).iMath("FillHoles")
 3743    if b0_idx is None:
 3744        b0_idx = segment_timeseries_by_meanvalue( dwi )['highermeans']
 3745    # first get the average b0
 3746    if len( b0_idx ) > 1:
 3747        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] ).iMath("Normalize")
 3748        for n in range(1,len(b0_idx)):
 3749            temp = ants.slice_image( dwi, axis=3, idx=b0_idx[n] )
 3750            reg = ants.registration( b0_avg, temp, 'antsRegistrationSyNRepro[r]' )
 3751            b0_avg = b0_avg + ants.iMath( reg['warpedmovout'], "Normalize")
 3752    else:
 3753        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] )
 3754    b0_avg = ants.iMath(b0_avg,"Normalize")
 3755    reg = tra_initializer( b0_avg, t1w, n_simulations=12,   verbose=verbose )
 3756    if deform is not None:
 3757        reg = ants.registration( b0_avg, t1w,
 3758            'SyNOnly',
 3759            total_sigma=0.5,
 3760            initial_transform=reg['fwdtransforms'][0],
 3761            verbose=False )
 3762    outmsk = ants.apply_transforms( b0_avg, t1bxt, reg['fwdtransforms'], interpolator='linear').threshold_image( 0.5, 1.0 )
 3763    return  {
 3764    'b0_avg':b0_avg,
 3765    'b0_mask':outmsk }
 3766
 3767def mc_denoise( x, ratio = 0.5 ):
 3768    """
 3769    ants denoising for timeseries (4D)
 3770
 3771    Arguments
 3772    ---------
 3773    x : an antsImage 4D
 3774
 3775    ratio : weight between 1 and 0 - lower weights bring result closer to initial image
 3776
 3777    Returns
 3778    -------
 3779    denoised time series
 3780
 3781    """
 3782    dwpimage = []
 3783    for myidx in range(x.shape[3]):
 3784        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3785        dnzb0 = ants.denoise_image( b0, p=1,r=1,noise_model='Gaussian' )
 3786        dwpimage.append( dnzb0 * ratio + b0 * (1.0-ratio) )
 3787    return ants.list_to_ndimage( x, dwpimage )
 3788
 3789def tsnr( x, mask, indices=None ):
 3790    """
 3791    3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1
 3792
 3793    x: image
 3794
 3795    mask : mask
 3796
 3797    indices: indices to use
 3798
 3799    returns a 3D image
 3800    """
 3801    M = ants.timeseries_to_matrix( x, mask )
 3802    M = M - M.min()
 3803    M = M / M.max()
 3804    if indices is not None:
 3805        M=M[indices,:]
 3806    stdM = np.std(M, axis=0 )
 3807    stdM[np.isnan(stdM)] = 0
 3808    tt = round( 0.975*100 )
 3809    threshold_std = np.percentile( stdM, tt )
 3810    tsnrimage = ants.make_image( mask, stdM )
 3811    return tsnrimage
 3812
 3813def dvars( x,  mask, indices=None ):
 3814    """
 3815    dvars on a time series image ... the matrix is normalized to range of 0,1
 3816
 3817    x: image
 3818
 3819    mask : mask
 3820
 3821    indices: indices to use
 3822
 3823    returns an array
 3824    """
 3825    M = ants.timeseries_to_matrix( x, mask )
 3826    M = M - M.min()
 3827    M = M / M.max()
 3828    if indices is not None:
 3829        M=M[indices,:]
 3830    DVARS = np.zeros( M.shape[0] )
 3831    for i in range(1, M.shape[0] ):
 3832        vecdiff = M[i-1,:] - M[i,:]
 3833        DVARS[i] = np.sqrt( ( vecdiff * vecdiff ).mean() )
 3834    DVARS[0] = DVARS.mean()
 3835    return DVARS
 3836
 3837
 3838def slice_snr( x,  background_mask, foreground_mask, indices=None ):
 3839    """
 3840    slice-wise SNR on a time series image
 3841
 3842    x: image
 3843
 3844    background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask
 3845
 3846    foreground_mask : mask - maybe cortex or WM or brain mask
 3847
 3848    indices: indices to use
 3849
 3850    returns an array
 3851    """
 3852    xuse=ants.iMath(x,"Normalize")
 3853    MB = ants.timeseries_to_matrix( xuse, background_mask )
 3854    MF = ants.timeseries_to_matrix( xuse, foreground_mask )
 3855    if indices is not None:
 3856        MB=MB[indices,:]
 3857        MF=MF[indices,:]
 3858    ssnr = np.zeros( MB.shape[0] )
 3859    for i in range( MB.shape[0] ):
 3860        ssnr[i]=MF[i,:].mean()/MB[i,:].std()
 3861    ssnr[np.isnan(ssnr)] = 0
 3862    return ssnr
 3863
 3864
 3865def impute_fa( fa, md ):
 3866    """
 3867    impute bad values in dti, fa, md
 3868    """
 3869    def imputeit( x, fa ):
 3870        badfa=ants.threshold_image(fa,1,1)
 3871        if badfa.max() == 1:
 3872            temp=ants.image_clone(x)
 3873            temp[badfa==1]=0
 3874            temp=ants.iMath(temp,'GD',2)
 3875            x[ badfa==1 ]=temp[badfa==1]
 3876        return x
 3877    md=imputeit( md, fa )
 3878    fa=imputeit( ants.image_clone(fa), fa )
 3879    return fa, md
 3880
 3881def trim_dti_mask( fa, mask, param=4.0 ):
 3882    """
 3883    trim the dti mask to get rid of bright fa rim
 3884
 3885    this function erodes the famask by param amount then segments the rim into
 3886    bright and less bright parts.  the bright parts are trimmed from the mask
 3887    and the remaining edges are cleaned up a bit with closing.
 3888
 3889    param: closing radius unit is in physical space
 3890    """
 3891    spacing = ants.get_spacing(mask)
 3892    spacing_product = np.prod( spacing )
 3893    spcmin = min( spacing )
 3894    paramVox = int(np.round( param / spcmin ))
 3895    trim_mask = ants.image_clone( mask )
 3896    trim_mask = ants.iMath( trim_mask, "FillHoles" )
 3897    edgemask = trim_mask - ants.iMath( trim_mask, "ME", paramVox )
 3898    maxk=4
 3899    edgemask = ants.threshold_image( fa * edgemask, "Otsu", maxk )
 3900    edgemask = ants.threshold_image( edgemask, maxk-1, maxk )
 3901    trim_mask[edgemask >= 1 ]=0
 3902    trim_mask = ants.iMath(trim_mask,"ME",paramVox-1)
 3903    trim_mask = ants.iMath(trim_mask,'GetLargestComponent')
 3904    trim_mask = ants.iMath(trim_mask,"MD",paramVox-1)
 3905    return trim_mask
 3906
 3907
 3908
 3909def efficient_tensor_fit( gtab, fit_method, imagein, maskin, diffusion_model='DTI', 
 3910                         chunk_size=10, num_threads=1, verbose=True):
 3911    """
 3912    Efficient and optionally parallelized tensor reconstruction using DiPy.
 3913
 3914    Parameters
 3915    ----------
 3916    gtab : GradientTable
 3917        Dipy gradient table.
 3918    fit_method : str
 3919        Tensor fitting method (e.g. 'WLS', 'OLS', 'RESTORE').
 3920    imagein : ants.ANTsImage
 3921        4D diffusion-weighted image.
 3922    maskin : ants.ANTsImage
 3923        Binary brain mask image.
 3924    diffusion_model : string, optional
 3925        DTI, FreeWater, DKI.
 3926    chunk_size : int, optional
 3927        Number of slices (along z-axis) to process at once.
 3928    num_threads : int, optional
 3929        Number of threads to use (1 = single-threaded).
 3930    verbose : bool, optional
 3931        Print status updates.
 3932    
 3933    Returns
 3934    -------
 3935    tenfit : TensorFit or FreeWaterTensorFit
 3936        Fitted tensor model.
 3937    FA : ants.ANTsImage
 3938        Fractional anisotropy image.
 3939    MD : ants.ANTsImage
 3940        Mean diffusivity image.
 3941    RGB : ants.ANTsImage
 3942        RGB FA map.
 3943    """
 3944    assert imagein.dimension == 4, "Input image must be 4D"
 3945
 3946    import ants
 3947    import numpy as np
 3948    import dipy.reconst.dti as dti
 3949    import dipy.reconst.fwdti as fwdti
 3950    from dipy.reconst.dti import fractional_anisotropy
 3951    from dipy.reconst.dti import color_fa
 3952    from concurrent.futures import ThreadPoolExecutor, as_completed
 3953
 3954    img_data = imagein.numpy()
 3955    mask = maskin.numpy().astype(bool)
 3956    X, Y, Z, N = img_data.shape
 3957    if verbose:
 3958        print(f"Input shape: {img_data.shape}, Processing in chunks of {chunk_size} slices.")
 3959
 3960    model = fwdti.FreeWaterTensorModel(gtab) if diffusion_model == 'FreeWater' else dti.TensorModel(gtab, fit_method=fit_method)
 3961
 3962    def process_chunk(z_start):
 3963        z_end = min(Z, z_start + chunk_size)
 3964        local_data = img_data[:, :, z_start:z_end, :]
 3965        local_mask = mask[:, :, z_start:z_end]
 3966        masked_data = local_data * local_mask[..., None]
 3967        masked_data = np.nan_to_num(masked_data, nan=0)
 3968        fit = model.fit(masked_data)
 3969        FA_chunk = fractional_anisotropy(fit.evals)
 3970        FA_chunk[np.isnan(FA_chunk)] = 1
 3971        FA_chunk = np.clip(FA_chunk, 0, 1)
 3972        MD_chunk = dti.mean_diffusivity(fit.evals)
 3973        RGB_chunk = color_fa(FA_chunk, fit.evecs)
 3974        return z_start, z_end, FA_chunk, MD_chunk, RGB_chunk
 3975
 3976    FA_vol = np.zeros((X, Y, Z), dtype=np.float32)
 3977    MD_vol = np.zeros((X, Y, Z), dtype=np.float32)
 3978    RGB_vol = np.zeros((X, Y, Z, 3), dtype=np.float32)
 3979
 3980    chunks = range(0, Z, chunk_size)
 3981    if num_threads > 1:
 3982        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 3983            futures = {executor.submit(process_chunk, z): z for z in chunks}
 3984            for f in as_completed(futures):
 3985                z_start, z_end, FA_chunk, MD_chunk, RGB_chunk = f.result()
 3986                FA_vol[:, :, z_start:z_end] = FA_chunk
 3987                MD_vol[:, :, z_start:z_end] = MD_chunk
 3988                RGB_vol[:, :, z_start:z_end, :] = RGB_chunk
 3989    else:
 3990        for z in chunks:
 3991            z_start, z_end, FA_chunk, MD_chunk, RGB_chunk = process_chunk(z)
 3992            FA_vol[:, :, z_start:z_end] = FA_chunk
 3993            MD_vol[:, :, z_start:z_end] = MD_chunk
 3994            RGB_vol[:, :, z_start:z_end, :] = RGB_chunk
 3995
 3996    b0 = ants.slice_image(imagein, axis=3, idx=0)
 3997    FA = ants.copy_image_info(b0, ants.from_numpy(FA_vol))
 3998    MD = ants.copy_image_info(b0, ants.from_numpy(MD_vol))
 3999    RGB_channels = [ants.copy_image_info(b0, ants.from_numpy(RGB_vol[..., i])) for i in range(3)]
 4000    RGB = ants.merge_channels(RGB_channels)
 4001
 4002    return model.fit(img_data * mask[..., None]), FA, MD, RGB
 4003
 4004
 4005
 4006def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
 4007                      model_params=None, bvals_to_use=None,
 4008                      chunk_size=1024, num_threads=1, verbose=True):
 4009    """
 4010    Efficient and optionally parallelized diffusion model reconstruction using DiPy.
 4011
 4012    Parameters
 4013    ----------
 4014    gtab : GradientTable
 4015        DiPy gradient table.
 4016    diffusion_model : str
 4017        One of ['DTI', 'FreeWater', 'DKI'].
 4018    imagein : ants.ANTsImage
 4019        4D diffusion-weighted image.
 4020    maskin : ants.ANTsImage
 4021        Binary brain mask image.
 4022    model_params : dict, optional
 4023        Additional parameters passed to model constructors.
 4024    bvals_to_use : list of int, optional
 4025        Subset of b-values to use for the fit (e.g., [0, 1000, 2000]).
 4026    chunk_size : int, optional
 4027        Maximum number of voxels per chunk (default 1024).
 4028    num_threads : int, optional
 4029        Number of parallel threads.
 4030    verbose : bool, optional
 4031        Whether to print status messages.
 4032
 4033    Returns
 4034    -------
 4035    fit : dipy ModelFit
 4036        The fitted model object.
 4037    FA : ants.ANTsImage or None
 4038        Fractional anisotropy image (if applicable).
 4039    MD : ants.ANTsImage or None
 4040        Mean diffusivity image (if applicable).
 4041    RGB : ants.ANTsImage or None
 4042        Color FA image (if applicable).
 4043    """
 4044    import ants
 4045    import numpy as np
 4046    import dipy.reconst.dti as dti
 4047    import dipy.reconst.fwdti as fwdti
 4048    import dipy.reconst.dki as dki
 4049    from dipy.core.gradients import gradient_table
 4050    from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
 4051    from concurrent.futures import ThreadPoolExecutor, as_completed
 4052
 4053    assert imagein.dimension == 4, "Input image must be 4D"
 4054    model_params = model_params or {}
 4055
 4056    img_data = imagein.numpy()
 4057    mask = maskin.numpy().astype(bool)
 4058    X, Y, Z, N = img_data.shape
 4059    inplane_size = X * Y
 4060
 4061    # Convert chunk_size from voxel count to number of slices
 4062    slices_per_chunk = max(1, chunk_size // inplane_size)
 4063
 4064    if verbose:
 4065        print(f"[INFO] Image shape: {img_data.shape}")
 4066        print(f"[INFO] Using model: {diffusion_model}")
 4067        print(f"[INFO] Max voxels per chunk: {chunk_size} (→ {slices_per_chunk} slices) | Threads: {num_threads}")
 4068
 4069    if bvals_to_use is not None:
 4070        bvals_to_use = set(bvals_to_use)
 4071        sel = np.isin(gtab.bvals, list(bvals_to_use))
 4072        img_data = img_data[..., sel]
 4073        gtab = gradient_table(gtab.bvals[sel], bvecs=gtab.bvecs[sel])
 4074        if verbose:
 4075            print(f"[INFO] Selected b-values: {sorted(bvals_to_use)}")
 4076            print(f"[INFO] Selected volumes: {sel.sum()} / {N}")
 4077
 4078    def get_model(name, gtab, **params):
 4079        if name == 'DTI':
 4080            return dti.TensorModel(gtab, **params)
 4081        elif name == 'FreeWater':
 4082            return fwdti.FreeWaterTensorModel(gtab)
 4083        elif name == 'DKI':
 4084            return dki.DiffusionKurtosisModel(gtab, **params)
 4085        else:
 4086            raise ValueError(f"Unsupported model: {name}")
 4087
 4088    model = get_model(diffusion_model, gtab, **model_params)
 4089
 4090    FA_vol = np.zeros((X, Y, Z), dtype=np.float32)
 4091    MD_vol = np.zeros((X, Y, Z), dtype=np.float32)
 4092    RGB_vol = np.zeros((X, Y, Z, 3), dtype=np.float32)
 4093    has_tensor_metrics = diffusion_model in ['DTI', 'FreeWater']
 4094
 4095    def process_chunk(z_start):
 4096        z_end = min(Z, z_start + slices_per_chunk)
 4097        local_data = img_data[:, :, z_start:z_end, :]
 4098        local_mask = mask[:, :, z_start:z_end]
 4099        masked_data = local_data * local_mask[..., None]
 4100        masked_data = np.nan_to_num(masked_data, nan=0)
 4101        fit = model.fit(masked_data)
 4102        if has_tensor_metrics and hasattr(fit, 'evals') and hasattr(fit, 'evecs'):
 4103            FA = fractional_anisotropy(fit.evals)
 4104            FA[np.isnan(FA)] = 1
 4105            FA = np.clip(FA, 0, 1)
 4106            MD = mean_diffusivity(fit.evals)
 4107            RGB = color_fa(FA, fit.evecs)
 4108            return z_start, z_end, FA, MD, RGB
 4109        return z_start, z_end, None, None, None
 4110
 4111    chunks = range(0, Z, slices_per_chunk)
 4112    if num_threads > 1:
 4113        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 4114            futures = {executor.submit(process_chunk, z): z for z in chunks}
 4115            for f in as_completed(futures):
 4116                z_start, z_end, FA, MD, RGB = f.result()
 4117                if FA is not None:
 4118                    FA_vol[:, :, z_start:z_end] = FA
 4119                    MD_vol[:, :, z_start:z_end] = MD
 4120                    RGB_vol[:, :, z_start:z_end, :] = RGB
 4121    else:
 4122        for z in chunks:
 4123            z_start, z_end, FA, MD, RGB = process_chunk(z)
 4124            if FA is not None:
 4125                FA_vol[:, :, z_start:z_end] = FA
 4126                MD_vol[:, :, z_start:z_end] = MD
 4127                RGB_vol[:, :, z_start:z_end, :] = RGB
 4128
 4129    b0 = ants.slice_image(imagein, axis=3, idx=0)
 4130    FA_img = ants.copy_image_info(b0, ants.from_numpy(FA_vol)) if has_tensor_metrics else None
 4131    MD_img = ants.copy_image_info(b0, ants.from_numpy(MD_vol)) if has_tensor_metrics else None
 4132    RGB_img = (ants.merge_channels([
 4133        ants.copy_image_info(b0, ants.from_numpy(RGB_vol[..., i])) for i in range(3)
 4134    ]) if has_tensor_metrics else None)
 4135
 4136    full_fit = model.fit(img_data * mask[..., None])
 4137    return full_fit, FA_img, MD_img, RGB_img
 4138
 4139
 4140def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
 4141                                bvals_to_use=None, num_threads=1, verbose=True):
 4142    """
 4143    Voxel-wise diffusion model fitting with individual b-vectors per voxel.
 4144
 4145    Parameters
 4146    ----------
 4147    imagein : ants.ANTsImage
 4148        4D DWI image (X, Y, Z, N).
 4149    maskin : ants.ANTsImage
 4150        3D binary mask.
 4151    bvals : (N,) array-like
 4152        Common b-values across volumes.
 4153    bvecs_5d : (X, Y, Z, N, 3) ndarray
 4154        Voxel-specific b-vectors.
 4155    model_params : dict
 4156        Extra arguments for model.
 4157    bvals_to_use : list[int]
 4158        Subset of b-values to include.
 4159    num_threads : int
 4160        Number of threads to use.
 4161    verbose : bool
 4162        Whether to print status.
 4163
 4164    Returns
 4165    -------
 4166    FA_img : ants.ANTsImage
 4167        Fractional anisotropy.
 4168    MD_img : ants.ANTsImage
 4169        Mean diffusivity.
 4170    RGB_img : ants.ANTsImage
 4171        RGB FA image.
 4172    """
 4173    import numpy as np
 4174    import ants
 4175    import dipy.reconst.dti as dti
 4176    from dipy.core.gradients import gradient_table
 4177    from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
 4178    from concurrent.futures import ThreadPoolExecutor
 4179    from tqdm import tqdm
 4180
 4181    model_params = model_params or {}
 4182    img = imagein.numpy()
 4183    mask = maskin.numpy().astype(bool)
 4184    X, Y, Z, N = img.shape
 4185
 4186    if bvals_to_use is not None:
 4187        sel = np.isin(bvals, bvals_to_use)
 4188        img = img[..., sel]
 4189        bvals = bvals[sel]
 4190        bvecs_5d = bvecs_5d[..., sel, :]
 4191
 4192    FA = np.zeros((X, Y, Z), dtype=np.float32)
 4193    MD = np.zeros((X, Y, Z), dtype=np.float32)
 4194    RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
 4195
 4196    def fit_voxel(ix, iy, iz):
 4197        if not mask[ix, iy, iz]:
 4198            return
 4199        sig = img[ix, iy, iz, :]
 4200        if np.all(sig == 0):
 4201            return
 4202        bv = bvecs_5d[ix, iy, iz, :, :]
 4203        gtab = gradient_table(bvals, bv)
 4204        try:
 4205            model = dti.TensorModel(gtab, **model_params)
 4206            fit = model.fit(sig)
 4207            evals = fit.evals
 4208            evecs = fit.evecs
 4209            FA[ix, iy, iz] = fractional_anisotropy(evals)
 4210            MD[ix, iy, iz] = mean_diffusivity(evals)
 4211            RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
 4212        except Exception as e:
 4213            if verbose:
 4214                print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
 4215
 4216    coords = np.argwhere(mask)
 4217    if verbose:
 4218        print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
 4219
 4220    if num_threads > 1:
 4221        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 4222            list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
 4223    else:
 4224        for c in tqdm(coords):
 4225            fit_voxel(*c)
 4226
 4227    ref = ants.slice_image(imagein, axis=3, idx=0)
 4228    return (
 4229        ants.copy_image_info(ref, ants.from_numpy(FA)),
 4230        ants.copy_image_info(ref, ants.from_numpy(MD)),
 4231        ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
 4232    )
 4233
 4234
 4235def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
 4236    """
 4237    Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
 4238
 4239    Parameters
 4240    ----------
 4241    global_bvecs : ndarray of shape (N, 3)
 4242        Global diffusion gradient directions.
 4243    voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
 4244        3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
 4245    transpose : bool, optional
 4246        If True, transpose the rotation matrices before applying them to the b-vectors.
 4247
 4248
 4249    Returns
 4250    -------
 4251    bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
 4252        Voxel-specific b-vectors.
 4253    """
 4254    X, Y, Z, _, _ = voxel_rotations.shape
 4255    N = global_bvecs.shape[0]
 4256    bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
 4257
 4258    for n in range(N):
 4259        bvec = global_bvecs[n]
 4260        for i in range(X):
 4261            for j in range(Y):
 4262                for k in range(Z):
 4263                    R = voxel_rotations[i, j, k]
 4264                    if transpose:
 4265                        R = R.T  # Use transpose if needed
 4266                    bvecs_5d[i, j, k, n, :] = R @ bvec
 4267    return bvecs_5d
 4268
 4269def dipy_dti_recon(
 4270    image,
 4271    bvalsfn,
 4272    bvecsfn,
 4273    mask = None,
 4274    b0_idx = None,
 4275    mask_dilation = 2,
 4276    mask_closing = 5,
 4277    fit_method='WLS',
 4278    trim_the_mask=2.0,
 4279    diffusion_model='DTI',
 4280    verbose=False ):
 4281    """
 4282    DiPy DTI reconstruction - building on the DiPy basic DTI example
 4283
 4284    Arguments
 4285    ---------
 4286    image : an antsImage holding B0 and DWI
 4287
 4288    bvalsfn : bvalues  obtained by dipy read_bvals_bvecs or the values themselves
 4289
 4290    bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves
 4291
 4292    mask : brain mask for the DWI/DTI reconstruction; if it is not in the same
 4293        space as the image, we will resample directly to the image space.  This
 4294        could lead to problems if the inputs are really incorrect.
 4295
 4296    b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue
 4297
 4298    mask_dilation : integer zero or more dilates the brain mask
 4299
 4300    mask_closing : integer zero or more closes the brain mask
 4301
 4302    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.
 4303
 4304    trim_the_mask : float >=0 post-hoc method for trimming the mask
 4305
 4306    diffusion_model : string
 4307        DTI, FreeWater, DKI
 4308
 4309    verbose : boolean
 4310
 4311    Returns
 4312    -------
 4313    dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)
 4314
 4315    NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)
 4316
 4317    NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect
 4318        the tailing image volumes.
 4319
 4320    Example
 4321    -------
 4322    >>> import antspymm
 4323    """
 4324
 4325    import dipy.reconst.fwdti as fwdti
 4326
 4327    if isinstance(bvecsfn, str):
 4328        bvals, bvecs = read_bvals_bvecs( bvalsfn , bvecsfn   )
 4329    else: # assume we already read them
 4330        bvals = bvalsfn.copy()
 4331        bvecs = bvecsfn.copy()
 4332
 4333    if bvals.max() < 1.0:
 4334        raise ValueError("DTI recon error: maximum bvalues are too small.")
 4335
 4336    b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
 4337
 4338    b0 = ants.slice_image( image, axis=3, idx=b0_idx[0] )
 4339    bxtmod='bold'
 4340    bxtmod='t2'
 4341    constant_mask=False
 4342    if verbose:
 4343        print( np.unique( bvals ), flush=True )
 4344    if mask is not None:
 4345        if verbose:
 4346            print("use set bxt in dipy_dti_recon", flush=True)
 4347        constant_mask=True
 4348        mask = ants.resample_image_to_target( mask, b0, interp_type='nearestNeighbor')
 4349    else:
 4350        if verbose:
 4351            print("use deep learning bxt in dipy_dti_recon")
 4352        mask = antspynet.brain_extraction( b0, bxtmod ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 4353    if mask_closing > 0 and not constant_mask :
 4354        mask = ants.morphology( mask, "close", mask_closing ) # good
 4355    maskdil = ants.iMath( mask, "MD", mask_dilation )
 4356
 4357    if verbose:
 4358        print("recon dti.TensorModel",flush=True)
 4359
 4360    bvecs = repair_bvecs( bvecs )
 4361    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 4362    mynt=1
 4363    threads_env = os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")
 4364    if threads_env is not None:
 4365        mynt = int(threads_env)
 4366    tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
 4367                                             num_threads=mynt )
 4368    if verbose:
 4369        print("recon dti.TensorModel done",flush=True)
 4370
 4371    # change the brain mask based on high FA values
 4372    if trim_the_mask > 0 and fit_method is not None:
 4373        mask = trim_dti_mask( FA, mask, trim_the_mask )
 4374        tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
 4375                                             num_threads=mynt )
 4376
 4377    return {
 4378        'tensormodel' : tenfit,
 4379        'MD' : MD1 ,
 4380        'FA' : FA ,
 4381        'RGB' : RGB,
 4382        'dwi_mask':mask,
 4383        'bvals':bvals,
 4384        'bvecs':bvecs
 4385        }
 4386
 4387
 4388def concat_dewarp(
 4389        refimg,
 4390        originalDWI,
 4391        physSpaceDWI,
 4392        dwpTx,
 4393        motion_parameters,
 4394        motion_correct=True,
 4395        verbose=False ):
 4396    """
 4397    Apply concatentated motion correction and dewarping transforms to timeseries image.
 4398
 4399    Arguments
 4400    ---------
 4401
 4402    refimg : an antsImage defining the reference domain (3D)
 4403
 4404    originalDWI : the antsImage in original (not interpolated space) (4D)
 4405
 4406    physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)
 4407
 4408    dwpTx : dewarping transform
 4409
 4410    motion_parameters : previously computed list of motion parameters
 4411
 4412    motion_correct : boolean
 4413
 4414    verbose : boolean
 4415
 4416    """
 4417    # apply the dewarping tx to the original dwi and reconstruct again
 4418    # NOTE: refimg must be in the same space for this to work correctly
 4419    # due to the use of ants.list_to_ndimage( originalDWI, dwpimage )
 4420    dwpimage = []
 4421    for myidx in range(originalDWI.shape[3]):
 4422        b0 = ants.slice_image( originalDWI, axis=3, idx=myidx)
 4423        concatx = dwpTx.copy()
 4424        if motion_correct:
 4425            concatx = concatx + motion_parameters[myidx]
 4426        if verbose and myidx == 0:
 4427            print("dwp parameters")
 4428            print( dwpTx )
 4429            print("Motion parameters")
 4430            print( motion_parameters[myidx] )
 4431            print("concat parameters")
 4432            print(concatx)
 4433        warpedb0 = ants.apply_transforms( refimg, b0, concatx,
 4434            interpolator='nearestNeighbor' )
 4435        dwpimage.append( warpedb0 )
 4436    return ants.list_to_ndimage( physSpaceDWI, dwpimage )
 4437
 4438
 4439def joint_dti_recon(
 4440    img_LR,
 4441    bval_LR,
 4442    bvec_LR,
 4443    jhu_atlas,
 4444    jhu_labels,
 4445    reference_B0,
 4446    reference_DWI,
 4447    srmodel = None,
 4448    img_RL = None,
 4449    bval_RL = None,
 4450    bvec_RL = None,
 4451    t1w = None,
 4452    brain_mask = None,
 4453    motion_correct = None,
 4454    dewarp_modality = 'FA',
 4455    denoise=False,
 4456    fit_method='WLS',
 4457    impute = False,
 4458    censor = True,
 4459    diffusion_model = 'DTI',
 4460    verbose = False ):
 4461    """
 4462    1. pass in subject data and 1mm JHU atlas/labels
 4463    2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
 4464    3. dewarp the images using dewarp_modality or T1w
 4465    4. apply dewarping to the original data
 4466        ===> may want to apply SR at this step
 4467    5. reconstruct DTI again
 4468    6. label images and do registration
 4469    7. return relevant outputs
 4470
 4471    NOTE: RL images are optional; should pass t1w in this case.
 4472
 4473    Arguments
 4474    ---------
 4475
 4476    img_LR : an antsImage holding B0 and DWI LR acquisition
 4477
 4478    bval_LR : bvalue filename LR
 4479
 4480    bvec_LR : bvector filename LR
 4481
 4482    jhu_atlas : atlas FA image
 4483
 4484    jhu_labels : atlas labels
 4485
 4486    reference_B0 : the "target" B0 image space
 4487
 4488    reference_DWI : the "target" DW image space
 4489
 4490    srmodel : optional h5 (tensorflow) model
 4491
 4492    img_RL : an antsImage holding B0 and DWI RL acquisition
 4493
 4494    bval_RL : bvalue filename RL
 4495
 4496    bvec_RL : bvector filename RL
 4497
 4498    t1w : antsimage t1w neuroimage (brain-extracted)
 4499
 4500    brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space
 4501
 4502    motion_correct : None Rigid or SyN
 4503
 4504    dewarp_modality : string average_dwi, average_b0, MD or FA
 4505
 4506    denoise: boolean
 4507
 4508    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
 4509
 4510    impute : boolean
 4511
 4512    censor : boolean
 4513
 4514    diffusion_model : string
 4515        DTI, FreeWater, DKI
 4516
 4517    verbose : boolean
 4518
 4519    Returns
 4520    -------
 4521    dictionary holding the mean_fa, its summary statistics via JHU labels,
 4522        the JHU registration, the JHU labels, the dewarping dictionary and the
 4523        dti reconstruction dictionaries.
 4524
 4525    Example
 4526    -------
 4527    >>> import antspymm
 4528    """
 4529
 4530    if verbose:
 4531        print("Recon DTI on OR images ...")
 4532
 4533    def fix_dwi_shape( img, bvalfn, bvecfn ):
 4534        if isinstance(bvecfn, str):
 4535            bvals, bvecs = read_bvals_bvecs( bvalfn , bvecfn   )
 4536        else:
 4537            bvals = bvalfn
 4538            bvecs = bvecfn
 4539        if bvecs.shape[0] < img.shape[3]:
 4540            imgout = ants.from_numpy( img[:,:,:,0:bvecs.shape[0]] )
 4541            imgout = ants.copy_image_info( img, imgout )
 4542            return( imgout )
 4543        else:
 4544            return( img )
 4545
 4546    img_LR = fix_dwi_shape( img_LR, bval_LR, bvec_LR )
 4547    if denoise :
 4548        img_LR = mc_denoise( img_LR )
 4549    if img_RL is not None:
 4550        img_RL = fix_dwi_shape( img_RL, bval_RL, bvec_RL )
 4551        if denoise :
 4552            img_RL = mc_denoise( img_RL )
 4553
 4554    brainmaske = None
 4555    if brain_mask is not None:
 4556        maskInRightSpace = ants.image_physical_space_consistency( brain_mask, reference_B0 )
 4557        if not maskInRightSpace :
 4558            raise ValueError('not maskInRightSpace ... provided brain mask should be in reference_B0 space')
 4559        brainmaske = ants.iMath( brain_mask, "ME", 2 )
 4560
 4561    if img_RL is not None :
 4562        if verbose:
 4563            print("img_RL correction")
 4564        reg_RL = dti_reg(
 4565            img_RL,
 4566            avg_b0=reference_B0,
 4567            avg_dwi=reference_DWI,
 4568            bvals=bval_RL,
 4569            bvecs=bvec_RL,
 4570            type_of_transform=motion_correct,
 4571            brain_mask_eroded=brainmaske,
 4572            verbose=True )
 4573    else:
 4574        reg_RL=None
 4575
 4576
 4577    if verbose:
 4578        print("img_LR correction")
 4579    reg_LR = dti_reg(
 4580            img_LR,
 4581            avg_b0=reference_B0,
 4582            avg_dwi=reference_DWI,
 4583            bvals=bval_LR,
 4584            bvecs=bvec_LR,
 4585            type_of_transform=motion_correct,
 4586            brain_mask_eroded=brainmaske,
 4587            verbose=True )
 4588
 4589    ts_LR_avg = None
 4590    ts_RL_avg = None
 4591    reg_its = [100,50,10]
 4592    img_LRdwp = ants.image_clone( reg_LR[ 'motion_corrected' ] )
 4593    if img_RL is not None:
 4594        img_RLdwp = ants.image_clone( reg_RL[ 'motion_corrected' ] )
 4595        if srmodel is not None:
 4596            if verbose:
 4597                print("convert img_RL_dwp to img_RL_dwp_SR")
 4598            img_RLdwp = super_res_mcimage( img_RLdwp, srmodel, isotropic=True,
 4599                        verbose=verbose )
 4600    if srmodel is not None:
 4601        reg_its = [100] + reg_its
 4602        if verbose:
 4603            print("convert img_LR_dwp to img_LR_dwp_SR")
 4604        img_LRdwp = super_res_mcimage( img_LRdwp, srmodel, isotropic=True,
 4605                verbose=verbose )
 4606    if verbose:
 4607        print("recon after distortion correction", flush=True)
 4608
 4609    if impute:
 4610        print("impute begin", flush=True)
 4611        img_LRdwp=impute_dwi( img_LRdwp, verbose=True )
 4612        print("impute done", flush=True)
 4613    elif censor:
 4614        print("censor begin", flush=True)
 4615        img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'] = censor_dwi( img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'], verbose=True )
 4616        print("censor done", flush=True)
 4617    if impute and img_RL is not None:
 4618        img_RLdwp=impute_dwi( img_RLdwp, verbose=True )
 4619    elif censor and img_RL is not None:
 4620        img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'] = censor_dwi( img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'], verbose=True )
 4621
 4622    if img_RL is not None:
 4623        img_LRdwp, bval_LR, bvec_LR = merge_dwi_data(
 4624            img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'],
 4625            img_RLdwp, reg_RL['bvals'], reg_RL['bvecs']
 4626        )
 4627    else:
 4628        bval_LR=reg_LR['bvals']
 4629        bvec_LR=reg_LR['bvecs']
 4630
 4631    if verbose:
 4632        print("final recon", flush=True)
 4633        print(img_LRdwp)
 4634
 4635    recon_LR_dewarp = dipy_dti_recon(
 4636            img_LRdwp, bval_LR, bvec_LR,
 4637            mask = brain_mask,
 4638            fit_method=fit_method,
 4639            mask_dilation=0, diffusion_model=diffusion_model, verbose=True )
 4640    if verbose:
 4641        print("recon done", flush=True)
 4642
 4643    if img_RL is not None:
 4644        fdjoin = [ reg_LR['FD'],
 4645                   reg_RL['FD'] ]
 4646        framewise_displacement=np.concatenate( fdjoin )
 4647    else:
 4648        framewise_displacement=reg_LR['FD']
 4649
 4650    motion_count = ( framewise_displacement > 1.5  ).sum()
 4651    reconFA = recon_LR_dewarp['FA']
 4652    reconMD = recon_LR_dewarp['MD']
 4653
 4654    if verbose:
 4655        print("JHU reg",flush=True)
 4656
 4657    OR_FA2JHUreg = ants.registration( reconFA, jhu_atlas,
 4658        type_of_transform = 'antsRegistrationSyNQuickRepro[s]', 
 4659        reg_iterations=reg_its, verbose=False )
 4660    OR_FA_jhulabels = ants.apply_transforms( reconFA, jhu_labels,
 4661        OR_FA2JHUreg['fwdtransforms'], interpolator='genericLabel')
 4662
 4663    df_FA_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
 4664        'FA_JHU_labels_edited',
 4665        reconFA,
 4666        OR_FA_jhulabels)
 4667    df_FA_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 4668            {'df_FA_JHU_ORRL' : df_FA_JHU_ORRL},
 4669            col_names = ['Mean'] )
 4670
 4671    df_MD_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
 4672        'MD_JHU_labels_edited',
 4673        reconMD,
 4674        OR_FA_jhulabels)
 4675    df_MD_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 4676            {'df_MD_JHU_ORRL' : df_MD_JHU_ORRL},
 4677            col_names = ['Mean'] )
 4678
 4679    temp = segment_timeseries_by_meanvalue( img_LRdwp )
 4680    b0_idx = temp['highermeans']
 4681    non_b0_idx = temp['lowermeans']
 4682
 4683    nonbrainmask = ants.iMath( recon_LR_dewarp['dwi_mask'], "MD",2) - recon_LR_dewarp['dwi_mask']
 4684    fgmask = ants.threshold_image( reconFA, 0.5 , 1.0).iMath("GetLargestComponent")
 4685    bgmask = ants.threshold_image( reconFA, 1e-4 , 0.1)
 4686    fa_SNR = 0.0
 4687    fa_SNR = mask_snr( reconFA, bgmask, fgmask, bias_correct=False )
 4688    fa_evr = antspyt1w.patch_eigenvalue_ratio( reconFA, 512, [16,16,16], evdepth = 0.9, mask=recon_LR_dewarp['dwi_mask'] )
 4689
 4690    dti_itself = get_dti( reconFA, recon_LR_dewarp['tensormodel'], return_image=True )
 4691    return convert_np_in_dict( {
 4692        'dti': dti_itself,
 4693        'recon_fa':reconFA,
 4694        'recon_fa_summary':df_FA_JHU_ORRL_bfwide,
 4695        'recon_md':reconMD,
 4696        'recon_md_summary':df_MD_JHU_ORRL_bfwide,
 4697        'jhu_labels':OR_FA_jhulabels,
 4698        'jhu_registration':OR_FA2JHUreg,
 4699        'reg_LR':reg_LR,
 4700        'reg_RL':reg_RL,
 4701        'dtrecon_LR_dewarp':recon_LR_dewarp,
 4702        'dwi_LR_dewarped':img_LRdwp,
 4703        'bval_unique_count': len(np.unique(bval_LR)),
 4704        'bval_LR':bval_LR,
 4705        'bvec_LR':bvec_LR,
 4706        'bval_RL':bval_RL,
 4707        'bvec_RL':bvec_RL,
 4708        'b0avg': reference_B0,
 4709        'dwiavg': reference_DWI,
 4710        'framewise_displacement':framewise_displacement,
 4711        'high_motion_count': motion_count,
 4712        'tsnr_b0': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
 4713        'tsnr_dwi': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
 4714        'dvars_b0': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
 4715        'dvars_dwi': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
 4716        'ssnr_b0': slice_snr( img_LRdwp, bgmask , fgmask, b0_idx),
 4717        'ssnr_dwi': slice_snr( img_LRdwp, bgmask, fgmask, non_b0_idx),
 4718        'fa_evr': fa_evr,
 4719        'fa_SNR': fa_SNR
 4720    } )
 4721
 4722
 4723def middle_slice_snr( x, background_dilation=5 ):
 4724    """
 4725
 4726    Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image.
 4727    Estimates noise from a background mask which is a
 4728    dilation of the foreground mask minus the foreground mask.
 4729    Actually estimates the reciprocal of the coefficient of variation.
 4730
 4731    Arguments
 4732    ---------
 4733
 4734    x : an antsImage
 4735
 4736    background_dilation : integer - amount to dilate foreground mask
 4737
 4738    """
 4739    xshp = x.shape
 4740    xmidslice = ants.slice_image( x, 2, int( xshp[2]/2 )  )
 4741    xmidslice = ants.iMath( xmidslice - xmidslice.min(), "Normalize" )
 4742    xmidslice = ants.n3_bias_field_correction( xmidslice )
 4743    xmidslice = ants.n3_bias_field_correction( xmidslice )
 4744    xmidslicemask = ants.threshold_image( xmidslice, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
 4745    xbkgmask = ants.iMath( xmidslicemask, "MD", background_dilation ) - xmidslicemask
 4746    signal = (xmidslice[ xmidslicemask == 1] ).mean()
 4747    noise = (xmidslice[ xbkgmask == 1] ).std()
 4748    return signal / noise
 4749
 4750def foreground_background_snr( x, background_dilation=10,
 4751        erode_foreground=False):
 4752    """
 4753
 4754    Estimate signal to noise ratio (SNR) in an image.
 4755    Estimates noise from a background mask which is a
 4756    dilation of the foreground mask minus the foreground mask.
 4757    Actually estimates the reciprocal of the coefficient of variation.
 4758
 4759    Arguments
 4760    ---------
 4761
 4762    x : an antsImage
 4763
 4764    background_dilation : integer - amount to dilate foreground mask
 4765
 4766    erode_foreground : boolean - 2nd option which erodes the initial
 4767    foregound mask  to create a new foreground mask.  the background
 4768    mask is the initial mask minus the eroded mask.
 4769
 4770    """
 4771    xshp = x.shape
 4772    xbc = ants.iMath( x - x.min(), "Normalize" )
 4773    xbc = ants.n3_bias_field_correction( xbc )
 4774    xmask = ants.threshold_image( xbc, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
 4775    xbkgmask = ants.iMath( xmask, "MD", background_dilation ) - xmask
 4776    fgmask = xmask
 4777    if erode_foreground:
 4778        fgmask = ants.iMath( xmask, "ME", background_dilation )
 4779        xbkgmask = xmask - fgmask
 4780    signal = (xbc[ fgmask == 1] ).mean()
 4781    noise = (xbc[ xbkgmask == 1] ).std()
 4782    return signal / noise
 4783
 4784def quantile_snr( x,
 4785    lowest_quantile=0.01,
 4786    low_quantile=0.1,
 4787    high_quantile=0.5,
 4788    highest_quantile=0.95 ):
 4789    """
 4790
 4791    Estimate signal to noise ratio (SNR) in an image.
 4792    Estimates noise from a background mask which is a
 4793    dilation of the foreground mask minus the foreground mask.
 4794    Actually estimates the reciprocal of the coefficient of variation.
 4795
 4796    Arguments
 4797    ---------
 4798
 4799    x : an antsImage
 4800
 4801    lowest_quantile : float value < 1 and > 0
 4802
 4803    low_quantile : float value < 1 and > 0
 4804
 4805    high_quantile : float value < 1 and > 0
 4806
 4807    highest_quantile : float value < 1 and > 0
 4808
 4809    """
 4810    import numpy as np
 4811    xshp = x.shape
 4812    xbc = ants.iMath( x - x.min(), "Normalize" )
 4813    xbc = ants.n3_bias_field_correction( xbc )
 4814    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
 4815    y = xbc.numpy()
 4816    ylowest = np.quantile( y[y>0], lowest_quantile )
 4817    ylo = np.quantile( y[y>0], low_quantile )
 4818    yhi = np.quantile( y[y>0], high_quantile )
 4819    yhiest = np.quantile( y[y>0], highest_quantile )
 4820    xbkgmask = ants.threshold_image( xbc, ylowest, ylo )
 4821    fgmask = ants.threshold_image( xbc, yhi, yhiest )
 4822    signal = (xbc[ fgmask == 1] ).mean()
 4823    noise = (xbc[ xbkgmask == 1] ).std()
 4824    return signal / noise
 4825
 4826def mask_snr( x, background_mask, foreground_mask, bias_correct=True ):
 4827    """
 4828
 4829    Estimate signal to noise ratio (SNR) in an image using
 4830    a user-defined foreground and background mask.
 4831    Actually estimates the reciprocal of the coefficient of variation.
 4832
 4833    Arguments
 4834    ---------
 4835
 4836    x : an antsImage
 4837
 4838    background_mask : binary antsImage
 4839
 4840    foreground_mask : binary antsImage
 4841
 4842    bias_correct : boolean
 4843
 4844    """
 4845    import numpy as np
 4846    if foreground_mask.sum() <= 1 or background_mask.sum() <= 1:
 4847        return 0
 4848    xbc = ants.iMath( x - x.min(), "Normalize" )
 4849    if bias_correct:
 4850        xbc = ants.n3_bias_field_correction( xbc )
 4851    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
 4852    signal = (xbc[ foreground_mask == 1] ).mean()
 4853    noise = (xbc[ background_mask == 1] ).std()
 4854    return signal / noise
 4855
 4856
 4857def dwi_deterministic_tracking(
 4858    dwi,
 4859    fa,
 4860    bvals,
 4861    bvecs,
 4862    num_processes=1,
 4863    mask=None,
 4864    label_image = None,
 4865    seed_labels = None,
 4866    fa_thresh = 0.05,
 4867    seed_density = 1,
 4868    step_size = 0.15,
 4869    peak_indices = None,
 4870    fit_method='WLS',
 4871    verbose = False ):
 4872    """
 4873
 4874    Performs deterministic tractography from the DWI and returns a tractogram
 4875    and path length data frame.
 4876
 4877    Arguments
 4878    ---------
 4879
 4880    dwi : an antsImage holding DWI acquisition
 4881
 4882    fa : an antsImage holding FA values
 4883
 4884    bvals : bvalues
 4885
 4886    bvecs : bvectors
 4887
 4888    num_processes : number of subprocesses
 4889
 4890    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
 4891        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 4892
 4893    label_image : atlas labels
 4894
 4895    seed_labels : list of label numbers from the atlas labels
 4896
 4897    fa_thresh : 0.25 defaults
 4898
 4899    seed_density : 1 default number of seeds per voxel
 4900
 4901    step_size : for tracking
 4902
 4903    peak_indices : pass these in, if they are previously estimated.  otherwise, will
 4904        compute on the fly (slow)
 4905
 4906    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
 4907
 4908    verbose : boolean
 4909
 4910    Returns
 4911    -------
 4912    dictionary holding tracts and stateful object.
 4913
 4914    Example
 4915    -------
 4916    >>> import antspymm
 4917    """
 4918    import os
 4919    import re
 4920    import nibabel as nib
 4921    import numpy as np
 4922    import ants
 4923    from dipy.io.gradients import read_bvals_bvecs
 4924    from dipy.core.gradients import gradient_table
 4925    from dipy.tracking import utils
 4926    import dipy.reconst.dti as dti
 4927    from dipy.segment.clustering import QuickBundles
 4928    from dipy.tracking.utils import path_length
 4929    if verbose:
 4930        print("begin tracking",flush=True)
 4931
 4932    affine = ants_to_nibabel_affine(dwi)
 4933
 4934    if isinstance( bvals, str ) or isinstance( bvecs, str ):
 4935        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
 4936    bvecs = repair_bvecs( bvecs )
 4937    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 4938    if mask is None:
 4939        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 4940    dwi_data = dwi.numpy()
 4941    dwi_mask = mask.numpy() == 1
 4942    dti_model = dti.TensorModel(gtab,fit_method=fit_method)
 4943    if verbose:
 4944        print("begin tracking fit",flush=True)
 4945    dti_fit = dti_model.fit(dwi_data, mask=dwi_mask)  # This step may take a while
 4946    evecs_img = dti_fit.evecs
 4947    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
 4948    stopping_criterion = ThresholdStoppingCriterion(fa.numpy(), fa_thresh)
 4949    from dipy.data import get_sphere
 4950    sphere = get_sphere(name='symmetric362')
 4951    from dipy.direction import peaks_from_model
 4952    if peak_indices is None:
 4953        # problems with multi-threading ...
 4954        # see https://github.com/dipy/dipy/issues/2519
 4955        if verbose:
 4956            print("begin peaks",flush=True)
 4957        mynump=1
 4958        # if os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"):
 4959        #    mynump = os.environ['ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS']
 4960        # current_openblas = os.environ.get('OPENBLAS_NUM_THREADS', '')
 4961        # current_mkl = os.environ.get('MKL_NUM_THREADS', '')
 4962        # os.environ['DIPY_OPENBLAS_NUM_THREADS'] = current_openblas
 4963        # os.environ['DIPY_MKL_NUM_THREADS'] = current_mkl
 4964        # os.environ['OPENBLAS_NUM_THREADS'] = '1'
 4965        # os.environ['MKL_NUM_THREADS'] = '1'
 4966        peak_indices = peaks_from_model(
 4967            model=dti_model,
 4968            data=dwi_data,
 4969            sphere=sphere,
 4970            relative_peak_threshold=.5,
 4971            min_separation_angle=25,
 4972            mask=dwi_mask,
 4973            npeaks=3, return_odf=False,
 4974            return_sh=False,
 4975            parallel=int(mynump) > 1,
 4976            num_processes=int(mynump)
 4977            )
 4978        if False:
 4979            if 'DIPY_OPENBLAS_NUM_THREADS' in os.environ:
 4980                os.environ['OPENBLAS_NUM_THREADS'] = \
 4981                    os.environ.pop('DIPY_OPENBLAS_NUM_THREADS', '')
 4982                if os.environ['OPENBLAS_NUM_THREADS'] in ['', None]:
 4983                    os.environ.pop('OPENBLAS_NUM_THREADS', '')
 4984            if 'DIPY_MKL_NUM_THREADS' in os.environ:
 4985                os.environ['MKL_NUM_THREADS'] = \
 4986                    os.environ.pop('DIPY_MKL_NUM_THREADS', '')
 4987                if os.environ['MKL_NUM_THREADS'] in ['', None]:
 4988                    os.environ.pop('MKL_NUM_THREADS', '')
 4989
 4990    if label_image is None or seed_labels is None:
 4991        seed_mask = fa.numpy().copy()
 4992        seed_mask[seed_mask >= fa_thresh] = 1
 4993        seed_mask[seed_mask < fa_thresh] = 0
 4994    else:
 4995        labels = label_image.numpy()
 4996        seed_mask = labels * 0
 4997        for u in seed_labels:
 4998            seed_mask[ labels == u ] = 1
 4999    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
 5000    from dipy.tracking.local_tracking import LocalTracking
 5001    from dipy.tracking.streamline import Streamlines
 5002    if verbose:
 5003        print("streamlines begin ...", flush=True)
 5004    streamlines_generator = LocalTracking(
 5005        peak_indices, stopping_criterion, seeds, affine=affine, step_size=step_size)
 5006    streamlines = Streamlines(streamlines_generator)
 5007    from dipy.io.stateful_tractogram import Space, StatefulTractogram
 5008    from dipy.io.streamline import save_tractogram
 5009    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
 5010    if verbose:
 5011        print("streamlines done", flush=True)
 5012    return {
 5013          'tractogram': sft,
 5014          'streamlines': streamlines,
 5015          'peak_indices': peak_indices
 5016          }
 5017
 5018def repair_bvecs( bvecs ):
 5019    bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
 5020    # bvecnormnan = np.isnan( bvecnorm )
 5021    # nan_indices = list( np.unique( np.where(np.isnan(bvecs))[0]))
 5022    # bvecs = remove_elements_from_numpy_array( bvecs, nan_indices )
 5023    if abs(np.linalg.norm(bvecs)-1) > 0.009:
 5024        warnings.warn( "Warning: bvecs are not unit norm - we normalize them here but this may indicate a problem with the data.  Norm is : " + str( np.linalg.norm(bvecs) ) + " shape is " + str( bvecs.shape[0] ) + " " + str( bvecs.shape[1] ))
 5025        bvecs=np.where(bvecnorm > 1e-16, bvecs / bvecnorm, 0)
 5026    return bvecs
 5027
 5028
 5029def dwi_closest_peak_tracking(
 5030    dwi,
 5031    fa,
 5032    bvals,
 5033    bvecs,
 5034    num_processes=1,
 5035    mask=None,
 5036    label_image = None,
 5037    seed_labels = None,
 5038    fa_thresh = 0.05,
 5039    seed_density = 1,
 5040    step_size = 0.15,
 5041    peak_indices = None,
 5042    verbose = False ):
 5043    """
 5044
 5045    Performs deterministic tractography from the DWI and returns a tractogram
 5046    and path length data frame.
 5047
 5048    Arguments
 5049    ---------
 5050
 5051    dwi : an antsImage holding DWI acquisition
 5052
 5053    fa : an antsImage holding FA values
 5054
 5055    bvals : bvalues
 5056
 5057    bvecs : bvectors
 5058
 5059    num_processes : number of subprocesses
 5060
 5061    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
 5062        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 5063
 5064    label_image : atlas labels
 5065
 5066    seed_labels : list of label numbers from the atlas labels
 5067
 5068    fa_thresh : 0.25 defaults
 5069
 5070    seed_density : 1 default number of seeds per voxel
 5071
 5072    step_size : for tracking
 5073
 5074    peak_indices : pass these in, if they are previously estimated.  otherwise, will
 5075        compute on the fly (slow)
 5076
 5077    verbose : boolean
 5078
 5079    Returns
 5080    -------
 5081    dictionary holding tracts and stateful object.
 5082
 5083    Example
 5084    -------
 5085    >>> import antspymm
 5086    """
 5087    import os
 5088    import re
 5089    import nibabel as nib
 5090    import numpy as np
 5091    import ants
 5092    from dipy.io.gradients import read_bvals_bvecs
 5093    from dipy.core.gradients import gradient_table
 5094    from dipy.tracking import utils
 5095    import dipy.reconst.dti as dti
 5096    from dipy.segment.clustering import QuickBundles
 5097    from dipy.tracking.utils import path_length
 5098    from dipy.core.gradients import gradient_table
 5099    from dipy.data import small_sphere
 5100    from dipy.direction import BootDirectionGetter, ClosestPeakDirectionGetter
 5101    from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,
 5102                                    auto_response_ssst)
 5103    from dipy.reconst.shm import CsaOdfModel
 5104    from dipy.tracking.local_tracking import LocalTracking
 5105    from dipy.tracking.streamline import Streamlines
 5106    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
 5107
 5108    if verbose:
 5109        print("begin tracking",flush=True)
 5110
 5111    affine = ants_to_nibabel_affine(dwi)
 5112    if isinstance( bvals, str ) or isinstance( bvecs, str ):
 5113        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
 5114    bvecs = repair_bvecs( bvecs )
 5115    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 5116    if mask is None:
 5117        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 5118    dwi_data = dwi.numpy()
 5119    dwi_mask = mask.numpy() == 1
 5120
 5121
 5122    response, ratio = auto_response_ssst(gtab, dwi_data, roi_radii=10, fa_thr=0.7)
 5123    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
 5124    csd_fit = csd_model.fit(dwi_data, mask=dwi_mask)
 5125    csa_model = CsaOdfModel(gtab, sh_order=6)
 5126    gfa = csa_model.fit(dwi_data, mask=dwi_mask).gfa
 5127    stopping_criterion = ThresholdStoppingCriterion(gfa, .25)
 5128
 5129
 5130    if label_image is None or seed_labels is None:
 5131        seed_mask = fa.numpy().copy()
 5132        seed_mask[seed_mask >= fa_thresh] = 1
 5133        seed_mask[seed_mask < fa_thresh] = 0
 5134    else:
 5135        labels = label_image.numpy()
 5136        seed_mask = labels * 0
 5137        for u in seed_labels:
 5138            seed_mask[ labels == u ] = 1
 5139    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
 5140    if verbose:
 5141        print("streamlines begin ...", flush=True)
 5142
 5143    pmf = csd_fit.odf(small_sphere).clip(min=0)
 5144    if verbose:
 5145        print("ClosestPeakDirectionGetter begin ...", flush=True)
 5146    peak_dg = ClosestPeakDirectionGetter.from_pmf(pmf, max_angle=30.,
 5147                                                sphere=small_sphere)
 5148    if verbose:
 5149        print("local tracking begin ...", flush=True)
 5150    streamlines_generator = LocalTracking(peak_dg, stopping_criterion, seeds,
 5151                                            affine, step_size=.5)
 5152    streamlines = Streamlines(streamlines_generator)
 5153    from dipy.io.stateful_tractogram import Space, StatefulTractogram
 5154    from dipy.io.streamline import save_tractogram
 5155    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
 5156    if verbose:
 5157        print("streamlines done", flush=True)
 5158    return {
 5159          'tractogram': sft,
 5160          'streamlines': streamlines
 5161          }
 5162
 5163def dwi_streamline_pairwise_connectivity( streamlines, label_image, labels_to_connect=[1,None], verbose=False ):
 5164    """
 5165
 5166    Return streamlines connecting all of the regions in the label set. Ideal
 5167    for just 2 regions.
 5168
 5169    Arguments
 5170    ---------
 5171
 5172    streamlines : streamline object from dipy
 5173
 5174    label_image : atlas labels
 5175
 5176    labels_to_connect : list of 2 labels or [label,None]
 5177
 5178    verbose : boolean
 5179
 5180    Returns
 5181    -------
 5182    the subset of streamlines and a streamline count
 5183
 5184    Example
 5185    -------
 5186    >>> import antspymm
 5187    """
 5188    from dipy.tracking.streamline import Streamlines
 5189    keep_streamlines = Streamlines()
 5190
 5191    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5192
 5193    lin_T, offset = utils._mapping_to_voxel(affine)
 5194    label_image_np = label_image.numpy()
 5195    def check_it( sl, target_label, label_image, index, full=False ):
 5196        if full:
 5197            maxind=sl.shape[0]
 5198            for index in range(maxind):
 5199                pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5200                mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5201                if mylab == target_label[0] or mylab == target_label[1]:
 5202                    return { 'ok': True, 'label':mylab }
 5203        else:
 5204            pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5205            mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5206            if mylab == target_label[0] or mylab == target_label[1]:
 5207                return { 'ok': True, 'label':mylab }
 5208        return { 'ok': False, 'label':None }
 5209    ct=0
 5210    for k in range( len( streamlines ) ):
 5211        sl = streamlines[k]
 5212        mycheck = check_it( sl, labels_to_connect, label_image_np, index=0, full=True )
 5213        if mycheck['ok']:
 5214            otherind=1
 5215            if mycheck['label'] == labels_to_connect[1]:
 5216                otherind=0
 5217            lsl = len( sl )-1
 5218            pt = utils._to_voxel_coordinates(sl[lsl,:], lin_T, offset)
 5219            mylab_end = (label_image_np[ pt[0], pt[1], pt[2] ]).astype(int)
 5220            accept_point = mylab_end == labels_to_connect[otherind]
 5221            if verbose and accept_point:
 5222                print( mylab_end )
 5223            if labels_to_connect[1] is None:
 5224                accept_point = mylab_end != 0
 5225            if accept_point:
 5226                keep_streamlines.append(sl)
 5227                ct=ct+1
 5228    return { 'streamlines': keep_streamlines, 'count': ct }
 5229
 5230def dwi_streamline_pairwise_connectivity_old(
 5231    streamlines,
 5232    label_image,
 5233    exclusion_label = None,
 5234    verbose = False ):
 5235    import os
 5236    import re
 5237    import nibabel as nib
 5238    import numpy as np
 5239    import ants
 5240    from dipy.io.gradients import read_bvals_bvecs
 5241    from dipy.core.gradients import gradient_table
 5242    from dipy.tracking import utils
 5243    import dipy.reconst.dti as dti
 5244    from dipy.segment.clustering import QuickBundles
 5245    from dipy.tracking.utils import path_length
 5246    from dipy.tracking.local_tracking import LocalTracking
 5247    from dipy.tracking.streamline import Streamlines
 5248    volUnit = np.prod( ants.get_spacing( label_image ) )
 5249    labels = label_image.numpy()
 5250
 5251    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5252
 5253    import numpy as np
 5254    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5255    import pandas as pd
 5256    ulabs = np.unique( labels[ labels > 0 ] )
 5257    if exclusion_label is not None:
 5258        ulabs = ulabs[ ulabs != exclusion_label ]
 5259        exc_slice = labels == exclusion_label
 5260    if verbose:
 5261        print("Begin connectivity")
 5262    tracts = []
 5263    for k in range(len(ulabs)):
 5264        cc_slice = labels == ulabs[k]
 5265        cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5266        cc_streamlines = Streamlines(cc_streamlines)
 5267        if exclusion_label is not None:
 5268            cc_streamlines = utils.target(cc_streamlines, affine, exc_slice, include=False)
 5269            cc_streamlines = Streamlines(cc_streamlines)
 5270        for j in range(len(ulabs)):
 5271            cc_slice2 = labels == ulabs[j]
 5272            cc_streamlines2 = utils.target(cc_streamlines, affine, cc_slice2)
 5273            cc_streamlines2 = Streamlines(cc_streamlines2)
 5274            if exclusion_label is not None:
 5275                cc_streamlines2 = utils.target(cc_streamlines2, affine, exc_slice, include=False)
 5276                cc_streamlines2 = Streamlines(cc_streamlines2)
 5277            tracts.append( cc_streamlines2 )
 5278        if verbose:
 5279            print("end connectivity")
 5280    return {
 5281          'pairwise_tracts': tracts
 5282          }
 5283
 5284
 5285def dwi_streamline_connectivity(
 5286    streamlines,
 5287    label_image,
 5288    label_dataframe,
 5289    verbose = False ):
 5290    """
 5291
 5292    Summarize network connetivity of the input streamlines between all of the
 5293    regions in the label set.
 5294
 5295    Arguments
 5296    ---------
 5297
 5298    streamlines : streamline object from dipy
 5299
 5300    label_image : atlas labels
 5301
 5302    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
 5303
 5304    verbose : boolean
 5305
 5306    Returns
 5307    -------
 5308    dictionary holding summary connection statistics in wide format and matrix format.
 5309
 5310    Example
 5311    -------
 5312    >>> import antspymm
 5313    """
 5314    import os
 5315    import re
 5316    import nibabel as nib
 5317    import numpy as np
 5318    import ants
 5319    from dipy.io.gradients import read_bvals_bvecs
 5320    from dipy.core.gradients import gradient_table
 5321    from dipy.tracking import utils
 5322    import dipy.reconst.dti as dti
 5323    from dipy.segment.clustering import QuickBundles
 5324    from dipy.tracking.utils import path_length
 5325    from dipy.tracking.local_tracking import LocalTracking
 5326    from dipy.tracking.streamline import Streamlines
 5327    import os
 5328    import re
 5329    import nibabel as nib
 5330    import numpy as np
 5331    import ants
 5332    from dipy.io.gradients import read_bvals_bvecs
 5333    from dipy.core.gradients import gradient_table
 5334    from dipy.tracking import utils
 5335    import dipy.reconst.dti as dti
 5336    from dipy.segment.clustering import QuickBundles
 5337    from dipy.tracking.utils import path_length
 5338    from dipy.tracking.local_tracking import LocalTracking
 5339    from dipy.tracking.streamline import Streamlines
 5340    volUnit = np.prod( ants.get_spacing( label_image ) )
 5341    labels = label_image.numpy()
 5342
 5343    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5344
 5345    import numpy as np
 5346    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5347    import pandas as pd
 5348    ulabs = label_dataframe['Label']
 5349    labels_to_connect = ulabs[ulabs > 0]
 5350    Ctdf = None
 5351    lin_T, offset = utils._mapping_to_voxel(affine)
 5352    label_image_np = label_image.numpy()
 5353    def check_it( sl, target_label, label_image, index, not_label = None ):
 5354        pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5355        mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5356        if not_label is None:
 5357            if ( mylab == target_label ).sum() > 0 :
 5358                return { 'ok': True, 'label':mylab }
 5359        else:
 5360            if ( mylab == target_label ).sum() > 0 and ( mylab == not_label ).sum() == 0:
 5361                return { 'ok': True, 'label':mylab }
 5362        return { 'ok': False, 'label':None }
 5363    ct=0
 5364    which = lambda lst:list(np.where(lst)[0])
 5365    myCount = np.zeros( [len(ulabs),len(ulabs)])
 5366    for k in range( len( streamlines ) ):
 5367            sl = streamlines[k]
 5368            mycheck = check_it( sl, labels_to_connect, label_image_np, index=0 )
 5369            if mycheck['ok']:
 5370                exclabel=mycheck['label']
 5371                lsl = len( sl )-1
 5372                mycheck2 = check_it( sl, labels_to_connect, label_image_np, index=lsl, not_label=exclabel )
 5373                if mycheck2['ok']:
 5374                    myCount[ulabs == mycheck['label'],ulabs == mycheck2['label']]+=1
 5375                    ct=ct+1
 5376    Ctdf = label_dataframe.copy()
 5377    for k in range(len(ulabs)):
 5378            nn3 = "CnxCount"+str(k).zfill(3)
 5379            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
 5380    Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
 5381    return { 'connectivity_matrix' :  myCount, 'connectivity_wide' : Ctdfw }
 5382
 5383def dwi_streamline_connectivity_old(
 5384    streamlines,
 5385    label_image,
 5386    label_dataframe,
 5387    verbose = False ):
 5388    """
 5389
 5390    Summarize network connetivity of the input streamlines between all of the
 5391    regions in the label set.
 5392
 5393    Arguments
 5394    ---------
 5395
 5396    streamlines : streamline object from dipy
 5397
 5398    label_image : atlas labels
 5399
 5400    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
 5401
 5402    verbose : boolean
 5403
 5404    Returns
 5405    -------
 5406    dictionary holding summary connection statistics in wide format and matrix format.
 5407
 5408    Example
 5409    -------
 5410    >>> import antspymm
 5411    """
 5412
 5413    if verbose:
 5414        print("streamline connections ...")
 5415
 5416    import os
 5417    import re
 5418    import nibabel as nib
 5419    import numpy as np
 5420    import ants
 5421    from dipy.io.gradients import read_bvals_bvecs
 5422    from dipy.core.gradients import gradient_table
 5423    from dipy.tracking import utils
 5424    import dipy.reconst.dti as dti
 5425    from dipy.segment.clustering import QuickBundles
 5426    from dipy.tracking.utils import path_length
 5427    from dipy.tracking.local_tracking import LocalTracking
 5428    from dipy.tracking.streamline import Streamlines
 5429
 5430    volUnit = np.prod( ants.get_spacing( label_image ) )
 5431    labels = label_image.numpy()
 5432
 5433    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5434
 5435    if verbose:
 5436        print("path length begin ... volUnit = " + str( volUnit ) )
 5437    import numpy as np
 5438    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5439    import pandas as pd
 5440    ulabs = label_dataframe['Label']
 5441    pathLmean = np.zeros( [len(ulabs)])
 5442    pathLtot = np.zeros( [len(ulabs)])
 5443    pathCt = np.zeros( [len(ulabs)])
 5444    for k in range(len(ulabs)):
 5445        cc_slice = labels == ulabs[k]
 5446        cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5447        cc_streamlines = Streamlines(cc_streamlines)
 5448        if len(cc_streamlines) > 0:
 5449            wmpl = path_length(cc_streamlines, affine, cc_slice)
 5450            mean_path_length = wmpl[wmpl>0].mean()
 5451            total_path_length = wmpl[wmpl>0].sum()
 5452            pathLmean[int(k)] = mean_path_length
 5453            pathLtot[int(k)] = total_path_length
 5454            pathCt[int(k)] = len(cc_streamlines) * volUnit
 5455
 5456    # convert paths to data frames
 5457    pathdf = label_dataframe.copy()
 5458    pathdf.insert(pathdf.shape[1], "mean_path_length", pathLmean )
 5459    pathdf.insert(pathdf.shape[1], "total_path_length", pathLtot )
 5460    pathdf.insert(pathdf.shape[1], "streamline_count", pathCt )
 5461    pathdfw =antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5462        {path_length:pathdf }, ['mean_path_length', 'total_path_length', 'streamline_count'] )
 5463    allconnexwide = pathdfw
 5464
 5465    if verbose:
 5466        print("path length done ...")
 5467
 5468    Mdfw = None
 5469    Tdfw = None
 5470    Mdf = None
 5471    Tdf = None
 5472    Ctdf = None
 5473    Ctdfw = None
 5474    if True:
 5475        if verbose:
 5476            print("Begin connectivity")
 5477        M = np.zeros( [len(ulabs),len(ulabs)])
 5478        T = np.zeros( [len(ulabs),len(ulabs)])
 5479        myCount = np.zeros( [len(ulabs),len(ulabs)])
 5480        for k in range(len(ulabs)):
 5481            cc_slice = labels == ulabs[k]
 5482            cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5483            cc_streamlines = Streamlines(cc_streamlines)
 5484            for j in range(len(ulabs)):
 5485                cc_slice2 = labels == ulabs[j]
 5486                cc_streamlines2 = utils.target(cc_streamlines, affine, cc_slice2)
 5487                cc_streamlines2 = Streamlines(cc_streamlines2)
 5488                if len(cc_streamlines2) > 0 :
 5489                    wmpl = path_length(cc_streamlines2, affine, cc_slice2)
 5490                    mean_path_length = wmpl[wmpl>0].mean()
 5491                    total_path_length = wmpl[wmpl>0].sum()
 5492                    M[int(j),int(k)] = mean_path_length
 5493                    T[int(j),int(k)] = total_path_length
 5494                    myCount[int(j),int(k)] = len( cc_streamlines2 ) * volUnit
 5495        if verbose:
 5496            print("end connectivity")
 5497        Mdf = label_dataframe.copy()
 5498        Tdf = label_dataframe.copy()
 5499        Ctdf = label_dataframe.copy()
 5500        for k in range(len(ulabs)):
 5501            nn1 = "CnxMeanPL"+str(k).zfill(3)
 5502            nn2 = "CnxTotPL"+str(k).zfill(3)
 5503            nn3 = "CnxCount"+str(k).zfill(3)
 5504            Mdf.insert(Mdf.shape[1], nn1, M[k,:] )
 5505            Tdf.insert(Tdf.shape[1], nn2, T[k,:] )
 5506            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
 5507        Mdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkm' : Mdf },  Mdf.keys()[2:Mdf.shape[1]] )
 5508        Tdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkt' : Tdf },  Tdf.keys()[2:Tdf.shape[1]] )
 5509        Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
 5510        allconnexwide = pd.concat( [
 5511            pathdfw,
 5512            Mdfw,
 5513            Tdfw,
 5514            Ctdfw ], axis=1, ignore_index=False )
 5515
 5516    return {
 5517          'connectivity': allconnexwide,
 5518          'connectivity_matrix_mean': Mdf,
 5519          'connectivity_matrix_total': Tdf,
 5520          'connectivity_matrix_count': Ctdf
 5521          }
 5522
 5523
 5524def hierarchical_modality_summary(
 5525    target_image,
 5526    hier,
 5527    transformlist,
 5528    modality_name,
 5529    return_keys = ["Mean","Volume"],
 5530    verbose = False ):
 5531    """
 5532
 5533    Use output of antspyt1w.hierarchical to summarize a modality
 5534
 5535    Arguments
 5536    ---------
 5537
 5538    target_image : the image to summarize - should be brain extracted
 5539
 5540    hier : dictionary holding antspyt1w.hierarchical output
 5541
 5542    transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)
 5543
 5544    modality_name : adds the modality name to the data frame columns
 5545
 5546    return_keys = ["Mean","Volume"] keys to return
 5547
 5548    verbose : boolean
 5549
 5550    Returns
 5551    -------
 5552    data frame holding summary statistics in wide format
 5553
 5554    Example
 5555    -------
 5556    >>> import antspymm
 5557    """
 5558    dfout = pd.DataFrame()
 5559    def myhelper( target_image, seg, mytx, mapname, modname, mydf, extra='', verbose=False ):
 5560        if verbose:
 5561            print( mapname )
 5562        target_image_mask = ants.image_clone( target_image ) * 0.0
 5563        target_image_mask[ target_image != 0 ] = 1
 5564        cortmapped = ants.apply_transforms(
 5565            target_image,
 5566            seg,
 5567            mytx, interpolator='nearestNeighbor' ) * target_image_mask
 5568        mapped = antspyt1w.map_intensity_to_dataframe(
 5569            mapname,
 5570            target_image,
 5571            cortmapped )
 5572        mapped.iloc[:,1] = modname + '_' + extra + mapped.iloc[:,1]
 5573        mappedw = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5574            { 'x' : mapped},
 5575            col_names = return_keys )
 5576        if verbose:
 5577            print( mappedw.keys() )
 5578        if mydf.shape[0] > 0:
 5579            mydf = pd.concat( [ mydf, mappedw], axis=1, ignore_index=False )
 5580        else:
 5581            mydf = mappedw
 5582        return mydf
 5583    if hier['dkt_parc']['dkt_cortex'] is not None:
 5584        dfout = myhelper( target_image, hier['dkt_parc']['dkt_cortex'], transformlist,
 5585            "dkt", modality_name, dfout, extra='', verbose=verbose )
 5586    if hier['deep_cit168lab'] is not None:
 5587        dfout = myhelper( target_image, hier['deep_cit168lab'], transformlist,
 5588            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='deep_', verbose=verbose )
 5589    if hier['cit168lab'] is not None:
 5590        dfout = myhelper( target_image, hier['cit168lab'], transformlist,
 5591            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='', verbose=verbose  )
 5592    if hier['bf'] is not None:
 5593        dfout = myhelper( target_image, hier['bf'], transformlist,
 5594            "nbm3CH13", modality_name, dfout, extra='', verbose=verbose  )
 5595    # if hier['mtl'] is not None:
 5596    #    dfout = myhelper( target_image, hier['mtl'], reg,
 5597    #        "mtl_description", modality_name, dfout, extra='', verbose=verbose  )
 5598    return dfout
 5599
 5600def get_rsf_outputs( coords ):
 5601    if coords == 'powers':
 5602        return list([ 'meanBold', 'fmri_template', 'alff', 'falff', 'PerAF', 
 5603                   'CinguloopercularTaskControl', 'DefaultMode', 
 5604                   'MemoryRetrieval', 'VentralAttention', 'Visual',
 5605                   'FrontoparietalTaskControl', 'Salience', 'Subcortical', 'DorsalAttention'])
 5606    else:
 5607        yeo = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
 5608        return list( yeo['SystemName'].unique() )
 5609
 5610def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30,
 5611    transform=['rigid'], compreg=None, random_seed=42, verbose=False ):
 5612    """
 5613    multi-start multi-transform registration solution - based on ants.registration
 5614
 5615    fixed: fixed image
 5616
 5617    moving: moving image
 5618
 5619    n_simulations : number of simulations
 5620
 5621    max_rotation : maximum rotation angle
 5622
 5623    transform : list of transforms to loop through
 5624
 5625    compreg : registration results against which to compare
 5626
 5627    random_seed : random seed for reproducibility
 5628
 5629    verbose : boolean
 5630
 5631    """
 5632    import random
 5633    if random_seed is not None:
 5634        random.seed(random_seed)
 5635    if True:
 5636        output_directory = tempfile.mkdtemp()
 5637        output_directory_w = output_directory + "/tra_reg/"
 5638        os.makedirs(output_directory_w,exist_ok=True)
 5639        bestmi = math.inf
 5640        bestvar = 0.0
 5641        myorig = list(ants.get_origin( fixed ))
 5642        mymax = 0;
 5643        for k in range(len( myorig ) ):
 5644            if abs(myorig[k]) > mymax:
 5645                mymax = abs(myorig[k])
 5646        maxtrans = mymax * 0.05
 5647        if compreg is None:
 5648            bestreg=ants.registration( fixed,moving,'Translation',
 5649                outprefix=output_directory_w+"trans")
 5650            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
 5651        else :
 5652            bestreg=compreg
 5653            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
 5654        for mytx in transform:
 5655            regtx = 'antsRegistrationSyNRepro[r]'
 5656            with tempfile.NamedTemporaryFile(suffix='.h5') as tp:
 5657                if mytx == 'translation':
 5658                    regtx = 'Translation'
 5659                    rRotGenerator = ants.contrib.RandomTranslate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
 5660                elif mytx == 'affine':
 5661                    regtx = 'Affine'
 5662                    rRotGenerator = ants.contrib.RandomRotate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
 5663                else:
 5664                    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rotation*(-1.0), max_rotation ), reference=fixed )
 5665                for k in range(n_simulations):
 5666                    simtx = ants.compose_ants_transforms( [rRotGenerator.transform(), initx] )
 5667                    ants.write_transform( simtx, tp.name )
 5668                    if k > 0:
 5669                        reg = ants.registration( fixed, moving, regtx,
 5670                            initial_transform=tp.name,
 5671                            outprefix=output_directory_w+"reg"+str(k),
 5672                            verbose=False )
 5673                    else:
 5674                        reg = ants.registration( fixed, moving,
 5675                            regtx,
 5676                            outprefix=output_directory_w+"reg"+str(k),
 5677                            verbose=False )
 5678                    mymi = math.inf
 5679                    temp = reg['warpedmovout']
 5680                    myvar = temp.numpy().var()
 5681                    if verbose:
 5682                        print( str(k) + " : " + regtx  + " : " + mytx + " _var_ " + str( myvar ) )
 5683                    if myvar > 0 :
 5684                        mymi = ants.image_mutual_information( fixed, temp )
 5685                        if mymi < bestmi:
 5686                            if verbose:
 5687                                print( "mi @ " + str(k) + " : " + str(mymi), flush=True)
 5688                            bestmi = mymi
 5689                            bestreg = reg
 5690                            bestvar = myvar
 5691        if bestvar == 0.0 and compreg is not None:
 5692            return compreg        
 5693        return bestreg
 5694
 5695def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8,
 5696    bias_correct=True,
 5697    denoise=None,
 5698    srmodel=None,
 5699    target_range=[0,1],
 5700    poly_order='hist',
 5701    normalize_nm = False,
 5702    verbose=False ) :
 5703
 5704  """
 5705  Outputs the averaged and registered neuromelanin image, and neuromelanin labels
 5706
 5707  Arguments
 5708  ---------
 5709  list_nm_image : list of ANTsImages
 5710    list of neuromenlanin repeat images
 5711
 5712  t1 : ANTsImage
 5713    input 3-D T1 brain image
 5714
 5715  t1_head : ANTsImage
 5716    input 3-D T1 head image
 5717
 5718  t1lab : ANTsImage
 5719    t1 labels that will be propagated to the NM
 5720
 5721  brain_stem_dilation : integer default 8
 5722    dilates the brain stem mask to better match coverage of NM
 5723
 5724  bias_correct : boolean
 5725
 5726  denoise : None or integer
 5727
 5728  srmodel : None -- this is a work in progress feature, probably not optimal
 5729
 5730  target_range : 2-element tuple
 5731        a tuple or array defining the (min, max) of the input image
 5732        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 5733        intensity. This range should match the mapping used in the training
 5734        of the network.
 5735
 5736  poly_order : if not None, will fit a global regression model to map
 5737      intensity back to original histogram space; if 'hist' will match
 5738      by histogram matching - ants.histogram_match_image
 5739
 5740  normalize_nm : boolean - WIP not validated
 5741
 5742  verbose : boolean
 5743
 5744  Returns
 5745  ---------
 5746  Averaged and registered neuromelanin image and neuromelanin labels and wide csv
 5747
 5748  """
 5749
 5750  fnt=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni.nii.gz" )
 5751  fntNM=os.path.expanduser("~/.antspymm/CIT168_T1w_700um_pad_adni_NM_norm_avg.nii.gz" )
 5752  fntbst=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni_brainstem.nii.gz")
 5753  fnslab=os.path.expanduser("~/.antspyt1w/CIT168_MT_Slab_adni.nii.gz")
 5754  fntseg=os.path.expanduser("~/.antspyt1w/det_atlas_25_pad_LR_adni.nii.gz")
 5755
 5756  template = mm_read( fnt )
 5757  templateNM = ants.iMath( mm_read( fntNM ), "Normalize" )
 5758  templatebstem = mm_read( fntbst ).threshold_image( 1, 1000 )
 5759  # reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
 5760  reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
 5761  # map NM avg to t1 for neuromelanin processing
 5762  nmavg2t1 = ants.apply_transforms( t1, templateNM,
 5763    reg['fwdtransforms'], interpolator='linear' )
 5764  slab2t1 = ants.threshold_image( nmavg2t1, "Otsu", 2 ).threshold_image(1,2).iMath("MD",1).iMath("FillHoles")
 5765  # map brain stem and slab to t1 for neuromelanin processing
 5766  bstem2t1 = ants.apply_transforms( t1, templatebstem,
 5767    reg['fwdtransforms'],
 5768    interpolator='nearestNeighbor' ).iMath("MD",1)
 5769  slab2t1B = ants.apply_transforms( t1, mm_read( fnslab ),
 5770    reg['fwdtransforms'], interpolator = 'nearestNeighbor')
 5771  bstem2t1 = ants.crop_image( bstem2t1, slab2t1 )
 5772  cropper = ants.decrop_image( bstem2t1, slab2t1 ).iMath("MD",brain_stem_dilation)
 5773
 5774  # Average images in image_list
 5775  nm_avg = list_nm_images[0]*0.0
 5776  for k in range(len( list_nm_images )):
 5777    if denoise is not None:
 5778        list_nm_images[k] = ants.denoise_image( list_nm_images[k],
 5779            shrink_factor=1,
 5780            p=denoise,
 5781            r=denoise+1,
 5782            noise_model='Gaussian' )
 5783    if bias_correct :
 5784        n4mask = ants.threshold_image( ants.iMath(list_nm_images[k], "Normalize" ), 0.05, 1 )
 5785        list_nm_images[k] = ants.n4_bias_field_correction( list_nm_images[k], mask=n4mask )
 5786    nm_avg = nm_avg + ants.resample_image_to_target( list_nm_images[k], nm_avg ) / len( list_nm_images )
 5787
 5788  if verbose:
 5789      print("Register each nm image in list_nm_images to the averaged nm image (avg)")
 5790  nm_avg_new = nm_avg * 0.0
 5791  txlist = []
 5792  for k in range(len( list_nm_images )):
 5793    if verbose:
 5794        print(str(k) + " of " + str(len( list_nm_images ) ) )
 5795    current_image = ants.registration( list_nm_images[k], nm_avg,
 5796        type_of_transform = 'antsRegistrationSyNRepro[r]' )
 5797    txlist.append( current_image['fwdtransforms'][0] )
 5798    current_image = current_image['warpedfixout']
 5799    nm_avg_new = nm_avg_new + current_image / len( list_nm_images )
 5800  nm_avg = nm_avg_new
 5801
 5802  if verbose:
 5803      print("do slab registration to map anatomy to NM space")
 5804  t1c = ants.crop_image( t1_head, slab2t1 ).iMath("Normalize") # old way
 5805  nmavg2t1c = ants.crop_image( nmavg2t1, slab2t1 ).iMath("Normalize")
 5806  # slabreg = ants.registration( nm_avg, nmavg2t1c, 'antsRegistrationSyNRepro[r]' )
 5807  slabreg = tra_initializer( nm_avg, t1c, verbose=verbose )
 5808  if False:
 5809      slabregT1 = tra_initializer( nm_avg, t1c, verbose=verbose  )
 5810      miNM = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
 5811            ants.iMath(slabreg0['warpedmovout'],"Normalize") )
 5812      miT1 = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
 5813            ants.iMath(slabreg1['warpedmovout'],"Normalize") )
 5814      if miT1 < miNM:
 5815        slabreg = slabregT1
 5816  labels2nm = ants.apply_transforms( nm_avg, t1lab, slabreg['fwdtransforms'],
 5817    interpolator = 'genericLabel' )
 5818  cropper2nm = ants.apply_transforms( nm_avg, cropper, slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
 5819  nm_avg_cropped = ants.crop_image( nm_avg, cropper2nm )
 5820
 5821  if verbose:
 5822      print("now map these labels to each individual nm")
 5823  crop_mask_list = []
 5824  crop_nm_list = []
 5825  for k in range(len( list_nm_images )):
 5826      concattx = []
 5827      concattx.append( txlist[k] )
 5828      concattx.append( slabreg['fwdtransforms'][0] )
 5829      cropmask = ants.apply_transforms( list_nm_images[k], cropper,
 5830        concattx, interpolator = 'nearestNeighbor' )
 5831      crop_mask_list.append( cropmask )
 5832      temp = ants.crop_image( list_nm_images[k], cropmask )
 5833      crop_nm_list.append( temp )
 5834
 5835  if srmodel is not None:
 5836      if verbose:
 5837          print( " start sr " + str(len( crop_nm_list )) )
 5838      for k in range(len( crop_nm_list )):
 5839          if verbose:
 5840              print( " do sr " + str(k) )
 5841              print( crop_nm_list[k] )
 5842          temp = antspynet.apply_super_resolution_model_to_image(
 5843                crop_nm_list[k], srmodel, target_range=target_range,
 5844                regression_order=None )
 5845          if poly_order is not None:
 5846              bilin = ants.resample_image_to_target( crop_nm_list[k], temp )
 5847              if poly_order == 'hist':
 5848                  temp = ants.histogram_match_image( temp, bilin )
 5849              else:
 5850                  temp = antspynet.regression_match_image( temp, bilin, poly_order = poly_order )
 5851          crop_nm_list[k] = temp
 5852
 5853  nm_avg_cropped = crop_nm_list[0]*0.0
 5854  if verbose:
 5855      print( "cropped average" )
 5856      print( nm_avg_cropped )
 5857  for k in range(len( crop_nm_list )):
 5858      nm_avg_cropped = nm_avg_cropped + ants.apply_transforms( nm_avg_cropped,
 5859        crop_nm_list[k], txlist[k] ) / len( crop_nm_list )
 5860  for loop in range( 3 ):
 5861      nm_avg_cropped_new = nm_avg_cropped * 0.0
 5862      for k in range(len( crop_nm_list )):
 5863            myreg = ants.registration(
 5864                ants.iMath(nm_avg_cropped,"Normalize"),
 5865                ants.iMath(crop_nm_list[k],"Normalize"),
 5866                'antsRegistrationSyNRepro[r]' )
 5867            warpednext = ants.apply_transforms(
 5868                nm_avg_cropped_new,
 5869                crop_nm_list[k],
 5870                myreg['fwdtransforms'] )
 5871            nm_avg_cropped_new = nm_avg_cropped_new + warpednext
 5872      nm_avg_cropped = nm_avg_cropped_new / len( crop_nm_list )
 5873
 5874  slabregUpdated = tra_initializer( nm_avg_cropped, t1c, compreg=slabreg,verbose=verbose  )
 5875  tempOrig = ants.apply_transforms( nm_avg_cropped_new, t1c, slabreg['fwdtransforms'] )
 5876  tempUpdate = ants.apply_transforms( nm_avg_cropped_new, t1c, slabregUpdated['fwdtransforms'] )
 5877  miUpdate = ants.image_mutual_information(
 5878    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempUpdate,"Normalize") )
 5879  miOrig = ants.image_mutual_information(
 5880    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempOrig,"Normalize") )
 5881  if miUpdate < miOrig :
 5882      slabreg = slabregUpdated
 5883
 5884  if normalize_nm:
 5885      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
 5886      nm_avg_cropped = ants.iMath( nm_avg_cropped, "TruncateIntensity",0.05,0.95)
 5887      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
 5888
 5889  labels2nm = ants.apply_transforms( nm_avg_cropped, t1lab,
 5890        slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
 5891
 5892  # fix the reference region - keep top two parts
 5893  def get_biggest_part( x, labeln ):
 5894      temp33 = ants.threshold_image( x, labeln, labeln ).iMath("GetLargestComponent")
 5895      x[ x == labeln] = 0
 5896      x[ temp33 == 1 ] = labeln
 5897
 5898  get_biggest_part( labels2nm, 33 )
 5899  get_biggest_part( labels2nm, 34 )
 5900
 5901  if verbose:
 5902      print( "map summary measurements to wide format" )
 5903  nmdf = antspyt1w.map_intensity_to_dataframe(
 5904          'CIT168_Reinf_Learn_v1_label_descriptions_pad',
 5905          nm_avg_cropped,
 5906          labels2nm)
 5907  if verbose:
 5908      print( "merge to wide format" )
 5909  nmdf_wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5910              {'NM' : nmdf},
 5911              col_names = ['Mean'] )
 5912
 5913  rr_mask = ants.mask_image( labels2nm, labels2nm, [33,34] , binarize=True )
 5914  sn_mask = ants.mask_image( labels2nm, labels2nm, [7,9,23,25] , binarize=True )
 5915  nmavgsnr = mask_snr( nm_avg_cropped, rr_mask, sn_mask, bias_correct = False )
 5916
 5917  snavg = nm_avg_cropped[ sn_mask == 1].mean()
 5918  rravg = nm_avg_cropped[ rr_mask == 1].mean()
 5919  snstd = nm_avg_cropped[ sn_mask == 1].std()
 5920  rrstd = nm_avg_cropped[ rr_mask == 1].std()
 5921  vol_element = np.prod( ants.get_spacing(sn_mask) )
 5922  snvol = vol_element * sn_mask.sum()
 5923
 5924  # get the mean voxel position of the SN
 5925  if snvol > 0:
 5926      sn_z = ants.transform_physical_point_to_index( sn_mask, ants.get_center_of_mass(sn_mask ))[2]
 5927      sn_z = sn_z/sn_mask.shape[2] # around 0.5 would be nice
 5928  else:
 5929      sn_z = math.nan
 5930
 5931  nm_evr = 0.0
 5932  if cropper2nm.sum() > 0:
 5933    nm_evr = antspyt1w.patch_eigenvalue_ratio( nm_avg, 512, [6,6,6], 
 5934        evdepth = 0.9, mask=cropper2nm )
 5935
 5936  simg = ants.smooth_image( nm_avg_cropped, np.min(ants.get_spacing(nm_avg_cropped)) )
 5937  k = 2.0
 5938  rrthresh = (rravg + k * rrstd)
 5939  nmabovekthresh_mask = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5940  snvolabovethresh = vol_element * nmabovekthresh_mask.sum()
 5941  snintmeanabovethresh = float( ( simg * nmabovekthresh_mask ).mean() )
 5942  snintsumabovethresh = float( ( simg * nmabovekthresh_mask ).sum() )
 5943
 5944  k = 3.0
 5945  rrthresh = (rravg + k * rrstd)
 5946  nmabovekthresh_mask3 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5947  snvolabovethresh3 = vol_element * nmabovekthresh_mask3.sum()
 5948
 5949  k = 1.0
 5950  rrthresh = (rravg + k * rrstd)
 5951  nmabovekthresh_mask1 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5952  snvolabovethresh1 = vol_element * nmabovekthresh_mask1.sum()
 5953  
 5954  if verbose:
 5955    print( "nm vol @2std above rrmean: " + str( snvolabovethresh ) )
 5956    print( "nm intmean @2std above rrmean: " + str( snintmeanabovethresh ) )
 5957    print( "nm intsum @2std above rrmean: " + str( snintsumabovethresh ) )
 5958    print( "nm done" )
 5959
 5960  return convert_np_in_dict( {
 5961      'NM_avg' : nm_avg,
 5962      'NM_avg_cropped' : nm_avg_cropped,
 5963      'NM_labels': labels2nm,
 5964      'NM_cropped': crop_nm_list,
 5965      'NM_midbrainROI': cropper2nm,
 5966      'NM_dataframe': nmdf,
 5967      'NM_dataframe_wide': nmdf_wide,
 5968      't1_to_NM': slabreg['warpedmovout'],
 5969      't1_to_NM_transform' : slabreg['fwdtransforms'],
 5970      'NM_avg_signaltonoise' : nmavgsnr,
 5971      'NM_avg_substantianigra' : snavg,
 5972      'NM_std_substantianigra' : snstd,
 5973      'NM_volume_substantianigra' : snvol,
 5974      'NM_volume_substantianigra_1std' : snvolabovethresh1,
 5975      'NM_volume_substantianigra_2std' : snvolabovethresh,
 5976      'NM_intmean_substantianigra_2std' : snintmeanabovethresh,
 5977      'NM_intsum_substantianigra_2std' : snintsumabovethresh,
 5978      'NM_volume_substantianigra_3std' : snvolabovethresh3,
 5979      'NM_avg_refregion' : rravg,
 5980      'NM_std_refregion' : rrstd,
 5981      'NM_min' : nm_avg_cropped.min(),
 5982      'NM_max' : nm_avg_cropped.max(),
 5983      'NM_mean' : nm_avg_cropped.numpy().mean(),
 5984      'NM_sd' : np.std( nm_avg_cropped.numpy() ),
 5985      'NM_q0pt05' : np.quantile( nm_avg_cropped.numpy(), 0.05 ),
 5986      'NM_q0pt10' : np.quantile( nm_avg_cropped.numpy(), 0.10 ),
 5987      'NM_q0pt90' : np.quantile( nm_avg_cropped.numpy(), 0.90 ),
 5988      'NM_q0pt95' : np.quantile( nm_avg_cropped.numpy(), 0.95 ),
 5989      'NM_substantianigra_z_coordinate' : sn_z,
 5990      'NM_evr' : nm_evr,
 5991      'NM_count': len( list_nm_images )
 5992       } )
 5993
 5994
 5995
 5996def estimate_optimal_pca_components(data, variance_threshold=0.80, plot=False):
 5997    """
 5998    Estimate the optimal number of PCA components to represent the given data.
 5999
 6000    :param data: The data matrix (samples x features).
 6001    :param variance_threshold: Threshold for cumulative explained variance (default 0.95).
 6002    :param plot: If True, plot the cumulative explained variance graph (default False).
 6003    :return: The optimal number of principal components.
 6004    """
 6005    import numpy as np
 6006    from sklearn.decomposition import PCA
 6007    import matplotlib.pyplot as plt
 6008
 6009    # Perform PCA
 6010    pca = PCA()
 6011    pca.fit(data)
 6012
 6013    # Calculate cumulative explained variance
 6014    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
 6015
 6016    # Determine the number of components for desired explained variance
 6017    n_components = np.where(cumulative_variance >= variance_threshold)[0][0] + 1
 6018
 6019    # Optionally plot the explained variance
 6020    if plot:
 6021        plt.figure(figsize=(8, 4))
 6022        plt.plot(cumulative_variance, linewidth=2)
 6023        plt.axhline(y=variance_threshold, color='r', linestyle='--')
 6024        plt.axvline(x=n_components - 1, color='r', linestyle='--')
 6025        plt.xlabel('Number of Components')
 6026        plt.ylabel('Cumulative Explained Variance')
 6027        plt.title('Explained Variance by Number of Principal Components')
 6028        plt.show()
 6029
 6030    return n_components
 6031
 6032import numpy as np
 6033
 6034def compute_PerAF_voxel(time_series):
 6035    """
 6036    Compute the Percentage Amplitude Fluctuation (PerAF) for a given time series.
 6037
 6038    10.1371/journal.pone.0227021
 6039
 6040    PerAF = 100/n * sum(|(x_i - m)/m|) 
 6041    where m = 1/n * sum(x_i), x_i is the signal intensity at each time point, 
 6042    and n is the total number of time points.
 6043
 6044    :param time_series: Numpy array of time series data
 6045    :return: Computed PerAF value
 6046    """
 6047    n = len(time_series)
 6048    m = np.mean(time_series)
 6049    perAF = 100 / n * np.sum(np.abs((time_series - m) / m))
 6050    return perAF
 6051
 6052def calculate_trimmed_mean(data, proportion_to_trim):
 6053    """
 6054    Calculate the trimmed mean for a given data array.
 6055
 6056    :param data: A numpy array of data.
 6057    :param proportion_to_trim: Proportion (0 to 0.5) of data to trim from each end.
 6058    :return: The trimmed mean of the data.
 6059    """
 6060    from scipy import stats
 6061    return stats.trim_mean(data, proportion_to_trim)
 6062
 6063def PerAF( x, mask, globalmean=True ):
 6064    """
 6065    Compute the Percentage Amplitude Fluctuation (PerAF) for a given time series.
 6066
 6067    10.1371/journal.pone.0227021
 6068
 6069    PerAF = 100/n * sum(|(x_i - m)/m|) 
 6070    where m = 1/n * sum(x_i), x_i is the signal intensity at each time point, 
 6071    and n is the total number of time points.
 6072
 6073    :param x: time series antsImage
 6074    :param mask: brain mask
 6075    :param globalmean: boolean if True divide by the globalmean in the brain mask
 6076    :return: Computed PerAF image
 6077    """
 6078    time_series = ants.timeseries_to_matrix( x, mask )
 6079    n = time_series.shape[1]
 6080    vec = np.zeros( n )
 6081    for i in range(n):
 6082        vec[i] = compute_PerAF_voxel( time_series[:,i] )
 6083    outimg = ants.make_image( mask, vec )
 6084    if globalmean:
 6085        outimg = outimg / calculate_trimmed_mean( vec, 0.01 )
 6086    return outimg
 6087
 6088
 6089
 6090def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation,
 6091    f=[0.03, 0.08],
 6092    FD_threshold=5.0,
 6093    spa = None,
 6094    spt = None,
 6095    nc = 5,
 6096    outlier_threshold=0.250,
 6097    ica_components = 0,
 6098    impute = True,
 6099    censor = True,
 6100    despike = 2.5,
 6101    motion_as_nuisance = True,
 6102    powers = False,
 6103    upsample = 3.0,
 6104    clean_tmp = None,
 6105    paramset='unset',
 6106    verbose=False ):
 6107  """
 6108  Compute resting state network correlation maps based on the J Power labels.
 6109  This will output a map for each of the major network systems.  This function 
 6110  will by optionally upsample data to 2mm during the registration process if data 
 6111  is below that resolution.
 6112
 6113  registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations
 6114
 6115  Arguments
 6116  ---------
 6117  fmri : BOLD fmri antsImage
 6118
 6119  fmri_template : reference space for BOLD
 6120
 6121  t1 : ANTsImage
 6122    input 3-D T1 brain image (brain extracted)
 6123
 6124  t1segmentation : ANTsImage
 6125    t1 segmentation - a six tissue segmentation image in T1 space
 6126
 6127  f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015
 6128
 6129  spa : gaussian smoothing for spatial component (physical coordinates)
 6130
 6131  spt : gaussian smoothing for temporal component
 6132
 6133  nc  : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm
 6134
 6135  ica_components : integer if greater than 0 then include ica components
 6136
 6137  impute : boolean if True, then use imputation in f/ALFF, PerAF calculation
 6138
 6139  censor : boolean if True, then use censoring (censoring)
 6140
 6141  despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction
 6142
 6143  motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance
 6144
 6145  powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)
 6146
 6147  upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution
 6148
 6149  clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing.  if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.
 6150
 6151  verbose : boolean
 6152
 6153  Returns
 6154  ---------
 6155  a dictionary containing the derived network maps
 6156
 6157  References
 6158  ---------
 6159
 6160  10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."
 6161
 6162  10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable"  "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding  high frequency information) reduced both reliability and signal-noise separation"
 6163
 6164  10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.
 6165
 6166  10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.
 6167
 6168  """
 6169
 6170  import warnings
 6171
 6172  if clean_tmp is not None:
 6173    clean_tmp_directory( age_hours = clean_tmp )
 6174
 6175  if nc > 1:
 6176    nc = int(nc)
 6177  else:
 6178    nc=float(nc)
 6179
 6180  type_of_transform="antsRegistrationSyNQuickRepro[r]" # , # should probably not change this
 6181  remove_it=True
 6182  output_directory = tempfile.mkdtemp()
 6183  output_directory_w = output_directory + "/ts_t1_reg/"
 6184  os.makedirs(output_directory_w,exist_ok=True)
 6185  ofnt1tx = tempfile.NamedTemporaryFile(delete=False,suffix='t1_deformation',dir=output_directory_w).name
 6186
 6187  import numpy as np
 6188# Assuming core and utils are modules or packages with necessary functions
 6189
 6190  if upsample > 0.0:
 6191      spc = ants.get_spacing( fmri )
 6192      minspc = upsample
 6193      if min(spc[0:3]) < minspc:
 6194          minspc = min(spc[0:3])
 6195      newspc = [minspc,minspc,minspc]
 6196      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
 6197
 6198  def temporal_derivative_same_shape(array):
 6199    """
 6200    Compute the temporal derivative of a 2D numpy array along the 0th axis (time)
 6201    and ensure the output has the same shape as the input.
 6202
 6203    :param array: 2D numpy array with time as the 0th axis.
 6204    :return: 2D numpy array of the temporal derivative with the same shape as input.
 6205    """
 6206    derivative = np.diff(array, axis=0)
 6207    
 6208    # Append a row to maintain the same shape
 6209    # You can choose to append a row of zeros or the last row of the derivative
 6210    # Here, a row of zeros is appended
 6211    zeros_row = np.zeros((1, array.shape[1]))
 6212    return np.vstack((zeros_row, derivative ))
 6213
 6214  def compute_tSTD(M, quantile, x=0, axis=0):
 6215    stdM = np.std(M, axis=axis)
 6216    # set bad values to x
 6217    stdM[stdM == 0] = x
 6218    stdM[np.isnan(stdM)] = x
 6219    tt = round(quantile * 100)
 6220    threshold_std = np.percentile(stdM, tt)
 6221    return {'tSTD': stdM, 'threshold_std': threshold_std}
 6222
 6223  def get_compcor_matrix(boldImage, mask, quantile):
 6224    """
 6225    Compute the compcor matrix.
 6226
 6227    :param boldImage: The bold image.
 6228    :param mask: The mask to apply, if None, it will be computed.
 6229    :param quantile: Quantile for computing threshold in tSTD.
 6230    :return: The compor matrix.
 6231    """
 6232    if mask is None:
 6233        temp = ants.slice_image(boldImage, axis=boldImage.dimension - 1, idx=0)
 6234        mask = ants.get_mask(temp)
 6235
 6236    imagematrix = ants.timeseries_to_matrix(boldImage, mask)
 6237    temp = compute_tSTD(imagematrix, quantile, 0)
 6238    tsnrmask = ants.make_image(mask, temp['tSTD'])
 6239    tsnrmask = ants.threshold_image(tsnrmask, temp['threshold_std'], temp['tSTD'].max())
 6240    M = ants.timeseries_to_matrix(boldImage, tsnrmask)
 6241    return M
 6242
 6243
 6244  from sklearn.decomposition import FastICA
 6245  def find_indices(lst, value):
 6246    return [index for index, element in enumerate(lst) if element > value]
 6247
 6248  def mean_of_list(lst):
 6249    if not lst:  # Check if the list is not empty
 6250        return 0  # Return 0 or appropriate value for an empty list
 6251    return sum(lst) / len(lst)
 6252  fmrispc = list( ants.get_spacing( fmri ) )
 6253  if spa is None:
 6254    spa = mean_of_list( fmrispc[0:3] ) * 1.0
 6255  if spt is None:
 6256    spt = fmrispc[3] * 0.5
 6257      
 6258  import numpy as np
 6259  import pandas as pd
 6260  import re
 6261  import math
 6262  # point data resources
 6263  A = np.zeros((1,1))
 6264  dfnname='DefaultMode'
 6265  if powers:
 6266      powers_areal_mni_itk = pd.read_csv( get_data('powers_mni_itk', target_extension=".csv")) # power coordinates
 6267      coords='powers'
 6268  else:
 6269      powers_areal_mni_itk = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
 6270      coords='yeo_17_500_2023'
 6271  fmri = ants.iMath( fmri, 'Normalize' )
 6272  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("FillHoles")
 6273  if verbose:
 6274      print("Begin rsfmri motion correction")
 6275  debug=False
 6276  if debug:
 6277      ants.image_write( fmri_template, '/tmp/fmri_template.nii.gz' )
 6278      ants.image_write( fmri, '/tmp/fmri.nii.gz' )
 6279      print("debug wrote fmri and fmri_template")
 6280  # mot-co
 6281  corrmo = timeseries_reg(
 6282    fmri, fmri_template,
 6283    type_of_transform=type_of_transform,
 6284    total_sigma=0.5,
 6285    fdOffset=2.0,
 6286    trim = 8,
 6287    output_directory=None,
 6288    verbose=verbose,
 6289    syn_metric='CC',
 6290    syn_sampling=2,
 6291    reg_iterations=[40,20,5],
 6292    return_numpy_motion_parameters=True )
 6293  
 6294  if verbose:
 6295      print("End rsfmri motion correction")
 6296      print("--maximum motion : " + str(corrmo['FD'].max()) )
 6297      print("=== next anatomically based mapping ===")
 6298
 6299  despiking_count = np.zeros( corrmo['motion_corrected'].shape[3] )
 6300  if despike > 0.0:
 6301      corrmo['motion_corrected'], despiking_count = despike_time_series_afni( corrmo['motion_corrected'], c1=despike )
 6302
 6303  despiking_count_summary = despiking_count.sum() / np.prod( corrmo['motion_corrected'].shape )
 6304  high_motion_count=(corrmo['FD'] > FD_threshold ).sum()
 6305  high_motion_pct=high_motion_count / fmri.shape[3]
 6306
 6307  # filter mask based on TSNR
 6308  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 6309  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 6310  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",2)
 6311  bmask = bmask * tsnrmask
 6312
 6313  # anatomical mapping
 6314  und = fmri_template * bmask
 6315  t1reg = ants.registration( und, t1,
 6316     "antsRegistrationSyNQuickRepro[s]", outprefix=ofnt1tx )
 6317  if verbose:
 6318    print("t1 2 bold done")
 6319  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 6320  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 6321  gmseg = ants.threshold_image( gmseg, 1, 4 )
 6322  gmseg = ants.iMath( gmseg, 'MD', 1 ) # FIXMERSF
 6323  gmseg = ants.apply_transforms( und, gmseg,
 6324    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' ) * bmask
 6325  csfAndWM = ( ants.threshold_image( t1segmentation, 1, 1 ) +
 6326               ants.threshold_image( t1segmentation, 3, 3 ) ).morphology("erode",1)
 6327  csfAndWM = ants.apply_transforms( und, csfAndWM,
 6328    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6329  csf = ants.threshold_image( t1segmentation, 1, 1 )
 6330  csf = ants.apply_transforms( und, csf, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6331  wm = ants.threshold_image( t1segmentation, 3, 3 ).morphology("erode",1)
 6332  wm = ants.apply_transforms( und, wm, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6333  if powers:
 6334    ch2 = mm_read( ants.get_ants_data( "ch2" ) )
 6335  else:
 6336    ch2 = mm_read( get_data( "PPMI_template0_brain", target_extension='.nii.gz' ) )
 6337  treg = ants.registration( 
 6338    # this is to make the impact of resolution consistent
 6339    ants.resample_image(t1, [1.0,1.0,1.0], interp_type=0), 
 6340    ch2, "antsRegistrationSyNQuickRepro[s]" )
 6341  if powers:
 6342    concatx2 = treg['invtransforms'] + t1reg['invtransforms']
 6343    pts2bold = ants.apply_transforms_to_points( 3, powers_areal_mni_itk, concatx2,
 6344        whichtoinvert = ( True, False, True, False ) )
 6345    locations = pts2bold.iloc[:,:3].values
 6346    ptImg = ants.make_points_image( locations, bmask, radius = 2 )
 6347  else:
 6348    concatx2 = t1reg['fwdtransforms'] + treg['fwdtransforms']    
 6349    rsfsegfn = get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".nii.gz")
 6350    rsfsegimg = ants.image_read( rsfsegfn )
 6351    ptImg = ants.apply_transforms( und, rsfsegimg, concatx2, interpolator='nearestNeighbor' ) * bmask
 6352    pts2bold = powers_areal_mni_itk
 6353    # ants.plot( und, ptImg, crop=True, axis=2 )
 6354
 6355  # optional smoothing
 6356  tr = ants.get_spacing( corrmo['motion_corrected'] )[3]
 6357  smth = ( spa, spa, spa, spt ) # this is for sigmaInPhysicalCoordinates = TRUE
 6358  simg = ants.smooth_image( corrmo['motion_corrected'], smth, sigma_in_physical_coordinates = True )
 6359
 6360  # collect censoring indices
 6361  hlinds = find_indices( corrmo['FD'], FD_threshold )
 6362  if verbose:
 6363    print("high motion indices")
 6364    print( hlinds )
 6365  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6366    fmrimotcorr, hlinds2 = loop_timeseries_censoring( corrmo['motion_corrected'], 
 6367      threshold=outlier_threshold, verbose=verbose )
 6368    hlinds.extend( hlinds2 )
 6369    del fmrimotcorr
 6370  hlinds = list(set(hlinds)) # make unique
 6371
 6372  # nuisance
 6373  globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], bmask )
 6374  globalsignal = np.nanmean( globalmat, axis = 1 )
 6375  del globalmat
 6376  compcorquantile=0.50
 6377  nc_wm=nc_csf=nc
 6378  if nc < 1:
 6379    globalmat = get_compcor_matrix( corrmo['motion_corrected'], wm, compcorquantile )
 6380    nc_wm = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
 6381    globalmat = get_compcor_matrix( corrmo['motion_corrected'], csf, compcorquantile )
 6382    nc_csf = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
 6383    del globalmat
 6384  if verbose:
 6385    print("include compcor components as nuisance: csf " + str(nc_csf) + " wm " + str(nc_wm))
 6386  mycompcor_csf = ants.compcor( corrmo['motion_corrected'],
 6387    ncompcor=nc_csf, quantile=compcorquantile, mask = csf,
 6388    filter_type='polynomial', degree=2 )
 6389  mycompcor_wm = ants.compcor( corrmo['motion_corrected'],
 6390    ncompcor=nc_wm, quantile=compcorquantile, mask = wm,
 6391    filter_type='polynomial', degree=2 )
 6392  nuisance = np.c_[ mycompcor_csf[ 'components' ], mycompcor_wm[ 'components' ] ]
 6393
 6394  if motion_as_nuisance:
 6395      if verbose:
 6396          print("include motion as nuisance")
 6397          print( corrmo['motion_parameters'].shape )
 6398      deriv = temporal_derivative_same_shape( corrmo['motion_parameters']  )
 6399      nuisance = np.c_[ nuisance, corrmo['motion_parameters'], deriv ]
 6400
 6401  if ica_components > 0:
 6402    if verbose:
 6403        print("include ica components as nuisance: " + str(ica_components))
 6404    ica = FastICA(n_components=ica_components, max_iter=10000, tol=0.001, random_state=42 )
 6405    globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], csfAndWM )
 6406    nuisance_ica = ica.fit_transform(globalmat)  # Reconstruct signals
 6407    nuisance = np.c_[ nuisance, nuisance_ica ]
 6408    del globalmat
 6409
 6410  # concat all nuisance data
 6411  # nuisance = np.c_[ nuisance, mycompcor['basis'] ]
 6412  # nuisance = np.c_[ nuisance, corrmo['FD'] ]
 6413  nuisance = np.c_[ nuisance, globalsignal ]
 6414
 6415  if impute:
 6416    simgimp = impute_timeseries( simg, hlinds, method='linear')
 6417  else:
 6418    simgimp = simg
 6419
 6420  # falff/alff stuff  def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
 6421  myfalff=alff_image( simgimp, bmask, flo=f[0], fhi=f[1], nuisance=nuisance  )
 6422
 6423  # bandpass any data collected before here -- if bandpass requested
 6424  if f[0] > 0 and f[1] < 1.0:
 6425    if verbose:
 6426        print( "bandpass: " + str(f[0]) + " <=> " + str( f[1] ) )
 6427    nuisance = ants.bandpass_filter_matrix( nuisance, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
 6428    globalmat = ants.timeseries_to_matrix( simg, bmask )
 6429    globalmat = ants.bandpass_filter_matrix( globalmat, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
 6430    simg = ants.matrix_to_timeseries( simg, globalmat, bmask )
 6431
 6432  if verbose:
 6433    print("now regress nuisance")
 6434
 6435
 6436  if len( hlinds ) > 0 :
 6437    if censor:
 6438        nuisance = remove_elements_from_numpy_array( nuisance, hlinds  )
 6439        simg = remove_volumes_from_timeseries( simg, hlinds )
 6440
 6441  gmmat = ants.timeseries_to_matrix( simg, bmask )
 6442  gmmat = ants.regress_components( gmmat, nuisance )
 6443  simg = ants.matrix_to_timeseries(simg, gmmat, bmask)
 6444
 6445
 6446  # structure the output data
 6447  outdict = {}
 6448  outdict['paramset'] = paramset
 6449  outdict['upsampling'] = upsample
 6450  outdict['coords'] = coords
 6451  outdict['dfnname']=dfnname
 6452  outdict['meanBold'] = und
 6453
 6454  # add correlation matrix that captures each node pair
 6455  # some of the spheres overlap so extract separately from each ROI
 6456  if powers:
 6457    nPoints = int(pts2bold['ROI'].max())
 6458    pointrange = list(range(int(nPoints)))
 6459  else:
 6460    nPoints = int(ptImg.max())
 6461    pointrange = list(range(int(nPoints)))
 6462  nVolumes = simg.shape[3]
 6463  meanROI = np.zeros([nVolumes, nPoints])
 6464  roiNames = []
 6465  if debug:
 6466      ptImgAll = und * 0.
 6467  for i in pointrange:
 6468    # specify name for matrix entries that's links back to ROI number and network; e.g., ROI1_Uncertain
 6469    netLabel = re.sub( " ", "", pts2bold.loc[i,'SystemName'])
 6470    netLabel = re.sub( "-", "", netLabel )
 6471    netLabel = re.sub( "/", "", netLabel )
 6472    roiLabel = "ROI" + str(pts2bold.loc[i,'ROI']) + '_' + netLabel
 6473    roiNames.append( roiLabel )
 6474    if powers:
 6475        ptImage = ants.make_points_image(pts2bold.iloc[[i],:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6476    else:
 6477        #print("Doing " + pts2bold.loc[i,'SystemName'] + " at " + str(i) )
 6478        #ptImage = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==pts2bold.loc[i,'SystemName']],binarize=True)
 6479        ptImage=ants.threshold_image( ptImg, pts2bold.loc[i,'ROI'], pts2bold.loc[i,'ROI'] )
 6480    if debug:
 6481      ptImgAll = ptImgAll + ptImage
 6482    if ptImage.sum() > 0 :
 6483        meanROI[:,i] = ants.timeseries_to_matrix( simg, ptImage).mean(axis=1)
 6484
 6485  if debug:
 6486      ants.image_write( simg, '/tmp/simg.nii.gz' )
 6487      ants.image_write( ptImgAll, '/tmp/ptImgAll.nii.gz' )
 6488      ants.image_write( und, '/tmp/und.nii.gz' )
 6489      ants.image_write( und, '/tmp/und.nii.gz' )
 6490
 6491  # get full correlation matrix
 6492  corMat = np.corrcoef(meanROI, rowvar=False)
 6493  outputMat = pd.DataFrame(corMat)
 6494  outputMat.columns = roiNames
 6495  outputMat['ROIs'] = roiNames
 6496  # add to dictionary
 6497  outdict['fullCorrMat'] = outputMat
 6498
 6499  networks = powers_areal_mni_itk['SystemName'].unique()
 6500  # this is just for human readability - reminds us of which we choose by default
 6501  if powers:
 6502    netnames = ['Cingulo-opercular Task Control', 'Default Mode',
 6503                    'Memory Retrieval', 'Ventral Attention', 'Visual',
 6504                    'Fronto-parietal Task Control', 'Salience', 'Subcortical',
 6505                    'Dorsal Attention']
 6506    numofnets = [3,5,6,7,8,9,10,11,13]
 6507  else:
 6508    netnames = networks
 6509    numofnets = list(range(len(netnames)))
 6510 
 6511  ct = 0
 6512  for mynet in numofnets:
 6513    netname = re.sub( " ", "", networks[mynet] )
 6514    netname = re.sub( "-", "", netname )
 6515    ww = np.where( powers_areal_mni_itk['SystemName'] == networks[mynet] )[0]
 6516    if powers:
 6517        dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6518    else:
 6519        dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[mynet]],binarize=True)
 6520    if dfnImg.max() >= 1:
 6521        if verbose:
 6522            print("DO: " + coords + " " + netname )
 6523        dfnmat = ants.timeseries_to_matrix( simg, ants.threshold_image( dfnImg, 1, dfnImg.max() ) )
 6524        dfnsignal = np.nanmean( dfnmat, axis = 1 )
 6525        nan_count_dfn = np.count_nonzero( np.isnan( dfnsignal) )
 6526        if nan_count_dfn > 0 :
 6527            warnings.warn( " mynet " + netnames[ mynet ] + " vs " +  " mean-signal has nans " + str( nan_count_dfn ) ) 
 6528        gmmatDFNCorr = np.zeros( gmmat.shape[1] )
 6529        if nan_count_dfn == 0:
 6530            for k in range( gmmat.shape[1] ):
 6531                nan_count_gm = np.count_nonzero( np.isnan( gmmat[:,k]) )
 6532                if debug and False:
 6533                    print( str( k ) +  " nans gm " + str(nan_count_gm)  )
 6534                if nan_count_gm == 0:
 6535                    gmmatDFNCorr[ k ] = pearsonr( dfnsignal, gmmat[:,k] )[0]
 6536        corrImg = ants.make_image( bmask, gmmatDFNCorr  )
 6537        outdict[ netname ] = corrImg * gmseg
 6538    else:
 6539        outdict[ netname ] = None
 6540    ct = ct + 1
 6541
 6542  A = np.zeros( ( len( numofnets ) , len( numofnets ) ) )
 6543  A_wide = np.zeros( ( 1, len( numofnets ) * len( numofnets ) ) )
 6544  newnames=[]
 6545  newnames_wide=[]
 6546  ct = 0
 6547  for i in range( len( numofnets ) ):
 6548      netnamei = re.sub( " ", "", networks[numofnets[i]] )
 6549      netnamei = re.sub( "-", "", netnamei )
 6550      newnames.append( netnamei  )
 6551      ww = np.where( powers_areal_mni_itk['SystemName'] == networks[numofnets[i]] )[0]
 6552      if powers:
 6553          dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6554      else:
 6555          dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[numofnets[i]]],binarize=True)
 6556      for j in range( len( numofnets ) ):
 6557          netnamej = re.sub( " ", "", networks[numofnets[j]] )
 6558          netnamej = re.sub( "-", "", netnamej )
 6559          newnames_wide.append( netnamei + "_2_" + netnamej )
 6560          A[i,j] = 0
 6561          if dfnImg is not None and netnamej is not None:
 6562            subbit = dfnImg == 1
 6563            if subbit is not None:
 6564                if subbit.sum() > 0 and netnamej in outdict:
 6565                    A[i,j] = outdict[ netnamej ][ subbit ].mean()
 6566          A_wide[0,ct] = A[i,j]
 6567          ct=ct+1
 6568
 6569  A = pd.DataFrame( A )
 6570  A.columns = newnames
 6571  A['networks']=newnames
 6572  A_wide = pd.DataFrame( A_wide )
 6573  A_wide.columns = newnames_wide
 6574  outdict['corr'] = A
 6575  outdict['corr_wide'] = A_wide
 6576  outdict['fmri_template'] = fmri_template
 6577  outdict['brainmask'] = bmask
 6578  outdict['gmmask'] = gmseg
 6579  outdict['alff'] = myfalff['alff']
 6580  outdict['falff'] = myfalff['falff']
 6581  # add global mean and standard deviation for post-hoc z-scoring
 6582  outdict['alff_mean'] = (myfalff['alff'][myfalff['alff']!=0]).mean()
 6583  outdict['alff_sd'] = (myfalff['alff'][myfalff['alff']!=0]).std()
 6584  outdict['falff_mean'] = (myfalff['falff'][myfalff['falff']!=0]).mean()
 6585  outdict['falff_sd'] = (myfalff['falff'][myfalff['falff']!=0]).std()
 6586
 6587  perafimg = PerAF( simgimp, bmask )
 6588  for k in pointrange:
 6589    anatname=( pts2bold['AAL'][k] )
 6590    if isinstance(anatname, str):
 6591        anatname = re.sub("_","",anatname)
 6592    else:
 6593        anatname='Unk'
 6594    if powers:
 6595        kk = f"{k:0>3}"+"_"
 6596    else:
 6597        kk = f"{k % int(nPoints/2):0>3}"+"_"
 6598    fname='falffPoint'+kk+anatname
 6599    aname='alffPoint'+kk+anatname
 6600    pname='perafPoint'+kk+anatname
 6601    localsel = ptImg == k
 6602    if localsel.sum() > 0 : # check if non-empty
 6603        outdict[fname]=(outdict['falff'][localsel]).mean()
 6604        outdict[aname]=(outdict['alff'][localsel]).mean()
 6605        outdict[pname]=(perafimg[localsel]).mean()
 6606    else:
 6607        outdict[fname]=math.nan
 6608        outdict[aname]=math.nan
 6609        outdict[pname]=math.nan
 6610
 6611  rsfNuisance = pd.DataFrame( nuisance )
 6612  if remove_it:
 6613    import shutil
 6614    shutil.rmtree(output_directory, ignore_errors=True )
 6615
 6616  if not powers:
 6617    dfnsum=outdict['DefaultA']+outdict['DefaultB']+outdict['DefaultC']
 6618    outdict['DefaultMode']=dfnsum
 6619    dfnsum=outdict['VisCent']+outdict['VisPeri']
 6620    outdict['Visual']=dfnsum
 6621
 6622  nonbrainmask = ants.iMath( bmask, "MD",2) - bmask
 6623  trimmask = ants.iMath( bmask, "ME",2)
 6624  edgemask = ants.iMath( bmask, "ME",1) - trimmask
 6625  outdict['motion_corrected'] = corrmo['motion_corrected']
 6626  outdict['nuisance'] = rsfNuisance
 6627  outdict['PerAF'] = perafimg
 6628  outdict['tsnr'] = mytsnr
 6629  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
 6630  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 6631  outdict['bandpass_freq_0']=f[0]
 6632  outdict['bandpass_freq_1']=f[1]
 6633  outdict['censor']=int(censor)
 6634  outdict['spatial_smoothing']=spa
 6635  outdict['outlier_threshold']=outlier_threshold
 6636  outdict['FD_threshold']=outlier_threshold
 6637  outdict['high_motion_count'] = high_motion_count
 6638  outdict['high_motion_pct'] = high_motion_pct
 6639  outdict['despiking_count_summary'] = despiking_count_summary
 6640  outdict['FD_max'] = corrmo['FD'].max()
 6641  outdict['FD_mean'] = corrmo['FD'].mean()
 6642  outdict['FD_sd'] = corrmo['FD'].std()
 6643  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
 6644  outdict['n_outliers'] = len(hlinds)
 6645  outdict['nc_wm'] = int(nc_wm)
 6646  outdict['nc_csf'] = int(nc_csf)
 6647  outdict['minutes_original_data'] = ( tr * fmri.shape[3] ) / 60.0 # minutes of useful data
 6648  outdict['minutes_censored_data'] = ( tr * simg.shape[3] ) / 60.0 # minutes of useful data
 6649  return convert_np_in_dict( outdict )
 6650
 6651
 6652def calculate_CBF(Delta_M, M_0, mask,
 6653                  Lambda=0.9, T_1=0.67, Alpha=0.68, w=1.0, Tau=1.5):
 6654    """
 6655    Calculate the Cerebral Blood Flow (CBF) where Delta_M and M_0 are antsImages 
 6656    and the other variables are scalars.  Guesses at default values are used here. 
 6657    We use the pCASL equation.  NOT YET TESTED.
 6658
 6659    Parameters:
 6660    Delta_M (antsImage): Change in magnetization (matrix)
 6661    M_0 (antsImage): Initial magnetization (matrix)
 6662    mask ( antsImage ): where to do the calculation
 6663    Lambda (float): Scalar
 6664    T_1 (float): Scalar representing relaxation time
 6665    Alpha (float): Scalar representing flip angle
 6666    w (float): Scalar
 6667    Tau (float): Scalar
 6668
 6669    Returns:
 6670    np.ndarray: CBF values (matrix)
 6671    """
 6672    cbf = M_0 * 0.0
 6673    m0thresh = np.quantile( M_0[mask==1], 0.1 )
 6674    sel = mask == 1 and M_0 > m0thresh
 6675    cbf[ sel ] = Delta_M[ sel ] * 60. * 100. * (Lambda * T_1)/( M_0[sel] * 2.0 * Alpha * 
 6676        (np.exp( -w * T_1) - np.exp(-(Tau + w) * T_1)))
 6677    cbf[ cbf < 0.0]=0.0
 6678    return cbf
 6679
 6680def despike_time_series_afni(image, c1=2.5, c2=4):
 6681    """
 6682    Despike a time series image using L1 polynomial fitting and nonlinear filtering.
 6683    Based on afni 3dDespike
 6684
 6685    :param image: ANTsPy image object containing time series data.
 6686    :param c1: Spike threshold value. Default is 2.5.
 6687    :param c2: Upper range of allowed deviation. Default is 4.
 6688    :return: Despiked ANTsPy image object.
 6689    """
 6690    data = image.numpy()  # Convert to numpy array
 6691    despiked_data = np.copy(data)  # Create a copy for despiked data
 6692    curve = despiked_data * 0.0
 6693
 6694    def l1_fit_polynomial(time_series, degree=2):
 6695        """
 6696        Fit a polynomial of given degree to the time series using least squares.
 6697        
 6698        :param time_series: 1D numpy array of voxel time series data.
 6699        :param degree: Degree of the polynomial to fit.
 6700        :return: Fitted polynomial values for the time series.
 6701        """
 6702        t = np.arange(len(time_series))
 6703        coefs = np.polyfit(t, time_series, degree)
 6704        polynomial = np.polyval(coefs, t)
 6705        return polynomial
 6706
 6707    # L1 fit a smooth-ish curve to each voxel time series
 6708    # Curve fitting for each voxel
 6709    for x in range(data.shape[0]):
 6710        for y in range(data.shape[1]):
 6711            for z in range(data.shape[2]):
 6712                voxel_time_series = data[x, y, z, :]
 6713                curve[x, y, z, :] = l1_fit_polynomial(voxel_time_series, degree=2)
 6714
 6715    # Compute the MAD of the residuals
 6716    residuals = data - curve
 6717    mad = np.median(np.abs(residuals - np.median(residuals, axis=-1, keepdims=True)), axis=-1, keepdims=True)
 6718    sigma = np.sqrt(np.pi / 2) * mad
 6719    # Ensure sigma is not zero to avoid division by zero
 6720    sigma_safe = np.where(sigma == 0, 1e-10, sigma)
 6721
 6722    # Optionally, handle NaN or inf values in data, curve, or sigma
 6723    data = np.nan_to_num(data, nan=0.0, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6724    curve = np.nan_to_num(curve, nan=0.0, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6725    sigma_safe = np.nan_to_num(sigma_safe, nan=1e-10, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6726
 6727    # Despike algorithm
 6728    spike_counts = np.zeros( image.shape[3] )
 6729    for i in range(data.shape[-1]):
 6730        s = (data[..., i] - curve[..., i]) / sigma_safe[..., 0]
 6731        ww = s > c1
 6732        s_prime = np.where( ww, c1 + (c2 - c1) * np.tanh((s - c1) / (c2 - c1)), s)
 6733        spike_counts[i] = ww.sum()
 6734        despiked_data[..., i] = curve[..., i] + s_prime * sigma[..., 0]
 6735
 6736    # Convert back to ANTsPy image
 6737    despiked_image = ants.from_numpy(despiked_data)
 6738    return ants.copy_image_info( image, despiked_image ), spike_counts
 6739
 6740def despike_time_series(image, threshold=3.0, replacement='threshold' ):
 6741    """
 6742    Despike a time series image.
 6743    
 6744    :param image: ANTsPy image object containing time series data.
 6745    :param threshold: z-score value to identify spikes. Default is 3.
 6746    :param replacement: median or threshold - the latter is similar 3DDespike but simpler
 6747    :return: Despiked ANTsPy image object.
 6748    """
 6749    # Convert image to numpy array
 6750    data = image.numpy()
 6751    
 6752    # Calculate the mean and standard deviation along the time axis
 6753    mean = np.mean(data, axis=-1)
 6754    std = np.std(data, axis=-1)
 6755
 6756    # Identify spikes: points where the deviation from the mean exceeds the threshold
 6757    spikes = np.abs(data - mean[..., np.newaxis]) > threshold * std[..., np.newaxis]
 6758
 6759    # Replace spike values
 6760    spike_counts = np.zeros( image.shape[3] )
 6761    for i in range(data.shape[-1]):
 6762        slice = data[..., i]
 6763        spike_locations = spikes[..., i]
 6764        spike_counts[i] = spike_locations.sum()
 6765        if replacement == 'median':
 6766            slice[spike_locations] = np.median(slice)  # Replace with median or another method
 6767        else:
 6768	    # Calculate threshold values (mean ± threshold * std)
 6769            threshold_values = mean + np.sign(slice - mean) * threshold * std
 6770            slice[spike_locations] = threshold_values[spike_locations]
 6771        data[..., i] = slice
 6772    # Convert back to ANTsPy image
 6773    despike_image = ants.from_numpy(data)
 6774    despike_image = ants.copy_image_info( image, despike_image )
 6775    return despike_image, spike_counts
 6776
 6777
 6778
 6779def bold_perfusion_minimal( 
 6780        fmri, 
 6781        m0_image = None,
 6782        spa = (0., 0., 0., 0.),
 6783        nc  = 0,
 6784        tc='alternating',
 6785        n_to_trim=0,
 6786        outlier_threshold=0.250,
 6787        plot_brain_mask=False,
 6788        verbose=False ):
 6789  """
 6790  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 6791
 6792  This function is intended for use in debugging/testing or when one lacks a T1w image.
 6793
 6794  Arguments
 6795  ---------
 6796
 6797  fmri : BOLD fmri antsImage
 6798
 6799  m0_image: a pre-defined m0 antsImage
 6800
 6801  spa : gaussian smoothing for spatial and temporal component e.g. (1,1,1,0) in physical space coordinates
 6802
 6803  nc  : number of components for compcor filtering
 6804
 6805  tc: string either alternating or split (default is alternating ie CTCTCT; split is CCCCTTTT)
 6806
 6807  n_to_trim: number of volumes to trim off the front of the time series to account for initial magnetic saturation effects or to allow the signal to reach a steady state. in some cases, trailing volumes or other outlier volumes may need to be rejected.  this code does not currently handle that issue.
 6808
 6809  outlier_threshold (numeric): between zero (remove all) and one (remove none); automatically calculates outlierness and uses it to censor the time series.
 6810
 6811  plot_brain_mask : boolean can help with checking data quality visually
 6812
 6813  verbose : boolean
 6814
 6815  Returns
 6816  ---------
 6817  a dictionary containing the derived network maps
 6818
 6819  """
 6820  import numpy as np
 6821  import pandas as pd
 6822  import re
 6823  import math
 6824  from sklearn.linear_model import RANSACRegressor, TheilSenRegressor, HuberRegressor, QuantileRegressor, LinearRegression, SGDRegressor
 6825  from sklearn.multioutput import MultiOutputRegressor
 6826  from sklearn.preprocessing import StandardScaler
 6827
 6828  def replicate_list(user_list, target_size):
 6829    # Calculate the number of times the list should be replicated
 6830    replication_factor = target_size // len(user_list)
 6831    # Replicate the list and handle any remaining elements
 6832    replicated_list = user_list * replication_factor
 6833    remaining_elements = target_size % len(user_list)
 6834    replicated_list += user_list[:remaining_elements]
 6835    return replicated_list
 6836
 6837  def one_hot_encode(char_list):
 6838    unique_chars = list(set(char_list))
 6839    encoding_dict = {char: [1 if char == c else 0 for c in unique_chars] for char in unique_chars}
 6840    encoded_matrix = np.array([encoding_dict[char] for char in char_list])
 6841    return encoded_matrix
 6842  
 6843  A = np.zeros((1,1))
 6844  fmri_template = ants.get_average_of_timeseries( fmri )
 6845  if n_to_trim is None:
 6846    n_to_trim=0
 6847  mytrim=n_to_trim
 6848  perf_total_sigma = 1.5
 6849  corrmo = timeseries_reg(
 6850    fmri, fmri_template,
 6851    type_of_transform='antsRegistrationSyNRepro[r]',
 6852    total_sigma=perf_total_sigma,
 6853    fdOffset=2.0,
 6854    trim = mytrim,
 6855    output_directory=None,
 6856    verbose=verbose,
 6857    syn_metric='CC',
 6858    syn_sampling=2,
 6859    reg_iterations=[40,20,5] )
 6860  if verbose:
 6861      print("End rsfmri motion correction")
 6862      print("--maximum motion : " + str(corrmo['FD'].max()) )
 6863
 6864  if m0_image is not None:
 6865      m0 = m0_image
 6866
 6867  ntp = corrmo['motion_corrected'].shape[3]
 6868  fmri_template = ants.get_average_of_timeseries( corrmo['motion_corrected'] )
 6869  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 6870  if plot_brain_mask:
 6871    ants.plot( fmri_template, bmask, axis=1, crop=True )
 6872    ants.plot( fmri_template, bmask, axis=2, crop=True )
 6873
 6874  if tc == 'alternating':
 6875      tclist = replicate_list( ['C','T'], ntp )
 6876  else:
 6877      tclist = replicate_list( ['C'], int(ntp/2) ) + replicate_list( ['T'],  int(ntp/2) )
 6878
 6879  tclist = one_hot_encode( tclist[0:ntp ] )
 6880  fmrimotcorr=corrmo['motion_corrected']
 6881  hlinds = None
 6882  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6883    fmrimotcorr, hlinds = loop_timeseries_censoring( fmrimotcorr, outlier_threshold, mask=None, verbose=verbose )
 6884    tclist = remove_elements_from_numpy_array( tclist, hlinds)
 6885    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 6886
 6887  # redo template and registration at (potentially) upsampled scale
 6888  fmri_template = ants.iMath( ants.get_average_of_timeseries( fmrimotcorr ), "Normalize" )
 6889  corrmo = timeseries_reg(
 6890        fmri, fmri_template,
 6891        type_of_transform='antsRegistrationSyNRepro[r]',
 6892        total_sigma=perf_total_sigma,
 6893        fdOffset=2.0,
 6894        trim = mytrim,
 6895        output_directory=None,
 6896        verbose=verbose,
 6897        syn_metric='CC',
 6898        syn_sampling=2,
 6899        reg_iterations=[40,20,5] )
 6900  if verbose:
 6901        print("End 2nd rsfmri motion correction")
 6902        print("--maximum motion : " + str(corrmo['FD'].max()) )
 6903
 6904  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6905    corrmo['motion_corrected'] = remove_volumes_from_timeseries( corrmo['motion_corrected'], hlinds )
 6906    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 6907
 6908  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 6909  if plot_brain_mask:
 6910    ants.plot( fmri_template, bmask, axis=1, crop=True )
 6911    ants.plot( fmri_template, bmask, axis=2, crop=True )
 6912
 6913  regression_mask = bmask.clone()
 6914  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 6915  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 6916  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",3)
 6917  bmask = bmask * ants.iMath( tsnrmask, "FillHoles" )
 6918  fmrimotcorr=corrmo['motion_corrected']
 6919  und = fmri_template * bmask
 6920  compcorquantile=0.50
 6921  mycompcor = ants.compcor( fmrimotcorr,
 6922    ncompcor=nc, quantile=compcorquantile, mask = bmask,
 6923    filter_type='polynomial', degree=2 )
 6924  tr = ants.get_spacing( fmrimotcorr )[3]
 6925  simg = ants.smooth_image(fmrimotcorr, spa, sigma_in_physical_coordinates = True )
 6926  nuisance = mycompcor['basis']
 6927  nuisance = np.c_[ nuisance, mycompcor['components'] ]
 6928  if verbose:
 6929    print("make sure nuisance is independent of TC")
 6930  nuisance = ants.regress_components( nuisance, tclist )
 6931  regression_mask = bmask.clone()
 6932  gmmat = ants.timeseries_to_matrix( simg, regression_mask )
 6933  regvars = np.hstack( (nuisance, tclist ))
 6934  coefind = regvars.shape[1]-1
 6935  regvars = regvars[:,range(coefind)]
 6936  predictor_of_interest_idx = regvars.shape[1]-1
 6937  valid_perf_models = ['huber','quantile','theilsen','ransac', 'sgd', 'linear','SM']
 6938  perfusion_regression_model='linear'
 6939  if verbose:
 6940    print( "begin perfusion estimation with " + perfusion_regression_model + " model " )
 6941  regression_model = LinearRegression()
 6942  regression_model.fit( regvars, gmmat )
 6943  coefind = regression_model.coef_.shape[1]-1
 6944  perfimg = ants.make_image( regression_mask, regression_model.coef_[:,coefind] )
 6945  gmseg = ants.image_clone( bmask )
 6946  meangmval = ( perfimg[ gmseg == 1 ] ).mean()
 6947  if meangmval < 0:
 6948      perfimg = perfimg * (-1.0)
 6949  negative_voxels = ( perfimg < 0.0 ).sum() / np.prod( perfimg.shape ) * 100.0
 6950  perfimg[ perfimg < 0.0 ] = 0.0 # non-physiological
 6951
 6952  if m0_image is None:
 6953    m0 = ants.get_average_of_timeseries( fmrimotcorr )
 6954  else:
 6955    # register m0 to current template
 6956    m0reg = ants.registration( fmri_template, m0, 'antsRegistrationSyNRepro[r]', verbose=False )
 6957    m0 = m0reg['warpedmovout']
 6958
 6959  if ntp == 2 :
 6960      img0 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=0 )
 6961      img1 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=1 )
 6962      if m0_image is None:
 6963        if img0.mean() < img1.mean():
 6964            perfimg=img0
 6965            m0=img1
 6966        else:
 6967            perfimg=img1
 6968            m0=img0
 6969      else:
 6970        if img0.mean() < img1.mean():
 6971            perfimg=img1-img0
 6972        else:
 6973            perfimg=img0-img1
 6974  
 6975  cbf = calculate_CBF( Delta_M=perfimg, M_0=m0, mask=bmask )
 6976  meangmval = ( perfimg[ gmseg == 1 ] ).mean()        
 6977  meangmvalcbf = ( cbf[ gmseg == 1 ] ).mean()
 6978  if verbose:
 6979    print("perfimg.max() " + str(  perfimg.max() ) )
 6980  outdict = {}
 6981  outdict['meanBold'] = und
 6982  outdict['brainmask'] = bmask
 6983  rsfNuisance = pd.DataFrame( nuisance )
 6984  rsfNuisance['FD']=corrmo['FD']
 6985  outdict['perfusion']=perfimg
 6986  outdict['cbf']=cbf
 6987  outdict['m0']=m0
 6988  outdict['perfusion_gm_mean']=meangmval
 6989  outdict['cbf_gm_mean']=meangmvalcbf
 6990  outdict['motion_corrected'] = corrmo['motion_corrected']
 6991  outdict['brain_mask'] = bmask
 6992  outdict['nuisance'] = rsfNuisance
 6993  outdict['tsnr'] = mytsnr
 6994  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 6995  outdict['FD_max'] = rsfNuisance['FD'].max()
 6996  outdict['FD_mean'] = rsfNuisance['FD'].mean()
 6997  outdict['FD_sd'] = rsfNuisance['FD'].std()
 6998  outdict['outlier_volumes']=hlinds
 6999  outdict['negative_voxels']=negative_voxels
 7000  return convert_np_in_dict( outdict )
 7001
 7002
 7003
 7004def warn_if_small_mask( mask: ants.ANTsImage, threshold_fraction: float = 0.05, label: str = ' ' ):
 7005    """
 7006    Warn the user if the number of non-zero voxels in the mask
 7007    is less than a given fraction of the total number of voxels in the mask.
 7008
 7009    Parameters
 7010    ----------
 7011    mask : ants.ANTsImage
 7012        The binary mask to evaluate.
 7013    threshold_fraction : float, optional
 7014        Fraction threshold below which a warning is triggered (default is 0.05).
 7015    
 7016    Returns
 7017    -------
 7018    None
 7019    """
 7020    import warnings
 7021    image_size = np.prod(mask.shape)
 7022    mask_size = np.count_nonzero(mask.numpy())
 7023    if mask_size / image_size < threshold_fraction:
 7024        percentage = 100.0 * mask_size / image_size
 7025        warnings.warn(
 7026            f"[ants] Warning: {label} contains only {mask_size} voxels "
 7027            f"({percentage:.2f}% of image volume). "
 7028            f"This is below the threshold of {threshold_fraction * 100:.2f}% and may lead to unreliable results.",
 7029            UserWarning
 7030        )
 7031
 7032def bold_perfusion( 
 7033    fmri, t1head, t1, t1segmentation, t1dktcit,
 7034                   FD_threshold=0.5,
 7035                   spa = (0., 0., 0., 0.),
 7036                   nc = 3,
 7037                   type_of_transform='antsRegistrationSyNRepro[r]',
 7038                   tc='alternating',
 7039                   n_to_trim=0,
 7040                   m0_image = None,
 7041                   m0_indices=None,
 7042                   outlier_threshold=0.250,
 7043                   add_FD_to_nuisance=False,
 7044                   n3=False,
 7045                   segment_timeseries=False,
 7046                   trim_the_mask=4.25,
 7047                   upsample=True,
 7048                   perfusion_regression_model='linear',
 7049                   verbose=False ):
 7050  """
 7051  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 7052
 7053  Arguments
 7054  ---------
 7055  fmri : BOLD fmri antsImage
 7056
 7057  t1head : ANTsImage
 7058    input 3-D T1 brain image (not brain extracted)
 7059
 7060  t1 : ANTsImage
 7061    input 3-D T1 brain image (brain extracted)
 7062
 7063  t1segmentation : ANTsImage
 7064    t1 segmentation - a six tissue segmentation image in T1 space
 7065
 7066  t1dktcit : ANTsImage
 7067    t1 dkt cortex plus cit parcellation
 7068
 7069  spa : gaussian smoothing for spatial and temporal component e.g. (1,1,1,0) in physical space coordinates
 7070
 7071  nc  : number of components for compcor filtering
 7072
 7073  type_of_transform : SyN or Rigid
 7074
 7075  tc: string either alternating or split (default is alternating ie CTCTCT; split is CCCCTTTT)
 7076
 7077  n_to_trim: number of volumes to trim off the front of the time series to account for initial magnetic saturation effects or to allow the signal to reach a steady state. in some cases, trailing volumes or other outlier volumes may need to be rejected.  this code does not currently handle that issue.
 7078
 7079  m0_image: a pre-defined m0 image - we expect this to be 3D.  if it is not, we naively 
 7080    average over the 4th dimension.
 7081
 7082  m0_indices: which indices in the perfusion image are the m0.  if set, n_to_trim will be ignored.
 7083
 7084  outlier_threshold (numeric): between zero (remove all) and one (remove none); automatically calculates outlierness and uses it to censor the time series.
 7085
 7086  add_FD_to_nuisance: boolean
 7087
 7088  n3: boolean
 7089
 7090  segment_timeseries : boolean
 7091
 7092  trim_the_mask : float >= 0 post-hoc method for trimming the mask
 7093
 7094  upsample: boolean
 7095
 7096  perfusion_regression_model: string 'linear', 'ransac', 'theilsen', 'huber', 'quantile', 'sgd'; 'linear' and 'huber' are the only ones that work ok by default and are relatively quick to compute.
 7097
 7098  verbose : boolean
 7099
 7100  Returns
 7101  ---------
 7102  a dictionary containing the derived network maps
 7103
 7104  """
 7105  import numpy as np
 7106  import pandas as pd
 7107  import re
 7108  import math
 7109  from sklearn.linear_model import RANSACRegressor, TheilSenRegressor, HuberRegressor, QuantileRegressor, LinearRegression, SGDRegressor
 7110  from sklearn.multioutput import MultiOutputRegressor
 7111  from sklearn.preprocessing import StandardScaler
 7112
 7113  # remove outlier volumes
 7114  if segment_timeseries:
 7115    lo_vs_high = segment_timeseries_by_meanvalue(fmri)
 7116    fmri = remove_volumes_from_timeseries( fmri, lo_vs_high['lowermeans'] )
 7117
 7118  ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7119  cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7120
 7121  if n3:
 7122    fmri = timeseries_n3( fmri )
 7123
 7124  if m0_image is not None:
 7125    if m0_image.dimension == 4:
 7126      m0_image = ants.get_average_of_timeseries( m0_image )
 7127
 7128  def select_regression_model(regression_model, min_samples=10 ):
 7129    if regression_model == 'sgd' :
 7130      sgd_regressor = SGDRegressor(penalty='elasticnet', alpha=1e-5, l1_ratio=0.15, max_iter=20000, tol=1e-3, random_state=42)
 7131      return sgd_regressor
 7132    elif regression_model == 'ransac':
 7133      ransac = RANSACRegressor(
 7134            min_samples=0.8,
 7135            max_trials=10,         # Maximum number of iterations
 7136#            min_samples=min_samples, # Minimum number samples to be chosen as inliers in each iteration
 7137#            stop_probability=0.80,  # Probability to stop the algorithm if a good subset is found
 7138#            stop_n_inliers=40,      # Stop if this number of inliers is found
 7139#            stop_score=0.8,         # Stop if the model score reaches this value
 7140#            n_jobs=int(os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")) # Use all available CPU cores for parallel processing
 7141        )
 7142      return ransac
 7143    models = {
 7144        'sgd':SGDRegressor,
 7145        'ransac': RANSACRegressor,
 7146        'theilsen': TheilSenRegressor,
 7147        'huber': HuberRegressor,
 7148        'quantile': QuantileRegressor
 7149    }
 7150    return models.get(regression_model.lower(), LinearRegression)()
 7151
 7152  def replicate_list(user_list, target_size):
 7153    # Calculate the number of times the list should be replicated
 7154    replication_factor = target_size // len(user_list)
 7155    # Replicate the list and handle any remaining elements
 7156    replicated_list = user_list * replication_factor
 7157    remaining_elements = target_size % len(user_list)
 7158    replicated_list += user_list[:remaining_elements]
 7159    return replicated_list
 7160
 7161  def one_hot_encode(char_list):
 7162    unique_chars = list(set(char_list))
 7163    encoding_dict = {char: [1 if char == c else 0 for c in unique_chars] for char in unique_chars}
 7164    encoded_matrix = np.array([encoding_dict[char] for char in char_list])
 7165    return encoded_matrix
 7166  
 7167  A = np.zeros((1,1))
 7168  # fmri = ants.iMath( fmri, 'Normalize' )
 7169  fmri_template, hlinds = loop_timeseries_censoring( fmri, 0.10 )
 7170  fmri_template = ants.get_average_of_timeseries( fmri_template )
 7171  del hlinds
 7172  rig = ants.registration( fmri_template, t1head, 'antsRegistrationSyNRepro[r]' )
 7173  bmask = ants.apply_transforms( fmri_template, ants.threshold_image(t1segmentation,1,6), rig['fwdtransforms'][0], interpolator='genericLabel' )
 7174  if m0_indices is None:
 7175    if n_to_trim is None:
 7176        n_to_trim=0
 7177    mytrim=n_to_trim
 7178  else:
 7179    mytrim = 0
 7180  perf_total_sigma = 1.5
 7181  corrmo = timeseries_reg(
 7182    fmri, fmri_template,
 7183    type_of_transform=type_of_transform,
 7184    total_sigma=perf_total_sigma,
 7185    fdOffset=2.0,
 7186    trim = mytrim,
 7187    output_directory=None,
 7188    verbose=verbose,
 7189    syn_metric='CC',
 7190    syn_sampling=2,
 7191    reg_iterations=[40,20,5] )
 7192  if verbose:
 7193      print("End rsfmri motion correction")
 7194
 7195  if m0_image is not None:
 7196      m0 = m0_image
 7197  elif m0_indices is not None:
 7198    not_m0 = list( range( fmri.shape[3] ) )
 7199    not_m0 = [x for x in not_m0 if x not in m0_indices]
 7200    if verbose:
 7201        print( m0_indices )
 7202        print( not_m0 )
 7203    # then remove it from the time series
 7204    m0 = remove_volumes_from_timeseries( corrmo['motion_corrected'], not_m0 )
 7205    m0 = ants.get_average_of_timeseries( m0 )
 7206    corrmo['motion_corrected'] = remove_volumes_from_timeseries( 
 7207        corrmo['motion_corrected'], m0_indices )
 7208    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], m0_indices )
 7209    fmri = remove_volumes_from_timeseries( fmri, m0_indices )
 7210
 7211  ntp = corrmo['motion_corrected'].shape[3]
 7212  if tc == 'alternating':
 7213      tclist = replicate_list( ['C','T'], ntp )
 7214  else:
 7215      tclist = replicate_list( ['C'], int(ntp/2) ) + replicate_list( ['T'],  int(ntp/2) )
 7216
 7217  tclist = one_hot_encode( tclist[0:ntp ] )
 7218  fmrimotcorr=corrmo['motion_corrected']
 7219  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 7220    fmrimotcorr, hlinds = loop_timeseries_censoring( fmrimotcorr, outlier_threshold, mask=None, verbose=verbose )
 7221    tclist = remove_elements_from_numpy_array( tclist, hlinds)
 7222    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 7223
 7224  # redo template and registration at (potentially) upsampled scale
 7225  fmri_template = ants.iMath( ants.get_average_of_timeseries( fmrimotcorr ), "Normalize" )
 7226  if upsample:
 7227      spc = ants.get_spacing( fmri )
 7228      minspc = 2.0
 7229      if min(spc[0:3]) < minspc:
 7230          minspc = min(spc[0:3])
 7231      newspc = [minspc,minspc,minspc]
 7232      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
 7233
 7234  if verbose:
 7235      print( 'fmri_template')
 7236      print( fmri_template )
 7237
 7238  rig = ants.registration( fmri_template, t1head, 'antsRegistrationSyNRepro[r]' )
 7239  bmask = ants.apply_transforms( fmri_template, 
 7240    ants.threshold_image(t1segmentation,1,6), 
 7241    rig['fwdtransforms'][0], 
 7242    interpolator='genericLabel' )
 7243
 7244  warn_if_small_mask( bmask, label='bold_perfusion:bmask')
 7245
 7246  corrmo = timeseries_reg(
 7247        fmri, fmri_template,
 7248        type_of_transform=type_of_transform,
 7249        total_sigma=perf_total_sigma,
 7250        fdOffset=2.0,
 7251        trim = mytrim,
 7252        output_directory=None,
 7253        verbose=verbose,
 7254        syn_metric='CC',
 7255        syn_sampling=2,
 7256        reg_iterations=[40,20,5] )
 7257  if verbose:
 7258        print("End 2nd rsfmri motion correction")
 7259
 7260  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 7261    corrmo['motion_corrected'] = remove_volumes_from_timeseries( corrmo['motion_corrected'], hlinds )
 7262    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 7263
 7264  regression_mask = bmask.clone()
 7265  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 7266  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 7267  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",3)
 7268  bmask = bmask * ants.iMath( tsnrmask, "FillHoles" )
 7269  warn_if_small_mask( bmask, label='bold_perfusion:bmask*tsnrmask')
 7270  fmrimotcorr=corrmo['motion_corrected']
 7271  und = fmri_template * bmask
 7272  t1reg = ants.registration( und, t1, "antsRegistrationSyNRepro[s]" )
 7273  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 7274  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 7275  gmseg = ants.threshold_image( gmseg, 1, 4 )
 7276  gmseg = ants.iMath( gmseg, 'MD', 1 )
 7277  gmseg = ants.apply_transforms( und, gmseg,
 7278    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7279  csfseg = ants.threshold_image( t1segmentation, 1, 1 )
 7280  wmseg = ants.threshold_image( t1segmentation, 3, 3 )
 7281  csfAndWM = ( csfseg + wmseg ).morphology("erode",1)
 7282  compcorquantile=0.50
 7283  csfAndWM = ants.apply_transforms( und, csfAndWM,
 7284    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7285  csfseg = ants.apply_transforms( und, csfseg,
 7286    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7287  wmseg = ants.apply_transforms( und, wmseg,
 7288    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7289  warn_if_small_mask( wmseg, label='bold_perfusion:wmseg')
 7290  # warn_if_small_mask( csfseg, threshold_fraction=0.01, label='bold_perfusion:csfseg')
 7291  warn_if_small_mask( csfAndWM, label='bold_perfusion:csfAndWM')
 7292  mycompcor = ants.compcor( fmrimotcorr,
 7293    ncompcor=nc, quantile=compcorquantile, mask = csfAndWM,
 7294    filter_type='polynomial', degree=2 )
 7295  tr = ants.get_spacing( fmrimotcorr )[3]
 7296  simg = ants.smooth_image(fmrimotcorr, spa, sigma_in_physical_coordinates = True )
 7297  nuisance = mycompcor['basis']
 7298  nuisance = np.c_[ nuisance, mycompcor['components'] ]
 7299  if add_FD_to_nuisance:
 7300    nuisance = np.c_[ nuisance, corrmo['FD'] ]
 7301  if verbose:
 7302    print("make sure nuisance is independent of TC")
 7303  nuisance = ants.regress_components( nuisance, tclist )
 7304  regression_mask = bmask.clone()
 7305  gmmat = ants.timeseries_to_matrix( simg, regression_mask )
 7306  regvars = np.hstack( (nuisance, tclist ))
 7307  coefind = regvars.shape[1]-1
 7308  regvars = regvars[:,range(coefind)]
 7309  predictor_of_interest_idx = regvars.shape[1]-1
 7310  valid_perf_models = ['huber','quantile','theilsen','ransac', 'sgd', 'linear','SM']
 7311  if verbose:
 7312    print( "begin perfusion estimation with " + perfusion_regression_model + " model " )
 7313  if perfusion_regression_model == 'linear':
 7314    regression_model = LinearRegression()
 7315    regression_model.fit( regvars, gmmat )
 7316    coefind = regression_model.coef_.shape[1]-1
 7317    perfimg = ants.make_image( regression_mask, regression_model.coef_[:,coefind] )
 7318  elif perfusion_regression_model == 'SM': #
 7319    import statsmodels.api as sm
 7320    coeffs = np.zeros( gmmat.shape[1] )
 7321    # Loop over each outcome column in the outcomes matrix
 7322    for outcome_idx in range(gmmat.shape[1]):
 7323        outcome = gmmat[:, outcome_idx]  # Select one outcome column
 7324        model = sm.RLM(outcome, sm.add_constant(regvars), M=sm.robust.norms.HuberT())  # Huber's T norm for robust regression
 7325        results = model.fit()
 7326        coefficients = results.params  # Coefficients of all predictors
 7327        coeffs[outcome_idx] = coefficients[predictor_of_interest_idx]
 7328    perfimg = ants.make_image( regression_mask, coeffs )
 7329  elif perfusion_regression_model in valid_perf_models :
 7330    scaler = StandardScaler()
 7331    gmmat = scaler.fit_transform(gmmat)
 7332    coeffs = np.zeros( gmmat.shape[1] )
 7333    huber_regressor = select_regression_model( perfusion_regression_model )
 7334    multioutput_model = MultiOutputRegressor(huber_regressor)
 7335    multioutput_model.fit( regvars, gmmat )
 7336    ct=0
 7337    for i, estimator in enumerate(multioutput_model.estimators_):
 7338      coefficients = estimator.coef_
 7339      coeffs[ct]=coefficients[predictor_of_interest_idx]
 7340      ct=ct+1
 7341    perfimg = ants.make_image( regression_mask, coeffs )
 7342  else:
 7343    raise ValueError( perfusion_regression_model + " regression model is not found.")
 7344  meangmval = ( perfimg[ gmseg == 1 ] ).mean()
 7345  if meangmval < 0:
 7346      perfimg = perfimg * (-1.0)
 7347  negative_voxels = ( perfimg < 0.0 ).sum() / np.prod( perfimg.shape ) * 100.0
 7348  perfimg[ perfimg < 0.0 ] = 0.0 # non-physiological
 7349
 7350  # LaTeX code for Cerebral Blood Flow (CBF) calculation using ASL MRI
 7351  """
 7352CBF = \\frac{\\Delta M \\cdot \\lambda}{2 \\cdot T_1 \\cdot \\alpha \\cdot M_0 \\cdot (e^{-\\frac{w}{T_1}} - e^{-\\frac{w + \\tau}{T_1}})}
 7353
 7354Where:
 7355- \\Delta M is the difference in magnetization between labeled and control images.
 7356- \\lambda is the brain-blood partition coefficient, typically around 0.9 mL/g.
 7357- T_1 is the longitudinal relaxation time of blood, which is a tissue-specific constant.
 7358- \\alpha is the labeling efficiency.
 7359- M_0 is the equilibrium magnetization of brain tissue (from the M0 image).
 7360- w is the post-labeling delay, the time between the end of the labeling and the acquisition of the image.
 7361- \\tau is the labeling duration.
 7362  """
 7363  if m0_indices is None and m0_image is None:
 7364    m0 = ants.get_average_of_timeseries( fmrimotcorr )
 7365  else:
 7366    # register m0 to current template
 7367    m0reg = ants.registration( fmri_template, m0, 'antsRegistrationSyNRepro[r]', verbose=False )
 7368    m0 = m0reg['warpedmovout']
 7369
 7370  if ntp == 2 :
 7371      img0 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=0 )
 7372      img1 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=1 )
 7373      if m0_image is None:
 7374        if img0.mean() < img1.mean():
 7375            perfimg=img0
 7376            m0=img1
 7377        else:
 7378            perfimg=img1
 7379            m0=img0
 7380      else:
 7381        if img0.mean() < img1.mean():
 7382            perfimg=img1-img0
 7383        else:
 7384            perfimg=img0-img1
 7385
 7386  cbf = calculate_CBF(
 7387      Delta_M=perfimg, M_0=m0, mask=bmask )
 7388  if trim_the_mask > 0.0 :
 7389    bmask = trim_dti_mask( cbf, bmask, trim_the_mask )
 7390    perfimg = perfimg * bmask
 7391    cbf = cbf * bmask
 7392
 7393  meangmval = ( perfimg[ gmseg == 1 ] ).mean()        
 7394  meangmvalcbf = ( cbf[ gmseg == 1 ] ).mean()
 7395  if verbose:
 7396    print("perfimg.max() " + str(  perfimg.max() ) )
 7397  outdict = {}
 7398  outdict['meanBold'] = und
 7399  outdict['brainmask'] = bmask
 7400  rsfNuisance = pd.DataFrame( nuisance )
 7401  rsfNuisance['FD']=corrmo['FD']
 7402
 7403  if verbose:
 7404      print("perfusion dataframe begin")
 7405  dktseg = ants.apply_transforms( und, t1dktcit,
 7406    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7407  df_perf = antspyt1w.map_intensity_to_dataframe(
 7408        'dkt_cortex_cit_deep_brain',
 7409        perfimg,
 7410        dktseg)
 7411  df_perf = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7412              {'perf' : df_perf},
 7413              col_names = ['Mean'] )
 7414  df_cbf = antspyt1w.map_intensity_to_dataframe(
 7415        'dkt_cortex_cit_deep_brain',
 7416        cbf,
 7417        dktseg)
 7418  df_cbf = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7419              {'cbf' : df_cbf},
 7420              col_names = ['Mean'] )
 7421  df_cbf = df_cbf.add_prefix('cbf_')
 7422  df_perf = pd.concat( [df_perf,df_cbf], axis=1, ignore_index=False )
 7423  if verbose:
 7424      print("perfusion dataframe end")
 7425
 7426  outdict['perfusion']=perfimg
 7427  outdict['cbf']=cbf
 7428  outdict['m0']=m0
 7429  outdict['perfusion_gm_mean']=meangmval
 7430  outdict['cbf_gm_mean']=meangmvalcbf
 7431  outdict['perf_dataframe']=df_perf
 7432  outdict['motion_corrected'] = corrmo['motion_corrected']
 7433  outdict['gmseg'] = gmseg
 7434  outdict['brain_mask'] = bmask
 7435  outdict['nuisance'] = rsfNuisance
 7436  outdict['tsnr'] = mytsnr
 7437  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
 7438  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 7439  outdict['high_motion_count'] = (rsfNuisance['FD'] > FD_threshold ).sum()
 7440  outdict['high_motion_pct'] = (rsfNuisance['FD'] > FD_threshold ).sum() / rsfNuisance.shape[0]
 7441  outdict['FD_max'] = rsfNuisance['FD'].max()
 7442  outdict['FD_mean'] = rsfNuisance['FD'].mean()
 7443  outdict['FD_sd'] = rsfNuisance['FD'].std()
 7444  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
 7445  outdict['t1reg'] = t1reg
 7446  outdict['outlier_volumes']=hlinds
 7447  outdict['n_outliers']=len(hlinds)
 7448  outdict['negative_voxels']=negative_voxels
 7449  return convert_np_in_dict( outdict )
 7450
 7451
 7452def pet3d_summary( pet3d, t1head, t1, t1segmentation, t1dktcit,
 7453                   spa = (0., 0., 0.),
 7454                   type_of_transform='antsRegistrationSyNRepro[r]',
 7455                   upsample=True,
 7456                   verbose=False ):
 7457  """
 7458  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 7459
 7460  Arguments
 7461  ---------
 7462  pet3d : 3D PET antsImage
 7463
 7464  t1head : ANTsImage
 7465    input 3-D T1 brain image (not brain extracted)
 7466
 7467  t1 : ANTsImage
 7468    input 3-D T1 brain image (brain extracted)
 7469
 7470  t1segmentation : ANTsImage
 7471    t1 segmentation - a six tissue segmentation image in T1 space
 7472
 7473  t1dktcit : ANTsImage
 7474    t1 dkt cortex plus cit parcellation
 7475
 7476  type_of_transform : SyN or Rigid
 7477
 7478  upsample: boolean
 7479
 7480  verbose : boolean
 7481
 7482  Returns
 7483  ---------
 7484  a dictionary containing the derived network maps
 7485
 7486  """
 7487  import numpy as np
 7488  import pandas as pd
 7489  import re
 7490  import math
 7491
 7492  ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7493  cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7494  
 7495  pet3dr=pet3d
 7496  if upsample:
 7497      spc = ants.get_spacing( pet3d )
 7498      minspc = 1.0
 7499      if min(spc) < minspc:
 7500        minspc = min(spc)
 7501      newspc = [minspc,minspc,minspc]
 7502      pet3dr = ants.resample_image( pet3d, newspc, interp_type=0 )
 7503
 7504  rig = ants.registration( pet3dr, t1head, 'antsRegistrationSyNRepro[r]' )
 7505  bmask = ants.apply_transforms( pet3dr, 
 7506    ants.threshold_image(t1segmentation,1,6), 
 7507    rig['fwdtransforms'][0], 
 7508    interpolator='genericLabel' )
 7509  if verbose:
 7510    print("End t1=>pet registration")
 7511
 7512  und = pet3dr * bmask
 7513#  t1reg = ants.registration( und, t1, "antsRegistrationSyNQuickRepro[s]" )
 7514  t1reg = rig # ants.registration( und, t1, "Rigid" )
 7515  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 7516  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 7517  gmseg = ants.threshold_image( gmseg, 1, 4 )
 7518  gmseg = ants.iMath( gmseg, 'MD', 1 )
 7519  gmseg = ants.apply_transforms( und, gmseg,
 7520    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7521  csfseg = ants.threshold_image( t1segmentation, 1, 1 )
 7522  wmseg = ants.threshold_image( t1segmentation, 3, 3 )
 7523  csfAndWM = ( csfseg + wmseg ).morphology("erode",1)
 7524  csfAndWM = ants.apply_transforms( und, csfAndWM,
 7525    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7526  csfseg = ants.apply_transforms( und, csfseg,
 7527    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7528  wmseg = ants.apply_transforms( und, wmseg,
 7529    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7530  wmsignal = pet3dr[ ants.iMath(wmseg,"ME",1) == 1 ].mean()
 7531  gmsignal = pet3dr[ gmseg == 1 ].mean()
 7532  csfsignal = pet3dr[ csfseg == 1 ].mean()
 7533  if verbose:
 7534    print("pet3dr.max() " + str(  pet3dr.max() ) )
 7535  if verbose:
 7536      print("pet3d dataframe begin")
 7537  dktseg = ants.apply_transforms( und, t1dktcit,
 7538    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7539  df_pet3d = antspyt1w.map_intensity_to_dataframe(
 7540        'dkt_cortex_cit_deep_brain',
 7541        und,
 7542        dktseg)
 7543  df_pet3d = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7544              {'pet3d' : df_pet3d},
 7545              col_names = ['Mean'] )
 7546  if verbose:
 7547      print("pet3d dataframe end")
 7548
 7549  outdict = {}
 7550  outdict['pet3d_dataframe']=df_pet3d
 7551  outdict['pet3d'] = pet3dr
 7552  outdict['brainmask'] = bmask
 7553  outdict['gm_mean']=gmsignal
 7554  outdict['wm_mean']=wmsignal
 7555  outdict['csf_mean']=csfsignal
 7556  outdict['pet3d_dataframe']=df_pet3d
 7557  outdict['t1reg'] = t1reg
 7558  return convert_np_in_dict( outdict )
 7559
 7560
 7561def write_bvals_bvecs(bvals, bvecs, prefix ):
 7562    ''' Write FSL FDT bvals and bvecs files
 7563
 7564    adapted from dipy.external code
 7565
 7566    Parameters
 7567    -------------
 7568    bvals : (N,) sequence
 7569       Vector with diffusion gradient strength (one per diffusion
 7570       acquisition, N=no of acquisitions)
 7571    bvecs : (N, 3) array-like
 7572       diffusion gradient directions
 7573    prefix : string
 7574       path to write FDT bvals, bvecs text files
 7575       None results in current working directory.
 7576    '''
 7577    _VAL_FMT = '   %e'
 7578    bvals = tuple(bvals)
 7579    bvecs = np.asarray(bvecs)
 7580    bvecs[np.isnan(bvecs)] = 0
 7581    N = len(bvals)
 7582    fname = prefix + '.bval'
 7583    fmt = _VAL_FMT * N + '\n'
 7584    myfile = open(fname, 'wt')
 7585    myfile.write(fmt % bvals)
 7586    myfile.close()
 7587    fname = prefix + '.bvec'
 7588    bvf = open(fname, 'wt')
 7589    for dim_vals in bvecs.T:
 7590        bvf.write(fmt % tuple(dim_vals))
 7591    bvf.close()
 7592    
 7593
 7594def crop_mcimage( x, mask, padder=None ):
 7595    """
 7596    crop a time series (4D) image by a 3D mask
 7597
 7598    Parameters
 7599    -------------
 7600
 7601    x : raw image
 7602
 7603    mask  : mask for cropping
 7604
 7605    """
 7606    cropmask = ants.crop_image( mask, mask )
 7607    myorig = list( ants.get_origin(cropmask) )
 7608    myorig.append( ants.get_origin( x )[3] )
 7609    croplist = []
 7610    if len(x.shape) > 3:
 7611        for k in range(x.shape[3]):
 7612            temp = ants.slice_image( x, axis=3, idx=k )
 7613            temp = ants.crop_image( temp, mask )
 7614            if padder is not None:
 7615                temp = ants.pad_image( temp, pad_width=padder )
 7616            croplist.append( temp )
 7617        temp = ants.list_to_ndimage( x, croplist )
 7618        temp.set_origin( myorig )
 7619        return temp
 7620    else:
 7621        return( ants.crop_image( x, mask ) )
 7622
 7623
 7624def mm(
 7625    t1_image,
 7626    hier,
 7627    rsf_image=[],
 7628    flair_image=None,
 7629    nm_image_list=None,
 7630    dw_image=[], bvals=[], bvecs=[],
 7631    perfusion_image=None,
 7632    srmodel=None,
 7633    do_tractography = False,
 7634    do_kk = False,
 7635    do_normalization = None,
 7636    group_template = None,
 7637    group_transform = None,
 7638    target_range = [0,1],
 7639    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
 7640    dti_denoise = False,
 7641    perfusion_trim=10,
 7642    perfusion_m0_image=None,
 7643    perfusion_m0=None,
 7644    rsf_upsampling=3.0,
 7645    pet_3d_image=None,
 7646    test_run = False,
 7647    verbose = False ):
 7648    """
 7649    Multiple modality processing and normalization
 7650
 7651    aggregates modality-specific processing under one roof.  see individual
 7652    modality specific functions for details.
 7653
 7654    Parameters
 7655    -------------
 7656
 7657    t1_image : raw t1 image
 7658
 7659    hier  : output of antspyt1w.hierarchical ( see read hierarchical )
 7660
 7661    rsf_image : list of resting state fmri
 7662
 7663    flair_image : flair
 7664
 7665    nm_image_list : list of neuromelanin images
 7666
 7667    dw_image : list of diffusion weighted images
 7668
 7669    bvals : list of bvals file names
 7670
 7671    bvecs : list of bvecs file names
 7672
 7673    perfusion_image : single perfusion image
 7674
 7675    srmodel : optional srmodel
 7676
 7677    do_tractography : boolean
 7678
 7679    do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)
 7680
 7681    do_normalization : template transformation if available
 7682
 7683    group_template : optional reference template corresponding to the group_transform
 7684
 7685    group_transform : optional transforms corresponding to the group_template
 7686
 7687    target_range : 2-element tuple
 7688        a tuple or array defining the (min, max) of the input image
 7689        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 7690        intensity. This range should match the mapping used in the training
 7691        of the network.
 7692    
 7693    dti_motion_correct : None Rigid or SyN
 7694
 7695    dti_denoise : boolean
 7696
 7697    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
 7698
 7699    perfusion_m0_image : optional antsImage m0 associated with the perfusion time series
 7700
 7701    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
 7702
 7703    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
 7704
 7705    pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of 
 7706        this image.  we just process it and provide summary information.
 7707
 7708    test_run : boolean 
 7709
 7710    verbose : boolean
 7711
 7712    """
 7713    from os.path import exists
 7714    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7715    ex_path_mm = os.path.expanduser( "~/.antspymm/" )
 7716    mycsvfn = ex_path + "FA_JHU_labels_edited.csv"
 7717    citcsvfn = ex_path + "CIT168_Reinf_Learn_v1_label_descriptions_pad.csv"
 7718    dktcsvfn = ex_path + "dkt.csv"
 7719    cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7720    JHU_atlasfn = ex_path + 'JHU-ICBM-FA-1mm.nii.gz' # Read in JHU atlas
 7721    JHU_labelsfn = ex_path + 'JHU-ICBM-labels-1mm.nii.gz' # Read in JHU labels
 7722    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 7723    if not exists( mycsvfn ) or not exists( citcsvfn ) or not exists( cnxcsvfn ) or not exists( dktcsvfn ) or not exists( JHU_atlasfn ) or not exists( JHU_labelsfn ) or not exists( templatefn ):
 7724        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 7725        raise ValueError('**missing files** => call get_data from latest antspyt1w and antspymm.')
 7726    mycsv = pd.read_csv(  mycsvfn )
 7727    citcsv = pd.read_csv(  os.path.expanduser( citcsvfn ) )
 7728    dktcsv = pd.read_csv(  os.path.expanduser( dktcsvfn ) )
 7729    cnxcsv = pd.read_csv(  os.path.expanduser( cnxcsvfn ) )
 7730    JHU_atlas = mm_read( JHU_atlasfn ) # Read in JHU atlas
 7731    JHU_labels = mm_read( JHU_labelsfn ) # Read in JHU labels
 7732    template = mm_read( templatefn ) # Read in template
 7733    if group_template is None:
 7734        group_template = template
 7735        group_transform = do_normalization['fwdtransforms']
 7736    if verbose:
 7737        print("Using group template:")
 7738        print( group_template )
 7739    #####################
 7740    #  T1 hierarchical  #
 7741    #####################
 7742    t1imgbrn = hier['brain_n4_dnz']
 7743    t1atropos = hier['dkt_parc']['tissue_segmentation']
 7744    output_dict = {
 7745        'kk': None,
 7746        'rsf': None,
 7747        'flair' : None,
 7748        'NM' : None,
 7749        'DTI' : None,
 7750        'FA_summ' : None,
 7751        'MD_summ' : None,
 7752        'tractography' : None,
 7753        'tractography_connectivity' : None,
 7754        'perf' : None,
 7755        'pet3d' : None,
 7756    }
 7757    normalization_dict = {
 7758        'kk_norm': None,
 7759        'NM_norm' : None,
 7760        'DTI_norm': None,
 7761        'FA_norm' : None,
 7762        'MD_norm' : None,
 7763        'perf_norm' : None,
 7764        'alff_norm' : None,
 7765        'falff_norm' : None,
 7766        'CinguloopercularTaskControl_norm' : None,
 7767        'DefaultMode_norm' : None,
 7768        'MemoryRetrieval_norm' : None,
 7769        'VentralAttention_norm' : None,
 7770        'Visual_norm' : None,
 7771        'FrontoparietalTaskControl_norm' : None,
 7772        'Salience_norm' : None,
 7773        'Subcortical_norm' : None,
 7774        'DorsalAttention_norm' : None,
 7775        'pet3d_norm' : None
 7776    }
 7777    if test_run:
 7778        return output_dict, normalization_dict
 7779
 7780    if do_kk:
 7781        if verbose:
 7782            print('kk in mm')
 7783        output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( t1_image,
 7784            labels=hier['dkt_parc']['dkt_cortex'], iterations=45 )
 7785
 7786    if perfusion_image is not None:
 7787        if perfusion_image.shape[3] > 1: # FIXME - better heuristic?
 7788            output_dict['perf'] = bold_perfusion(
 7789                perfusion_image,
 7790                t1_image,
 7791                hier['brain_n4_dnz'],
 7792                t1atropos,
 7793                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
 7794                n_to_trim = perfusion_trim,
 7795                m0_image = perfusion_m0_image,
 7796                m0_indices = perfusion_m0,
 7797                verbose=verbose )
 7798
 7799    if pet_3d_image is not None:
 7800        if pet_3d_image.dimension == 3: # FIXME - better heuristic?
 7801            output_dict['pet3d'] = pet3d_summary(
 7802                pet_3d_image,
 7803                t1_image,
 7804                hier['brain_n4_dnz'],
 7805                t1atropos,
 7806                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
 7807                verbose=verbose )
 7808    ################################## do the rsf .....
 7809    if len(rsf_image) > 0:
 7810        my_motion_tx = 'antsRegistrationSyNRepro[r]'
 7811        rsf_image = [i for i in rsf_image if i is not None]
 7812        if verbose:
 7813            print('rsf length ' + str( len( rsf_image ) ) )
 7814        if len( rsf_image ) >= 2: # assume 2 is the largest possible value
 7815            rsf_image1 = rsf_image[0]
 7816            rsf_image2 = rsf_image[1]
 7817            # build a template then join the images
 7818            if verbose:
 7819                print("initial average for rsf")
 7820            rsfavg1, hlinds = loop_timeseries_censoring( rsf_image1, 0.1 )
 7821            rsfavg1=get_average_rsf(rsfavg1)
 7822            rsfavg2, hlinds = loop_timeseries_censoring( rsf_image2, 0.1 )
 7823            rsfavg2=get_average_rsf(rsfavg2)
 7824            if verbose:
 7825                print("template average for rsf")
 7826            init_temp = ants.image_clone( rsfavg1 )
 7827            if rsf_image1.shape[3] < rsf_image2.shape[3]:
 7828                init_temp = ants.image_clone( rsfavg2 )
 7829            boldTemplate = ants.build_template(
 7830                initial_template = init_temp,
 7831                image_list=[rsfavg1,rsfavg2],
 7832                type_of_transform="antsRegistrationSyNQuickRepro[s]",
 7833                iterations=5, verbose=False )
 7834            if verbose:
 7835                print("join the 2 rsf")
 7836            if rsf_image1.shape[3] > 10 and rsf_image2.shape[3] > 10:
 7837                leadvols = list(range(8))
 7838                rsf_image2 = remove_volumes_from_timeseries( rsf_image2, leadvols )
 7839                rsf_image = merge_timeseries_data( rsf_image1, rsf_image2 )
 7840            elif rsf_image1.shape[3] > rsf_image2.shape[3]:
 7841                rsf_image = rsf_image1
 7842            else:
 7843                rsf_image = rsf_image2
 7844        elif len( rsf_image ) == 1:
 7845            rsf_image = rsf_image[0]
 7846            boldTemplate, hlinds = loop_timeseries_censoring( rsf_image, 0.1 )
 7847            boldTemplate = get_average_rsf(boldTemplate)
 7848        if rsf_image.shape[3] > 10: # FIXME - better heuristic?
 7849            rsfprolist = [] # FIXMERSF
 7850            # Create the parameter DataFrame
 7851            df = pd.DataFrame({
 7852                "num": [134, 122, 129],
 7853                "loop": [0.50, 0.25, 0.50],
 7854                "cens": [True, True, True],
 7855                "HM": [1.0, 5.0, 0.5],
 7856                "ff": ["tight", "tight", "tight"],
 7857                "CC": [5, 5, 0.8],
 7858                "imp": [True, True, True],
 7859                "up": [rsf_upsampling, rsf_upsampling, rsf_upsampling],
 7860                "coords": [False,False,False]
 7861            }, index=[0, 1, 2])
 7862            for p in range(df.shape[0]):
 7863                if verbose:
 7864                    print("rsf parameters")
 7865                    print( df.iloc[p] )
 7866                if df['ff'].iloc[p] == 'broad':
 7867                    f=[ 0.008, 0.15 ]
 7868                elif df['ff'].iloc[p] == 'tight':
 7869                    f=[ 0.03, 0.08 ]
 7870                elif df['ff'].iloc[p] == 'mid':
 7871                    f=[ 0.01, 0.1 ]
 7872                elif df['ff'].iloc[p] == 'mid2':
 7873                    f=[ 0.01, 0.08 ]
 7874                else:
 7875                    raise ValueError("we do not recognize this parameter choice for frequency filtering: " + df['ff'].iloc[p] )
 7876                HM = df['HM'].iloc[p]
 7877                CC = df['CC'].iloc[p]
 7878                loop= df['loop'].iloc[p]
 7879                cens =df['cens'].iloc[p]
 7880                imp = df['imp'].iloc[p]
 7881                rsf0 = resting_state_fmri_networks(
 7882                                            rsf_image,
 7883                                            boldTemplate,
 7884                                            hier['brain_n4_dnz'],
 7885                                            t1atropos,
 7886                                            f=f,
 7887                                            FD_threshold=HM, 
 7888                                            spa = None, 
 7889                                            spt = None, 
 7890                                            nc = CC,
 7891                                            outlier_threshold=loop,
 7892                                            ica_components = 0,
 7893                                            impute = imp,
 7894                                            censor = cens,
 7895                                            despike = 2.5,
 7896                                            motion_as_nuisance = True,
 7897                                            upsample=df['up'].iloc[p],
 7898                                            clean_tmp=0.66,
 7899                                            paramset=df['num'].iloc[p],
 7900                                            powers=df['coords'].iloc[p],
 7901                                            verbose=verbose ) # default
 7902                rsfprolist.append( rsf0 )
 7903            output_dict['rsf'] = rsfprolist
 7904
 7905    if nm_image_list is not None:
 7906        if verbose:
 7907            print('nm')
 7908        if srmodel is None:
 7909            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], verbose=verbose )
 7910        else:
 7911            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], srmodel=srmodel, target_range=target_range, verbose=verbose  )
 7912################################## do the dti .....
 7913    if len(dw_image) > 0 :
 7914        if verbose:
 7915            print('dti-x')
 7916        if len( dw_image ) == 1: # use T1 for distortion correction and brain extraction
 7917            if verbose:
 7918                print("We have only one DTI: " + str(len(dw_image)))
 7919            dw_image = dw_image[0]
 7920            btpB0, btpDW = get_average_dwi_b0(dw_image)
 7921            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
 7922            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
 7923                syn_metric='CC', syn_sampling=2,
 7924                reg_iterations=[50,50,20],
 7925                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
 7926                initial_transform=initrig
 7927                )
 7928            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
 7929            btpDW = ants.apply_transforms( btpDW, btpDW,
 7930                tempreg['invtransforms'][1], interpolator='linear')
 7931            btpB0 = ants.apply_transforms( btpB0, btpB0,
 7932                tempreg['invtransforms'][1], interpolator='linear')
 7933            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'][1], interpolator='nearestNeighbor')
 7934            # dwimask = ants.iMath(dwimask,'MD',1)
 7935            t12dwi = ants.apply_transforms( btpDW, hier['brain_n4_dnz'], tempreg['fwdtransforms'][1], interpolator='linear')
 7936            output_dict['DTI'] = joint_dti_recon(
 7937                dw_image,
 7938                bvals[0],
 7939                bvecs[0],
 7940                jhu_atlas=JHU_atlas,
 7941                jhu_labels=JHU_labels,
 7942                brain_mask = dwimask,
 7943                reference_B0 = btpB0,
 7944                reference_DWI = btpDW,
 7945                srmodel=srmodel,
 7946                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
 7947                denoise=dti_denoise,
 7948                verbose = verbose)
 7949        else :  # use phase encoding acquisitions for distortion correction and T1 for brain extraction
 7950            if verbose:
 7951                print("We have both DTI_LR and DTI_RL: " + str(len(dw_image)))
 7952            a1b,a1w=get_average_dwi_b0(dw_image[0])
 7953            a2b,a2w=get_average_dwi_b0(dw_image[1],fixed_b0=a1b,fixed_dwi=a1w)
 7954            btpB0, btpDW = dti_template(
 7955                b_image_list=[a1b,a2b],
 7956                w_image_list=[a1w,a2w],
 7957                iterations=7, verbose=verbose )
 7958            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
 7959            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
 7960                syn_metric='CC', syn_sampling=2,
 7961                reg_iterations=[50,50,20],
 7962                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
 7963                initial_transform=initrig
 7964                )
 7965            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
 7966            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'], interpolator='nearestNeighbor')
 7967            output_dict['DTI'] = joint_dti_recon(
 7968                dw_image[0],
 7969                bvals[0],
 7970                bvecs[0],
 7971                jhu_atlas=JHU_atlas,
 7972                jhu_labels=JHU_labels,
 7973                brain_mask = dwimask,
 7974                reference_B0 = btpB0,
 7975                reference_DWI = btpDW,
 7976                srmodel=srmodel,
 7977                img_RL=dw_image[1],
 7978                bval_RL=bvals[1],
 7979                bvec_RL=bvecs[1],
 7980                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
 7981                denoise=dti_denoise,
 7982                verbose = verbose)
 7983        mydti = output_dict['DTI']
 7984        # summarize dwi with T1 outputs
 7985        # first - register ....
 7986        reg = ants.registration( mydti['recon_fa'], hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[s]', total_sigma=1.0 )
 7987        ##################################################
 7988        output_dict['FA_summ'] = hierarchical_modality_summary(
 7989            mydti['recon_fa'],
 7990            hier=hier,
 7991            modality_name='fa',
 7992            transformlist=reg['fwdtransforms'],
 7993            verbose = False )
 7994        ##################################################
 7995        output_dict['MD_summ'] = hierarchical_modality_summary(
 7996            mydti['recon_md'],
 7997            hier=hier,
 7998            modality_name='md',
 7999            transformlist=reg['fwdtransforms'],
 8000            verbose = False )
 8001        # these inputs should come from nicely processed data
 8002        dktmapped = ants.apply_transforms(
 8003            mydti['recon_fa'],
 8004            hier['dkt_parc']['dkt_cortex'],
 8005            reg['fwdtransforms'], interpolator='nearestNeighbor' )
 8006        citmapped = ants.apply_transforms(
 8007            mydti['recon_fa'],
 8008            hier['cit168lab'],
 8009            reg['fwdtransforms'], interpolator='nearestNeighbor' )
 8010        dktmapped[ citmapped > 0]=0
 8011        mask = ants.threshold_image( mydti['recon_fa'], 0.01, 2.0 ).iMath("GetLargestComponent")
 8012        if do_tractography: # dwi_deterministic_tracking dwi_closest_peak_tracking
 8013            output_dict['tractography'] = dwi_deterministic_tracking(
 8014                mydti['dwi_LR_dewarped'],
 8015                mydti['recon_fa'],
 8016                mydti['bval_LR'],
 8017                mydti['bvec_LR'],
 8018                seed_density = 1,
 8019                mask=mask,
 8020                verbose = verbose )
 8021            mystr = output_dict['tractography']
 8022            output_dict['tractography_connectivity'] = dwi_streamline_connectivity( mystr['streamlines'], dktmapped+citmapped, cnxcsv, verbose=verbose )
 8023    ################################## do the flair .....
 8024    if flair_image is not None:
 8025        if verbose:
 8026            print('flair')
 8027        wmhprior = None
 8028        priorfn = ex_path_mm + 'CIT168_wmhprior_700um_pad_adni.nii.gz'
 8029        if ( exists( priorfn ) ):
 8030            wmhprior = ants.image_read( priorfn )
 8031            wmhprior = ants.apply_transforms( t1_image, wmhprior, do_normalization['invtransforms'] )
 8032        output_dict['flair'] = boot_wmh( flair_image, t1_image, t1atropos,
 8033            prior_probability=wmhprior, verbose=verbose )
 8034    #################################################################
 8035    ### NOTES: deforming to a common space and writing out images ###
 8036    ### images we want come from: DTI, NM, rsf, thickness ###########
 8037    #################################################################
 8038    if do_normalization is not None:
 8039        if verbose:
 8040            print('normalization')
 8041        # might reconsider this template space - cropped and/or higher res?
 8042        # template = ants.resample_image( template, [1,1,1], use_voxels=False )
 8043        # t1reg = ants.registration( template, hier['brain_n4_dnz'], "antsRegistrationSyNQuickRepro[s]")
 8044        t1reg = do_normalization
 8045        if do_kk:
 8046            normalization_dict['kk_norm'] = ants.apply_transforms( group_template, output_dict['kk']['thickness_image'], group_transform )
 8047        if output_dict['DTI'] is not None:
 8048            mydti = output_dict['DTI']
 8049            dtirig = ants.registration( hier['brain_n4_dnz'], mydti['recon_fa'], 'antsRegistrationSyNRepro[r]' )
 8050            normalization_dict['MD_norm'] = ants.apply_transforms( group_template, mydti['recon_md'],group_transform+dtirig['fwdtransforms'] )
 8051            normalization_dict['FA_norm'] = ants.apply_transforms( group_template, mydti['recon_fa'],group_transform+dtirig['fwdtransforms'] )
 8052            output_directory = tempfile.mkdtemp()
 8053            do_dti_norm=False
 8054            if do_dti_norm:
 8055                comptx = ants.apply_transforms( group_template, group_template, group_transform+dtirig['fwdtransforms'], compose = output_directory + '/xxx' )
 8056                tspc=[2.,2.,2.]
 8057                if srmodel is not None:
 8058                    tspc=[1.,1.,1.]
 8059                group_template2mm = ants.resample_image( group_template, tspc  )
 8060                normalization_dict['DTI_norm'] = transform_and_reorient_dti( group_template2mm, mydti['dti'], comptx, verbose=False )
 8061            import shutil
 8062            shutil.rmtree(output_directory, ignore_errors=True )
 8063        if output_dict['rsf'] is not None:
 8064            if False:
 8065                rsfpro = output_dict['rsf'] # FIXME
 8066                rsfrig = ants.registration( hier['brain_n4_dnz'], rsfpro['meanBold'], 'antsRegistrationSyNRepro[r]' )
 8067                for netid in get_antsimage_keys( rsfpro ):
 8068                    rsfkey = netid + "_norm"
 8069                    normalization_dict[rsfkey] = ants.apply_transforms(
 8070                        group_template, rsfpro[netid],
 8071                        group_transform+rsfrig['fwdtransforms'] )
 8072        if output_dict['perf'] is not None: # zizzer
 8073            comptx = group_transform + output_dict['perf']['t1reg']['invtransforms']
 8074            normalization_dict['perf_norm'] = ants.apply_transforms( group_template,
 8075                output_dict['perf']['perfusion'], comptx,
 8076                whichtoinvert=[False,False,True,False] )
 8077            normalization_dict['cbf_norm'] = ants.apply_transforms( group_template,
 8078                output_dict['perf']['cbf'], comptx,
 8079                whichtoinvert=[False,False,True,False] )
 8080        if output_dict['pet3d'] is not None: # zizzer
 8081            secondTx=output_dict['pet3d']['t1reg']['invtransforms']
 8082            comptx = group_transform + secondTx
 8083            if len( secondTx ) == 2:
 8084                wti=[False,False,True,False]
 8085            else:
 8086                wti=[False,False,True]
 8087            normalization_dict['pet3d_norm'] = ants.apply_transforms( group_template,
 8088                output_dict['pet3d']['pet3d'], comptx,
 8089                whichtoinvert=wti )
 8090        if nm_image_list is not None:
 8091            nmpro = output_dict['NM']
 8092            nmrig = nmpro['t1_to_NM_transform'] # this is an inverse tx
 8093            normalization_dict['NM_norm'] = ants.apply_transforms( group_template, nmpro['NM_avg'], group_transform+nmrig,
 8094                whichtoinvert=[False,False,True])
 8095
 8096    if verbose:
 8097        print('mm done')
 8098    return output_dict, normalization_dict
 8099
 8100
 8101def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False ):
 8102    """
 8103    write the tabular and normalization output of the mm function
 8104
 8105    Parameters
 8106    -------------
 8107
 8108    output_prefix : prefix for file outputs - modality specific postfix will be added
 8109
 8110    mm  : output of mm function for modality-space processing should be a dictionary with 
 8111        dictionary entries for each modality.
 8112
 8113    mm_norm : output of mm function for normalized processing
 8114
 8115    t1wide : wide output data frame from t1 hierarchical
 8116
 8117    separator : string or character separator for filenames
 8118
 8119    verbose : boolean
 8120
 8121    Returns
 8122    ---------
 8123
 8124    both csv and image files written to disk.  the primary outputs will be
 8125    output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images
 8126
 8127    """
 8128    from dipy.io.streamline import save_tractogram
 8129    if mm_norm is not None:
 8130        for mykey in mm_norm:
 8131            tempfn = output_prefix + separator + mykey + '.nii.gz'
 8132            if mm_norm[mykey] is not None:
 8133                image_write_with_thumbnail( mm_norm[mykey], tempfn )
 8134    thkderk = None
 8135    if t1wide is not None:
 8136        thkderk = t1wide.iloc[: , 1:]
 8137    kkderk = None
 8138    if 'kk' in mm:
 8139        if mm['kk'] is not None:
 8140            kkderk = mm['kk']['thickness_dataframe'].iloc[: , 1:]
 8141            mykey='thickness_image'
 8142            tempfn = output_prefix + separator + mykey + '.nii.gz'
 8143            image_write_with_thumbnail( mm['kk'][mykey], tempfn )
 8144    nmderk = None
 8145    if 'NM' in mm:
 8146        if mm['NM'] is not None:
 8147            nmderk = mm['NM']['NM_dataframe_wide'].iloc[: , 1:]
 8148            for mykey in get_antsimage_keys( mm['NM'] ):
 8149                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8150                image_write_with_thumbnail( mm['NM'][mykey], tempfn, thumb=False )
 8151
 8152    faderk = mdderk = fat1derk = mdt1derk = None
 8153
 8154    if 'DTI' in mm:
 8155        if mm['DTI'] is not None:
 8156            mydti = mm['DTI']
 8157            myop = output_prefix + separator
 8158            ants.image_write( mydti['dti'],  myop + 'dti.nii.gz' )
 8159            write_bvals_bvecs( mydti['bval_LR'], mydti['bvec_LR'], myop + 'reoriented' )
 8160            image_write_with_thumbnail( mydti['dwi_LR_dewarped'],  myop + 'dwi.nii.gz' )
 8161            image_write_with_thumbnail( mydti['dtrecon_LR_dewarp']['RGB'] ,  myop + 'DTIRGB.nii.gz' )
 8162            image_write_with_thumbnail( mydti['jhu_labels'],  myop+'dtijhulabels.nii.gz', mydti['recon_fa'] )
 8163            image_write_with_thumbnail( mydti['recon_fa'],  myop+'dtifa.nii.gz' )
 8164            image_write_with_thumbnail( mydti['recon_md'],  myop+'dtimd.nii.gz' )
 8165            image_write_with_thumbnail( mydti['b0avg'],  myop+'b0avg.nii.gz' )
 8166            image_write_with_thumbnail( mydti['dwiavg'],  myop+'dwiavg.nii.gz' )
 8167            faderk = mm['DTI']['recon_fa_summary'].iloc[: , 1:]
 8168            mdderk = mm['DTI']['recon_md_summary'].iloc[: , 1:]
 8169            fat1derk = mm['FA_summ'].iloc[: , 1:]
 8170            mdt1derk = mm['MD_summ'].iloc[: , 1:]
 8171    if 'tractography' in mm:
 8172        if mm['tractography'] is not None:
 8173            ofn = output_prefix + separator + 'tractogram.trk'
 8174            if mm['tractography']['tractogram'] is not None:
 8175                save_tractogram( mm['tractography']['tractogram'], ofn )
 8176    cnxderk = None
 8177    if 'tractography_connectivity' in mm:
 8178        if mm['tractography_connectivity'] is not None:
 8179            cnxderk = mm['tractography_connectivity']['connectivity_wide'].iloc[: , 1:] # NOTE: connectivity_wide is not much tested
 8180            ofn = output_prefix + separator + 'dtistreamlineconn.csv'
 8181            pd.DataFrame(mm['tractography_connectivity']['connectivity_matrix']).to_csv( ofn )
 8182
 8183    dlist = [
 8184        thkderk,
 8185        kkderk,
 8186        nmderk,
 8187        faderk,
 8188        mdderk,
 8189        fat1derk,
 8190        mdt1derk,
 8191        cnxderk
 8192        ]
 8193    is_all_none = all(element is None for element in dlist)
 8194    if is_all_none:
 8195        mm_wide = pd.DataFrame({'u_hier_id': [output_prefix] })
 8196    else:
 8197        mm_wide = pd.concat( dlist, axis=1, ignore_index=False )
 8198
 8199    mm_wide = mm_wide.copy()
 8200    if 'NM' in mm:
 8201        if mm['NM'] is not None:
 8202            nmwide = dict_to_dataframe( mm['NM'] )
 8203            if mm_wide.shape[0] > 0 and nmwide.shape[0] > 0:
 8204                nmwide.set_index( mm_wide.index, inplace=True )
 8205            mm_wide = pd.concat( [mm_wide, nmwide ], axis=1, ignore_index=False )
 8206    if 'flair' in mm:
 8207        if mm['flair'] is not None:
 8208            myop = output_prefix + separator + 'wmh.nii.gz'
 8209            pngfnb = output_prefix + separator + 'wmh_seg.png'
 8210            ants.plot( mm['flair']['flair'], mm['flair']['WMH_posterior_probability_map'], axis=2, nslices=21, ncol=7, filename=pngfnb, crop=True )
 8211            if mm['flair']['WMH_probability_map'] is not None:
 8212                image_write_with_thumbnail( mm['flair']['WMH_probability_map'], myop, thumb=False )
 8213            flwide = dict_to_dataframe( mm['flair'] )
 8214            if mm_wide.shape[0] > 0 and flwide.shape[0] > 0:
 8215                flwide.set_index( mm_wide.index, inplace=True )
 8216            mm_wide = pd.concat( [mm_wide, flwide ], axis=1, ignore_index=False )
 8217    if 'rsf' in mm:
 8218        if mm['rsf'] is not None:
 8219            fcnxpro=99
 8220            rsfdata = mm['rsf']
 8221            if not isinstance( rsfdata, list ):
 8222                rsfdata = [ rsfdata ]
 8223            for rsfpro in rsfdata:
 8224                fcnxpro=str( rsfpro['paramset']  )
 8225                pronum = 'fcnxpro'+str(fcnxpro)+"_"
 8226                if verbose:
 8227                    print("Collect rsf data " + pronum)
 8228                new_rsf_wide = dict_to_dataframe( rsfpro )
 8229                new_rsf_wide = pd.concat( [new_rsf_wide, rsfpro['corr_wide'] ], axis=1, ignore_index=False )
 8230                new_rsf_wide = new_rsf_wide.add_prefix( pronum )
 8231                new_rsf_wide.set_index( mm_wide.index, inplace=True )
 8232                ofn = output_prefix + separator + pronum + '.csv'
 8233                new_rsf_wide.to_csv( ofn )
 8234                mm_wide = pd.concat( [mm_wide, new_rsf_wide ], axis=1, ignore_index=False )
 8235                for mykey in get_antsimage_keys( rsfpro ):
 8236                    myop = output_prefix + separator + pronum + mykey + '.nii.gz'
 8237                    image_write_with_thumbnail( rsfpro[mykey], myop, thumb=True )
 8238                ofn = output_prefix + separator + pronum + 'rsfcorr.csv'
 8239                rsfpro['corr'].to_csv( ofn )
 8240                # apply same principle to new correlation matrix, doesn't need to be incorporated with mm_wide
 8241                ofn2 = output_prefix + separator + pronum + 'nodescorr.csv'
 8242                rsfpro['fullCorrMat'].to_csv( ofn2 )
 8243    if 'DTI' in mm:
 8244        if mm['DTI'] is not None:
 8245            mydti = mm['DTI']
 8246            mm_wide['dti_tsnr_b0_mean'] =  mydti['tsnr_b0'].mean()
 8247            mm_wide['dti_tsnr_dwi_mean'] =  mydti['tsnr_dwi'].mean()
 8248            mm_wide['dti_dvars_b0_mean'] =  mydti['dvars_b0'].mean()
 8249            mm_wide['dti_dvars_dwi_mean'] =  mydti['dvars_dwi'].mean()
 8250            mm_wide['dti_ssnr_b0_mean'] =  mydti['ssnr_b0'].mean()
 8251            mm_wide['dti_ssnr_dwi_mean'] =  mydti['ssnr_dwi'].mean()
 8252            mm_wide['dti_fa_evr'] =  mydti['fa_evr']
 8253            mm_wide['dti_fa_SNR'] =  mydti['fa_SNR']
 8254            if mydti['framewise_displacement'] is not None:
 8255                mm_wide['dti_high_motion_count'] =  mydti['high_motion_count']
 8256                mm_wide['dti_FD_mean'] = mydti['framewise_displacement'].mean()
 8257                mm_wide['dti_FD_max'] = mydti['framewise_displacement'].max()
 8258                mm_wide['dti_FD_sd'] = mydti['framewise_displacement'].std()
 8259                fdfn = output_prefix + separator + '_fd.csv'
 8260            else:
 8261                mm_wide['dti_FD_mean'] = mm_wide['dti_FD_max'] = mm_wide['dti_FD_sd'] = 'NA'
 8262
 8263    if 'perf' in mm:
 8264        if mm['perf'] is not None:
 8265            perfpro = mm['perf']
 8266            prwide = dict_to_dataframe( perfpro )
 8267            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
 8268                prwide.set_index( mm_wide.index, inplace=True )
 8269            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
 8270            if 'perf_dataframe' in perfpro.keys():
 8271                pderk = perfpro['perf_dataframe'].iloc[: , 1:]
 8272                pderk.set_index( mm_wide.index, inplace=True )
 8273                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
 8274            else:
 8275                print("FIXME - perfusion dataframe")
 8276            for mykey in get_antsimage_keys( mm['perf'] ):
 8277                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8278                image_write_with_thumbnail( mm['perf'][mykey], tempfn, thumb=False )
 8279
 8280    if 'pet3d' in mm:
 8281        if mm['pet3d'] is not None:
 8282            pet3dpro = mm['pet3d']
 8283            prwide = dict_to_dataframe( pet3dpro )
 8284            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
 8285                prwide.set_index( mm_wide.index, inplace=True )
 8286            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
 8287            if 'pet3d_dataframe' in pet3dpro.keys():
 8288                pderk = pet3dpro['pet3d_dataframe'].iloc[: , 1:]
 8289                pderk.set_index( mm_wide.index, inplace=True )
 8290                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
 8291            else:
 8292                print("FIXME - pet3dusion dataframe")
 8293            for mykey in get_antsimage_keys( mm['pet3d'] ):
 8294                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8295                image_write_with_thumbnail( mm['pet3d'][mykey], tempfn, thumb=False )
 8296
 8297    mmwidefn = output_prefix + separator + 'mmwide.csv'
 8298    mm_wide.to_csv( mmwidefn )
 8299    if verbose:
 8300        print( output_prefix + " write_mm done." )
 8301    return
 8302
 8303
 8304def mm_nrg(
 8305    studyid,   # pandas data frame
 8306    sourcedir = os.path.expanduser( "~/data/PPMI/MV/example_s3_b/images/PPMI/" ),
 8307    sourcedatafoldername = 'images', # root for source data
 8308    processDir = "processed", # where output will go - parallel to sourcedatafoldername
 8309    mysep = '-', # define a separator for filename components
 8310    srmodel_T1 = False, # optional - will add a great deal of time
 8311    srmodel_NM = False, # optional - will add a great deal of time
 8312    srmodel_DTI = False, # optional - will add a great deal of time
 8313    visualize = True,
 8314    nrg_modality_list = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI" ],
 8315    verbose = True
 8316):
 8317    """
 8318    too dangerous to document ... use with care.
 8319
 8320    processes multiple modality MRI specifically:
 8321
 8322    * T1w
 8323    * T2Flair
 8324    * DTI, DTI_LR, DTI_RL
 8325    * rsfMRI, rsfMRI_LR, rsfMRI_RL
 8326    * NM2DMT (neuromelanin)
 8327
 8328    other modalities may be added later ...
 8329
 8330    "trust me, i know what i'm doing" - sledgehammer
 8331
 8332    convert to pynb via:
 8333        p2j mm.py -o
 8334
 8335    convert the ipynb to html via:
 8336        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
 8337
 8338    this function assumes NRG format for the input data ....
 8339    we also assume that t1w hierarchical (if already done) was written
 8340    via its standardized write function.
 8341    NRG = https://github.com/stnava/biomedicalDataOrganization
 8342
 8343    this function is verbose
 8344
 8345    Parameters
 8346    -------------
 8347
 8348    studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID
 8349        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
 8350        these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor,
 8351        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
 8352        t1 is required.  rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.  see antspymm.generate_mm_dataframe
 8353
 8354    sourcedir : a study specific folder containing individual subject folders
 8355
 8356    sourcedatafoldername : root for source data e.g. "images"
 8357
 8358    processDir : where output will go - parallel to sourcedatafoldername e.g.
 8359        "processed"
 8360
 8361    mysep : define a character separator for filename components
 8362
 8363    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
 8364
 8365    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8366
 8367    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8368
 8369    visualize : True - will plot some results to png
 8370
 8371    nrg_modality_list : list of permissible modalities - always include [T1w] as base
 8372
 8373    verbose : boolean
 8374
 8375    Returns
 8376    ---------
 8377
 8378    writes output to disk and potentially produces figures that may be
 8379    captured in a ipynb / html file.
 8380
 8381    """
 8382    studyid = studyid.dropna(axis=1)
 8383    if studyid.shape[0] < 1:
 8384        raise ValueError('studyid has no rows')
 8385    musthavecols = ['subjectID','date','imageID']
 8386    for k in range(len(musthavecols)):
 8387        if not musthavecols[k] in studyid.keys():
 8388            raise ValueError('studyid is missing column ' +musthavecols[k] )
 8389    def makewideout( x, separator = '-' ):
 8390        return x + separator + 'mmwide.csv'
 8391    if nrg_modality_list[0] != 'T1w':
 8392        nrg_modality_list.insert(0, "T1w" )
 8393    testloop = False
 8394    counter=0
 8395    import glob as glob
 8396    from os.path import exists
 8397    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8398    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
 8399    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8400    if not exists( templatefn ):
 8401        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 8402        antspyt1w.get_data( force_download=True )
 8403        get_data( force_download=True )
 8404    temp = sourcedir.split( "/" )
 8405    splitCount = len( temp )
 8406    template = mm_read( templatefn ) # Read in template
 8407    test_run = False
 8408    if test_run:
 8409        visualize=False
 8410    # get sid and dtid from studyid
 8411    sid = str(studyid['subjectID'].iloc[0])
 8412    dtid = str(studyid['date'].iloc[0])
 8413    iid = str(studyid['imageID'].iloc[0])
 8414    subjectrootpath = os.path.join(sourcedir,sid, dtid)
 8415    if verbose:
 8416        print("subjectrootpath: "+ subjectrootpath )
 8417    myimgsInput = glob.glob( subjectrootpath+"/*" )
 8418    myimgsInput.sort( )
 8419    if verbose:
 8420        print( myimgsInput )
 8421    # hierarchical
 8422    # NOTE: if there are multiple T1s for this time point, should take
 8423    # the one with the highest resnetGrade
 8424    t1_search_path = os.path.join(subjectrootpath, "T1w", iid, "*nii.gz")
 8425    if verbose:
 8426        print(f"t1 search path: {t1_search_path}")
 8427    t1fn = glob.glob(t1_search_path)
 8428    t1fn.sort()
 8429    if len( t1fn ) < 1:
 8430        raise ValueError('mm_nrg cannot find the T1w with uid ' + iid + ' @ ' + subjectrootpath )
 8431    t1fn = t1fn[0]
 8432    t1 = mm_read( t1fn )
 8433    hierfn0 = re.sub( sourcedatafoldername, processDir, t1fn)
 8434    hierfn0 = re.sub( ".nii.gz", "", hierfn0)
 8435    hierfn = re.sub( "T1w", "T1wHierarchical", hierfn0)
 8436    hierfn = hierfn + mysep
 8437    hierfntest = hierfn + 'snseg.csv'
 8438    regout = hierfn0 + mysep + "syn"
 8439    templateTx = {
 8440        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
 8441        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
 8442    if verbose:
 8443        print( "-<REGISTRATION EXISTENCE>-: \n" + 
 8444              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
 8445            str(exists( templateTx['fwdtransforms'][0])) + " " +
 8446            str(exists( templateTx['fwdtransforms'][1])) + " " +
 8447            str(exists( templateTx['invtransforms'][0])) + " " +
 8448            str(exists( templateTx['invtransforms'][1])) )
 8449    if verbose:
 8450        print( hierfntest )
 8451    hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 8452    hier = None
 8453    if not hierexists and not testloop:
 8454        subjectpropath = os.path.dirname( hierfn )
 8455        if verbose:
 8456            print( subjectpropath )
 8457        os.makedirs( subjectpropath, exist_ok=True  )
 8458        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
 8459        antspyt1w.write_hierarchical( hier, hierfn )
 8460        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8461                hier['dataframes'], identifier=None )
 8462        t1wide.to_csv( hierfn + 'mmwide.csv' )
 8463    ################# read the hierarchical data ###############################
 8464    hier = antspyt1w.read_hierarchical( hierfn )
 8465    if exists( hierfn + 'mmwide.csv' ) :
 8466        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 8467    elif not testloop:
 8468        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8469                hier['dataframes'], identifier=None )
 8470    if srmodel_T1 is not None :
 8471        hierfnSR = re.sub( sourcedatafoldername, processDir, t1fn)
 8472        hierfnSR = re.sub( "T1w", "T1wHierarchicalSR", hierfnSR)
 8473        hierfnSR = re.sub( ".nii.gz", "", hierfnSR)
 8474        hierfnSR = hierfnSR + mysep
 8475        hierfntest = hierfnSR + 'mtl.csv'
 8476        if verbose:
 8477            print( hierfntest )
 8478        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 8479        if not hierexists:
 8480            subjectpropath = os.path.dirname( hierfnSR )
 8481            if verbose:
 8482                print( subjectpropath )
 8483            os.makedirs( subjectpropath, exist_ok=True  )
 8484            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
 8485            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
 8486            mdlfn = re.sub( 'bestup', bestup, srmodel_T1 ) 
 8487            if verbose:
 8488                print( mdlfn )
 8489            if exists( mdlfn ):
 8490                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 8491            else:
 8492                print( mdlfn + " does not exist - will not run.")
 8493            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
 8494            antspyt1w.write_hierarchical( hierSR, hierfnSR )
 8495            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8496                    hierSR['dataframes'], identifier=None )
 8497            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
 8498    hier = antspyt1w.read_hierarchical( hierfn )
 8499    if exists( hierfn + 'mmwide.csv' ) :
 8500        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 8501    elif not testloop:
 8502        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8503                hier['dataframes'], identifier=None )
 8504    if not testloop:
 8505        t1imgbrn = hier['brain_n4_dnz']
 8506        t1atropos = hier['dkt_parc']['tissue_segmentation']
 8507    # loop over modalities and then unique image IDs
 8508    # we treat NM in a "special" way -- aggregating repeats
 8509    # other modalities (beyond T1) are treated individually
 8510    nimages = len(myimgsInput)
 8511    if verbose:
 8512        print(  " we have : " + str(nimages) + " modalities.")
 8513    for overmodX in nrg_modality_list:
 8514        counter=counter+1
 8515        if counter > (len(nrg_modality_list)+1):
 8516            print("This is weird. " + str(counter))
 8517            return
 8518        if overmodX == 'T1w':
 8519            iidOtherMod = iid
 8520            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8521            myimgsr = glob.glob(mod_search_path)
 8522        elif overmodX == 'NM2DMT' and ('nmid1' in studyid.keys() ):
 8523            iidOtherMod = str( int(studyid['nmid1'].iloc[0]) )
 8524            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8525            myimgsr = glob.glob(mod_search_path)
 8526            for nmnum in range(2,11):
 8527                locnmnum = 'nmid'+str(nmnum)
 8528                if locnmnum in studyid.keys() :
 8529                    iidOtherMod = str( int(studyid[locnmnum].iloc[0]) )
 8530                    mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8531                    myimgsr.append( glob.glob(mod_search_path)[0] )
 8532        elif 'rsfMRI' in overmodX and ( ( 'rsfid1' in studyid.keys() ) or ('rsfid2' in studyid.keys() ) ):
 8533            myimgsr = []
 8534            if  'rsfid1' in studyid.keys():
 8535                iidOtherMod = str( int(studyid['rsfid1'].iloc[0]) )
 8536                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8537                myimgsr.append( glob.glob(mod_search_path)[0] )
 8538            if  'rsfid2' in studyid.keys():
 8539                iidOtherMod = str( int(studyid['rsfid2'].iloc[0]) )
 8540                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8541                myimgsr.append( glob.glob(mod_search_path)[0] )
 8542        elif 'DTI' in overmodX and (  'dtid1' in studyid.keys() or  'dtid2' in studyid.keys() ):
 8543            myimgsr = []
 8544            if  'dtid1' in studyid.keys():
 8545                iidOtherMod = str( int(studyid['dtid1'].iloc[0]) )
 8546                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8547                myimgsr.append( glob.glob(mod_search_path)[0] )
 8548            if  'dtid2' in studyid.keys():
 8549                iidOtherMod = str( int(studyid['dtid2'].iloc[0]) )
 8550                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8551                myimgsr.append( glob.glob(mod_search_path)[0] )
 8552        elif 'T2Flair' in overmodX and ('flairid' in studyid.keys() ):
 8553            iidOtherMod = str( int(studyid['flairid'].iloc[0]) )
 8554            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8555            myimgsr = glob.glob(mod_search_path)
 8556        if verbose:
 8557            print( "overmod " + overmodX + " " + iidOtherMod )
 8558            print(f"modality search path: {mod_search_path}")
 8559        myimgsr.sort()
 8560        if len(myimgsr) > 0:
 8561            overmodXx = str(overmodX)
 8562            dowrite=False
 8563            if verbose:
 8564                print( 'overmodX is : ' + overmodXx )
 8565                print( 'example image name is : '  )
 8566                print( myimgsr )
 8567            if overmodXx == 'NM2DMT':
 8568                myimgsr2 = myimgsr
 8569                myimgsr2.sort()
 8570                is4d = False
 8571                temp = ants.image_read( myimgsr2[0] )
 8572                if temp.dimension == 4:
 8573                    is4d = True
 8574                if len( myimgsr2 ) == 1 and not is4d: # check dimension
 8575                    myimgsr2 = myimgsr2 + myimgsr2
 8576                subjectpropath = os.path.dirname( myimgsr2[0] )
 8577                subjectpropath = re.sub( sourcedatafoldername, processDir,subjectpropath )
 8578                if verbose:
 8579                    print( "subjectpropath " + subjectpropath )
 8580                mysplit = subjectpropath.split( "/" )
 8581                os.makedirs( subjectpropath, exist_ok=True  )
 8582                mysplitCount = len( mysplit )
 8583                project = mysplit[mysplitCount-5]
 8584                subject = mysplit[mysplitCount-4]
 8585                date = mysplit[mysplitCount-3]
 8586                modality = mysplit[mysplitCount-2]
 8587                uider = mysplit[mysplitCount-1]
 8588                identifier = mysep.join([project, subject, date, modality ])
 8589                identifier = identifier + "_" + iid
 8590                mymm = subjectpropath + "/" + identifier
 8591                mymmout = makewideout( mymm )
 8592                if verbose and not exists( mymmout ):
 8593                    print( "NM " + mymm  + ' execution ')
 8594                elif verbose and exists( mymmout ) :
 8595                    print( "NM " + mymm + ' complete ' )
 8596                if exists( mymmout ):
 8597                    continue
 8598                if is4d:
 8599                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
 8600                else:
 8601                    nmlist = []
 8602                    for zz in myimgsr2:
 8603                        nmlist.append( mm_read( zz ) )
 8604                srmodel_NM_mdl = None
 8605                if srmodel_NM is not None:
 8606                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
 8607                    if isinstance( srmodel_NM, str ):
 8608                        mdlfn = re.sub( "bestup", bestup, srmodel_NM )
 8609                    if exists( mdlfn ):
 8610                        if verbose:
 8611                            print(mdlfn)
 8612                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
 8613                    else:
 8614                        print( mdlfn + " does not exist - wont use SR")
 8615                if not testloop:
 8616                    tabPro, normPro = mm( t1, hier,
 8617                            nm_image_list = nmlist,
 8618                            srmodel=srmodel_NM_mdl,
 8619                            do_tractography=False,
 8620                            do_kk=False,
 8621                            do_normalization=templateTx,
 8622                            test_run=test_run,
 8623                            verbose=True )
 8624                    if not test_run:
 8625                        write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=None, separator=mysep )
 8626                        nmpro = tabPro['NM']
 8627                        mysl = range( nmpro['NM_avg'].shape[2] )
 8628                    if visualize:
 8629                        mysl = range( nmpro['NM_avg'].shape[2] )
 8630                        ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
 8631                        mysl = range( nmpro['NM_avg_cropped'].shape[2] )
 8632                        ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
 8633                        ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
 8634                        ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
 8635            else :
 8636                if len( myimgsr ) > 0:
 8637                    dowrite=False
 8638                    myimgcount = 0
 8639                    if len( myimgsr ) > 0 :
 8640                        myimg = myimgsr[myimgcount]
 8641                        subjectpropath = os.path.dirname( myimg )
 8642                        subjectpropath = re.sub( sourcedatafoldername, processDir, subjectpropath )
 8643                        mysplit = subjectpropath.split("/")
 8644                        mysplitCount = len( mysplit )
 8645                        project = mysplit[mysplitCount-5]
 8646                        date = mysplit[mysplitCount-4]
 8647                        subject = mysplit[mysplitCount-3]
 8648                        mymod = mysplit[mysplitCount-2] # FIXME system dependent
 8649                        uid = mysplit[mysplitCount-1] # unique image id
 8650                        os.makedirs( subjectpropath, exist_ok=True  )
 8651                        if mymod == 'T1w':
 8652                            identifier = mysep.join([project, date, subject, mymod, uid])
 8653                        else:  # add the T1 unique id since that drives a lot of the analysis
 8654                            identifier = mysep.join([project, date, subject, mymod, uid ])
 8655                            identifier = identifier + "_" + iid
 8656                        mymm = subjectpropath + "/" + identifier
 8657                        mymmout = makewideout( mymm )
 8658                        if verbose and not exists( mymmout ):
 8659                            print("Modality specific processing: " + mymod + " execution " )
 8660                            print( mymm )
 8661                        elif verbose and exists( mymmout ) :
 8662                            print("Modality specific processing: " + mymod + " complete " )
 8663                        if exists( mymmout ) :
 8664                            continue
 8665                        if verbose:
 8666                            print(subjectpropath)
 8667                            print(identifier)
 8668                            print( myimg )
 8669                        if not testloop:
 8670                            img = mm_read( myimg )
 8671                            ishapelen = len( img.shape )
 8672                            if mymod == 'T1w' and ishapelen == 3: # for a real run, set to True
 8673                                if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
 8674                                    if verbose:
 8675                                        print('start t1 registration')
 8676                                    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8677                                    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8678                                    template = mm_read( templatefn )
 8679                                    template = ants.resample_image( template, [1,1,1], use_voxels=False )
 8680                                    t1reg = ants.registration( template, hier['brain_n4_dnz'],
 8681                                        "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
 8682                                    myjac = ants.create_jacobian_determinant_image( template,
 8683                                        t1reg['fwdtransforms'][0], do_log=True, geom=True )
 8684                                    image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
 8685                                    if visualize:
 8686                                        ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
 8687                                        ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
 8688                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
 8689                                    dowrite=True
 8690                                    if verbose:
 8691                                        print('start kk')
 8692                                    tabPro, normPro = mm( t1, hier,
 8693                                        srmodel=None,
 8694                                        do_tractography=False,
 8695                                        do_kk=True,
 8696                                        do_normalization=templateTx,
 8697                                        test_run=test_run,
 8698                                        verbose=True )
 8699                                    if visualize:
 8700                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
 8701                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
 8702                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
 8703                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
 8704                            if mymod == 'T2Flair' and ishapelen == 3:
 8705                                dowrite=True
 8706                                tabPro, normPro = mm( t1, hier,
 8707                                    flair_image = img,
 8708                                    srmodel=None,
 8709                                    do_tractography=False,
 8710                                    do_kk=False,
 8711                                    do_normalization=templateTx,
 8712                                    test_run=test_run,
 8713                                    verbose=True )
 8714                                if visualize:
 8715                                    maxslice = np.min( [21, img.shape[2] ] )
 8716                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
 8717                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
 8718                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
 8719                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
 8720                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
 8721                                img2 = None
 8722                                if len( myimgsr ) > 1:
 8723                                    img2 = mm_read( myimgsr[myimgcount+1] )
 8724                                    ishapelen2 = len( img2.shape )
 8725                                    if ishapelen2 != 4 :
 8726                                        img2 = None
 8727                                dowrite=True
 8728                                tabPro, normPro = mm( t1, hier,
 8729                                    rsf_image=[img,img2],
 8730                                    srmodel=None,
 8731                                    do_tractography=False,
 8732                                    do_kk=False,
 8733                                    do_normalization=templateTx,
 8734                                    test_run=test_run,
 8735                                    verbose=True )
 8736                                if tabPro['rsf'] is not None and visualize:
 8737                                    dfn=tabPro['rsf']['dfnname']
 8738                                    maxslice = np.min( [21, tabPro['rsf']['meanBold'].shape[2] ] )
 8739                                    ants.plot( tabPro['rsf']['meanBold'],
 8740                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=mymm+mysep+"meanBOLD.png" )
 8741                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['alff'],"Normalize"),
 8742                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=mymm+mysep+"boldALFF.png" )
 8743                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['falff'],"Normalize"),
 8744                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=mymm+mysep+"boldfALFF.png" )
 8745                                    ants.plot( tabPro['rsf']['meanBold'], tabPro['rsf'][dfn],
 8746                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='DefaultMode', filename=mymm+mysep+"boldDefaultMode.png" )
 8747                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
 8748                                dowrite=True
 8749                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
 8750                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
 8751                                imgList = [ img ]
 8752                                bvalfnList = [ bvalfn ]
 8753                                bvecfnList = [ bvecfn ]
 8754                                if len( myimgsr ) > 1:  # find DTI_RL
 8755                                    dtilrfn = myimgsr[myimgcount+1]
 8756                                    if len( dtilrfn ) == 1:
 8757                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 8758                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 8759                                        imgRL = ants.image_read( dtilrfn )
 8760                                        imgList.append( imgRL )
 8761                                        bvalfnList.append( bvalfnRL )
 8762                                        bvecfnList.append( bvecfnRL )
 8763                                srmodel_DTI_mdl=None
 8764                                if srmodel_DTI is not None:
 8765                                    temp = ants.get_spacing(img)
 8766                                    dtspc=[temp[0],temp[1],temp[2]]
 8767                                    bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
 8768                                    if isinstance( srmodel_DTI, str ):
 8769                                        mdlfn = re.sub( "bestup", bestup, srmodel_DTI )
 8770                                    if exists( mdlfn ):
 8771                                        if verbose:
 8772                                            print(mdlfn)
 8773                                        srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 8774                                    else:
 8775                                        print(mdlfn + " does not exist - wont use SR")
 8776                                tabPro, normPro = mm( t1, hier,
 8777                                    dw_image=imgList,
 8778                                    bvals = bvalfnList,
 8779                                    bvecs = bvecfnList,
 8780                                    srmodel=srmodel_DTI_mdl,
 8781                                    do_tractography=not test_run,
 8782                                    do_kk=False,
 8783                                    do_normalization=templateTx,
 8784                                    test_run=test_run,
 8785                                    verbose=True )
 8786                                mydti = tabPro['DTI']
 8787                                if visualize:
 8788                                    maxslice = np.min( [21, mydti['recon_fa'] ] )
 8789                                    ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
 8790                                    ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
 8791                                    ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
 8792                            if dowrite:
 8793                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep, verbose=True )
 8794                                for mykey in normPro.keys():
 8795                                    if normPro[mykey] is not None:
 8796                                        if visualize and normPro[mykey].components == 1 and False:
 8797                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
 8798        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
 8799            return
 8800        if verbose:
 8801            print("done with " + overmodX )
 8802    if verbose:
 8803        print("mm_nrg complete.")
 8804    return
 8805
 8806
 8807
 8808def mm_csv(
 8809    studycsv,   # pandas data frame
 8810    mysep = '-', # or "_" for BIDS
 8811    srmodel_T1 = False, # optional - will add a great deal of time
 8812    srmodel_NM = False, # optional - will add a great deal of time
 8813    srmodel_DTI = False, # optional - will add a great deal of time
 8814    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
 8815    dti_denoise = False,
 8816    nrg_modality_list = None,
 8817    normalization_template = None,
 8818    normalization_template_output = None,
 8819    normalization_template_transform_type = "antsRegistrationSyNRepro[s]",
 8820    normalization_template_spacing=None,
 8821    enantiomorphic=False,
 8822    perfusion_trim = 10,
 8823    perfusion_m0_image = None,
 8824    perfusion_m0 = None,
 8825    rsf_upsampling = 3.0,
 8826    pet3d = None,
 8827    min_t1_spacing_for_sr = 0.8,
 8828):
 8829    """
 8830    too dangerous to document ... use with care.
 8831
 8832    processes multiple modality MRI specifically:
 8833
 8834    * T1w
 8835    * T2Flair
 8836    * DTI, DTI_LR, DTI_RL
 8837    * rsfMRI, rsfMRI_LR, rsfMRI_RL
 8838    * NM2DMT (neuromelanin)
 8839
 8840    other modalities may be added later ...
 8841
 8842    "trust me, i know what i'm doing" - sledgehammer
 8843
 8844    convert to pynb via:
 8845        p2j mm.py -o
 8846
 8847    convert the ipynb to html via:
 8848        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
 8849
 8850    this function does not assume NRG format for the input data ....
 8851
 8852    Parameters
 8853    -------------
 8854
 8855    studycsv : must have columns:
 8856        - subjectID
 8857        - date or session
 8858        - imageID
 8859        - modality
 8860        - sourcedir
 8861        - outputdir
 8862        - filename (path to the t1 image)
 8863        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
 8864        these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor,
 8865        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
 8866        t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.
 8867        see antspymm.generate_mm_dataframe
 8868
 8869    sourcedir : a study specific folder containing individual subject folders
 8870
 8871    outputdir : a study specific folder where individual output subject folders will go
 8872
 8873    filename : the raw image filename (full path)
 8874
 8875    srmodel_T1 : None (default) - .keras or h5 filename for SR model (siq generated). 
 8876
 8877    srmodel_NM : None (default) - .keras or h5 filename for SR model (siq generated)
 8878    the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data.  see siq.optimize_upsampling_shape.
 8879
 8880    srmodel_DTI : None (default) - .keras or h5 filename for SR model (siq generated). 
 8881    the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data.  see siq.optimize_upsampling_shape.
 8882
 8883    dti_motion_correct : None, Rigid or SyN
 8884
 8885    dti_denoise : boolean
 8886
 8887    nrg_modality_list : optional; defaults to None; use to focus on a given modality
 8888
 8889    normalization_template : optional; defaults to None; if present, all images will
 8890        be deformed into this space and the deformation will be stored with an extension
 8891        related to this variable.  this should be a brain extracted T1w image.
 8892
 8893    normalization_template_output : optional string; defaults to None; naming for the 
 8894        normalization_template outputs which will be in the T1w directory.
 8895
 8896    normalization_template_transform_type : optional string transform type passed to ants.registration
 8897
 8898    normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed 
 8899    
 8900    enantiomorphic: boolean (WIP)
 8901
 8902    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
 8903
 8904    perfusion_m0_image : optional m0 antsImage associated with the perfusion time series
 8905
 8906    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
 8907
 8908    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
 8909
 8910    pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize
 8911
 8912    min_t1_spacing_for_sr : float 
 8913        if the minimum input image spacing is less than this value, 
 8914        the function will return the original image.  Default 0.8.
 8915
 8916    Returns
 8917    ---------
 8918
 8919    writes output to disk and produces figures
 8920
 8921    """
 8922    import traceback
 8923    visualize = True
 8924    verbose = True
 8925    if verbose:
 8926        print( version() )
 8927    if nrg_modality_list is None:
 8928        nrg_modality_list = get_valid_modalities()
 8929    if studycsv.shape[0] < 1:
 8930        raise ValueError('studycsv has no rows')
 8931    musthavecols = ['projectID', 'subjectID','date','imageID','modality','sourcedir','outputdir','filename']
 8932    for k in range(len(musthavecols)):
 8933        if not musthavecols[k] in studycsv.keys():
 8934            raise ValueError('studycsv is missing column ' +musthavecols[k] )
 8935    def makewideout( x, separator = mysep ):
 8936        return x + separator + 'mmwide.csv'
 8937    testloop = False
 8938    counter=0
 8939    import glob as glob
 8940    from os.path import exists
 8941    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8942    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
 8943    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8944    if not exists( templatefn ):
 8945        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 8946        antspyt1w.get_data( force_download=True )
 8947        get_data( force_download=True )
 8948    template = mm_read( templatefn ) # Read in template
 8949    test_run = False
 8950    if test_run:
 8951        visualize=False
 8952    # get sid and dtid from studycsv
 8953    # musthavecols = ['projectID','subjectID','date','imageID','modality','sourcedir','outputdir','filename']
 8954    projid = str(studycsv['projectID'].iloc[0])
 8955    sid = str(studycsv['subjectID'].iloc[0])
 8956    dtid = str(studycsv['date'].iloc[0])
 8957    iid = str(studycsv['imageID'].iloc[0])
 8958    t1iidUse=iid
 8959    modality = str(studycsv['modality'].iloc[0])
 8960    sourcedir = str(studycsv['sourcedir'].iloc[0])
 8961    outputdir = str(studycsv['outputdir'].iloc[0])
 8962    filename = str(studycsv['filename'].iloc[0])
 8963    if not exists(filename):
 8964            raise ValueError('mm_nrg cannot find filename ' + filename + ' in mm_csv' )
 8965
 8966    # hierarchical
 8967    # NOTE: if there are multiple T1s for this time point, should take
 8968    # the one with the highest resnetGrade
 8969    t1fn = filename
 8970    if not exists( t1fn ):
 8971        raise ValueError('mm_nrg cannot find the T1w with uid ' + t1fn )
 8972    t1 = mm_read( t1fn, modality='T1w' )
 8973    minspc = np.min(ants.get_spacing(t1))
 8974    minshape = np.min(t1.shape)
 8975    if minspc < 1e-16:
 8976        warnings.warn('minimum spacing in T1w is too small - cannot process. ' + str(minspc) )
 8977        return
 8978    if minshape < 32:
 8979        warnings.warn('minimum shape in T1w is too small - cannot process. ' + str(minshape) )
 8980        return
 8981
 8982    if enantiomorphic:
 8983        t1 = enantiomorphic_filling_without_mask( t1, axis=0 )[0]
 8984    hierfn = outputdir + "/"  + projid + "/" + sid + "/" + dtid + "/" + "T1wHierarchical" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchical" + mysep + iid + mysep
 8985    hierfnSR = outputdir + "/" + projid + "/"  + sid + "/" + dtid + "/" + "T1wHierarchicalSR" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchicalSR" + mysep + iid + mysep
 8986    hierfntest = hierfn + 'cerebellum.csv'
 8987    if verbose:
 8988        print( hierfntest )
 8989    regout = re.sub("T1wHierarchical","T1w",hierfn) + "syn"
 8990    templateTx = {
 8991        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
 8992        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
 8993    groupTx = None
 8994    # make the T1w directory
 8995    os.makedirs( os.path.dirname(re.sub("T1wHierarchical","T1w",hierfn)), exist_ok=True  )
 8996    if normalization_template_output is not None:
 8997        normout = re.sub("T1wHierarchical","T1w",hierfn) +  normalization_template_output
 8998        templateNormTx = {
 8999            'fwdtransforms': [ normout+'1Warp.nii.gz', normout+'0GenericAffine.mat'],
 9000            'invtransforms': [ normout+'0GenericAffine.mat', normout+'1InverseWarp.nii.gz']  }
 9001        groupTx = templateNormTx['fwdtransforms']
 9002    if verbose:
 9003        print( "-<REGISTRATION EXISTENCE>-: \n" + 
 9004              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
 9005            str(exists( templateTx['fwdtransforms'][0])) + " " +
 9006            str(exists( templateTx['fwdtransforms'][1])) + " " +
 9007            str(exists( templateTx['invtransforms'][0])) + " " +
 9008            str(exists( templateTx['invtransforms'][1])) )
 9009    if verbose:
 9010        print( hierfntest )
 9011    hierexists = exists( hierfntest ) and exists( templateTx['fwdtransforms'][0]) and exists( templateTx['fwdtransforms'][1]) and exists( templateTx['invtransforms'][0]) and exists( templateTx['invtransforms'][1])
 9012    hier = None
 9013    if srmodel_T1 is not None:
 9014        srmodel_T1_mdl = tf.keras.models.load_model( srmodel_T1, compile=False )
 9015        if verbose:
 9016            print("Convert T1w to SR via model ", srmodel_T1 )
 9017        t1 = t1w_super_resolution_with_hemispheres( t1, srmodel_T1_mdl,
 9018            min_spacing=min_t1_spacing_for_sr )
 9019    if not hierexists and not testloop:
 9020        subjectpropath = os.path.dirname( hierfn )
 9021        if verbose:
 9022            print( subjectpropath )
 9023        os.makedirs( subjectpropath, exist_ok=True  )
 9024        ants.image_write( t1, hierfn + 'head.nii.gz' )
 9025        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
 9026        antspyt1w.write_hierarchical( hier, hierfn )
 9027        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9028                hier['dataframes'], identifier=None )
 9029        t1wide.to_csv( hierfn + 'mmwide.csv' )
 9030    ################# read the hierarchical data ###############################
 9031    # over-write the rbp data with a consistent and recent approach ############
 9032    redograding = True
 9033    if redograding:
 9034        myx = antspyt1w.inspect_raw_t1( 
 9035            ants.image_read(t1fn), hierfn + 'rbp' , option='both' )
 9036        myx['brain'].to_csv( hierfn + 'rbp.csv', index=False )
 9037        myx['brain'].to_csv( hierfn + 'rbpbrain.csv', index=False )
 9038        del myx
 9039
 9040    hier = antspyt1w.read_hierarchical( hierfn )
 9041    t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9042        hier['dataframes'], identifier=None )
 9043    rgrade = str( t1wide['resnetGrade'].iloc[0] )
 9044    if t1wide['resnetGrade'].iloc[0] < 0.20:
 9045        warnings.warn('T1w quality check indicates failure: ' + rgrade + " will not process." )
 9046        return
 9047    else:
 9048        print('T1w quality check indicates success: ' + rgrade + " will process." )
 9049
 9050    if srmodel_T1 is not None and False : # deprecated
 9051        hierfntest = hierfnSR + 'mtl.csv'
 9052        if verbose:
 9053            print( hierfntest )
 9054        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 9055        if not hierexists:
 9056            subjectpropath = os.path.dirname( hierfnSR )
 9057            if verbose:
 9058                print( subjectpropath )
 9059            os.makedirs( subjectpropath, exist_ok=True  )
 9060            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
 9061            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
 9062            if isinstance( srmodel_T1, str ):
 9063                mdlfn = re.sub( 'bestup', bestup, srmodel_T1 )
 9064            if verbose:
 9065                print( mdlfn )
 9066            if exists( mdlfn ):
 9067                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 9068            else:
 9069                print( mdlfn + " does not exist - will not run.")
 9070            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
 9071            antspyt1w.write_hierarchical( hierSR, hierfnSR )
 9072            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9073                    hierSR['dataframes'], identifier=None )
 9074            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
 9075    hier = antspyt1w.read_hierarchical( hierfn )
 9076    if exists( hierfn + 'mmwide.csv' ) :
 9077        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 9078    elif not testloop:
 9079        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9080                hier['dataframes'], identifier=None )
 9081    if not testloop:
 9082        t1imgbrn = hier['brain_n4_dnz']
 9083        t1atropos = hier['dkt_parc']['tissue_segmentation']
 9084
 9085    if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
 9086        if verbose:
 9087            print('start t1 registration')
 9088        ex_path = os.path.expanduser( "~/.antspyt1w/" )
 9089        templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 9090        template = mm_read( templatefn )
 9091        template = ants.resample_image( template, [1,1,1], use_voxels=False )
 9092        t1reg = ants.registration( template, 
 9093            hier['brain_n4_dnz'],
 9094            "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
 9095        myjac = ants.create_jacobian_determinant_image( template,
 9096            t1reg['fwdtransforms'][0], do_log=True, geom=True )
 9097        image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
 9098        if visualize:
 9099            ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
 9100            ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
 9101
 9102    if normalization_template_output is not None and normalization_template is not None:
 9103        if verbose:
 9104            print("begin group template registration")
 9105        if not exists( normout+'0GenericAffine.mat' ):
 9106            if normalization_template_spacing is not None:
 9107                normalization_template_rr=ants.resample_image(normalization_template,normalization_template_spacing)
 9108            else:
 9109                normalization_template_rr=normalization_template
 9110            greg = ants.registration( 
 9111                normalization_template_rr, 
 9112                hier['brain_n4_dnz'],
 9113                normalization_template_transform_type,
 9114                outprefix = normout, verbose=False )
 9115            myjac = ants.create_jacobian_determinant_image( template,
 9116                    greg['fwdtransforms'][0], do_log=True, geom=True )
 9117            image_write_with_thumbnail( myjac, normout + "logjacobian.nii.gz", thumb=False )
 9118            if verbose:
 9119                print("end group template registration")
 9120        else:
 9121            if verbose:
 9122                print("group template registration already done")
 9123
 9124    # loop over modalities and then unique image IDs
 9125    # we treat NM in a "special" way -- aggregating repeats
 9126    # other modalities (beyond T1) are treated individually
 9127    for overmodX in nrg_modality_list:
 9128        # define 1. input images 2. output prefix
 9129        mydoc = docsamson( overmodX, studycsv=studycsv, outputdir=outputdir, projid=projid, sid=sid, dtid=dtid, mysep=mysep,t1iid=t1iidUse )
 9130        myimgsr = mydoc['images']
 9131        mymm = mydoc['outprefix']
 9132        mymod = mydoc['modality']
 9133        if verbose:
 9134            print( mydoc )
 9135        if len(myimgsr) > 0:
 9136            dowrite=False
 9137            if verbose:
 9138                print( 'overmodX is : ' + overmodX )
 9139                print( 'example image name is : '  )
 9140                print( myimgsr )
 9141            if overmodX == 'NM2DMT':
 9142                dowrite = True
 9143                visualize = True
 9144                subjectpropath = os.path.dirname( mydoc['outprefix'] )
 9145                if verbose:
 9146                    print("subjectpropath is")
 9147                    print(subjectpropath)
 9148                    os.makedirs( subjectpropath, exist_ok=True  )
 9149                myimgsr2 = myimgsr
 9150                myimgsr2.sort()
 9151                is4d = False
 9152                temp = ants.image_read( myimgsr2[0] )
 9153                if temp.dimension == 4:
 9154                    is4d = True
 9155                if len( myimgsr2 ) == 1 and not is4d: # check dimension
 9156                    myimgsr2 = myimgsr2 + myimgsr2
 9157                mymmout = makewideout( mymm )
 9158                if verbose and not exists( mymmout ):
 9159                    print( "NM " + mymm  + ' execution ')
 9160                elif verbose and exists( mymmout ) :
 9161                    print( "NM " + mymm + ' complete ' )
 9162                if exists( mymmout ):
 9163                    continue
 9164                if is4d:
 9165                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
 9166                else:
 9167                    nmlist = []
 9168                    for zz in myimgsr2:
 9169                        nmlist.append( mm_read( zz ) )
 9170                srmodel_NM_mdl = None
 9171                if srmodel_NM is not None:
 9172                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
 9173                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.keras"
 9174                    if isinstance( srmodel_NM, str ):
 9175                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
 9176                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
 9177                    if exists( mdlfn ):
 9178                        if verbose:
 9179                            print(mdlfn)
 9180                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
 9181                    else:
 9182                        print( mdlfn + " does not exist - wont use SR")
 9183                if not testloop:
 9184                    try:
 9185                        tabPro, normPro = mm( t1, hier,
 9186                            nm_image_list = nmlist,
 9187                            srmodel=srmodel_NM_mdl,
 9188                            do_tractography=False,
 9189                            do_kk=False,
 9190                            do_normalization=templateTx,
 9191                            group_template = normalization_template,
 9192                            group_transform = groupTx,
 9193                            test_run=test_run,
 9194                            verbose=True )
 9195                    except Exception as e:
 9196                        error_info = traceback.format_exc()
 9197                        print(error_info)
 9198                        visualize=False
 9199                        dowrite=False
 9200                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9201                        pass
 9202                    if not test_run:
 9203                        if dowrite:
 9204                            write_mm( output_prefix=mymm, mm=tabPro,
 9205                                mm_norm=normPro, t1wide=None, separator=mysep )
 9206                        if visualize :
 9207                            nmpro = tabPro['NM']
 9208                            mysl = range( nmpro['NM_avg'].shape[2] )
 9209                            ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
 9210                            mysl = range( nmpro['NM_avg_cropped'].shape[2] )
 9211                            ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
 9212                            ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
 9213                            ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
 9214            else :
 9215                if len( myimgsr ) > 0 :
 9216                    dowrite=False
 9217                    myimgcount=0
 9218                    if len( myimgsr ) > 0 :
 9219                        myimg = myimgsr[ myimgcount ]
 9220                        subjectpropath = os.path.dirname( mydoc['outprefix'] )
 9221                        if verbose:
 9222                            print("subjectpropath is")
 9223                            print(subjectpropath)
 9224                        os.makedirs( subjectpropath, exist_ok=True  )
 9225                        mymmout = makewideout( mymm )
 9226                        if verbose and not exists( mymmout ):
 9227                            print( "Modality specific processing: " + mymod + " execution " )
 9228                            print( mymm )
 9229                        elif verbose and exists( mymmout ) :
 9230                            print("Modality specific processing: " + mymod + " complete " )
 9231                        if exists( mymmout ) :
 9232                            continue
 9233                        if verbose:
 9234                            print( subjectpropath )
 9235                            print( myimg )
 9236                        if not testloop:
 9237                            img = mm_read( myimg )
 9238                            ishapelen = len( img.shape )
 9239                            if mymod == 'T1w' and ishapelen == 3:
 9240                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
 9241                                    dowrite=True
 9242                                    if verbose:
 9243                                        print('start kk')
 9244                                    try:
 9245                                        tabPro, normPro = mm( t1, hier,
 9246                                            srmodel=None,
 9247                                            do_tractography=False,
 9248                                            do_kk=True,
 9249                                            do_normalization=templateTx,
 9250                                            group_template = normalization_template,
 9251                                            group_transform = groupTx,
 9252                                            test_run=test_run,
 9253                                            verbose=True )
 9254                                    except Exception as e:
 9255                                        error_info = traceback.format_exc()
 9256                                        print(error_info)
 9257                                        visualize=False
 9258                                        dowrite=False
 9259                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9260                                        pass
 9261                                    if visualize:
 9262                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
 9263                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
 9264                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
 9265                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
 9266                            if mymod == 'T2Flair' and ishapelen == 3 and np.min(img.shape) > 15:
 9267                                dowrite=True
 9268                                try:
 9269                                    tabPro, normPro = mm( t1, hier,
 9270                                        flair_image = img,
 9271                                        srmodel=None,
 9272                                        do_tractography=False,
 9273                                        do_kk=False,
 9274                                        do_normalization=templateTx,
 9275                                        group_template = normalization_template,
 9276                                        group_transform = groupTx,
 9277                                        test_run=test_run,
 9278                                        verbose=True )
 9279                                except Exception as e:
 9280                                        error_info = traceback.format_exc()
 9281                                        print(error_info)
 9282                                        visualize=False
 9283                                        dowrite=False
 9284                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9285                                        pass
 9286                                if visualize:
 9287                                    maxslice = np.min( [21, img.shape[2] ] )
 9288                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
 9289                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
 9290                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
 9291                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
 9292                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
 9293                                img2 = None
 9294                                if len( myimgsr ) > 1:
 9295                                    img2 = mm_read( myimgsr[myimgcount+1] )
 9296                                    ishapelen2 = len( img2.shape )
 9297                                    if ishapelen2 != 4 or 1 in img2.shape:
 9298                                        img2 = None
 9299                                if 1 in img.shape:
 9300                                    warnings.warn( 'rsfMRI image shape suggests it is an incorrectly converted mosaic image - will not process.')
 9301                                    dowrite=False
 9302                                    tabPro={'rsf':None}
 9303                                    normPro={'rsf':None}
 9304                                else:
 9305                                    dowrite=True
 9306                                    try:
 9307                                        tabPro, normPro = mm( t1, hier,
 9308                                            rsf_image=[img,img2],
 9309                                            srmodel=None,
 9310                                            do_tractography=False,
 9311                                            do_kk=False,
 9312                                            do_normalization=templateTx,
 9313                                            group_template = normalization_template,
 9314                                            group_transform = groupTx,
 9315                                            rsf_upsampling = rsf_upsampling,
 9316                                            test_run=test_run,
 9317                                            verbose=True )
 9318                                    except Exception as e:
 9319                                        error_info = traceback.format_exc()
 9320                                        print(error_info)
 9321                                        visualize=False
 9322                                        dowrite=False
 9323                                        tabPro={'rsf':None}
 9324                                        normPro={'rsf':None}
 9325                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9326                                        pass
 9327                                if tabPro['rsf'] is not None and visualize:
 9328                                    for tpro in tabPro['rsf']: # FIXMERSF
 9329                                        maxslice = np.min( [21, tpro['meanBold'].shape[2] ] )
 9330                                        tproprefix = mymm+mysep+str(tpro['paramset'])+mysep
 9331                                        ants.plot( tpro['meanBold'],
 9332                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=tproprefix+"meanBOLD.png" )
 9333                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['alff'],"Normalize"),
 9334                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=tproprefix+"boldALFF.png" )
 9335                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['falff'],"Normalize"),
 9336                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=tproprefix+"boldfALFF.png" )
 9337                                        dfn=tpro['dfnname']
 9338                                        ants.plot( tpro['meanBold'], tpro[dfn],
 9339                                            axis=2, nslices=maxslice, ncol=7, crop=True, title=dfn, filename=tproprefix+"boldDefaultMode.png" )
 9340                            if ( mymod == 'perf' ) and ishapelen == 4:
 9341                                dowrite=True
 9342                                try:
 9343                                    tabPro, normPro = mm( t1, hier,
 9344                                        perfusion_image=img,
 9345                                        srmodel=None,
 9346                                        do_tractography=False,
 9347                                        do_kk=False,
 9348                                        do_normalization=templateTx,
 9349                                        group_template = normalization_template,
 9350                                        group_transform = groupTx,
 9351                                        test_run=test_run,
 9352                                        perfusion_trim=perfusion_trim,
 9353                                        perfusion_m0_image=perfusion_m0_image,
 9354                                        perfusion_m0=perfusion_m0,
 9355                                        verbose=True )
 9356                                except Exception as e:
 9357                                        error_info = traceback.format_exc()
 9358                                        print(error_info)
 9359                                        visualize=False
 9360                                        dowrite=False
 9361                                        tabPro={'perf':None}
 9362                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9363                                        pass
 9364                                if tabPro['perf'] is not None and visualize:
 9365                                    maxslice = np.min( [21, tabPro['perf']['meanBold'].shape[2] ] )
 9366                                    ants.plot( tabPro['perf']['perfusion'],
 9367                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='perfusion image', filename=mymm+mysep+"perfusion.png" )
 9368                                    ants.plot( tabPro['perf']['cbf'],
 9369                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='CBF image', filename=mymm+mysep+"cbf.png" )
 9370                                    ants.plot( tabPro['perf']['m0'],
 9371                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='M0 image', filename=mymm+mysep+"m0.png" )
 9372
 9373                            if ( mymod == 'pet3d' ) and ishapelen == 3:
 9374                                dowrite=True
 9375                                try:
 9376                                    tabPro, normPro = mm( t1, hier,
 9377                                        srmodel=None,
 9378                                        do_tractography=False,
 9379                                        do_kk=False,
 9380                                        do_normalization=templateTx,
 9381                                        group_template = normalization_template,
 9382                                        group_transform = groupTx,
 9383                                        test_run=test_run,
 9384                                        pet_3d_image=img,
 9385                                        verbose=True )
 9386                                except Exception as e:
 9387                                        error_info = traceback.format_exc()
 9388                                        print(error_info)
 9389                                        visualize=False
 9390                                        dowrite=False
 9391                                        tabPro={'pet3d':None}
 9392                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9393                                        pass
 9394                                if tabPro['pet3d'] is not None and visualize:
 9395                                    maxslice = np.min( [21, tabPro['pet3d']['pet3d'].shape[2] ] )
 9396                                    ants.plot( tabPro['pet3d']['pet3d'],
 9397                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='PET image', filename=mymm+mysep+"pet3d.png" )
 9398                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
 9399                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
 9400                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
 9401                                imgList = [ img ]
 9402                                bvalfnList = [ bvalfn ]
 9403                                bvecfnList = [ bvecfn ]
 9404                                missing_dti_data=False # bval, bvec or images
 9405                                if len( myimgsr ) == 2:  # find DTI_RL
 9406                                    dtilrfn = myimgsr[myimgcount+1]
 9407                                    if exists( dtilrfn ):
 9408                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 9409                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 9410                                        imgRL = ants.image_read( dtilrfn )
 9411                                        imgList.append( imgRL )
 9412                                        bvalfnList.append( bvalfnRL )
 9413                                        bvecfnList.append( bvecfnRL )
 9414                                elif len( myimgsr ) == 3:  # find DTI_RL
 9415                                    print("DTI trinity")
 9416                                    dtilrfn = myimgsr[myimgcount+1]
 9417                                    dtilrfn2 = myimgsr[myimgcount+2]
 9418                                    if exists( dtilrfn ) and exists( dtilrfn2 ):
 9419                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 9420                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 9421                                        bvalfnRL2 = re.sub( '.nii.gz', '.bval' , dtilrfn2 )
 9422                                        bvecfnRL2 = re.sub( '.nii.gz', '.bvec' , dtilrfn2 )
 9423                                        imgRL = ants.image_read( dtilrfn )
 9424                                        imgRL2 = ants.image_read( dtilrfn2 )
 9425                                        bvals, bvecs = read_bvals_bvecs( bvalfnRL , bvecfnRL  )
 9426                                        print( bvals.max() )
 9427                                        bvals2, bvecs2 = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
 9428                                        print( bvals2.max() )
 9429                                        temp = merge_dwi_data( imgRL, bvals, bvecs, imgRL2, bvals2, bvecs2  )
 9430                                        imgList.append( temp[0] )
 9431                                        bvalfnList.append( mymm+mysep+'joined.bval' )
 9432                                        bvecfnList.append( mymm+mysep+'joined.bvec' )
 9433                                        write_bvals_bvecs( temp[1], temp[2], mymm+mysep+'joined' )
 9434                                        bvalsX, bvecsX = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
 9435                                        print( bvalsX.max() )
 9436                                # check existence of all files expected ...
 9437                                for dtiex in bvalfnList+bvecfnList+myimgsr:
 9438                                    if not exists(dtiex):
 9439                                        print('mm_csv: missing dti data ' + dtiex )
 9440                                        missing_dti_data=True
 9441                                        dowrite=False
 9442                                if not missing_dti_data:
 9443                                    dowrite=True
 9444                                    srmodel_DTI_mdl=None
 9445                                    if srmodel_DTI is not None:
 9446                                        temp = ants.get_spacing(img)
 9447                                        dtspc=[temp[0],temp[1],temp[2]]
 9448                                        bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
 9449                                        mdlfn = re.sub( 'bestup', bestup, srmodel_DTI )
 9450                                        if isinstance( srmodel_DTI, str ):
 9451                                            srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
 9452                                            mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
 9453                                        if exists( mdlfn ):
 9454                                            if verbose:
 9455                                                print(mdlfn)
 9456                                            srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 9457                                        else:
 9458                                            print(mdlfn + " does not exist - wont use SR")
 9459                                    try:
 9460                                        tabPro, normPro = mm( t1, hier,
 9461                                            dw_image=imgList,
 9462                                            bvals = bvalfnList,
 9463                                            bvecs = bvecfnList,
 9464                                            srmodel=srmodel_DTI_mdl,
 9465                                            do_tractography=not test_run,
 9466                                            do_kk=False,
 9467                                            do_normalization=templateTx,
 9468                                            group_template = normalization_template,
 9469                                            group_transform = groupTx,
 9470                                            dti_motion_correct = dti_motion_correct,
 9471                                            dti_denoise = dti_denoise,
 9472                                            test_run=test_run,
 9473                                            verbose=True )
 9474                                    except Exception as e:
 9475                                            error_info = traceback.format_exc()
 9476                                            print(error_info)
 9477                                            visualize=False
 9478                                            dowrite=False
 9479                                            tabPro={'DTI':None}
 9480                                            print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9481                                            pass
 9482                                    mydti = tabPro['DTI']
 9483                                    if visualize and tabPro['DTI'] is not None:
 9484                                        maxslice = np.min( [21, mydti['recon_fa'] ] )
 9485                                        ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
 9486                                        ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
 9487                                        ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
 9488                            if dowrite:
 9489                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep )
 9490                                for mykey in normPro.keys():
 9491                                    if normPro[mykey] is not None and normPro[mykey].components == 1:
 9492                                        if visualize and False:
 9493                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
 9494        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
 9495            return
 9496        if verbose:
 9497            print("done with " + overmodX )
 9498    if verbose:
 9499        print("mm_nrg complete.")
 9500    return
 9501
 9502def spec_taper(x, p=0.1):
 9503    from scipy import stats, signal, fft
 9504    from statsmodels.regression.linear_model import yule_walker
 9505    """
 9506    Computes a tapered version of x, with tapering p.
 9507
 9508    Adapted from R's stats::spec.taper at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9509
 9510    """
 9511
 9512    p = np.r_[p]
 9513    assert np.all((p >= 0) & (p < 0.5)), "'p' must be between 0 and 0.5"
 9514
 9515    x = np.r_[x].astype('float64')
 9516    original_shape = x.shape
 9517
 9518    assert len(original_shape) <= 2, "'x' must have at most 2 dimensions"
 9519    while len(x.shape) < 2:
 9520        x = np.expand_dims(x, axis=1)
 9521
 9522    nr, nc = x.shape
 9523    if len(p) == 1:
 9524        p = p * np.ones(nc)
 9525    else:
 9526        assert len(p) == nc, "length of 'p' must be 1 or equal the number of columns of 'x'"
 9527
 9528    for i in range(nc):
 9529        m = int(np.floor(nr * p[i]))
 9530        if m == 0:
 9531            continue
 9532        w = 0.5 * (1 - np.cos(np.pi * np.arange(1, 2 * m, step=2)/(2 * m)))
 9533        x[:, i] = np.r_[w, np.ones(nr - 2 * m), w[::-1]] * x[:, i]
 9534
 9535    x = np.reshape(x, original_shape)
 9536    return x
 9537
 9538def plot_spec(spec_res, coverage=None, ax=None, title=None):
 9539    import matplotlib.pyplot as plt
 9540    """Convenience plotting method, also includes confidence cross in the same style as R.
 9541
 9542    Note that the location of the cross is irrelevant; only width and height matter."""
 9543    f, Pxx = spec_res['freq'], spec_res['spec']
 9544
 9545    if coverage is not None:
 9546        ci = spec_ci(spec_res['df'], coverage=coverage)
 9547        conf_x = (max(spec_res['freq']) - spec_res['bandwidth']) + np.r_[-0.5, 0.5] * spec_res['bandwidth']
 9548        conf_y = max(spec_res['spec']) / ci[1]
 9549
 9550    if ax is None:
 9551        ax = plt.gca()
 9552
 9553    ax.plot(f, Pxx, color='C0')
 9554    ax.set_xlabel('Frequency')
 9555    ax.set_ylabel('Log Spectrum')
 9556    ax.set_yscale('log')
 9557    if coverage is not None:
 9558        ax.plot(np.mean(conf_x) * np.r_[1, 1], conf_y * ci, color='red')
 9559        ax.plot(conf_x, np.mean(conf_y) * np.r_[1, 1], color='red')
 9560
 9561    ax.set_title(spec_res['method'] if title is None else title)
 9562
 9563def spec_ci(df, coverage=0.95):
 9564    from scipy import stats, signal, fft
 9565    from statsmodels.regression.linear_model import yule_walker
 9566    """
 9567    Computes the confidence interval for a spectral fit, based on the number of degrees of freedom.
 9568
 9569    Adapted from R's stats::plot.spec at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9570
 9571    """
 9572
 9573    assert coverage >= 0 and coverage < 1, "coverage probability out of range [0, 1)"
 9574
 9575    tail = 1 - coverage
 9576
 9577    phi = stats.chi2.cdf(x=df, df=df)
 9578    upper_quantile = 1 - tail * (1 - phi)
 9579    lower_quantile = tail * phi
 9580
 9581    return df / stats.chi2.ppf([upper_quantile, lower_quantile], df=df)
 9582
 9583def spec_pgram(x, xfreq=1, spans=None, kernel=None, taper=0.1, pad=0, fast=True, demean=False, detrend=True,
 9584               plot=True, **kwargs):
 9585    """
 9586    Computes the spectral density estimate using a periodogram.  Optionally, it also:
 9587    - Uses a provided kernel window, or a sequence of spans for convoluted modified Daniell kernels.
 9588    - Tapers the start and end of the series to avoid end-of-signal effects.
 9589    - Pads the provided series before computation, adding pad*(length of series) zeros at the end.
 9590    - Pads the provided series before computation to speed up FFT calculation.
 9591    - Performs demeaning or detrending on the series.
 9592    - Plots results.
 9593
 9594    Implemented to ensure compatibility with R's spectral functions, as opposed to reusing scipy's periodogram.
 9595
 9596    Adapted from R's stats::spec.pgram at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9597
 9598    example:
 9599
 9600    import numpy as np
 9601    import antspymm
 9602    myx = np.random.rand(100,1)
 9603    myspec = antspymm.spec_pgram(myx,0.5)
 9604
 9605    """
 9606    from scipy import stats, signal, fft
 9607    from statsmodels.regression.linear_model import yule_walker
 9608    def daniell_window_modified(m):
 9609        """ Single-pass modified Daniell kernel window.
 9610
 9611        Weight is normalized to add up to 1, and all values are the same, other than the first and the
 9612        last, which are divided by 2.
 9613        """
 9614        def w(k):
 9615            return np.where(np.abs(k) < m, 1 / (2*m), np.where(np.abs(k) == m, 1/(4*m), 0))
 9616
 9617        return w(np.arange(-m, m+1))
 9618
 9619    def daniell_window_convolve(v):
 9620        """ Convolved version of multiple modified Daniell kernel windows.
 9621
 9622        Parameter v should be an iterable of m values.
 9623        """
 9624
 9625        if len(v) == 0:
 9626            return np.r_[1]
 9627
 9628        if len(v) == 1:
 9629            return daniell_window_modified(v[0])
 9630
 9631        return signal.convolve(daniell_window_modified(v[0]), daniell_window_convolve(v[1:]))
 9632
 9633    # Ensure we can store non-integers in x, and that it is a numpy object
 9634    x = np.r_[x].astype('float64')
 9635    original_shape = x.shape
 9636
 9637    # Ensure correct dimensions
 9638    assert len(original_shape) <= 2, "'x' must have at most 2 dimensions"
 9639    while len(x.shape) < 2:
 9640        x = np.expand_dims(x, axis=1)
 9641
 9642    N, nser = x.shape
 9643    N0 = N
 9644
 9645    # Ensure only one of spans, kernel is provided, and build the kernel window if needed
 9646    assert (spans is None) or (kernel is None), "must specify only one of 'spans' or 'kernel'"
 9647    if spans is not None:
 9648        kernel = daniell_window_convolve(np.floor_divide(np.r_[spans], 2))
 9649
 9650    # Detrend or demean the series
 9651    if detrend:
 9652        t = np.arange(N) - (N - 1)/2
 9653        sumt2 = N * (N**2 - 1)/12
 9654        x -= (np.repeat(np.expand_dims(np.mean(x, axis=0), 0), N, axis=0) + np.outer(np.sum(x.T * t, axis=1), t/sumt2).T)
 9655    elif demean:
 9656        x -= np.mean(x, axis=0)
 9657
 9658    # Compute taper and taper adjustment variables
 9659    x = spec_taper(x, taper)
 9660    u2 = (1 - (5/8) * taper * 2)
 9661    u4 = (1 - (93/128) * taper * 2)
 9662
 9663    # Pad the series with copies of the same shape, but filled with zeroes
 9664    if pad > 0:
 9665        x = np.r_[x, np.zeros((pad * x.shape[0], x.shape[1]))]
 9666        N = x.shape[0]
 9667
 9668    # Further pad the series to accelerate FFT computation
 9669    if fast:
 9670        newN = fft.next_fast_len(N, True)
 9671        x = np.r_[x, np.zeros((newN - N, x.shape[1]))]
 9672        N = newN
 9673
 9674    # Compute the Fourier frequencies (R's spec.pgram convention style)
 9675    Nspec = int(np.floor(N/2))
 9676    freq = (np.arange(Nspec) + 1) * xfreq / N
 9677
 9678    # Translations to keep same row / column convention as stats::mvfft
 9679    xfft = fft.fft(x.T).T
 9680
 9681    # Compute the periodogram for each i, j
 9682    pgram = np.empty((N, nser, nser), dtype='complex')
 9683    for i in range(nser):
 9684        for j in range(nser):
 9685            pgram[:, i, j] = xfft[:, i] * np.conj(xfft[:, j]) / (N0 * xfreq)
 9686            pgram[0, i, j] = 0.5 * (pgram[1, i, j] + pgram[-1, i, j])
 9687
 9688    if kernel is None:
 9689        # Values pre-adjustment
 9690        df = 2
 9691        bandwidth = np.sqrt(1 / 12)
 9692    else:
 9693        def conv_circular(signal, kernel):
 9694            """
 9695            Performs 1D circular convolution, in the same style as R::kernapply,
 9696            assuming the kernel window is centered at 0.
 9697            """
 9698            pad = len(signal) - len(kernel)
 9699            half_window = int((len(kernel) + 1) / 2)
 9700            indexes = range(-half_window, len(signal) - half_window)
 9701            orig_conv = np.real(fft.ifft(fft.fft(signal) * fft.fft(np.r_[np.zeros(pad), kernel])))
 9702            return orig_conv.take(indexes, mode='wrap')
 9703
 9704        # Convolve pgram with kernel with circular conv
 9705        for i in range(nser):
 9706            for j in range(nser):
 9707                pgram[:, i, j] = conv_circular(pgram[:, i, j], kernel)
 9708
 9709        df = 2 / np.sum(kernel**2)
 9710        m = (len(kernel) - 1)/2
 9711        k = np.arange(-m, m+1)
 9712        bandwidth = np.sqrt(np.sum((1/12 + k**2) * kernel))
 9713
 9714    df = df/(u4/u2**2)*(N0/N)
 9715    bandwidth = bandwidth * xfreq/N
 9716
 9717    # Remove padded results
 9718    pgram = pgram[1:(Nspec+1), :, :]
 9719
 9720    spec = np.empty((Nspec, nser))
 9721    for i in range(nser):
 9722        spec[:, i] = np.real(pgram[:, i, i])
 9723
 9724    if nser == 1:
 9725        coh = None
 9726        phase = None
 9727    else:
 9728        coh = np.empty((Nspec, int(nser * (nser - 1)/2)))
 9729        phase = np.empty((Nspec, int(nser * (nser - 1)/2)))
 9730        for i in range(nser):
 9731            for j in range(i+1, nser):
 9732                index = int(i + j*(j-1)/2)
 9733                coh[:, index] = np.abs(pgram[:, i, j])**2 / (spec[:, i] * spec[:, j])
 9734                phase[:, index] = np.angle(pgram[:, i, j])
 9735
 9736    spec = spec / u2
 9737    spec = spec.squeeze()
 9738
 9739    results = {
 9740        'freq': freq,
 9741        'spec': spec,
 9742        'coh': coh,
 9743        'phase': phase,
 9744        'kernel': kernel,
 9745        'df': df,
 9746        'bandwidth': bandwidth,
 9747        'n.used': N,
 9748        'orig.n': N0,
 9749        'taper': taper,
 9750        'pad': pad,
 9751        'detrend': detrend,
 9752        'demean': demean,
 9753        'method': 'Raw Periodogram' if kernel is None else 'Smoothed Periodogram'
 9754    }
 9755
 9756    if plot:
 9757        plot_spec(results, coverage=0.95, **kwargs)
 9758
 9759    return results
 9760
 9761def alffmap( x, flo=0.01, fhi=0.1, tr=1, detrend = True ):
 9762    """
 9763    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
 9764    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
 9765    are related measures that quantify the amplitude of low frequency
 9766    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
 9767    same function in ANTsR.
 9768
 9769    x input vector for the time series of interest
 9770    flo low frequency, typically 0.01
 9771    fhi high frequency, typically 0.1
 9772    tr the period associated with the vector x (inverse of frequency)
 9773    detrend detrend the input time series
 9774
 9775    return vector is output showing ALFF and fALFF values
 9776    """
 9777    temp = spec_pgram( x, xfreq=1.0/tr, demean=False, detrend=detrend, taper=0, fast=True, plot=False )
 9778    fselect = np.logical_and( temp['freq'] >= flo, temp['freq'] <= fhi )
 9779    denom = (temp['spec']).sum()
 9780    numer = (temp['spec'][fselect]).sum()
 9781    return {  'alff':numer, 'falff': numer/denom }
 9782
 9783
 9784def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
 9785    """
 9786    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
 9787    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
 9788    are related measures that quantify the amplitude of low frequency
 9789    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
 9790
 9791    x - input clean resting state fmri
 9792    mask - mask over which to compute f/alff
 9793    flo - low frequency, typically 0.01
 9794    fhi - high frequency, typically 0.1
 9795    nuisance - optional nuisance matrix
 9796
 9797    return dictionary with ALFF and fALFF images
 9798    """
 9799    xmat = ants.timeseries_to_matrix( x, mask )
 9800    if nuisance is not None:
 9801        xmat = ants.regress_components( xmat, nuisance )
 9802    alffvec = xmat[0,:]*0
 9803    falffvec = xmat[0,:]*0
 9804    mytr = ants.get_spacing( x )[3]
 9805    for n in range( xmat.shape[1] ):
 9806        temp = alffmap( xmat[:,n], flo=flo, fhi=fhi, tr=mytr )
 9807        alffvec[n]=temp['alff']
 9808        falffvec[n]=temp['falff']
 9809    alffi=ants.make_image( mask, alffvec )
 9810    falffi=ants.make_image( mask, falffvec )
 9811    alfftrimmedmean = calculate_trimmed_mean( alffvec, 0.01 )
 9812    falfftrimmedmean = calculate_trimmed_mean( falffvec, 0.01 )
 9813    alffi=alffi / alfftrimmedmean
 9814    falffi=falffi / falfftrimmedmean
 9815    return {  'alff': alffi, 'falff': falffi }
 9816
 9817
 9818def down2iso( x, interpolation='linear', takemin=False ):
 9819    """
 9820    will downsample an anisotropic image to an isotropic resolution
 9821
 9822    x: input image
 9823
 9824    interpolation: linear or nearestneighbor
 9825
 9826    takemin : boolean map to min space; otherwise max
 9827
 9828    return image downsampled to isotropic resolution
 9829    """
 9830    spc = ants.get_spacing( x )
 9831    if takemin:
 9832        newspc = np.asarray(spc).min()
 9833    else:
 9834        newspc = np.asarray(spc).max()
 9835    newspc = np.repeat( newspc, x.dimension )
 9836    if interpolation == 'linear':
 9837        xs = ants.resample_image( x, newspc, interp_type=0)
 9838    else:
 9839        xs = ants.resample_image( x, newspc, interp_type=1)
 9840    return xs
 9841
 9842
 9843def read_mm_csv( x, is_t1=False, colprefix=None, separator='-', verbose=False ):
 9844    splitter=os.path.basename(x).split( separator )
 9845    lensplit = len( splitter )-1
 9846    temp = os.path.basename(x)
 9847    temp = os.path.splitext(temp)[0]
 9848    temp = re.sub(separator+'mmwide','',temp)
 9849    idcols = ['u_hier_id','sid','visitdate','modality','mmimageuid','t1imageuid']
 9850    df = pd.DataFrame( columns = idcols, index=range(1) )
 9851    valstoadd = [temp] + splitter[1:(lensplit-1)]
 9852    if is_t1:
 9853        valstoadd = valstoadd + [splitter[(lensplit-1)],splitter[(lensplit-1)]]
 9854    else:
 9855        split2=splitter[(lensplit-1)].split( "_" )
 9856        if len(split2) == 1:
 9857            split2.append( split2[0] )
 9858        if len(valstoadd) == 3:
 9859            valstoadd = valstoadd + [split2[0]] + [math.nan] + [split2[1]]
 9860        else:
 9861            valstoadd = valstoadd + [split2[0],split2[1]]
 9862    if verbose:
 9863        print( valstoadd )
 9864    df.iloc[0] = valstoadd
 9865    if verbose:
 9866        print( "read xdf: " + x )
 9867    xdf = pd.read_csv( x )
 9868    df.reset_index()
 9869    xdf.reset_index(drop=True)
 9870    if "Unnamed: 0" in xdf.columns:
 9871        holder=xdf.pop( "Unnamed: 0" )
 9872    if "Unnamed: 1" in xdf.columns:
 9873        holder=xdf.pop( "Unnamed: 1" )
 9874    if "u_hier_id.1" in xdf.columns:
 9875        holder=xdf.pop( "u_hier_id.1" )
 9876    if "u_hier_id" in xdf.columns:
 9877        holder=xdf.pop( "u_hier_id" )
 9878    if not is_t1:
 9879        if 'resnetGrade' in xdf.columns:
 9880            index_no = xdf.columns.get_loc('resnetGrade')
 9881            xdf = xdf.drop( xdf.columns[range(index_no+1)] , axis=1)
 9882
 9883    if xdf.shape[0] == 2:
 9884        xdfcols = xdf.columns
 9885        xdf = xdf.iloc[1]
 9886        ddnum = xdf.to_numpy()
 9887        ddnum = ddnum.reshape([1,ddnum.shape[0]])
 9888        newcolnames = xdf.index.to_list()
 9889        if len(newcolnames) != ddnum.shape[1]:
 9890            print("Cannot Merge : Shape MisMatch " + str( len(newcolnames) ) + " " + str(ddnum.shape[1]))
 9891        else:
 9892            xdf = pd.DataFrame(ddnum, columns=xdfcols )
 9893    if xdf.shape[1] == 0:
 9894        return None
 9895    if colprefix is not None:
 9896        xdf.columns=colprefix + xdf.columns
 9897    return pd.concat( [df,xdf], axis=1, ignore_index=False )
 9898
 9899def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False,
 9900progress=False, verbose=False ):
 9901    """
 9902    extend a study data frame with wide outputs
 9903
 9904    sdf : the input study dataframe from antspymm QC output
 9905
 9906    processing_dir:  the directory location of the processed data 
 9907
 9908    separator : string usually '-' or '_'
 9909
 9910    sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas
 9911
 9912    date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas
 9913
 9914    id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas
 9915
 9916    report_missing : boolean combined with verbose will report missing modalities
 9917
 9918    progress : integer reports percent progress modulo progress value 
 9919
 9920    verbose : boolean
 9921    """
 9922    from os.path import exists
 9923    musthavecols = ['projectID', 'subjectID','date','imageID']
 9924    for k in range(len(musthavecols)):
 9925        if not musthavecols[k] in sdf.keys():
 9926            raise ValueError('sdf is missing column ' +musthavecols[k] + ' in merge_wides_to_study_dataframe' )
 9927    possible_iids = [ 'imageID', 'imageID', 'imageID', 'flairid', 'dtid1', 'dtid2', 'rsfid1', 'rsfid2', 'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5', 'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'perfid' ]
 9928    modality_ids = [ 'T1wHierarchical', 'T1wHierarchicalSR', 'T1w', 'T2Flair', 'DTI', 'DTI', 'rsfMRI', 'rsfMRI', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'perf']
 9929    alldf=pd.DataFrame()
 9930    for myk in sdf.index:
 9931        if progress > 0 and int(myk) % int(progress) == 0:
 9932            print( str( round( myk/sdf.shape[0]*100.0)) + "%...", end='', flush=True)
 9933        if verbose:
 9934            print( "DOROW " + str(myk) + ' of ' + str( sdf.shape[0] ) )
 9935        csvrow = sdf.loc[sdf.index == myk].dropna(axis=1)
 9936        ct=-1
 9937        for iidkey in possible_iids:
 9938            ct=ct+1
 9939            mod_name = modality_ids[ct]
 9940            if iidkey in csvrow.keys():
 9941                if id_is_int:
 9942                    iid = str( int( csvrow[iidkey].iloc[0] ) )
 9943                else:
 9944                    iid = str( csvrow[iidkey].iloc[0] )
 9945                if verbose:
 9946                    print( "iidkey " + iidkey + " modality " + mod_name + ' iid '+ iid )
 9947                pid=str(csvrow['projectID'].iloc[0] )
 9948                if sid_is_int:
 9949                    sid=str(int(csvrow['subjectID'].iloc[0] ))
 9950                else:
 9951                    sid=str(csvrow['subjectID'].iloc[0] )
 9952                if date_is_int:
 9953                    dt=str(int(csvrow['date'].iloc[0]))
 9954                else:
 9955                    dt=str(csvrow['date'].iloc[0])
 9956                if id_is_int:
 9957                    t1iid=str(int(csvrow['imageID'].iloc[0]))
 9958                else:
 9959                    t1iid=str(csvrow['imageID'].iloc[0])
 9960                if t1iid != iid:
 9961                    iidj=iid+"_"+t1iid
 9962                else:
 9963                    iidj=iid
 9964                rootid = pid +separator+ sid +separator+dt+separator+mod_name+separator+iidj
 9965                myext = rootid +separator+'mmwide.csv'
 9966                nrgwidefn=os.path.join( processing_dir, pid, sid, dt, mod_name, iid, myext )
 9967                moddersub = mod_name
 9968                is_t1=False
 9969                if mod_name == 'T1wHierarchical':
 9970                    is_t1=True
 9971                    moddersub='T1Hier'
 9972                elif mod_name == 'T1wHierarchicalSR':
 9973                    is_t1=True
 9974                    moddersub='T1HSR'
 9975                if exists( nrgwidefn ):
 9976                    if verbose:
 9977                        print( nrgwidefn + " exists")
 9978                    mm=read_mm_csv( nrgwidefn, colprefix=moddersub+'_', is_t1=is_t1, separator=separator, verbose=verbose )
 9979                    if mm is not None:
 9980                        if mod_name == 'T1wHierarchical':
 9981                            a=list( csvrow.keys() )
 9982                            b=list( mm.keys() )
 9983                            abintersect=list(set(b).intersection( set(a) ) )
 9984                            if len( abintersect  ) > 0 :
 9985                                for qq in abintersect:
 9986                                    mm.pop( qq )
 9987                        # mm.index=csvrow.index
 9988                        uidname = mod_name + '_mmwide_filename'
 9989                        mm[ uidname ] = rootid
 9990                        csvrow=pd.concat( [csvrow,mm], axis=1, ignore_index=False )
 9991                else:
 9992                    if verbose and report_missing:
 9993                        print( nrgwidefn + " absent")
 9994        if alldf.shape[0] == 0:
 9995            alldf = csvrow.copy()
 9996            alldf = alldf.loc[:,~alldf.columns.duplicated()]
 9997        else:
 9998            csvrow=csvrow.loc[:,~csvrow.columns.duplicated()]
 9999            alldf = alldf.loc[:,~alldf.columns.duplicated()]
10000            alldf = pd.concat( [alldf, csvrow], axis=0, ignore_index=True )
10001    return alldf
10002
10003def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False ):
10004    moddersub = re.sub( "[*]","",nrg_modality)
10005    nmdf=pd.DataFrame()
10006    for k in range( hierdfin.shape[0] ):
10007        if progress is not None:
10008            if k % progress == 0:
10009                progger = str( np.round( k / hierdfin.shape[0] * 100 ) )
10010                print( progger, end ="...", flush=True)
10011        temp = mm_wide_csvs[k]
10012        mypartsf = temp.split("T1wHierarchical")
10013        myparts = mypartsf[0]
10014        t1iid = str(mypartsf[1].split("/")[1])
10015        fnsnm = glob.glob(myparts+"/" + nrg_modality + "/*/*" + t1iid + "*wide.csv")
10016        if len( fnsnm ) > 0 :
10017            for y in fnsnm:
10018                temp=read_mm_csv( y, colprefix=moddersub+'_', is_t1=False, separator=separator, verbose=verbose )
10019                if temp is not None:
10020                    nmdf=pd.concat( [nmdf, temp], axis=0, ignore_index=False )
10021    return nmdf
10022
10023def bind_wide_mm_csvs( mm_wide_csvs, merge=True, separator='-', verbose = 0 ) :
10024    """
10025    will convert a list of t1w hierarchical csv filenames to a merged dataframe
10026
10027    returns a pair of data frames, the left side having all entries and the
10028        right side having row averaged entries i.e. unique values for each visit
10029
10030    set merge to False to return individual dataframes ( for debugging )
10031
10032    return alldata, row_averaged_data
10033    """
10034    mm_wide_csvs.sort()
10035    if not mm_wide_csvs:
10036        print("No files found with specified pattern")
10037        return
10038    # 1. row-bind the t1whier data
10039    # 2. same for each other modality
10040    # 3. merge the modalities by the keys
10041    hierdf = pd.DataFrame()
10042    for y in mm_wide_csvs:
10043        temp=read_mm_csv( y, colprefix='T1Hier_', separator=separator, is_t1=True )
10044        if temp is not None:
10045            hierdf=pd.concat( [hierdf, temp], axis=0, ignore_index=False )
10046    if verbose > 0:
10047        mypro=50
10048    else:
10049        mypro=None
10050    if verbose > 0:
10051        print("thickness")
10052    thkdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T1w', progress=mypro, verbose=verbose==2)
10053    if verbose > 0:
10054        print("flair")
10055    flairdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T2Flair', progress=mypro, verbose=verbose==2)
10056    if verbose > 0:
10057        print("NM")
10058    nmdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'NM2DMT', progress=mypro, verbose=verbose==2)
10059    if verbose > 0:
10060        print("rsf")
10061    rsfdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'rsfMRI*', progress=mypro, verbose=verbose==2)
10062    if verbose > 0:
10063        print("dti")
10064    dtidf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'DTI*', progress=mypro, verbose=verbose==2 )
10065    if not merge:
10066        return hierdf, thkdf, flairdf, nmdf, rsfdf, dtidf
10067    hierdfmix = hierdf.copy()
10068    modality_df_suffixes = [
10069        (thkdf, "_thk"),
10070        (flairdf, "_flair"),
10071        (nmdf, "_nm"),
10072        (rsfdf, "_rsf"),
10073        (dtidf, "_dti"),
10074    ]
10075    for pair in modality_df_suffixes:
10076        hierdfmix = merge_mm_dataframe(hierdfmix, pair[0], pair[1])
10077    hierdfmix = hierdfmix.replace(r'^\s*$', np.nan, regex=True)
10078    return hierdfmix, hierdfmix.groupby("u_hier_id", as_index=False).mean(numeric_only=True)
10079
10080def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10081    try:
10082        hierdf = hierdf.merge(mmdf, on=['sid', 'visitdate', 't1imageuid'], suffixes=("",mm_suffix),how='left')
10083        return hierdf
10084    except KeyError:
10085        return hierdf
10086
10087def augment_image( x,  max_rot=10, nzsd=1 ):
10088    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rot*(-1.0), max_rot ), reference=x )
10089    tx = rRotGenerator.transform()
10090    itx = ants.invert_ants_transform(tx)
10091    y = ants.apply_ants_transform_to_image( tx, x, x, interpolation='linear')
10092    y = ants.add_noise_to_image( y,'additivegaussian', [0,nzsd] )
10093    return y, tx, itx
10094
10095def boot_wmh( flair, t1, t1seg, mmfromconvexhull = 0.0, strict=True,
10096        probability_mask=None, prior_probability=None, n_simulations=16,
10097        random_seed = 42,
10098        verbose=False ) :
10099    import random
10100    random.seed( random_seed )
10101    if verbose and prior_probability is None:
10102        print("augmented flair")
10103    if verbose and prior_probability is not None:
10104        print("augmented flair with prior")
10105    wmh_sum_aug = 0
10106    wmh_sum_prior_aug = 0
10107    augprob = flair * 0.0
10108    augprob_prior = None
10109    if prior_probability is not None:
10110        augprob_prior = flair * 0.0
10111    for n in range(n_simulations):
10112        augflair, tx, itx = augment_image( ants.iMath(flair,"Normalize"), 5, 0.01 )
10113        locwmh = wmh( augflair, t1, t1seg, mmfromconvexhull = mmfromconvexhull,
10114            strict=strict, probability_mask=None, prior_probability=prior_probability )
10115        if verbose:
10116            print( "flair sim: " + str(n) + " vol: " + str( locwmh['wmh_mass'] )+ " vol-prior: " + str( locwmh['wmh_mass_prior'] )+ " snr: " + str( locwmh['wmh_SNR'] ) )
10117        wmh_sum_aug = wmh_sum_aug + locwmh['wmh_mass']
10118        wmh_sum_prior_aug = wmh_sum_prior_aug + locwmh['wmh_mass_prior']
10119        temp = locwmh['WMH_probability_map']
10120        augprob = augprob + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10121        if prior_probability is not None:
10122            temp = locwmh['WMH_posterior_probability_map']
10123            augprob_prior = augprob_prior + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10124    augprob = augprob * (1.0/float( n_simulations ))
10125    if prior_probability is not None:
10126        augprob_prior = augprob_prior * (1.0/float( n_simulations ))
10127    wmh_sum_aug = wmh_sum_aug / float( n_simulations )
10128    wmh_sum_prior_aug = wmh_sum_prior_aug / float( n_simulations )
10129    return{
10130      'flair' : ants.iMath(flair,"Normalize"),
10131      'WMH_probability_map' : augprob,
10132      'WMH_posterior_probability_map' : augprob_prior,
10133      'wmh_mass': wmh_sum_aug,
10134      'wmh_mass_prior': wmh_sum_prior_aug,
10135      'wmh_evr': locwmh['wmh_evr'],
10136      'wmh_SNR': locwmh['wmh_SNR']  }
10137
10138
10139def threaded_bind_wide_mm_csvs( mm_wide_csvs, n_workers ):
10140    from concurrent.futures import as_completed
10141    from concurrent import futures
10142    import concurrent.futures
10143    def chunks(l, n):
10144        """Yield n number of sequential chunks from l."""
10145        d, r = divmod(len(l), n)
10146        for i in range(n):
10147            si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
10148            yield l[si:si+(d+1 if i < r else d)]
10149    import numpy as np
10150    newx = list( chunks( mm_wide_csvs, n_workers ) )
10151    import pandas as pd
10152    alldf = pd.DataFrame()
10153    alldfavg = pd.DataFrame()
10154    with futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
10155        to_do = []
10156        for group in range(len(newx)) :
10157            future = executor.submit(bind_wide_mm_csvs, newx[group] )
10158            to_do.append(future)
10159        results = []
10160        for future in futures.as_completed(to_do):
10161            res0, res1 = future.result()
10162            alldf=pd.concat(  [alldf, res0 ], axis=0, ignore_index=False )
10163            alldfavg=pd.concat(  [alldfavg, res1 ], axis=0, ignore_index=False )
10164    return alldf, alldfavg
10165
10166
10167def get_names_from_data_frame(x, demogIn, exclusions=None):
10168    """
10169    data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]}
10170    antspymm.get_names_from_data_frame( ['e'], df )
10171    antspymm.get_names_from_data_frame( ['a','e'], df )
10172    antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )
10173    """
10174    # Check if x is a string and convert it to a list
10175    if isinstance(x, str):
10176        x = [x]
10177    def get_unique( qq ):
10178        unique = []
10179        for number in qq:
10180            if number in unique:
10181                continue
10182            else:
10183                unique.append(number)
10184        return unique
10185    outnames = list(demogIn.columns[demogIn.columns.str.contains(x[0])])
10186    if len(x) > 1:
10187        for y in x[1:]:
10188            outnames = [i for i in outnames if y in i]
10189    outnames = get_unique( outnames )
10190    if exclusions is not None:
10191        toexclude = [name for name in outnames if exclusions[0] in name ]
10192        if len(exclusions) > 1:
10193            for zz in exclusions[1:]:
10194                toexclude.extend([name for name in outnames if zz in name ])
10195        if len(toexclude) > 0:
10196            outnames = [name for name in outnames if name not in toexclude]
10197    return outnames
10198
10199
10200def average_mm_df( jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False ):
10201    """
10202    jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )
10203    """
10204
10205    jmm = jmm_in.copy()
10206    dxcols=['subjectid1','subjectid2','modalityid','joinid','correlation','distance']
10207    joinDiagnostics = pd.DataFrame( columns = dxcols )
10208    nanList=[math.nan]
10209    def rob(x, y=0.99):
10210        x[x > np.quantile(x, y, nan_policy="omit")] = np.nan
10211        return x
10212
10213    jmm = jmm.replace(r'^\s*$', np.nan, regex=True)
10214
10215    if verbose:
10216        print("do rsfMRI")
10217    # here - we first have to average within each row
10218    dt0 = get_names_from_data_frame(["rsfMRI"], jmm, exclusions=["Unnamed", "rsfMRI_LR", "rsfMRI_RL"])
10219    dt1 = get_names_from_data_frame(["rsfMRI_RL"], jmm, exclusions=["Unnamed"])
10220    if len( dt0 ) > 0 and len( dt1 ) > 0:
10221        flid = dt0[0]
10222        wrows = []
10223        for i in range(jmm.shape[0]):
10224            if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) :
10225                wrows.append(i)
10226        for k in wrows:
10227            v1 = jmm.iloc[k][dt0[1:]].astype(float)
10228            v2 = jmm.iloc[k][dt1[1:]].astype(float)
10229            vvec = [v1[0], v2[0]]
10230            if any(~np.isnan(vvec)):
10231                mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10232                jmm.iloc[k][dt0[0]] = 'rsfMRI'
10233                if len(mynna) == 1:
10234                    if mynna[0] == 0:
10235                        jmm.iloc[k][dt0[1:]] = v1
10236                    if mynna[0] == 1:
10237                        jmm.iloc[k][dt0[1:]] = v2
10238                elif len(mynna) > 1:
10239                    if len(v2) > diagnostic_n:
10240                        v1dx=v1[0:diagnostic_n]
10241                        v2dx=v2[0:diagnostic_n]
10242                    else :
10243                        v1dx=v1
10244                        v2dx=v2
10245                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10246                    mycorr = np.corrcoef( v1dx.values, v2dx.values )[0,1]
10247                    myerr=np.sqrt(np.mean((v1dx.values - v2dx.values)**2))
10248                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'rsfMRI','colavg',mycorr,myerr]
10249                    if mycorr > corr_thresh:
10250                        jmm.loc[k, dt0[1:]] = v1.values*0.5 + v2.values*0.5
10251                    else:
10252                        jmm.loc[k, dt0[1:]] = nanList * len(v1)
10253                    if verbose:
10254                        print( joinDiagnosticsLoc )
10255                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10256
10257    if verbose:
10258        print("do DTI")
10259    # here - we first have to average within each row
10260    dt0 = get_names_from_data_frame(["DTI"], jmm, exclusions=["Unnamed", "DTI_LR", "DTI_RL"])
10261    dt1 = get_names_from_data_frame(["DTI_LR"], jmm, exclusions=["Unnamed"])
10262    dt2 = get_names_from_data_frame( ["DTI_RL"], jmm, exclusions=["Unnamed"])
10263    flid = dt0[0]
10264    wrows = []
10265    for i in range(jmm.shape[0]):
10266        if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) or not pd.isna(jmm[dt2[1]][i]):
10267            wrows.append(i)
10268    for k in wrows:
10269        v1 = jmm.loc[k, dt0[1:]].astype(float)
10270        v2 = jmm.loc[k, dt1[1:]].astype(float)
10271        v3 = jmm.loc[k, dt2[1:]].astype(float)
10272        checkcol = dt0[5]
10273        if not np.isnan(v1[checkcol]):
10274            if v1[checkcol] < 0.25:
10275                v1.replace(np.nan, inplace=True)
10276        checkcol = dt1[5]
10277        if not np.isnan(v2[checkcol]):
10278            if v2[checkcol] < 0.25:
10279                v2.replace(np.nan, inplace=True)
10280        checkcol = dt2[5]
10281        if not np.isnan(v3[checkcol]):
10282            if v3[checkcol] < 0.25:
10283                v3.replace(np.nan, inplace=True)
10284        vvec = [v1[0], v2[0], v3[0]]
10285        if any(~np.isnan(vvec)):
10286            mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10287            jmm.loc[k, dt0[0]] = 'DTI'
10288            if len(mynna) == 1:
10289                if mynna[0] == 0:
10290                    jmm.loc[k, dt0[1:]] = v1
10291                if mynna[0] == 1:
10292                    jmm.loc[k, dt0[1:]] = v2
10293                if mynna[0] == 2:
10294                    jmm.loc[k, dt0[1:]] = v3
10295            elif len(mynna) > 1:
10296                if mynna[0] == 0:
10297                    jmm.loc[k, dt0[1:]] = v1
10298                else:
10299                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10300                    mycorr = np.corrcoef( v2[0:diagnostic_n].values, v3[0:diagnostic_n].values )[0,1]
10301                    myerr=np.sqrt(np.mean((v2[0:diagnostic_n].values - v3[0:diagnostic_n].values)**2))
10302                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'DTI','colavg',mycorr,myerr]
10303                    if mycorr > corr_thresh:
10304                        jmm.loc[k, dt0[1:]] = v2.values*0.5 + v3.values*0.5
10305                    else: #
10306                        jmm.loc[k, dt0[1:]] = nanList * len( dt0[1:] )
10307                    if verbose:
10308                        print( joinDiagnosticsLoc )
10309                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10310
10311
10312    # first task - sort by u_hier_id
10313    jmm = jmm.sort_values( "u_hier_id" )
10314    # get rid of junk columns
10315    badnames = get_names_from_data_frame( ['Unnamed'], jmm )
10316    jmm=jmm.drop(badnames, axis=1)
10317    jmm=jmm.set_index("u_hier_id",drop=False)
10318    # 2nd - get rid of duplicated u_hier_id
10319    jmmUniq = jmm.drop_duplicates( subset="u_hier_id" ) # fast and easy
10320    # for each modality, count which ids have more than one
10321    mod_names = get_valid_modalities()
10322    for mod_name in mod_names:
10323        fl_names = get_names_from_data_frame([mod_name], jmm,
10324            exclusions=['Unnamed',"DTI_LR","DTI_RL","rsfMRI_RL","rsfMRI_LR"])
10325        if len( fl_names ) > 1:
10326            if verbose:
10327                print(mod_name)
10328                print(fl_names)
10329            fl_id = fl_names[0]
10330            n_names = len(fl_names)
10331            locvec = jmm[fl_names[n_names-1]].astype(float)
10332            boolvec=~pd.isna(locvec)
10333            jmmsub = jmm[boolvec][ ['u_hier_id']+fl_names]
10334            my_tbl = Counter(jmmsub['u_hier_id'])
10335            gtoavg = [name for name in my_tbl.keys() if my_tbl[name] == 1]
10336            gtoavgG1 = [name for name in my_tbl.keys() if my_tbl[name] > 1]
10337            if verbose:
10338                print("Join 1")
10339            jmmsub1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavg)][['u_hier_id']+fl_names]
10340            for u in gtoavg:
10341                jmmUniq.loc[u][fl_names[1:]] = jmmsub1.loc[u][fl_names[1:]]
10342            if verbose and len(gtoavgG1) > 1:
10343                print("Join >1")
10344            jmmsubG1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavgG1)][['u_hier_id']+fl_names]
10345            for u in gtoavgG1:
10346                temp = jmmsubG1.loc[u][ ['u_hier_id']+fl_names ]
10347                dropnames = get_names_from_data_frame( ['MM.ID'], temp )
10348                tempVec = temp.drop(columns=dropnames)
10349                joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10350                id1=temp[fl_id].iloc[0]
10351                id2=temp[fl_id].iloc[1]
10352                v1=tempVec.iloc[0][1:].astype(float).to_numpy()
10353                v2=tempVec.iloc[1][1:].astype(float).to_numpy()
10354                if len(v2) > diagnostic_n:
10355                    v1=v1[0:diagnostic_n]
10356                    v2=v2[0:diagnostic_n]
10357                mycorr = np.corrcoef( v1, v2 )[0,1]
10358                # mycorr=temparr[np.triu_indices_from(temparr, k=1)].mean()
10359                myerr=np.sqrt(np.mean((v1 - v2)**2))
10360                joinDiagnosticsLoc.iloc[0] = [id1,id2,mod_name,'rowavg',mycorr,myerr]
10361                if verbose:
10362                    print( joinDiagnosticsLoc )
10363                temp = jmmsubG1.loc[u][fl_names[1:]].astype(float)
10364                if mycorr > corr_thresh or len( v1 ) < 10:
10365                    jmmUniq.loc[u][fl_names[1:]] = temp.mean(axis=0)
10366                else:
10367                    jmmUniq.loc[u][fl_names[1:]] = nanList * temp.shape[1]
10368                joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], 
10369                                            axis=0, ignore_index=False )
10370
10371    return jmmUniq, jmm, joinDiagnostics
10372
10373
10374
10375def quick_viz_mm_nrg(
10376    sourcedir, # root folder
10377    projectid, # project name
10378    sid , # subject unique id
10379    dtid, # date
10380    extract_brain=True,
10381    slice_factor = 0.55,
10382    post = False,
10383    original_sourcedir = None,
10384    filename = None, # output path
10385    verbose = True
10386):
10387    """
10388    This function creates visualizations of brain images for a specific subject in a project using ANTsPy.
10389
10390    Args:
10391
10392    sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)
10393    
10394    projectid (str): Project name.
10395    
10396    sid (str): Subject unique id.
10397    
10398    dtid (str): Date.
10399    
10400    extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.
10401    
10402    slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.
10403
10404    post ( bool ) : if True, will visualize example post-processing results.
10405    
10406    original_sourcedir (str): Root folder for original data (used if post=True)
10407    
10408    filename (str): Output path with extension (.png)
10409    
10410    verbose (bool): If True, information will be printed while running the function. Default is True.
10411
10412    Returns:
10413    None
10414
10415    """
10416    iid='*'
10417    import glob as glob
10418    from os.path import exists
10419    import ants
10420    temp = sourcedir.split( "/" )
10421    subjectrootpath = os.path.join(sourcedir, projectid, sid, dtid)
10422    if verbose:
10423        print( 'subjectrootpath' )
10424        print( subjectrootpath )
10425    t1_search_path = os.path.join(subjectrootpath, "T1w", "*", "*nii.gz")
10426    if verbose:
10427        print(f"t1 search path: {t1_search_path}")
10428    t1fn = glob.glob(t1_search_path)
10429    if len( t1fn ) < 1:
10430        raise ValueError('quick_viz_mm_nrg cannot find the T1w @ ' + subjectrootpath )
10431    vizlist=[]
10432    undlist=[]
10433    nrg_modality_list = [ 'T1w', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10434    if post:
10435        nrg_modality_list = [ 'T1wHierarchical', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10436    for nrgNum in [0,1,2,3,4,5]:
10437        underlay = None
10438        overmodX = nrg_modality_list[nrgNum]
10439        if  'T1w' in overmodX :
10440            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10441            if post:
10442                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*brain_n4_dnz.nii.gz")
10443                mod_search_path_ol = os.path.join(subjectrootpath, overmodX, iid, "*thickness_image.nii.gz" )
10444                mod_search_path_ol = re.sub( "T1wHierarchical","T1w",mod_search_path_ol)
10445                myol = glob.glob(mod_search_path_ol)
10446                if len( myol ) > 0:
10447                    temper = find_most_recent_file( myol )[0]
10448                    underlay = ants.image_read(  temper )
10449                    if verbose:
10450                        print("T1w overlay " + temper )
10451                    underlay = underlay * ants.threshold_image( underlay, 0.2, math.inf )
10452            myimgsr = glob.glob(mod_search_path)
10453            if len( myimgsr ) == 0:
10454                if verbose:
10455                    print("No t1 images: " + sid + dtid )
10456                return None
10457            myimgsr=find_most_recent_file( myimgsr )[0]
10458            vimg=ants.image_read( myimgsr )
10459        elif  'T2Flair' in overmodX :
10460            if verbose:
10461                print("search flair")
10462            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10463            if post and original_sourcedir is not None:
10464                if verbose:
10465                    print("post in flair")
10466                mysubdir = os.path.join(original_sourcedir, projectid, sid, dtid)
10467                mod_search_path_under = os.path.join(mysubdir, overmodX, iid, "*T2Flair*.nii.gz")
10468                if verbose:
10469                    print("post in flair mod_search_path_under " + mod_search_path_under)
10470                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*wmh.nii.gz")
10471                if verbose:
10472                    print("post in flair mod_search_path " + mod_search_path )
10473                myimgul = glob.glob(mod_search_path_under)
10474                if len( myimgul ) > 0:
10475                    myimgul = find_most_recent_file( myimgul )[0]
10476                    if verbose:
10477                        print("Flair  " + myimgul )
10478                    vimg = ants.image_read( myimgul )
10479                    myol = glob.glob(mod_search_path)
10480                    if len( myol ) == 0:
10481                        underlay = myimgsr * 0.0
10482                    else:
10483                        myol = find_most_recent_file( myol )[0]
10484                        if verbose:
10485                            print("Flair overlay " + myol )
10486                        underlay=ants.image_read( myol )
10487                        underlay=underlay*ants.threshold_image(underlay,0.05,math.inf)
10488                else:
10489                    vimg = noizimg.clone()
10490                    underlay = vimg * 0.0
10491            if original_sourcedir is None:
10492                myimgsr = glob.glob(mod_search_path)
10493                if len( myimgsr ) == 0:
10494                    vimg = noizimg.clone()
10495                else:
10496                    myimgsr=find_most_recent_file( myimgsr )[0]
10497                    vimg=ants.image_read( myimgsr )
10498        elif overmodX == 'DTI':
10499            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10500            if post:
10501                mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*fa.nii.gz")
10502            myimgsr = glob.glob(mod_search_path)
10503            if len( myimgsr ) > 0:
10504                myimgsr=find_most_recent_file( myimgsr )[0]
10505                vimg=ants.image_read( myimgsr )
10506            else:
10507                if verbose:
10508                    print("No " + overmodX)
10509                vimg = noizimg.clone()
10510        elif overmodX == 'DTI2':
10511            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10512            myimgsr = glob.glob(mod_search_path)
10513            if len( myimgsr ) > 0:
10514                myimgsr.sort()
10515                myimgsr=myimgsr[len(myimgsr)-1]
10516                vimg=ants.image_read( myimgsr )
10517            else:
10518                if verbose:
10519                    print("No " + overmodX)
10520                vimg = noizimg.clone()
10521        elif overmodX == 'NM2DMT':
10522            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10523            if post:
10524                mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*NM_avg.nii.gz" )
10525            myimgsr = glob.glob(mod_search_path)
10526            if len( myimgsr ) > 0:
10527                myimgsr0=myimgsr[0]
10528                vimg=ants.image_read( myimgsr0 )
10529                for k in range(1,len(myimgsr)):
10530                    temp = ants.image_read( myimgsr[k])
10531                    vimg=vimg+ants.resample_image_to_target(temp,vimg)
10532            else:
10533                if verbose:
10534                    print("No " + overmodX)
10535                vimg = noizimg.clone()
10536        elif overmodX == 'rsfMRI':
10537            mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*nii.gz")
10538            if post:
10539                mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_meanBold.nii.gz" )
10540                mod_search_path_ol = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_DefaultMode.nii.gz" )
10541                myol = glob.glob(mod_search_path_ol)
10542                if len( myol ) > 0:
10543                    myol = find_most_recent_file( myol )[0]
10544                    underlay = ants.image_read( myol )
10545                    if verbose:
10546                        print("BOLD overlay " + myol )
10547                    underlay = underlay * ants.threshold_image( underlay, 0.1, math.inf )
10548            myimgsr = glob.glob(mod_search_path)
10549            if len( myimgsr ) > 0:
10550                myimgsr=find_most_recent_file( myimgsr )[0]
10551                vimg=mm_read_to_3d( myimgsr )
10552            else:
10553                if verbose:
10554                    print("No " + overmodX)
10555                vimg = noizimg.clone()
10556        elif overmodX == 'perf':
10557            mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*nii.gz")
10558            if post:
10559                mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*cbf.nii.gz")
10560            myimgsr = glob.glob(mod_search_path)
10561            if len( myimgsr ) > 0:
10562                myimgsr=find_most_recent_file( myimgsr )[0]
10563                vimg=mm_read_to_3d( myimgsr )
10564            else:
10565                if verbose:
10566                    print("No " + overmodX)
10567                vimg = noizimg.clone()
10568        else :
10569            if verbose:
10570                print("Something else here")
10571            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10572            myimgsr = glob.glob(mod_search_path)
10573            if post:
10574                myimgsr=[]
10575            if len( myimgsr ) > 0:
10576                myimgsr=find_most_recent_file( myimgsr )[0]
10577                vimg=ants.image_read( myimgsr )
10578            else:
10579                if verbose:
10580                    print("No " + overmodX)
10581                vimg = noizimg
10582        if True:
10583            if extract_brain and overmodX == 'T1w' and post == False:
10584                vimg = vimg * antspyt1w.brain_extraction(vimg)
10585            if verbose:
10586                print(f"modality search path: {myimgsr}" + " num: " + str(nrgNum))
10587            if vimg.dimension == 4 and ( overmodX == "DTI2"  ):
10588                ttb0, ttdw=get_average_dwi_b0(vimg)
10589                vimg = ttdw
10590            elif vimg.dimension == 4 and overmodX == "DTI":
10591                ttb0, ttdw=get_average_dwi_b0(vimg)
10592                vimg = ttb0
10593            elif vimg.dimension == 4 :
10594                vimg=ants.get_average_of_timeseries(vimg)
10595            msk=ants.get_mask(vimg)
10596            if overmodX == 'T2Flair':
10597                msk=vimg*0+1
10598            if underlay is not None:
10599                print( overmodX + " has underlay" )
10600            else:
10601                underlay = vimg * 0.0
10602            if nrgNum == 0:
10603                refimg=ants.image_clone( vimg )
10604                noizimg = ants.add_noise_to_image( refimg*0, 'additivegaussian', [100,1] )
10605                vizlist.append( vimg )
10606                undlist.append( underlay )
10607            else:
10608                vimg = ants.iMath( vimg, 'TruncateIntensity',0.01,0.98)
10609                vizlist.append( ants.iMath( vimg, 'Normalize' ) * 255 )
10610                undlist.append( underlay )
10611
10612    # mask & crop systematically ...
10613    msk = ants.get_mask( refimg )
10614    refimg = ants.crop_image( refimg, msk )
10615
10616    for jj in range(len(vizlist)):
10617        vizlist[jj]=ants.resample_image_to_target( vizlist[jj], refimg )
10618        undlist[jj]=ants.resample_image_to_target( undlist[jj], refimg )
10619        print( 'viz: ' + str( jj ) )
10620        print( vizlist[jj] )
10621        print( 'und: ' + str( jj ) )
10622        print( undlist[jj] )
10623
10624
10625    xyz = [None]*3
10626    for i in range(3):
10627        if xyz[i] is None:
10628            xyz[i] = int(refimg.shape[i] * slice_factor )
10629
10630    if verbose:
10631        print('slice positions')
10632        print( xyz )
10633
10634    ants.plot_ortho_stack( vizlist, overlays=undlist, crop=False, reorient=False, filename=filename, xyz=xyz, orient_labels=False )
10635    return
10636    # listlen = len( vizlist )
10637    # vizlist = np.asarray( vizlist )
10638    if show_it is not None:
10639        filenameout=None
10640        if verbose:
10641            print( show_it )
10642        for a in [0,1,2]:
10643            n=int(np.round( refimg.shape[a] * slice_factor ))
10644            slices=np.repeat( int(n), listlen  )
10645            if isinstance(show_it,str):
10646                filenameout=show_it+'_ax'+str(int(a))+'_sl'+str(n)+'.png'
10647                if verbose:
10648                    print( filenameout )
10649#            ants.plot_grid(vizlist.reshape(2,3), slices.reshape(2,3), title='MM Subject ' + sid + ' ' + dtid, rfacecolor='white', axes=a, filename=filenameout )
10650    if verbose:
10651        print("viz complete.")
10652    return vizlist
10653
10654
10655def blind_image_assessment(
10656    image,
10657    viz_filename=None,
10658    title=False,
10659    pull_rank=False,
10660    resample=None,
10661    n_to_skip = 10,
10662    verbose=False
10663):
10664    """
10665    quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D.  produces a png and csv where csv contains:
10666
10667    * reflection error ( estimates asymmetry )
10668
10669    * brisq ( blind quality assessment )
10670
10671    * patch eigenvalue ratio ( blind quality assessment )
10672
10673    * PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)
10674
10675    * mask volume ( estimates foreground object size )
10676
10677    * spacing
10678
10679    * dimension after cropping by mask
10680
10681    image : character or image object usually a nifti image
10682
10683    viz_filename : character for a png output image
10684
10685    title : display a summary title on the png
10686
10687    pull_rank : boolean
10688
10689    resample : None, numeric max or min, resamples image to isotropy
10690
10691    n_to_skip : 10 by default; samples time series every n_to_skip volume
10692
10693    verbose : boolean
10694
10695    """
10696    import glob as glob
10697    from os.path import exists
10698    import ants
10699    import matplotlib.pyplot as plt
10700    from PIL import Image
10701    from pathlib import Path
10702    import json
10703    import re
10704    from dipy.io.gradients import read_bvals_bvecs
10705    mystem=''
10706    if isinstance(image,list):
10707        isfilename=isinstance( image[0], str)
10708        image = image[0]
10709    else:
10710        isfilename=isinstance( image, str)
10711    outdf = pd.DataFrame()
10712    mymeta = None
10713    MagneticFieldStrength = None
10714    image_filename=''
10715    if isfilename:
10716        image_filename = image
10717        if isinstance(image,list):
10718            image_filename=image[0]
10719        json_name = re.sub(".nii.gz",".json",image_filename)
10720        if exists( json_name ):
10721            try:
10722                with open(json_name, 'r') as fcc_file:
10723                    mymeta = json.load(fcc_file)
10724                    if verbose:
10725                        print(json.dumps(mymeta, indent=4))
10726                    fcc_file.close()
10727            except:
10728                pass
10729        mystem=Path( image ).stem
10730        mystem=Path( mystem ).stem
10731        image_reference = ants.image_read( image )
10732        image = ants.image_read( image )
10733    else:
10734        image_reference = ants.image_clone( image )
10735    ntimepoints = 1
10736    bvalueMax=None
10737    bvecnorm=None
10738    if image_reference.dimension == 4:
10739        ntimepoints = image_reference.shape[3]
10740        if "DTI" in image_filename:
10741            myTSseg = segment_timeseries_by_meanvalue( image_reference )
10742            image_b0, image_dwi = get_average_dwi_b0( image_reference, fast=True )
10743            image_b0 = ants.iMath( image_b0, 'Normalize' )
10744            image_dwi = ants.iMath( image_dwi, 'Normalize' )
10745            bval_name = re.sub(".nii.gz",".bval",image_filename)
10746            bvec_name = re.sub(".nii.gz",".bvec",image_filename)
10747            if exists( bval_name ) and exists( bvec_name ):
10748                bvals, bvecs = read_bvals_bvecs( bval_name , bvec_name  )
10749                bvalueMax = bvals.max()
10750                bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
10751                bvecnorm = bvecnorm.max()
10752        else:
10753            image_b0 = ants.get_average_of_timeseries( image_reference ).iMath("Normalize")
10754    else:
10755        image_compare = ants.smooth_image( image_reference, 3, sigma_in_physical_coordinates=False )
10756    for jjj in range(0,ntimepoints,n_to_skip):
10757        modality='unknown'
10758        if "rsfMRI" in image_filename:
10759            modality='rsfMRI'
10760        elif "perf" in image_filename:
10761            modality='perf'
10762        elif "DTI" in image_filename:
10763            modality='DTI'
10764        elif "T1w" in image_filename:
10765            modality='T1w'
10766        elif "T2Flair" in image_filename:
10767            modality='T2Flair'
10768        elif "NM2DMT" in image_filename:
10769            modality='NM2DMT'
10770        if image_reference.dimension == 4:
10771            image = ants.slice_image( image_reference, idx=int(jjj), axis=3 )
10772            if "DTI" in image_filename:
10773                if jjj in myTSseg['highermeans']:
10774                    image_compare = ants.image_clone( image_b0 )
10775                    modality='DTIb0'
10776                else:
10777                    image_compare = ants.image_clone( image_dwi )
10778                    modality='DTIdwi'
10779            else:
10780                image_compare = ants.image_clone( image_b0 )
10781        # image = ants.iMath( image, 'TruncateIntensity',0.01,0.995)
10782        minspc = np.min(ants.get_spacing(image))
10783        maxspc = np.max(ants.get_spacing(image))
10784        if resample is not None:
10785            if resample == 'min':
10786                if minspc < 1e-12:
10787                    minspc = np.max(ants.get_spacing(image))
10788                newspc = np.repeat( minspc, 3 )
10789            elif resample == 'max':
10790                newspc = np.repeat( maxspc, 3 )
10791            else:
10792                newspc = np.repeat( resample, 3 )
10793            image = ants.resample_image( image, newspc )
10794            image_compare = ants.resample_image( image_compare, newspc )
10795        else:
10796            # check for spc close to zero
10797            spc = list(ants.get_spacing(image))
10798            for spck in range(len(spc)):
10799                if spc[spck] < 1e-12:
10800                    spc[spck]=1
10801            ants.set_spacing( image, spc )
10802            ants.set_spacing( image_compare, spc )
10803        # if "NM2DMT" in image_filename or "FIXME" in image_filename or "SPECT" in image_filename or "UNKNOWN" in image_filename:
10804        minspc = np.min(ants.get_spacing(image))
10805        maxspc = np.max(ants.get_spacing(image))
10806        msk = ants.threshold_image( ants.iMath(image,'Normalize'), 0.15, 1.0 )
10807        # else:
10808        #    msk = ants.get_mask( image )
10809        msk = ants.morphology(msk, "close", 3 )
10810        bgmsk = msk*0+1-msk
10811        mskdil = ants.iMath(msk, "MD", 4 )
10812        # ants.plot_ortho( image, msk, crop=False )
10813        nvox = int( msk.sum() )
10814        spc = ants.get_spacing( image )
10815        org = ants.get_origin( image )
10816        if ( nvox > 0 ):
10817            image = ants.crop_image( image, mskdil ).iMath("Normalize")
10818            msk = ants.crop_image( msk, mskdil ).iMath("Normalize")
10819            bgmsk = ants.crop_image( bgmsk, mskdil ).iMath("Normalize")
10820            image_compare = ants.crop_image( image_compare, mskdil ).iMath("Normalize")           
10821            npatch = int( np.round(  0.1 * nvox ) )
10822            npatch = np.min(  [512,npatch ] )
10823            patch_shape = []
10824            for k in range( 3 ):
10825                p = int( 32.0 / ants.get_spacing( image  )[k] )
10826                if p > int( np.round( image.shape[k] * 0.5 ) ):
10827                    p = int( np.round( image.shape[k] * 0.5 ) )
10828                patch_shape.append( p )
10829            if verbose:
10830                print(image)
10831                print( patch_shape )
10832                print( npatch )
10833            myevr = math.nan # dont want to fail if something odd happens in patch extraction
10834            try:
10835                myevr = antspyt1w.patch_eigenvalue_ratio( image, npatch, patch_shape,
10836                    evdepth = 0.9, mask=msk )
10837            except:
10838                pass
10839            if pull_rank:
10840                image = ants.rank_intensity(image)
10841            imagereflect = ants.reflect_image(image, axis=0)
10842            asym_err = ( image - imagereflect ).abs().mean()
10843            # estimate noise by center cropping, denoizing and taking magnitude of difference
10844            nocrop=False
10845            if image.dimension == 3:
10846                if image.shape[2] == 1:
10847                    nocrop=True        
10848            if maxspc/minspc > 10:
10849                nocrop=True
10850            if nocrop:
10851                mycc = ants.image_clone( image )
10852            else:
10853                mycc = antspyt1w.special_crop( image,
10854                    ants.get_center_of_mass( msk *0 + 1 ), patch_shape )
10855            myccd = ants.denoise_image( mycc, p=1,r=1,noise_model='Gaussian' )
10856            noizlevel = ( mycc - myccd ).abs().mean()
10857    #        ants.plot_ortho( image, crop=False, filename=viz_filename, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
10858    #        from brisque import BRISQUE
10859    #        obj = BRISQUE(url=False)
10860    #        mybrisq = obj.score( np.array( Image.open( viz_filename )) )
10861            msk_vol = msk.sum() * np.prod( spc )
10862            bgstd = image[ bgmsk == 1 ].std()
10863            fgmean = image[ msk == 1 ].mean()
10864            bgmean = image[ bgmsk == 1 ].mean()
10865            snrref = fgmean / bgstd
10866            cnrref = ( fgmean - bgmean ) / bgstd
10867            psnrref = antspynet.psnr(  image_compare, image  )
10868            ssimref = antspynet.ssim(  image_compare, image  )
10869            if nocrop:
10870                mymi = math.inf
10871            else:
10872                mymi = ants.image_mutual_information( image_compare, image )
10873        else:
10874            msk_vol = 0
10875            myevr = mymi = ssimref = psnrref = cnrref = asym_err = noizlevel = math.nan
10876            
10877        mriseries=None
10878        mrimfg=None
10879        mrimodel=None
10880        mriSAR=None
10881        BandwidthPerPixelPhaseEncode=None
10882        PixelBandwidth=None
10883        if mymeta is not None:
10884            # mriseries=mymeta['']
10885            try:
10886                mrimfg=mymeta['Manufacturer']
10887            except:
10888                pass
10889            try:
10890                mrimodel=mymeta['ManufacturersModelName']
10891            except:
10892                pass
10893            try:
10894                MagneticFieldStrength=mymeta['MagneticFieldStrength']
10895            except:
10896                pass
10897            try:
10898                PixelBandwidth=mymeta['PixelBandwidth']
10899            except:
10900                pass
10901            try:
10902                BandwidthPerPixelPhaseEncode=mymeta['BandwidthPerPixelPhaseEncode']
10903            except:
10904                pass
10905            try:
10906                mriSAR=mymeta['SAR']
10907            except:
10908                pass
10909        ttl=mystem + ' '
10910        ttl=''
10911        ttl=ttl + "NZ: " + "{:0.4f}".format(noizlevel) + " SNR: " + "{:0.4f}".format(snrref) + " CNR: " + "{:0.4f}".format(cnrref) + " PS: " + "{:0.4f}".format(psnrref)+ " SS: " + "{:0.4f}".format(ssimref) + " EVR: " + "{:0.4f}".format(myevr)+ " MI: " + "{:0.4f}".format(mymi)
10912        if viz_filename is not None and ( jjj == 0 or (jjj % 30 == 0) ) and image.shape[2] < 685:
10913            viz_filename_use = re.sub( ".png", "_slice"+str(jjj).zfill(4)+".png", viz_filename )
10914            ants.plot_ortho( image, crop=False, filename=viz_filename_use, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0,  title=ttl, titlefontsize=12, title_dy=-0.02,textfontcolor='red' )
10915        df = pd.DataFrame([[ 
10916            mystem, 
10917            image_reference.dimension, 
10918            noizlevel, snrref, cnrref, psnrref, ssimref, mymi, asym_err, myevr, msk_vol, 
10919            spc[0], spc[1], spc[2],org[0], org[1], org[2], 
10920            image.shape[0], image.shape[1], image.shape[2], ntimepoints, 
10921            jjj, modality, mriseries, mrimfg, mrimodel, MagneticFieldStrength, mriSAR, PixelBandwidth, BandwidthPerPixelPhaseEncode, bvalueMax, bvecnorm ]], 
10922            columns=[
10923                'filename', 
10924                'dimensionality',
10925                'noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol', 'spc0','spc1','spc2','org0','org1','org2','dimx','dimy','dimz','dimt','slice','modality', 'mriseries', 'mrimfg', 'mrimodel', 'mriMagneticFieldStrength', 'mriSAR', 'mriPixelBandwidth', 'mriPixelBandwidthPE', 'dti_bvalueMax', 'dti_bvecnorm' ])
10926        outdf = pd.concat( [outdf, df ], axis=0, ignore_index=False )
10927        if verbose:
10928            print( outdf )
10929    if viz_filename is not None:
10930        csvfn = re.sub( "png", "csv", viz_filename )
10931        outdf.to_csv( csvfn )
10932    return outdf
10933
10934def remove_unwanted_columns(df):
10935    # Identify columns to drop: those named 'X' or starting with 'Unnamed'
10936    cols_to_drop = [col for col in df.columns if col == 'X' or col.startswith('Unnamed')]
10937    
10938    # Drop the identified columns from the DataFrame, if any
10939    df_cleaned = df.drop(columns=cols_to_drop, errors='ignore')
10940    
10941    return df_cleaned
10942
10943def process_dataframe_generalized(df, group_by_column):
10944    # Make sure the group_by_column is excluded from both numeric and other columns calculations
10945    numeric_cols = df.select_dtypes(include='number').columns.difference([group_by_column])
10946    other_cols = df.columns.difference(numeric_cols).difference([group_by_column])
10947    
10948    # Define aggregation functions: mean for numeric cols, mode for other cols
10949    # Update to handle empty mode results safely
10950    agg_dict = {col: 'mean' for col in numeric_cols}
10951    agg_dict.update({
10952        col: lambda x: pd.Series.mode(x).iloc[0] if not pd.Series.mode(x).empty else None for col in other_cols
10953    })    
10954    # Group by the specified column, applying different aggregation functions to different columns
10955    processed_df = df.groupby(group_by_column, as_index=False).agg(agg_dict)
10956    return processed_df
10957
10958def average_blind_qc_by_modality(qc_full,verbose=False):
10959    """
10960    Averages time series qc results to yield one entry per image. this also filters to "known" columns.
10961
10962    Args:
10963    qc_full: pandas dataframe containing the full qc data.
10964
10965    Returns:
10966    pandas dataframe containing the processed qc data.
10967    """
10968    qc_full = remove_unwanted_columns( qc_full )
10969    # Get unique modalities
10970    modalities = qc_full['modality'].unique()
10971    modalities = modalities[modalities != 'unknown']
10972    # Get unique ids
10973    uid = qc_full['filename']
10974    to_average = uid.unique()
10975    meta = pd.DataFrame(columns=qc_full.columns )
10976    # Process each unique id
10977    n = len(to_average)
10978    for k in range(n):
10979        if verbose:
10980            if k % 100 == 0:
10981                progger = str( np.round( k / n * 100 ) )
10982                print( progger, end ="...", flush=True)
10983        m1sel = uid == to_average[k]
10984        if sum(m1sel) > 1:
10985            # If more than one entry for id, take the average of continuous columns,
10986            # maximum of the slice column, and the first entry of the other columns
10987            mfsub = process_dataframe_generalized(qc_full[m1sel],'filename')
10988        else:
10989            mfsub = qc_full[m1sel]
10990        meta.loc[k] = mfsub.iloc[0]
10991    meta['modality'] = meta['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
10992    return meta
10993
10994def wmh( flair, t1, t1seg,
10995    mmfromconvexhull = 3.0,
10996    strict=True,
10997    probability_mask=None,
10998    prior_probability=None,
10999    model='sysu',
11000    verbose=False ) :
11001    """
11002    Outputs the WMH probability mask and a summary single measurement
11003
11004    Arguments
11005    ---------
11006    flair : ANTsImage
11007        input 3-D FLAIR brain image (not skull-stripped).
11008
11009    t1 : ANTsImage
11010        input 3-D T1 brain image (not skull-stripped).
11011
11012    t1seg : ANTsImage
11013        T1 segmentation image
11014
11015    mmfromconvexhull : float
11016        restrict WMH to regions that are WM or mmfromconvexhull mm away from the
11017        convex hull of the cerebrum.   we choose a default value based on
11018        Figure 4 from:
11019        https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf
11020
11021    strict: boolean - if True, only use convex hull distance
11022
11023    probability_mask : None - use to compute wmh just once - then this function
11024        just does refinement and summary
11025
11026    prior_probability : optional prior probability image in space of the input t1
11027
11028    model : either sysu or hyper
11029
11030    verbose : boolean
11031
11032    Returns
11033    ---------
11034    WMH probability map and a summary single measurement which is the sum of the WMH map
11035
11036    """
11037    import numpy as np
11038    import math
11039    t1_2_flair_reg = ants.registration(flair, t1, type_of_transform = 'antsRegistrationSyNRepro[r]') # Register T1 to Flair
11040    if probability_mask is None and model == 'sysu':
11041        if verbose:
11042            print('sysu')
11043        probability_mask = antspynet.sysu_media_wmh_segmentation( flair )
11044    elif probability_mask is None and model == 'hyper':
11045        if verbose:
11046            print('hyper')
11047        probability_mask = antspynet.hypermapp3r_segmentation( t1_2_flair_reg['warpedmovout'], flair )
11048    # t1_2_flair_reg = tra_initializer( flair, t1, n_simulations=4, max_rotation=5, transform=['rigid'], verbose=False )
11049    prior_probability_flair = None
11050    if prior_probability is not None:
11051        prior_probability_flair = ants.apply_transforms( flair, prior_probability,
11052            t1_2_flair_reg['fwdtransforms'] )
11053    wmseg_mask = ants.threshold_image( t1seg,
11054        low_thresh = 3, high_thresh = 3).iMath("FillHoles")
11055    wmseg_mask_use = ants.image_clone( wmseg_mask )
11056    distmask = None
11057    if mmfromconvexhull > 0:
11058            convexhull = ants.threshold_image( t1seg, 1, 4 )
11059            spc2vox = np.prod( ants.get_spacing( t1seg ) )
11060            voxdist = 0.0
11061            myspc = ants.get_spacing( t1seg )
11062            for k in range( t1seg.dimension ):
11063                voxdist = voxdist + myspc[k] * myspc[k]
11064            voxdist = math.sqrt( voxdist )
11065            nmorph = round( 2.0 / voxdist )
11066            convexhull = ants.morphology( convexhull, "close", nmorph ).iMath("FillHoles")
11067            dist = ants.iMath( convexhull, "MaurerDistance" ) * -1.0
11068            distmask = ants.threshold_image( dist, mmfromconvexhull, 1.e80 )
11069            wmseg_mask = wmseg_mask + distmask
11070            if strict:
11071                wmseg_mask_use = ants.threshold_image( wmseg_mask, 2, 2 )
11072            else:
11073                wmseg_mask_use = ants.threshold_image( wmseg_mask, 1, 2 )
11074    ##############################################################################
11075    wmseg_2_flair = ants.apply_transforms(flair, wmseg_mask_use,
11076        transformlist = t1_2_flair_reg['fwdtransforms'],
11077        interpolator = 'nearestNeighbor' )
11078    seg_2_flair = ants.apply_transforms(flair, t1seg,
11079        transformlist = t1_2_flair_reg['fwdtransforms'],
11080        interpolator = 'nearestNeighbor' )
11081    csfmask = ants.threshold_image(seg_2_flair,1,1)
11082    flairsnr = mask_snr( flair, csfmask, wmseg_2_flair, bias_correct = False )
11083    probability_mask_WM = wmseg_2_flair * probability_mask # Remove WMH signal outside of WM
11084    wmh_sum = np.prod( ants.get_spacing( flair ) ) * probability_mask_WM.sum()
11085    wmh_sum_prior = math.nan
11086    probability_mask_posterior = None
11087    if prior_probability_flair is not None:
11088        probability_mask_posterior = prior_probability_flair * probability_mask # use prior
11089        wmh_sum_prior = np.prod( ants.get_spacing(flair) ) * probability_mask_posterior.sum()
11090    if math.isnan( wmh_sum ):
11091        wmh_sum=0
11092    if math.isnan( wmh_sum_prior ):
11093        wmh_sum_prior=0
11094    flair_evr = antspyt1w.patch_eigenvalue_ratio( flair, 512, [16,16,16], evdepth = 0.9, mask=wmseg_2_flair )
11095    return{
11096        'WMH_probability_map_raw': probability_mask,
11097        'WMH_probability_map' : probability_mask_WM,
11098        'WMH_posterior_probability_map' : probability_mask_posterior,
11099        'wmh_mass': wmh_sum,
11100        'wmh_mass_prior': wmh_sum_prior,
11101        'wmh_evr' : flair_evr,
11102        'wmh_SNR' : flairsnr,
11103        'convexhull_mask': distmask }
11104
11105
11106def replace_elements_in_numpy_array(original_array, indices_to_replace, new_value):
11107    """
11108    Replace specified elements or rows in a numpy array with a new value.
11109
11110    Parameters:
11111    original_array (numpy.ndarray): A numpy array in which elements or rows are to be replaced.
11112    indices_to_replace (list or numpy.ndarray): Indices of elements or rows to be replaced.
11113    new_value: The new value to replace the specified elements or rows.
11114
11115    Returns:
11116    numpy.ndarray: A new numpy array with the specified elements or rows replaced. If the input array is None,
11117                   the function returns None.
11118    """
11119
11120    if original_array is None:
11121        return None
11122
11123    max_index = original_array.size if original_array.ndim == 1 else original_array.shape[0]
11124
11125    # Filter out invalid indices and check for any out-of-bounds indices
11126    valid_indices = []
11127    for idx in indices_to_replace:
11128        if idx < max_index:
11129            valid_indices.append(idx)
11130        else:
11131            warnings.warn(f"Warning: Index {idx} is out of bounds and will be ignored.")
11132
11133    if original_array.ndim == 1:
11134        # Replace elements in a 1D array
11135        original_array[valid_indices] = new_value
11136    elif original_array.ndim == 2:
11137        # Replace rows in a 2D array
11138        original_array[valid_indices, :] = new_value
11139    else:
11140        raise ValueError("original_array must be either 1D or 2D.")
11141
11142    return original_array
11143
11144
11145
11146def remove_elements_from_numpy_array(original_array, indices_to_remove):
11147    """
11148    Remove specified elements or rows from a numpy array.
11149
11150    Parameters:
11151    original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed.
11152    indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.
11153
11154    Returns:
11155    numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None,
11156                   the function returns None.
11157    """
11158
11159    if original_array is None:
11160        return None
11161
11162    if original_array.ndim == 1:
11163        # Remove elements from a 1D array
11164        return np.delete(original_array, indices_to_remove)
11165    elif original_array.ndim == 2:
11166        # Remove rows from a 2D array
11167        return np.delete(original_array, indices_to_remove, axis=0)
11168    else:
11169        raise ValueError("original_array must be either 1D or 2D.")
11170
11171def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11172    """
11173    Remove specified volumes from a time series.
11174
11175    :param time_series: ANTsImage representing the time series (4D image).
11176    :param volumes_to_remove: List of volume indices to remove.
11177    :return: ANTsImage with specified volumes removed.
11178    """
11179    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11180        raise ValueError("time_series must be an ANTsImage.")
11181
11182    if time_series.dimension != 4:
11183        raise ValueError("time_series must be a 4D image.")
11184
11185    # Create a boolean index for volumes to keep
11186    volumes_to_keep = [i for i in range(time_series.shape[3]) if i not in volumes_to_remove]
11187
11188    # Select the volumes to keep
11189    filtered_time_series = ants.from_numpy( time_series.numpy()[..., volumes_to_keep] )
11190
11191    return ants.copy_image_info( time_series, filtered_time_series )
11192
11193def remove_elements_from_list(original_list, elements_to_remove):
11194    """
11195    Remove specified elements from a list.
11196
11197    Parameters:
11198    original_list (list): The original list from which elements will be removed.
11199    elements_to_remove (list): A list of elements that need to be removed from the original list.
11200
11201    Returns:
11202    list: A new list with the specified elements removed.
11203    """
11204    return [element for element in original_list if element not in elements_to_remove]
11205
11206
11207def impute_timeseries(time_series, volumes_to_impute, method='linear', verbose=False):
11208    """
11209    Impute specified volumes from a time series with interpolated values.
11210
11211    :param time_series: ANTsImage representing the time series (4D image).
11212    :param volumes_to_impute: List of volume indices to impute.
11213    :param method: Interpolation method ('linear' or other methods if implemented).
11214    :param verbose: boolean
11215    :return: ANTsImage with specified volumes imputed.
11216    """
11217    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11218        raise ValueError("time_series must be an ANTsImage.")
11219
11220    if time_series.dimension != 4:
11221        raise ValueError("time_series must be a 4D image.")
11222
11223    # Convert time_series to numpy for manipulation
11224    time_series_np = time_series.numpy()
11225    total_volumes = time_series_np.shape[3]
11226
11227    # Create a complement list of volumes not to impute
11228    volumes_not_to_impute = [i for i in range(total_volumes) if i not in volumes_to_impute]
11229
11230    # Define the lower and upper bounds
11231    min_valid_index = min(volumes_not_to_impute)
11232    max_valid_index = max(volumes_not_to_impute)
11233
11234    for vol_idx in volumes_to_impute:
11235        # Ensure the volume index is within the valid range
11236        if vol_idx < 0 or vol_idx >= total_volumes:
11237            raise ValueError(f"Volume index {vol_idx} is out of bounds.")
11238
11239        # Find the nearest valid lower index within the bounds
11240        lower_candidates = [v for v in volumes_not_to_impute if v <= vol_idx]
11241        lower_idx = max(lower_candidates) if lower_candidates else min_valid_index
11242
11243        # Find the nearest valid upper index within the bounds
11244        upper_candidates = [v for v in volumes_not_to_impute if v >= vol_idx]
11245        upper_idx = min(upper_candidates) if upper_candidates else max_valid_index
11246
11247        if verbose:
11248            print(f"Imputing volume {vol_idx} using indices {lower_idx} and {upper_idx}")
11249
11250        if method == 'linear':
11251            # Linear interpolation between the two nearest volumes
11252            lower_volume = time_series_np[..., lower_idx]
11253            upper_volume = time_series_np[..., upper_idx]
11254            interpolated_volume = (lower_volume + upper_volume) / 2
11255        else:
11256            # Placeholder for other interpolation methods
11257            raise NotImplementedError("Currently, only linear interpolation is implemented.")
11258
11259        # Replace the specified volume with the interpolated volume
11260        time_series_np[..., vol_idx] = interpolated_volume
11261
11262    # Convert the numpy array back to ANTsImage
11263    imputed_time_series = ants.from_numpy(time_series_np)
11264    imputed_time_series = ants.copy_image_info(time_series, imputed_time_series)
11265
11266    return imputed_time_series
11267
11268def impute_dwi( dwi, threshold = 0.20, imputeb0=False, mask=None, verbose=False ):
11269    """
11270    Identify bad volumes in a dwi and impute them fully automatically.
11271
11272    :param dwi: ANTsImage representing the time series (4D image).
11273    :param threshold: threshold (0,1) for outlierness (lower means impute more data)
11274    :param imputeb0: boolean will impute the b0 with dwi if True
11275    :param mask: restricts to a region of interest
11276    :param verbose: boolean
11277    :return: ANTsImage automatically imputed.
11278    """
11279    list1 = segment_timeseries_by_meanvalue( dwi )['highermeans']
11280    if imputeb0:
11281        dwib = impute_timeseries( dwi, list1 ) # focus on the dwi - not the b0
11282        looped, list2 = loop_timeseries_censoring( dwib, threshold, mask )
11283    else:
11284        looped, list2 = loop_timeseries_censoring( dwi, threshold, mask )
11285    if verbose:
11286        print( list1 )
11287        print( list2 )
11288    complement = remove_elements_from_list( list2, list1 )
11289    if verbose:
11290        print( "Imputing:")
11291        print( complement )
11292    if len( complement ) == 0:
11293        return dwi
11294    return impute_timeseries( dwi, complement )
11295
11296def censor_dwi( dwi, bval, bvec, threshold = 0.20, imputeb0=False, mask=None, verbose=False ):
11297    """
11298    Identify bad volumes in a dwi and impute them fully automatically.
11299
11300    :param dwi: ANTsImage representing the time series (4D image).
11301    :param bval: bval array
11302    :param bvec: bvec array
11303    :param threshold: threshold (0,1) for outlierness (lower means impute more data)
11304    :param imputeb0: boolean will impute the b0 with dwi if True
11305    :param mask: restricts to a region of interest
11306    :param verbose: boolean
11307    :return: ANTsImage automatically imputed.
11308    """
11309    list1 = segment_timeseries_by_meanvalue( dwi )['highermeans']
11310    if imputeb0:
11311        dwib = impute_timeseries( dwi, list1 ) # focus on the dwi - not the b0
11312        looped, list2 = loop_timeseries_censoring( dwib, threshold, mask, verbose=verbose)
11313    else:
11314        looped, list2 = loop_timeseries_censoring( dwi, threshold, mask, verbose=verbose )
11315    if verbose:
11316        print( list1 )
11317        print( list2 )
11318    complement = remove_elements_from_list( list2, list1 )
11319    if verbose:
11320        print( "censoring:")
11321        print( complement )
11322    if len( complement ) == 0:
11323        return dwi, bval, bvec
11324    return remove_volumes_from_timeseries( dwi, complement ), remove_elements_from_numpy_array( bval, complement ), remove_elements_from_numpy_array( bvec, complement )
11325
11326
11327def flatten_time_series(time_series):
11328    """
11329    Flatten a 4D time series into a 2D array.
11330    
11331    :param time_series: A 4D numpy array where the last dimension is time.
11332    :return: A 2D numpy array where each row is a flattened volume.
11333    """
11334    n_volumes = time_series.shape[3]
11335    return time_series.reshape(-1, n_volumes).T
11336
11337def calculate_loop_scores_full(flattened_series, n_neighbors=20, verbose=True ):
11338    """
11339    Calculate Local Outlier Probabilities for each volume.
11340    
11341    :param flattened_series: A 2D numpy array from flatten_time_series.
11342    :param n_neighbors: Number of neighbors to use for calculating LOF scores.
11343    :param verbose: boolean
11344    :return: An array of LoOP scores.
11345    """
11346    from PyNomaly import loop
11347    from sklearn.neighbors import NearestNeighbors
11348    from sklearn.preprocessing import StandardScaler
11349    # replace nans with zero
11350    if verbose:
11351        print("loop: nan_to_num")
11352    flattened_series=np.nan_to_num(flattened_series, nan=0)
11353    scaler = StandardScaler()
11354    scaler.fit(flattened_series)
11355    data = scaler.transform(flattened_series)
11356    data=np.nan_to_num(data, nan=0)
11357    if n_neighbors > int(flattened_series.shape[0]/2.0):
11358        n_neighbors = int(flattened_series.shape[0]/2.0)
11359    if verbose:
11360        print("loop: nearest neighbors init")
11361    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric='minkowski')
11362    if verbose:
11363        print("loop: nearest neighbors fit")
11364    neigh.fit(data)
11365    d, idx = neigh.kneighbors(data, return_distance=True)
11366    if verbose:
11367        print("loop: probability")
11368    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11369    return m.local_outlier_probabilities[:]
11370
11371
11372def calculate_loop_scores(
11373    flattened_series,
11374    n_neighbors=20,
11375    n_features_sample=0.02,
11376    n_feature_repeats=5,
11377    seed=42,
11378    use_approx_knn=True,
11379    verbose=True,
11380):
11381    """
11382    Memory-efficient and robust LoOP score estimation with optional approximate KNN
11383    and averaging over multiple random feature subsets.
11384
11385    Parameters:
11386        flattened_series (np.ndarray): 2D array (n_samples x n_features)
11387        n_neighbors (int): Number of neighbors for LoOP
11388        n_features_sample (int or float): Number or fraction of features to sample
11389        n_feature_repeats (int): How many independent feature subsets to sample and average over
11390        seed (int): Random seed
11391        use_approx_knn (bool): Whether to use fast approximate KNN (via pynndescent)
11392        verbose (bool): Verbose output
11393
11394    Returns:
11395        np.ndarray: Averaged local outlier probabilities (length n_samples)
11396    """
11397    import numpy as np
11398    from sklearn.preprocessing import StandardScaler
11399    from PyNomaly import loop
11400
11401    # Optional approximate nearest neighbors
11402    try:
11403        from pynndescent import NNDescent
11404        has_nn_descent = True
11405    except ImportError:
11406        has_nn_descent = False
11407
11408    rng = np.random.default_rng(seed)
11409    X = np.nan_to_num(flattened_series, nan=0).astype(np.float32)
11410    n_samples, n_features = X.shape
11411
11412    # Handle feature sampling
11413    if isinstance(n_features_sample, float):
11414        if 0 < n_features_sample <= 1.0:
11415            n_features_sample = max(1, int(n_features_sample * n_features))
11416        else:
11417            raise ValueError("If float, n_features_sample must be in (0, 1].")
11418
11419    n_features_sample = min(n_features, n_features_sample)
11420
11421    if n_neighbors >= n_samples:
11422        n_neighbors = max(1, n_samples // 2)
11423
11424    if verbose:
11425        print(f"[LoOP] Input shape: {X.shape}")
11426        print(f"[LoOP] Sampling {n_features_sample} features per repeat, {n_feature_repeats} repeats")
11427        print(f"[LoOP] Using {n_neighbors} neighbors")
11428
11429    loop_scores = []
11430
11431    for rep in range(n_feature_repeats):
11432        feature_idx = rng.choice(n_features, n_features_sample, replace=False)
11433        X_sub = X[:, feature_idx]
11434
11435        scaler = StandardScaler(copy=False)
11436        X_sub = scaler.fit_transform(X_sub)
11437        X_sub = np.nan_to_num(X_sub, nan=0)
11438
11439        # Approximate or exact KNN
11440        if use_approx_knn and has_nn_descent and n_samples > 1000:
11441            if verbose:
11442                print(f"  [Rep {rep+1}] Using NNDescent (approximate KNN)")
11443            ann = NNDescent(X_sub, n_neighbors=n_neighbors, random_state=seed + rep)
11444            indices, dists = ann.neighbor_graph
11445        else:
11446            from sklearn.neighbors import NearestNeighbors
11447            if verbose:
11448                print(f"  [Rep {rep+1}] Using NearestNeighbors (exact KNN)")
11449            nn = NearestNeighbors(n_neighbors=n_neighbors)
11450            nn.fit(X_sub)
11451            dists, indices = nn.kneighbors(X_sub)
11452
11453        # LoOP score for this repeat
11454        model = loop.LocalOutlierProbability(
11455            distance_matrix=dists,
11456            neighbor_matrix=indices,
11457            n_neighbors=n_neighbors
11458        ).fit()
11459        loop_scores.append(model.local_outlier_probabilities[:])
11460
11461    # Average over repeats
11462    loop_scores = np.stack(loop_scores)
11463    loop_scores_mean = loop_scores.mean(axis=0)
11464
11465    if verbose:
11466        print(f"[LoOP] Averaged over {n_feature_repeats} feature subsets. Final shape: {loop_scores_mean.shape}")
11467
11468    return loop_scores_mean
11469
11470
11471
11472def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg ):
11473    """
11474    Process CBF time series to remove high-leverage points.
11475    Derived from the SCORE algorithm by Sudipto Dolui et. al.
11476
11477    Parameters:
11478    cbfts (ANTsImage): 4D ANTsImage of CBF time series.
11479    csf_seg (ANTsImage): CSF binary map.
11480    gm_seg (ANTsImage): Gray matter binary map.
11481    wm_seg (ANTsImage): WM binary map.
11482
11483    Returns:
11484    ANTsImage: Processed CBF time series.
11485    ndarray: Index of removed volumes.
11486    """
11487    
11488    n_gm_voxels = np.sum(gm_seg.numpy()) - 1
11489    n_wm_voxels = np.sum(wm_seg.numpy()) - 1
11490    n_csf_voxels = np.sum(csf_seg.numpy()) - 1
11491    mask1img = gm_seg + wm_seg + csf_seg
11492    mask1 = (mask1img==1).numpy()
11493    
11494    cbfts_np = cbfts.numpy()
11495    gmbool = (gm_seg==1).numpy()
11496    csfbool = (csf_seg==1).numpy()
11497    wmbool = (wm_seg==1).numpy()
11498    gm_cbf_ts = ants.timeseries_to_matrix( cbfts, gm_seg )
11499    gm_cbf_ts = np.squeeze(np.mean(gm_cbf_ts, axis=1))
11500    
11501    median_gm_cbf = np.median(gm_cbf_ts)
11502    mad_gm_cbf = np.median(np.abs(gm_cbf_ts - median_gm_cbf)) / 0.675
11503    indx = np.abs(gm_cbf_ts - median_gm_cbf) > (2.5 * mad_gm_cbf)
11504    
11505    # the spatial mean
11506    spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11507    spatmean = ants.from_numpy( spatmeannp )
11508    V = (
11509        n_gm_voxels * np.var(spatmeannp[gmbool])
11510        + n_wm_voxels * np.var(spatmeannp[wmbool])
11511        + n_csf_voxels * np.var(spatmeannp[csfbool])
11512    )
11513    V1 = math.inf
11514    ct=0
11515    while V < V1:
11516        ct=ct+1
11517        V1 = V
11518        CC = np.zeros(cbfts_np.shape[3])
11519        for s in range(cbfts_np.shape[3]):
11520            if indx[s]:
11521                continue
11522            tmp1 = ants.from_numpy( cbfts_np[:, :, :, s] )
11523            CC[s] = ants.image_similarity( spatmean, tmp1, metric_type='Correlation', fixed_mask=mask1img )
11524        inx = np.argmin(CC)
11525        indx[inx] = True
11526        spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11527        spatmean = ants.from_numpy( spatmeannp )
11528        V = (
11529          n_gm_voxels * np.var(spatmeannp[gmbool]) + 
11530          n_wm_voxels * np.var(spatmeannp[wmbool]) + 
11531          n_csf_voxels * np.var(spatmeannp[csfbool])
11532        )
11533    cbfts_recon = cbfts_np[:, :, :, ~indx]
11534    cbfts_recon = np.nan_to_num(cbfts_recon)
11535    cbfts_recon_ants = ants.from_numpy(cbfts_recon)
11536    cbfts_recon_ants = ants.copy_image_info(cbfts, cbfts_recon_ants)
11537    return cbfts_recon_ants, indx
11538
11539def loop_timeseries_censoring(x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11540    """
11541    Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).
11542
11543    Parameters:
11544    x (ANTsImage): A 4D time series image.
11545    threshold (float): Threshold for determining high leverage volumes based on LoOP scores.
11546    mask (antsImage): restricts to a ROI
11547    n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used
11548    seed (int): random seed
11549    verbose (bool)
11550
11551    Returns:
11552    tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.
11553    """
11554    import warnings
11555    if x.shape[3] < 20: # just a guess at what we need here ...
11556        warnings.warn("Warning: the time dimension is < 20 - too few samples for loop. just return the original data.")
11557        return x, []
11558    if mask is None:
11559        flattened_series = flatten_time_series(x.numpy())
11560    else:
11561        flattened_series = ants.timeseries_to_matrix( x, mask )
11562    if verbose:
11563        print("loop_timeseries_censoring: flattened")
11564    loop_scores = calculate_loop_scores(flattened_series, n_features_sample=n_features_sample, seed=seed, verbose=verbose )
11565    high_leverage_volumes = np.where(loop_scores > threshold)[0]
11566    if verbose:
11567        print("loop_timeseries_censoring: High Leverage Volumes:", high_leverage_volumes)
11568    new_asl = remove_volumes_from_timeseries(x, high_leverage_volumes)
11569    return new_asl, high_leverage_volumes
11570
11571
11572def novelty_detection_ee(df_train, df_test, contamination=0.05):
11573    """
11574    This function performs novelty detection using Elliptic Envelope.
11575
11576    Parameters:
11577
11578    - df_train (pandas dataframe): training data used to fit the model
11579
11580    - df_test (pandas dataframe): test data used to predict novelties
11581
11582    - contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)
11583
11584    Returns:
11585
11586    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11587    """
11588    import pandas as pd
11589    from sklearn.covariance import EllipticEnvelope
11590    # Fit the model on the training data
11591    clf = EllipticEnvelope(contamination=contamination,support_fraction=1)
11592    df_train[ df_train == math.inf ] = 0
11593    df_test[ df_test == math.inf ] = 0
11594    from sklearn.preprocessing import StandardScaler
11595    scaler = StandardScaler()
11596    scaler.fit(df_train)
11597    clf.fit(scaler.transform(df_train))
11598    predictions = clf.predict(scaler.transform(df_test))
11599    predictions[predictions==1]=0
11600    predictions[predictions==-1]=1
11601    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11602        return pd.Series(predictions, index=df_test.index)
11603    else:
11604        return pd.Series(predictions)
11605
11606
11607
11608def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11609    """
11610    This function performs novelty detection using One-Class SVM.
11611
11612    Parameters:
11613
11614    - df_train (pandas dataframe): training data used to fit the model
11615
11616    - df_test (pandas dataframe): test data used to predict novelties
11617
11618    - nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)
11619
11620    - kernel (str): kernel type used in the SVM algorithm (default: 'rbf')
11621
11622    Returns:
11623
11624    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11625    """
11626    from sklearn.svm import OneClassSVM
11627    # Fit the model on the training data
11628    df_train[ df_train == math.inf ] = 0
11629    df_test[ df_test == math.inf ] = 0
11630    clf = OneClassSVM(nu=nu, kernel=kernel)
11631    from sklearn.preprocessing import StandardScaler
11632    scaler = StandardScaler()
11633    scaler.fit(df_train)
11634    clf.fit(scaler.transform(df_train))
11635    predictions = clf.predict(scaler.transform(df_test))
11636    predictions[predictions==1]=0
11637    predictions[predictions==-1]=1
11638    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11639        return pd.Series(predictions, index=df_test.index)
11640    else:
11641        return pd.Series(predictions)
11642
11643
11644
11645def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11646    """
11647    This function performs novelty detection using Local Outlier Factor (LOF).
11648
11649    Parameters:
11650
11651    - df_train (pandas dataframe): training data used to fit the model
11652
11653    - df_test (pandas dataframe): test data used to predict novelties
11654
11655    - n_neighbors (int): number of neighbors used to compute the LOF (default: 20)
11656
11657    Returns:
11658
11659    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11660
11661    """
11662    from sklearn.neighbors import LocalOutlierFactor
11663    # Fit the model on the training data
11664    df_train[ df_train == math.inf ] = 0
11665    df_test[ df_test == math.inf ] = 0
11666    clf = LocalOutlierFactor(n_neighbors=n_neighbors, algorithm='auto',contamination='auto', novelty=True)
11667    from sklearn.preprocessing import StandardScaler
11668    scaler = StandardScaler()
11669    scaler.fit(df_train)
11670    clf.fit(scaler.transform(df_train))
11671    predictions = clf.predict(scaler.transform(df_test))
11672    predictions[predictions==1]=0
11673    predictions[predictions==-1]=1
11674    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11675        return pd.Series(predictions, index=df_test.index)
11676    else:
11677        return pd.Series(predictions)
11678
11679
11680def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11681    """
11682    This function performs novelty detection using Local Outlier Factor (LOF).
11683
11684    Parameters:
11685
11686    - df_train (pandas dataframe): training data used to fit the model
11687
11688    - df_test (pandas dataframe): test data used to predict novelties
11689
11690    - n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)
11691
11692    - distance_metric : default minkowski
11693
11694    Returns:
11695
11696    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11697
11698    """
11699    from PyNomaly import loop
11700    from sklearn.neighbors import NearestNeighbors
11701    from sklearn.preprocessing import StandardScaler
11702    scaler = StandardScaler()
11703    scaler.fit(df_train)
11704    data = np.vstack( [scaler.transform(df_test),scaler.transform(df_train)])
11705    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric=distance_metric)
11706    neigh.fit(data)
11707    d, idx = neigh.kneighbors(data, return_distance=True)
11708    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11709    return m.local_outlier_probabilities[range(df_test.shape[0])]
11710
11711
11712
11713def novelty_detection_quantile(df_train, df_test):
11714    """
11715    This function performs novelty detection using quantiles for each column.
11716
11717    Parameters:
11718
11719    - df_train (pandas dataframe): training data used to fit the model
11720
11721    - df_test (pandas dataframe): test data used to predict novelties
11722
11723    Returns:
11724
11725    - quantiles for the test sample at each column where values range in [0,1]
11726        and higher values mean the column is closer to the edge of the distribution
11727
11728    """
11729    myqs = df_test.copy()
11730    n = df_train.shape[0]
11731    df_trainkeys = df_train.keys()
11732    for k in range( df_train.shape[1] ):
11733        mykey = df_trainkeys[k]
11734        temp = (myqs[mykey][0] >  df_train[mykey]).sum() / n
11735        myqs[mykey] = abs( temp - 0.5 ) / 0.5
11736    return myqs
11737
11738
11739
11740def shorten_pymm_names(x):
11741    """
11742    Shortens pmymm names by applying a series of regex substitutions.
11743    
11744    Parameters:
11745    x (str): The input string to be shortened
11746    
11747    Returns:
11748    str: The shortened string
11749    """
11750    xx = x.lower()
11751    xx = re.sub("_", ".", xx)  # Replace underscores with periods
11752    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11753    # Apply the following regex substitutions in order
11754    xx = re.sub("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.","ilf.and.ifo", xx, flags=re.I)
11755    xx = re.sub(r"sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.", "ilf.and.ifo", xx, flags=re.I)
11756    xx = re.sub(r".cres.stria.terminalis.can.not.be.resolved.with.current.resolution.", "", 
11757xx, flags=re.I)
11758    xx = re.sub("_", ".", xx)  # Replace underscores with periods
11759    xx = re.sub(r"longitudinal.fasciculus", "l.fasc", xx, flags=re.I)
11760    xx = re.sub(r"corona.radiata", "cor.rad", xx, flags=re.I)
11761    xx = re.sub("central", "cent", xx, flags=re.I)
11762    xx = re.sub(r"deep.cit168", "dp.", xx, flags=re.I)
11763    xx = re.sub("cit168", "", xx, flags=re.I)
11764    xx = re.sub(".include", "", xx, flags=re.I)
11765    xx = re.sub("mtg.sn", "", xx, flags=re.I)
11766    xx = re.sub("brainstem", ".bst", xx, flags=re.I)
11767    xx = re.sub(r"rsfmri.", "rsf.", xx, flags=re.I)
11768    xx = re.sub(r"dti.mean.fa.", "dti.fa.", xx, flags=re.I)
11769    xx = re.sub("perf.cbf.mean.", "cbf.", xx, flags=re.I)
11770    xx = re.sub(".jhu.icbm.labels.1mm", "", xx, flags=re.I)
11771    xx = re.sub(".include.optic.radiation.", "", xx, flags=re.I)
11772    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11773    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11774    xx = re.sub("cerebellar.peduncle", "cereb.ped", xx, flags=re.I)
11775    xx = re.sub(r"anterior.limb.of.internal.capsule", "ant.int.cap", xx, flags=re.I)
11776    xx = re.sub(r"posterior.limb.of.internal.capsule", "post.int.cap", xx, flags=re.I)
11777    xx = re.sub("t1hier.", "t1.", xx, flags=re.I)
11778    xx = re.sub("anterior", "ant", xx, flags=re.I)
11779    xx = re.sub("posterior", "post", xx, flags=re.I)
11780    xx = re.sub("inferior", "inf", xx, flags=re.I)
11781    xx = re.sub("superior", "sup", xx, flags=re.I)
11782    xx = re.sub(r"dktcortex", ".ctx", xx, flags=re.I)
11783    xx = re.sub(".lravg", "", xx, flags=re.I)
11784    xx = re.sub("dti.mean.fa", "dti.fa", xx, flags=re.I)
11785    xx = re.sub(r"retrolenticular.part.of.internal", "rent.int.cap", xx, flags=re.I)
11786    xx = re.sub(r"iculus.could.be.a.part.of.ant.internal.capsule", "", xx, flags=re.I)  # Twice
11787    xx = re.sub(".fronto.occipital.", ".frnt.occ.", xx, flags=re.I)
11788    xx = re.sub(r".longitidinal.fasciculus.", ".long.fasc.", xx, flags=re.I)  # Twice
11789    xx = re.sub(".external.capsule", ".ext.cap", xx, flags=re.I)
11790    xx = re.sub("of.internal.capsule", ".int.cap", xx, flags=re.I)
11791    xx = re.sub("fornix.cres.stria.terminalis", "fornix.", xx, flags=re.I)
11792    xx = re.sub("capsule", "", xx, flags=re.I)
11793    xx = re.sub("and.inf.frnt.occ.fasciculus.", "", xx, flags=re.I)
11794    xx = re.sub("crossing.tract.a.part.of.mcp.", "", xx, flags=re.I)
11795    return xx[:40]  # Truncate to first 40 characters
11796
11797
11798def shorten_pymm_names2(x, verbose=False ):
11799    """
11800    Shortens pmymm names by applying a series of regex substitutions.
11801
11802    Parameters:
11803    x (str): The input string to be shortened
11804
11805    verbose (bool): explain the patterns and replacements and their impact
11806
11807    Returns:
11808    str: The shortened string
11809    """
11810    # Define substitution patterns as tuples
11811    substitutions = [
11812        ("_", "."),  
11813        ("\.\.", "."),
11814        ("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.","ilf.and.ifo"),
11815        (r"sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.", "ilf.and.ifo"),
11816        (r".cres.stria.terminalis.can.not.be.resolved.with.current.resolution.", ""),
11817        ("_", "."),
11818        (r"longitudinal.fasciculus", "l.fasc"),
11819        (r"corona.radiata", "cor.rad"),
11820        ("central", "cent"),
11821        (r"deep.cit168", "dp."),
11822        ("cit168", ""),
11823        (".include", ""),
11824        ("mtg.sn", ""),
11825        ("brainstem", ".bst"),
11826        (r"rsfmri.", "rsf."),
11827        (r"dti.mean.fa.", "dti.fa."),
11828        ("perf.cbf.mean.", "cbf."),
11829        (".jhu.icbm.labels.1mm", ""),
11830        (".include.optic.radiation.", ""),
11831        ("\.\.", "."),  # Replace double dots with single dot
11832        ("\.\.", "."),  # Replace double dots with single dot
11833        ("cerebellar.peduncle", "cereb.ped"),
11834        (r"anterior.limb.of.internal.capsule", "ant.int.cap"),
11835        (r"posterior.limb.of.internal.capsule", "post.int.cap"),
11836        ("t1hier.", "t1."),
11837        ("anterior", "ant"),
11838        ("posterior", "post"),
11839        ("inferior", "inf"),
11840        ("superior", "sup"),
11841        (r"dktcortex", ".ctx"),
11842        (".lravg", ""),
11843        ("dti.mean.fa", "dti.fa"),
11844        (r"retrolenticular.part.of.internal", "rent.int.cap"),
11845        (r"iculus.could.be.a.part.of.ant.internal.capsule", ""),  # Twice
11846        (".fronto.occipital.", ".frnt.occ."),
11847        (r".longitidinal.fasciculus.", ".long.fasc."),  # Twice
11848        (".external.capsule", ".ext.cap"),
11849        ("of.internal.capsule", ".int.cap"),
11850        ("fornix.cres.stria.terminalis", "fornix."),
11851        ("capsule", ""),
11852        ("and.inf.frnt.occ.fasciculus.", ""),
11853        ("crossing.tract.a.part.of.mcp.", "")
11854      ]
11855
11856    # Apply substitutions in order
11857    for pattern, replacement in substitutions:
11858        if verbose:
11859            print("Pre " + x + " pattern "+pattern + " repl " + replacement )
11860        x = re.sub(pattern, replacement, x.lower(), flags=re.IGNORECASE)
11861        if verbose:
11862            print("Post " + x)
11863
11864    return x[:40]  # Truncate to first 40 characters
11865
11866
11867def brainmap_figure(statistical_df, data_dictionary, output_prefix, brain_image, overlay_cmap='bwr', nslices=21, ncol=7, edge_image_dilation = 0, black_bg=True, axes = [0,1,2], fixed_overlay_range=None, crop=5, verbose=0 ):
11868    """
11869    Create figures based on statistical data and an underlying brain image.
11870
11871    Assumes both ~/.antspyt1w and ~/.antspymm data is available
11872
11873    Parameters:
11874    - statistical_df (pandas dataframe): with 2 columns named anat and values
11875        the anat column should have names that meet *partial matching* criterion 
11876        with respect to regions that are measured in antspymm.   value will be 
11877        the value to be displayed.   if two examples of a given region exist in 
11878        statistical_df, then the largest absolute value will be taken for display.
11879    - data_dictionary (pandas dataframe): antspymm data dictionary.
11880    - output_prefix (str): Prefix for the output figure filenames.
11881    - brain_image (antsImage): the brain image on which results will overlay.
11882    - overlay_cmap (str): see matplotlib
11883    - nslices (int): number of slices to show
11884    - ncol (int): number of columns to show
11885    - edge_image_dilation (int): integer greater than or equal to zero
11886    - black_bg (bool): boolean
11887    - axes (list): integer list typically [0,1,2] sagittal coronal axial
11888    - fixed_overlay_range (list): scalar pair will try to keep a constant cbar and will truncate the overlay at these min/max values
11889    - crop (int): crops the image to display by the extent of the overlay; larger values dilate the masks more.
11890    - verbose (bool): boolean
11891
11892    Returns:
11893    an image with values mapped to the associated regions
11894    """
11895    import re
11896
11897    def is_bst_region(filename):
11898        return filename[-4:] == '.bst'
11899
11900    # Read the statistical file
11901    zz = statistical_df 
11902    
11903    # Read the data dictionary from a CSV file
11904    mydict = data_dictionary
11905    mydict = mydict[~mydict['Measurement'].str.contains("tractography-based connectivity", na=False)]
11906    mydict2=mydict.copy()
11907    mydict2['tidynames']=mydict2['tidynames'].str.replace(".left","")
11908    mydict2['tidynames']=mydict2['tidynames'].str.replace(".right","")
11909
11910    statistical_df['anat'] = statistical_df['anat'].str.replace("_", ".", regex=True)
11911
11912    # Load image and process it
11913    edgeimg = ants.iMath(brain_image,"Normalize")
11914    if edge_image_dilation > 0:
11915        edgeimg = ants.iMath( edgeimg, "MD", edge_image_dilation)
11916
11917    # Define lists and data frames
11918    postfix = ['bf', 'cit168lab', 'mtl', 'cerebellum', 'dkt_cortex','brainstem','JHU_wm','yeo']
11919    atlas = ['BF', 'CIT168', 'MTL', 'TustisonCobra', 'desikan-killiany-tourville','brainstem','JHU_wm','yeo']
11920    postdesc = ['nbm3CH13', 'CIT168_Reinf_Learn_v1_label_descriptions_pad', 'mtl_description', 'cerebellum', 'dkt','CIT168_T1w_700um_pad_adni_brainstem','FA_JHU_labels_edited','ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic']
11921    templateprefix = '~/.antspymm/PPMI_template0_'
11922    # Iterate through columns and create figures
11923    col2viz = 'values'
11924    if True:
11925        anattoshow = zz['anat'].unique()
11926        if verbose > 0:
11927            print(col2viz)
11928            print(anattoshow)
11929        # Rest of your code for figure creation goes here...
11930        addem = edgeimg * 0
11931        for k in range(len(anattoshow)):
11932            if verbose > 0 :
11933                print(str(k) +  " " + anattoshow[k]  )
11934            mysub = zz[zz['anat'].str.contains(anattoshow[k])]
11935            anatsear=shorten_pymm_names( anattoshow[k] )
11936            anatsear=re.sub(r'[()]', '.', anatsear )
11937            anatsear=re.sub(r'\.\.', '.', anatsear )
11938            anatsear=re.sub("dti.mean.md.snc","md.snc",anatsear)
11939            anatsear=re.sub("dti.mean.fa.snc","fa.snc",anatsear)
11940            anatsear=re.sub("dti.mean.md.snr","md.snr",anatsear)
11941            anatsear=re.sub("dti.mean.fa.snr","fa.snr",anatsear)
11942            anatsear=re.sub("dti.mean.md.","",anatsear)
11943            anatsear=re.sub("dti.mean.fa.","",anatsear)
11944            anatsear=re.sub("dti.md.","",anatsear)
11945            anatsear=re.sub("dti.fa.","",anatsear)
11946            anatsear=re.sub("dti.md","",anatsear)
11947            anatsear=re.sub("dti.fa","",anatsear)
11948            anatsear=re.sub("cbf.","",anatsear)
11949            anatsear=re.sub("rsfmri.fcnxpro122.","",anatsear)
11950            anatsear=re.sub("rsfmri.fcnxpro129.","",anatsear)
11951            anatsear=re.sub("rsfmri.fcnxpro134.","",anatsear)
11952            anatsear=re.sub("t1hier.vollravg","",anatsear)
11953            anatsear=re.sub("t1hier.volasym","",anatsear)
11954            anatsear=re.sub("t1hier.thkasym","",anatsear)
11955            anatsear=re.sub("t1hier.areaasym","",anatsear)
11956            anatsear=re.sub("t1hier.vol.","",anatsear)
11957            anatsear=re.sub("t1hier.thk.","",anatsear)
11958            anatsear=re.sub("t1hier.area.","",anatsear)
11959            anatsear=re.sub("t1.volasym","",anatsear)
11960            anatsear=re.sub("t1.thkasym","",anatsear)
11961            anatsear=re.sub("t1.areaasym","",anatsear)
11962            anatsear=re.sub("t1.vol.","",anatsear)
11963            anatsear=re.sub("t1.thk.","",anatsear)
11964            anatsear=re.sub("t1.area.","",anatsear)
11965            anatsear=re.sub("asymdp.","",anatsear)
11966            anatsear=re.sub("asym.","",anatsear)
11967            anatsear=re.sub("asym","",anatsear)
11968            anatsear=re.sub("lravg.","",anatsear)
11969            anatsear=re.sub("lravg","",anatsear)
11970            anatsear=re.sub("dktcortex","",anatsear)
11971            anatsear=re.sub("dktregions","",anatsear)
11972            anatsear=re.sub("_",".",anatsear)
11973            anatsear=re.sub("superior","sup",anatsear)
11974            anatsear=re.sub("cerebellum","",anatsear)
11975            anatsear=re.sub("brainstem","",anatsear)
11976            anatsear=re.sub("t.limb.int","t.int",anatsear)
11977            anatsear=re.sub("paracentral","paracent",anatsear)
11978            anatsear=re.sub("precentral","precent",anatsear)
11979            anatsear=re.sub("postcentral","postcent",anatsear)
11980            anatsear=re.sub("sup.cerebellar.peduncle","sup.cereb.ped",anatsear)
11981            anatsear=re.sub("inferior.cerebellar.peduncle","inf.cereb.ped",anatsear)
11982            anatsear=re.sub(".crossing.tract.a.part.of.mcp.","",anatsear)
11983            anatsear=re.sub(".crossing.tract.a.part.of.","",anatsear)
11984            anatsear=re.sub(".column.and.body.of.fornix.","",anatsear)
11985            anatsear=re.sub("fronto.occipital.fasciculus.could.be.a.part.of.ant.internal.capsule","frnt.occ",anatsear)
11986            anatsear=re.sub("inferior.fronto.occipital.fasciculus.could.be.a.part.of.anterior.internal.capsule","inf.frnt.occ",anatsear)
11987            anatsear=re.sub("fornix.cres.stria.terminalis.can.not.be.resolved.with.current.resolution","fornix.column.and.body.of.fornix",anatsear)
11988            anatsear=re.sub("external.capsule","ext.cap",anatsear)
11989            anatsear=re.sub(".jhu.icbm.labels.1mm","",anatsear)
11990            anatsear=re.sub("dp.",".",anatsear)
11991            anatsear=re.sub(".mtg.sn.snc.",".snc.",anatsear)
11992            anatsear=re.sub(".mtg.sn.snr.",".snr.",anatsear)
11993            anatsear=re.sub("mtg.sn.snc.",".snc.",anatsear)
11994            anatsear=re.sub("mtg.sn.snr.",".snr.",anatsear)
11995            anatsear=re.sub("mtg.sn.snc",".snc.",anatsear)
11996            anatsear=re.sub("mtg.sn.snr",".snr.",anatsear)
11997            anatsear=re.sub("anterior.","ant.",anatsear)
11998            anatsear=re.sub("rsf.","",anatsear)
11999            anatsear=re.sub("fcnxpro122.","",anatsear)
12000            anatsear=re.sub("fcnxpro129.","",anatsear)
12001            anatsear=re.sub("fcnxpro134.","",anatsear)
12002            anatsear=re.sub("ant.corona.radiata","ant.cor.rad",anatsear)
12003            anatsear=re.sub("sup.corona.radiata","sup.cor.rad",anatsear)
12004            anatsear=re.sub("posterior.thalamic.radiation.include.optic.radiation","post.thalamic.radiation",anatsear)
12005            anatsear=re.sub("retrolenticular.part.of.internal.capsule","rent.int.cap",anatsear)
12006            anatsear=re.sub("post.limb.of.internal.capsule","post.int.cap",anatsear)
12007            anatsear=re.sub("ant.limb.of.internal.capsule","ant.int.cap",anatsear)
12008            anatsear=re.sub("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus","ilf.and.ifo",anatsear)
12009            anatsear=re.sub("post.thalamic.radiation.optic.rad","post.thalamic.radiation",anatsear)
12010            atlassearch = mydict['tidynames'].str.contains(anatsear)
12011            if atlassearch.sum() == 0:
12012                atlassearch = mydict2['tidynames'].str.contains(anatsear)
12013            if verbose > 0 :
12014                print( " anatsear " + anatsear + " atlassearch " )
12015            if atlassearch.sum() > 0:
12016                whichatlas = mydict[atlassearch]['Atlas'].iloc[0]
12017                oglabelname = mydict[atlassearch]['Label'].iloc[0]
12018                oglabelname=re.sub("_",".",oglabelname)
12019                oglabelname=re.sub(r'\.\.','.',oglabelname)
12020            else:
12021                print(anatsear)
12022                oglabelname='unknown'
12023                whichatlas=None
12024            if verbose > 0:
12025                print("oglabelname " + oglabelname + " whichatlas " + str(whichatlas) )
12026            vals2viz = mysub[col2viz].agg(['min', 'max'])
12027            vals2viz = vals2viz[abs(vals2viz).idxmax()]
12028            myext = None
12029            if anatsear == 'cingulum.hippocampus':
12030                myext = 'JHU_wm'
12031            elif 'dktcortex' in anattoshow[k] or whichatlas == 'desikan-killiany-tourville' or 'dtkregions' in anattoshow[k]  :
12032                myext = 'dkt_cortex'
12033            elif ('cit168' in anattoshow[k] or whichatlas == 'CIT168') and not 'brainstem' in anattoshow[k] and not is_bst_region(anatsear):
12034                myext = 'cit168lab'
12035            elif 'mtl' in anattoshow[k]:
12036                myext = 'mtl'
12037                oglabelname=re.sub('mtl', '',anatsear)
12038            elif 'cerebellum' in anattoshow[k]:
12039                myext = 'cerebellum'
12040                oglabelname=re.sub('cerebellum', '',anatsear)
12041                oglabelname=re.sub('t1.vo','',oglabelname)
12042                # oglabelname=oglabelname[2:]
12043            elif 'brainstem' in anattoshow[k] or is_bst_region(anatsear):
12044                myext = 'brainstem'
12045            elif any(item in anattoshow[k] for item in ['nbm', 'bf']):
12046                myext = 'bf'
12047                oglabelname=re.sub('bf', '',oglabelname)
12048#                oglabelname=re.sub(r'\.', '_',anatsear)
12049            elif whichatlas == 'johns hopkins white matter':
12050                myext = 'JHU_wm'
12051            elif whichatlas == 'desikan-killiany-tourville':
12052                myext = 'dkt_cortex'
12053            elif whichatlas == 'CIT168':
12054                myext = 'cit168lab'
12055            elif whichatlas == 'BF':
12056                myext = 'bf'
12057                oglabelname=re.sub('bf', '',oglabelname)
12058            elif whichatlas == 'yeo_homotopic':
12059                myext = 'yeo'
12060            if myext is None and verbose > 0 :
12061                if whichatlas is None:
12062                    whichatlas='None'
12063                if anattoshow[k] is None:
12064                    anattoshow[k]='None'
12065                print( "MYEXT " + anattoshow[k] + ' unfound ' + whichatlas )
12066            else:
12067                if verbose > 0 :
12068                    print( "MYEXT " + myext )
12069
12070            if myext == 'cit168lab':
12071                oglabelname=re.sub("cit168","",oglabelname)
12072            
12073            for j in postfix:
12074                if j == "dkt_cortex":
12075                    j = 'dktcortex'
12076                if j == "deep_cit168lab":
12077                    j = 'deep_cit168'
12078                anattoshow[k] = anattoshow[k].replace(j, "")
12079            if verbose > 0:
12080                print( anattoshow[k] + " " + str( vals2viz ) )
12081            correctdescript = postdesc[postfix.index(myext)]
12082            locfilename =  templateprefix + myext + '.nii.gz'
12083            if verbose > 0:
12084                print( locfilename )
12085            if myext == 'yeo':
12086                oglabelname=oglabelname.lower()
12087                oglabelname=re.sub("rsfmri_fcnxpro122_","",oglabelname)
12088                oglabelname=re.sub("rsfmri_fcnxpro129_","",oglabelname)
12089                oglabelname=re.sub("rsfmri_fcnxpro134_","",oglabelname)
12090                oglabelname=re.sub("rsfmri.fcnxpro122.","",oglabelname)
12091                oglabelname=re.sub("rsfmri.fcnxpro129.","",oglabelname)
12092                oglabelname=re.sub("rsfmri.fcnxpro134.","",oglabelname)
12093                oglabelname=re.sub("_",".",oglabelname)
12094                locfilename = "~/.antspymm/ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic.nii.gz"
12095                atlasDescript = pd.read_csv(f"~/.antspymm/{correctdescript}.csv")
12096                atlasDescript.rename(columns={'SystemName': 'Description'}, inplace=True)
12097                atlasDescript.rename(columns={'ROI': 'Label'}, inplace=True)
12098                atlasDescript['Description'] = atlasDescript['Description'].str.lower()
12099            else:
12100                atlasDescript = pd.read_csv(f"~/.antspyt1w/{correctdescript}.csv")
12101                atlasDescript['Description'] = atlasDescript['Description'].str.lower()
12102                atlasDescript['Description'] = atlasDescript['Description'].str.replace(" ", "_")
12103                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_left_", "_")
12104                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_right_", "_")
12105                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_left", "")
12106                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_right", "")
12107                atlasDescript['Description'] = atlasDescript['Description'].str.replace("left_", "")
12108                atlasDescript['Description'] = atlasDescript['Description'].str.replace("right_", "")
12109                atlasDescript['Description'] = atlasDescript['Description'].str.replace("/",".")
12110                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_",".")
12111                atlasDescript['Description'] = atlasDescript['Description'].str.replace(r'[()]', '', regex=True)
12112                atlasDescript['Description'] = atlasDescript['Description'].str.replace(r'\.\.', '.')
12113                if myext == 'JHU_wm':
12114                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-", ".")
12115                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("jhu.icbm.labels.1mm", "")
12116                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("fronto-occipital", "frnt.occ")
12117                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("superior", "sup")
12118                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("fa-", "")
12119                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-left-", "")
12120                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-right-", "")
12121                if myext == 'cerebellum':
12122                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l_", "")
12123                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r_", "")
12124                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l.", "")
12125                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r.", "")
12126                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("_",".")
12127
12128            if verbose > 0:
12129                print( atlasDescript )
12130            oglabelname = oglabelname.lower()
12131            oglabelname = re.sub(" ", "_",oglabelname)
12132            oglabelname = re.sub("_left_", "_",oglabelname)
12133            oglabelname = re.sub("_right_", "_",oglabelname)
12134            oglabelname = re.sub("_left", "",oglabelname)
12135            oglabelname = re.sub("_right", "",oglabelname)
12136            oglabelname = re.sub("t1hier_vol_", "",oglabelname)
12137            oglabelname = re.sub("t1hier_area_", "",oglabelname)
12138            oglabelname = re.sub("t1hier_thk_", "",oglabelname)
12139            oglabelname = re.sub("dktregions", "",oglabelname)
12140            oglabelname = re.sub("dktcortex", "",oglabelname)
12141
12142            oglabelname = re.sub(" ", ".",oglabelname)
12143            oglabelname = re.sub(".left.", ".",oglabelname)
12144            oglabelname = re.sub(".right.", ".",oglabelname)
12145            oglabelname = re.sub(".left", "",oglabelname)
12146            oglabelname = re.sub(".right", "",oglabelname)
12147            oglabelname = re.sub("t1hier.vol.", "",oglabelname)
12148            oglabelname = re.sub("t1hier.area.", "",oglabelname)
12149            oglabelname = re.sub("t1hier.thk.", "",oglabelname)
12150            oglabelname = re.sub("dktregions", "",oglabelname)
12151            oglabelname = re.sub("dktcortex", "",oglabelname)
12152            oglabelname=re.sub("brainstem","",oglabelname)
12153            if myext == 'JHU_wm':
12154                oglabelname = re.sub("dti_mean_fa.", "",oglabelname)
12155                oglabelname = re.sub("dti_mean_md.", "",oglabelname)
12156                oglabelname = re.sub("dti.mean.fa.", "",oglabelname)
12157                oglabelname = re.sub("dti.mean.md.", "",oglabelname)
12158                oglabelname = re.sub(".left.", "",oglabelname)
12159                oglabelname = re.sub(".right.", "",oglabelname)
12160                oglabelname = re.sub(".lravg.", "",oglabelname)
12161                oglabelname = re.sub(".asym.", "",oglabelname)
12162                oglabelname = re.sub(".jhu.icbm.labels.1mm", "",oglabelname)
12163                oglabelname = re.sub("superior", "sup",oglabelname)
12164
12165            if verbose > 0:
12166                print("oglabelname " + oglabelname )
12167
12168            if myext == 'cerebellum':
12169                if not atlasDescript.empty and 'Description' in atlasDescript.columns:
12170                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l_", "")
12171                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r_", "")
12172                    oglabelname=re.sub("ravg","",oglabelname)
12173                    oglabelname=re.sub("lavg","",oglabelname)
12174                    whichindex = atlasDescript.index[atlasDescript['Description'] == oglabelname].values
12175                else:
12176                    if atlasDescript.empty:
12177                        print("The DataFrame 'atlasDescript' is empty.")
12178                    if 'Description' not in atlasDescript.columns:
12179                        print("The column 'Description' does not exist in 'atlasDescript'.")
12180            else:
12181                whichindex = atlasDescript.index[atlasDescript['Description'].str.contains(oglabelname)]
12182
12183            if type(whichindex) is np.int64:
12184                labelnums = atlasDescript.loc[whichindex, 'Label']
12185            else:
12186                labelnums = list(atlasDescript.loc[whichindex, 'Label'])
12187
12188            if myext == 'yeo':
12189                parts = re.findall(r'\D+', oglabelname)
12190                oglabelname = [part.replace('_', '') for part in parts if part.replace('_', '')]
12191                oglabelname = [part.replace('.', '') for part in parts if part.replace('.', '')]
12192                filtered_df = atlasDescript[atlasDescript['Description'].isin(oglabelname)]
12193                labelnums = filtered_df['Label'].tolist()
12194
12195            if not isinstance(labelnums, list):
12196                labelnums=[labelnums]
12197            addemiszero = ants.threshold_image(addem, 0, 0)
12198            temp = ants.image_read(locfilename)
12199            temp = ants.mask_image(temp, temp, level=labelnums, binarize=True)
12200            if verbose > 0:
12201                print("DEBUG")
12202                print(  temp.sum() ) 
12203                print( labelnums )
12204            temp[temp == 1] = (vals2viz)
12205            temp[addemiszero == 0] = 0
12206            addem = addem + temp
12207
12208        if verbose > 0:
12209            print('Done Adding')
12210        for axx in axes:
12211            figfn=output_prefix+f"fig{col2viz}ax{axx}_py.jpg"
12212            if crop > 0:
12213                cmask = ants.threshold_image( addem,1e-5, 1e9 ).iMath("MD",crop) + ants.threshold_image( addem,-1e9, -1e-5 ).iMath("MD",crop)
12214                addemC = ants.crop_image( addem, cmask )
12215                edgeimgC = ants.crop_image( edgeimg, cmask )
12216            else:
12217                addemC = addem
12218                edgeimgC = edgeimg
12219            if fixed_overlay_range is not None:
12220                addemC[0:3,0:3,0:3]=fixed_overlay_range[0]
12221                addemC[4:7,4:7,4:7]=fixed_overlay_range[1]
12222                addemC[ addemC <= fixed_overlay_range[0] ] = 0 # fixed_overlay_range[0]
12223                addemC[ addemC >= fixed_overlay_range[1] ] = fixed_overlay_range[1]
12224            ants.plot(edgeimgC, addemC, axis=axx, nslices=nslices, ncol=ncol,       
12225                overlay_cmap=overlay_cmap, resample=False, overlay_alpha=1.0,
12226                filename=figfn, cbar=axx==axes[0], crop=True, black_bg=black_bg )
12227        if verbose > 0:
12228            print(f"{col2viz} done")
12229    if verbose:
12230        print("DONE brain map figures")
12231    return addem
12232
12233def filter_df(indf, myprefix):
12234    """
12235    Process and filter a pandas DataFrame, removing certain columns, 
12236    filtering based on data types, computing the mean of numeric columns, 
12237    and adding a prefix to column names.
12238
12239    Parameters:
12240    indf (pandas.DataFrame): The input DataFrame to be processed.
12241    myprefix (str): A string prefix to be added to the column names 
12242                    of the processed DataFrame.
12243
12244    Steps:
12245    1. Removes columns with names containing 'Unnamed'.
12246    2. If the DataFrame has no rows, it returns the empty DataFrame.
12247    3. Filters out columns based on the type of the first element, 
12248       keeping those that are of type `object`, `int`, or `float`.
12249    4. Removes columns that are of `object` dtype.
12250    5. Calculates the mean of the remaining columns, skipping NaN values.
12251    6. Adds the specified `myprefix` to the column names.
12252
12253    Returns:
12254    pandas.DataFrame: A transformed DataFrame with a single row containing 
12255                      the mean values of the filtered columns, and with 
12256                      column names prefixed as specified.
12257    """
12258    indf = indf.loc[:, ~indf.columns.str.contains('Unnamed*', na=False, regex=True)]
12259    if indf.shape[0] == 0:
12260        return indf
12261    nums = [isinstance(indf[col].iloc[0], (object, int, float)) for col in indf.columns]
12262    indf = indf.loc[:, nums]
12263    indf = indf.loc[:, indf.dtypes != 'object']
12264    indf = pd.DataFrame(indf.mean(axis=0, skipna=True)).T
12265    indf = indf.add_prefix(myprefix)
12266    return indf
12267
12268
12269def aggregate_antspymm_results(input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Processed/ANTsExpArt/", hiervariable='T1wHierarchical', valid_modalities=None, verbose=False ):
12270    """
12271    Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.
12272
12273    Parameters:
12274    - input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
12275    - subject_col (str): Name of the column to store subject IDs.
12276    - date_col (str): Name of the column to store date information.
12277    - image_col (str): Name of the column to store image IDs.
12278    - date_column (str): Name of the column representing the date information.
12279    - base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
12280    - hiervariable (str) : the string variable denoting the Hierarchical output
12281    - valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
12282    - verbose : boolean
12283
12284    Note:
12285    This function is tested under limited circumstances. Use with caution.
12286
12287    Example usage:
12288    agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")
12289
12290    Author:
12291    Avants and ChatGPT
12292    """
12293    import pandas as pd
12294    import numpy as np
12295    from glob import glob
12296
12297    def myread_csv(x, cnms):
12298        """
12299        Reads a CSV file and returns a DataFrame excluding specified columns.
12300
12301        Parameters:
12302        - x (str): File path of the input CSV file describing the blind QC output
12303        - cnms (list): List of column names to exclude from the DataFrame.
12304
12305        Returns:
12306        pd.DataFrame: DataFrame with specified columns excluded.
12307        """
12308        df = pd.read_csv(x)
12309        return df.loc[:, ~df.columns.isin(cnms)]
12310
12311    import warnings
12312    # Warning message for untested function
12313    warnings.warn("Warning: This function is not well tested. Use with caution.")
12314
12315    if valid_modalities is None:
12316        valid_modalities = get_valid_modalities('long')
12317
12318    # Read the input CSV file
12319    df = pd.read_csv(input_csv)
12320
12321    # Filter rows where modality is 'T1w'
12322    df = df[df['modality'] == 'T1w']
12323    badnames = get_names_from_data_frame( ['Unnamed'], df )
12324    df=df.drop(badnames, axis=1)
12325
12326    # Add new columns for subject ID, date, and image ID
12327    df[subject_col] = np.nan
12328    df[date_col] = date_column
12329    df[image_col] = np.nan
12330    df = df.astype({subject_col: str, date_col: str, image_col: str })
12331
12332#    if verbose:
12333#        print( df.shape )
12334#        print( df.dtypes )
12335
12336    # prefilter df for data that exists
12337    keep = np.tile( False, df.shape[0] )
12338    for x in range(df.shape[0]):
12339        temp = df['filename'].iloc[x].split("_")
12340        # Generalized search paths
12341        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12342        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12343        if len( hierfn ) > 0:
12344            keep[x]=True
12345
12346    
12347    df=df[keep]
12348    
12349    if verbose:
12350        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12351        print( df.shape )
12352
12353    myct = 0
12354    for x in range( df.shape[0]):
12355        if verbose:
12356            print(f"{x}...")
12357        locind = df.index[x]
12358        temp = df['filename'].iloc[x].split("_")
12359        if verbose:
12360            print( temp )
12361        df[subject_col].iloc[x]=temp[0]
12362        df[date_col].iloc[x]=date_column
12363        df[image_col].iloc[x]=temp[1]
12364
12365        # Generalized search paths
12366        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12367        if verbose:
12368            print(path_template)
12369        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12370        if len( hierfn ) > 0:
12371            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12372            if verbose:
12373                print(hierfn)
12374            hdf = pd.read_csv(hierfn[0])
12375            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12376            hdf=hdf.drop(badnames, axis=1)
12377            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12378            corenames = list(np.array(hdf.columns)[nums])
12379            hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12380            myct = myct + 1
12381            dflist = [hdf]
12382
12383            for mymod in valid_modalities:
12384                t1wfn = sorted(glob( path_template+ "-" + mymod + "-*wide.csv" ) )
12385                if len( t1wfn ) > 0 :
12386                    if verbose:
12387                        print(t1wfn)
12388                    t1df = myread_csv(t1wfn[0], corenames)
12389                    t1df = filter_df( t1df, mymod+'_')
12390                    dflist = dflist + [t1df]
12391                
12392            hdf = pd.concat( dflist, axis=1, ignore_index=False )
12393            if verbose:
12394                print( df.loc[locind,'filename'] )
12395            if myct == 1:
12396                subdf = df.iloc[[x]]
12397                hdf.index = subdf.index.copy()
12398                df = pd.concat( [df,hdf], axis=1, ignore_index=False )
12399            else:
12400                commcols = list(set(hdf.columns).intersection(df.columns))
12401                df.loc[locind, commcols] = hdf.loc[0, commcols]
12402    badnames = get_names_from_data_frame( ['Unnamed'], df )
12403    df=df.drop(badnames, axis=1)
12404    return( df )
12405
12406def find_most_recent_file(file_list):
12407    """
12408    Finds and returns the most recently modified file from a list of file paths.
12409    
12410    Parameters:
12411    - file_list: A list of strings, where each string is a path to a file.
12412    
12413    Returns:
12414    - The path to the most recently modified file in the list, or None if the list is empty or contains no valid files.
12415    """
12416    # Filter out items that are not files or do not exist
12417    valid_files = [f for f in file_list if os.path.isfile(f)]
12418    
12419    # Check if the filtered list is not empty
12420    if valid_files:
12421        # Find the file with the latest modification time
12422        most_recent_file = max(valid_files, key=os.path.getmtime)
12423        return [most_recent_file]
12424    else:
12425        return None
12426    
12427def aggregate_antspymm_results_sdf(
12428    study_df, 
12429    project_col='projectID',
12430    subject_col='subjectID', 
12431    date_col='date', 
12432    image_col='imageID', 
12433    base_path="./", 
12434    hiervariable='T1wHierarchical', 
12435    splitsep='-',
12436    idsep='-',
12437    wild_card_modality_id=False,
12438    second_split=False,
12439    verbose=False ):
12440    """
12441    Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame.  This assumes data is organized on disk 
12442    as follows:  rootdir/projectID/subjectID/date/outputid/imageid/ where 
12443    outputid is modality-specific and created by ANTsPyMM processing.
12444
12445    Parameters:
12446    - study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
12447    - project_col (str): Name of the column that stores the project ID
12448    - subject_col (str): Name of the column to store subject IDs.
12449    - date_col (str): Name of the column to store date information.
12450    - image_col (str): Name of the column to store image IDs.
12451    - base_path (str): Base path for searching for processing outputs of ANTsPyMM.
12452    - hiervariable (str) : the string variable denoting the Hierarchical output
12453    - splitsep (str):  the separator used to split the filename
12454    - idsep (str): the separator used to partition subjectid date and imageid 
12455        for example, if idsep is - then we have subjectid-date-imageid
12456    - wild_card_modality_id (bool): keep if False for safer execution
12457    - second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
12458    - verbose : boolean
12459
12460    Note:
12461    This function is tested under limited circumstances. Use with caution.
12462    One particular gotcha is if the imageID is stored as a numeric value in the dataframe 
12463    but is meant to be a string.  E.g. '000' (string) would be interpreted as 0 in the 
12464    file name glob.  This would miss the extant (on disk) csv.
12465
12466    Example usage:
12467    agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")
12468
12469    Author:
12470    Avants and ChatGPT
12471    """
12472    import pandas as pd
12473    import numpy as np
12474    from glob import glob
12475
12476    def progress_reporter(current_step, total_steps, width=50):
12477        # Calculate the proportion of progress
12478        progress = current_step / total_steps
12479        # Calculate the number of 'filled' characters in the progress bar
12480        filled_length = int(width * progress)
12481        # Create the progress bar string
12482        bar = 'â–ˆ' * filled_length + '-' * (width - filled_length)
12483        # Print the progress bar with percentage
12484        print(f'\rProgress: |{bar}| {int(100 * progress)}%', end='\r')
12485        # Print a new line when the progress is complete
12486        if current_step == total_steps:
12487            print()
12488
12489    def myread_csv(x, cnms):
12490        """
12491        Reads a CSV file and returns a DataFrame excluding specified columns.
12492
12493        Parameters:
12494        - x (str): File path of the input CSV file describing the blind QC output
12495        - cnms (list): List of column names to exclude from the DataFrame.
12496
12497        Returns:
12498        pd.DataFrame: DataFrame with specified columns excluded.
12499        """
12500        df = pd.read_csv(x)
12501        return df.loc[:, ~df.columns.isin(cnms)]
12502
12503    import warnings
12504    # Warning message for untested function
12505    warnings.warn("Warning: This function is not well tested. Use with caution.")
12506
12507    vmoddict = {}
12508    # Add key-value pairs
12509    vmoddict['imageID'] = 'T1w'
12510    vmoddict['flairid'] = 'T2Flair'
12511    vmoddict['perfid'] = 'perf'
12512    vmoddict['pet3did'] = 'pet3d'
12513    vmoddict['rsfid1'] = 'rsfMRI'
12514#    vmoddict['rsfid2'] = 'rsfMRI'
12515    vmoddict['dtid1'] = 'DTI'
12516#    vmoddict['dtid2'] = 'DTI'
12517    vmoddict['nmid1'] = 'NM2DMT'
12518#    vmoddict['nmid2'] = 'NM2DMT'
12519
12520    # Filter rows where modality is 'T1w'
12521    df = study_df[ study_df['modality'] == 'T1w']
12522    badnames = get_names_from_data_frame( ['Unnamed'], df )
12523    df=df.drop(badnames, axis=1)
12524    # prefilter df for data that exists
12525    keep = np.tile( False, df.shape[0] )
12526    for x in range(df.shape[0]):
12527        myfn = os.path.basename( df['filename'].iloc[x] )
12528        temp = myfn.split( splitsep )
12529        # Generalized search paths
12530        sid0 = str( temp[1] )
12531        sid = str( df[subject_col].iloc[x] )
12532        if sid0 != sid:
12533            warnings.warn("OUTER: the id derived from the filename " + sid0 + " does not match the id stored in the data frame " + sid )
12534            warnings.warn( "filename is : " +  myfn )
12535            warnings.warn( "sid is : " + sid )
12536            warnings.warn( "x is : " + str(x) )
12537        myproj = str(df[project_col].iloc[x])
12538        mydate = str(df[date_col].iloc[x])
12539        myid = str(df[image_col].iloc[x])
12540        if second_split:
12541            myid = myid.split(".")[0]
12542        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12543        hierfn = sorted(glob( path_template + "*" + hiervariable + "*wide.csv" ) )
12544        if len( hierfn ) == 0:
12545            print( hierfn )
12546            print( path_template )
12547            print( myproj )
12548            print( sid )
12549            print( mydate ) 
12550            print( myid )
12551        if len( hierfn ) > 0:
12552            keep[x]=True
12553
12554    # df=df[keep]
12555    if df.shape[0] == 0:
12556        warnings.warn("input data frame shape is filtered down to zero")
12557        return df
12558
12559    if not df.index.is_unique:
12560        warnings.warn("data frame does not have unique indices.  we therefore reset the index to allow the function to continue on." )
12561        df = df.reset_index()
12562
12563    
12564    if verbose:
12565        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12566        print( df.shape )
12567
12568    dfout = pd.DataFrame()
12569    myct = 0
12570    for x in range( df.shape[0]):
12571        if verbose:
12572            print("\n\n-------------------------------------------------")
12573            print(f"{x}...")
12574        else:
12575            progress_reporter(x, df.shape[0], width=500)
12576        locind = df.index[x]
12577        myfn = os.path.basename( df['filename'].iloc[x] )
12578        sid = str( df[subject_col].iloc[x] )
12579        tempB = myfn.split( splitsep )
12580        sid0 = str(tempB[1])
12581        if sid0 != sid and verbose:
12582            warnings.warn("INNER: the id derived from the filename " + str(sid) + " does not match the id stored in the data frame " + str(sid0) )
12583            warnings.warn( "filename is : " +  str(myfn) )
12584            warnings.warn( "sid is : " + str(sid) )
12585            warnings.warn( "x is : " + str(x) )
12586            warnings.warn( "index is : " + str(locind) )
12587        myproj = str(df[project_col].iloc[x])
12588        mydate = str(df[date_col].iloc[x])
12589        myid = str(df[image_col].iloc[x])
12590        if second_split:
12591            myid = myid.split(".")[0]
12592        if verbose:
12593            print( myfn )
12594            print( temp )
12595            print( "id " + sid  )
12596        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12597        searchhier = path_template + "*" + hiervariable + "*wide.csv"
12598        if verbose:
12599            print( searchhier )
12600        hierfn = sorted( glob( searchhier ) )
12601        if len( hierfn ) > 1:
12602            raise ValueError("there are " + str( len( hierfn ) ) + " number of hier fns with search path " + searchhier )
12603        if len( hierfn ) == 1:
12604            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12605            if verbose:
12606                print(hierfn)
12607            hdf = pd.read_csv(hierfn[0])
12608            if verbose:
12609                print( hdf['vol_hemisphere_lefthemispheres'] )
12610            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12611            hdf=hdf.drop(badnames, axis=1)
12612            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12613            corenames = list(np.array(hdf.columns)[nums])
12614            # hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12615            hdf = hdf.add_prefix("T1Hier_")
12616            myct = myct + 1
12617            dflist = [hdf]
12618
12619            for mymod in vmoddict.keys():
12620                if verbose:
12621                    print("\n\n************************* " + mymod + " *************************")
12622                modalityclass = vmoddict[ mymod ]
12623                if wild_card_modality_id:
12624                    mymodid = '*'
12625                else:
12626                    mymodid = str( df[mymod].iloc[x] )
12627                    if mymodid.lower() != "nan" and mymodid.lower() != "na":
12628                        mymodid = os.path.basename( mymodid )
12629                        mymodid = os.path.splitext( mymodid )[0]
12630                        mymodid = os.path.splitext( mymodid )[0]
12631                        temp = mymodid.split( idsep )
12632                        mymodid = temp[ len( temp )-1 ]
12633                    else:
12634                        if verbose:
12635                            print("missing")
12636                        continue
12637                if verbose:
12638                    print( "modality id is " + mymodid + " for modality " + modalityclass + ' modality specific subj ' + sid + ' modality specific id is ' + myid + " its date " +  mydate )
12639                modalityclasssearch = modalityclass
12640                if modalityclass in ['rsfMRI','DTI']:
12641                    modalityclasssearch=modalityclass+"*"
12642                path_template_m = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + modalityclasssearch + '/' + mymodid + "/"
12643                modsearch = path_template_m + "*" + modalityclasssearch + "*wide.csv"
12644                if verbose:
12645                    print( modsearch )
12646                t1wfn = sorted( glob( modsearch ) )
12647                if len( t1wfn ) > 1:
12648                    nlarge = len(t1wfn)
12649                    t1wfn = find_most_recent_file( t1wfn )
12650                    warnings.warn("there are " + str( nlarge ) + " number of wide fns with search path " + modsearch + " we take the most recent of these " + t1wfn[0] )
12651                if len( t1wfn ) == 1:
12652                    if verbose:
12653                        print(t1wfn)
12654                    t1df = myread_csv(t1wfn[0], corenames)
12655                    t1df = filter_df( t1df, modalityclass+'_')
12656                    dflist = dflist + [t1df]
12657                else:
12658                    if verbose:
12659                        print( " cannot find " + modsearch )
12660                
12661            hdf = pd.concat( dflist, axis=1, ignore_index=False)
12662            if verbose:
12663                print( "count: " + str( myct ) )
12664            subdf = df.iloc[[x]]
12665            hdf.index = subdf.index.copy()
12666            subdf = pd.concat( [subdf,hdf], axis=1, ignore_index=False)
12667            dfout = pd.concat( [dfout,subdf], axis=0, ignore_index=False )
12668
12669    if dfout.shape[0] > 0:
12670        badnames = get_names_from_data_frame( ['Unnamed'], dfout )
12671        dfout=dfout.drop(badnames, axis=1)
12672    return dfout
12673
12674def enantiomorphic_filling_without_mask( image, axis=0, intensity='low' ):
12675    """
12676    Perform an enantiomorphic lesion filling on an image without a lesion mask.
12677
12678    Args:
12679    image (antsImage): The ants image to flip and fill
12680    axis ( int ): the axis along which to reflect the image
12681    intensity ( str ) : low or high
12682
12683    Returns:
12684    ants.ANTsImage: The image after enantiomorphic filling.
12685    """
12686    imagen = ants.iMath( image, 'Normalize' )
12687    imagen = ants.iMath( imagen, "TruncateIntensity", 1e-6, 0.98 )
12688    imagen = ants.iMath( imagen, 'Normalize' )
12689    # Create a mirror image (flipping left and right)
12690    mirror_image = ants.reflect_image(imagen, axis=0, tx='antsRegistrationSyNQuickRepro[s]' )['warpedmovout']
12691
12692    # Create a symmetric version of the image by averaging the original and the mirror image
12693    symmetric_image = imagen * 0.5 + mirror_image * 0.5
12694
12695    # Identify potential lesion areas by finding differences between the original and symmetric image
12696    difference_image = image - symmetric_image
12697    diffseg = ants.threshold_image(difference_image, "Otsu", 3 )
12698    if intensity == 'low':
12699        likely_lesion = ants.threshold_image( diffseg, 1,  1)
12700    else:
12701        likely_lesion = ants.threshold_image( diffseg, 3,  3)
12702    likely_lesion = ants.smooth_image( likely_lesion, 3.0 ).iMath("Normalize")
12703    lesionneg = ( imagen*0+1.0 ) - likely_lesion
12704    filled_image = ants.image_clone(imagen)    
12705    filled_image = imagen * lesionneg + mirror_image * likely_lesion
12706
12707    return filled_image, diffseg
12708
12709
12710
12711def filter_image_files(image_paths, criteria='largest'):
12712    """
12713    Filters a list of image file paths based on specified criteria and returns 
12714    the path of the image that best matches that criteria (smallest, largest, or brightest).
12715
12716    Args:
12717    image_paths (list): A list of file paths to the images.
12718    criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').
12719
12720    Returns:
12721    str: The file path of the selected image, or None if no valid images are found.
12722    """
12723    import numpy as np
12724    if not image_paths:
12725        return None
12726
12727    selected_image_path = None
12728    if criteria == 'smallest' or criteria == 'largest':
12729        extreme_volume = None
12730
12731        for path in image_paths:
12732            try:
12733                image = ants.image_read(path)
12734                volume = np.prod(image.shape)
12735
12736                if criteria == 'largest':
12737                    if extreme_volume is None or volume > extreme_volume:
12738                        extreme_volume = volume
12739                        selected_image_path = path
12740                elif criteria == 'smallest':
12741                    if extreme_volume is None or volume < extreme_volume:
12742                        extreme_volume = volume
12743                        selected_image_path = path
12744
12745            except Exception as e:
12746                print(f"Error processing image {path}: {e}")
12747
12748    elif criteria == 'brightest':
12749        max_brightness = None
12750
12751        for path in image_paths:
12752            try:
12753                image = ants.image_read(path)
12754                brightness = np.mean(image.numpy())
12755
12756                if max_brightness is None or brightness > max_brightness:
12757                    max_brightness = brightness
12758                    selected_image_path = path
12759
12760            except Exception as e:
12761                print(f"Error processing image {path}: {e}")
12762
12763    else:
12764        raise ValueError("Criteria must be 'smallest', 'largest', or 'brightest'.")
12765
12766    return selected_image_path
12767
12768
12769
12770def mm_match_by_qc_scoring(df_a, df_b, match_column, criteria, prefix='matched_', exclude_columns=None):
12771    """
12772    Match each row in df_a to a row in df_b based on a matching column and criteria for selecting the best match,
12773    with options to prefix column names from df_b and exclude certain columns from the final output. Additionally,
12774    returns a DataFrame containing rows from df_b that were not matched to any row in df_a.
12775
12776    Parameters:
12777    - df_a: DataFrame A.
12778    - df_b: DataFrame B.
12779    - match_column: The column name on which rows should match between DataFrame A and B.
12780    - criteria: A dictionary where keys are column names and values are 'min' or 'max', indicating whether
12781                the column should be minimized or maximized for the best match.
12782    - prefix: A string prefix to add to column names from df_b in the final output to avoid duplication.
12783    - exclude_columns: A list of column names from df_b to exclude from the final output.
12784    
12785    Returns:
12786    - A tuple of two DataFrames: 
12787        1. A new DataFrame combining df_a with matched rows from df_b.
12788        2. A DataFrame containing rows from df_b that were not matched to df_a.
12789    """
12790    from scipy.stats import zscore
12791    df_a = df_a.loc[:, ~df_a.columns.str.startswith('Unnamed:')].copy()
12792    if df_b is not None:
12793        df_b = df_b.loc[:, ~df_b.columns.str.startswith('Unnamed:')].copy()
12794    else:
12795        return df_a, pd.DataFrame()
12796    
12797    # Normalize df_b based on criteria
12798    for col, crit in criteria.items():
12799        if crit == 'max':
12800            df_b.loc[df_b.index, f'score_{col}'] = zscore(-df_b[col])
12801        elif crit == 'min':
12802            df_b.loc[df_b.index, f'score_{col}'] = zscore(df_b[col])
12803
12804    # Calculate 'best_score' by summing all score columns
12805    score_columns = [f'score_{col}' for col in criteria.keys()]
12806    df_b['best_score'] = df_b[score_columns].sum(axis=1)
12807
12808    matched_indices = []  # Track indices of matched rows in df_b
12809
12810    # Match rows
12811    matched_rows = []
12812    for _, row_a in df_a.iterrows():
12813        matches = df_b[df_b[match_column] == row_a[match_column]]
12814        if not matches.empty:
12815            best_idx = matches['best_score'].idxmin()
12816            best_match = matches.loc[best_idx]
12817            matched_indices.append(best_idx)  # Track this index as matched
12818            matched_rows.append(best_match)
12819        else:
12820            matched_rows.append(pd.Series(dtype='float64'))
12821
12822    # Create a DataFrame from matched rows
12823    df_matched = pd.DataFrame(matched_rows).reset_index(drop=True)
12824    
12825    # Exclude specified columns and add prefix
12826    if exclude_columns is not None:
12827        df_matched = df_matched.drop(columns=exclude_columns, errors='ignore')
12828    df_matched = df_matched.rename(columns=lambda x: f"{prefix}{x}" if x != match_column and x in df_matched.columns else x)
12829
12830    # Combine df_a with matched rows from df_b
12831    result_df = pd.concat([df_a.reset_index(drop=True), df_matched], axis=1)
12832    
12833    # Extract unmatched rows from df_b
12834    unmatched_df_b = df_b.drop(index=matched_indices).reset_index(drop=True)
12835
12836    return result_df, unmatched_df_b
12837
12838
12839def fix_LR_RL_stuff(df, col1, col2, size_col1, size_col2, id1, id2 ):
12840    df_copy = df.copy()
12841    # Ensure columns contain strings for substring checks
12842    df_copy[col1] = df_copy[col1].astype(str)
12843    df_copy[col2] = df_copy[col2].astype(str)
12844    df_copy[id1] = df_copy[id1].astype(str)
12845    df_copy[id2] = df_copy[id2].astype(str)
12846    
12847    for index, row in df_copy.iterrows():
12848        col1_val = row[col1]
12849        col2_val = row[col2]
12850        size1 = row[size_col1]
12851        size2 = row[size_col2]
12852        
12853        # Check for 'RL' or 'LR' in each column and compare sizes
12854        if ('RL' in col1_val or 'LR' in col1_val) and ('RL' in col2_val or 'LR' in col2_val):
12855            continue
12856        elif 'RL' not in col1_val and 'LR' not in col1_val and 'RL' not in col2_val and 'LR' not in col2_val:
12857            if size1 < size2:
12858                df_copy.at[index, col1] = df_copy.at[index, col2]
12859                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12860                df_copy.at[index, id1] = df_copy.at[index, id2]
12861                df_copy.at[index, size_col2] = 0
12862                df_copy.at[index, col2] = None
12863                df_copy.at[index, id2] = None
12864            else:
12865                df_copy.at[index, col2] = None
12866                df_copy.at[index, size_col2] = 0
12867                df_copy.at[index, id2] = None
12868        elif 'RL' in col1_val or 'LR' in col1_val:
12869            if size1 < size2:
12870                df_copy.at[index, col1] = df_copy.at[index, col2]
12871                df_copy.at[index, id1] = df_copy.at[index, id2]
12872                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12873                df_copy.at[index, size_col2] = 0
12874                df_copy.at[index, col2] = None
12875                df_copy.at[index, id2] = None
12876            else:
12877                df_copy.at[index, col2] = None
12878                df_copy.at[index, id2] = None
12879                df_copy.at[index, size_col2] = 0
12880        elif 'RL' in col2_val or 'LR' in col2_val:
12881            if size2 < size1:
12882                df_copy.at[index, id2] = None
12883                df_copy.at[index, col2] = None
12884                df_copy.at[index, size_col2] = 0
12885            else:
12886                df_copy.at[index, col1] = df_copy.at[index, col2]
12887                df_copy.at[index, id1] = df_copy.at[index, id2]
12888                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12889                df_copy.at[index, size_col2] = 0
12890                df_copy.at[index, col2] = None    
12891                df_copy.at[index, id2] = None    
12892    return df_copy
12893
12894
12895def renameit(df, old_col_name, new_col_name):
12896    """
12897    Renames a column in a pandas DataFrame in place. Raises an error if the specified old column name does not exist.
12898
12899    Parameters:
12900    - df: pandas.DataFrame
12901        The DataFrame in which the column is to be renamed.
12902    - old_col_name: str
12903        The current name of the column to be renamed.
12904    - new_col_name: str
12905        The new name for the column.
12906    
12907    Raises:
12908    - ValueError: If the old column name does not exist in the DataFrame.
12909    
12910    Returns:
12911    None
12912    """
12913    import warnings
12914    # Check if the old column name exists in the DataFrame
12915    if old_col_name not in df.columns:
12916        warnings.warn(f"The column '{old_col_name}' does not exist in the DataFrame.")
12917        return
12918    
12919    # Proceed with renaming the column if it exists
12920    df.rename(columns={old_col_name: new_col_name}, inplace=True)
12921
12922
12923def mm_match_by_qc_scoring_all( qc_dataframe, fix_LRRL=True, mysep='-', verbose=True ):
12924    """
12925    Processes a quality control (QC) DataFrame to perform modality-specific matching and filtering based
12926    on predefined criteria, optimizing for minimal outliers and noise, and maximal signal-to-noise ratio (SNR),
12927    expected value of randomness (EVR), and dimensionality time (dimt).
12928
12929    This function iteratively matches dataframes derived from the QC dataframe for different imaging modalities,
12930    applying a series of filters to select the best matches based on the QC metrics. Matches are made with
12931    consideration to minimize outlier loop and noise, while maximizing SNR, EVR, and dimt for each modality.
12932
12933    Parameters:
12934    ----------
12935    qc_dataframe : pandas.DataFrame
12936        The DataFrame containing QC metrics for different modalities and imaging data.
12937    fix_LRRL : bool, optional
12938    mysep : string, character such as - or _ the usual antspymm separator argument
12939
12940    verbose : bool, optional
12941        If True, prints the progress and the shape of the DataFrame being processed in each step.
12942
12943    Process:
12944    -------
12945    1. Standardizes modalities by merging DTI-related entries.
12946    2. Converts specific columns to appropriate data types for processing.
12947    3. Performs modality-specific matching and filtering based on the outlier column and criteria for each modality.
12948    4. Iteratively processes unmatched data for predefined modalities with specific prefixes to find further matches.
12949    
12950    Returns:
12951    -------
12952    pandas.DataFrame
12953        The matched and filtered DataFrame after applying all QC scoring and matching operations across specified modalities.
12954
12955    """
12956    qc_dataframe=remove_unwanted_columns( qc_dataframe )
12957    qc_dataframe['modality'] = qc_dataframe['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
12958    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
12959    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
12960    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
12961    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
12962    outlier_column='ol_loop'
12963    mmdf0 = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column, mysep=mysep )['filt']
12964    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column, mysep=mysep  )['filt']
12965    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column, mysep=mysep  )['filt']
12966    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column, mysep=mysep  )['filt']
12967    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column, mysep=mysep  )['filt']
12968    pfdf = best_mmm( qc_dataframe, 'perf', outlier_column=outlier_column, mysep=mysep  )['filt']
12969
12970    criteria = {'ol_loop': 'min', 'noise': 'min', 'snr': 'max', 'EVR': 'max', 'reflection_err':'min'}
12971    xcl = [ 'mrimfg', 'mrimodel','mriMagneticFieldStrength', 'dti_failed', 'rsf_failed', 'subjectID', 'date', 'subjectIDdate','repeat']
12972    # Assuming df_a and df_b are already loaded
12973    mmdf, undffl = mm_match_by_qc_scoring(mmdf0, fldf, 'subjectIDdate', criteria, 
12974                        prefix='T2Flair_', exclude_columns=xcl )
12975
12976    mmdf, undfpf = mm_match_by_qc_scoring(mmdf, pfdf, 'subjectIDdate', criteria, 
12977                        prefix='perf_', exclude_columns=xcl )
12978
12979    prefixes = ['NM1_', 'NM2_', 'NM3_', 'NM4_', 'NM5_', 'NM6_']  
12980    undfmod = nmdf  # Initialize 'undfmod' with 'nmdf' for the first iteration
12981    if undfmod is not None:
12982        if verbose:
12983            print('start NM')
12984            print( undfmod.shape )
12985        for prefix in prefixes:
12986            if undfmod.shape[0] > 50:
12987                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
12988                if verbose:
12989                    print( prefix )
12990                    print( undfmod.shape )
12991
12992    criteria = {'ol_loop': 'min', 'noise': 'min', 'snr': 'max', 'EVR': 'max', 'dimt':'max'}
12993    # higher bvalues lead to more noise ...
12994    criteria = {'ol_loop': 'min', 'noise': 'min',  'dti_bvalueMax':'min',  'dimt':'max'}
12995    prefixes = ['DTI1_', 'DTI2_', 'DTI3_']  # List of prefixes for each matching iteration
12996    undfmod = dtdf
12997    if undfmod is not None:
12998        if verbose:
12999            print('start DT')
13000            print( undfmod.shape )
13001        for prefix in prefixes:
13002            if undfmod.shape[0] > 50:
13003                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
13004                if verbose:
13005                    print( prefix )
13006                    print( undfmod.shape )
13007
13008    prefixes = ['rsf1_', 'rsf2_', 'rsf3_']  # List of prefixes for each matching iteration
13009    undfmod = rsdf  # Initialize 'undfmod' with 'nmdf' for the first iteration
13010    if undfmod is not None:
13011        if verbose:
13012            print('start rsf')
13013            print( undfmod.shape )
13014        for prefix in prefixes:
13015            if undfmod.shape[0] > 50:
13016                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
13017                if verbose:
13018                    print( prefix )
13019                    print( undfmod.shape )
13020    
13021    if fix_LRRL:
13022        #        mmdf=fix_LR_RL_stuff( mmdf, 'DTI1_filename', 'DTI2_filename', 'DTI1_dimt', 'DTI2_dimt')
13023        mmdf=fix_LR_RL_stuff( mmdf, 'rsf1_filename', 'rsf2_filename', 'rsf1_dimt', 'rsf2_dimt', 'rsf1_imageID', 'rsf2_imageID'  )
13024    else:
13025        import warnings
13026        warnings.warn("FIXME: should fix LR and RL situation for the DTI and rsfMRI")
13027
13028    # now do the necessary replacements
13029    
13030    renameit( mmdf, 'perf_imageID', 'perfid' )
13031    renameit( mmdf, 'perf_filename', 'perffn' )
13032    renameit( mmdf, 'T2Flair_imageID', 'flairid' )
13033    renameit( mmdf, 'T2Flair_filename', 'flairfn' )
13034    renameit( mmdf, 'rsf1_imageID', 'rsfid1' )
13035    renameit( mmdf, 'rsf2_imageID', 'rsfid2' )
13036    renameit( mmdf, 'rsf1_filename', 'rsffn1' )
13037    renameit( mmdf, 'rsf2_filename', 'rsffn2' )
13038    renameit( mmdf, 'DTI1_imageID', 'dtid1' )
13039    renameit( mmdf, 'DTI2_imageID', 'dtid2' )
13040    renameit( mmdf, 'DTI3_imageID', 'dtid3' )
13041    renameit( mmdf, 'DTI1_filename', 'dtfn1' )
13042    renameit( mmdf, 'DTI2_filename', 'dtfn2' )
13043    renameit( mmdf, 'DTI3_filename', 'dtfn3' )
13044    for x in range(1,6):
13045        temp0="NM"+str(x)+"_imageID"
13046        temp1="nmid"+str(x)
13047        renameit( mmdf, temp0, temp1 )
13048        temp0="NM"+str(x)+"_filename"
13049        temp1="nmfn"+str(x)
13050        renameit( mmdf, temp0, temp1 )
13051    return mmdf
13052
13053
13054def t1w_super_resolution_with_hemispheres(
13055    t1img,
13056    model,
13057    dilation_amount=8,
13058    truncation=[0.001, 0.999],
13059    target_range=[0, 1],
13060    poly_order="hist",
13061    min_spacing=0.8,
13062    verbose=True
13063):
13064    """
13065    Perform hemisphere-aware super-resolution on a T1-weighted image using a segmentation-guided DBPN model.
13066
13067    This function performs brain extraction, hemisphere labeling, and segmentation-aware
13068    super-resolution using the provided T1 image and model. If the resolution is sufficient,
13069    hemisphere labels guide targeted SR via the `siq.inference` function.
13070
13071    Parameters
13072    ----------
13073    t1img : ANTsImage
13074        Input T1-weighted image.
13075
13076    model : keras.Model
13077        Super-resolution model (e.g., from `siq.default_dbpn` or loaded from `.keras` file).
13078
13079    dilation_amount : int
13080        Amount of dilation to apply around labeled regions before SR.
13081
13082    truncation : list of float
13083        Percentile values used to truncate intensity before model inference.
13084
13085    target_range : list of float
13086        Range to normalize input intensities for model input.
13087
13088    poly_order : str or int
13089        Polynomial order or "hist" for histogram matching after SR.
13090
13091    min_spacing : float 
13092        if the minimum input image spacing is less than this value, 
13093        the function will return the original image.  Default 0.8.
13094
13095    verbose : bool
13096        If True, print progress updates.
13097
13098    Returns
13099    -------
13100    ANTsImage
13101        Super-resolved T1-weighted image.
13102    """
13103    if np.min(ants.get_spacing(t1img)) < min_spacing:
13104        if verbose:
13105            print("Image resolution too high — skipping SR.")
13106        return t1img
13107
13108    if verbose:
13109        print("Performing brain extraction...")
13110    brain_mask = antspyt1w.brain_extraction(t1img)
13111    brain = t1img * brain_mask
13112
13113    if verbose:
13114        print("Begin template loading")
13115    tlrfn = antspyt1w.get_data('T_template0_LR', target_extension='.nii.gz')
13116    tfn = antspyt1w.get_data('T_template0', target_extension='.nii.gz')
13117    template = ants.image_read(tfn)
13118    template = (template * antspynet.brain_extraction(template, 't1')).iMath("Normalize")
13119    template_lr = ants.image_read(tlrfn)
13120    if verbose:
13121        print("Done template loading")
13122
13123    if verbose:
13124        print("Labeling hemispheres...")
13125    hemi_seg = antspyt1w.label_hemispheres(brain, template, template_lr)
13126
13127    # Combine segmentation and brain mask — label values 1, 2 (hemi) → 3, 4
13128    hemisphere_mask = hemi_seg + 2.0 * brain_mask
13129
13130    if verbose:
13131        print("Starting segmentation-aware super-resolution...")
13132    sr_result = siq.inference(
13133        t1img,
13134        model,
13135        segmentation=hemisphere_mask,
13136        truncation=truncation,
13137        target_range=target_range,
13138        dilation_amount=dilation_amount,
13139        poly_order=poly_order,
13140        verbose=verbose
13141    )
13142
13143    sr_image = sr_result['super_resolution'] if isinstance(sr_result, dict) else sr_result
13144
13145    if verbose:
13146        print("Done super-resolution.")
13147    return sr_image
13148
13149
13150def map_idps_to_rois(
13151    idp_data_frame: pd.DataFrame,
13152    roi_image: ants.ANTsImage,
13153    idp_column: str,
13154    map_type: str = 'average'
13155) -> ants.ANTsImage:
13156    """
13157    Produces a new ANTsImage where each ROI is assigned a value based on IDP data
13158    from a DataFrame. ROIs are identified by integer labels in `roi_image`
13159    and values are linked via `idp_data_frame`.
13160
13161    Assumes `idp_data_frame` contains both 'Label' (integer ROI ID) and
13162    'Description' (string description for the ROI, e.g., 'left caudal anterior cingulate')
13163    columns, in addition to the specified `idp_column`.
13164
13165    Parameters:
13166    - idp_data_frame (pd.DataFrame): DataFrame containing IDP measurements.
13167      Must have 'Label', 'Description' (for hemisphere parsing), and `idp_column`.
13168    - roi_image (ants.ANTsImage): An ANTsImage where each voxel contains an integer
13169      label identifying an ROI.
13170    - idp_column (str): The name of the column in `idp_data_frame` whose values
13171      are to be mapped to the ROIs (e.g., 'VolumeInMillimeters').
13172    - map_type (str): Type of mapping to perform.
13173      - 'average': For identified paired left/right ROIs, their `idp_column` values are
13174                   averaged and this average is assigned to both the left and right
13175                   hemisphere ROIs in the output image. If only one side of a pair
13176                   is found, its raw value is used.
13177      - 'asymmetry': For identified paired left/right ROIs, the (Left - Right)
13178                     difference for `idp_column` is calculated and assigned only
13179                     to the left hemisphere ROI. Right hemisphere ROIs that are
13180                     part of a pair, and any unpaired ROIs, will be set to 0 in
13181                     the output image.
13182      - 'raw': Each ROI's original value from `idp_column` is mapped directly to
13183               its corresponding ROI in the output image.
13184      Default is 'average'.
13185
13186    Returns:
13187    - ants.ANTsImage: A new ANTsImage with the same header (origin, spacing,
13188                      direction, etc.) as `roi_image`, where ROI voxels are filled
13189                      with the mapped IDP values. Voxels not part of any described
13190                      ROI, or unmatched based on `map_type`, will be 0.
13191
13192    Raises:
13193    - ValueError: If required columns (`Label`, `Description`, `idp_column`) are missing
13194                  from `idp_data_frame`, `roi_image` is not an ANTsImage,
13195                  or `map_type` is invalid.
13196    """
13197    import logging
13198
13199    logging.info(f"Starting map_idps_to_rois (map_type='{map_type}', IDP column='{idp_column}')")
13200
13201    # --- 1. Input Validation ---
13202    required_idp_cols = ['Label', 'Description', idp_column]
13203    if not all(col in idp_data_frame.columns for col in required_idp_cols):
13204        raise ValueError(f"idp_data_frame must contain columns: {', '.join(required_idp_cols)}")
13205
13206    if not isinstance(roi_image, ants.ANTsImage):
13207        raise ValueError("roi_image must be an ants.ANTsImage object.")
13208
13209    valid_map_types = ['average', 'asymmetry', 'raw']
13210    if map_type not in valid_map_types:
13211        raise ValueError(f"Invalid map_type: '{map_type}'. Must be one of {valid_map_types}.")
13212
13213    # --- 2. Prepare Data (use idp_data_frame directly) ---
13214    # Select only the necessary columns from the input idp_data_frame
13215    processed_idp_df = idp_data_frame[['Label', 'Description', idp_column]].copy()
13216    processed_idp_df.rename(columns={idp_column: 'Value'}, inplace=True)
13217
13218    # Ensure 'Label' column is numeric and drop rows where conversion fails
13219    processed_idp_df['Label'] = pd.to_numeric(processed_idp_df['Label'], errors='coerce')
13220    processed_idp_df = processed_idp_df.dropna(subset=['Label']) # Drop rows where Label is NaN after coercion
13221    processed_idp_df['Label'] = processed_idp_df['Label'].astype(int) # Convert to integer labels
13222
13223    logging.info(f"Processed IDP data contains {len(processed_idp_df)} entries. "
13224                 f"{processed_idp_df['Value'].isnull().sum()} entries have no valid IDP value (NaN).")
13225
13226    # --- 3. Identify Hemispheres and Base Regions ---
13227    # This helper function parses the ROI description to determine its hemisphere
13228    # and a common base name for pairing (e.g., 'caudal anterior cingulate').
13229    def get_hemisphere_and_base(description):
13230        desc = str(description).strip()
13231
13232        # Pattern 1: FreeSurfer-like (e.g., "left caudal anterior cingulate")
13233        match_fs = re.match(r"^(left|right)\s(.+)$", desc, re.IGNORECASE)
13234        if match_fs:
13235            return match_fs.group(1).capitalize(), match_fs.group(2).strip()
13236
13237        # Pattern 2: BN_STR-like (e.g., "BN_STR_Pu_Left")
13238        match_bn = re.match(r"(.+)_(Left|Right)$", desc, re.IGNORECASE)
13239        if match_bn:
13240            return match_bn.group(2).capitalize(), match_bn.group(1).strip()
13241
13242        # No clear hemisphere identified (e.g., 'corpus callosum')
13243        return 'Unknown', desc
13244
13245    processed_idp_df[['Hemisphere', 'BaseRegion']] = processed_idp_df['Description'].apply(
13246        lambda x: pd.Series(get_hemisphere_and_base(x))
13247    )
13248
13249    # Dictionary to store the final computed value for each ROI Label
13250    label_value_map = {}
13251
13252    # --- 4. Process Values based on map_type ---
13253    if map_type == 'raw':
13254        logging.info("Mapping raw IDP values to ROIs.")
13255        # Directly map values where available
13256        for _, row in processed_idp_df.dropna(subset=['Value']).iterrows():
13257            label_value_map[row['Label']] = row['Value']
13258
13259    else: # 'average' or 'asymmetry' types which require pairing logic
13260        # Group by the 'BaseRegion' to find potential left/right pairs
13261        grouped = processed_idp_df.groupby('BaseRegion')
13262
13263        for base_region, group_df in grouped:
13264            left_roi_data = group_df[group_df['Hemisphere'] == 'Left'].dropna(subset=['Value'])
13265            right_roi_data = group_df[group_df['Hemisphere'] == 'Right'].dropna(subset=['Value'])
13266
13267            # Handle ROIs that are not clearly left/right (e.g., 'Bilateral' or 'Unknown')
13268            # For these, we include their raw value regardless of map_type.
13269            other_rois = group_df[ (group_df['Hemisphere'] != 'Left') & (group_df['Hemisphere'] != 'Right') ].dropna(subset=['Value'])
13270            for _, row in other_rois.iterrows():
13271                label_value_map[row['Label']] = row['Value']
13272                logging.debug(f"ROI: '{base_region}' (Label: {row['Label']}) - Not a pair, mapped raw value: {row['Value']:.2f}")
13273
13274            # Process paired regions if both left and right data are available
13275            if not left_roi_data.empty and not right_roi_data.empty:
13276                l_label = left_roi_data['Label'].iloc[0]
13277                l_value = left_roi_data['Value'].iloc[0]
13278                r_label = right_roi_data['Label'].iloc[0]
13279                r_value = right_roi_data['Value'].iloc[0]
13280
13281                if map_type == 'average':
13282                    avg_val = (l_value + r_value) / 2
13283                    label_value_map[l_label] = avg_val
13284                    label_value_map[r_label] = avg_val
13285                    logging.debug(f"ROI: '{base_region}' - Paired AVG: {avg_val:.2f} (L:{l_value:.2f}, R:{r_value:.2f})")
13286                elif map_type == 'asymmetry':
13287                    asym_val = l_value - r_value
13288                    label_value_map[l_label] = asym_val
13289                    # Right ROI is not assigned a value based on asymmetry, so it retains 0
13290                    logging.debug(f"ROI: '{base_region}' - Paired ASYM (L-R): {asym_val:.2f} (L:{l_value:.2f}, R:{r_value:.2f})")
13291            else:
13292                # If only one side of a pair (or neither) is found with a valid value
13293                if map_type == 'average':
13294                    if not left_roi_data.empty:
13295                        label_value_map[left_roi_data['Label'].iloc[0]] = left_roi_data['Value'].iloc[0]
13296                        logging.debug(f"ROI: '{base_region}' - L-only AVG (raw): {left_roi_data['Value'].iloc[0]:.2f}")
13297                    if not right_roi_data.empty:
13298                        label_value_map[right_roi_data['Label'].iloc[0]] = right_roi_data['Value'].iloc[0]
13299                        logging.debug(f"ROI: '{base_region}' - R-only AVG (raw): {right_roi_data['Value'].iloc[0]:.2f}")
13300                elif map_type == 'asymmetry':
13301                    # If only one hemisphere's data is available, asymmetry cannot be computed.
13302                    # For asymmetry map_type, any unpaired ROI (including left) gets 0.
13303                    if not left_roi_data.empty:
13304                        label_value_map[left_roi_data['Label'].iloc[0]] = 0.0
13305                        logging.debug(f"ROI: '{base_region}' - L-only ASYM: Set to 0.0 (no pair for computation).")
13306                    if not right_roi_data.empty:
13307                        logging.debug(f"ROI: '{base_region}' - R-only ASYM: Not assigned (relevant for left hemisphere only).")
13308
13309    # --- 5. Populate Output Image Array ---
13310    # Initialize the output NumPy array with zeros, using the robust float32 type
13311    output_numpy = np.zeros(roi_image.shape, dtype=np.float32)
13312    # Get the input ROI image's data as a NumPy array for fast lookups
13313    roi_image_numpy = roi_image.numpy()
13314
13315    total_voxels_mapped = 0
13316    unique_labels_mapped_in_image = set() # Track unique labels actually processed in the image
13317
13318    # Iterate through the `label_value_map` to assign values to the output image
13319    for label_id, value in label_value_map.items():
13320        # Only map if the value is not NaN (means it had data from IDP, valid conversion, etc.)
13321        if not np.isnan(value):
13322            # Find all voxels in `roi_image_numpy` that match the current `label_id`
13323            matching_indices = np.where(roi_image_numpy == int(label_id))
13324
13325            if matching_indices[0].size > 0: # Check if this label actually exists in roi_image
13326                output_numpy[matching_indices] = value
13327                total_voxels_mapped += matching_indices[0].size
13328                unique_labels_mapped_in_image.add(label_id)
13329
13330    logging.info(f"Mapped values for {len(unique_labels_mapped_in_image)} unique ROI labels found in `roi_image`, affecting {total_voxels_mapped} voxels.")
13331    logging.info("Unmapped ROIs in `roi_image` (not present in `idp_data_frame` or outside processing scope) retain value of 0.")
13332
13333    # --- 6. Create ANTsImage Output ---
13334    # Construct the final ANTsImage from the populated NumPy array,
13335    # preserving the spatial header information from the original `roi_image`.
13336    output_image = ants.from_numpy(
13337        output_numpy,
13338        origin=roi_image.origin,
13339        spacing=roi_image.spacing,
13340        direction=roi_image.direction
13341    )
13342
13343    logging.info("map_idps_to_rois completed successfully.")
13344    return output_image
def version():
145def version( ):
146    """
147    report versions of this package and primary dependencies
148
149    Arguments
150    ---------
151    None
152
153    Returns
154    -------
155    a dictionary with package name and versions
156
157    Example
158    -------
159    >>> import antspymm
160    >>> antspymm.version()
161    """
162    import pkg_resources
163    return {
164              'tensorflow': pkg_resources.get_distribution("tensorflow").version,
165              'antspyx': pkg_resources.get_distribution("antspyx").version,
166              'antspynet': pkg_resources.get_distribution("antspynet").version,
167              'antspyt1w': pkg_resources.get_distribution("antspyt1w").version,
168              'antspymm': pkg_resources.get_distribution("antspymm").version
169              }

report versions of this package and primary dependencies

Arguments

None

Returns

a dictionary with package name and versions

Example

>>> import antspymm
>>> antspymm.version()
def mm_read(x, standardize_intensity=False, modality=''):
1671def mm_read( x, standardize_intensity=False, modality='' ):
1672    """
1673    read an image from a filename - same as ants.image_read (for now)
1674
1675    standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one
1676
1677    modality : not used
1678    """
1679    if x is None:
1680        raise ValueError( " None passed to function antspymm.mm_read." )
1681    if not isinstance(x,str):
1682        raise ValueError( " Non-string passed to function antspymm.mm_read." )
1683    if not os.path.exists( x ):
1684        raise ValueError( " file " + fni + " does not exist." )
1685    img = ants.image_read( x, reorient=False )
1686    if standardize_intensity:
1687        img[img<0.0]=0.0
1688        img=ants.iMath(img,'Normalize')
1689    if modality == "T1w" and img.dimension == 4:
1690        print("WARNING: input image is 4D - we attempt a hack fix that works in some odd cases of PPMI data - please check this image: " + x, flush=True )
1691        i1=ants.slice_image(img,3,0)
1692        i2=ants.slice_image(img,3,1)
1693        kk=np.concatenate( [i1.numpy(),i2.numpy()], axis=2 )
1694        kk=ants.from_numpy(kk)
1695        img=ants.copy_image_info(i1,kk)
1696    return img

read an image from a filename - same as ants.image_read (for now)

standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one

modality : not used

def mm_read_to_3d(x, slice=None, modality=''):
1698def mm_read_to_3d( x, slice=None, modality='' ):
1699    """
1700    read an image from a filename - and return as 3d or None if that is not possible
1701    """
1702    img = ants.image_read( x, reorient=False )
1703    if img.dimension <= 3:
1704        return img
1705    elif img.dimension == 4:
1706        nslices = img.shape[3]
1707        if slice is None:
1708            sl = np.round( nslices * 0.5 )
1709        else:
1710            sl = slice
1711        if sl > nslices:
1712            sl = nslices-1
1713        return ants.slice_image( img, axis=3, idx=int(sl) )
1714    elif img.dimension > 4:
1715        return img
1716    return None

read an image from a filename - and return as 3d or None if that is not possible

def image_write_with_thumbnail(x, fn, y=None, thumb=True):
1754def image_write_with_thumbnail( x,  fn, y=None, thumb=True ):
1755    """
1756    will write the image and (optionally) a png thumbnail with (optional) overlay/underlay
1757    """
1758    ants.image_write( x, fn )
1759    if not thumb or x.components > 1:
1760        return
1761    thumb_fn=re.sub(".nii.gz","_3dthumb.png",fn)
1762    if thumb and x.dimension == 3:
1763        if y is None:
1764            try:
1765                ants.plot_ortho( x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1766            except:
1767                pass
1768        else:
1769            try:
1770                ants.plot_ortho( y, x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1771            except:
1772                pass
1773    if thumb and x.dimension == 4:
1774        thumb_fn=re.sub(".nii.gz","_4dthumb.png",fn)
1775        nslices = x.shape[3]
1776        sl = np.round( nslices * 0.5 )
1777        if sl > nslices:
1778            sl = nslices-1
1779        xview = ants.slice_image( x, axis=3, idx=int(sl) )
1780        if y is None:
1781            try:
1782                ants.plot_ortho( xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1783            except:
1784                pass
1785        else:
1786            if y.dimension == 3:
1787                try:
1788                    ants.plot_ortho(y, xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1789                except:
1790                    pass
1791    return

will write the image and (optionally) a png thumbnail with (optional) overlay/underlay

def nrg_format_path(projectID, subjectID, date, modality, imageID, separator='-'):
1181def nrg_format_path( projectID, subjectID, date, modality, imageID, separator='-' ):
1182    """
1183    create the NRG path on disk given the project, subject id, date, modality and image id
1184
1185    Arguments
1186    ---------
1187
1188    projectID : string for the project e.g. PPMI
1189
1190    subjectID : string uniquely identifying the subject e.g. 0001
1191
1192    date : string for the date usually 20550228 ie YYYYMMDD format
1193
1194    modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)
1195
1196    imageID : string uniquely identifying the specific image
1197
1198    separator : default to -
1199
1200    Returns
1201    -------
1202    the path where one would write the image on disk
1203
1204    """
1205    thedirectory = os.path.join( str(projectID), str(subjectID), str(date), str(modality), str(imageID) )
1206    thefilename = str(projectID) + separator + str(subjectID) + separator + str(date) + separator + str(modality) + separator + str(imageID)
1207    return os.path.join( thedirectory, thefilename )

create the NRG path on disk given the project, subject id, date, modality and image id

Arguments

projectID : string for the project e.g. PPMI

subjectID : string uniquely identifying the subject e.g. 0001

date : string for the date usually 20550228 ie YYYYMMDD format

modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)

imageID : string uniquely identifying the specific image

separator : default to -

Returns

the path where one would write the image on disk

def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
1373def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
1374    """
1375    This function returns a subset of the input dataframe that retains only the rows
1376    that correspond to the highest quality observation for each combination of ID and visit.
1377
1378    Parameters:
1379    ----------
1380    mxdfin: pandas.DataFrame
1381        The input dataframe.
1382    idvar: str
1383        The name of the column that contains the ID variable.
1384    visitvar: str
1385        The name of the column that contains the visit variable.
1386    qualityvar: str
1387        The name of the column that contains the quality variable.
1388
1389    Returns:
1390    -------
1391    pandas.DataFrame
1392        A subset of the input dataframe that retains only the rows that correspond
1393        to the highest quality observation for each combination of ID and visit.
1394    """
1395    if visitvar not in mxdfin.columns:
1396        raise ValueError("visitvar not in dataframe")
1397    if idvar not in mxdfin.columns:
1398        raise ValueError("idvar not in dataframe")
1399    if qualityvar not in mxdfin.columns:
1400        raise ValueError("qualityvar not in dataframe")
1401
1402    mxdfin[qualityvar] = mxdfin[qualityvar].astype(float)
1403
1404    vizzes = mxdfin[visitvar].unique()
1405    uids = mxdfin[idvar].unique()
1406    useit = np.zeros(mxdfin.shape[0], dtype=bool)
1407
1408    for u in uids:
1409        losel = mxdfin[idvar] == u
1410        vizzesloc = mxdfin[losel][visitvar].unique()
1411
1412        for v in vizzesloc:
1413            losel = (mxdfin[idvar] == u) & (mxdfin[visitvar] == v)
1414            mysnr = mxdfin.loc[losel, qualityvar]
1415            myw = np.where(losel)[0]
1416
1417            if len(myw) > 1:
1418                if any(~np.isnan(mysnr)):
1419                    useit[myw[np.argmax(mysnr)]] = True
1420                else:
1421                    useit[myw] = True
1422            else:
1423                useit[myw] = True
1424
1425    return mxdfin[useit]

This function returns a subset of the input dataframe that retains only the rows that correspond to the highest quality observation for each combination of ID and visit.

Parameters:

mxdfin: pandas.DataFrame The input dataframe. idvar: str The name of the column that contains the ID variable. visitvar: str The name of the column that contains the visit variable. qualityvar: str The name of the column that contains the quality variable.

Returns:

pandas.DataFrame A subset of the input dataframe that retains only the rows that correspond to the highest quality observation for each combination of ID and visit.

def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False):
1428def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False ):
1429    """
1430    Find the best multiple modality dataset at each time point
1431
1432    :param qc_dataframe: quality control data frame with
1433    :param unique_identifier : the unique NRG filename for each image
1434    :param outlier_column: outlierness score used to identify the best image (pair) at a given date
1435    :param mysep (str, optional): the separator used in the image file names. Defaults to '-'.
1436    :param verbose: boolean
1437    :return: filtered matched modality data frame
1438    """
1439    import pandas as pd
1440    import numpy as np
1441    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
1442    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
1443    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
1444    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
1445    mmdf = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column )['filt']
1446    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column )['filt']
1447    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column )['filt']
1448    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column )['filt']
1449    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column )['filt']
1450    mmdf['flairid'] = None
1451    mmdf['flairfn'] = None
1452    mmdf['flairloop'] = None
1453    mmdf['flairlof'] = None
1454    mmdf['dtid1'] = None
1455    mmdf['dtfn1'] = None
1456    mmdf['dtntimepoints1'] = 0
1457    mmdf['dtloop1'] = math.nan
1458    mmdf['dtlof1'] = math.nan
1459    mmdf['dtid2'] = None
1460    mmdf['dtfn2'] = None
1461    mmdf['dtntimepoints2'] = 0
1462    mmdf['dtloop2'] = math.nan
1463    mmdf['dtlof2'] = math.nan
1464    mmdf['rsfid1'] = None
1465    mmdf['rsffn1'] = None
1466    mmdf['rsfntimepoints1'] = 0
1467    mmdf['rsfloop1'] = math.nan
1468    mmdf['rsflof1'] = math.nan
1469    mmdf['rsfid2'] = None
1470    mmdf['rsffn2'] = None
1471    mmdf['rsfntimepoints2'] = 0
1472    mmdf['rsfloop2'] = math.nan
1473    mmdf['rsflof2'] = math.nan
1474    for k in range(1,11):
1475        myid='nmid'+str(k)
1476        mmdf[myid] = None
1477        myid='nmfn'+str(k)
1478        mmdf[myid] = None
1479        myid='nmloop'+str(k)
1480        mmdf[myid] = math.nan
1481        myid='nmlof'+str(k)
1482        mmdf[myid] = math.nan
1483    if verbose:
1484        print( mmdf.shape )
1485    for k in range(mmdf.shape[0]):
1486        if verbose:
1487            if k % 100 == 0:
1488                progger = str( k ) # np.round( k / mmdf.shape[0] * 100 ) )
1489                print( progger, end ="...", flush=True)
1490        if dtdf is not None:
1491            locsel = (dtdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
1492            if sum(locsel) == 1:
1493                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = dtdf["imageID"][locsel].values[0]
1494                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = dtdf[unique_identifier][locsel].values[0]
1495                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = dtdf[outlier_column][locsel].values[0]
1496                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(dtdf['ol_lof_decision'][locsel].values[0])
1497                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
1498            elif sum(locsel) > 1:
1499                locdf = dtdf[locsel]
1500                dedupe = locdf[["snr","cnr"]].duplicated()
1501                locdf = locdf[~dedupe]
1502                if locdf.shape[0] > 1:
1503                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1504                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = locdf["imageID"].values[0]
1505                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = locdf[unique_identifier].values[0]
1506                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = locdf[outlier_column].values[0]
1507                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(locdf['ol_lof_decision'][locsel].values[0])
1508                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
1509                if locdf.shape[0] > 1:
1510                    mmdf.iloc[k, mmdf.columns.get_loc("dtid2")] = locdf["imageID"].values[1]
1511                    mmdf.iloc[k, mmdf.columns.get_loc("dtfn2")] = locdf[unique_identifier].values[1]
1512                    mmdf.iloc[k, mmdf.columns.get_loc("dtloop2")] = locdf[outlier_column].values[1]
1513                    mmdf.iloc[k, mmdf.columns.get_loc("dtlof2")] = float(locdf['ol_lof_decision'][locsel].values[1])
1514                    mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints2")] = float(dtdf['dimt'][locsel].values[1])
1515        if rsdf is not None:
1516            locsel = (rsdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
1517            if sum(locsel) == 1:
1518                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = rsdf["imageID"][locsel].values[0]
1519                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = rsdf[unique_identifier][locsel].values[0]
1520                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = rsdf[outlier_column][locsel].values[0]
1521                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(rsdf['ol_lof_decision'].values[0])
1522                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(rsdf['dimt'][locsel].values[0])
1523            elif sum(locsel) > 1:
1524                locdf = rsdf[locsel]
1525                dedupe = locdf[["snr","cnr"]].duplicated()
1526                locdf = locdf[~dedupe]
1527                if locdf.shape[0] > 1:
1528                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1529                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = locdf["imageID"].values[0]
1530                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = locdf[unique_identifier].values[0]
1531                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = locdf[outlier_column].values[0]
1532                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(locdf['ol_lof_decision'].values[0])
1533                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(locdf['dimt'][locsel].values[0])
1534                if locdf.shape[0] > 1:
1535                    mmdf.iloc[k, mmdf.columns.get_loc("rsfid2")] = locdf["imageID"].values[1]
1536                    mmdf.iloc[k, mmdf.columns.get_loc("rsffn2")] = locdf[unique_identifier].values[1]
1537                    mmdf.iloc[k, mmdf.columns.get_loc("rsfloop2")] = locdf[outlier_column].values[1]
1538                    mmdf.iloc[k, mmdf.columns.get_loc("rsflof2")] = float(locdf['ol_lof_decision'].values[1])
1539                    mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints2")] = float(locdf['dimt'][locsel].values[1])
1540
1541        if fldf is not None:
1542            locsel = fldf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
1543            if locsel.sum() == 1:
1544                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = fldf['imageID'][locsel].values[0]
1545                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = fldf[unique_identifier][locsel].values[0]
1546                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = fldf[outlier_column][locsel].values[0]
1547                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(fldf['ol_lof_decision'][locsel].values[0])
1548            elif sum(locsel) > 1:
1549                locdf = fldf[locsel]
1550                dedupe = locdf[["snr","cnr"]].duplicated()
1551                locdf = locdf[~dedupe]
1552                if locdf.shape[0] > 1:
1553                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1554                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = locdf["imageID"].values[0]
1555                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = locdf[unique_identifier].values[0]
1556                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = locdf[outlier_column].values[0]
1557                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(locdf['ol_lof_decision'].values[0])
1558
1559        if nmdf is not None:
1560            locsel = nmdf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
1561            if locsel.sum() > 0:
1562                locdf = nmdf[locsel]
1563                for i in range(np.min( [10,locdf.shape[0]])):
1564                    nmid = "nmid"+str(i+1)
1565                    mmdf.loc[k,nmid] = locdf['imageID'].iloc[i]
1566                    nmfn = "nmfn"+str(i+1)
1567                    mmdf.loc[k,nmfn] = locdf['imageID'].iloc[i]
1568                    nmloop = "nmloop"+str(i+1)
1569                    mmdf.loc[k,nmloop] = locdf[outlier_column].iloc[i]
1570                    nmloop = "nmlof"+str(i+1)
1571                    mmdf.loc[k,nmloop] = float(locdf['ol_lof_decision'].iloc[i])
1572
1573    mmdf['rsf_total_timepoints']=mmdf['rsfntimepoints1']+mmdf['rsfntimepoints2']
1574    mmdf['dt_total_timepoints']=mmdf['dtntimepoints1']+mmdf['dtntimepoints2']
1575    return mmdf

Find the best multiple modality dataset at each time point

Parameters
  • qc_dataframe: quality control data frame with
  • unique_identifier: the unique NRG filename for each image
  • outlier_column: outlierness score used to identify the best image (pair) at a given date
  • mysep (str, optional): the separator used in the image file names. Defaults to '-'.
  • verbose: boolean
Returns

filtered matched modality data frame

def mc_resample_image_to_target(x, y, interp_type='linear'):
1810def mc_resample_image_to_target( x , y, interp_type='linear' ):
1811    """
1812    multichannel version of resample_image_to_target
1813    """
1814    xx=ants.split_channels( x )
1815    yy=ants.split_channels( y )[0]
1816    newl=[]
1817    for k in range(len(xx)):
1818        newl.append(  ants.resample_image_to_target( xx[k], yy, interp_type=interp_type ) )
1819    return ants.merge_channels( newl )

multichannel version of resample_image_to_target

def nrg_filelist_to_dataframe(filename_list, myseparator='-'):
1821def nrg_filelist_to_dataframe( filename_list, myseparator="-" ):
1822    """
1823    convert a list of files in nrg format to a dataframe
1824
1825    Arguments
1826    ---------
1827    filename_list : globbed list of files
1828
1829    myseparator : string separator between nrg parts
1830
1831    Returns
1832    -------
1833
1834    df : pandas data frame
1835
1836    """
1837    def getmtime(x):
1838        x= dt.datetime.fromtimestamp(os.path.getmtime(x)).strftime("%Y-%m-%d %H:%M:%d")
1839        return x
1840    df=pd.DataFrame(columns=['filename','file_last_mod_t','else','sid','visitdate','modality','uid'])
1841    df.set_index('filename')
1842    df['filename'] = pd.Series([file for file in filename_list ])
1843    # I applied a time modified file to df['file_last_mod_t'] by getmtime function
1844    df['file_last_mod_t'] = df['filename'].apply(lambda x: getmtime(x))
1845    for k in range(df.shape[0]):
1846        locfn=df['filename'].iloc[k]
1847        splitter=os.path.basename(locfn).split( myseparator )
1848        df['sid'].iloc[k]=splitter[1]
1849        df['visitdate'].iloc[k]=splitter[2]
1850        df['modality'].iloc[k]=splitter[3]
1851        temp = os.path.splitext(splitter[4])[0]
1852        df['uid'].iloc[k]=os.path.splitext(temp)[0]
1853    return df

convert a list of files in nrg format to a dataframe

Arguments

filename_list : globbed list of files

myseparator : string separator between nrg parts

Returns

df : pandas data frame

def merge_timeseries_data(img_LR, img_RL, allow_resample=True):
1856def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
1857    """
1858    merge time series data into space of reference_image
1859
1860    img_LR : image
1861
1862    img_RL : image
1863
1864    allow_resample : boolean
1865
1866    """
1867    # concatenate the images into the reference space
1868    mimg=[]
1869    for kk in range( img_LR.shape[3] ):
1870        temp = ants.slice_image( img_LR, axis=3, idx=kk )
1871        mimg.append( temp )
1872    for kk in range( img_RL.shape[3] ):
1873        temp = ants.slice_image( img_RL, axis=3, idx=kk )
1874        if kk == 0:
1875            insamespace = ants.image_physical_space_consistency( temp, mimg[0] )
1876        if allow_resample and not insamespace :
1877            temp = ants.resample_image_to_target( temp, mimg[0] )
1878        mimg.append( temp )
1879    return ants.list_to_ndimage( img_LR, mimg )

merge time series data into space of reference_image

img_LR : image

img_RL : image

allow_resample : boolean

def timeseries_reg( image, avg_b0, type_of_transform='antsRegistrationSyNRepro[r]', total_sigma=1.0, fdOffset=2.0, trim=0, output_directory=None, return_numpy_motion_parameters=False, verbose=False, **kwargs):
1965def timeseries_reg(
1966    image,
1967    avg_b0,
1968    type_of_transform='antsRegistrationSyNRepro[r]',
1969    total_sigma=1.0,
1970    fdOffset=2.0,
1971    trim = 0,
1972    output_directory=None,
1973    return_numpy_motion_parameters=False,
1974    verbose=False, **kwargs
1975):
1976    """
1977    Correct time-series data for motion.
1978
1979    Arguments
1980    ---------
1981    image: antsImage, usually ND where D=4.
1982
1983    avg_b0: Fixed image b0 image
1984
1985    type_of_transform : string
1986            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
1987            See ants registration for details.
1988
1989    fdOffset: offset value to use in framewise displacement calculation
1990
1991    trim : integer - trim this many images off the front of the time series
1992
1993    output_directory : string
1994            output will be placed in this directory plus a numeric extension.
1995
1996    return_numpy_motion_parameters : boolean
1997
1998    verbose: boolean
1999
2000    kwargs: keyword args
2001            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2002
2003    Returns
2004    -------
2005    dict containing follow key/value pairs:
2006        `motion_corrected`: Moving image warped to space of fixed image.
2007        `motion_parameters`: transforms for each image in the time series.
2008        `FD`: Framewise displacement generalized for arbitrary transformations.
2009
2010    Notes
2011    -----
2012    Control extra arguments via kwargs. see ants.registration for details.
2013
2014    Example
2015    -------
2016    >>> import ants
2017    """
2018    idim = image.dimension
2019    ishape = image.shape
2020    nTimePoints = ishape[idim - 1]
2021    FD = np.zeros(nTimePoints)
2022    if type_of_transform is None:
2023        return {
2024            "motion_corrected": image,
2025            "motion_parameters": None,
2026            "FD": FD
2027        }
2028
2029    remove_it=False
2030    if output_directory is None:
2031        remove_it=True
2032        output_directory = tempfile.mkdtemp()
2033    output_directory_w = output_directory + "/ts_reg/"
2034    os.makedirs(output_directory_w,exist_ok=True)
2035    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2036    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2037    if verbose:
2038        print('bold motcorr with ' + type_of_transform)
2039        print(output_directory_w)
2040        print(ofnG)
2041        print(ofnL)
2042        print("remove_it " + str( remove_it ) )
2043
2044    # get a local deformation from slice to local avg space
2045    motion_parameters = list()
2046    motion_corrected = list()
2047    mask = ants.get_mask( avg_b0 )
2048    centerOfMass = mask.get_center_of_mass()
2049    npts = pow(2, idim - 1)
2050    pointOffsets = np.zeros((npts, idim - 1))
2051    myrad = np.ones(idim - 1).astype(int).tolist()
2052    mask1vals = np.zeros(int(mask.sum()))
2053    mask1vals[round(len(mask1vals) / 2)] = 1
2054    mask1 = ants.make_image(mask, mask1vals)
2055    myoffsets = ants.get_neighborhood_in_mask(
2056        mask1, mask1, radius=myrad, spatial_info=True
2057    )["offsets"]
2058    mycols = list("xy")
2059    if idim - 1 == 3:
2060        mycols = list("xyz")
2061    useinds = list()
2062    for k in range(myoffsets.shape[0]):
2063        if abs(myoffsets[k, :]).sum() == (idim - 2):
2064            useinds.append(k)
2065        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2066    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2067    if verbose:
2068        print("Progress:")
2069    counter = round( nTimePoints / 10 ) + 1
2070    for k in range( nTimePoints):
2071        if verbose and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
2072            myperc = round( k / nTimePoints * 100)
2073            print(myperc, end="%.", flush=True)
2074        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2075        temp = ants.iMath(temp, "Normalize")
2076        txprefix = ofnL+str(k % 2).zfill(4)+"_"
2077        if temp.numpy().var() > 0:
2078            myrig = ants.registration(
2079                    avg_b0, temp,
2080                    type_of_transform='antsRegistrationSyNRepro[r]',
2081                    outprefix=txprefix
2082                )
2083            if type_of_transform == 'SyN':
2084                myreg = ants.registration(
2085                    avg_b0, temp,
2086                    type_of_transform='SyNOnly',
2087                    total_sigma=total_sigma,
2088                    initial_transform=myrig['fwdtransforms'][0],
2089                    outprefix=txprefix,
2090                    **kwargs
2091                )
2092            else:
2093                myreg = myrig
2094            fdptsTxI = ants.apply_transforms_to_points(
2095                idim - 1, fdpts, myrig["fwdtransforms"]
2096            )
2097            if k > 0 and motion_parameters[k - 1] != "NA":
2098                fdptsTxIminus1 = ants.apply_transforms_to_points(
2099                    idim - 1, fdpts, motion_parameters[k - 1]
2100                )
2101            else:
2102                fdptsTxIminus1 = fdptsTxI
2103            # take the absolute value, then the mean across columns, then the sum
2104            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2105            motion_parameters.append(myreg["fwdtransforms"])
2106        else:
2107            motion_parameters.append("NA")
2108
2109        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2110        if temp.numpy().var() > 0:
2111            img1w = ants.apply_transforms( avg_b0,
2112                temp,
2113                motion_parameters[k] )
2114            motion_corrected.append(img1w)
2115        else:
2116            motion_corrected.append(avg_b0)
2117
2118    motion_parameters = motion_parameters[trim:len(motion_parameters)]
2119    if return_numpy_motion_parameters:
2120        motion_parameters = read_ants_transforms_to_numpy( motion_parameters )
2121
2122    if remove_it:
2123        import shutil
2124        shutil.rmtree(output_directory, ignore_errors=True )
2125
2126    if verbose:
2127        print("Done")
2128    d4siz = list(avg_b0.shape)
2129    d4siz.append( 2 )
2130    spc = list(ants.get_spacing( avg_b0 ))
2131    spc.append( ants.get_spacing(image)[3] )
2132    mydir = ants.get_direction( avg_b0 )
2133    mydir4d = ants.get_direction( image )
2134    mydir4d[0:3,0:3]=mydir
2135    myorg = list(ants.get_origin( avg_b0 ))
2136    myorg.append( 0.0 )
2137    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
2138    return {
2139        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected[trim:len(motion_corrected)]),
2140        "motion_parameters": motion_parameters,
2141        "FD": FD[trim:len(FD)]
2142    }

Correct time-series data for motion.

Arguments

image: antsImage, usually ND where D=4.

avg_b0: Fixed image b0 image

type_of_transform : string A linear or non-linear registration type. Mutual information metric and rigid transformation by default. See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

trim : integer - trim this many images off the front of the time series

output_directory : string output will be placed in this directory plus a numeric extension.

return_numpy_motion_parameters : boolean

verbose: boolean

kwargs: keyword args extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
def merge_dwi_data(img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL):
2145def merge_dwi_data( img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL ):
2146    """
2147    merge motion and distortion corrected data if possible
2148
2149    img_LRdwp : image
2150
2151    bval_LR : array
2152
2153    bvec_LR : array
2154
2155    img_RLdwp : image
2156
2157    bval_RL : array
2158
2159    bvec_RL : array
2160
2161    """
2162    import warnings
2163    insamespace = ants.image_physical_space_consistency( img_LRdwp, img_RLdwp )
2164    if not insamespace :
2165        warnings.warn('not insamespace ... corrected image pair should occupy the same physical space; returning only the 1st set and wont join these data.')
2166        return img_LRdwp, bval_LR, bvec_LR
2167    
2168    bval_LR = np.concatenate([bval_LR,bval_RL])
2169    bvec_LR = np.concatenate([bvec_LR,bvec_RL])
2170    # concatenate the images
2171    mimg=[]
2172    for kk in range( img_LRdwp.shape[3] ):
2173            mimg.append( ants.slice_image( img_LRdwp, axis=3, idx=kk ) )
2174    for kk in range( img_RLdwp.shape[3] ):
2175            mimg.append( ants.slice_image( img_RLdwp, axis=3, idx=kk ) )
2176    img_LRdwp = ants.list_to_ndimage( img_LRdwp, mimg )
2177    return img_LRdwp, bval_LR, bvec_LR

merge motion and distortion corrected data if possible

img_LRdwp : image

bval_LR : array

bvec_LR : array

img_RLdwp : image

bval_RL : array

bvec_RL : array

def outlierness_by_modality( qcdf, uid='filename', outlier_columns=['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'], verbose=False):
1124def outlierness_by_modality( qcdf, uid='filename', outlier_columns = ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi','reflection_err', 'EVR', 'msk_vol'], verbose=False ):
1125    """
1126    Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF.  LOF appears to be more conservative.  This function will impute missing columns with the mean.
1127
1128    Args:
1129    - qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
1130    - uid: (str) Unique identifier for a subject. Default is 'filename'.
1131    - outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
1132    - verbose: (bool) If True, prints information for each modality. Default is False.
1133
1134    Returns:
1135    - qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column.  Higher values near 1 are more outlying.
1136
1137    Raises:
1138    - ValueError: If uid is not present in the dataframe.
1139
1140    Example:
1141    >>> df = pd.read_csv('data.csv')
1142    >>> outlierness_by_modality(df)
1143    """
1144    from PyNomaly import loop
1145    from sklearn.neighbors import LocalOutlierFactor
1146    qcdfout = qcdf.copy()
1147    pd.set_option('future.no_silent_downcasting', True)
1148    qcdfout.replace([np.inf, -np.inf], np.nan, inplace=True)
1149    if uid not in qcdfout.keys():
1150        raise ValueError( str(uid) + " not in dataframe")
1151    if 'ol_loop' not in qcdfout.keys():
1152        qcdfout['ol_loop']=math.nan
1153    if 'ol_lof' not in qcdfout.keys():
1154        qcdfout['ol_lof']=math.nan
1155    didit=False
1156    for mod in get_valid_modalities( qc=True ):
1157        didit=True
1158        lof = LocalOutlierFactor()
1159        locsel = qcdfout["modality"] == mod
1160        rr = qcdfout[locsel][outlier_columns]
1161        column_means = rr.mean()
1162        rr.fillna(column_means, inplace=True)
1163        if rr.shape[0] > 1:
1164            if verbose:
1165                print("calc: " + mod + " outlierness " )
1166            myneigh = np.min( [24, int(np.round(rr.shape[0]*0.5)) ] )
1167            temp = antspyt1w.loop_outlierness(rr.astype(float), standardize=True, extent=3, n_neighbors=myneigh, cluster_labels=None)
1168            qcdfout.loc[locsel,'ol_loop']=temp.astype('float64')
1169            yhat = lof.fit_predict(rr)
1170            temp = lof.negative_outlier_factor_*(-1.0)
1171            temp = temp - temp.min()
1172            yhat[ yhat == 1] = 0
1173            yhat[ yhat == -1] = 1 # these are outliers
1174            qcdfout.loc[locsel,'ol_lof_decision']=yhat
1175            qcdfout.loc[locsel,'ol_lof']=temp/temp.max()
1176    if verbose:
1177        print( didit )
1178    return qcdfout

Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF. LOF appears to be more conservative. This function will impute missing columns with the mean.

Args:

  • qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
  • uid: (str) Unique identifier for a subject. Default is 'filename'.
  • outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
  • verbose: (bool) If True, prints information for each modality. Default is False.

Returns:

  • qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column. Higher values near 1 are more outlying.

Raises:

  • ValueError: If uid is not present in the dataframe.

Example:

>>> df = pd.read_csv('data.csv')
>>> outlierness_by_modality(df)
def bvec_reorientation(motion_parameters, bvecs, rebase=None):
2179def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
2180    if motion_parameters is None:
2181        return bvecs
2182    n = len(motion_parameters)
2183    if n < 1:
2184        return bvecs
2185    from scipy.linalg import inv, polar
2186    from dipy.core.gradients import reorient_bvecs
2187    dipymoco = np.zeros( [n,3,3] )
2188    for myidx in range(n):
2189        if myidx < bvecs.shape[0]:
2190            dipymoco[myidx,:,:] = np.eye( 3 )
2191            if motion_parameters[myidx] != 'NA':
2192                temp = motion_parameters[myidx]
2193                if len(temp) == 4 :
2194                    temp1=temp[3] # FIXME should be composite of index 1 and 3
2195                    temp2=temp[1] # FIXME should be composite of index 1 and 3
2196                    txparam1 = ants.read_transform(temp1)
2197                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
2198                    txparam2 = ants.read_transform(temp2)
2199                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
2200                    Rinv = inv( np.dot( txparam2, txparam1 ) )
2201                elif len(temp) == 2 :
2202                    temp=temp[1] # FIXME should be composite of index 1 and 3
2203                    txparam = ants.read_transform(temp)
2204                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
2205                    Rinv = inv( txparam )
2206                elif len(temp) == 3 :
2207                    temp1=temp[2] # FIXME should be composite of index 1 and 3
2208                    temp2=temp[1] # FIXME should be composite of index 1 and 3
2209                    txparam1 = ants.read_transform(temp1)
2210                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
2211                    txparam2 = ants.read_transform(temp2)
2212                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
2213                    Rinv = inv( np.dot( txparam2, txparam1 ) )
2214                else:
2215                    temp=temp[0]
2216                    txparam = ants.read_transform(temp)
2217                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
2218                    Rinv = inv( txparam )
2219                bvecs[myidx,:] = np.dot( Rinv, bvecs[myidx,:] )
2220                if rebase is not None:
2221                    # FIXME - should combine these operations
2222                    bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
2223    return bvecs
def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False):
2259def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
2260    """
2261    extract DTI data from a dipy tensormodel
2262
2263    reference_image : antsImage defining physical space (3D)
2264
2265    tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon
2266
2267    upper_triangular: boolean otherwise use lower triangular coding
2268
2269    return_image : boolean return the ANTsImage form of DTI otherwise return an array
2270
2271    Returns
2272    -------
2273    either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form)
2274        or a 5D NumPy array (dim=X.Y.Z.3.3)
2275
2276    Notes
2277    -----
2278    DiPy returns lower triangular form but ANTs expects upper triangular.
2279        Here, we default to the ANTs standard but could generalize in the future 
2280        because not much here depends on ANTs standards of tensor data.
2281        ANTs xx,xy,xz,yy,yz,zz
2282        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
2283
2284    """
2285    # make the DTI - see 
2286    # https://dipy.org/documentation/1.7.0/examples_built/07_reconstruction/reconst_dti/#sphx-glr-examples-built-07-reconstruction-reconst-dti-py
2287    # By default, in DIPY, values are ordered as (Dxx, Dxy, Dyy, Dxz, Dyz, Dzz)
2288    # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
2289    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
2290    import dipy.reconst.dti as dti
2291    dtiut = dti.lower_triangular(tensormodel.quadratic_form)
2292    it = np.ndindex( reference_image.shape )
2293    yyind=2
2294    xzind=3
2295    if upper_triangular:
2296        yyind=3
2297        xzind=2
2298        for i in it: # convert to upper triangular
2299            dtiut[i] = dtiut[i][ reoind ] # do we care if this is doing extra work?
2300    if return_image:
2301        dtiAnts = ants.from_numpy(dtiut,has_components=True)
2302        ants.copy_image_info( reference_image, dtiAnts )
2303        return dtiAnts
2304    # copy these data into a tensor 
2305    dtinp = np.zeros(reference_image.shape + (3,3), dtype=float)  
2306    dtix = np.zeros((3,3), dtype=float)  
2307    it = np.ndindex( reference_image.shape )
2308    for i in it:
2309        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
2310        dtix[0,0]=dtivec[0]
2311        dtix[1,1]=dtivec[yyind] # 2 for LT
2312        dtix[2,2]=dtivec[5] 
2313        dtix[0,1]=dtix[1,0]=dtivec[1]
2314        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
2315        dtix[1,2]=dtix[2,1]=dtivec[4]
2316        dtinp[i]=dtix
2317    return dtinp

extract DTI data from a dipy tensormodel

reference_image : antsImage defining physical space (3D)

tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon

upper_triangular: boolean otherwise use lower triangular coding

return_image : boolean return the ANTsImage form of DTI otherwise return an array

Returns

either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form) or a 5D NumPy array (dim=X.Y.Z.3.3)

Notes

DiPy returns lower triangular form but ANTs expects upper triangular. Here, we default to the ANTs standard but could generalize in the future because not much here depends on ANTs standards of tensor data. ANTs xx,xy,xz,yy,yz,zz DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz

def dti_reg( image, avg_b0, avg_dwi, bvals=None, bvecs=None, b0_idx=None, type_of_transform='antsRegistrationSyNRepro[r]', total_sigma=3.0, fdOffset=2.0, mask_csf=False, brain_mask_eroded=None, output_directory=None, verbose=False, **kwargs):
2600def dti_reg(
2601    image,
2602    avg_b0,
2603    avg_dwi,
2604    bvals=None,
2605    bvecs=None,
2606    b0_idx=None,
2607    type_of_transform="antsRegistrationSyNRepro[r]",
2608    total_sigma=3.0,
2609    fdOffset=2.0,
2610    mask_csf=False,
2611    brain_mask_eroded=None,
2612    output_directory=None,
2613    verbose=False, **kwargs
2614):
2615    """
2616    Correct time-series data for motion - with optional deformation.
2617
2618    Arguments
2619    ---------
2620        image: antsImage, usually ND where D=4.
2621
2622        avg_b0: Fixed image b0 image
2623
2624        avg_dwi: Fixed dwi same space as b0 image
2625
2626        bvals: bvalues (file or array)
2627
2628        bvecs: bvecs (file or array)
2629
2630        b0_idx: indices of b0
2631
2632        type_of_transform : string
2633            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
2634            See ants registration for details.
2635
2636        fdOffset: offset value to use in framewise displacement calculation
2637
2638        mask_csf: boolean
2639
2640        brain_mask_eroded: optional mask that will trigger mixed interpolation
2641
2642        output_directory : string
2643            output will be placed in this directory plus a numeric extension.
2644
2645        verbose: boolean
2646
2647        kwargs: keyword args
2648            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2649
2650    Returns
2651    -------
2652    dict containing follow key/value pairs:
2653        `motion_corrected`: Moving image warped to space of fixed image.
2654        `motion_parameters`: transforms for each image in the time series.
2655        `FD`: Framewise displacement generalized for arbitrary transformations.
2656
2657    Notes
2658    -----
2659    Control extra arguments via kwargs. see ants.registration for details.
2660
2661    Example
2662    -------
2663    >>> import ants
2664    """
2665
2666    idim = image.dimension
2667    ishape = image.shape
2668    nTimePoints = ishape[idim - 1]
2669    FD = np.zeros(nTimePoints)
2670    if bvals is not None and bvecs is not None:
2671        if isinstance(bvecs, str):
2672            bvals, bvecs = read_bvals_bvecs( bvals , bvecs  )
2673        else: # assume we already read them
2674            bvals = bvals.copy()
2675            bvecs = bvecs.copy()
2676    if type_of_transform is None:
2677        return {
2678            "motion_corrected": image,
2679            "motion_parameters": None,
2680            "FD": FD,
2681            'bvals':bvals,
2682            'bvecs':bvecs
2683        }
2684
2685    from scipy.linalg import inv, polar
2686    from dipy.core.gradients import reorient_bvecs
2687
2688    remove_it=False
2689    if output_directory is None:
2690        remove_it=True
2691        output_directory = tempfile.mkdtemp()
2692    output_directory_w = output_directory + "/dti_reg/"
2693    os.makedirs(output_directory_w,exist_ok=True)
2694    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2695    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2696    if verbose:
2697        print(output_directory_w)
2698        print(ofnG)
2699        print(ofnL)
2700        print("remove_it " + str( remove_it ) )
2701
2702    if b0_idx is None:
2703        # b0_idx = segment_timeseries_by_meanvalue( image )['highermeans']
2704        b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
2705
2706    # first get a local deformation from slice to local avg space
2707    # then get a global deformation from avg to ref space
2708    ab0, adw = get_average_dwi_b0( image )
2709    # mask is used to roughly locate middle of brain
2710    mask = ants.threshold_image( ants.iMath(adw,'Normalize'), 0.1, 1.0 )
2711    if brain_mask_eroded is None:
2712        brain_mask_eroded = mask * 0 + 1
2713    motion_parameters = list()
2714    motion_corrected = list()
2715    centerOfMass = mask.get_center_of_mass()
2716    npts = pow(2, idim - 1)
2717    pointOffsets = np.zeros((npts, idim - 1))
2718    myrad = np.ones(idim - 1).astype(int).tolist()
2719    mask1vals = np.zeros(int(mask.sum()))
2720    mask1vals[round(len(mask1vals) / 2)] = 1
2721    mask1 = ants.make_image(mask, mask1vals)
2722    myoffsets = ants.get_neighborhood_in_mask(
2723        mask1, mask1, radius=myrad, spatial_info=True
2724    )["offsets"]
2725    mycols = list("xy")
2726    if idim - 1 == 3:
2727        mycols = list("xyz")
2728    useinds = list()
2729    for k in range(myoffsets.shape[0]):
2730        if abs(myoffsets[k, :]).sum() == (idim - 2):
2731            useinds.append(k)
2732        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2733    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2734
2735
2736    if verbose:
2737        print("begin global distortion correction")
2738    # initrig = tra_initializer(avg_b0, ab0, max_rotation=60, transform=['rigid'], verbose=verbose)
2739    if mask_csf:
2740        bcsf = ants.threshold_image( avg_b0,"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
2741    else:
2742        bcsf = ab0 * 0 + 1
2743
2744    initrig = ants.registration( avg_b0, ab0,'antsRegistrationSyNRepro[r]',outprefix=ofnG)
2745    deftx = ants.registration( avg_dwi, adw, 'SyNOnly',
2746        syn_metric='CC', syn_sampling=2,
2747        reg_iterations=[50,50,20],
2748        multivariate_extras=[ [ "CC", avg_b0, ab0, 1, 2 ]],
2749        initial_transform=initrig['fwdtransforms'][0],
2750        outprefix=ofnG
2751        )['fwdtransforms']
2752    if verbose:
2753        print("end global distortion correction")
2754
2755    if verbose:
2756        print("Progress:")
2757    counter = round( nTimePoints / 10 ) + 1
2758    for k in range(nTimePoints):
2759        if verbose and nTimePoints > 0 and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
2760            myperc = round( k / nTimePoints * 100)
2761            print(myperc, end="%.", flush=True)
2762        if k in b0_idx:
2763            fixed=ants.image_clone( ab0 )
2764        else:
2765            fixed=ants.image_clone( adw )
2766        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2767        temp = ants.iMath(temp, "Normalize")
2768        txprefix = ofnL+str(k).zfill(4)+"rig_"
2769        txprefix2 = ofnL+str(k % 2).zfill(4)+"def_"
2770        if temp.numpy().var() > 0:
2771            myrig = ants.registration(
2772                    fixed, temp,
2773                    type_of_transform='antsRegistrationSyNRepro[r]',
2774                    outprefix=txprefix,
2775                    **kwargs
2776                )
2777            if type_of_transform == 'SyN':
2778                myreg = ants.registration(
2779                    fixed, temp,
2780                    type_of_transform='SyNOnly',
2781                    total_sigma=total_sigma, grad_step=0.1,
2782                    initial_transform=myrig['fwdtransforms'][0],
2783                    outprefix=txprefix2,
2784                    **kwargs
2785                )
2786            else:
2787                myreg = myrig
2788            fdptsTxI = ants.apply_transforms_to_points(
2789                idim - 1, fdpts, myrig["fwdtransforms"]
2790            )
2791            if k > 0 and motion_parameters[k - 1] != "NA":
2792                fdptsTxIminus1 = ants.apply_transforms_to_points(
2793                    idim - 1, fdpts, motion_parameters[k - 1]
2794                )
2795            else:
2796                fdptsTxIminus1 = fdptsTxI
2797            # take the absolute value, then the mean across columns, then the sum
2798            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2799            motion_parameters.append(myreg["fwdtransforms"])
2800        else:
2801            motion_parameters.append("NA")
2802
2803        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2804        if k in b0_idx:
2805            fixed=ants.image_clone( ab0 )
2806        else:
2807            fixed=ants.image_clone( adw )
2808        if temp.numpy().var() > 0:
2809            motion_parameters[k]=deftx+motion_parameters[k]
2810            img1w = apply_transforms_mixed_interpolation( avg_dwi,
2811                ants.slice_image(image, axis=idim - 1, idx=k),
2812                motion_parameters[k], mask=brain_mask_eroded )
2813            motion_corrected.append(img1w)
2814        else:
2815            motion_corrected.append(fixed)
2816
2817    if verbose:
2818        print("Reorient bvecs")
2819    if bvecs is not None:
2820        #    direction = target->GetDirection().GetTranspose() * img_mov->GetDirection().GetVnlMatrix();
2821        rebase = np.dot( np.transpose( avg_b0.direction  ), ab0.direction )
2822        bvecs = bvec_reorientation( motion_parameters, bvecs, rebase )
2823
2824    if remove_it:
2825        import shutil
2826        shutil.rmtree(output_directory, ignore_errors=True )
2827
2828    if verbose:
2829        print("Done")
2830    d4siz = list(avg_b0.shape)
2831    d4siz.append( 2 )
2832    spc = list(ants.get_spacing( avg_b0 ))
2833    spc.append( 1.0 )
2834    mydir = ants.get_direction( avg_b0 )
2835    mydir4d = ants.get_direction( image )
2836    mydir4d[0:3,0:3]=mydir
2837    myorg = list(ants.get_origin( avg_b0 ))
2838    myorg.append( 0.0 )
2839    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
2840    return {
2841        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected),
2842        "motion_parameters": motion_parameters,
2843        "FD": FD,
2844        'bvals':bvals,
2845        'bvecs':bvecs
2846    }

Correct time-series data for motion - with optional deformation.

Arguments

image: antsImage, usually ND where D=4.

avg_b0: Fixed image b0 image

avg_dwi: Fixed dwi same space as b0 image

bvals: bvalues (file or array)

bvecs: bvecs (file or array)

b0_idx: indices of b0

type_of_transform : string
    A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
    See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

mask_csf: boolean

brain_mask_eroded: optional mask that will trigger mixed interpolation

output_directory : string
    output will be placed in this directory plus a numeric extension.

verbose: boolean

kwargs: keyword args
    extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
def mc_reg( image, fixed=None, type_of_transform='antsRegistrationSyNRepro[r]', mask=None, total_sigma=3.0, fdOffset=2.0, output_directory=None, verbose=False, **kwargs):
2849def mc_reg(
2850    image,
2851    fixed=None,
2852    type_of_transform="antsRegistrationSyNRepro[r]",
2853    mask=None,
2854    total_sigma=3.0,
2855    fdOffset=2.0,
2856    output_directory=None,
2857    verbose=False, **kwargs
2858):
2859    """
2860    Correct time-series data for motion - with deformation.
2861
2862    Arguments
2863    ---------
2864        image: antsImage, usually ND where D=4.
2865
2866        fixed: Fixed image to register all timepoints to.  If not provided,
2867            mean image is used.
2868
2869        type_of_transform : string
2870            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
2871            See ants registration for details.
2872
2873        fdOffset: offset value to use in framewise displacement calculation
2874
2875        output_directory : string
2876            output will be named with this prefix plus a numeric extension.
2877
2878        verbose: boolean
2879
2880        kwargs: keyword args
2881            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2882
2883    Returns
2884    -------
2885    dict containing follow key/value pairs:
2886        `motion_corrected`: Moving image warped to space of fixed image.
2887        `motion_parameters`: transforms for each image in the time series.
2888        `FD`: Framewise displacement generalized for arbitrary transformations.
2889
2890    Notes
2891    -----
2892    Control extra arguments via kwargs. see ants.registration for details.
2893
2894    Example
2895    -------
2896    >>> import ants
2897    >>> fi = ants.image_read(ants.get_ants_data('ch2'))
2898    >>> mytx = ants.motion_correction( fi )
2899    """
2900    remove_it=False
2901    if output_directory is None:
2902        remove_it=True
2903        output_directory = tempfile.mkdtemp()
2904    output_directory_w = output_directory + "/mc_reg/"
2905    os.makedirs(output_directory_w,exist_ok=True)
2906    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2907    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2908    if verbose:
2909        print(output_directory_w)
2910        print(ofnG)
2911        print(ofnL)
2912
2913    idim = image.dimension
2914    ishape = image.shape
2915    nTimePoints = ishape[idim - 1]
2916    if fixed is None:
2917        fixed = ants.get_average_of_timeseries( image )
2918    if mask is None:
2919        mask = ants.get_mask(fixed)
2920    FD = np.zeros(nTimePoints)
2921    motion_parameters = list()
2922    motion_corrected = list()
2923    centerOfMass = mask.get_center_of_mass()
2924    npts = pow(2, idim - 1)
2925    pointOffsets = np.zeros((npts, idim - 1))
2926    myrad = np.ones(idim - 1).astype(int).tolist()
2927    mask1vals = np.zeros(int(mask.sum()))
2928    mask1vals[round(len(mask1vals) / 2)] = 1
2929    mask1 = ants.make_image(mask, mask1vals)
2930    myoffsets = ants.get_neighborhood_in_mask(
2931        mask1, mask1, radius=myrad, spatial_info=True
2932    )["offsets"]
2933    mycols = list("xy")
2934    if idim - 1 == 3:
2935        mycols = list("xyz")
2936    useinds = list()
2937    for k in range(myoffsets.shape[0]):
2938        if abs(myoffsets[k, :]).sum() == (idim - 2):
2939            useinds.append(k)
2940        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2941    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2942    if verbose:
2943        print("Progress:")
2944    counter = 0
2945    for k in range(nTimePoints):
2946        mycount = round(k / nTimePoints * 100)
2947        if verbose and mycount == counter:
2948            counter = counter + 10
2949            print(mycount, end="%.", flush=True)
2950        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2951        temp = ants.iMath(temp, "Normalize")
2952        if temp.numpy().var() > 0:
2953            myrig = ants.registration(
2954                    fixed, temp,
2955                    type_of_transform='antsRegistrationSyNRepro[r]',
2956                    outprefix=ofnL+str(k).zfill(4)+"_",
2957                    **kwargs
2958                )
2959            if type_of_transform == 'SyN':
2960                myreg = ants.registration(
2961                    fixed, temp,
2962                    type_of_transform='SyNOnly',
2963                    total_sigma=total_sigma,
2964                    initial_transform=myrig['fwdtransforms'][0],
2965                    outprefix=ofnL+str(k).zfill(4)+"_",
2966                    **kwargs
2967                )
2968            else:
2969                myreg = myrig
2970            fdptsTxI = ants.apply_transforms_to_points(
2971                idim - 1, fdpts, myreg["fwdtransforms"]
2972            )
2973            if k > 0 and motion_parameters[k - 1] != "NA":
2974                fdptsTxIminus1 = ants.apply_transforms_to_points(
2975                    idim - 1, fdpts, motion_parameters[k - 1]
2976                )
2977            else:
2978                fdptsTxIminus1 = fdptsTxI
2979            # take the absolute value, then the mean across columns, then the sum
2980            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2981            motion_parameters.append(myreg["fwdtransforms"])
2982            img1w = ants.apply_transforms( fixed,
2983                ants.slice_image(image, axis=idim - 1, idx=k),
2984                myreg["fwdtransforms"] )
2985            motion_corrected.append(img1w)
2986        else:
2987            motion_parameters.append("NA")
2988            motion_corrected.append(temp)
2989
2990    if remove_it:
2991        import shutil
2992        shutil.rmtree(output_directory, ignore_errors=True )
2993
2994    if verbose:
2995        print("Done")
2996    return {
2997        "motion_corrected": ants.list_to_ndimage(image, motion_corrected),
2998        "motion_parameters": motion_parameters,
2999        "FD": FD,
3000    }

Correct time-series data for motion - with deformation.

Arguments

image: antsImage, usually ND where D=4.

fixed: Fixed image to register all timepoints to.  If not provided,
    mean image is used.

type_of_transform : string
    A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
    See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

output_directory : string
    output will be named with this prefix plus a numeric extension.

verbose: boolean

kwargs: keyword args
    extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
>>> fi = ants.image_read(ants.get_ants_data('ch2'))
>>> mytx = ants.motion_correction( fi )
def get_data(name=None, force_download=False, version=26, target_extension='.csv'):
3123def get_data( name=None, force_download=False, version=26, target_extension='.csv' ):
3124    """
3125    Get ANTsPyMM data filename
3126
3127    The first time this is called, it will download data to ~/.antspymm.
3128    After, it will just read data from disk.  The ~/.antspymm may need to
3129    be periodically deleted in order to ensure data is current.
3130
3131    Arguments
3132    ---------
3133    name : string
3134        name of data tag to retrieve
3135        Options:
3136            - 'all'
3137
3138    force_download: boolean
3139
3140    version: version of data to download (integer)
3141
3142    Returns
3143    -------
3144    string
3145        filepath of selected data
3146
3147    Example
3148    -------
3149    >>> import antspymm
3150    >>> antspymm.get_data()
3151    """
3152    os.makedirs(DATA_PATH, exist_ok=True)
3153
3154    def mv_subfolder_files(folder, verbose=False):
3155        """
3156        Move files from subfolders to the parent folder.
3157
3158        Parameters
3159        ----------
3160        folder : str
3161            Path to the folder.
3162        verbose : bool, optional
3163            Print information about the files and folders being processed (default is False).
3164
3165        Returns
3166        -------
3167        None
3168        """
3169        import os
3170        import shutil
3171        for root, dirs, files in os.walk(folder):
3172            if verbose:
3173                print(f"Processing directory: {root}")
3174                print(f"Subdirectories: {dirs}")
3175                print(f"Files: {files}")
3176            
3177            for file in files:
3178                if root != folder:
3179                    if verbose:
3180                        print(f"Moving file: {file} from {root} to {folder}")
3181                    shutil.move(os.path.join(root, file), folder)
3182            
3183            for dir in dirs:
3184                if root != folder:
3185                    if verbose:
3186                        print(f"Removing directory: {dir} from {root}")
3187                    shutil.rmtree(os.path.join(root, dir))
3188
3189    def download_data( version ):
3190        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
3191        target_file_name = "16912366.zip"
3192        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
3193            cache_subdir=DATA_PATH, extract = True )
3194        mv_subfolder_files( os.path.expanduser("~/.antspymm"), False )
3195        os.remove( DATA_PATH + target_file_name )
3196
3197    if force_download:
3198        download_data( version = version )
3199
3200
3201    files = []
3202    for fname in os.listdir(DATA_PATH):
3203        if ( fname.endswith(target_extension) ) :
3204            fname = os.path.join(DATA_PATH, fname)
3205            files.append(fname)
3206
3207    if len( files ) == 0 :
3208        download_data( version = version )
3209        for fname in os.listdir(DATA_PATH):
3210            if ( fname.endswith(target_extension) ) :
3211                fname = os.path.join(DATA_PATH, fname)
3212                files.append(fname)
3213
3214
3215    if name == 'all':
3216        return files
3217
3218    datapath = None
3219
3220    for fname in os.listdir(DATA_PATH):
3221        mystem = (Path(fname).resolve().stem)
3222        mystem = (Path(mystem).resolve().stem)
3223        mystem = (Path(mystem).resolve().stem)
3224        if ( name == mystem and fname.endswith(target_extension) ) :
3225            datapath = os.path.join(DATA_PATH, fname)
3226
3227    return datapath

Get ANTsPyMM data filename

The first time this is called, it will download data to ~/.antspymm. After, it will just read data from disk. The ~/.antspymm may need to be periodically deleted in order to ensure data is current.

Arguments

name : string name of data tag to retrieve Options: - 'all'

force_download: boolean

version: version of data to download (integer)

Returns

string filepath of selected data

Example

>>> import antspymm
>>> antspymm.get_data()
def get_models(version=3, force_download=True):
3230def get_models( version=3, force_download=True ):
3231    """
3232    Get ANTsPyMM data models
3233
3234    force_download: boolean
3235
3236    Returns
3237    -------
3238    None
3239
3240    """
3241    os.makedirs(DATA_PATH, exist_ok=True)
3242
3243    def download_data( version ):
3244        url = "https://figshare.com/ndownloader/articles/21718412/versions/"+str(version)
3245        target_file_name = "21718412.zip"
3246        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
3247            cache_subdir=DATA_PATH, extract = True )
3248        os.remove( DATA_PATH + target_file_name )
3249
3250    if force_download:
3251        download_data( version = version )
3252    return

Get ANTsPyMM data models

force_download: boolean

Returns

None

def get_valid_modalities(long=False, asString=False, qc=False):
625def get_valid_modalities( long=False, asString=False, qc=False ):
626    """
627    return a list of valid modality identifiers used in NRG modality designation
628    and that can be processed by this package.
629
630    long - return the long version
631
632    asString - concat list to string
633    """
634    if long:
635        mymod = ["T1w", "NM2DMT", "rsfMRI", "rsfMRI_LR", "rsfMRI_RL", "rsfMRILR", "rsfMRIRL", "DTI", "DTI_LR","DTI_RL",  "DTILR","DTIRL","T2Flair", "dwi", "dwi_ap", "dwi_pa", "func", "func_ap", "func_pa", "perf", 'pet3d']
636    elif qc:
637        mymod = [ 'T1w', 'T2Flair', 'NM2DMT', 'DTI', 'DTIdwi','DTIb0', 'rsfMRI', "perf", 'pet3d' ]
638    else:
639        mymod = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI", "perf", 'pet3d' ]
640    if not asString:
641        return mymod
642    else:
643        mymodchar=""
644        for x in mymod:
645            mymodchar = mymodchar + " " + str(x)
646        return mymodchar

return a list of valid modality identifiers used in NRG modality designation and that can be processed by this package.

long - return the long version

asString - concat list to string

def dewarp_imageset( image_list, initial_template=None, iterations=None, padding=0, target_idx=[0], **kwargs):
3256def dewarp_imageset( image_list, initial_template=None,
3257    iterations=None, padding=0, target_idx=[0], **kwargs ):
3258    """
3259    Dewarp a set of images
3260
3261    Makes simplifying heuristic decisions about how to transform an image set
3262    into an unbiased reference space.  Will handle plenty of decisions
3263    automatically so beware.  Computes an average shape space for the images
3264    and transforms them to that space.
3265
3266    Arguments
3267    ---------
3268    image_list : list containing antsImages 2D, 3D or 4D
3269
3270    initial_template : optional
3271
3272    iterations : number of template building iterations
3273
3274    padding:  will pad the images by an integer amount to limit edge effects
3275
3276    target_idx : the target indices for the time series over which we should average;
3277        a list of integer indices into the last axis of the input images.
3278
3279    kwargs : keyword args
3280        arguments passed to ants registration - these must be set explicitly
3281
3282    Returns
3283    -------
3284    a dictionary with the mean image and the list of the transformed images as
3285    well as motion correction parameters for each image in the input list
3286
3287    Example
3288    -------
3289    >>> import antspymm
3290    """
3291    outlist = []
3292    avglist = []
3293    if len(image_list[0].shape) > 3:
3294        imagetype = 3
3295        for k in range(len(image_list)):
3296            for j in range(len(target_idx)):
3297                avglist.append( ants.slice_image( image_list[k], axis=3, idx=target_idx[j] ) )
3298    else:
3299        imagetype = 0
3300        avglist=image_list
3301
3302    pw=[]
3303    for k in range(len(avglist[0].shape)):
3304        pw.append( padding )
3305    for k in range(len(avglist)):
3306        avglist[k] = ants.pad_image( avglist[k], pad_width=pw  )
3307
3308    if initial_template is None:
3309        initial_template = avglist[0] * 0
3310        for k in range(len(avglist)):
3311            initial_template = initial_template + avglist[k]/len(avglist)
3312
3313    if iterations is None:
3314        iterations = 2
3315
3316    btp = ants.build_template(
3317        initial_template=initial_template,
3318        image_list=avglist,
3319        gradient_step=0.5, blending_weight=0.8,
3320        iterations=iterations, verbose=False, **kwargs )
3321
3322    # last - warp all images to this frame
3323    mocoplist = []
3324    mocofdlist = []
3325    reglist = []
3326    for k in range(len(image_list)):
3327        if imagetype == 3:
3328            moco0 = ants.motion_correction( image=image_list[k], fixed=btp, type_of_transform='antsRegistrationSyNRepro[r]' )
3329            mocoplist.append( moco0['motion_parameters'] )
3330            mocofdlist.append( moco0['FD'] )
3331            locavg = ants.slice_image( moco0['motion_corrected'], axis=3, idx=0 ) * 0.0
3332            for j in range(len(target_idx)):
3333                locavg = locavg + ants.slice_image( moco0['motion_corrected'], axis=3, idx=target_idx[j] )
3334            locavg = locavg * 1.0 / len(target_idx)
3335        else:
3336            locavg = image_list[k]
3337        reg = ants.registration( btp, locavg, **kwargs )
3338        reglist.append( reg )
3339        if imagetype == 3:
3340            myishape = image_list[k].shape
3341            mytslength = myishape[ len(myishape) - 1 ]
3342            mywarpedlist = []
3343            for j in range(mytslength):
3344                locimg = ants.slice_image( image_list[k], axis=3, idx = j )
3345                mywarped = ants.apply_transforms( btp, locimg,
3346                    reg['fwdtransforms'] + moco0['motion_parameters'][j], imagetype=0 )
3347                mywarpedlist.append( mywarped )
3348            mywarped = ants.list_to_ndimage( image_list[k], mywarpedlist )
3349        else:
3350            mywarped = ants.apply_transforms( btp, image_list[k], reg['fwdtransforms'], imagetype=imagetype )
3351        outlist.append( mywarped )
3352
3353    return {
3354        'dewarpedmean':btp,
3355        'dewarped':outlist,
3356        'deformable_registrations': reglist,
3357        'FD': mocofdlist,
3358        'motionparameters': mocoplist }

Dewarp a set of images

Makes simplifying heuristic decisions about how to transform an image set into an unbiased reference space. Will handle plenty of decisions automatically so beware. Computes an average shape space for the images and transforms them to that space.

Arguments

image_list : list containing antsImages 2D, 3D or 4D

initial_template : optional

iterations : number of template building iterations

padding: will pad the images by an integer amount to limit edge effects

target_idx : the target indices for the time series over which we should average; a list of integer indices into the last axis of the input images.

kwargs : keyword args arguments passed to ants registration - these must be set explicitly

Returns

a dictionary with the mean image and the list of the transformed images as well as motion correction parameters for each image in the input list

Example

>>> import antspymm
def super_res_mcimage( image, srmodel, truncation=[0.0001, 0.995], poly_order='hist', target_range=[0, 1], isotropic=False, verbose=False):
3361def super_res_mcimage( image,
3362    srmodel,
3363    truncation=[0.0001,0.995],
3364    poly_order='hist',
3365    target_range=[0,1],
3366    isotropic = False,
3367    verbose=False ):
3368    """
3369    Super resolution on a timeseries or multi-channel image
3370
3371    Arguments
3372    ---------
3373    image : an antsImage
3374
3375    srmodel : a tensorflow fully convolutional model
3376
3377    truncation :  quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]
3378
3379    poly_order : if not None, will fit a global regression model to map
3380        intensity back to original histogram space; if 'hist' will match
3381        by histogram matching - ants.histogram_match_image
3382
3383    target_range : 2-element tuple
3384        a tuple or array defining the (min, max) of the input image
3385        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
3386        intensity. This range should match the mapping used in the training
3387        of the network.
3388
3389    isotropic : boolean
3390
3391    verbose : boolean
3392
3393    Returns
3394    -------
3395    super resolution version of the image
3396
3397    Example
3398    -------
3399    >>> import antspymm
3400    """
3401    idim = image.dimension
3402    ishape = image.shape
3403    nTimePoints = ishape[idim - 1]
3404    mcsr = list()
3405    for k in range(nTimePoints):
3406        if verbose and (( k % 5 ) == 0 ):
3407            mycount = round(k / nTimePoints * 100)
3408            print(mycount, end="%.", flush=True)
3409        temp = ants.slice_image( image, axis=idim - 1, idx=k )
3410        temp = ants.iMath( temp, "TruncateIntensity", truncation[0], truncation[1] )
3411        mysr = antspynet.apply_super_resolution_model_to_image( temp, srmodel,
3412            target_range = target_range )
3413        if poly_order is not None:
3414            bilin = ants.resample_image_to_target( temp, mysr )
3415            if poly_order == 'hist':
3416                mysr = ants.histogram_match_image( mysr, bilin )
3417            else:
3418                mysr = ants.regression_match_image( mysr, bilin, poly_order = poly_order )
3419        if isotropic:
3420            mysr = down2iso( mysr )
3421        if k == 0:
3422            upshape = list()
3423            for j in range(len(ishape)-1):
3424                upshape.append( mysr.shape[j] )
3425            upshape.append( ishape[ idim-1 ] )
3426            if verbose:
3427                print("SR will be of voxel size:" + str(upshape) )
3428        mcsr.append( mysr )
3429
3430    upshape = list()
3431    for j in range(len(ishape)-1):
3432        upshape.append( mysr.shape[j] )
3433    upshape.append( ishape[ idim-1 ] )
3434    if verbose:
3435        print("SR will be of voxel size:" + str(upshape) )
3436
3437    imageup = ants.resample_image( image, upshape, use_voxels = True )
3438    if verbose:
3439        print("Done")
3440
3441    return ants.list_to_ndimage( imageup, mcsr )

Super resolution on a timeseries or multi-channel image

Arguments

image : an antsImage

srmodel : a tensorflow fully convolutional model

truncation : quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]

poly_order : if not None, will fit a global regression model to map intensity back to original histogram space; if 'hist' will match by histogram matching - ants.histogram_match_image

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

isotropic : boolean

verbose : boolean

Returns

super resolution version of the image

Example

>>> import antspymm
def segment_timeseries_by_meanvalue(image, quantile=0.995):
3485def segment_timeseries_by_meanvalue( image, quantile = 0.995 ):
3486    """
3487    Identify indices of a time series where we assume there is a different mean
3488    intensity over the volumes.  The indices of volumes with higher and lower
3489    intensities is returned.  Can be used to automatically identify B0 volumes
3490    in DWI timeseries.
3491
3492    Arguments
3493    ---------
3494    image : an antsImage holding B0 and DWI
3495
3496    quantile : a quantile for splitting the indices of the volume - should be greater than 0.5
3497
3498    Returns
3499    -------
3500    dictionary holding the two sets of indices
3501
3502    Example
3503    -------
3504    >>> import antspymm
3505    """
3506    ishape = image.shape
3507    lastdim = len(ishape)-1
3508    meanvalues = list()
3509    for x in range(ishape[lastdim]):
3510        meanvalues.append(  ants.slice_image( image, axis=lastdim, idx=x ).mean() )
3511    myhiq = np.quantile( meanvalues, quantile )
3512    myloq = np.quantile( meanvalues, 1.0 - quantile )
3513    lowerindices = list()
3514    higherindices = list()
3515    for x in range(len(meanvalues)):
3516        hiabs = abs( meanvalues[x] - myhiq )
3517        loabs = abs( meanvalues[x] - myloq )
3518        if hiabs < loabs:
3519            higherindices.append(x)
3520        else:
3521            lowerindices.append(x)
3522
3523    return {
3524    'lowermeans':lowerindices,
3525    'highermeans':higherindices }

Identify indices of a time series where we assume there is a different mean intensity over the volumes. The indices of volumes with higher and lower intensities is returned. Can be used to automatically identify B0 volumes in DWI timeseries.

Arguments

image : an antsImage holding B0 and DWI

quantile : a quantile for splitting the indices of the volume - should be greater than 0.5

Returns

dictionary holding the two sets of indices

Example

>>> import antspymm
def get_average_rsf(x, min_t=10, max_t=35):
3528def get_average_rsf( x, min_t=10, max_t=35 ):
3529    """
3530    automatically generates the average bold image with quick registration
3531
3532    returns:
3533        avg_bold
3534    """
3535    output_directory = tempfile.mkdtemp()
3536    ofn = output_directory + "/w"
3537    bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3538    oavg = ants.slice_image( x, axis=3, idx=0 )
3539    if x.shape[3] <= min_t:
3540        min_t=0
3541    if x.shape[3] <= max_t:
3542        max_t=x.shape[3]
3543    for myidx in range(min_t,max_t):
3544        b0 = ants.slice_image( x, axis=3, idx=myidx)
3545        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3546    bavg = ants.iMath( bavg, 'Normalize' )
3547    oavg = ants.image_clone( bavg )
3548    bavg = oavg * 0.0
3549    for myidx in range(min_t,max_t):
3550        b0 = ants.slice_image( x, axis=3, idx=myidx)
3551        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3552    import shutil
3553    shutil.rmtree(output_directory, ignore_errors=True )
3554    bavg = ants.iMath( bavg, 'Normalize' )
3555    return bavg
3556    # return ants.n4_bias_field_correction(bavg, mask=ants.get_mask( bavg ) )

automatically generates the average bold image with quick registration

returns: avg_bold

def get_average_dwi_b0(x, fixed_b0=None, fixed_dwi=None, fast=False):
3559def get_average_dwi_b0( x, fixed_b0=None, fixed_dwi=None, fast=False ):
3560    """
3561    automatically generates the average b0 and dwi and outputs both;
3562    maps dwi to b0 space at end.
3563
3564    x : input image
3565
3566    fixed_b0 : alernative reference space
3567
3568    fixed_dwi : alernative reference space
3569
3570    fast : boolean
3571
3572    returns:
3573        avg_b0, avg_dwi
3574    """
3575    output_directory = tempfile.mkdtemp()
3576    ofn = output_directory + "/w"
3577    temp = segment_timeseries_by_meanvalue( x )
3578    b0_idx = temp['highermeans']
3579    non_b0_idx = temp['lowermeans']
3580    if ( fixed_b0 is None and fixed_dwi is None ) or fast:
3581        xavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3582        bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3583        fixed_b0_use = ants.slice_image( x, axis=3, idx=b0_idx[0] )
3584        fixed_dwi_use = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
3585    else:
3586        temp_b0 = ants.slice_image( x, axis=3, idx=b0_idx[0] )
3587        temp_dwi = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
3588        xavg = fixed_b0 * 0.0
3589        bavg = fixed_b0 * 0.0
3590        tempreg = ants.registration( fixed_b0, temp_b0,'antsRegistrationSyNRepro[r]')
3591        fixed_b0_use = tempreg['warpedmovout']
3592        fixed_dwi_use = ants.apply_transforms( fixed_b0, temp_dwi, tempreg['fwdtransforms'] )
3593    for myidx in range(x.shape[3]):
3594        b0 = ants.slice_image( x, axis=3, idx=myidx)
3595        if not fast:
3596            if not myidx in b0_idx:
3597                xavg = xavg + ants.registration(fixed_dwi_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3598            else:
3599                bavg = bavg + ants.registration(fixed_b0_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3600        else:
3601            if not myidx in b0_idx:
3602                xavg = xavg + b0
3603            else:
3604                bavg = bavg + b0
3605    bavg = ants.iMath( bavg, 'Normalize' )
3606    xavg = ants.iMath( xavg, 'Normalize' )
3607    import shutil
3608    shutil.rmtree(output_directory, ignore_errors=True )
3609    avgb0=ants.n4_bias_field_correction(bavg)
3610    avgdwi=ants.n4_bias_field_correction(xavg)
3611    avgdwi=ants.registration( avgb0, avgdwi, 'antsRegistrationSyNRepro[r]' )['warpedmovout']
3612    return avgb0, avgdwi

automatically generates the average b0 and dwi and outputs both; maps dwi to b0 space at end.

x : input image

fixed_b0 : alernative reference space

fixed_dwi : alernative reference space

fast : boolean

returns: avg_b0, avg_dwi

def dti_template( b_image_list=None, w_image_list=None, iterations=5, gradient_step=0.5, mask_csf=False, average_both=True, verbose=False):
3614def dti_template(
3615    b_image_list=None,
3616    w_image_list=None,
3617    iterations=5,
3618    gradient_step=0.5,
3619    mask_csf=False,
3620    average_both=True,
3621    verbose=False
3622):
3623    """
3624    two channel version of build_template
3625
3626    returns:
3627        avg_b0, avg_dwi
3628    """
3629    output_directory = tempfile.mkdtemp()
3630    mydeftx = tempfile.NamedTemporaryFile(delete=False,dir=output_directory).name
3631    tmp = tempfile.NamedTemporaryFile(delete=False,dir=output_directory,suffix=".nii.gz")
3632    wavgfn = tmp.name
3633    tmp2 = tempfile.NamedTemporaryFile(delete=False,dir=output_directory)
3634    comptx = tmp2.name
3635    weights = np.repeat(1.0 / len(b_image_list), len(b_image_list))
3636    weights = [x / sum(weights) for x in weights]
3637    w_initial_template = w_image_list[0]
3638    b_initial_template = b_image_list[0]
3639    b_initial_template = ants.iMath(b_initial_template,"Normalize")
3640    w_initial_template = ants.iMath(w_initial_template,"Normalize")
3641    if mask_csf:
3642        bcsf0 = ants.threshold_image( b_image_list[0],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
3643        bcsf1 = ants.threshold_image( b_image_list[1],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
3644    else:
3645        bcsf0 = b_image_list[0] * 0 + 1
3646        bcsf1 = b_image_list[1] * 0 + 1
3647    bavg = b_initial_template.clone() * bcsf0
3648    wavg = w_initial_template.clone() * bcsf0
3649    bcsf = [ bcsf0, bcsf1 ]
3650    for i in range(iterations):
3651        for k in range(len(w_image_list)):
3652            fimg=wavg
3653            mimg=w_image_list[k] * bcsf[k]
3654            fimg2=bavg
3655            mimg2=b_image_list[k] * bcsf[k]
3656            w1 = ants.registration(
3657                fimg, mimg, type_of_transform='antsRegistrationSyNQuickRepro[s]',
3658                    multivariate_extras= [ [ "CC", fimg2, mimg2, 1, 2 ]],
3659                    outprefix=mydeftx,
3660                    verbose=0 )
3661            txname = ants.apply_transforms(wavg, wavg,
3662                w1["fwdtransforms"], compose=comptx )
3663            if k == 0:
3664                txavg = ants.image_read(txname) * weights[k]
3665                wavgnew = ants.apply_transforms( wavg,
3666                    w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3667                bavgnew = ants.apply_transforms( wavg,
3668                    b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3669            else:
3670                txavg = txavg + ants.image_read(txname) * weights[k]
3671                if i >= (iterations-2) and average_both:
3672                    wavgnew = wavgnew+ants.apply_transforms( wavg,
3673                        w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3674                    bavgnew = bavgnew+ants.apply_transforms( wavg,
3675                        b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3676        if verbose:
3677            print("iteration:",str(i),str(txavg.abs().mean()))
3678        wscl = (-1.0) * gradient_step
3679        txavg = txavg * wscl
3680        ants.image_write( txavg, wavgfn )
3681        wavg = ants.apply_transforms(wavg, wavgnew, wavgfn).iMath("Normalize")
3682        bavg = ants.apply_transforms(bavg, bavgnew, wavgfn).iMath("Normalize")
3683    import shutil
3684    shutil.rmtree( output_directory, ignore_errors=True )
3685    if verbose:
3686        print("done")
3687    return bavg, wavg

two channel version of build_template

returns: avg_b0, avg_dwi

def t1_based_dwi_brain_extraction( t1w_head, t1w, dwi, b0_idx=None, transform='antsRegistrationSyNRepro[r]', deform=None, verbose=False):
3709def t1_based_dwi_brain_extraction(
3710    t1w_head,
3711    t1w,
3712    dwi,
3713    b0_idx = None,
3714    transform='antsRegistrationSyNRepro[r]',
3715    deform=None,
3716    verbose=False
3717):
3718    """
3719    Map a t1-based brain extraction to b0 and return a mask and average b0
3720
3721    Arguments
3722    ---------
3723    t1w_head : an antsImage of the hole head
3724
3725    t1w : an antsImage probably but not necessarily T1-weighted
3726
3727    dwi : an antsImage holding B0 and DWI
3728
3729    b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess
3730
3731    transform : string Rigid or other ants.registration tx type
3732
3733    deform : follow up transform with deformation
3734
3735    Returns
3736    -------
3737    dictionary holding the avg_b0 and its mask
3738
3739    Example
3740    -------
3741    >>> import antspymm
3742    """
3743    t1w_use = ants.iMath( t1w, "Normalize" )
3744    t1bxt = ants.threshold_image( t1w_use, 0.05, 1 ).iMath("FillHoles")
3745    if b0_idx is None:
3746        b0_idx = segment_timeseries_by_meanvalue( dwi )['highermeans']
3747    # first get the average b0
3748    if len( b0_idx ) > 1:
3749        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] ).iMath("Normalize")
3750        for n in range(1,len(b0_idx)):
3751            temp = ants.slice_image( dwi, axis=3, idx=b0_idx[n] )
3752            reg = ants.registration( b0_avg, temp, 'antsRegistrationSyNRepro[r]' )
3753            b0_avg = b0_avg + ants.iMath( reg['warpedmovout'], "Normalize")
3754    else:
3755        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] )
3756    b0_avg = ants.iMath(b0_avg,"Normalize")
3757    reg = tra_initializer( b0_avg, t1w, n_simulations=12,   verbose=verbose )
3758    if deform is not None:
3759        reg = ants.registration( b0_avg, t1w,
3760            'SyNOnly',
3761            total_sigma=0.5,
3762            initial_transform=reg['fwdtransforms'][0],
3763            verbose=False )
3764    outmsk = ants.apply_transforms( b0_avg, t1bxt, reg['fwdtransforms'], interpolator='linear').threshold_image( 0.5, 1.0 )
3765    return  {
3766    'b0_avg':b0_avg,
3767    'b0_mask':outmsk }

Map a t1-based brain extraction to b0 and return a mask and average b0

Arguments

t1w_head : an antsImage of the hole head

t1w : an antsImage probably but not necessarily T1-weighted

dwi : an antsImage holding B0 and DWI

b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess

transform : string Rigid or other ants.registration tx type

deform : follow up transform with deformation

Returns

dictionary holding the avg_b0 and its mask

Example

>>> import antspymm
def mc_denoise(x, ratio=0.5):
3769def mc_denoise( x, ratio = 0.5 ):
3770    """
3771    ants denoising for timeseries (4D)
3772
3773    Arguments
3774    ---------
3775    x : an antsImage 4D
3776
3777    ratio : weight between 1 and 0 - lower weights bring result closer to initial image
3778
3779    Returns
3780    -------
3781    denoised time series
3782
3783    """
3784    dwpimage = []
3785    for myidx in range(x.shape[3]):
3786        b0 = ants.slice_image( x, axis=3, idx=myidx)
3787        dnzb0 = ants.denoise_image( b0, p=1,r=1,noise_model='Gaussian' )
3788        dwpimage.append( dnzb0 * ratio + b0 * (1.0-ratio) )
3789    return ants.list_to_ndimage( x, dwpimage )

ants denoising for timeseries (4D)

Arguments

x : an antsImage 4D

ratio : weight between 1 and 0 - lower weights bring result closer to initial image

Returns

denoised time series

def tsnr(x, mask, indices=None):
3791def tsnr( x, mask, indices=None ):
3792    """
3793    3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1
3794
3795    x: image
3796
3797    mask : mask
3798
3799    indices: indices to use
3800
3801    returns a 3D image
3802    """
3803    M = ants.timeseries_to_matrix( x, mask )
3804    M = M - M.min()
3805    M = M / M.max()
3806    if indices is not None:
3807        M=M[indices,:]
3808    stdM = np.std(M, axis=0 )
3809    stdM[np.isnan(stdM)] = 0
3810    tt = round( 0.975*100 )
3811    threshold_std = np.percentile( stdM, tt )
3812    tsnrimage = ants.make_image( mask, stdM )
3813    return tsnrimage

3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1

x: image

mask : mask

indices: indices to use

returns a 3D image

def dvars(x, mask, indices=None):
3815def dvars( x,  mask, indices=None ):
3816    """
3817    dvars on a time series image ... the matrix is normalized to range of 0,1
3818
3819    x: image
3820
3821    mask : mask
3822
3823    indices: indices to use
3824
3825    returns an array
3826    """
3827    M = ants.timeseries_to_matrix( x, mask )
3828    M = M - M.min()
3829    M = M / M.max()
3830    if indices is not None:
3831        M=M[indices,:]
3832    DVARS = np.zeros( M.shape[0] )
3833    for i in range(1, M.shape[0] ):
3834        vecdiff = M[i-1,:] - M[i,:]
3835        DVARS[i] = np.sqrt( ( vecdiff * vecdiff ).mean() )
3836    DVARS[0] = DVARS.mean()
3837    return DVARS

dvars on a time series image ... the matrix is normalized to range of 0,1

x: image

mask : mask

indices: indices to use

returns an array

def slice_snr(x, background_mask, foreground_mask, indices=None):
3840def slice_snr( x,  background_mask, foreground_mask, indices=None ):
3841    """
3842    slice-wise SNR on a time series image
3843
3844    x: image
3845
3846    background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask
3847
3848    foreground_mask : mask - maybe cortex or WM or brain mask
3849
3850    indices: indices to use
3851
3852    returns an array
3853    """
3854    xuse=ants.iMath(x,"Normalize")
3855    MB = ants.timeseries_to_matrix( xuse, background_mask )
3856    MF = ants.timeseries_to_matrix( xuse, foreground_mask )
3857    if indices is not None:
3858        MB=MB[indices,:]
3859        MF=MF[indices,:]
3860    ssnr = np.zeros( MB.shape[0] )
3861    for i in range( MB.shape[0] ):
3862        ssnr[i]=MF[i,:].mean()/MB[i,:].std()
3863    ssnr[np.isnan(ssnr)] = 0
3864    return ssnr

slice-wise SNR on a time series image

x: image

background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask

foreground_mask : mask - maybe cortex or WM or brain mask

indices: indices to use

returns an array

def impute_fa(fa, md):
3867def impute_fa( fa, md ):
3868    """
3869    impute bad values in dti, fa, md
3870    """
3871    def imputeit( x, fa ):
3872        badfa=ants.threshold_image(fa,1,1)
3873        if badfa.max() == 1:
3874            temp=ants.image_clone(x)
3875            temp[badfa==1]=0
3876            temp=ants.iMath(temp,'GD',2)
3877            x[ badfa==1 ]=temp[badfa==1]
3878        return x
3879    md=imputeit( md, fa )
3880    fa=imputeit( ants.image_clone(fa), fa )
3881    return fa, md

impute bad values in dti, fa, md

def trim_dti_mask(fa, mask, param=4.0):
3883def trim_dti_mask( fa, mask, param=4.0 ):
3884    """
3885    trim the dti mask to get rid of bright fa rim
3886
3887    this function erodes the famask by param amount then segments the rim into
3888    bright and less bright parts.  the bright parts are trimmed from the mask
3889    and the remaining edges are cleaned up a bit with closing.
3890
3891    param: closing radius unit is in physical space
3892    """
3893    spacing = ants.get_spacing(mask)
3894    spacing_product = np.prod( spacing )
3895    spcmin = min( spacing )
3896    paramVox = int(np.round( param / spcmin ))
3897    trim_mask = ants.image_clone( mask )
3898    trim_mask = ants.iMath( trim_mask, "FillHoles" )
3899    edgemask = trim_mask - ants.iMath( trim_mask, "ME", paramVox )
3900    maxk=4
3901    edgemask = ants.threshold_image( fa * edgemask, "Otsu", maxk )
3902    edgemask = ants.threshold_image( edgemask, maxk-1, maxk )
3903    trim_mask[edgemask >= 1 ]=0
3904    trim_mask = ants.iMath(trim_mask,"ME",paramVox-1)
3905    trim_mask = ants.iMath(trim_mask,'GetLargestComponent')
3906    trim_mask = ants.iMath(trim_mask,"MD",paramVox-1)
3907    return trim_mask

trim the dti mask to get rid of bright fa rim

this function erodes the famask by param amount then segments the rim into bright and less bright parts. the bright parts are trimmed from the mask and the remaining edges are cleaned up a bit with closing.

param: closing radius unit is in physical space

def dipy_dti_recon( image, bvalsfn, bvecsfn, mask=None, b0_idx=None, mask_dilation=2, mask_closing=5, fit_method='WLS', trim_the_mask=2.0, diffusion_model='DTI', verbose=False):
4271def dipy_dti_recon(
4272    image,
4273    bvalsfn,
4274    bvecsfn,
4275    mask = None,
4276    b0_idx = None,
4277    mask_dilation = 2,
4278    mask_closing = 5,
4279    fit_method='WLS',
4280    trim_the_mask=2.0,
4281    diffusion_model='DTI',
4282    verbose=False ):
4283    """
4284    DiPy DTI reconstruction - building on the DiPy basic DTI example
4285
4286    Arguments
4287    ---------
4288    image : an antsImage holding B0 and DWI
4289
4290    bvalsfn : bvalues  obtained by dipy read_bvals_bvecs or the values themselves
4291
4292    bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves
4293
4294    mask : brain mask for the DWI/DTI reconstruction; if it is not in the same
4295        space as the image, we will resample directly to the image space.  This
4296        could lead to problems if the inputs are really incorrect.
4297
4298    b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue
4299
4300    mask_dilation : integer zero or more dilates the brain mask
4301
4302    mask_closing : integer zero or more closes the brain mask
4303
4304    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.
4305
4306    trim_the_mask : float >=0 post-hoc method for trimming the mask
4307
4308    diffusion_model : string
4309        DTI, FreeWater, DKI
4310
4311    verbose : boolean
4312
4313    Returns
4314    -------
4315    dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)
4316
4317    NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)
4318
4319    NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect
4320        the tailing image volumes.
4321
4322    Example
4323    -------
4324    >>> import antspymm
4325    """
4326
4327    import dipy.reconst.fwdti as fwdti
4328
4329    if isinstance(bvecsfn, str):
4330        bvals, bvecs = read_bvals_bvecs( bvalsfn , bvecsfn   )
4331    else: # assume we already read them
4332        bvals = bvalsfn.copy()
4333        bvecs = bvecsfn.copy()
4334
4335    if bvals.max() < 1.0:
4336        raise ValueError("DTI recon error: maximum bvalues are too small.")
4337
4338    b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
4339
4340    b0 = ants.slice_image( image, axis=3, idx=b0_idx[0] )
4341    bxtmod='bold'
4342    bxtmod='t2'
4343    constant_mask=False
4344    if verbose:
4345        print( np.unique( bvals ), flush=True )
4346    if mask is not None:
4347        if verbose:
4348            print("use set bxt in dipy_dti_recon", flush=True)
4349        constant_mask=True
4350        mask = ants.resample_image_to_target( mask, b0, interp_type='nearestNeighbor')
4351    else:
4352        if verbose:
4353            print("use deep learning bxt in dipy_dti_recon")
4354        mask = antspynet.brain_extraction( b0, bxtmod ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
4355    if mask_closing > 0 and not constant_mask :
4356        mask = ants.morphology( mask, "close", mask_closing ) # good
4357    maskdil = ants.iMath( mask, "MD", mask_dilation )
4358
4359    if verbose:
4360        print("recon dti.TensorModel",flush=True)
4361
4362    bvecs = repair_bvecs( bvecs )
4363    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
4364    mynt=1
4365    threads_env = os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")
4366    if threads_env is not None:
4367        mynt = int(threads_env)
4368    tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
4369                                             num_threads=mynt )
4370    if verbose:
4371        print("recon dti.TensorModel done",flush=True)
4372
4373    # change the brain mask based on high FA values
4374    if trim_the_mask > 0 and fit_method is not None:
4375        mask = trim_dti_mask( FA, mask, trim_the_mask )
4376        tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
4377                                             num_threads=mynt )
4378
4379    return {
4380        'tensormodel' : tenfit,
4381        'MD' : MD1 ,
4382        'FA' : FA ,
4383        'RGB' : RGB,
4384        'dwi_mask':mask,
4385        'bvals':bvals,
4386        'bvecs':bvecs
4387        }

DiPy DTI reconstruction - building on the DiPy basic DTI example

Arguments

image : an antsImage holding B0 and DWI

bvalsfn : bvalues obtained by dipy read_bvals_bvecs or the values themselves

bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves

mask : brain mask for the DWI/DTI reconstruction; if it is not in the same space as the image, we will resample directly to the image space. This could lead to problems if the inputs are really incorrect.

b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue

mask_dilation : integer zero or more dilates the brain mask

mask_closing : integer zero or more closes the brain mask

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.

trim_the_mask : float >=0 post-hoc method for trimming the mask

diffusion_model : string DTI, FreeWater, DKI

verbose : boolean

Returns

dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)

NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)

NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect the tailing image volumes.

Example

>>> import antspymm
def concat_dewarp( refimg, originalDWI, physSpaceDWI, dwpTx, motion_parameters, motion_correct=True, verbose=False):
4390def concat_dewarp(
4391        refimg,
4392        originalDWI,
4393        physSpaceDWI,
4394        dwpTx,
4395        motion_parameters,
4396        motion_correct=True,
4397        verbose=False ):
4398    """
4399    Apply concatentated motion correction and dewarping transforms to timeseries image.
4400
4401    Arguments
4402    ---------
4403
4404    refimg : an antsImage defining the reference domain (3D)
4405
4406    originalDWI : the antsImage in original (not interpolated space) (4D)
4407
4408    physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)
4409
4410    dwpTx : dewarping transform
4411
4412    motion_parameters : previously computed list of motion parameters
4413
4414    motion_correct : boolean
4415
4416    verbose : boolean
4417
4418    """
4419    # apply the dewarping tx to the original dwi and reconstruct again
4420    # NOTE: refimg must be in the same space for this to work correctly
4421    # due to the use of ants.list_to_ndimage( originalDWI, dwpimage )
4422    dwpimage = []
4423    for myidx in range(originalDWI.shape[3]):
4424        b0 = ants.slice_image( originalDWI, axis=3, idx=myidx)
4425        concatx = dwpTx.copy()
4426        if motion_correct:
4427            concatx = concatx + motion_parameters[myidx]
4428        if verbose and myidx == 0:
4429            print("dwp parameters")
4430            print( dwpTx )
4431            print("Motion parameters")
4432            print( motion_parameters[myidx] )
4433            print("concat parameters")
4434            print(concatx)
4435        warpedb0 = ants.apply_transforms( refimg, b0, concatx,
4436            interpolator='nearestNeighbor' )
4437        dwpimage.append( warpedb0 )
4438    return ants.list_to_ndimage( physSpaceDWI, dwpimage )

Apply concatentated motion correction and dewarping transforms to timeseries image.

Arguments

refimg : an antsImage defining the reference domain (3D)

originalDWI : the antsImage in original (not interpolated space) (4D)

physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)

dwpTx : dewarping transform

motion_parameters : previously computed list of motion parameters

motion_correct : boolean

verbose : boolean

def joint_dti_recon( img_LR, bval_LR, bvec_LR, jhu_atlas, jhu_labels, reference_B0, reference_DWI, srmodel=None, img_RL=None, bval_RL=None, bvec_RL=None, t1w=None, brain_mask=None, motion_correct=None, dewarp_modality='FA', denoise=False, fit_method='WLS', impute=False, censor=True, diffusion_model='DTI', verbose=False):
4441def joint_dti_recon(
4442    img_LR,
4443    bval_LR,
4444    bvec_LR,
4445    jhu_atlas,
4446    jhu_labels,
4447    reference_B0,
4448    reference_DWI,
4449    srmodel = None,
4450    img_RL = None,
4451    bval_RL = None,
4452    bvec_RL = None,
4453    t1w = None,
4454    brain_mask = None,
4455    motion_correct = None,
4456    dewarp_modality = 'FA',
4457    denoise=False,
4458    fit_method='WLS',
4459    impute = False,
4460    censor = True,
4461    diffusion_model = 'DTI',
4462    verbose = False ):
4463    """
4464    1. pass in subject data and 1mm JHU atlas/labels
4465    2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
4466    3. dewarp the images using dewarp_modality or T1w
4467    4. apply dewarping to the original data
4468        ===> may want to apply SR at this step
4469    5. reconstruct DTI again
4470    6. label images and do registration
4471    7. return relevant outputs
4472
4473    NOTE: RL images are optional; should pass t1w in this case.
4474
4475    Arguments
4476    ---------
4477
4478    img_LR : an antsImage holding B0 and DWI LR acquisition
4479
4480    bval_LR : bvalue filename LR
4481
4482    bvec_LR : bvector filename LR
4483
4484    jhu_atlas : atlas FA image
4485
4486    jhu_labels : atlas labels
4487
4488    reference_B0 : the "target" B0 image space
4489
4490    reference_DWI : the "target" DW image space
4491
4492    srmodel : optional h5 (tensorflow) model
4493
4494    img_RL : an antsImage holding B0 and DWI RL acquisition
4495
4496    bval_RL : bvalue filename RL
4497
4498    bvec_RL : bvector filename RL
4499
4500    t1w : antsimage t1w neuroimage (brain-extracted)
4501
4502    brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space
4503
4504    motion_correct : None Rigid or SyN
4505
4506    dewarp_modality : string average_dwi, average_b0, MD or FA
4507
4508    denoise: boolean
4509
4510    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
4511
4512    impute : boolean
4513
4514    censor : boolean
4515
4516    diffusion_model : string
4517        DTI, FreeWater, DKI
4518
4519    verbose : boolean
4520
4521    Returns
4522    -------
4523    dictionary holding the mean_fa, its summary statistics via JHU labels,
4524        the JHU registration, the JHU labels, the dewarping dictionary and the
4525        dti reconstruction dictionaries.
4526
4527    Example
4528    -------
4529    >>> import antspymm
4530    """
4531
4532    if verbose:
4533        print("Recon DTI on OR images ...")
4534
4535    def fix_dwi_shape( img, bvalfn, bvecfn ):
4536        if isinstance(bvecfn, str):
4537            bvals, bvecs = read_bvals_bvecs( bvalfn , bvecfn   )
4538        else:
4539            bvals = bvalfn
4540            bvecs = bvecfn
4541        if bvecs.shape[0] < img.shape[3]:
4542            imgout = ants.from_numpy( img[:,:,:,0:bvecs.shape[0]] )
4543            imgout = ants.copy_image_info( img, imgout )
4544            return( imgout )
4545        else:
4546            return( img )
4547
4548    img_LR = fix_dwi_shape( img_LR, bval_LR, bvec_LR )
4549    if denoise :
4550        img_LR = mc_denoise( img_LR )
4551    if img_RL is not None:
4552        img_RL = fix_dwi_shape( img_RL, bval_RL, bvec_RL )
4553        if denoise :
4554            img_RL = mc_denoise( img_RL )
4555
4556    brainmaske = None
4557    if brain_mask is not None:
4558        maskInRightSpace = ants.image_physical_space_consistency( brain_mask, reference_B0 )
4559        if not maskInRightSpace :
4560            raise ValueError('not maskInRightSpace ... provided brain mask should be in reference_B0 space')
4561        brainmaske = ants.iMath( brain_mask, "ME", 2 )
4562
4563    if img_RL is not None :
4564        if verbose:
4565            print("img_RL correction")
4566        reg_RL = dti_reg(
4567            img_RL,
4568            avg_b0=reference_B0,
4569            avg_dwi=reference_DWI,
4570            bvals=bval_RL,
4571            bvecs=bvec_RL,
4572            type_of_transform=motion_correct,
4573            brain_mask_eroded=brainmaske,
4574            verbose=True )
4575    else:
4576        reg_RL=None
4577
4578
4579    if verbose:
4580        print("img_LR correction")
4581    reg_LR = dti_reg(
4582            img_LR,
4583            avg_b0=reference_B0,
4584            avg_dwi=reference_DWI,
4585            bvals=bval_LR,
4586            bvecs=bvec_LR,
4587            type_of_transform=motion_correct,
4588            brain_mask_eroded=brainmaske,
4589            verbose=True )
4590
4591    ts_LR_avg = None
4592    ts_RL_avg = None
4593    reg_its = [100,50,10]
4594    img_LRdwp = ants.image_clone( reg_LR[ 'motion_corrected' ] )
4595    if img_RL is not None:
4596        img_RLdwp = ants.image_clone( reg_RL[ 'motion_corrected' ] )
4597        if srmodel is not None:
4598            if verbose:
4599                print("convert img_RL_dwp to img_RL_dwp_SR")
4600            img_RLdwp = super_res_mcimage( img_RLdwp, srmodel, isotropic=True,
4601                        verbose=verbose )
4602    if srmodel is not None:
4603        reg_its = [100] + reg_its
4604        if verbose:
4605            print("convert img_LR_dwp to img_LR_dwp_SR")
4606        img_LRdwp = super_res_mcimage( img_LRdwp, srmodel, isotropic=True,
4607                verbose=verbose )
4608    if verbose:
4609        print("recon after distortion correction", flush=True)
4610
4611    if impute:
4612        print("impute begin", flush=True)
4613        img_LRdwp=impute_dwi( img_LRdwp, verbose=True )
4614        print("impute done", flush=True)
4615    elif censor:
4616        print("censor begin", flush=True)
4617        img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'] = censor_dwi( img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'], verbose=True )
4618        print("censor done", flush=True)
4619    if impute and img_RL is not None:
4620        img_RLdwp=impute_dwi( img_RLdwp, verbose=True )
4621    elif censor and img_RL is not None:
4622        img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'] = censor_dwi( img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'], verbose=True )
4623
4624    if img_RL is not None:
4625        img_LRdwp, bval_LR, bvec_LR = merge_dwi_data(
4626            img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'],
4627            img_RLdwp, reg_RL['bvals'], reg_RL['bvecs']
4628        )
4629    else:
4630        bval_LR=reg_LR['bvals']
4631        bvec_LR=reg_LR['bvecs']
4632
4633    if verbose:
4634        print("final recon", flush=True)
4635        print(img_LRdwp)
4636
4637    recon_LR_dewarp = dipy_dti_recon(
4638            img_LRdwp, bval_LR, bvec_LR,
4639            mask = brain_mask,
4640            fit_method=fit_method,
4641            mask_dilation=0, diffusion_model=diffusion_model, verbose=True )
4642    if verbose:
4643        print("recon done", flush=True)
4644
4645    if img_RL is not None:
4646        fdjoin = [ reg_LR['FD'],
4647                   reg_RL['FD'] ]
4648        framewise_displacement=np.concatenate( fdjoin )
4649    else:
4650        framewise_displacement=reg_LR['FD']
4651
4652    motion_count = ( framewise_displacement > 1.5  ).sum()
4653    reconFA = recon_LR_dewarp['FA']
4654    reconMD = recon_LR_dewarp['MD']
4655
4656    if verbose:
4657        print("JHU reg",flush=True)
4658
4659    OR_FA2JHUreg = ants.registration( reconFA, jhu_atlas,
4660        type_of_transform = 'antsRegistrationSyNQuickRepro[s]', 
4661        reg_iterations=reg_its, verbose=False )
4662    OR_FA_jhulabels = ants.apply_transforms( reconFA, jhu_labels,
4663        OR_FA2JHUreg['fwdtransforms'], interpolator='genericLabel')
4664
4665    df_FA_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
4666        'FA_JHU_labels_edited',
4667        reconFA,
4668        OR_FA_jhulabels)
4669    df_FA_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
4670            {'df_FA_JHU_ORRL' : df_FA_JHU_ORRL},
4671            col_names = ['Mean'] )
4672
4673    df_MD_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
4674        'MD_JHU_labels_edited',
4675        reconMD,
4676        OR_FA_jhulabels)
4677    df_MD_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
4678            {'df_MD_JHU_ORRL' : df_MD_JHU_ORRL},
4679            col_names = ['Mean'] )
4680
4681    temp = segment_timeseries_by_meanvalue( img_LRdwp )
4682    b0_idx = temp['highermeans']
4683    non_b0_idx = temp['lowermeans']
4684
4685    nonbrainmask = ants.iMath( recon_LR_dewarp['dwi_mask'], "MD",2) - recon_LR_dewarp['dwi_mask']
4686    fgmask = ants.threshold_image( reconFA, 0.5 , 1.0).iMath("GetLargestComponent")
4687    bgmask = ants.threshold_image( reconFA, 1e-4 , 0.1)
4688    fa_SNR = 0.0
4689    fa_SNR = mask_snr( reconFA, bgmask, fgmask, bias_correct=False )
4690    fa_evr = antspyt1w.patch_eigenvalue_ratio( reconFA, 512, [16,16,16], evdepth = 0.9, mask=recon_LR_dewarp['dwi_mask'] )
4691
4692    dti_itself = get_dti( reconFA, recon_LR_dewarp['tensormodel'], return_image=True )
4693    return convert_np_in_dict( {
4694        'dti': dti_itself,
4695        'recon_fa':reconFA,
4696        'recon_fa_summary':df_FA_JHU_ORRL_bfwide,
4697        'recon_md':reconMD,
4698        'recon_md_summary':df_MD_JHU_ORRL_bfwide,
4699        'jhu_labels':OR_FA_jhulabels,
4700        'jhu_registration':OR_FA2JHUreg,
4701        'reg_LR':reg_LR,
4702        'reg_RL':reg_RL,
4703        'dtrecon_LR_dewarp':recon_LR_dewarp,
4704        'dwi_LR_dewarped':img_LRdwp,
4705        'bval_unique_count': len(np.unique(bval_LR)),
4706        'bval_LR':bval_LR,
4707        'bvec_LR':bvec_LR,
4708        'bval_RL':bval_RL,
4709        'bvec_RL':bvec_RL,
4710        'b0avg': reference_B0,
4711        'dwiavg': reference_DWI,
4712        'framewise_displacement':framewise_displacement,
4713        'high_motion_count': motion_count,
4714        'tsnr_b0': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
4715        'tsnr_dwi': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
4716        'dvars_b0': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
4717        'dvars_dwi': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
4718        'ssnr_b0': slice_snr( img_LRdwp, bgmask , fgmask, b0_idx),
4719        'ssnr_dwi': slice_snr( img_LRdwp, bgmask, fgmask, non_b0_idx),
4720        'fa_evr': fa_evr,
4721        'fa_SNR': fa_SNR
4722    } )
  1. pass in subject data and 1mm JHU atlas/labels
  2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
  3. dewarp the images using dewarp_modality or T1w
  4. apply dewarping to the original data ===> may want to apply SR at this step
  5. reconstruct DTI again
  6. label images and do registration
  7. return relevant outputs

NOTE: RL images are optional; should pass t1w in this case.

Arguments

img_LR : an antsImage holding B0 and DWI LR acquisition

bval_LR : bvalue filename LR

bvec_LR : bvector filename LR

jhu_atlas : atlas FA image

jhu_labels : atlas labels

reference_B0 : the "target" B0 image space

reference_DWI : the "target" DW image space

srmodel : optional h5 (tensorflow) model

img_RL : an antsImage holding B0 and DWI RL acquisition

bval_RL : bvalue filename RL

bvec_RL : bvector filename RL

t1w : antsimage t1w neuroimage (brain-extracted)

brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space

motion_correct : None Rigid or SyN

dewarp_modality : string average_dwi, average_b0, MD or FA

denoise: boolean

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)

impute : boolean

censor : boolean

diffusion_model : string DTI, FreeWater, DKI

verbose : boolean

Returns

dictionary holding the mean_fa, its summary statistics via JHU labels, the JHU registration, the JHU labels, the dewarping dictionary and the dti reconstruction dictionaries.

Example

>>> import antspymm
def middle_slice_snr(x, background_dilation=5):
4725def middle_slice_snr( x, background_dilation=5 ):
4726    """
4727
4728    Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image.
4729    Estimates noise from a background mask which is a
4730    dilation of the foreground mask minus the foreground mask.
4731    Actually estimates the reciprocal of the coefficient of variation.
4732
4733    Arguments
4734    ---------
4735
4736    x : an antsImage
4737
4738    background_dilation : integer - amount to dilate foreground mask
4739
4740    """
4741    xshp = x.shape
4742    xmidslice = ants.slice_image( x, 2, int( xshp[2]/2 )  )
4743    xmidslice = ants.iMath( xmidslice - xmidslice.min(), "Normalize" )
4744    xmidslice = ants.n3_bias_field_correction( xmidslice )
4745    xmidslice = ants.n3_bias_field_correction( xmidslice )
4746    xmidslicemask = ants.threshold_image( xmidslice, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
4747    xbkgmask = ants.iMath( xmidslicemask, "MD", background_dilation ) - xmidslicemask
4748    signal = (xmidslice[ xmidslicemask == 1] ).mean()
4749    noise = (xmidslice[ xbkgmask == 1] ).std()
4750    return signal / noise

Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_dilation : integer - amount to dilate foreground mask

def foreground_background_snr(x, background_dilation=10, erode_foreground=False):
4752def foreground_background_snr( x, background_dilation=10,
4753        erode_foreground=False):
4754    """
4755
4756    Estimate signal to noise ratio (SNR) in an image.
4757    Estimates noise from a background mask which is a
4758    dilation of the foreground mask minus the foreground mask.
4759    Actually estimates the reciprocal of the coefficient of variation.
4760
4761    Arguments
4762    ---------
4763
4764    x : an antsImage
4765
4766    background_dilation : integer - amount to dilate foreground mask
4767
4768    erode_foreground : boolean - 2nd option which erodes the initial
4769    foregound mask  to create a new foreground mask.  the background
4770    mask is the initial mask minus the eroded mask.
4771
4772    """
4773    xshp = x.shape
4774    xbc = ants.iMath( x - x.min(), "Normalize" )
4775    xbc = ants.n3_bias_field_correction( xbc )
4776    xmask = ants.threshold_image( xbc, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
4777    xbkgmask = ants.iMath( xmask, "MD", background_dilation ) - xmask
4778    fgmask = xmask
4779    if erode_foreground:
4780        fgmask = ants.iMath( xmask, "ME", background_dilation )
4781        xbkgmask = xmask - fgmask
4782    signal = (xbc[ fgmask == 1] ).mean()
4783    noise = (xbc[ xbkgmask == 1] ).std()
4784    return signal / noise

Estimate signal to noise ratio (SNR) in an image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_dilation : integer - amount to dilate foreground mask

erode_foreground : boolean - 2nd option which erodes the initial foregound mask to create a new foreground mask. the background mask is the initial mask minus the eroded mask.

def quantile_snr( x, lowest_quantile=0.01, low_quantile=0.1, high_quantile=0.5, highest_quantile=0.95):
4786def quantile_snr( x,
4787    lowest_quantile=0.01,
4788    low_quantile=0.1,
4789    high_quantile=0.5,
4790    highest_quantile=0.95 ):
4791    """
4792
4793    Estimate signal to noise ratio (SNR) in an image.
4794    Estimates noise from a background mask which is a
4795    dilation of the foreground mask minus the foreground mask.
4796    Actually estimates the reciprocal of the coefficient of variation.
4797
4798    Arguments
4799    ---------
4800
4801    x : an antsImage
4802
4803    lowest_quantile : float value < 1 and > 0
4804
4805    low_quantile : float value < 1 and > 0
4806
4807    high_quantile : float value < 1 and > 0
4808
4809    highest_quantile : float value < 1 and > 0
4810
4811    """
4812    import numpy as np
4813    xshp = x.shape
4814    xbc = ants.iMath( x - x.min(), "Normalize" )
4815    xbc = ants.n3_bias_field_correction( xbc )
4816    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
4817    y = xbc.numpy()
4818    ylowest = np.quantile( y[y>0], lowest_quantile )
4819    ylo = np.quantile( y[y>0], low_quantile )
4820    yhi = np.quantile( y[y>0], high_quantile )
4821    yhiest = np.quantile( y[y>0], highest_quantile )
4822    xbkgmask = ants.threshold_image( xbc, ylowest, ylo )
4823    fgmask = ants.threshold_image( xbc, yhi, yhiest )
4824    signal = (xbc[ fgmask == 1] ).mean()
4825    noise = (xbc[ xbkgmask == 1] ).std()
4826    return signal / noise

Estimate signal to noise ratio (SNR) in an image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

lowest_quantile : float value < 1 and > 0

low_quantile : float value < 1 and > 0

high_quantile : float value < 1 and > 0

highest_quantile : float value < 1 and > 0

def mask_snr(x, background_mask, foreground_mask, bias_correct=True):
4828def mask_snr( x, background_mask, foreground_mask, bias_correct=True ):
4829    """
4830
4831    Estimate signal to noise ratio (SNR) in an image using
4832    a user-defined foreground and background mask.
4833    Actually estimates the reciprocal of the coefficient of variation.
4834
4835    Arguments
4836    ---------
4837
4838    x : an antsImage
4839
4840    background_mask : binary antsImage
4841
4842    foreground_mask : binary antsImage
4843
4844    bias_correct : boolean
4845
4846    """
4847    import numpy as np
4848    if foreground_mask.sum() <= 1 or background_mask.sum() <= 1:
4849        return 0
4850    xbc = ants.iMath( x - x.min(), "Normalize" )
4851    if bias_correct:
4852        xbc = ants.n3_bias_field_correction( xbc )
4853    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
4854    signal = (xbc[ foreground_mask == 1] ).mean()
4855    noise = (xbc[ background_mask == 1] ).std()
4856    return signal / noise

Estimate signal to noise ratio (SNR) in an image using a user-defined foreground and background mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_mask : binary antsImage

foreground_mask : binary antsImage

bias_correct : boolean

def dwi_deterministic_tracking( dwi, fa, bvals, bvecs, num_processes=1, mask=None, label_image=None, seed_labels=None, fa_thresh=0.05, seed_density=1, step_size=0.15, peak_indices=None, fit_method='WLS', verbose=False):
4859def dwi_deterministic_tracking(
4860    dwi,
4861    fa,
4862    bvals,
4863    bvecs,
4864    num_processes=1,
4865    mask=None,
4866    label_image = None,
4867    seed_labels = None,
4868    fa_thresh = 0.05,
4869    seed_density = 1,
4870    step_size = 0.15,
4871    peak_indices = None,
4872    fit_method='WLS',
4873    verbose = False ):
4874    """
4875
4876    Performs deterministic tractography from the DWI and returns a tractogram
4877    and path length data frame.
4878
4879    Arguments
4880    ---------
4881
4882    dwi : an antsImage holding DWI acquisition
4883
4884    fa : an antsImage holding FA values
4885
4886    bvals : bvalues
4887
4888    bvecs : bvectors
4889
4890    num_processes : number of subprocesses
4891
4892    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
4893        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
4894
4895    label_image : atlas labels
4896
4897    seed_labels : list of label numbers from the atlas labels
4898
4899    fa_thresh : 0.25 defaults
4900
4901    seed_density : 1 default number of seeds per voxel
4902
4903    step_size : for tracking
4904
4905    peak_indices : pass these in, if they are previously estimated.  otherwise, will
4906        compute on the fly (slow)
4907
4908    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
4909
4910    verbose : boolean
4911
4912    Returns
4913    -------
4914    dictionary holding tracts and stateful object.
4915
4916    Example
4917    -------
4918    >>> import antspymm
4919    """
4920    import os
4921    import re
4922    import nibabel as nib
4923    import numpy as np
4924    import ants
4925    from dipy.io.gradients import read_bvals_bvecs
4926    from dipy.core.gradients import gradient_table
4927    from dipy.tracking import utils
4928    import dipy.reconst.dti as dti
4929    from dipy.segment.clustering import QuickBundles
4930    from dipy.tracking.utils import path_length
4931    if verbose:
4932        print("begin tracking",flush=True)
4933
4934    affine = ants_to_nibabel_affine(dwi)
4935
4936    if isinstance( bvals, str ) or isinstance( bvecs, str ):
4937        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
4938    bvecs = repair_bvecs( bvecs )
4939    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
4940    if mask is None:
4941        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
4942    dwi_data = dwi.numpy()
4943    dwi_mask = mask.numpy() == 1
4944    dti_model = dti.TensorModel(gtab,fit_method=fit_method)
4945    if verbose:
4946        print("begin tracking fit",flush=True)
4947    dti_fit = dti_model.fit(dwi_data, mask=dwi_mask)  # This step may take a while
4948    evecs_img = dti_fit.evecs
4949    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
4950    stopping_criterion = ThresholdStoppingCriterion(fa.numpy(), fa_thresh)
4951    from dipy.data import get_sphere
4952    sphere = get_sphere(name='symmetric362')
4953    from dipy.direction import peaks_from_model
4954    if peak_indices is None:
4955        # problems with multi-threading ...
4956        # see https://github.com/dipy/dipy/issues/2519
4957        if verbose:
4958            print("begin peaks",flush=True)
4959        mynump=1
4960        # if os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"):
4961        #    mynump = os.environ['ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS']
4962        # current_openblas = os.environ.get('OPENBLAS_NUM_THREADS', '')
4963        # current_mkl = os.environ.get('MKL_NUM_THREADS', '')
4964        # os.environ['DIPY_OPENBLAS_NUM_THREADS'] = current_openblas
4965        # os.environ['DIPY_MKL_NUM_THREADS'] = current_mkl
4966        # os.environ['OPENBLAS_NUM_THREADS'] = '1'
4967        # os.environ['MKL_NUM_THREADS'] = '1'
4968        peak_indices = peaks_from_model(
4969            model=dti_model,
4970            data=dwi_data,
4971            sphere=sphere,
4972            relative_peak_threshold=.5,
4973            min_separation_angle=25,
4974            mask=dwi_mask,
4975            npeaks=3, return_odf=False,
4976            return_sh=False,
4977            parallel=int(mynump) > 1,
4978            num_processes=int(mynump)
4979            )
4980        if False:
4981            if 'DIPY_OPENBLAS_NUM_THREADS' in os.environ:
4982                os.environ['OPENBLAS_NUM_THREADS'] = \
4983                    os.environ.pop('DIPY_OPENBLAS_NUM_THREADS', '')
4984                if os.environ['OPENBLAS_NUM_THREADS'] in ['', None]:
4985                    os.environ.pop('OPENBLAS_NUM_THREADS', '')
4986            if 'DIPY_MKL_NUM_THREADS' in os.environ:
4987                os.environ['MKL_NUM_THREADS'] = \
4988                    os.environ.pop('DIPY_MKL_NUM_THREADS', '')
4989                if os.environ['MKL_NUM_THREADS'] in ['', None]:
4990                    os.environ.pop('MKL_NUM_THREADS', '')
4991
4992    if label_image is None or seed_labels is None:
4993        seed_mask = fa.numpy().copy()
4994        seed_mask[seed_mask >= fa_thresh] = 1
4995        seed_mask[seed_mask < fa_thresh] = 0
4996    else:
4997        labels = label_image.numpy()
4998        seed_mask = labels * 0
4999        for u in seed_labels:
5000            seed_mask[ labels == u ] = 1
5001    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
5002    from dipy.tracking.local_tracking import LocalTracking
5003    from dipy.tracking.streamline import Streamlines
5004    if verbose:
5005        print("streamlines begin ...", flush=True)
5006    streamlines_generator = LocalTracking(
5007        peak_indices, stopping_criterion, seeds, affine=affine, step_size=step_size)
5008    streamlines = Streamlines(streamlines_generator)
5009    from dipy.io.stateful_tractogram import Space, StatefulTractogram
5010    from dipy.io.streamline import save_tractogram
5011    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
5012    if verbose:
5013        print("streamlines done", flush=True)
5014    return {
5015          'tractogram': sft,
5016          'streamlines': streamlines,
5017          'peak_indices': peak_indices
5018          }

Performs deterministic tractography from the DWI and returns a tractogram and path length data frame.

Arguments

dwi : an antsImage holding DWI acquisition

fa : an antsImage holding FA values

bvals : bvalues

bvecs : bvectors

num_processes : number of subprocesses

mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")

label_image : atlas labels

seed_labels : list of label numbers from the atlas labels

fa_thresh : 0.25 defaults

seed_density : 1 default number of seeds per voxel

step_size : for tracking

peak_indices : pass these in, if they are previously estimated. otherwise, will compute on the fly (slow)

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)

verbose : boolean

Returns

dictionary holding tracts and stateful object.

Example

>>> import antspymm
def dwi_closest_peak_tracking( dwi, fa, bvals, bvecs, num_processes=1, mask=None, label_image=None, seed_labels=None, fa_thresh=0.05, seed_density=1, step_size=0.15, peak_indices=None, verbose=False):
5031def dwi_closest_peak_tracking(
5032    dwi,
5033    fa,
5034    bvals,
5035    bvecs,
5036    num_processes=1,
5037    mask=None,
5038    label_image = None,
5039    seed_labels = None,
5040    fa_thresh = 0.05,
5041    seed_density = 1,
5042    step_size = 0.15,
5043    peak_indices = None,
5044    verbose = False ):
5045    """
5046
5047    Performs deterministic tractography from the DWI and returns a tractogram
5048    and path length data frame.
5049
5050    Arguments
5051    ---------
5052
5053    dwi : an antsImage holding DWI acquisition
5054
5055    fa : an antsImage holding FA values
5056
5057    bvals : bvalues
5058
5059    bvecs : bvectors
5060
5061    num_processes : number of subprocesses
5062
5063    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
5064        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
5065
5066    label_image : atlas labels
5067
5068    seed_labels : list of label numbers from the atlas labels
5069
5070    fa_thresh : 0.25 defaults
5071
5072    seed_density : 1 default number of seeds per voxel
5073
5074    step_size : for tracking
5075
5076    peak_indices : pass these in, if they are previously estimated.  otherwise, will
5077        compute on the fly (slow)
5078
5079    verbose : boolean
5080
5081    Returns
5082    -------
5083    dictionary holding tracts and stateful object.
5084
5085    Example
5086    -------
5087    >>> import antspymm
5088    """
5089    import os
5090    import re
5091    import nibabel as nib
5092    import numpy as np
5093    import ants
5094    from dipy.io.gradients import read_bvals_bvecs
5095    from dipy.core.gradients import gradient_table
5096    from dipy.tracking import utils
5097    import dipy.reconst.dti as dti
5098    from dipy.segment.clustering import QuickBundles
5099    from dipy.tracking.utils import path_length
5100    from dipy.core.gradients import gradient_table
5101    from dipy.data import small_sphere
5102    from dipy.direction import BootDirectionGetter, ClosestPeakDirectionGetter
5103    from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,
5104                                    auto_response_ssst)
5105    from dipy.reconst.shm import CsaOdfModel
5106    from dipy.tracking.local_tracking import LocalTracking
5107    from dipy.tracking.streamline import Streamlines
5108    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
5109
5110    if verbose:
5111        print("begin tracking",flush=True)
5112
5113    affine = ants_to_nibabel_affine(dwi)
5114    if isinstance( bvals, str ) or isinstance( bvecs, str ):
5115        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
5116    bvecs = repair_bvecs( bvecs )
5117    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
5118    if mask is None:
5119        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
5120    dwi_data = dwi.numpy()
5121    dwi_mask = mask.numpy() == 1
5122
5123
5124    response, ratio = auto_response_ssst(gtab, dwi_data, roi_radii=10, fa_thr=0.7)
5125    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
5126    csd_fit = csd_model.fit(dwi_data, mask=dwi_mask)
5127    csa_model = CsaOdfModel(gtab, sh_order=6)
5128    gfa = csa_model.fit(dwi_data, mask=dwi_mask).gfa
5129    stopping_criterion = ThresholdStoppingCriterion(gfa, .25)
5130
5131
5132    if label_image is None or seed_labels is None:
5133        seed_mask = fa.numpy().copy()
5134        seed_mask[seed_mask >= fa_thresh] = 1
5135        seed_mask[seed_mask < fa_thresh] = 0
5136    else:
5137        labels = label_image.numpy()
5138        seed_mask = labels * 0
5139        for u in seed_labels:
5140            seed_mask[ labels == u ] = 1
5141    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
5142    if verbose:
5143        print("streamlines begin ...", flush=True)
5144
5145    pmf = csd_fit.odf(small_sphere).clip(min=0)
5146    if verbose:
5147        print("ClosestPeakDirectionGetter begin ...", flush=True)
5148    peak_dg = ClosestPeakDirectionGetter.from_pmf(pmf, max_angle=30.,
5149                                                sphere=small_sphere)
5150    if verbose:
5151        print("local tracking begin ...", flush=True)
5152    streamlines_generator = LocalTracking(peak_dg, stopping_criterion, seeds,
5153                                            affine, step_size=.5)
5154    streamlines = Streamlines(streamlines_generator)
5155    from dipy.io.stateful_tractogram import Space, StatefulTractogram
5156    from dipy.io.streamline import save_tractogram
5157    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
5158    if verbose:
5159        print("streamlines done", flush=True)
5160    return {
5161          'tractogram': sft,
5162          'streamlines': streamlines
5163          }

Performs deterministic tractography from the DWI and returns a tractogram and path length data frame.

Arguments

dwi : an antsImage holding DWI acquisition

fa : an antsImage holding FA values

bvals : bvalues

bvecs : bvectors

num_processes : number of subprocesses

mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")

label_image : atlas labels

seed_labels : list of label numbers from the atlas labels

fa_thresh : 0.25 defaults

seed_density : 1 default number of seeds per voxel

step_size : for tracking

peak_indices : pass these in, if they are previously estimated. otherwise, will compute on the fly (slow)

verbose : boolean

Returns

dictionary holding tracts and stateful object.

Example

>>> import antspymm
def dwi_streamline_pairwise_connectivity(streamlines, label_image, labels_to_connect=[1, None], verbose=False):
5165def dwi_streamline_pairwise_connectivity( streamlines, label_image, labels_to_connect=[1,None], verbose=False ):
5166    """
5167
5168    Return streamlines connecting all of the regions in the label set. Ideal
5169    for just 2 regions.
5170
5171    Arguments
5172    ---------
5173
5174    streamlines : streamline object from dipy
5175
5176    label_image : atlas labels
5177
5178    labels_to_connect : list of 2 labels or [label,None]
5179
5180    verbose : boolean
5181
5182    Returns
5183    -------
5184    the subset of streamlines and a streamline count
5185
5186    Example
5187    -------
5188    >>> import antspymm
5189    """
5190    from dipy.tracking.streamline import Streamlines
5191    keep_streamlines = Streamlines()
5192
5193    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
5194
5195    lin_T, offset = utils._mapping_to_voxel(affine)
5196    label_image_np = label_image.numpy()
5197    def check_it( sl, target_label, label_image, index, full=False ):
5198        if full:
5199            maxind=sl.shape[0]
5200            for index in range(maxind):
5201                pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5202                mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5203                if mylab == target_label[0] or mylab == target_label[1]:
5204                    return { 'ok': True, 'label':mylab }
5205        else:
5206            pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5207            mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5208            if mylab == target_label[0] or mylab == target_label[1]:
5209                return { 'ok': True, 'label':mylab }
5210        return { 'ok': False, 'label':None }
5211    ct=0
5212    for k in range( len( streamlines ) ):
5213        sl = streamlines[k]
5214        mycheck = check_it( sl, labels_to_connect, label_image_np, index=0, full=True )
5215        if mycheck['ok']:
5216            otherind=1
5217            if mycheck['label'] == labels_to_connect[1]:
5218                otherind=0
5219            lsl = len( sl )-1
5220            pt = utils._to_voxel_coordinates(sl[lsl,:], lin_T, offset)
5221            mylab_end = (label_image_np[ pt[0], pt[1], pt[2] ]).astype(int)
5222            accept_point = mylab_end == labels_to_connect[otherind]
5223            if verbose and accept_point:
5224                print( mylab_end )
5225            if labels_to_connect[1] is None:
5226                accept_point = mylab_end != 0
5227            if accept_point:
5228                keep_streamlines.append(sl)
5229                ct=ct+1
5230    return { 'streamlines': keep_streamlines, 'count': ct }

Return streamlines connecting all of the regions in the label set. Ideal for just 2 regions.

Arguments

streamlines : streamline object from dipy

label_image : atlas labels

labels_to_connect : list of 2 labels or [label,None]

verbose : boolean

Returns

the subset of streamlines and a streamline count

Example

>>> import antspymm
def dwi_streamline_connectivity(streamlines, label_image, label_dataframe, verbose=False):
5287def dwi_streamline_connectivity(
5288    streamlines,
5289    label_image,
5290    label_dataframe,
5291    verbose = False ):
5292    """
5293
5294    Summarize network connetivity of the input streamlines between all of the
5295    regions in the label set.
5296
5297    Arguments
5298    ---------
5299
5300    streamlines : streamline object from dipy
5301
5302    label_image : atlas labels
5303
5304    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
5305
5306    verbose : boolean
5307
5308    Returns
5309    -------
5310    dictionary holding summary connection statistics in wide format and matrix format.
5311
5312    Example
5313    -------
5314    >>> import antspymm
5315    """
5316    import os
5317    import re
5318    import nibabel as nib
5319    import numpy as np
5320    import ants
5321    from dipy.io.gradients import read_bvals_bvecs
5322    from dipy.core.gradients import gradient_table
5323    from dipy.tracking import utils
5324    import dipy.reconst.dti as dti
5325    from dipy.segment.clustering import QuickBundles
5326    from dipy.tracking.utils import path_length
5327    from dipy.tracking.local_tracking import LocalTracking
5328    from dipy.tracking.streamline import Streamlines
5329    import os
5330    import re
5331    import nibabel as nib
5332    import numpy as np
5333    import ants
5334    from dipy.io.gradients import read_bvals_bvecs
5335    from dipy.core.gradients import gradient_table
5336    from dipy.tracking import utils
5337    import dipy.reconst.dti as dti
5338    from dipy.segment.clustering import QuickBundles
5339    from dipy.tracking.utils import path_length
5340    from dipy.tracking.local_tracking import LocalTracking
5341    from dipy.tracking.streamline import Streamlines
5342    volUnit = np.prod( ants.get_spacing( label_image ) )
5343    labels = label_image.numpy()
5344
5345    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
5346
5347    import numpy as np
5348    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
5349    import pandas as pd
5350    ulabs = label_dataframe['Label']
5351    labels_to_connect = ulabs[ulabs > 0]
5352    Ctdf = None
5353    lin_T, offset = utils._mapping_to_voxel(affine)
5354    label_image_np = label_image.numpy()
5355    def check_it( sl, target_label, label_image, index, not_label = None ):
5356        pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5357        mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5358        if not_label is None:
5359            if ( mylab == target_label ).sum() > 0 :
5360                return { 'ok': True, 'label':mylab }
5361        else:
5362            if ( mylab == target_label ).sum() > 0 and ( mylab == not_label ).sum() == 0:
5363                return { 'ok': True, 'label':mylab }
5364        return { 'ok': False, 'label':None }
5365    ct=0
5366    which = lambda lst:list(np.where(lst)[0])
5367    myCount = np.zeros( [len(ulabs),len(ulabs)])
5368    for k in range( len( streamlines ) ):
5369            sl = streamlines[k]
5370            mycheck = check_it( sl, labels_to_connect, label_image_np, index=0 )
5371            if mycheck['ok']:
5372                exclabel=mycheck['label']
5373                lsl = len( sl )-1
5374                mycheck2 = check_it( sl, labels_to_connect, label_image_np, index=lsl, not_label=exclabel )
5375                if mycheck2['ok']:
5376                    myCount[ulabs == mycheck['label'],ulabs == mycheck2['label']]+=1
5377                    ct=ct+1
5378    Ctdf = label_dataframe.copy()
5379    for k in range(len(ulabs)):
5380            nn3 = "CnxCount"+str(k).zfill(3)
5381            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
5382    Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
5383    return { 'connectivity_matrix' :  myCount, 'connectivity_wide' : Ctdfw }

Summarize network connetivity of the input streamlines between all of the regions in the label set.

Arguments

streamlines : streamline object from dipy

label_image : atlas labels

label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)

verbose : boolean

Returns

dictionary holding summary connection statistics in wide format and matrix format.

Example

>>> import antspymm
def hierarchical_modality_summary( target_image, hier, transformlist, modality_name, return_keys=['Mean', 'Volume'], verbose=False):
5526def hierarchical_modality_summary(
5527    target_image,
5528    hier,
5529    transformlist,
5530    modality_name,
5531    return_keys = ["Mean","Volume"],
5532    verbose = False ):
5533    """
5534
5535    Use output of antspyt1w.hierarchical to summarize a modality
5536
5537    Arguments
5538    ---------
5539
5540    target_image : the image to summarize - should be brain extracted
5541
5542    hier : dictionary holding antspyt1w.hierarchical output
5543
5544    transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)
5545
5546    modality_name : adds the modality name to the data frame columns
5547
5548    return_keys = ["Mean","Volume"] keys to return
5549
5550    verbose : boolean
5551
5552    Returns
5553    -------
5554    data frame holding summary statistics in wide format
5555
5556    Example
5557    -------
5558    >>> import antspymm
5559    """
5560    dfout = pd.DataFrame()
5561    def myhelper( target_image, seg, mytx, mapname, modname, mydf, extra='', verbose=False ):
5562        if verbose:
5563            print( mapname )
5564        target_image_mask = ants.image_clone( target_image ) * 0.0
5565        target_image_mask[ target_image != 0 ] = 1
5566        cortmapped = ants.apply_transforms(
5567            target_image,
5568            seg,
5569            mytx, interpolator='nearestNeighbor' ) * target_image_mask
5570        mapped = antspyt1w.map_intensity_to_dataframe(
5571            mapname,
5572            target_image,
5573            cortmapped )
5574        mapped.iloc[:,1] = modname + '_' + extra + mapped.iloc[:,1]
5575        mappedw = antspyt1w.merge_hierarchical_csvs_to_wide_format(
5576            { 'x' : mapped},
5577            col_names = return_keys )
5578        if verbose:
5579            print( mappedw.keys() )
5580        if mydf.shape[0] > 0:
5581            mydf = pd.concat( [ mydf, mappedw], axis=1, ignore_index=False )
5582        else:
5583            mydf = mappedw
5584        return mydf
5585    if hier['dkt_parc']['dkt_cortex'] is not None:
5586        dfout = myhelper( target_image, hier['dkt_parc']['dkt_cortex'], transformlist,
5587            "dkt", modality_name, dfout, extra='', verbose=verbose )
5588    if hier['deep_cit168lab'] is not None:
5589        dfout = myhelper( target_image, hier['deep_cit168lab'], transformlist,
5590            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='deep_', verbose=verbose )
5591    if hier['cit168lab'] is not None:
5592        dfout = myhelper( target_image, hier['cit168lab'], transformlist,
5593            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='', verbose=verbose  )
5594    if hier['bf'] is not None:
5595        dfout = myhelper( target_image, hier['bf'], transformlist,
5596            "nbm3CH13", modality_name, dfout, extra='', verbose=verbose  )
5597    # if hier['mtl'] is not None:
5598    #    dfout = myhelper( target_image, hier['mtl'], reg,
5599    #        "mtl_description", modality_name, dfout, extra='', verbose=verbose  )
5600    return dfout

Use output of antspyt1w.hierarchical to summarize a modality

Arguments

target_image : the image to summarize - should be brain extracted

hier : dictionary holding antspyt1w.hierarchical output

transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)

modality_name : adds the modality name to the data frame columns

return_keys = ["Mean","Volume"] keys to return

verbose : boolean

Returns

data frame holding summary statistics in wide format

Example

>>> import antspymm
def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30, transform=['rigid'], compreg=None, random_seed=42, verbose=False):
5612def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30,
5613    transform=['rigid'], compreg=None, random_seed=42, verbose=False ):
5614    """
5615    multi-start multi-transform registration solution - based on ants.registration
5616
5617    fixed: fixed image
5618
5619    moving: moving image
5620
5621    n_simulations : number of simulations
5622
5623    max_rotation : maximum rotation angle
5624
5625    transform : list of transforms to loop through
5626
5627    compreg : registration results against which to compare
5628
5629    random_seed : random seed for reproducibility
5630
5631    verbose : boolean
5632
5633    """
5634    import random
5635    if random_seed is not None:
5636        random.seed(random_seed)
5637    if True:
5638        output_directory = tempfile.mkdtemp()
5639        output_directory_w = output_directory + "/tra_reg/"
5640        os.makedirs(output_directory_w,exist_ok=True)
5641        bestmi = math.inf
5642        bestvar = 0.0
5643        myorig = list(ants.get_origin( fixed ))
5644        mymax = 0;
5645        for k in range(len( myorig ) ):
5646            if abs(myorig[k]) > mymax:
5647                mymax = abs(myorig[k])
5648        maxtrans = mymax * 0.05
5649        if compreg is None:
5650            bestreg=ants.registration( fixed,moving,'Translation',
5651                outprefix=output_directory_w+"trans")
5652            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
5653        else :
5654            bestreg=compreg
5655            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
5656        for mytx in transform:
5657            regtx = 'antsRegistrationSyNRepro[r]'
5658            with tempfile.NamedTemporaryFile(suffix='.h5') as tp:
5659                if mytx == 'translation':
5660                    regtx = 'Translation'
5661                    rRotGenerator = ants.contrib.RandomTranslate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
5662                elif mytx == 'affine':
5663                    regtx = 'Affine'
5664                    rRotGenerator = ants.contrib.RandomRotate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
5665                else:
5666                    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rotation*(-1.0), max_rotation ), reference=fixed )
5667                for k in range(n_simulations):
5668                    simtx = ants.compose_ants_transforms( [rRotGenerator.transform(), initx] )
5669                    ants.write_transform( simtx, tp.name )
5670                    if k > 0:
5671                        reg = ants.registration( fixed, moving, regtx,
5672                            initial_transform=tp.name,
5673                            outprefix=output_directory_w+"reg"+str(k),
5674                            verbose=False )
5675                    else:
5676                        reg = ants.registration( fixed, moving,
5677                            regtx,
5678                            outprefix=output_directory_w+"reg"+str(k),
5679                            verbose=False )
5680                    mymi = math.inf
5681                    temp = reg['warpedmovout']
5682                    myvar = temp.numpy().var()
5683                    if verbose:
5684                        print( str(k) + " : " + regtx  + " : " + mytx + " _var_ " + str( myvar ) )
5685                    if myvar > 0 :
5686                        mymi = ants.image_mutual_information( fixed, temp )
5687                        if mymi < bestmi:
5688                            if verbose:
5689                                print( "mi @ " + str(k) + " : " + str(mymi), flush=True)
5690                            bestmi = mymi
5691                            bestreg = reg
5692                            bestvar = myvar
5693        if bestvar == 0.0 and compreg is not None:
5694            return compreg        
5695        return bestreg

multi-start multi-transform registration solution - based on ants.registration

fixed: fixed image

moving: moving image

n_simulations : number of simulations

max_rotation : maximum rotation angle

transform : list of transforms to loop through

compreg : registration results against which to compare

random_seed : random seed for reproducibility

verbose : boolean

def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8, bias_correct=True, denoise=None, srmodel=None, target_range=[0, 1], poly_order='hist', normalize_nm=False, verbose=False):
5697def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8,
5698    bias_correct=True,
5699    denoise=None,
5700    srmodel=None,
5701    target_range=[0,1],
5702    poly_order='hist',
5703    normalize_nm = False,
5704    verbose=False ) :
5705
5706  """
5707  Outputs the averaged and registered neuromelanin image, and neuromelanin labels
5708
5709  Arguments
5710  ---------
5711  list_nm_image : list of ANTsImages
5712    list of neuromenlanin repeat images
5713
5714  t1 : ANTsImage
5715    input 3-D T1 brain image
5716
5717  t1_head : ANTsImage
5718    input 3-D T1 head image
5719
5720  t1lab : ANTsImage
5721    t1 labels that will be propagated to the NM
5722
5723  brain_stem_dilation : integer default 8
5724    dilates the brain stem mask to better match coverage of NM
5725
5726  bias_correct : boolean
5727
5728  denoise : None or integer
5729
5730  srmodel : None -- this is a work in progress feature, probably not optimal
5731
5732  target_range : 2-element tuple
5733        a tuple or array defining the (min, max) of the input image
5734        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
5735        intensity. This range should match the mapping used in the training
5736        of the network.
5737
5738  poly_order : if not None, will fit a global regression model to map
5739      intensity back to original histogram space; if 'hist' will match
5740      by histogram matching - ants.histogram_match_image
5741
5742  normalize_nm : boolean - WIP not validated
5743
5744  verbose : boolean
5745
5746  Returns
5747  ---------
5748  Averaged and registered neuromelanin image and neuromelanin labels and wide csv
5749
5750  """
5751
5752  fnt=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni.nii.gz" )
5753  fntNM=os.path.expanduser("~/.antspymm/CIT168_T1w_700um_pad_adni_NM_norm_avg.nii.gz" )
5754  fntbst=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni_brainstem.nii.gz")
5755  fnslab=os.path.expanduser("~/.antspyt1w/CIT168_MT_Slab_adni.nii.gz")
5756  fntseg=os.path.expanduser("~/.antspyt1w/det_atlas_25_pad_LR_adni.nii.gz")
5757
5758  template = mm_read( fnt )
5759  templateNM = ants.iMath( mm_read( fntNM ), "Normalize" )
5760  templatebstem = mm_read( fntbst ).threshold_image( 1, 1000 )
5761  # reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
5762  reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
5763  # map NM avg to t1 for neuromelanin processing
5764  nmavg2t1 = ants.apply_transforms( t1, templateNM,
5765    reg['fwdtransforms'], interpolator='linear' )
5766  slab2t1 = ants.threshold_image( nmavg2t1, "Otsu", 2 ).threshold_image(1,2).iMath("MD",1).iMath("FillHoles")
5767  # map brain stem and slab to t1 for neuromelanin processing
5768  bstem2t1 = ants.apply_transforms( t1, templatebstem,
5769    reg['fwdtransforms'],
5770    interpolator='nearestNeighbor' ).iMath("MD",1)
5771  slab2t1B = ants.apply_transforms( t1, mm_read( fnslab ),
5772    reg['fwdtransforms'], interpolator = 'nearestNeighbor')
5773  bstem2t1 = ants.crop_image( bstem2t1, slab2t1 )
5774  cropper = ants.decrop_image( bstem2t1, slab2t1 ).iMath("MD",brain_stem_dilation)
5775
5776  # Average images in image_list
5777  nm_avg = list_nm_images[0]*0.0
5778  for k in range(len( list_nm_images )):
5779    if denoise is not None:
5780        list_nm_images[k] = ants.denoise_image( list_nm_images[k],
5781            shrink_factor=1,
5782            p=denoise,
5783            r=denoise+1,
5784            noise_model='Gaussian' )
5785    if bias_correct :
5786        n4mask = ants.threshold_image( ants.iMath(list_nm_images[k], "Normalize" ), 0.05, 1 )
5787        list_nm_images[k] = ants.n4_bias_field_correction( list_nm_images[k], mask=n4mask )
5788    nm_avg = nm_avg + ants.resample_image_to_target( list_nm_images[k], nm_avg ) / len( list_nm_images )
5789
5790  if verbose:
5791      print("Register each nm image in list_nm_images to the averaged nm image (avg)")
5792  nm_avg_new = nm_avg * 0.0
5793  txlist = []
5794  for k in range(len( list_nm_images )):
5795    if verbose:
5796        print(str(k) + " of " + str(len( list_nm_images ) ) )
5797    current_image = ants.registration( list_nm_images[k], nm_avg,
5798        type_of_transform = 'antsRegistrationSyNRepro[r]' )
5799    txlist.append( current_image['fwdtransforms'][0] )
5800    current_image = current_image['warpedfixout']
5801    nm_avg_new = nm_avg_new + current_image / len( list_nm_images )
5802  nm_avg = nm_avg_new
5803
5804  if verbose:
5805      print("do slab registration to map anatomy to NM space")
5806  t1c = ants.crop_image( t1_head, slab2t1 ).iMath("Normalize") # old way
5807  nmavg2t1c = ants.crop_image( nmavg2t1, slab2t1 ).iMath("Normalize")
5808  # slabreg = ants.registration( nm_avg, nmavg2t1c, 'antsRegistrationSyNRepro[r]' )
5809  slabreg = tra_initializer( nm_avg, t1c, verbose=verbose )
5810  if False:
5811      slabregT1 = tra_initializer( nm_avg, t1c, verbose=verbose  )
5812      miNM = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
5813            ants.iMath(slabreg0['warpedmovout'],"Normalize") )
5814      miT1 = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
5815            ants.iMath(slabreg1['warpedmovout'],"Normalize") )
5816      if miT1 < miNM:
5817        slabreg = slabregT1
5818  labels2nm = ants.apply_transforms( nm_avg, t1lab, slabreg['fwdtransforms'],
5819    interpolator = 'genericLabel' )
5820  cropper2nm = ants.apply_transforms( nm_avg, cropper, slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
5821  nm_avg_cropped = ants.crop_image( nm_avg, cropper2nm )
5822
5823  if verbose:
5824      print("now map these labels to each individual nm")
5825  crop_mask_list = []
5826  crop_nm_list = []
5827  for k in range(len( list_nm_images )):
5828      concattx = []
5829      concattx.append( txlist[k] )
5830      concattx.append( slabreg['fwdtransforms'][0] )
5831      cropmask = ants.apply_transforms( list_nm_images[k], cropper,
5832        concattx, interpolator = 'nearestNeighbor' )
5833      crop_mask_list.append( cropmask )
5834      temp = ants.crop_image( list_nm_images[k], cropmask )
5835      crop_nm_list.append( temp )
5836
5837  if srmodel is not None:
5838      if verbose:
5839          print( " start sr " + str(len( crop_nm_list )) )
5840      for k in range(len( crop_nm_list )):
5841          if verbose:
5842              print( " do sr " + str(k) )
5843              print( crop_nm_list[k] )
5844          temp = antspynet.apply_super_resolution_model_to_image(
5845                crop_nm_list[k], srmodel, target_range=target_range,
5846                regression_order=None )
5847          if poly_order is not None:
5848              bilin = ants.resample_image_to_target( crop_nm_list[k], temp )
5849              if poly_order == 'hist':
5850                  temp = ants.histogram_match_image( temp, bilin )
5851              else:
5852                  temp = antspynet.regression_match_image( temp, bilin, poly_order = poly_order )
5853          crop_nm_list[k] = temp
5854
5855  nm_avg_cropped = crop_nm_list[0]*0.0
5856  if verbose:
5857      print( "cropped average" )
5858      print( nm_avg_cropped )
5859  for k in range(len( crop_nm_list )):
5860      nm_avg_cropped = nm_avg_cropped + ants.apply_transforms( nm_avg_cropped,
5861        crop_nm_list[k], txlist[k] ) / len( crop_nm_list )
5862  for loop in range( 3 ):
5863      nm_avg_cropped_new = nm_avg_cropped * 0.0
5864      for k in range(len( crop_nm_list )):
5865            myreg = ants.registration(
5866                ants.iMath(nm_avg_cropped,"Normalize"),
5867                ants.iMath(crop_nm_list[k],"Normalize"),
5868                'antsRegistrationSyNRepro[r]' )
5869            warpednext = ants.apply_transforms(
5870                nm_avg_cropped_new,
5871                crop_nm_list[k],
5872                myreg['fwdtransforms'] )
5873            nm_avg_cropped_new = nm_avg_cropped_new + warpednext
5874      nm_avg_cropped = nm_avg_cropped_new / len( crop_nm_list )
5875
5876  slabregUpdated = tra_initializer( nm_avg_cropped, t1c, compreg=slabreg,verbose=verbose  )
5877  tempOrig = ants.apply_transforms( nm_avg_cropped_new, t1c, slabreg['fwdtransforms'] )
5878  tempUpdate = ants.apply_transforms( nm_avg_cropped_new, t1c, slabregUpdated['fwdtransforms'] )
5879  miUpdate = ants.image_mutual_information(
5880    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempUpdate,"Normalize") )
5881  miOrig = ants.image_mutual_information(
5882    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempOrig,"Normalize") )
5883  if miUpdate < miOrig :
5884      slabreg = slabregUpdated
5885
5886  if normalize_nm:
5887      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
5888      nm_avg_cropped = ants.iMath( nm_avg_cropped, "TruncateIntensity",0.05,0.95)
5889      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
5890
5891  labels2nm = ants.apply_transforms( nm_avg_cropped, t1lab,
5892        slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
5893
5894  # fix the reference region - keep top two parts
5895  def get_biggest_part( x, labeln ):
5896      temp33 = ants.threshold_image( x, labeln, labeln ).iMath("GetLargestComponent")
5897      x[ x == labeln] = 0
5898      x[ temp33 == 1 ] = labeln
5899
5900  get_biggest_part( labels2nm, 33 )
5901  get_biggest_part( labels2nm, 34 )
5902
5903  if verbose:
5904      print( "map summary measurements to wide format" )
5905  nmdf = antspyt1w.map_intensity_to_dataframe(
5906          'CIT168_Reinf_Learn_v1_label_descriptions_pad',
5907          nm_avg_cropped,
5908          labels2nm)
5909  if verbose:
5910      print( "merge to wide format" )
5911  nmdf_wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
5912              {'NM' : nmdf},
5913              col_names = ['Mean'] )
5914
5915  rr_mask = ants.mask_image( labels2nm, labels2nm, [33,34] , binarize=True )
5916  sn_mask = ants.mask_image( labels2nm, labels2nm, [7,9,23,25] , binarize=True )
5917  nmavgsnr = mask_snr( nm_avg_cropped, rr_mask, sn_mask, bias_correct = False )
5918
5919  snavg = nm_avg_cropped[ sn_mask == 1].mean()
5920  rravg = nm_avg_cropped[ rr_mask == 1].mean()
5921  snstd = nm_avg_cropped[ sn_mask == 1].std()
5922  rrstd = nm_avg_cropped[ rr_mask == 1].std()
5923  vol_element = np.prod( ants.get_spacing(sn_mask) )
5924  snvol = vol_element * sn_mask.sum()
5925
5926  # get the mean voxel position of the SN
5927  if snvol > 0:
5928      sn_z = ants.transform_physical_point_to_index( sn_mask, ants.get_center_of_mass(sn_mask ))[2]
5929      sn_z = sn_z/sn_mask.shape[2] # around 0.5 would be nice
5930  else:
5931      sn_z = math.nan
5932
5933  nm_evr = 0.0
5934  if cropper2nm.sum() > 0:
5935    nm_evr = antspyt1w.patch_eigenvalue_ratio( nm_avg, 512, [6,6,6], 
5936        evdepth = 0.9, mask=cropper2nm )
5937
5938  simg = ants.smooth_image( nm_avg_cropped, np.min(ants.get_spacing(nm_avg_cropped)) )
5939  k = 2.0
5940  rrthresh = (rravg + k * rrstd)
5941  nmabovekthresh_mask = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5942  snvolabovethresh = vol_element * nmabovekthresh_mask.sum()
5943  snintmeanabovethresh = float( ( simg * nmabovekthresh_mask ).mean() )
5944  snintsumabovethresh = float( ( simg * nmabovekthresh_mask ).sum() )
5945
5946  k = 3.0
5947  rrthresh = (rravg + k * rrstd)
5948  nmabovekthresh_mask3 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5949  snvolabovethresh3 = vol_element * nmabovekthresh_mask3.sum()
5950
5951  k = 1.0
5952  rrthresh = (rravg + k * rrstd)
5953  nmabovekthresh_mask1 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5954  snvolabovethresh1 = vol_element * nmabovekthresh_mask1.sum()
5955  
5956  if verbose:
5957    print( "nm vol @2std above rrmean: " + str( snvolabovethresh ) )
5958    print( "nm intmean @2std above rrmean: " + str( snintmeanabovethresh ) )
5959    print( "nm intsum @2std above rrmean: " + str( snintsumabovethresh ) )
5960    print( "nm done" )
5961
5962  return convert_np_in_dict( {
5963      'NM_avg' : nm_avg,
5964      'NM_avg_cropped' : nm_avg_cropped,
5965      'NM_labels': labels2nm,
5966      'NM_cropped': crop_nm_list,
5967      'NM_midbrainROI': cropper2nm,
5968      'NM_dataframe': nmdf,
5969      'NM_dataframe_wide': nmdf_wide,
5970      't1_to_NM': slabreg['warpedmovout'],
5971      't1_to_NM_transform' : slabreg['fwdtransforms'],
5972      'NM_avg_signaltonoise' : nmavgsnr,
5973      'NM_avg_substantianigra' : snavg,
5974      'NM_std_substantianigra' : snstd,
5975      'NM_volume_substantianigra' : snvol,
5976      'NM_volume_substantianigra_1std' : snvolabovethresh1,
5977      'NM_volume_substantianigra_2std' : snvolabovethresh,
5978      'NM_intmean_substantianigra_2std' : snintmeanabovethresh,
5979      'NM_intsum_substantianigra_2std' : snintsumabovethresh,
5980      'NM_volume_substantianigra_3std' : snvolabovethresh3,
5981      'NM_avg_refregion' : rravg,
5982      'NM_std_refregion' : rrstd,
5983      'NM_min' : nm_avg_cropped.min(),
5984      'NM_max' : nm_avg_cropped.max(),
5985      'NM_mean' : nm_avg_cropped.numpy().mean(),
5986      'NM_sd' : np.std( nm_avg_cropped.numpy() ),
5987      'NM_q0pt05' : np.quantile( nm_avg_cropped.numpy(), 0.05 ),
5988      'NM_q0pt10' : np.quantile( nm_avg_cropped.numpy(), 0.10 ),
5989      'NM_q0pt90' : np.quantile( nm_avg_cropped.numpy(), 0.90 ),
5990      'NM_q0pt95' : np.quantile( nm_avg_cropped.numpy(), 0.95 ),
5991      'NM_substantianigra_z_coordinate' : sn_z,
5992      'NM_evr' : nm_evr,
5993      'NM_count': len( list_nm_images )
5994       } )

Outputs the averaged and registered neuromelanin image, and neuromelanin labels

Arguments

list_nm_image : list of ANTsImages list of neuromenlanin repeat images

t1 : ANTsImage input 3-D T1 brain image

t1_head : ANTsImage input 3-D T1 head image

t1lab : ANTsImage t1 labels that will be propagated to the NM

brain_stem_dilation : integer default 8 dilates the brain stem mask to better match coverage of NM

bias_correct : boolean

denoise : None or integer

srmodel : None -- this is a work in progress feature, probably not optimal

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

poly_order : if not None, will fit a global regression model to map intensity back to original histogram space; if 'hist' will match by histogram matching - ants.histogram_match_image

normalize_nm : boolean - WIP not validated

verbose : boolean

Returns

Averaged and registered neuromelanin image and neuromelanin labels and wide csv

def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation, f=[0.03, 0.08], FD_threshold=5.0, spa=None, spt=None, nc=5, outlier_threshold=0.25, ica_components=0, impute=True, censor=True, despike=2.5, motion_as_nuisance=True, powers=False, upsample=3.0, clean_tmp=None, paramset='unset', verbose=False):
6092def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation,
6093    f=[0.03, 0.08],
6094    FD_threshold=5.0,
6095    spa = None,
6096    spt = None,
6097    nc = 5,
6098    outlier_threshold=0.250,
6099    ica_components = 0,
6100    impute = True,
6101    censor = True,
6102    despike = 2.5,
6103    motion_as_nuisance = True,
6104    powers = False,
6105    upsample = 3.0,
6106    clean_tmp = None,
6107    paramset='unset',
6108    verbose=False ):
6109  """
6110  Compute resting state network correlation maps based on the J Power labels.
6111  This will output a map for each of the major network systems.  This function 
6112  will by optionally upsample data to 2mm during the registration process if data 
6113  is below that resolution.
6114
6115  registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations
6116
6117  Arguments
6118  ---------
6119  fmri : BOLD fmri antsImage
6120
6121  fmri_template : reference space for BOLD
6122
6123  t1 : ANTsImage
6124    input 3-D T1 brain image (brain extracted)
6125
6126  t1segmentation : ANTsImage
6127    t1 segmentation - a six tissue segmentation image in T1 space
6128
6129  f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015
6130
6131  spa : gaussian smoothing for spatial component (physical coordinates)
6132
6133  spt : gaussian smoothing for temporal component
6134
6135  nc  : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm
6136
6137  ica_components : integer if greater than 0 then include ica components
6138
6139  impute : boolean if True, then use imputation in f/ALFF, PerAF calculation
6140
6141  censor : boolean if True, then use censoring (censoring)
6142
6143  despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction
6144
6145  motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance
6146
6147  powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)
6148
6149  upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution
6150
6151  clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing.  if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.
6152
6153  verbose : boolean
6154
6155  Returns
6156  ---------
6157  a dictionary containing the derived network maps
6158
6159  References
6160  ---------
6161
6162  10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."
6163
6164  10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable"  "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding  high frequency information) reduced both reliability and signal-noise separation"
6165
6166  10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.
6167
6168  10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.
6169
6170  """
6171
6172  import warnings
6173
6174  if clean_tmp is not None:
6175    clean_tmp_directory( age_hours = clean_tmp )
6176
6177  if nc > 1:
6178    nc = int(nc)
6179  else:
6180    nc=float(nc)
6181
6182  type_of_transform="antsRegistrationSyNQuickRepro[r]" # , # should probably not change this
6183  remove_it=True
6184  output_directory = tempfile.mkdtemp()
6185  output_directory_w = output_directory + "/ts_t1_reg/"
6186  os.makedirs(output_directory_w,exist_ok=True)
6187  ofnt1tx = tempfile.NamedTemporaryFile(delete=False,suffix='t1_deformation',dir=output_directory_w).name
6188
6189  import numpy as np
6190# Assuming core and utils are modules or packages with necessary functions
6191
6192  if upsample > 0.0:
6193      spc = ants.get_spacing( fmri )
6194      minspc = upsample
6195      if min(spc[0:3]) < minspc:
6196          minspc = min(spc[0:3])
6197      newspc = [minspc,minspc,minspc]
6198      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
6199
6200  def temporal_derivative_same_shape(array):
6201    """
6202    Compute the temporal derivative of a 2D numpy array along the 0th axis (time)
6203    and ensure the output has the same shape as the input.
6204
6205    :param array: 2D numpy array with time as the 0th axis.
6206    :return: 2D numpy array of the temporal derivative with the same shape as input.
6207    """
6208    derivative = np.diff(array, axis=0)
6209    
6210    # Append a row to maintain the same shape
6211    # You can choose to append a row of zeros or the last row of the derivative
6212    # Here, a row of zeros is appended
6213    zeros_row = np.zeros((1, array.shape[1]))
6214    return np.vstack((zeros_row, derivative ))
6215
6216  def compute_tSTD(M, quantile, x=0, axis=0):
6217    stdM = np.std(M, axis=axis)
6218    # set bad values to x
6219    stdM[stdM == 0] = x
6220    stdM[np.isnan(stdM)] = x
6221    tt = round(quantile * 100)
6222    threshold_std = np.percentile(stdM, tt)
6223    return {'tSTD': stdM, 'threshold_std': threshold_std}
6224
6225  def get_compcor_matrix(boldImage, mask, quantile):
6226    """
6227    Compute the compcor matrix.
6228
6229    :param boldImage: The bold image.
6230    :param mask: The mask to apply, if None, it will be computed.
6231    :param quantile: Quantile for computing threshold in tSTD.
6232    :return: The compor matrix.
6233    """
6234    if mask is None:
6235        temp = ants.slice_image(boldImage, axis=boldImage.dimension - 1, idx=0)
6236        mask = ants.get_mask(temp)
6237
6238    imagematrix = ants.timeseries_to_matrix(boldImage, mask)
6239    temp = compute_tSTD(imagematrix, quantile, 0)
6240    tsnrmask = ants.make_image(mask, temp['tSTD'])
6241    tsnrmask = ants.threshold_image(tsnrmask, temp['threshold_std'], temp['tSTD'].max())
6242    M = ants.timeseries_to_matrix(boldImage, tsnrmask)
6243    return M
6244
6245
6246  from sklearn.decomposition import FastICA
6247  def find_indices(lst, value):
6248    return [index for index, element in enumerate(lst) if element > value]
6249
6250  def mean_of_list(lst):
6251    if not lst:  # Check if the list is not empty
6252        return 0  # Return 0 or appropriate value for an empty list
6253    return sum(lst) / len(lst)
6254  fmrispc = list( ants.get_spacing( fmri ) )
6255  if spa is None:
6256    spa = mean_of_list( fmrispc[0:3] ) * 1.0
6257  if spt is None:
6258    spt = fmrispc[3] * 0.5
6259      
6260  import numpy as np
6261  import pandas as pd
6262  import re
6263  import math
6264  # point data resources
6265  A = np.zeros((1,1))
6266  dfnname='DefaultMode'
6267  if powers:
6268      powers_areal_mni_itk = pd.read_csv( get_data('powers_mni_itk', target_extension=".csv")) # power coordinates
6269      coords='powers'
6270  else:
6271      powers_areal_mni_itk = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
6272      coords='yeo_17_500_2023'
6273  fmri = ants.iMath( fmri, 'Normalize' )
6274  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("FillHoles")
6275  if verbose:
6276      print("Begin rsfmri motion correction")
6277  debug=False
6278  if debug:
6279      ants.image_write( fmri_template, '/tmp/fmri_template.nii.gz' )
6280      ants.image_write( fmri, '/tmp/fmri.nii.gz' )
6281      print("debug wrote fmri and fmri_template")
6282  # mot-co
6283  corrmo = timeseries_reg(
6284    fmri, fmri_template,
6285    type_of_transform=type_of_transform,
6286    total_sigma=0.5,
6287    fdOffset=2.0,
6288    trim = 8,
6289    output_directory=None,
6290    verbose=verbose,
6291    syn_metric='CC',
6292    syn_sampling=2,
6293    reg_iterations=[40,20,5],
6294    return_numpy_motion_parameters=True )
6295  
6296  if verbose:
6297      print("End rsfmri motion correction")
6298      print("--maximum motion : " + str(corrmo['FD'].max()) )
6299      print("=== next anatomically based mapping ===")
6300
6301  despiking_count = np.zeros( corrmo['motion_corrected'].shape[3] )
6302  if despike > 0.0:
6303      corrmo['motion_corrected'], despiking_count = despike_time_series_afni( corrmo['motion_corrected'], c1=despike )
6304
6305  despiking_count_summary = despiking_count.sum() / np.prod( corrmo['motion_corrected'].shape )
6306  high_motion_count=(corrmo['FD'] > FD_threshold ).sum()
6307  high_motion_pct=high_motion_count / fmri.shape[3]
6308
6309  # filter mask based on TSNR
6310  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
6311  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
6312  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",2)
6313  bmask = bmask * tsnrmask
6314
6315  # anatomical mapping
6316  und = fmri_template * bmask
6317  t1reg = ants.registration( und, t1,
6318     "antsRegistrationSyNQuickRepro[s]", outprefix=ofnt1tx )
6319  if verbose:
6320    print("t1 2 bold done")
6321  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
6322  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
6323  gmseg = ants.threshold_image( gmseg, 1, 4 )
6324  gmseg = ants.iMath( gmseg, 'MD', 1 ) # FIXMERSF
6325  gmseg = ants.apply_transforms( und, gmseg,
6326    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' ) * bmask
6327  csfAndWM = ( ants.threshold_image( t1segmentation, 1, 1 ) +
6328               ants.threshold_image( t1segmentation, 3, 3 ) ).morphology("erode",1)
6329  csfAndWM = ants.apply_transforms( und, csfAndWM,
6330    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6331  csf = ants.threshold_image( t1segmentation, 1, 1 )
6332  csf = ants.apply_transforms( und, csf, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6333  wm = ants.threshold_image( t1segmentation, 3, 3 ).morphology("erode",1)
6334  wm = ants.apply_transforms( und, wm, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6335  if powers:
6336    ch2 = mm_read( ants.get_ants_data( "ch2" ) )
6337  else:
6338    ch2 = mm_read( get_data( "PPMI_template0_brain", target_extension='.nii.gz' ) )
6339  treg = ants.registration( 
6340    # this is to make the impact of resolution consistent
6341    ants.resample_image(t1, [1.0,1.0,1.0], interp_type=0), 
6342    ch2, "antsRegistrationSyNQuickRepro[s]" )
6343  if powers:
6344    concatx2 = treg['invtransforms'] + t1reg['invtransforms']
6345    pts2bold = ants.apply_transforms_to_points( 3, powers_areal_mni_itk, concatx2,
6346        whichtoinvert = ( True, False, True, False ) )
6347    locations = pts2bold.iloc[:,:3].values
6348    ptImg = ants.make_points_image( locations, bmask, radius = 2 )
6349  else:
6350    concatx2 = t1reg['fwdtransforms'] + treg['fwdtransforms']    
6351    rsfsegfn = get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".nii.gz")
6352    rsfsegimg = ants.image_read( rsfsegfn )
6353    ptImg = ants.apply_transforms( und, rsfsegimg, concatx2, interpolator='nearestNeighbor' ) * bmask
6354    pts2bold = powers_areal_mni_itk
6355    # ants.plot( und, ptImg, crop=True, axis=2 )
6356
6357  # optional smoothing
6358  tr = ants.get_spacing( corrmo['motion_corrected'] )[3]
6359  smth = ( spa, spa, spa, spt ) # this is for sigmaInPhysicalCoordinates = TRUE
6360  simg = ants.smooth_image( corrmo['motion_corrected'], smth, sigma_in_physical_coordinates = True )
6361
6362  # collect censoring indices
6363  hlinds = find_indices( corrmo['FD'], FD_threshold )
6364  if verbose:
6365    print("high motion indices")
6366    print( hlinds )
6367  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
6368    fmrimotcorr, hlinds2 = loop_timeseries_censoring( corrmo['motion_corrected'], 
6369      threshold=outlier_threshold, verbose=verbose )
6370    hlinds.extend( hlinds2 )
6371    del fmrimotcorr
6372  hlinds = list(set(hlinds)) # make unique
6373
6374  # nuisance
6375  globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], bmask )
6376  globalsignal = np.nanmean( globalmat, axis = 1 )
6377  del globalmat
6378  compcorquantile=0.50
6379  nc_wm=nc_csf=nc
6380  if nc < 1:
6381    globalmat = get_compcor_matrix( corrmo['motion_corrected'], wm, compcorquantile )
6382    nc_wm = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
6383    globalmat = get_compcor_matrix( corrmo['motion_corrected'], csf, compcorquantile )
6384    nc_csf = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
6385    del globalmat
6386  if verbose:
6387    print("include compcor components as nuisance: csf " + str(nc_csf) + " wm " + str(nc_wm))
6388  mycompcor_csf = ants.compcor( corrmo['motion_corrected'],
6389    ncompcor=nc_csf, quantile=compcorquantile, mask = csf,
6390    filter_type='polynomial', degree=2 )
6391  mycompcor_wm = ants.compcor( corrmo['motion_corrected'],
6392    ncompcor=nc_wm, quantile=compcorquantile, mask = wm,
6393    filter_type='polynomial', degree=2 )
6394  nuisance = np.c_[ mycompcor_csf[ 'components' ], mycompcor_wm[ 'components' ] ]
6395
6396  if motion_as_nuisance:
6397      if verbose:
6398          print("include motion as nuisance")
6399          print( corrmo['motion_parameters'].shape )
6400      deriv = temporal_derivative_same_shape( corrmo['motion_parameters']  )
6401      nuisance = np.c_[ nuisance, corrmo['motion_parameters'], deriv ]
6402
6403  if ica_components > 0:
6404    if verbose:
6405        print("include ica components as nuisance: " + str(ica_components))
6406    ica = FastICA(n_components=ica_components, max_iter=10000, tol=0.001, random_state=42 )
6407    globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], csfAndWM )
6408    nuisance_ica = ica.fit_transform(globalmat)  # Reconstruct signals
6409    nuisance = np.c_[ nuisance, nuisance_ica ]
6410    del globalmat
6411
6412  # concat all nuisance data
6413  # nuisance = np.c_[ nuisance, mycompcor['basis'] ]
6414  # nuisance = np.c_[ nuisance, corrmo['FD'] ]
6415  nuisance = np.c_[ nuisance, globalsignal ]
6416
6417  if impute:
6418    simgimp = impute_timeseries( simg, hlinds, method='linear')
6419  else:
6420    simgimp = simg
6421
6422  # falff/alff stuff  def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
6423  myfalff=alff_image( simgimp, bmask, flo=f[0], fhi=f[1], nuisance=nuisance  )
6424
6425  # bandpass any data collected before here -- if bandpass requested
6426  if f[0] > 0 and f[1] < 1.0:
6427    if verbose:
6428        print( "bandpass: " + str(f[0]) + " <=> " + str( f[1] ) )
6429    nuisance = ants.bandpass_filter_matrix( nuisance, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
6430    globalmat = ants.timeseries_to_matrix( simg, bmask )
6431    globalmat = ants.bandpass_filter_matrix( globalmat, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
6432    simg = ants.matrix_to_timeseries( simg, globalmat, bmask )
6433
6434  if verbose:
6435    print("now regress nuisance")
6436
6437
6438  if len( hlinds ) > 0 :
6439    if censor:
6440        nuisance = remove_elements_from_numpy_array( nuisance, hlinds  )
6441        simg = remove_volumes_from_timeseries( simg, hlinds )
6442
6443  gmmat = ants.timeseries_to_matrix( simg, bmask )
6444  gmmat = ants.regress_components( gmmat, nuisance )
6445  simg = ants.matrix_to_timeseries(simg, gmmat, bmask)
6446
6447
6448  # structure the output data
6449  outdict = {}
6450  outdict['paramset'] = paramset
6451  outdict['upsampling'] = upsample
6452  outdict['coords'] = coords
6453  outdict['dfnname']=dfnname
6454  outdict['meanBold'] = und
6455
6456  # add correlation matrix that captures each node pair
6457  # some of the spheres overlap so extract separately from each ROI
6458  if powers:
6459    nPoints = int(pts2bold['ROI'].max())
6460    pointrange = list(range(int(nPoints)))
6461  else:
6462    nPoints = int(ptImg.max())
6463    pointrange = list(range(int(nPoints)))
6464  nVolumes = simg.shape[3]
6465  meanROI = np.zeros([nVolumes, nPoints])
6466  roiNames = []
6467  if debug:
6468      ptImgAll = und * 0.
6469  for i in pointrange:
6470    # specify name for matrix entries that's links back to ROI number and network; e.g., ROI1_Uncertain
6471    netLabel = re.sub( " ", "", pts2bold.loc[i,'SystemName'])
6472    netLabel = re.sub( "-", "", netLabel )
6473    netLabel = re.sub( "/", "", netLabel )
6474    roiLabel = "ROI" + str(pts2bold.loc[i,'ROI']) + '_' + netLabel
6475    roiNames.append( roiLabel )
6476    if powers:
6477        ptImage = ants.make_points_image(pts2bold.iloc[[i],:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6478    else:
6479        #print("Doing " + pts2bold.loc[i,'SystemName'] + " at " + str(i) )
6480        #ptImage = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==pts2bold.loc[i,'SystemName']],binarize=True)
6481        ptImage=ants.threshold_image( ptImg, pts2bold.loc[i,'ROI'], pts2bold.loc[i,'ROI'] )
6482    if debug:
6483      ptImgAll = ptImgAll + ptImage
6484    if ptImage.sum() > 0 :
6485        meanROI[:,i] = ants.timeseries_to_matrix( simg, ptImage).mean(axis=1)
6486
6487  if debug:
6488      ants.image_write( simg, '/tmp/simg.nii.gz' )
6489      ants.image_write( ptImgAll, '/tmp/ptImgAll.nii.gz' )
6490      ants.image_write( und, '/tmp/und.nii.gz' )
6491      ants.image_write( und, '/tmp/und.nii.gz' )
6492
6493  # get full correlation matrix
6494  corMat = np.corrcoef(meanROI, rowvar=False)
6495  outputMat = pd.DataFrame(corMat)
6496  outputMat.columns = roiNames
6497  outputMat['ROIs'] = roiNames
6498  # add to dictionary
6499  outdict['fullCorrMat'] = outputMat
6500
6501  networks = powers_areal_mni_itk['SystemName'].unique()
6502  # this is just for human readability - reminds us of which we choose by default
6503  if powers:
6504    netnames = ['Cingulo-opercular Task Control', 'Default Mode',
6505                    'Memory Retrieval', 'Ventral Attention', 'Visual',
6506                    'Fronto-parietal Task Control', 'Salience', 'Subcortical',
6507                    'Dorsal Attention']
6508    numofnets = [3,5,6,7,8,9,10,11,13]
6509  else:
6510    netnames = networks
6511    numofnets = list(range(len(netnames)))
6512 
6513  ct = 0
6514  for mynet in numofnets:
6515    netname = re.sub( " ", "", networks[mynet] )
6516    netname = re.sub( "-", "", netname )
6517    ww = np.where( powers_areal_mni_itk['SystemName'] == networks[mynet] )[0]
6518    if powers:
6519        dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6520    else:
6521        dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[mynet]],binarize=True)
6522    if dfnImg.max() >= 1:
6523        if verbose:
6524            print("DO: " + coords + " " + netname )
6525        dfnmat = ants.timeseries_to_matrix( simg, ants.threshold_image( dfnImg, 1, dfnImg.max() ) )
6526        dfnsignal = np.nanmean( dfnmat, axis = 1 )
6527        nan_count_dfn = np.count_nonzero( np.isnan( dfnsignal) )
6528        if nan_count_dfn > 0 :
6529            warnings.warn( " mynet " + netnames[ mynet ] + " vs " +  " mean-signal has nans " + str( nan_count_dfn ) ) 
6530        gmmatDFNCorr = np.zeros( gmmat.shape[1] )
6531        if nan_count_dfn == 0:
6532            for k in range( gmmat.shape[1] ):
6533                nan_count_gm = np.count_nonzero( np.isnan( gmmat[:,k]) )
6534                if debug and False:
6535                    print( str( k ) +  " nans gm " + str(nan_count_gm)  )
6536                if nan_count_gm == 0:
6537                    gmmatDFNCorr[ k ] = pearsonr( dfnsignal, gmmat[:,k] )[0]
6538        corrImg = ants.make_image( bmask, gmmatDFNCorr  )
6539        outdict[ netname ] = corrImg * gmseg
6540    else:
6541        outdict[ netname ] = None
6542    ct = ct + 1
6543
6544  A = np.zeros( ( len( numofnets ) , len( numofnets ) ) )
6545  A_wide = np.zeros( ( 1, len( numofnets ) * len( numofnets ) ) )
6546  newnames=[]
6547  newnames_wide=[]
6548  ct = 0
6549  for i in range( len( numofnets ) ):
6550      netnamei = re.sub( " ", "", networks[numofnets[i]] )
6551      netnamei = re.sub( "-", "", netnamei )
6552      newnames.append( netnamei  )
6553      ww = np.where( powers_areal_mni_itk['SystemName'] == networks[numofnets[i]] )[0]
6554      if powers:
6555          dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6556      else:
6557          dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[numofnets[i]]],binarize=True)
6558      for j in range( len( numofnets ) ):
6559          netnamej = re.sub( " ", "", networks[numofnets[j]] )
6560          netnamej = re.sub( "-", "", netnamej )
6561          newnames_wide.append( netnamei + "_2_" + netnamej )
6562          A[i,j] = 0
6563          if dfnImg is not None and netnamej is not None:
6564            subbit = dfnImg == 1
6565            if subbit is not None:
6566                if subbit.sum() > 0 and netnamej in outdict:
6567                    A[i,j] = outdict[ netnamej ][ subbit ].mean()
6568          A_wide[0,ct] = A[i,j]
6569          ct=ct+1
6570
6571  A = pd.DataFrame( A )
6572  A.columns = newnames
6573  A['networks']=newnames
6574  A_wide = pd.DataFrame( A_wide )
6575  A_wide.columns = newnames_wide
6576  outdict['corr'] = A
6577  outdict['corr_wide'] = A_wide
6578  outdict['fmri_template'] = fmri_template
6579  outdict['brainmask'] = bmask
6580  outdict['gmmask'] = gmseg
6581  outdict['alff'] = myfalff['alff']
6582  outdict['falff'] = myfalff['falff']
6583  # add global mean and standard deviation for post-hoc z-scoring
6584  outdict['alff_mean'] = (myfalff['alff'][myfalff['alff']!=0]).mean()
6585  outdict['alff_sd'] = (myfalff['alff'][myfalff['alff']!=0]).std()
6586  outdict['falff_mean'] = (myfalff['falff'][myfalff['falff']!=0]).mean()
6587  outdict['falff_sd'] = (myfalff['falff'][myfalff['falff']!=0]).std()
6588
6589  perafimg = PerAF( simgimp, bmask )
6590  for k in pointrange:
6591    anatname=( pts2bold['AAL'][k] )
6592    if isinstance(anatname, str):
6593        anatname = re.sub("_","",anatname)
6594    else:
6595        anatname='Unk'
6596    if powers:
6597        kk = f"{k:0>3}"+"_"
6598    else:
6599        kk = f"{k % int(nPoints/2):0>3}"+"_"
6600    fname='falffPoint'+kk+anatname
6601    aname='alffPoint'+kk+anatname
6602    pname='perafPoint'+kk+anatname
6603    localsel = ptImg == k
6604    if localsel.sum() > 0 : # check if non-empty
6605        outdict[fname]=(outdict['falff'][localsel]).mean()
6606        outdict[aname]=(outdict['alff'][localsel]).mean()
6607        outdict[pname]=(perafimg[localsel]).mean()
6608    else:
6609        outdict[fname]=math.nan
6610        outdict[aname]=math.nan
6611        outdict[pname]=math.nan
6612
6613  rsfNuisance = pd.DataFrame( nuisance )
6614  if remove_it:
6615    import shutil
6616    shutil.rmtree(output_directory, ignore_errors=True )
6617
6618  if not powers:
6619    dfnsum=outdict['DefaultA']+outdict['DefaultB']+outdict['DefaultC']
6620    outdict['DefaultMode']=dfnsum
6621    dfnsum=outdict['VisCent']+outdict['VisPeri']
6622    outdict['Visual']=dfnsum
6623
6624  nonbrainmask = ants.iMath( bmask, "MD",2) - bmask
6625  trimmask = ants.iMath( bmask, "ME",2)
6626  edgemask = ants.iMath( bmask, "ME",1) - trimmask
6627  outdict['motion_corrected'] = corrmo['motion_corrected']
6628  outdict['nuisance'] = rsfNuisance
6629  outdict['PerAF'] = perafimg
6630  outdict['tsnr'] = mytsnr
6631  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
6632  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
6633  outdict['bandpass_freq_0']=f[0]
6634  outdict['bandpass_freq_1']=f[1]
6635  outdict['censor']=int(censor)
6636  outdict['spatial_smoothing']=spa
6637  outdict['outlier_threshold']=outlier_threshold
6638  outdict['FD_threshold']=outlier_threshold
6639  outdict['high_motion_count'] = high_motion_count
6640  outdict['high_motion_pct'] = high_motion_pct
6641  outdict['despiking_count_summary'] = despiking_count_summary
6642  outdict['FD_max'] = corrmo['FD'].max()
6643  outdict['FD_mean'] = corrmo['FD'].mean()
6644  outdict['FD_sd'] = corrmo['FD'].std()
6645  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
6646  outdict['n_outliers'] = len(hlinds)
6647  outdict['nc_wm'] = int(nc_wm)
6648  outdict['nc_csf'] = int(nc_csf)
6649  outdict['minutes_original_data'] = ( tr * fmri.shape[3] ) / 60.0 # minutes of useful data
6650  outdict['minutes_censored_data'] = ( tr * simg.shape[3] ) / 60.0 # minutes of useful data
6651  return convert_np_in_dict( outdict )

Compute resting state network correlation maps based on the J Power labels. This will output a map for each of the major network systems. This function will by optionally upsample data to 2mm during the registration process if data is below that resolution.

registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations

Arguments

fmri : BOLD fmri antsImage

fmri_template : reference space for BOLD

t1 : ANTsImage input 3-D T1 brain image (brain extracted)

t1segmentation : ANTsImage t1 segmentation - a six tissue segmentation image in T1 space

f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015

spa : gaussian smoothing for spatial component (physical coordinates)

spt : gaussian smoothing for temporal component

nc : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm

ica_components : integer if greater than 0 then include ica components

impute : boolean if True, then use imputation in f/ALFF, PerAF calculation

censor : boolean if True, then use censoring (censoring)

despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction

motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance

powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)

upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution

clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing. if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.

verbose : boolean

Returns

a dictionary containing the derived network maps

References

10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."

10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable" "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding high frequency information) reduced both reliability and signal-noise separation"

10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.

10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.

def write_bvals_bvecs(bvals, bvecs, prefix):
7563def write_bvals_bvecs(bvals, bvecs, prefix ):
7564    ''' Write FSL FDT bvals and bvecs files
7565
7566    adapted from dipy.external code
7567
7568    Parameters
7569    -------------
7570    bvals : (N,) sequence
7571       Vector with diffusion gradient strength (one per diffusion
7572       acquisition, N=no of acquisitions)
7573    bvecs : (N, 3) array-like
7574       diffusion gradient directions
7575    prefix : string
7576       path to write FDT bvals, bvecs text files
7577       None results in current working directory.
7578    '''
7579    _VAL_FMT = '   %e'
7580    bvals = tuple(bvals)
7581    bvecs = np.asarray(bvecs)
7582    bvecs[np.isnan(bvecs)] = 0
7583    N = len(bvals)
7584    fname = prefix + '.bval'
7585    fmt = _VAL_FMT * N + '\n'
7586    myfile = open(fname, 'wt')
7587    myfile.write(fmt % bvals)
7588    myfile.close()
7589    fname = prefix + '.bvec'
7590    bvf = open(fname, 'wt')
7591    for dim_vals in bvecs.T:
7592        bvf.write(fmt % tuple(dim_vals))
7593    bvf.close()

Write FSL FDT bvals and bvecs files

adapted from dipy.external code

Parameters

bvals : (N,) sequence Vector with diffusion gradient strength (one per diffusion acquisition, N=no of acquisitions) bvecs : (N, 3) array-like diffusion gradient directions prefix : string path to write FDT bvals, bvecs text files None results in current working directory.

def crop_mcimage(x, mask, padder=None):
7596def crop_mcimage( x, mask, padder=None ):
7597    """
7598    crop a time series (4D) image by a 3D mask
7599
7600    Parameters
7601    -------------
7602
7603    x : raw image
7604
7605    mask  : mask for cropping
7606
7607    """
7608    cropmask = ants.crop_image( mask, mask )
7609    myorig = list( ants.get_origin(cropmask) )
7610    myorig.append( ants.get_origin( x )[3] )
7611    croplist = []
7612    if len(x.shape) > 3:
7613        for k in range(x.shape[3]):
7614            temp = ants.slice_image( x, axis=3, idx=k )
7615            temp = ants.crop_image( temp, mask )
7616            if padder is not None:
7617                temp = ants.pad_image( temp, pad_width=padder )
7618            croplist.append( temp )
7619        temp = ants.list_to_ndimage( x, croplist )
7620        temp.set_origin( myorig )
7621        return temp
7622    else:
7623        return( ants.crop_image( x, mask ) )

crop a time series (4D) image by a 3D mask

Parameters

x : raw image

mask : mask for cropping

def mm( t1_image, hier, rsf_image=[], flair_image=None, nm_image_list=None, dw_image=[], bvals=[], bvecs=[], perfusion_image=None, srmodel=None, do_tractography=False, do_kk=False, do_normalization=None, group_template=None, group_transform=None, target_range=[0, 1], dti_motion_correct='antsRegistrationSyNQuickRepro[r]', dti_denoise=False, perfusion_trim=10, perfusion_m0_image=None, perfusion_m0=None, rsf_upsampling=3.0, pet_3d_image=None, test_run=False, verbose=False):
7626def mm(
7627    t1_image,
7628    hier,
7629    rsf_image=[],
7630    flair_image=None,
7631    nm_image_list=None,
7632    dw_image=[], bvals=[], bvecs=[],
7633    perfusion_image=None,
7634    srmodel=None,
7635    do_tractography = False,
7636    do_kk = False,
7637    do_normalization = None,
7638    group_template = None,
7639    group_transform = None,
7640    target_range = [0,1],
7641    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
7642    dti_denoise = False,
7643    perfusion_trim=10,
7644    perfusion_m0_image=None,
7645    perfusion_m0=None,
7646    rsf_upsampling=3.0,
7647    pet_3d_image=None,
7648    test_run = False,
7649    verbose = False ):
7650    """
7651    Multiple modality processing and normalization
7652
7653    aggregates modality-specific processing under one roof.  see individual
7654    modality specific functions for details.
7655
7656    Parameters
7657    -------------
7658
7659    t1_image : raw t1 image
7660
7661    hier  : output of antspyt1w.hierarchical ( see read hierarchical )
7662
7663    rsf_image : list of resting state fmri
7664
7665    flair_image : flair
7666
7667    nm_image_list : list of neuromelanin images
7668
7669    dw_image : list of diffusion weighted images
7670
7671    bvals : list of bvals file names
7672
7673    bvecs : list of bvecs file names
7674
7675    perfusion_image : single perfusion image
7676
7677    srmodel : optional srmodel
7678
7679    do_tractography : boolean
7680
7681    do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)
7682
7683    do_normalization : template transformation if available
7684
7685    group_template : optional reference template corresponding to the group_transform
7686
7687    group_transform : optional transforms corresponding to the group_template
7688
7689    target_range : 2-element tuple
7690        a tuple or array defining the (min, max) of the input image
7691        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
7692        intensity. This range should match the mapping used in the training
7693        of the network.
7694    
7695    dti_motion_correct : None Rigid or SyN
7696
7697    dti_denoise : boolean
7698
7699    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
7700
7701    perfusion_m0_image : optional antsImage m0 associated with the perfusion time series
7702
7703    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
7704
7705    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
7706
7707    pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of 
7708        this image.  we just process it and provide summary information.
7709
7710    test_run : boolean 
7711
7712    verbose : boolean
7713
7714    """
7715    from os.path import exists
7716    ex_path = os.path.expanduser( "~/.antspyt1w/" )
7717    ex_path_mm = os.path.expanduser( "~/.antspymm/" )
7718    mycsvfn = ex_path + "FA_JHU_labels_edited.csv"
7719    citcsvfn = ex_path + "CIT168_Reinf_Learn_v1_label_descriptions_pad.csv"
7720    dktcsvfn = ex_path + "dkt.csv"
7721    cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
7722    JHU_atlasfn = ex_path + 'JHU-ICBM-FA-1mm.nii.gz' # Read in JHU atlas
7723    JHU_labelsfn = ex_path + 'JHU-ICBM-labels-1mm.nii.gz' # Read in JHU labels
7724    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
7725    if not exists( mycsvfn ) or not exists( citcsvfn ) or not exists( cnxcsvfn ) or not exists( dktcsvfn ) or not exists( JHU_atlasfn ) or not exists( JHU_labelsfn ) or not exists( templatefn ):
7726        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
7727        raise ValueError('**missing files** => call get_data from latest antspyt1w and antspymm.')
7728    mycsv = pd.read_csv(  mycsvfn )
7729    citcsv = pd.read_csv(  os.path.expanduser( citcsvfn ) )
7730    dktcsv = pd.read_csv(  os.path.expanduser( dktcsvfn ) )
7731    cnxcsv = pd.read_csv(  os.path.expanduser( cnxcsvfn ) )
7732    JHU_atlas = mm_read( JHU_atlasfn ) # Read in JHU atlas
7733    JHU_labels = mm_read( JHU_labelsfn ) # Read in JHU labels
7734    template = mm_read( templatefn ) # Read in template
7735    if group_template is None:
7736        group_template = template
7737        group_transform = do_normalization['fwdtransforms']
7738    if verbose:
7739        print("Using group template:")
7740        print( group_template )
7741    #####################
7742    #  T1 hierarchical  #
7743    #####################
7744    t1imgbrn = hier['brain_n4_dnz']
7745    t1atropos = hier['dkt_parc']['tissue_segmentation']
7746    output_dict = {
7747        'kk': None,
7748        'rsf': None,
7749        'flair' : None,
7750        'NM' : None,
7751        'DTI' : None,
7752        'FA_summ' : None,
7753        'MD_summ' : None,
7754        'tractography' : None,
7755        'tractography_connectivity' : None,
7756        'perf' : None,
7757        'pet3d' : None,
7758    }
7759    normalization_dict = {
7760        'kk_norm': None,
7761        'NM_norm' : None,
7762        'DTI_norm': None,
7763        'FA_norm' : None,
7764        'MD_norm' : None,
7765        'perf_norm' : None,
7766        'alff_norm' : None,
7767        'falff_norm' : None,
7768        'CinguloopercularTaskControl_norm' : None,
7769        'DefaultMode_norm' : None,
7770        'MemoryRetrieval_norm' : None,
7771        'VentralAttention_norm' : None,
7772        'Visual_norm' : None,
7773        'FrontoparietalTaskControl_norm' : None,
7774        'Salience_norm' : None,
7775        'Subcortical_norm' : None,
7776        'DorsalAttention_norm' : None,
7777        'pet3d_norm' : None
7778    }
7779    if test_run:
7780        return output_dict, normalization_dict
7781
7782    if do_kk:
7783        if verbose:
7784            print('kk in mm')
7785        output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( t1_image,
7786            labels=hier['dkt_parc']['dkt_cortex'], iterations=45 )
7787
7788    if perfusion_image is not None:
7789        if perfusion_image.shape[3] > 1: # FIXME - better heuristic?
7790            output_dict['perf'] = bold_perfusion(
7791                perfusion_image,
7792                t1_image,
7793                hier['brain_n4_dnz'],
7794                t1atropos,
7795                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
7796                n_to_trim = perfusion_trim,
7797                m0_image = perfusion_m0_image,
7798                m0_indices = perfusion_m0,
7799                verbose=verbose )
7800
7801    if pet_3d_image is not None:
7802        if pet_3d_image.dimension == 3: # FIXME - better heuristic?
7803            output_dict['pet3d'] = pet3d_summary(
7804                pet_3d_image,
7805                t1_image,
7806                hier['brain_n4_dnz'],
7807                t1atropos,
7808                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
7809                verbose=verbose )
7810    ################################## do the rsf .....
7811    if len(rsf_image) > 0:
7812        my_motion_tx = 'antsRegistrationSyNRepro[r]'
7813        rsf_image = [i for i in rsf_image if i is not None]
7814        if verbose:
7815            print('rsf length ' + str( len( rsf_image ) ) )
7816        if len( rsf_image ) >= 2: # assume 2 is the largest possible value
7817            rsf_image1 = rsf_image[0]
7818            rsf_image2 = rsf_image[1]
7819            # build a template then join the images
7820            if verbose:
7821                print("initial average for rsf")
7822            rsfavg1, hlinds = loop_timeseries_censoring( rsf_image1, 0.1 )
7823            rsfavg1=get_average_rsf(rsfavg1)
7824            rsfavg2, hlinds = loop_timeseries_censoring( rsf_image2, 0.1 )
7825            rsfavg2=get_average_rsf(rsfavg2)
7826            if verbose:
7827                print("template average for rsf")
7828            init_temp = ants.image_clone( rsfavg1 )
7829            if rsf_image1.shape[3] < rsf_image2.shape[3]:
7830                init_temp = ants.image_clone( rsfavg2 )
7831            boldTemplate = ants.build_template(
7832                initial_template = init_temp,
7833                image_list=[rsfavg1,rsfavg2],
7834                type_of_transform="antsRegistrationSyNQuickRepro[s]",
7835                iterations=5, verbose=False )
7836            if verbose:
7837                print("join the 2 rsf")
7838            if rsf_image1.shape[3] > 10 and rsf_image2.shape[3] > 10:
7839                leadvols = list(range(8))
7840                rsf_image2 = remove_volumes_from_timeseries( rsf_image2, leadvols )
7841                rsf_image = merge_timeseries_data( rsf_image1, rsf_image2 )
7842            elif rsf_image1.shape[3] > rsf_image2.shape[3]:
7843                rsf_image = rsf_image1
7844            else:
7845                rsf_image = rsf_image2
7846        elif len( rsf_image ) == 1:
7847            rsf_image = rsf_image[0]
7848            boldTemplate, hlinds = loop_timeseries_censoring( rsf_image, 0.1 )
7849            boldTemplate = get_average_rsf(boldTemplate)
7850        if rsf_image.shape[3] > 10: # FIXME - better heuristic?
7851            rsfprolist = [] # FIXMERSF
7852            # Create the parameter DataFrame
7853            df = pd.DataFrame({
7854                "num": [134, 122, 129],
7855                "loop": [0.50, 0.25, 0.50],
7856                "cens": [True, True, True],
7857                "HM": [1.0, 5.0, 0.5],
7858                "ff": ["tight", "tight", "tight"],
7859                "CC": [5, 5, 0.8],
7860                "imp": [True, True, True],
7861                "up": [rsf_upsampling, rsf_upsampling, rsf_upsampling],
7862                "coords": [False,False,False]
7863            }, index=[0, 1, 2])
7864            for p in range(df.shape[0]):
7865                if verbose:
7866                    print("rsf parameters")
7867                    print( df.iloc[p] )
7868                if df['ff'].iloc[p] == 'broad':
7869                    f=[ 0.008, 0.15 ]
7870                elif df['ff'].iloc[p] == 'tight':
7871                    f=[ 0.03, 0.08 ]
7872                elif df['ff'].iloc[p] == 'mid':
7873                    f=[ 0.01, 0.1 ]
7874                elif df['ff'].iloc[p] == 'mid2':
7875                    f=[ 0.01, 0.08 ]
7876                else:
7877                    raise ValueError("we do not recognize this parameter choice for frequency filtering: " + df['ff'].iloc[p] )
7878                HM = df['HM'].iloc[p]
7879                CC = df['CC'].iloc[p]
7880                loop= df['loop'].iloc[p]
7881                cens =df['cens'].iloc[p]
7882                imp = df['imp'].iloc[p]
7883                rsf0 = resting_state_fmri_networks(
7884                                            rsf_image,
7885                                            boldTemplate,
7886                                            hier['brain_n4_dnz'],
7887                                            t1atropos,
7888                                            f=f,
7889                                            FD_threshold=HM, 
7890                                            spa = None, 
7891                                            spt = None, 
7892                                            nc = CC,
7893                                            outlier_threshold=loop,
7894                                            ica_components = 0,
7895                                            impute = imp,
7896                                            censor = cens,
7897                                            despike = 2.5,
7898                                            motion_as_nuisance = True,
7899                                            upsample=df['up'].iloc[p],
7900                                            clean_tmp=0.66,
7901                                            paramset=df['num'].iloc[p],
7902                                            powers=df['coords'].iloc[p],
7903                                            verbose=verbose ) # default
7904                rsfprolist.append( rsf0 )
7905            output_dict['rsf'] = rsfprolist
7906
7907    if nm_image_list is not None:
7908        if verbose:
7909            print('nm')
7910        if srmodel is None:
7911            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], verbose=verbose )
7912        else:
7913            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], srmodel=srmodel, target_range=target_range, verbose=verbose  )
7914################################## do the dti .....
7915    if len(dw_image) > 0 :
7916        if verbose:
7917            print('dti-x')
7918        if len( dw_image ) == 1: # use T1 for distortion correction and brain extraction
7919            if verbose:
7920                print("We have only one DTI: " + str(len(dw_image)))
7921            dw_image = dw_image[0]
7922            btpB0, btpDW = get_average_dwi_b0(dw_image)
7923            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
7924            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
7925                syn_metric='CC', syn_sampling=2,
7926                reg_iterations=[50,50,20],
7927                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
7928                initial_transform=initrig
7929                )
7930            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
7931            btpDW = ants.apply_transforms( btpDW, btpDW,
7932                tempreg['invtransforms'][1], interpolator='linear')
7933            btpB0 = ants.apply_transforms( btpB0, btpB0,
7934                tempreg['invtransforms'][1], interpolator='linear')
7935            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'][1], interpolator='nearestNeighbor')
7936            # dwimask = ants.iMath(dwimask,'MD',1)
7937            t12dwi = ants.apply_transforms( btpDW, hier['brain_n4_dnz'], tempreg['fwdtransforms'][1], interpolator='linear')
7938            output_dict['DTI'] = joint_dti_recon(
7939                dw_image,
7940                bvals[0],
7941                bvecs[0],
7942                jhu_atlas=JHU_atlas,
7943                jhu_labels=JHU_labels,
7944                brain_mask = dwimask,
7945                reference_B0 = btpB0,
7946                reference_DWI = btpDW,
7947                srmodel=srmodel,
7948                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
7949                denoise=dti_denoise,
7950                verbose = verbose)
7951        else :  # use phase encoding acquisitions for distortion correction and T1 for brain extraction
7952            if verbose:
7953                print("We have both DTI_LR and DTI_RL: " + str(len(dw_image)))
7954            a1b,a1w=get_average_dwi_b0(dw_image[0])
7955            a2b,a2w=get_average_dwi_b0(dw_image[1],fixed_b0=a1b,fixed_dwi=a1w)
7956            btpB0, btpDW = dti_template(
7957                b_image_list=[a1b,a2b],
7958                w_image_list=[a1w,a2w],
7959                iterations=7, verbose=verbose )
7960            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
7961            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
7962                syn_metric='CC', syn_sampling=2,
7963                reg_iterations=[50,50,20],
7964                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
7965                initial_transform=initrig
7966                )
7967            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
7968            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'], interpolator='nearestNeighbor')
7969            output_dict['DTI'] = joint_dti_recon(
7970                dw_image[0],
7971                bvals[0],
7972                bvecs[0],
7973                jhu_atlas=JHU_atlas,
7974                jhu_labels=JHU_labels,
7975                brain_mask = dwimask,
7976                reference_B0 = btpB0,
7977                reference_DWI = btpDW,
7978                srmodel=srmodel,
7979                img_RL=dw_image[1],
7980                bval_RL=bvals[1],
7981                bvec_RL=bvecs[1],
7982                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
7983                denoise=dti_denoise,
7984                verbose = verbose)
7985        mydti = output_dict['DTI']
7986        # summarize dwi with T1 outputs
7987        # first - register ....
7988        reg = ants.registration( mydti['recon_fa'], hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[s]', total_sigma=1.0 )
7989        ##################################################
7990        output_dict['FA_summ'] = hierarchical_modality_summary(
7991            mydti['recon_fa'],
7992            hier=hier,
7993            modality_name='fa',
7994            transformlist=reg['fwdtransforms'],
7995            verbose = False )
7996        ##################################################
7997        output_dict['MD_summ'] = hierarchical_modality_summary(
7998            mydti['recon_md'],
7999            hier=hier,
8000            modality_name='md',
8001            transformlist=reg['fwdtransforms'],
8002            verbose = False )
8003        # these inputs should come from nicely processed data
8004        dktmapped = ants.apply_transforms(
8005            mydti['recon_fa'],
8006            hier['dkt_parc']['dkt_cortex'],
8007            reg['fwdtransforms'], interpolator='nearestNeighbor' )
8008        citmapped = ants.apply_transforms(
8009            mydti['recon_fa'],
8010            hier['cit168lab'],
8011            reg['fwdtransforms'], interpolator='nearestNeighbor' )
8012        dktmapped[ citmapped > 0]=0
8013        mask = ants.threshold_image( mydti['recon_fa'], 0.01, 2.0 ).iMath("GetLargestComponent")
8014        if do_tractography: # dwi_deterministic_tracking dwi_closest_peak_tracking
8015            output_dict['tractography'] = dwi_deterministic_tracking(
8016                mydti['dwi_LR_dewarped'],
8017                mydti['recon_fa'],
8018                mydti['bval_LR'],
8019                mydti['bvec_LR'],
8020                seed_density = 1,
8021                mask=mask,
8022                verbose = verbose )
8023            mystr = output_dict['tractography']
8024            output_dict['tractography_connectivity'] = dwi_streamline_connectivity( mystr['streamlines'], dktmapped+citmapped, cnxcsv, verbose=verbose )
8025    ################################## do the flair .....
8026    if flair_image is not None:
8027        if verbose:
8028            print('flair')
8029        wmhprior = None
8030        priorfn = ex_path_mm + 'CIT168_wmhprior_700um_pad_adni.nii.gz'
8031        if ( exists( priorfn ) ):
8032            wmhprior = ants.image_read( priorfn )
8033            wmhprior = ants.apply_transforms( t1_image, wmhprior, do_normalization['invtransforms'] )
8034        output_dict['flair'] = boot_wmh( flair_image, t1_image, t1atropos,
8035            prior_probability=wmhprior, verbose=verbose )
8036    #################################################################
8037    ### NOTES: deforming to a common space and writing out images ###
8038    ### images we want come from: DTI, NM, rsf, thickness ###########
8039    #################################################################
8040    if do_normalization is not None:
8041        if verbose:
8042            print('normalization')
8043        # might reconsider this template space - cropped and/or higher res?
8044        # template = ants.resample_image( template, [1,1,1], use_voxels=False )
8045        # t1reg = ants.registration( template, hier['brain_n4_dnz'], "antsRegistrationSyNQuickRepro[s]")
8046        t1reg = do_normalization
8047        if do_kk:
8048            normalization_dict['kk_norm'] = ants.apply_transforms( group_template, output_dict['kk']['thickness_image'], group_transform )
8049        if output_dict['DTI'] is not None:
8050            mydti = output_dict['DTI']
8051            dtirig = ants.registration( hier['brain_n4_dnz'], mydti['recon_fa'], 'antsRegistrationSyNRepro[r]' )
8052            normalization_dict['MD_norm'] = ants.apply_transforms( group_template, mydti['recon_md'],group_transform+dtirig['fwdtransforms'] )
8053            normalization_dict['FA_norm'] = ants.apply_transforms( group_template, mydti['recon_fa'],group_transform+dtirig['fwdtransforms'] )
8054            output_directory = tempfile.mkdtemp()
8055            do_dti_norm=False
8056            if do_dti_norm:
8057                comptx = ants.apply_transforms( group_template, group_template, group_transform+dtirig['fwdtransforms'], compose = output_directory + '/xxx' )
8058                tspc=[2.,2.,2.]
8059                if srmodel is not None:
8060                    tspc=[1.,1.,1.]
8061                group_template2mm = ants.resample_image( group_template, tspc  )
8062                normalization_dict['DTI_norm'] = transform_and_reorient_dti( group_template2mm, mydti['dti'], comptx, verbose=False )
8063            import shutil
8064            shutil.rmtree(output_directory, ignore_errors=True )
8065        if output_dict['rsf'] is not None:
8066            if False:
8067                rsfpro = output_dict['rsf'] # FIXME
8068                rsfrig = ants.registration( hier['brain_n4_dnz'], rsfpro['meanBold'], 'antsRegistrationSyNRepro[r]' )
8069                for netid in get_antsimage_keys( rsfpro ):
8070                    rsfkey = netid + "_norm"
8071                    normalization_dict[rsfkey] = ants.apply_transforms(
8072                        group_template, rsfpro[netid],
8073                        group_transform+rsfrig['fwdtransforms'] )
8074        if output_dict['perf'] is not None: # zizzer
8075            comptx = group_transform + output_dict['perf']['t1reg']['invtransforms']
8076            normalization_dict['perf_norm'] = ants.apply_transforms( group_template,
8077                output_dict['perf']['perfusion'], comptx,
8078                whichtoinvert=[False,False,True,False] )
8079            normalization_dict['cbf_norm'] = ants.apply_transforms( group_template,
8080                output_dict['perf']['cbf'], comptx,
8081                whichtoinvert=[False,False,True,False] )
8082        if output_dict['pet3d'] is not None: # zizzer
8083            secondTx=output_dict['pet3d']['t1reg']['invtransforms']
8084            comptx = group_transform + secondTx
8085            if len( secondTx ) == 2:
8086                wti=[False,False,True,False]
8087            else:
8088                wti=[False,False,True]
8089            normalization_dict['pet3d_norm'] = ants.apply_transforms( group_template,
8090                output_dict['pet3d']['pet3d'], comptx,
8091                whichtoinvert=wti )
8092        if nm_image_list is not None:
8093            nmpro = output_dict['NM']
8094            nmrig = nmpro['t1_to_NM_transform'] # this is an inverse tx
8095            normalization_dict['NM_norm'] = ants.apply_transforms( group_template, nmpro['NM_avg'], group_transform+nmrig,
8096                whichtoinvert=[False,False,True])
8097
8098    if verbose:
8099        print('mm done')
8100    return output_dict, normalization_dict

Multiple modality processing and normalization

aggregates modality-specific processing under one roof. see individual modality specific functions for details.

Parameters

t1_image : raw t1 image

hier : output of antspyt1w.hierarchical ( see read hierarchical )

rsf_image : list of resting state fmri

flair_image : flair

nm_image_list : list of neuromelanin images

dw_image : list of diffusion weighted images

bvals : list of bvals file names

bvecs : list of bvecs file names

perfusion_image : single perfusion image

srmodel : optional srmodel

do_tractography : boolean

do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)

do_normalization : template transformation if available

group_template : optional reference template corresponding to the group_transform

group_transform : optional transforms corresponding to the group_template

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

dti_motion_correct : None Rigid or SyN

dti_denoise : boolean

perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series

perfusion_m0_image : optional antsImage m0 associated with the perfusion time series

perfusion_m0 : optional list containing indices of the m0 in the perfusion time series

rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done

pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of this image. we just process it and provide summary information.

test_run : boolean

verbose : boolean

def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False):
8103def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False ):
8104    """
8105    write the tabular and normalization output of the mm function
8106
8107    Parameters
8108    -------------
8109
8110    output_prefix : prefix for file outputs - modality specific postfix will be added
8111
8112    mm  : output of mm function for modality-space processing should be a dictionary with 
8113        dictionary entries for each modality.
8114
8115    mm_norm : output of mm function for normalized processing
8116
8117    t1wide : wide output data frame from t1 hierarchical
8118
8119    separator : string or character separator for filenames
8120
8121    verbose : boolean
8122
8123    Returns
8124    ---------
8125
8126    both csv and image files written to disk.  the primary outputs will be
8127    output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images
8128
8129    """
8130    from dipy.io.streamline import save_tractogram
8131    if mm_norm is not None:
8132        for mykey in mm_norm:
8133            tempfn = output_prefix + separator + mykey + '.nii.gz'
8134            if mm_norm[mykey] is not None:
8135                image_write_with_thumbnail( mm_norm[mykey], tempfn )
8136    thkderk = None
8137    if t1wide is not None:
8138        thkderk = t1wide.iloc[: , 1:]
8139    kkderk = None
8140    if 'kk' in mm:
8141        if mm['kk'] is not None:
8142            kkderk = mm['kk']['thickness_dataframe'].iloc[: , 1:]
8143            mykey='thickness_image'
8144            tempfn = output_prefix + separator + mykey + '.nii.gz'
8145            image_write_with_thumbnail( mm['kk'][mykey], tempfn )
8146    nmderk = None
8147    if 'NM' in mm:
8148        if mm['NM'] is not None:
8149            nmderk = mm['NM']['NM_dataframe_wide'].iloc[: , 1:]
8150            for mykey in get_antsimage_keys( mm['NM'] ):
8151                tempfn = output_prefix + separator + mykey + '.nii.gz'
8152                image_write_with_thumbnail( mm['NM'][mykey], tempfn, thumb=False )
8153
8154    faderk = mdderk = fat1derk = mdt1derk = None
8155
8156    if 'DTI' in mm:
8157        if mm['DTI'] is not None:
8158            mydti = mm['DTI']
8159            myop = output_prefix + separator
8160            ants.image_write( mydti['dti'],  myop + 'dti.nii.gz' )
8161            write_bvals_bvecs( mydti['bval_LR'], mydti['bvec_LR'], myop + 'reoriented' )
8162            image_write_with_thumbnail( mydti['dwi_LR_dewarped'],  myop + 'dwi.nii.gz' )
8163            image_write_with_thumbnail( mydti['dtrecon_LR_dewarp']['RGB'] ,  myop + 'DTIRGB.nii.gz' )
8164            image_write_with_thumbnail( mydti['jhu_labels'],  myop+'dtijhulabels.nii.gz', mydti['recon_fa'] )
8165            image_write_with_thumbnail( mydti['recon_fa'],  myop+'dtifa.nii.gz' )
8166            image_write_with_thumbnail( mydti['recon_md'],  myop+'dtimd.nii.gz' )
8167            image_write_with_thumbnail( mydti['b0avg'],  myop+'b0avg.nii.gz' )
8168            image_write_with_thumbnail( mydti['dwiavg'],  myop+'dwiavg.nii.gz' )
8169            faderk = mm['DTI']['recon_fa_summary'].iloc[: , 1:]
8170            mdderk = mm['DTI']['recon_md_summary'].iloc[: , 1:]
8171            fat1derk = mm['FA_summ'].iloc[: , 1:]
8172            mdt1derk = mm['MD_summ'].iloc[: , 1:]
8173    if 'tractography' in mm:
8174        if mm['tractography'] is not None:
8175            ofn = output_prefix + separator + 'tractogram.trk'
8176            if mm['tractography']['tractogram'] is not None:
8177                save_tractogram( mm['tractography']['tractogram'], ofn )
8178    cnxderk = None
8179    if 'tractography_connectivity' in mm:
8180        if mm['tractography_connectivity'] is not None:
8181            cnxderk = mm['tractography_connectivity']['connectivity_wide'].iloc[: , 1:] # NOTE: connectivity_wide is not much tested
8182            ofn = output_prefix + separator + 'dtistreamlineconn.csv'
8183            pd.DataFrame(mm['tractography_connectivity']['connectivity_matrix']).to_csv( ofn )
8184
8185    dlist = [
8186        thkderk,
8187        kkderk,
8188        nmderk,
8189        faderk,
8190        mdderk,
8191        fat1derk,
8192        mdt1derk,
8193        cnxderk
8194        ]
8195    is_all_none = all(element is None for element in dlist)
8196    if is_all_none:
8197        mm_wide = pd.DataFrame({'u_hier_id': [output_prefix] })
8198    else:
8199        mm_wide = pd.concat( dlist, axis=1, ignore_index=False )
8200
8201    mm_wide = mm_wide.copy()
8202    if 'NM' in mm:
8203        if mm['NM'] is not None:
8204            nmwide = dict_to_dataframe( mm['NM'] )
8205            if mm_wide.shape[0] > 0 and nmwide.shape[0] > 0:
8206                nmwide.set_index( mm_wide.index, inplace=True )
8207            mm_wide = pd.concat( [mm_wide, nmwide ], axis=1, ignore_index=False )
8208    if 'flair' in mm:
8209        if mm['flair'] is not None:
8210            myop = output_prefix + separator + 'wmh.nii.gz'
8211            pngfnb = output_prefix + separator + 'wmh_seg.png'
8212            ants.plot( mm['flair']['flair'], mm['flair']['WMH_posterior_probability_map'], axis=2, nslices=21, ncol=7, filename=pngfnb, crop=True )
8213            if mm['flair']['WMH_probability_map'] is not None:
8214                image_write_with_thumbnail( mm['flair']['WMH_probability_map'], myop, thumb=False )
8215            flwide = dict_to_dataframe( mm['flair'] )
8216            if mm_wide.shape[0] > 0 and flwide.shape[0] > 0:
8217                flwide.set_index( mm_wide.index, inplace=True )
8218            mm_wide = pd.concat( [mm_wide, flwide ], axis=1, ignore_index=False )
8219    if 'rsf' in mm:
8220        if mm['rsf'] is not None:
8221            fcnxpro=99
8222            rsfdata = mm['rsf']
8223            if not isinstance( rsfdata, list ):
8224                rsfdata = [ rsfdata ]
8225            for rsfpro in rsfdata:
8226                fcnxpro=str( rsfpro['paramset']  )
8227                pronum = 'fcnxpro'+str(fcnxpro)+"_"
8228                if verbose:
8229                    print("Collect rsf data " + pronum)
8230                new_rsf_wide = dict_to_dataframe( rsfpro )
8231                new_rsf_wide = pd.concat( [new_rsf_wide, rsfpro['corr_wide'] ], axis=1, ignore_index=False )
8232                new_rsf_wide = new_rsf_wide.add_prefix( pronum )
8233                new_rsf_wide.set_index( mm_wide.index, inplace=True )
8234                ofn = output_prefix + separator + pronum + '.csv'
8235                new_rsf_wide.to_csv( ofn )
8236                mm_wide = pd.concat( [mm_wide, new_rsf_wide ], axis=1, ignore_index=False )
8237                for mykey in get_antsimage_keys( rsfpro ):
8238                    myop = output_prefix + separator + pronum + mykey + '.nii.gz'
8239                    image_write_with_thumbnail( rsfpro[mykey], myop, thumb=True )
8240                ofn = output_prefix + separator + pronum + 'rsfcorr.csv'
8241                rsfpro['corr'].to_csv( ofn )
8242                # apply same principle to new correlation matrix, doesn't need to be incorporated with mm_wide
8243                ofn2 = output_prefix + separator + pronum + 'nodescorr.csv'
8244                rsfpro['fullCorrMat'].to_csv( ofn2 )
8245    if 'DTI' in mm:
8246        if mm['DTI'] is not None:
8247            mydti = mm['DTI']
8248            mm_wide['dti_tsnr_b0_mean'] =  mydti['tsnr_b0'].mean()
8249            mm_wide['dti_tsnr_dwi_mean'] =  mydti['tsnr_dwi'].mean()
8250            mm_wide['dti_dvars_b0_mean'] =  mydti['dvars_b0'].mean()
8251            mm_wide['dti_dvars_dwi_mean'] =  mydti['dvars_dwi'].mean()
8252            mm_wide['dti_ssnr_b0_mean'] =  mydti['ssnr_b0'].mean()
8253            mm_wide['dti_ssnr_dwi_mean'] =  mydti['ssnr_dwi'].mean()
8254            mm_wide['dti_fa_evr'] =  mydti['fa_evr']
8255            mm_wide['dti_fa_SNR'] =  mydti['fa_SNR']
8256            if mydti['framewise_displacement'] is not None:
8257                mm_wide['dti_high_motion_count'] =  mydti['high_motion_count']
8258                mm_wide['dti_FD_mean'] = mydti['framewise_displacement'].mean()
8259                mm_wide['dti_FD_max'] = mydti['framewise_displacement'].max()
8260                mm_wide['dti_FD_sd'] = mydti['framewise_displacement'].std()
8261                fdfn = output_prefix + separator + '_fd.csv'
8262            else:
8263                mm_wide['dti_FD_mean'] = mm_wide['dti_FD_max'] = mm_wide['dti_FD_sd'] = 'NA'
8264
8265    if 'perf' in mm:
8266        if mm['perf'] is not None:
8267            perfpro = mm['perf']
8268            prwide = dict_to_dataframe( perfpro )
8269            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
8270                prwide.set_index( mm_wide.index, inplace=True )
8271            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
8272            if 'perf_dataframe' in perfpro.keys():
8273                pderk = perfpro['perf_dataframe'].iloc[: , 1:]
8274                pderk.set_index( mm_wide.index, inplace=True )
8275                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
8276            else:
8277                print("FIXME - perfusion dataframe")
8278            for mykey in get_antsimage_keys( mm['perf'] ):
8279                tempfn = output_prefix + separator + mykey + '.nii.gz'
8280                image_write_with_thumbnail( mm['perf'][mykey], tempfn, thumb=False )
8281
8282    if 'pet3d' in mm:
8283        if mm['pet3d'] is not None:
8284            pet3dpro = mm['pet3d']
8285            prwide = dict_to_dataframe( pet3dpro )
8286            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
8287                prwide.set_index( mm_wide.index, inplace=True )
8288            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
8289            if 'pet3d_dataframe' in pet3dpro.keys():
8290                pderk = pet3dpro['pet3d_dataframe'].iloc[: , 1:]
8291                pderk.set_index( mm_wide.index, inplace=True )
8292                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
8293            else:
8294                print("FIXME - pet3dusion dataframe")
8295            for mykey in get_antsimage_keys( mm['pet3d'] ):
8296                tempfn = output_prefix + separator + mykey + '.nii.gz'
8297                image_write_with_thumbnail( mm['pet3d'][mykey], tempfn, thumb=False )
8298
8299    mmwidefn = output_prefix + separator + 'mmwide.csv'
8300    mm_wide.to_csv( mmwidefn )
8301    if verbose:
8302        print( output_prefix + " write_mm done." )
8303    return

write the tabular and normalization output of the mm function

Parameters

output_prefix : prefix for file outputs - modality specific postfix will be added

mm : output of mm function for modality-space processing should be a dictionary with dictionary entries for each modality.

mm_norm : output of mm function for normalized processing

t1wide : wide output data frame from t1 hierarchical

separator : string or character separator for filenames

verbose : boolean

Returns

both csv and image files written to disk. the primary outputs will be output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images

def mm_nrg( studyid, sourcedir='/Users/stnava/data/PPMI/MV/example_s3_b/images/PPMI/', sourcedatafoldername='images', processDir='processed', mysep='-', srmodel_T1=False, srmodel_NM=False, srmodel_DTI=False, visualize=True, nrg_modality_list=['T1w', 'NM2DMT', 'DTI', 'T2Flair', 'rsfMRI'], verbose=True):
8306def mm_nrg(
8307    studyid,   # pandas data frame
8308    sourcedir = os.path.expanduser( "~/data/PPMI/MV/example_s3_b/images/PPMI/" ),
8309    sourcedatafoldername = 'images', # root for source data
8310    processDir = "processed", # where output will go - parallel to sourcedatafoldername
8311    mysep = '-', # define a separator for filename components
8312    srmodel_T1 = False, # optional - will add a great deal of time
8313    srmodel_NM = False, # optional - will add a great deal of time
8314    srmodel_DTI = False, # optional - will add a great deal of time
8315    visualize = True,
8316    nrg_modality_list = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI" ],
8317    verbose = True
8318):
8319    """
8320    too dangerous to document ... use with care.
8321
8322    processes multiple modality MRI specifically:
8323
8324    * T1w
8325    * T2Flair
8326    * DTI, DTI_LR, DTI_RL
8327    * rsfMRI, rsfMRI_LR, rsfMRI_RL
8328    * NM2DMT (neuromelanin)
8329
8330    other modalities may be added later ...
8331
8332    "trust me, i know what i'm doing" - sledgehammer
8333
8334    convert to pynb via:
8335        p2j mm.py -o
8336
8337    convert the ipynb to html via:
8338        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
8339
8340    this function assumes NRG format for the input data ....
8341    we also assume that t1w hierarchical (if already done) was written
8342    via its standardized write function.
8343    NRG = https://github.com/stnava/biomedicalDataOrganization
8344
8345    this function is verbose
8346
8347    Parameters
8348    -------------
8349
8350    studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID
8351        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
8352        these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor,
8353        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
8354        t1 is required.  rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.  see antspymm.generate_mm_dataframe
8355
8356    sourcedir : a study specific folder containing individual subject folders
8357
8358    sourcedatafoldername : root for source data e.g. "images"
8359
8360    processDir : where output will go - parallel to sourcedatafoldername e.g.
8361        "processed"
8362
8363    mysep : define a character separator for filename components
8364
8365    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
8366
8367    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
8368
8369    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
8370
8371    visualize : True - will plot some results to png
8372
8373    nrg_modality_list : list of permissible modalities - always include [T1w] as base
8374
8375    verbose : boolean
8376
8377    Returns
8378    ---------
8379
8380    writes output to disk and potentially produces figures that may be
8381    captured in a ipynb / html file.
8382
8383    """
8384    studyid = studyid.dropna(axis=1)
8385    if studyid.shape[0] < 1:
8386        raise ValueError('studyid has no rows')
8387    musthavecols = ['subjectID','date','imageID']
8388    for k in range(len(musthavecols)):
8389        if not musthavecols[k] in studyid.keys():
8390            raise ValueError('studyid is missing column ' +musthavecols[k] )
8391    def makewideout( x, separator = '-' ):
8392        return x + separator + 'mmwide.csv'
8393    if nrg_modality_list[0] != 'T1w':
8394        nrg_modality_list.insert(0, "T1w" )
8395    testloop = False
8396    counter=0
8397    import glob as glob
8398    from os.path import exists
8399    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8400    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
8401    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8402    if not exists( templatefn ):
8403        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
8404        antspyt1w.get_data( force_download=True )
8405        get_data( force_download=True )
8406    temp = sourcedir.split( "/" )
8407    splitCount = len( temp )
8408    template = mm_read( templatefn ) # Read in template
8409    test_run = False
8410    if test_run:
8411        visualize=False
8412    # get sid and dtid from studyid
8413    sid = str(studyid['subjectID'].iloc[0])
8414    dtid = str(studyid['date'].iloc[0])
8415    iid = str(studyid['imageID'].iloc[0])
8416    subjectrootpath = os.path.join(sourcedir,sid, dtid)
8417    if verbose:
8418        print("subjectrootpath: "+ subjectrootpath )
8419    myimgsInput = glob.glob( subjectrootpath+"/*" )
8420    myimgsInput.sort( )
8421    if verbose:
8422        print( myimgsInput )
8423    # hierarchical
8424    # NOTE: if there are multiple T1s for this time point, should take
8425    # the one with the highest resnetGrade
8426    t1_search_path = os.path.join(subjectrootpath, "T1w", iid, "*nii.gz")
8427    if verbose:
8428        print(f"t1 search path: {t1_search_path}")
8429    t1fn = glob.glob(t1_search_path)
8430    t1fn.sort()
8431    if len( t1fn ) < 1:
8432        raise ValueError('mm_nrg cannot find the T1w with uid ' + iid + ' @ ' + subjectrootpath )
8433    t1fn = t1fn[0]
8434    t1 = mm_read( t1fn )
8435    hierfn0 = re.sub( sourcedatafoldername, processDir, t1fn)
8436    hierfn0 = re.sub( ".nii.gz", "", hierfn0)
8437    hierfn = re.sub( "T1w", "T1wHierarchical", hierfn0)
8438    hierfn = hierfn + mysep
8439    hierfntest = hierfn + 'snseg.csv'
8440    regout = hierfn0 + mysep + "syn"
8441    templateTx = {
8442        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
8443        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
8444    if verbose:
8445        print( "-<REGISTRATION EXISTENCE>-: \n" + 
8446              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
8447            str(exists( templateTx['fwdtransforms'][0])) + " " +
8448            str(exists( templateTx['fwdtransforms'][1])) + " " +
8449            str(exists( templateTx['invtransforms'][0])) + " " +
8450            str(exists( templateTx['invtransforms'][1])) )
8451    if verbose:
8452        print( hierfntest )
8453    hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
8454    hier = None
8455    if not hierexists and not testloop:
8456        subjectpropath = os.path.dirname( hierfn )
8457        if verbose:
8458            print( subjectpropath )
8459        os.makedirs( subjectpropath, exist_ok=True  )
8460        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
8461        antspyt1w.write_hierarchical( hier, hierfn )
8462        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8463                hier['dataframes'], identifier=None )
8464        t1wide.to_csv( hierfn + 'mmwide.csv' )
8465    ################# read the hierarchical data ###############################
8466    hier = antspyt1w.read_hierarchical( hierfn )
8467    if exists( hierfn + 'mmwide.csv' ) :
8468        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
8469    elif not testloop:
8470        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8471                hier['dataframes'], identifier=None )
8472    if srmodel_T1 is not None :
8473        hierfnSR = re.sub( sourcedatafoldername, processDir, t1fn)
8474        hierfnSR = re.sub( "T1w", "T1wHierarchicalSR", hierfnSR)
8475        hierfnSR = re.sub( ".nii.gz", "", hierfnSR)
8476        hierfnSR = hierfnSR + mysep
8477        hierfntest = hierfnSR + 'mtl.csv'
8478        if verbose:
8479            print( hierfntest )
8480        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
8481        if not hierexists:
8482            subjectpropath = os.path.dirname( hierfnSR )
8483            if verbose:
8484                print( subjectpropath )
8485            os.makedirs( subjectpropath, exist_ok=True  )
8486            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
8487            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
8488            mdlfn = re.sub( 'bestup', bestup, srmodel_T1 ) 
8489            if verbose:
8490                print( mdlfn )
8491            if exists( mdlfn ):
8492                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
8493            else:
8494                print( mdlfn + " does not exist - will not run.")
8495            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
8496            antspyt1w.write_hierarchical( hierSR, hierfnSR )
8497            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8498                    hierSR['dataframes'], identifier=None )
8499            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
8500    hier = antspyt1w.read_hierarchical( hierfn )
8501    if exists( hierfn + 'mmwide.csv' ) :
8502        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
8503    elif not testloop:
8504        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8505                hier['dataframes'], identifier=None )
8506    if not testloop:
8507        t1imgbrn = hier['brain_n4_dnz']
8508        t1atropos = hier['dkt_parc']['tissue_segmentation']
8509    # loop over modalities and then unique image IDs
8510    # we treat NM in a "special" way -- aggregating repeats
8511    # other modalities (beyond T1) are treated individually
8512    nimages = len(myimgsInput)
8513    if verbose:
8514        print(  " we have : " + str(nimages) + " modalities.")
8515    for overmodX in nrg_modality_list:
8516        counter=counter+1
8517        if counter > (len(nrg_modality_list)+1):
8518            print("This is weird. " + str(counter))
8519            return
8520        if overmodX == 'T1w':
8521            iidOtherMod = iid
8522            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8523            myimgsr = glob.glob(mod_search_path)
8524        elif overmodX == 'NM2DMT' and ('nmid1' in studyid.keys() ):
8525            iidOtherMod = str( int(studyid['nmid1'].iloc[0]) )
8526            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8527            myimgsr = glob.glob(mod_search_path)
8528            for nmnum in range(2,11):
8529                locnmnum = 'nmid'+str(nmnum)
8530                if locnmnum in studyid.keys() :
8531                    iidOtherMod = str( int(studyid[locnmnum].iloc[0]) )
8532                    mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8533                    myimgsr.append( glob.glob(mod_search_path)[0] )
8534        elif 'rsfMRI' in overmodX and ( ( 'rsfid1' in studyid.keys() ) or ('rsfid2' in studyid.keys() ) ):
8535            myimgsr = []
8536            if  'rsfid1' in studyid.keys():
8537                iidOtherMod = str( int(studyid['rsfid1'].iloc[0]) )
8538                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8539                myimgsr.append( glob.glob(mod_search_path)[0] )
8540            if  'rsfid2' in studyid.keys():
8541                iidOtherMod = str( int(studyid['rsfid2'].iloc[0]) )
8542                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8543                myimgsr.append( glob.glob(mod_search_path)[0] )
8544        elif 'DTI' in overmodX and (  'dtid1' in studyid.keys() or  'dtid2' in studyid.keys() ):
8545            myimgsr = []
8546            if  'dtid1' in studyid.keys():
8547                iidOtherMod = str( int(studyid['dtid1'].iloc[0]) )
8548                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8549                myimgsr.append( glob.glob(mod_search_path)[0] )
8550            if  'dtid2' in studyid.keys():
8551                iidOtherMod = str( int(studyid['dtid2'].iloc[0]) )
8552                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8553                myimgsr.append( glob.glob(mod_search_path)[0] )
8554        elif 'T2Flair' in overmodX and ('flairid' in studyid.keys() ):
8555            iidOtherMod = str( int(studyid['flairid'].iloc[0]) )
8556            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8557            myimgsr = glob.glob(mod_search_path)
8558        if verbose:
8559            print( "overmod " + overmodX + " " + iidOtherMod )
8560            print(f"modality search path: {mod_search_path}")
8561        myimgsr.sort()
8562        if len(myimgsr) > 0:
8563            overmodXx = str(overmodX)
8564            dowrite=False
8565            if verbose:
8566                print( 'overmodX is : ' + overmodXx )
8567                print( 'example image name is : '  )
8568                print( myimgsr )
8569            if overmodXx == 'NM2DMT':
8570                myimgsr2 = myimgsr
8571                myimgsr2.sort()
8572                is4d = False
8573                temp = ants.image_read( myimgsr2[0] )
8574                if temp.dimension == 4:
8575                    is4d = True
8576                if len( myimgsr2 ) == 1 and not is4d: # check dimension
8577                    myimgsr2 = myimgsr2 + myimgsr2
8578                subjectpropath = os.path.dirname( myimgsr2[0] )
8579                subjectpropath = re.sub( sourcedatafoldername, processDir,subjectpropath )
8580                if verbose:
8581                    print( "subjectpropath " + subjectpropath )
8582                mysplit = subjectpropath.split( "/" )
8583                os.makedirs( subjectpropath, exist_ok=True  )
8584                mysplitCount = len( mysplit )
8585                project = mysplit[mysplitCount-5]
8586                subject = mysplit[mysplitCount-4]
8587                date = mysplit[mysplitCount-3]
8588                modality = mysplit[mysplitCount-2]
8589                uider = mysplit[mysplitCount-1]
8590                identifier = mysep.join([project, subject, date, modality ])
8591                identifier = identifier + "_" + iid
8592                mymm = subjectpropath + "/" + identifier
8593                mymmout = makewideout( mymm )
8594                if verbose and not exists( mymmout ):
8595                    print( "NM " + mymm  + ' execution ')
8596                elif verbose and exists( mymmout ) :
8597                    print( "NM " + mymm + ' complete ' )
8598                if exists( mymmout ):
8599                    continue
8600                if is4d:
8601                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
8602                else:
8603                    nmlist = []
8604                    for zz in myimgsr2:
8605                        nmlist.append( mm_read( zz ) )
8606                srmodel_NM_mdl = None
8607                if srmodel_NM is not None:
8608                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
8609                    if isinstance( srmodel_NM, str ):
8610                        mdlfn = re.sub( "bestup", bestup, srmodel_NM )
8611                    if exists( mdlfn ):
8612                        if verbose:
8613                            print(mdlfn)
8614                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
8615                    else:
8616                        print( mdlfn + " does not exist - wont use SR")
8617                if not testloop:
8618                    tabPro, normPro = mm( t1, hier,
8619                            nm_image_list = nmlist,
8620                            srmodel=srmodel_NM_mdl,
8621                            do_tractography=False,
8622                            do_kk=False,
8623                            do_normalization=templateTx,
8624                            test_run=test_run,
8625                            verbose=True )
8626                    if not test_run:
8627                        write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=None, separator=mysep )
8628                        nmpro = tabPro['NM']
8629                        mysl = range( nmpro['NM_avg'].shape[2] )
8630                    if visualize:
8631                        mysl = range( nmpro['NM_avg'].shape[2] )
8632                        ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
8633                        mysl = range( nmpro['NM_avg_cropped'].shape[2] )
8634                        ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
8635                        ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
8636                        ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
8637            else :
8638                if len( myimgsr ) > 0:
8639                    dowrite=False
8640                    myimgcount = 0
8641                    if len( myimgsr ) > 0 :
8642                        myimg = myimgsr[myimgcount]
8643                        subjectpropath = os.path.dirname( myimg )
8644                        subjectpropath = re.sub( sourcedatafoldername, processDir, subjectpropath )
8645                        mysplit = subjectpropath.split("/")
8646                        mysplitCount = len( mysplit )
8647                        project = mysplit[mysplitCount-5]
8648                        date = mysplit[mysplitCount-4]
8649                        subject = mysplit[mysplitCount-3]
8650                        mymod = mysplit[mysplitCount-2] # FIXME system dependent
8651                        uid = mysplit[mysplitCount-1] # unique image id
8652                        os.makedirs( subjectpropath, exist_ok=True  )
8653                        if mymod == 'T1w':
8654                            identifier = mysep.join([project, date, subject, mymod, uid])
8655                        else:  # add the T1 unique id since that drives a lot of the analysis
8656                            identifier = mysep.join([project, date, subject, mymod, uid ])
8657                            identifier = identifier + "_" + iid
8658                        mymm = subjectpropath + "/" + identifier
8659                        mymmout = makewideout( mymm )
8660                        if verbose and not exists( mymmout ):
8661                            print("Modality specific processing: " + mymod + " execution " )
8662                            print( mymm )
8663                        elif verbose and exists( mymmout ) :
8664                            print("Modality specific processing: " + mymod + " complete " )
8665                        if exists( mymmout ) :
8666                            continue
8667                        if verbose:
8668                            print(subjectpropath)
8669                            print(identifier)
8670                            print( myimg )
8671                        if not testloop:
8672                            img = mm_read( myimg )
8673                            ishapelen = len( img.shape )
8674                            if mymod == 'T1w' and ishapelen == 3: # for a real run, set to True
8675                                if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
8676                                    if verbose:
8677                                        print('start t1 registration')
8678                                    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8679                                    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8680                                    template = mm_read( templatefn )
8681                                    template = ants.resample_image( template, [1,1,1], use_voxels=False )
8682                                    t1reg = ants.registration( template, hier['brain_n4_dnz'],
8683                                        "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
8684                                    myjac = ants.create_jacobian_determinant_image( template,
8685                                        t1reg['fwdtransforms'][0], do_log=True, geom=True )
8686                                    image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
8687                                    if visualize:
8688                                        ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
8689                                        ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
8690                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
8691                                    dowrite=True
8692                                    if verbose:
8693                                        print('start kk')
8694                                    tabPro, normPro = mm( t1, hier,
8695                                        srmodel=None,
8696                                        do_tractography=False,
8697                                        do_kk=True,
8698                                        do_normalization=templateTx,
8699                                        test_run=test_run,
8700                                        verbose=True )
8701                                    if visualize:
8702                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
8703                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
8704                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
8705                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
8706                            if mymod == 'T2Flair' and ishapelen == 3:
8707                                dowrite=True
8708                                tabPro, normPro = mm( t1, hier,
8709                                    flair_image = img,
8710                                    srmodel=None,
8711                                    do_tractography=False,
8712                                    do_kk=False,
8713                                    do_normalization=templateTx,
8714                                    test_run=test_run,
8715                                    verbose=True )
8716                                if visualize:
8717                                    maxslice = np.min( [21, img.shape[2] ] )
8718                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
8719                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
8720                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
8721                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
8722                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
8723                                img2 = None
8724                                if len( myimgsr ) > 1:
8725                                    img2 = mm_read( myimgsr[myimgcount+1] )
8726                                    ishapelen2 = len( img2.shape )
8727                                    if ishapelen2 != 4 :
8728                                        img2 = None
8729                                dowrite=True
8730                                tabPro, normPro = mm( t1, hier,
8731                                    rsf_image=[img,img2],
8732                                    srmodel=None,
8733                                    do_tractography=False,
8734                                    do_kk=False,
8735                                    do_normalization=templateTx,
8736                                    test_run=test_run,
8737                                    verbose=True )
8738                                if tabPro['rsf'] is not None and visualize:
8739                                    dfn=tabPro['rsf']['dfnname']
8740                                    maxslice = np.min( [21, tabPro['rsf']['meanBold'].shape[2] ] )
8741                                    ants.plot( tabPro['rsf']['meanBold'],
8742                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=mymm+mysep+"meanBOLD.png" )
8743                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['alff'],"Normalize"),
8744                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=mymm+mysep+"boldALFF.png" )
8745                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['falff'],"Normalize"),
8746                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=mymm+mysep+"boldfALFF.png" )
8747                                    ants.plot( tabPro['rsf']['meanBold'], tabPro['rsf'][dfn],
8748                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='DefaultMode', filename=mymm+mysep+"boldDefaultMode.png" )
8749                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
8750                                dowrite=True
8751                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
8752                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
8753                                imgList = [ img ]
8754                                bvalfnList = [ bvalfn ]
8755                                bvecfnList = [ bvecfn ]
8756                                if len( myimgsr ) > 1:  # find DTI_RL
8757                                    dtilrfn = myimgsr[myimgcount+1]
8758                                    if len( dtilrfn ) == 1:
8759                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
8760                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
8761                                        imgRL = ants.image_read( dtilrfn )
8762                                        imgList.append( imgRL )
8763                                        bvalfnList.append( bvalfnRL )
8764                                        bvecfnList.append( bvecfnRL )
8765                                srmodel_DTI_mdl=None
8766                                if srmodel_DTI is not None:
8767                                    temp = ants.get_spacing(img)
8768                                    dtspc=[temp[0],temp[1],temp[2]]
8769                                    bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
8770                                    if isinstance( srmodel_DTI, str ):
8771                                        mdlfn = re.sub( "bestup", bestup, srmodel_DTI )
8772                                    if exists( mdlfn ):
8773                                        if verbose:
8774                                            print(mdlfn)
8775                                        srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
8776                                    else:
8777                                        print(mdlfn + " does not exist - wont use SR")
8778                                tabPro, normPro = mm( t1, hier,
8779                                    dw_image=imgList,
8780                                    bvals = bvalfnList,
8781                                    bvecs = bvecfnList,
8782                                    srmodel=srmodel_DTI_mdl,
8783                                    do_tractography=not test_run,
8784                                    do_kk=False,
8785                                    do_normalization=templateTx,
8786                                    test_run=test_run,
8787                                    verbose=True )
8788                                mydti = tabPro['DTI']
8789                                if visualize:
8790                                    maxslice = np.min( [21, mydti['recon_fa'] ] )
8791                                    ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
8792                                    ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
8793                                    ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
8794                            if dowrite:
8795                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep, verbose=True )
8796                                for mykey in normPro.keys():
8797                                    if normPro[mykey] is not None:
8798                                        if visualize and normPro[mykey].components == 1 and False:
8799                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
8800        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
8801            return
8802        if verbose:
8803            print("done with " + overmodX )
8804    if verbose:
8805        print("mm_nrg complete.")
8806    return

too dangerous to document ... use with care.

processes multiple modality MRI specifically:

  • T1w
  • T2Flair
  • DTI, DTI_LR, DTI_RL
  • rsfMRI, rsfMRI_LR, rsfMRI_RL
  • NM2DMT (neuromelanin)

other modalities may be added later ...

"trust me, i know what i'm doing" - sledgehammer

convert to pynb via: p2j mm.py -o

convert the ipynb to html via: jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html

this function assumes NRG format for the input data .... we also assume that t1w hierarchical (if already done) was written via its standardized write function. NRG = https://github.com/stnava/biomedicalDataOrganization

this function is verbose

Parameters

studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid; these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor, rsf=resting state fmri, flair=T2Flair. none of these are required. only t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*. see antspymm.generate_mm_dataframe

sourcedir : a study specific folder containing individual subject folders

sourcedatafoldername : root for source data e.g. "images"

processDir : where output will go - parallel to sourcedatafoldername e.g. "processed"

mysep : define a character separator for filename components

srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan

srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan

srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan

visualize : True - will plot some results to png

nrg_modality_list : list of permissible modalities - always include [T1w] as base

verbose : boolean

Returns

writes output to disk and potentially produces figures that may be captured in a ipynb / html file.

def mm_csv( studycsv, mysep='-', srmodel_T1=False, srmodel_NM=False, srmodel_DTI=False, dti_motion_correct='antsRegistrationSyNQuickRepro[r]', dti_denoise=False, nrg_modality_list=None, normalization_template=None, normalization_template_output=None, normalization_template_transform_type='antsRegistrationSyNRepro[s]', normalization_template_spacing=None, enantiomorphic=False, perfusion_trim=10, perfusion_m0_image=None, perfusion_m0=None, rsf_upsampling=3.0, pet3d=None, min_t1_spacing_for_sr=0.8):
8810def mm_csv(
8811    studycsv,   # pandas data frame
8812    mysep = '-', # or "_" for BIDS
8813    srmodel_T1 = False, # optional - will add a great deal of time
8814    srmodel_NM = False, # optional - will add a great deal of time
8815    srmodel_DTI = False, # optional - will add a great deal of time
8816    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
8817    dti_denoise = False,
8818    nrg_modality_list = None,
8819    normalization_template = None,
8820    normalization_template_output = None,
8821    normalization_template_transform_type = "antsRegistrationSyNRepro[s]",
8822    normalization_template_spacing=None,
8823    enantiomorphic=False,
8824    perfusion_trim = 10,
8825    perfusion_m0_image = None,
8826    perfusion_m0 = None,
8827    rsf_upsampling = 3.0,
8828    pet3d = None,
8829    min_t1_spacing_for_sr = 0.8,
8830):
8831    """
8832    too dangerous to document ... use with care.
8833
8834    processes multiple modality MRI specifically:
8835
8836    * T1w
8837    * T2Flair
8838    * DTI, DTI_LR, DTI_RL
8839    * rsfMRI, rsfMRI_LR, rsfMRI_RL
8840    * NM2DMT (neuromelanin)
8841
8842    other modalities may be added later ...
8843
8844    "trust me, i know what i'm doing" - sledgehammer
8845
8846    convert to pynb via:
8847        p2j mm.py -o
8848
8849    convert the ipynb to html via:
8850        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
8851
8852    this function does not assume NRG format for the input data ....
8853
8854    Parameters
8855    -------------
8856
8857    studycsv : must have columns:
8858        - subjectID
8859        - date or session
8860        - imageID
8861        - modality
8862        - sourcedir
8863        - outputdir
8864        - filename (path to the t1 image)
8865        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
8866        these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor,
8867        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
8868        t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.
8869        see antspymm.generate_mm_dataframe
8870
8871    sourcedir : a study specific folder containing individual subject folders
8872
8873    outputdir : a study specific folder where individual output subject folders will go
8874
8875    filename : the raw image filename (full path)
8876
8877    srmodel_T1 : None (default) - .keras or h5 filename for SR model (siq generated). 
8878
8879    srmodel_NM : None (default) - .keras or h5 filename for SR model (siq generated)
8880    the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data.  see siq.optimize_upsampling_shape.
8881
8882    srmodel_DTI : None (default) - .keras or h5 filename for SR model (siq generated). 
8883    the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data.  see siq.optimize_upsampling_shape.
8884
8885    dti_motion_correct : None, Rigid or SyN
8886
8887    dti_denoise : boolean
8888
8889    nrg_modality_list : optional; defaults to None; use to focus on a given modality
8890
8891    normalization_template : optional; defaults to None; if present, all images will
8892        be deformed into this space and the deformation will be stored with an extension
8893        related to this variable.  this should be a brain extracted T1w image.
8894
8895    normalization_template_output : optional string; defaults to None; naming for the 
8896        normalization_template outputs which will be in the T1w directory.
8897
8898    normalization_template_transform_type : optional string transform type passed to ants.registration
8899
8900    normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed 
8901    
8902    enantiomorphic: boolean (WIP)
8903
8904    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
8905
8906    perfusion_m0_image : optional m0 antsImage associated with the perfusion time series
8907
8908    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
8909
8910    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
8911
8912    pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize
8913
8914    min_t1_spacing_for_sr : float 
8915        if the minimum input image spacing is less than this value, 
8916        the function will return the original image.  Default 0.8.
8917
8918    Returns
8919    ---------
8920
8921    writes output to disk and produces figures
8922
8923    """
8924    import traceback
8925    visualize = True
8926    verbose = True
8927    if verbose:
8928        print( version() )
8929    if nrg_modality_list is None:
8930        nrg_modality_list = get_valid_modalities()
8931    if studycsv.shape[0] < 1:
8932        raise ValueError('studycsv has no rows')
8933    musthavecols = ['projectID', 'subjectID','date','imageID','modality','sourcedir','outputdir','filename']
8934    for k in range(len(musthavecols)):
8935        if not musthavecols[k] in studycsv.keys():
8936            raise ValueError('studycsv is missing column ' +musthavecols[k] )
8937    def makewideout( x, separator = mysep ):
8938        return x + separator + 'mmwide.csv'
8939    testloop = False
8940    counter=0
8941    import glob as glob
8942    from os.path import exists
8943    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8944    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
8945    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8946    if not exists( templatefn ):
8947        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
8948        antspyt1w.get_data( force_download=True )
8949        get_data( force_download=True )
8950    template = mm_read( templatefn ) # Read in template
8951    test_run = False
8952    if test_run:
8953        visualize=False
8954    # get sid and dtid from studycsv
8955    # musthavecols = ['projectID','subjectID','date','imageID','modality','sourcedir','outputdir','filename']
8956    projid = str(studycsv['projectID'].iloc[0])
8957    sid = str(studycsv['subjectID'].iloc[0])
8958    dtid = str(studycsv['date'].iloc[0])
8959    iid = str(studycsv['imageID'].iloc[0])
8960    t1iidUse=iid
8961    modality = str(studycsv['modality'].iloc[0])
8962    sourcedir = str(studycsv['sourcedir'].iloc[0])
8963    outputdir = str(studycsv['outputdir'].iloc[0])
8964    filename = str(studycsv['filename'].iloc[0])
8965    if not exists(filename):
8966            raise ValueError('mm_nrg cannot find filename ' + filename + ' in mm_csv' )
8967
8968    # hierarchical
8969    # NOTE: if there are multiple T1s for this time point, should take
8970    # the one with the highest resnetGrade
8971    t1fn = filename
8972    if not exists( t1fn ):
8973        raise ValueError('mm_nrg cannot find the T1w with uid ' + t1fn )
8974    t1 = mm_read( t1fn, modality='T1w' )
8975    minspc = np.min(ants.get_spacing(t1))
8976    minshape = np.min(t1.shape)
8977    if minspc < 1e-16:
8978        warnings.warn('minimum spacing in T1w is too small - cannot process. ' + str(minspc) )
8979        return
8980    if minshape < 32:
8981        warnings.warn('minimum shape in T1w is too small - cannot process. ' + str(minshape) )
8982        return
8983
8984    if enantiomorphic:
8985        t1 = enantiomorphic_filling_without_mask( t1, axis=0 )[0]
8986    hierfn = outputdir + "/"  + projid + "/" + sid + "/" + dtid + "/" + "T1wHierarchical" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchical" + mysep + iid + mysep
8987    hierfnSR = outputdir + "/" + projid + "/"  + sid + "/" + dtid + "/" + "T1wHierarchicalSR" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchicalSR" + mysep + iid + mysep
8988    hierfntest = hierfn + 'cerebellum.csv'
8989    if verbose:
8990        print( hierfntest )
8991    regout = re.sub("T1wHierarchical","T1w",hierfn) + "syn"
8992    templateTx = {
8993        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
8994        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
8995    groupTx = None
8996    # make the T1w directory
8997    os.makedirs( os.path.dirname(re.sub("T1wHierarchical","T1w",hierfn)), exist_ok=True  )
8998    if normalization_template_output is not None:
8999        normout = re.sub("T1wHierarchical","T1w",hierfn) +  normalization_template_output
9000        templateNormTx = {
9001            'fwdtransforms': [ normout+'1Warp.nii.gz', normout+'0GenericAffine.mat'],
9002            'invtransforms': [ normout+'0GenericAffine.mat', normout+'1InverseWarp.nii.gz']  }
9003        groupTx = templateNormTx['fwdtransforms']
9004    if verbose:
9005        print( "-<REGISTRATION EXISTENCE>-: \n" + 
9006              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
9007            str(exists( templateTx['fwdtransforms'][0])) + " " +
9008            str(exists( templateTx['fwdtransforms'][1])) + " " +
9009            str(exists( templateTx['invtransforms'][0])) + " " +
9010            str(exists( templateTx['invtransforms'][1])) )
9011    if verbose:
9012        print( hierfntest )
9013    hierexists = exists( hierfntest ) and exists( templateTx['fwdtransforms'][0]) and exists( templateTx['fwdtransforms'][1]) and exists( templateTx['invtransforms'][0]) and exists( templateTx['invtransforms'][1])
9014    hier = None
9015    if srmodel_T1 is not None:
9016        srmodel_T1_mdl = tf.keras.models.load_model( srmodel_T1, compile=False )
9017        if verbose:
9018            print("Convert T1w to SR via model ", srmodel_T1 )
9019        t1 = t1w_super_resolution_with_hemispheres( t1, srmodel_T1_mdl,
9020            min_spacing=min_t1_spacing_for_sr )
9021    if not hierexists and not testloop:
9022        subjectpropath = os.path.dirname( hierfn )
9023        if verbose:
9024            print( subjectpropath )
9025        os.makedirs( subjectpropath, exist_ok=True  )
9026        ants.image_write( t1, hierfn + 'head.nii.gz' )
9027        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
9028        antspyt1w.write_hierarchical( hier, hierfn )
9029        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9030                hier['dataframes'], identifier=None )
9031        t1wide.to_csv( hierfn + 'mmwide.csv' )
9032    ################# read the hierarchical data ###############################
9033    # over-write the rbp data with a consistent and recent approach ############
9034    redograding = True
9035    if redograding:
9036        myx = antspyt1w.inspect_raw_t1( 
9037            ants.image_read(t1fn), hierfn + 'rbp' , option='both' )
9038        myx['brain'].to_csv( hierfn + 'rbp.csv', index=False )
9039        myx['brain'].to_csv( hierfn + 'rbpbrain.csv', index=False )
9040        del myx
9041
9042    hier = antspyt1w.read_hierarchical( hierfn )
9043    t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9044        hier['dataframes'], identifier=None )
9045    rgrade = str( t1wide['resnetGrade'].iloc[0] )
9046    if t1wide['resnetGrade'].iloc[0] < 0.20:
9047        warnings.warn('T1w quality check indicates failure: ' + rgrade + " will not process." )
9048        return
9049    else:
9050        print('T1w quality check indicates success: ' + rgrade + " will process." )
9051
9052    if srmodel_T1 is not None and False : # deprecated
9053        hierfntest = hierfnSR + 'mtl.csv'
9054        if verbose:
9055            print( hierfntest )
9056        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
9057        if not hierexists:
9058            subjectpropath = os.path.dirname( hierfnSR )
9059            if verbose:
9060                print( subjectpropath )
9061            os.makedirs( subjectpropath, exist_ok=True  )
9062            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
9063            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
9064            if isinstance( srmodel_T1, str ):
9065                mdlfn = re.sub( 'bestup', bestup, srmodel_T1 )
9066            if verbose:
9067                print( mdlfn )
9068            if exists( mdlfn ):
9069                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
9070            else:
9071                print( mdlfn + " does not exist - will not run.")
9072            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
9073            antspyt1w.write_hierarchical( hierSR, hierfnSR )
9074            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9075                    hierSR['dataframes'], identifier=None )
9076            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
9077    hier = antspyt1w.read_hierarchical( hierfn )
9078    if exists( hierfn + 'mmwide.csv' ) :
9079        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
9080    elif not testloop:
9081        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9082                hier['dataframes'], identifier=None )
9083    if not testloop:
9084        t1imgbrn = hier['brain_n4_dnz']
9085        t1atropos = hier['dkt_parc']['tissue_segmentation']
9086
9087    if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
9088        if verbose:
9089            print('start t1 registration')
9090        ex_path = os.path.expanduser( "~/.antspyt1w/" )
9091        templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
9092        template = mm_read( templatefn )
9093        template = ants.resample_image( template, [1,1,1], use_voxels=False )
9094        t1reg = ants.registration( template, 
9095            hier['brain_n4_dnz'],
9096            "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
9097        myjac = ants.create_jacobian_determinant_image( template,
9098            t1reg['fwdtransforms'][0], do_log=True, geom=True )
9099        image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
9100        if visualize:
9101            ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
9102            ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
9103
9104    if normalization_template_output is not None and normalization_template is not None:
9105        if verbose:
9106            print("begin group template registration")
9107        if not exists( normout+'0GenericAffine.mat' ):
9108            if normalization_template_spacing is not None:
9109                normalization_template_rr=ants.resample_image(normalization_template,normalization_template_spacing)
9110            else:
9111                normalization_template_rr=normalization_template
9112            greg = ants.registration( 
9113                normalization_template_rr, 
9114                hier['brain_n4_dnz'],
9115                normalization_template_transform_type,
9116                outprefix = normout, verbose=False )
9117            myjac = ants.create_jacobian_determinant_image( template,
9118                    greg['fwdtransforms'][0], do_log=True, geom=True )
9119            image_write_with_thumbnail( myjac, normout + "logjacobian.nii.gz", thumb=False )
9120            if verbose:
9121                print("end group template registration")
9122        else:
9123            if verbose:
9124                print("group template registration already done")
9125
9126    # loop over modalities and then unique image IDs
9127    # we treat NM in a "special" way -- aggregating repeats
9128    # other modalities (beyond T1) are treated individually
9129    for overmodX in nrg_modality_list:
9130        # define 1. input images 2. output prefix
9131        mydoc = docsamson( overmodX, studycsv=studycsv, outputdir=outputdir, projid=projid, sid=sid, dtid=dtid, mysep=mysep,t1iid=t1iidUse )
9132        myimgsr = mydoc['images']
9133        mymm = mydoc['outprefix']
9134        mymod = mydoc['modality']
9135        if verbose:
9136            print( mydoc )
9137        if len(myimgsr) > 0:
9138            dowrite=False
9139            if verbose:
9140                print( 'overmodX is : ' + overmodX )
9141                print( 'example image name is : '  )
9142                print( myimgsr )
9143            if overmodX == 'NM2DMT':
9144                dowrite = True
9145                visualize = True
9146                subjectpropath = os.path.dirname( mydoc['outprefix'] )
9147                if verbose:
9148                    print("subjectpropath is")
9149                    print(subjectpropath)
9150                    os.makedirs( subjectpropath, exist_ok=True  )
9151                myimgsr2 = myimgsr
9152                myimgsr2.sort()
9153                is4d = False
9154                temp = ants.image_read( myimgsr2[0] )
9155                if temp.dimension == 4:
9156                    is4d = True
9157                if len( myimgsr2 ) == 1 and not is4d: # check dimension
9158                    myimgsr2 = myimgsr2 + myimgsr2
9159                mymmout = makewideout( mymm )
9160                if verbose and not exists( mymmout ):
9161                    print( "NM " + mymm  + ' execution ')
9162                elif verbose and exists( mymmout ) :
9163                    print( "NM " + mymm + ' complete ' )
9164                if exists( mymmout ):
9165                    continue
9166                if is4d:
9167                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
9168                else:
9169                    nmlist = []
9170                    for zz in myimgsr2:
9171                        nmlist.append( mm_read( zz ) )
9172                srmodel_NM_mdl = None
9173                if srmodel_NM is not None:
9174                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
9175                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.keras"
9176                    if isinstance( srmodel_NM, str ):
9177                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
9178                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
9179                    if exists( mdlfn ):
9180                        if verbose:
9181                            print(mdlfn)
9182                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
9183                    else:
9184                        print( mdlfn + " does not exist - wont use SR")
9185                if not testloop:
9186                    try:
9187                        tabPro, normPro = mm( t1, hier,
9188                            nm_image_list = nmlist,
9189                            srmodel=srmodel_NM_mdl,
9190                            do_tractography=False,
9191                            do_kk=False,
9192                            do_normalization=templateTx,
9193                            group_template = normalization_template,
9194                            group_transform = groupTx,
9195                            test_run=test_run,
9196                            verbose=True )
9197                    except Exception as e:
9198                        error_info = traceback.format_exc()
9199                        print(error_info)
9200                        visualize=False
9201                        dowrite=False
9202                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9203                        pass
9204                    if not test_run:
9205                        if dowrite:
9206                            write_mm( output_prefix=mymm, mm=tabPro,
9207                                mm_norm=normPro, t1wide=None, separator=mysep )
9208                        if visualize :
9209                            nmpro = tabPro['NM']
9210                            mysl = range( nmpro['NM_avg'].shape[2] )
9211                            ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
9212                            mysl = range( nmpro['NM_avg_cropped'].shape[2] )
9213                            ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
9214                            ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
9215                            ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
9216            else :
9217                if len( myimgsr ) > 0 :
9218                    dowrite=False
9219                    myimgcount=0
9220                    if len( myimgsr ) > 0 :
9221                        myimg = myimgsr[ myimgcount ]
9222                        subjectpropath = os.path.dirname( mydoc['outprefix'] )
9223                        if verbose:
9224                            print("subjectpropath is")
9225                            print(subjectpropath)
9226                        os.makedirs( subjectpropath, exist_ok=True  )
9227                        mymmout = makewideout( mymm )
9228                        if verbose and not exists( mymmout ):
9229                            print( "Modality specific processing: " + mymod + " execution " )
9230                            print( mymm )
9231                        elif verbose and exists( mymmout ) :
9232                            print("Modality specific processing: " + mymod + " complete " )
9233                        if exists( mymmout ) :
9234                            continue
9235                        if verbose:
9236                            print( subjectpropath )
9237                            print( myimg )
9238                        if not testloop:
9239                            img = mm_read( myimg )
9240                            ishapelen = len( img.shape )
9241                            if mymod == 'T1w' and ishapelen == 3:
9242                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
9243                                    dowrite=True
9244                                    if verbose:
9245                                        print('start kk')
9246                                    try:
9247                                        tabPro, normPro = mm( t1, hier,
9248                                            srmodel=None,
9249                                            do_tractography=False,
9250                                            do_kk=True,
9251                                            do_normalization=templateTx,
9252                                            group_template = normalization_template,
9253                                            group_transform = groupTx,
9254                                            test_run=test_run,
9255                                            verbose=True )
9256                                    except Exception as e:
9257                                        error_info = traceback.format_exc()
9258                                        print(error_info)
9259                                        visualize=False
9260                                        dowrite=False
9261                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9262                                        pass
9263                                    if visualize:
9264                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
9265                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
9266                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
9267                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
9268                            if mymod == 'T2Flair' and ishapelen == 3 and np.min(img.shape) > 15:
9269                                dowrite=True
9270                                try:
9271                                    tabPro, normPro = mm( t1, hier,
9272                                        flair_image = img,
9273                                        srmodel=None,
9274                                        do_tractography=False,
9275                                        do_kk=False,
9276                                        do_normalization=templateTx,
9277                                        group_template = normalization_template,
9278                                        group_transform = groupTx,
9279                                        test_run=test_run,
9280                                        verbose=True )
9281                                except Exception as e:
9282                                        error_info = traceback.format_exc()
9283                                        print(error_info)
9284                                        visualize=False
9285                                        dowrite=False
9286                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9287                                        pass
9288                                if visualize:
9289                                    maxslice = np.min( [21, img.shape[2] ] )
9290                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
9291                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
9292                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
9293                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
9294                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
9295                                img2 = None
9296                                if len( myimgsr ) > 1:
9297                                    img2 = mm_read( myimgsr[myimgcount+1] )
9298                                    ishapelen2 = len( img2.shape )
9299                                    if ishapelen2 != 4 or 1 in img2.shape:
9300                                        img2 = None
9301                                if 1 in img.shape:
9302                                    warnings.warn( 'rsfMRI image shape suggests it is an incorrectly converted mosaic image - will not process.')
9303                                    dowrite=False
9304                                    tabPro={'rsf':None}
9305                                    normPro={'rsf':None}
9306                                else:
9307                                    dowrite=True
9308                                    try:
9309                                        tabPro, normPro = mm( t1, hier,
9310                                            rsf_image=[img,img2],
9311                                            srmodel=None,
9312                                            do_tractography=False,
9313                                            do_kk=False,
9314                                            do_normalization=templateTx,
9315                                            group_template = normalization_template,
9316                                            group_transform = groupTx,
9317                                            rsf_upsampling = rsf_upsampling,
9318                                            test_run=test_run,
9319                                            verbose=True )
9320                                    except Exception as e:
9321                                        error_info = traceback.format_exc()
9322                                        print(error_info)
9323                                        visualize=False
9324                                        dowrite=False
9325                                        tabPro={'rsf':None}
9326                                        normPro={'rsf':None}
9327                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9328                                        pass
9329                                if tabPro['rsf'] is not None and visualize:
9330                                    for tpro in tabPro['rsf']: # FIXMERSF
9331                                        maxslice = np.min( [21, tpro['meanBold'].shape[2] ] )
9332                                        tproprefix = mymm+mysep+str(tpro['paramset'])+mysep
9333                                        ants.plot( tpro['meanBold'],
9334                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=tproprefix+"meanBOLD.png" )
9335                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['alff'],"Normalize"),
9336                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=tproprefix+"boldALFF.png" )
9337                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['falff'],"Normalize"),
9338                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=tproprefix+"boldfALFF.png" )
9339                                        dfn=tpro['dfnname']
9340                                        ants.plot( tpro['meanBold'], tpro[dfn],
9341                                            axis=2, nslices=maxslice, ncol=7, crop=True, title=dfn, filename=tproprefix+"boldDefaultMode.png" )
9342                            if ( mymod == 'perf' ) and ishapelen == 4:
9343                                dowrite=True
9344                                try:
9345                                    tabPro, normPro = mm( t1, hier,
9346                                        perfusion_image=img,
9347                                        srmodel=None,
9348                                        do_tractography=False,
9349                                        do_kk=False,
9350                                        do_normalization=templateTx,
9351                                        group_template = normalization_template,
9352                                        group_transform = groupTx,
9353                                        test_run=test_run,
9354                                        perfusion_trim=perfusion_trim,
9355                                        perfusion_m0_image=perfusion_m0_image,
9356                                        perfusion_m0=perfusion_m0,
9357                                        verbose=True )
9358                                except Exception as e:
9359                                        error_info = traceback.format_exc()
9360                                        print(error_info)
9361                                        visualize=False
9362                                        dowrite=False
9363                                        tabPro={'perf':None}
9364                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9365                                        pass
9366                                if tabPro['perf'] is not None and visualize:
9367                                    maxslice = np.min( [21, tabPro['perf']['meanBold'].shape[2] ] )
9368                                    ants.plot( tabPro['perf']['perfusion'],
9369                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='perfusion image', filename=mymm+mysep+"perfusion.png" )
9370                                    ants.plot( tabPro['perf']['cbf'],
9371                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='CBF image', filename=mymm+mysep+"cbf.png" )
9372                                    ants.plot( tabPro['perf']['m0'],
9373                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='M0 image', filename=mymm+mysep+"m0.png" )
9374
9375                            if ( mymod == 'pet3d' ) and ishapelen == 3:
9376                                dowrite=True
9377                                try:
9378                                    tabPro, normPro = mm( t1, hier,
9379                                        srmodel=None,
9380                                        do_tractography=False,
9381                                        do_kk=False,
9382                                        do_normalization=templateTx,
9383                                        group_template = normalization_template,
9384                                        group_transform = groupTx,
9385                                        test_run=test_run,
9386                                        pet_3d_image=img,
9387                                        verbose=True )
9388                                except Exception as e:
9389                                        error_info = traceback.format_exc()
9390                                        print(error_info)
9391                                        visualize=False
9392                                        dowrite=False
9393                                        tabPro={'pet3d':None}
9394                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9395                                        pass
9396                                if tabPro['pet3d'] is not None and visualize:
9397                                    maxslice = np.min( [21, tabPro['pet3d']['pet3d'].shape[2] ] )
9398                                    ants.plot( tabPro['pet3d']['pet3d'],
9399                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='PET image', filename=mymm+mysep+"pet3d.png" )
9400                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
9401                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
9402                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
9403                                imgList = [ img ]
9404                                bvalfnList = [ bvalfn ]
9405                                bvecfnList = [ bvecfn ]
9406                                missing_dti_data=False # bval, bvec or images
9407                                if len( myimgsr ) == 2:  # find DTI_RL
9408                                    dtilrfn = myimgsr[myimgcount+1]
9409                                    if exists( dtilrfn ):
9410                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
9411                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
9412                                        imgRL = ants.image_read( dtilrfn )
9413                                        imgList.append( imgRL )
9414                                        bvalfnList.append( bvalfnRL )
9415                                        bvecfnList.append( bvecfnRL )
9416                                elif len( myimgsr ) == 3:  # find DTI_RL
9417                                    print("DTI trinity")
9418                                    dtilrfn = myimgsr[myimgcount+1]
9419                                    dtilrfn2 = myimgsr[myimgcount+2]
9420                                    if exists( dtilrfn ) and exists( dtilrfn2 ):
9421                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
9422                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
9423                                        bvalfnRL2 = re.sub( '.nii.gz', '.bval' , dtilrfn2 )
9424                                        bvecfnRL2 = re.sub( '.nii.gz', '.bvec' , dtilrfn2 )
9425                                        imgRL = ants.image_read( dtilrfn )
9426                                        imgRL2 = ants.image_read( dtilrfn2 )
9427                                        bvals, bvecs = read_bvals_bvecs( bvalfnRL , bvecfnRL  )
9428                                        print( bvals.max() )
9429                                        bvals2, bvecs2 = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
9430                                        print( bvals2.max() )
9431                                        temp = merge_dwi_data( imgRL, bvals, bvecs, imgRL2, bvals2, bvecs2  )
9432                                        imgList.append( temp[0] )
9433                                        bvalfnList.append( mymm+mysep+'joined.bval' )
9434                                        bvecfnList.append( mymm+mysep+'joined.bvec' )
9435                                        write_bvals_bvecs( temp[1], temp[2], mymm+mysep+'joined' )
9436                                        bvalsX, bvecsX = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
9437                                        print( bvalsX.max() )
9438                                # check existence of all files expected ...
9439                                for dtiex in bvalfnList+bvecfnList+myimgsr:
9440                                    if not exists(dtiex):
9441                                        print('mm_csv: missing dti data ' + dtiex )
9442                                        missing_dti_data=True
9443                                        dowrite=False
9444                                if not missing_dti_data:
9445                                    dowrite=True
9446                                    srmodel_DTI_mdl=None
9447                                    if srmodel_DTI is not None:
9448                                        temp = ants.get_spacing(img)
9449                                        dtspc=[temp[0],temp[1],temp[2]]
9450                                        bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
9451                                        mdlfn = re.sub( 'bestup', bestup, srmodel_DTI )
9452                                        if isinstance( srmodel_DTI, str ):
9453                                            srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
9454                                            mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
9455                                        if exists( mdlfn ):
9456                                            if verbose:
9457                                                print(mdlfn)
9458                                            srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
9459                                        else:
9460                                            print(mdlfn + " does not exist - wont use SR")
9461                                    try:
9462                                        tabPro, normPro = mm( t1, hier,
9463                                            dw_image=imgList,
9464                                            bvals = bvalfnList,
9465                                            bvecs = bvecfnList,
9466                                            srmodel=srmodel_DTI_mdl,
9467                                            do_tractography=not test_run,
9468                                            do_kk=False,
9469                                            do_normalization=templateTx,
9470                                            group_template = normalization_template,
9471                                            group_transform = groupTx,
9472                                            dti_motion_correct = dti_motion_correct,
9473                                            dti_denoise = dti_denoise,
9474                                            test_run=test_run,
9475                                            verbose=True )
9476                                    except Exception as e:
9477                                            error_info = traceback.format_exc()
9478                                            print(error_info)
9479                                            visualize=False
9480                                            dowrite=False
9481                                            tabPro={'DTI':None}
9482                                            print(f"antspymmerror occurred while processing {overmodX}: {e}")
9483                                            pass
9484                                    mydti = tabPro['DTI']
9485                                    if visualize and tabPro['DTI'] is not None:
9486                                        maxslice = np.min( [21, mydti['recon_fa'] ] )
9487                                        ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
9488                                        ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
9489                                        ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
9490                            if dowrite:
9491                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep )
9492                                for mykey in normPro.keys():
9493                                    if normPro[mykey] is not None and normPro[mykey].components == 1:
9494                                        if visualize and False:
9495                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
9496        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
9497            return
9498        if verbose:
9499            print("done with " + overmodX )
9500    if verbose:
9501        print("mm_nrg complete.")
9502    return

too dangerous to document ... use with care.

processes multiple modality MRI specifically:

  • T1w
  • T2Flair
  • DTI, DTI_LR, DTI_RL
  • rsfMRI, rsfMRI_LR, rsfMRI_RL
  • NM2DMT (neuromelanin)

other modalities may be added later ...

"trust me, i know what i'm doing" - sledgehammer

convert to pynb via: p2j mm.py -o

convert the ipynb to html via: jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html

this function does not assume NRG format for the input data ....

Parameters

studycsv : must have columns: - subjectID - date or session - imageID - modality - sourcedir - outputdir - filename (path to the t1 image) other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid; these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor, rsf=resting state fmri, flair=T2Flair. none of these are required. only t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*. see antspymm.generate_mm_dataframe

sourcedir : a study specific folder containing individual subject folders

outputdir : a study specific folder where individual output subject folders will go

filename : the raw image filename (full path)

srmodel_T1 : None (default) - .keras or h5 filename for SR model (siq generated).

srmodel_NM : None (default) - .keras or h5 filename for SR model (siq generated) the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data. see siq.optimize_upsampling_shape.

srmodel_DTI : None (default) - .keras or h5 filename for SR model (siq generated). the model name should follow a style like prefix_bestup_postfix where bestup will be replaced with an optimal upsampling factor eg 2x2x2 based on the data. see siq.optimize_upsampling_shape.

dti_motion_correct : None, Rigid or SyN

dti_denoise : boolean

nrg_modality_list : optional; defaults to None; use to focus on a given modality

normalization_template : optional; defaults to None; if present, all images will be deformed into this space and the deformation will be stored with an extension related to this variable. this should be a brain extracted T1w image.

normalization_template_output : optional string; defaults to None; naming for the normalization_template outputs which will be in the T1w directory.

normalization_template_transform_type : optional string transform type passed to ants.registration

normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed

enantiomorphic: boolean (WIP)

perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series

perfusion_m0_image : optional m0 antsImage associated with the perfusion time series

perfusion_m0 : optional list containing indices of the m0 in the perfusion time series

rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done

pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize

min_t1_spacing_for_sr : float if the minimum input image spacing is less than this value, the function will return the original image. Default 0.8.

Returns

writes output to disk and produces figures

def collect_blind_qc_by_modality(modality_path, set_index_to_fn=True):
1094def collect_blind_qc_by_modality( modality_path, set_index_to_fn=True ):
1095    """
1096    Collects blind QC data from multiple CSV files with the same modality.
1097
1098    Args:
1099
1100    modality_path (str): The path to the folder containing the CSV files.
1101
1102    set_index_to_fn: boolean
1103
1104    Returns:
1105    Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.
1106    """
1107    import glob as glob
1108    fns = glob.glob( modality_path )
1109    fns.sort()
1110    jdf = pd.DataFrame()
1111    for k in range(len(fns)):
1112        temp=pd.read_csv(fns[k])
1113        if not 'filename' in temp.keys():
1114            temp['filename']=fns[k]
1115        jdf=pd.concat( [jdf,temp], axis=0, ignore_index=False )
1116    if set_index_to_fn:
1117        jdf.reset_index(drop=True)
1118        if "Unnamed: 0" in jdf.columns:
1119            holder=jdf.pop( "Unnamed: 0" )
1120        jdf.set_index('filename')
1121    return jdf

Collects blind QC data from multiple CSV files with the same modality.

Args:

modality_path (str): The path to the folder containing the CSV files.

set_index_to_fn: boolean

Returns: Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.

def alffmap(x, flo=0.01, fhi=0.1, tr=1, detrend=True):
9763def alffmap( x, flo=0.01, fhi=0.1, tr=1, detrend = True ):
9764    """
9765    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
9766    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
9767    are related measures that quantify the amplitude of low frequency
9768    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
9769    same function in ANTsR.
9770
9771    x input vector for the time series of interest
9772    flo low frequency, typically 0.01
9773    fhi high frequency, typically 0.1
9774    tr the period associated with the vector x (inverse of frequency)
9775    detrend detrend the input time series
9776
9777    return vector is output showing ALFF and fALFF values
9778    """
9779    temp = spec_pgram( x, xfreq=1.0/tr, demean=False, detrend=detrend, taper=0, fast=True, plot=False )
9780    fselect = np.logical_and( temp['freq'] >= flo, temp['freq'] <= fhi )
9781    denom = (temp['spec']).sum()
9782    numer = (temp['spec'][fselect]).sum()
9783    return {  'alff':numer, 'falff': numer/denom }

Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008) are related measures that quantify the amplitude of low frequency oscillations (LFOs). This function outputs ALFF and fALFF for the input. same function in ANTsR.

x input vector for the time series of interest flo low frequency, typically 0.01 fhi high frequency, typically 0.1 tr the period associated with the vector x (inverse of frequency) detrend detrend the input time series

return vector is output showing ALFF and fALFF values

def alff_image(x, mask, flo=0.01, fhi=0.1, nuisance=None):
9786def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
9787    """
9788    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
9789    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
9790    are related measures that quantify the amplitude of low frequency
9791    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
9792
9793    x - input clean resting state fmri
9794    mask - mask over which to compute f/alff
9795    flo - low frequency, typically 0.01
9796    fhi - high frequency, typically 0.1
9797    nuisance - optional nuisance matrix
9798
9799    return dictionary with ALFF and fALFF images
9800    """
9801    xmat = ants.timeseries_to_matrix( x, mask )
9802    if nuisance is not None:
9803        xmat = ants.regress_components( xmat, nuisance )
9804    alffvec = xmat[0,:]*0
9805    falffvec = xmat[0,:]*0
9806    mytr = ants.get_spacing( x )[3]
9807    for n in range( xmat.shape[1] ):
9808        temp = alffmap( xmat[:,n], flo=flo, fhi=fhi, tr=mytr )
9809        alffvec[n]=temp['alff']
9810        falffvec[n]=temp['falff']
9811    alffi=ants.make_image( mask, alffvec )
9812    falffi=ants.make_image( mask, falffvec )
9813    alfftrimmedmean = calculate_trimmed_mean( alffvec, 0.01 )
9814    falfftrimmedmean = calculate_trimmed_mean( falffvec, 0.01 )
9815    alffi=alffi / alfftrimmedmean
9816    falffi=falffi / falfftrimmedmean
9817    return {  'alff': alffi, 'falff': falffi }

Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008) are related measures that quantify the amplitude of low frequency oscillations (LFOs). This function outputs ALFF and fALFF for the input.

x - input clean resting state fmri mask - mask over which to compute f/alff flo - low frequency, typically 0.01 fhi - high frequency, typically 0.1 nuisance - optional nuisance matrix

return dictionary with ALFF and fALFF images

def down2iso(x, interpolation='linear', takemin=False):
9820def down2iso( x, interpolation='linear', takemin=False ):
9821    """
9822    will downsample an anisotropic image to an isotropic resolution
9823
9824    x: input image
9825
9826    interpolation: linear or nearestneighbor
9827
9828    takemin : boolean map to min space; otherwise max
9829
9830    return image downsampled to isotropic resolution
9831    """
9832    spc = ants.get_spacing( x )
9833    if takemin:
9834        newspc = np.asarray(spc).min()
9835    else:
9836        newspc = np.asarray(spc).max()
9837    newspc = np.repeat( newspc, x.dimension )
9838    if interpolation == 'linear':
9839        xs = ants.resample_image( x, newspc, interp_type=0)
9840    else:
9841        xs = ants.resample_image( x, newspc, interp_type=1)
9842    return xs

will downsample an anisotropic image to an isotropic resolution

x: input image

interpolation: linear or nearestneighbor

takemin : boolean map to min space; otherwise max

return image downsampled to isotropic resolution

def read_mm_csv(x, is_t1=False, colprefix=None, separator='-', verbose=False):
9845def read_mm_csv( x, is_t1=False, colprefix=None, separator='-', verbose=False ):
9846    splitter=os.path.basename(x).split( separator )
9847    lensplit = len( splitter )-1
9848    temp = os.path.basename(x)
9849    temp = os.path.splitext(temp)[0]
9850    temp = re.sub(separator+'mmwide','',temp)
9851    idcols = ['u_hier_id','sid','visitdate','modality','mmimageuid','t1imageuid']
9852    df = pd.DataFrame( columns = idcols, index=range(1) )
9853    valstoadd = [temp] + splitter[1:(lensplit-1)]
9854    if is_t1:
9855        valstoadd = valstoadd + [splitter[(lensplit-1)],splitter[(lensplit-1)]]
9856    else:
9857        split2=splitter[(lensplit-1)].split( "_" )
9858        if len(split2) == 1:
9859            split2.append( split2[0] )
9860        if len(valstoadd) == 3:
9861            valstoadd = valstoadd + [split2[0]] + [math.nan] + [split2[1]]
9862        else:
9863            valstoadd = valstoadd + [split2[0],split2[1]]
9864    if verbose:
9865        print( valstoadd )
9866    df.iloc[0] = valstoadd
9867    if verbose:
9868        print( "read xdf: " + x )
9869    xdf = pd.read_csv( x )
9870    df.reset_index()
9871    xdf.reset_index(drop=True)
9872    if "Unnamed: 0" in xdf.columns:
9873        holder=xdf.pop( "Unnamed: 0" )
9874    if "Unnamed: 1" in xdf.columns:
9875        holder=xdf.pop( "Unnamed: 1" )
9876    if "u_hier_id.1" in xdf.columns:
9877        holder=xdf.pop( "u_hier_id.1" )
9878    if "u_hier_id" in xdf.columns:
9879        holder=xdf.pop( "u_hier_id" )
9880    if not is_t1:
9881        if 'resnetGrade' in xdf.columns:
9882            index_no = xdf.columns.get_loc('resnetGrade')
9883            xdf = xdf.drop( xdf.columns[range(index_no+1)] , axis=1)
9884
9885    if xdf.shape[0] == 2:
9886        xdfcols = xdf.columns
9887        xdf = xdf.iloc[1]
9888        ddnum = xdf.to_numpy()
9889        ddnum = ddnum.reshape([1,ddnum.shape[0]])
9890        newcolnames = xdf.index.to_list()
9891        if len(newcolnames) != ddnum.shape[1]:
9892            print("Cannot Merge : Shape MisMatch " + str( len(newcolnames) ) + " " + str(ddnum.shape[1]))
9893        else:
9894            xdf = pd.DataFrame(ddnum, columns=xdfcols )
9895    if xdf.shape[1] == 0:
9896        return None
9897    if colprefix is not None:
9898        xdf.columns=colprefix + xdf.columns
9899    return pd.concat( [df,xdf], axis=1, ignore_index=False )
def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False):
10005def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False ):
10006    moddersub = re.sub( "[*]","",nrg_modality)
10007    nmdf=pd.DataFrame()
10008    for k in range( hierdfin.shape[0] ):
10009        if progress is not None:
10010            if k % progress == 0:
10011                progger = str( np.round( k / hierdfin.shape[0] * 100 ) )
10012                print( progger, end ="...", flush=True)
10013        temp = mm_wide_csvs[k]
10014        mypartsf = temp.split("T1wHierarchical")
10015        myparts = mypartsf[0]
10016        t1iid = str(mypartsf[1].split("/")[1])
10017        fnsnm = glob.glob(myparts+"/" + nrg_modality + "/*/*" + t1iid + "*wide.csv")
10018        if len( fnsnm ) > 0 :
10019            for y in fnsnm:
10020                temp=read_mm_csv( y, colprefix=moddersub+'_', is_t1=False, separator=separator, verbose=verbose )
10021                if temp is not None:
10022                    nmdf=pd.concat( [nmdf, temp], axis=0, ignore_index=False )
10023    return nmdf
def bind_wide_mm_csvs(mm_wide_csvs, merge=True, separator='-', verbose=0):
10025def bind_wide_mm_csvs( mm_wide_csvs, merge=True, separator='-', verbose = 0 ) :
10026    """
10027    will convert a list of t1w hierarchical csv filenames to a merged dataframe
10028
10029    returns a pair of data frames, the left side having all entries and the
10030        right side having row averaged entries i.e. unique values for each visit
10031
10032    set merge to False to return individual dataframes ( for debugging )
10033
10034    return alldata, row_averaged_data
10035    """
10036    mm_wide_csvs.sort()
10037    if not mm_wide_csvs:
10038        print("No files found with specified pattern")
10039        return
10040    # 1. row-bind the t1whier data
10041    # 2. same for each other modality
10042    # 3. merge the modalities by the keys
10043    hierdf = pd.DataFrame()
10044    for y in mm_wide_csvs:
10045        temp=read_mm_csv( y, colprefix='T1Hier_', separator=separator, is_t1=True )
10046        if temp is not None:
10047            hierdf=pd.concat( [hierdf, temp], axis=0, ignore_index=False )
10048    if verbose > 0:
10049        mypro=50
10050    else:
10051        mypro=None
10052    if verbose > 0:
10053        print("thickness")
10054    thkdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T1w', progress=mypro, verbose=verbose==2)
10055    if verbose > 0:
10056        print("flair")
10057    flairdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T2Flair', progress=mypro, verbose=verbose==2)
10058    if verbose > 0:
10059        print("NM")
10060    nmdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'NM2DMT', progress=mypro, verbose=verbose==2)
10061    if verbose > 0:
10062        print("rsf")
10063    rsfdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'rsfMRI*', progress=mypro, verbose=verbose==2)
10064    if verbose > 0:
10065        print("dti")
10066    dtidf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'DTI*', progress=mypro, verbose=verbose==2 )
10067    if not merge:
10068        return hierdf, thkdf, flairdf, nmdf, rsfdf, dtidf
10069    hierdfmix = hierdf.copy()
10070    modality_df_suffixes = [
10071        (thkdf, "_thk"),
10072        (flairdf, "_flair"),
10073        (nmdf, "_nm"),
10074        (rsfdf, "_rsf"),
10075        (dtidf, "_dti"),
10076    ]
10077    for pair in modality_df_suffixes:
10078        hierdfmix = merge_mm_dataframe(hierdfmix, pair[0], pair[1])
10079    hierdfmix = hierdfmix.replace(r'^\s*$', np.nan, regex=True)
10080    return hierdfmix, hierdfmix.groupby("u_hier_id", as_index=False).mean(numeric_only=True)

will convert a list of t1w hierarchical csv filenames to a merged dataframe

returns a pair of data frames, the left side having all entries and the right side having row averaged entries i.e. unique values for each visit

set merge to False to return individual dataframes ( for debugging )

return alldata, row_averaged_data

def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10082def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10083    try:
10084        hierdf = hierdf.merge(mmdf, on=['sid', 'visitdate', 't1imageuid'], suffixes=("",mm_suffix),how='left')
10085        return hierdf
10086    except KeyError:
10087        return hierdf
def augment_image(x, max_rot=10, nzsd=1):
10089def augment_image( x,  max_rot=10, nzsd=1 ):
10090    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rot*(-1.0), max_rot ), reference=x )
10091    tx = rRotGenerator.transform()
10092    itx = ants.invert_ants_transform(tx)
10093    y = ants.apply_ants_transform_to_image( tx, x, x, interpolation='linear')
10094    y = ants.add_noise_to_image( y,'additivegaussian', [0,nzsd] )
10095    return y, tx, itx
def boot_wmh( flair, t1, t1seg, mmfromconvexhull=0.0, strict=True, probability_mask=None, prior_probability=None, n_simulations=16, random_seed=42, verbose=False):
10097def boot_wmh( flair, t1, t1seg, mmfromconvexhull = 0.0, strict=True,
10098        probability_mask=None, prior_probability=None, n_simulations=16,
10099        random_seed = 42,
10100        verbose=False ) :
10101    import random
10102    random.seed( random_seed )
10103    if verbose and prior_probability is None:
10104        print("augmented flair")
10105    if verbose and prior_probability is not None:
10106        print("augmented flair with prior")
10107    wmh_sum_aug = 0
10108    wmh_sum_prior_aug = 0
10109    augprob = flair * 0.0
10110    augprob_prior = None
10111    if prior_probability is not None:
10112        augprob_prior = flair * 0.0
10113    for n in range(n_simulations):
10114        augflair, tx, itx = augment_image( ants.iMath(flair,"Normalize"), 5, 0.01 )
10115        locwmh = wmh( augflair, t1, t1seg, mmfromconvexhull = mmfromconvexhull,
10116            strict=strict, probability_mask=None, prior_probability=prior_probability )
10117        if verbose:
10118            print( "flair sim: " + str(n) + " vol: " + str( locwmh['wmh_mass'] )+ " vol-prior: " + str( locwmh['wmh_mass_prior'] )+ " snr: " + str( locwmh['wmh_SNR'] ) )
10119        wmh_sum_aug = wmh_sum_aug + locwmh['wmh_mass']
10120        wmh_sum_prior_aug = wmh_sum_prior_aug + locwmh['wmh_mass_prior']
10121        temp = locwmh['WMH_probability_map']
10122        augprob = augprob + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10123        if prior_probability is not None:
10124            temp = locwmh['WMH_posterior_probability_map']
10125            augprob_prior = augprob_prior + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10126    augprob = augprob * (1.0/float( n_simulations ))
10127    if prior_probability is not None:
10128        augprob_prior = augprob_prior * (1.0/float( n_simulations ))
10129    wmh_sum_aug = wmh_sum_aug / float( n_simulations )
10130    wmh_sum_prior_aug = wmh_sum_prior_aug / float( n_simulations )
10131    return{
10132      'flair' : ants.iMath(flair,"Normalize"),
10133      'WMH_probability_map' : augprob,
10134      'WMH_posterior_probability_map' : augprob_prior,
10135      'wmh_mass': wmh_sum_aug,
10136      'wmh_mass_prior': wmh_sum_prior_aug,
10137      'wmh_evr': locwmh['wmh_evr'],
10138      'wmh_SNR': locwmh['wmh_SNR']  }
def threaded_bind_wide_mm_csvs(mm_wide_csvs, n_workers):
10141def threaded_bind_wide_mm_csvs( mm_wide_csvs, n_workers ):
10142    from concurrent.futures import as_completed
10143    from concurrent import futures
10144    import concurrent.futures
10145    def chunks(l, n):
10146        """Yield n number of sequential chunks from l."""
10147        d, r = divmod(len(l), n)
10148        for i in range(n):
10149            si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
10150            yield l[si:si+(d+1 if i < r else d)]
10151    import numpy as np
10152    newx = list( chunks( mm_wide_csvs, n_workers ) )
10153    import pandas as pd
10154    alldf = pd.DataFrame()
10155    alldfavg = pd.DataFrame()
10156    with futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
10157        to_do = []
10158        for group in range(len(newx)) :
10159            future = executor.submit(bind_wide_mm_csvs, newx[group] )
10160            to_do.append(future)
10161        results = []
10162        for future in futures.as_completed(to_do):
10163            res0, res1 = future.result()
10164            alldf=pd.concat(  [alldf, res0 ], axis=0, ignore_index=False )
10165            alldfavg=pd.concat(  [alldfavg, res1 ], axis=0, ignore_index=False )
10166    return alldf, alldfavg
def get_names_from_data_frame(x, demogIn, exclusions=None):
10169def get_names_from_data_frame(x, demogIn, exclusions=None):
10170    """
10171    data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]}
10172    antspymm.get_names_from_data_frame( ['e'], df )
10173    antspymm.get_names_from_data_frame( ['a','e'], df )
10174    antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )
10175    """
10176    # Check if x is a string and convert it to a list
10177    if isinstance(x, str):
10178        x = [x]
10179    def get_unique( qq ):
10180        unique = []
10181        for number in qq:
10182            if number in unique:
10183                continue
10184            else:
10185                unique.append(number)
10186        return unique
10187    outnames = list(demogIn.columns[demogIn.columns.str.contains(x[0])])
10188    if len(x) > 1:
10189        for y in x[1:]:
10190            outnames = [i for i in outnames if y in i]
10191    outnames = get_unique( outnames )
10192    if exclusions is not None:
10193        toexclude = [name for name in outnames if exclusions[0] in name ]
10194        if len(exclusions) > 1:
10195            for zz in exclusions[1:]:
10196                toexclude.extend([name for name in outnames if zz in name ])
10197        if len(toexclude) > 0:
10198            outnames = [name for name in outnames if name not in toexclude]
10199    return outnames

data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]} antspymm.get_names_from_data_frame( ['e'], df ) antspymm.get_names_from_data_frame( ['a','e'], df ) antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )

def average_mm_df(jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False):
10202def average_mm_df( jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False ):
10203    """
10204    jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )
10205    """
10206
10207    jmm = jmm_in.copy()
10208    dxcols=['subjectid1','subjectid2','modalityid','joinid','correlation','distance']
10209    joinDiagnostics = pd.DataFrame( columns = dxcols )
10210    nanList=[math.nan]
10211    def rob(x, y=0.99):
10212        x[x > np.quantile(x, y, nan_policy="omit")] = np.nan
10213        return x
10214
10215    jmm = jmm.replace(r'^\s*$', np.nan, regex=True)
10216
10217    if verbose:
10218        print("do rsfMRI")
10219    # here - we first have to average within each row
10220    dt0 = get_names_from_data_frame(["rsfMRI"], jmm, exclusions=["Unnamed", "rsfMRI_LR", "rsfMRI_RL"])
10221    dt1 = get_names_from_data_frame(["rsfMRI_RL"], jmm, exclusions=["Unnamed"])
10222    if len( dt0 ) > 0 and len( dt1 ) > 0:
10223        flid = dt0[0]
10224        wrows = []
10225        for i in range(jmm.shape[0]):
10226            if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) :
10227                wrows.append(i)
10228        for k in wrows:
10229            v1 = jmm.iloc[k][dt0[1:]].astype(float)
10230            v2 = jmm.iloc[k][dt1[1:]].astype(float)
10231            vvec = [v1[0], v2[0]]
10232            if any(~np.isnan(vvec)):
10233                mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10234                jmm.iloc[k][dt0[0]] = 'rsfMRI'
10235                if len(mynna) == 1:
10236                    if mynna[0] == 0:
10237                        jmm.iloc[k][dt0[1:]] = v1
10238                    if mynna[0] == 1:
10239                        jmm.iloc[k][dt0[1:]] = v2
10240                elif len(mynna) > 1:
10241                    if len(v2) > diagnostic_n:
10242                        v1dx=v1[0:diagnostic_n]
10243                        v2dx=v2[0:diagnostic_n]
10244                    else :
10245                        v1dx=v1
10246                        v2dx=v2
10247                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10248                    mycorr = np.corrcoef( v1dx.values, v2dx.values )[0,1]
10249                    myerr=np.sqrt(np.mean((v1dx.values - v2dx.values)**2))
10250                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'rsfMRI','colavg',mycorr,myerr]
10251                    if mycorr > corr_thresh:
10252                        jmm.loc[k, dt0[1:]] = v1.values*0.5 + v2.values*0.5
10253                    else:
10254                        jmm.loc[k, dt0[1:]] = nanList * len(v1)
10255                    if verbose:
10256                        print( joinDiagnosticsLoc )
10257                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10258
10259    if verbose:
10260        print("do DTI")
10261    # here - we first have to average within each row
10262    dt0 = get_names_from_data_frame(["DTI"], jmm, exclusions=["Unnamed", "DTI_LR", "DTI_RL"])
10263    dt1 = get_names_from_data_frame(["DTI_LR"], jmm, exclusions=["Unnamed"])
10264    dt2 = get_names_from_data_frame( ["DTI_RL"], jmm, exclusions=["Unnamed"])
10265    flid = dt0[0]
10266    wrows = []
10267    for i in range(jmm.shape[0]):
10268        if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) or not pd.isna(jmm[dt2[1]][i]):
10269            wrows.append(i)
10270    for k in wrows:
10271        v1 = jmm.loc[k, dt0[1:]].astype(float)
10272        v2 = jmm.loc[k, dt1[1:]].astype(float)
10273        v3 = jmm.loc[k, dt2[1:]].astype(float)
10274        checkcol = dt0[5]
10275        if not np.isnan(v1[checkcol]):
10276            if v1[checkcol] < 0.25:
10277                v1.replace(np.nan, inplace=True)
10278        checkcol = dt1[5]
10279        if not np.isnan(v2[checkcol]):
10280            if v2[checkcol] < 0.25:
10281                v2.replace(np.nan, inplace=True)
10282        checkcol = dt2[5]
10283        if not np.isnan(v3[checkcol]):
10284            if v3[checkcol] < 0.25:
10285                v3.replace(np.nan, inplace=True)
10286        vvec = [v1[0], v2[0], v3[0]]
10287        if any(~np.isnan(vvec)):
10288            mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10289            jmm.loc[k, dt0[0]] = 'DTI'
10290            if len(mynna) == 1:
10291                if mynna[0] == 0:
10292                    jmm.loc[k, dt0[1:]] = v1
10293                if mynna[0] == 1:
10294                    jmm.loc[k, dt0[1:]] = v2
10295                if mynna[0] == 2:
10296                    jmm.loc[k, dt0[1:]] = v3
10297            elif len(mynna) > 1:
10298                if mynna[0] == 0:
10299                    jmm.loc[k, dt0[1:]] = v1
10300                else:
10301                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10302                    mycorr = np.corrcoef( v2[0:diagnostic_n].values, v3[0:diagnostic_n].values )[0,1]
10303                    myerr=np.sqrt(np.mean((v2[0:diagnostic_n].values - v3[0:diagnostic_n].values)**2))
10304                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'DTI','colavg',mycorr,myerr]
10305                    if mycorr > corr_thresh:
10306                        jmm.loc[k, dt0[1:]] = v2.values*0.5 + v3.values*0.5
10307                    else: #
10308                        jmm.loc[k, dt0[1:]] = nanList * len( dt0[1:] )
10309                    if verbose:
10310                        print( joinDiagnosticsLoc )
10311                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10312
10313
10314    # first task - sort by u_hier_id
10315    jmm = jmm.sort_values( "u_hier_id" )
10316    # get rid of junk columns
10317    badnames = get_names_from_data_frame( ['Unnamed'], jmm )
10318    jmm=jmm.drop(badnames, axis=1)
10319    jmm=jmm.set_index("u_hier_id",drop=False)
10320    # 2nd - get rid of duplicated u_hier_id
10321    jmmUniq = jmm.drop_duplicates( subset="u_hier_id" ) # fast and easy
10322    # for each modality, count which ids have more than one
10323    mod_names = get_valid_modalities()
10324    for mod_name in mod_names:
10325        fl_names = get_names_from_data_frame([mod_name], jmm,
10326            exclusions=['Unnamed',"DTI_LR","DTI_RL","rsfMRI_RL","rsfMRI_LR"])
10327        if len( fl_names ) > 1:
10328            if verbose:
10329                print(mod_name)
10330                print(fl_names)
10331            fl_id = fl_names[0]
10332            n_names = len(fl_names)
10333            locvec = jmm[fl_names[n_names-1]].astype(float)
10334            boolvec=~pd.isna(locvec)
10335            jmmsub = jmm[boolvec][ ['u_hier_id']+fl_names]
10336            my_tbl = Counter(jmmsub['u_hier_id'])
10337            gtoavg = [name for name in my_tbl.keys() if my_tbl[name] == 1]
10338            gtoavgG1 = [name for name in my_tbl.keys() if my_tbl[name] > 1]
10339            if verbose:
10340                print("Join 1")
10341            jmmsub1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavg)][['u_hier_id']+fl_names]
10342            for u in gtoavg:
10343                jmmUniq.loc[u][fl_names[1:]] = jmmsub1.loc[u][fl_names[1:]]
10344            if verbose and len(gtoavgG1) > 1:
10345                print("Join >1")
10346            jmmsubG1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavgG1)][['u_hier_id']+fl_names]
10347            for u in gtoavgG1:
10348                temp = jmmsubG1.loc[u][ ['u_hier_id']+fl_names ]
10349                dropnames = get_names_from_data_frame( ['MM.ID'], temp )
10350                tempVec = temp.drop(columns=dropnames)
10351                joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10352                id1=temp[fl_id].iloc[0]
10353                id2=temp[fl_id].iloc[1]
10354                v1=tempVec.iloc[0][1:].astype(float).to_numpy()
10355                v2=tempVec.iloc[1][1:].astype(float).to_numpy()
10356                if len(v2) > diagnostic_n:
10357                    v1=v1[0:diagnostic_n]
10358                    v2=v2[0:diagnostic_n]
10359                mycorr = np.corrcoef( v1, v2 )[0,1]
10360                # mycorr=temparr[np.triu_indices_from(temparr, k=1)].mean()
10361                myerr=np.sqrt(np.mean((v1 - v2)**2))
10362                joinDiagnosticsLoc.iloc[0] = [id1,id2,mod_name,'rowavg',mycorr,myerr]
10363                if verbose:
10364                    print( joinDiagnosticsLoc )
10365                temp = jmmsubG1.loc[u][fl_names[1:]].astype(float)
10366                if mycorr > corr_thresh or len( v1 ) < 10:
10367                    jmmUniq.loc[u][fl_names[1:]] = temp.mean(axis=0)
10368                else:
10369                    jmmUniq.loc[u][fl_names[1:]] = nanList * temp.shape[1]
10370                joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], 
10371                                            axis=0, ignore_index=False )
10372
10373    return jmmUniq, jmm, joinDiagnostics

jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )

def quick_viz_mm_nrg( sourcedir, projectid, sid, dtid, extract_brain=True, slice_factor=0.55, post=False, original_sourcedir=None, filename=None, verbose=True):
10377def quick_viz_mm_nrg(
10378    sourcedir, # root folder
10379    projectid, # project name
10380    sid , # subject unique id
10381    dtid, # date
10382    extract_brain=True,
10383    slice_factor = 0.55,
10384    post = False,
10385    original_sourcedir = None,
10386    filename = None, # output path
10387    verbose = True
10388):
10389    """
10390    This function creates visualizations of brain images for a specific subject in a project using ANTsPy.
10391
10392    Args:
10393
10394    sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)
10395    
10396    projectid (str): Project name.
10397    
10398    sid (str): Subject unique id.
10399    
10400    dtid (str): Date.
10401    
10402    extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.
10403    
10404    slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.
10405
10406    post ( bool ) : if True, will visualize example post-processing results.
10407    
10408    original_sourcedir (str): Root folder for original data (used if post=True)
10409    
10410    filename (str): Output path with extension (.png)
10411    
10412    verbose (bool): If True, information will be printed while running the function. Default is True.
10413
10414    Returns:
10415    None
10416
10417    """
10418    iid='*'
10419    import glob as glob
10420    from os.path import exists
10421    import ants
10422    temp = sourcedir.split( "/" )
10423    subjectrootpath = os.path.join(sourcedir, projectid, sid, dtid)
10424    if verbose:
10425        print( 'subjectrootpath' )
10426        print( subjectrootpath )
10427    t1_search_path = os.path.join(subjectrootpath, "T1w", "*", "*nii.gz")
10428    if verbose:
10429        print(f"t1 search path: {t1_search_path}")
10430    t1fn = glob.glob(t1_search_path)
10431    if len( t1fn ) < 1:
10432        raise ValueError('quick_viz_mm_nrg cannot find the T1w @ ' + subjectrootpath )
10433    vizlist=[]
10434    undlist=[]
10435    nrg_modality_list = [ 'T1w', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10436    if post:
10437        nrg_modality_list = [ 'T1wHierarchical', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10438    for nrgNum in [0,1,2,3,4,5]:
10439        underlay = None
10440        overmodX = nrg_modality_list[nrgNum]
10441        if  'T1w' in overmodX :
10442            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10443            if post:
10444                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*brain_n4_dnz.nii.gz")
10445                mod_search_path_ol = os.path.join(subjectrootpath, overmodX, iid, "*thickness_image.nii.gz" )
10446                mod_search_path_ol = re.sub( "T1wHierarchical","T1w",mod_search_path_ol)
10447                myol = glob.glob(mod_search_path_ol)
10448                if len( myol ) > 0:
10449                    temper = find_most_recent_file( myol )[0]
10450                    underlay = ants.image_read(  temper )
10451                    if verbose:
10452                        print("T1w overlay " + temper )
10453                    underlay = underlay * ants.threshold_image( underlay, 0.2, math.inf )
10454            myimgsr = glob.glob(mod_search_path)
10455            if len( myimgsr ) == 0:
10456                if verbose:
10457                    print("No t1 images: " + sid + dtid )
10458                return None
10459            myimgsr=find_most_recent_file( myimgsr )[0]
10460            vimg=ants.image_read( myimgsr )
10461        elif  'T2Flair' in overmodX :
10462            if verbose:
10463                print("search flair")
10464            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10465            if post and original_sourcedir is not None:
10466                if verbose:
10467                    print("post in flair")
10468                mysubdir = os.path.join(original_sourcedir, projectid, sid, dtid)
10469                mod_search_path_under = os.path.join(mysubdir, overmodX, iid, "*T2Flair*.nii.gz")
10470                if verbose:
10471                    print("post in flair mod_search_path_under " + mod_search_path_under)
10472                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*wmh.nii.gz")
10473                if verbose:
10474                    print("post in flair mod_search_path " + mod_search_path )
10475                myimgul = glob.glob(mod_search_path_under)
10476                if len( myimgul ) > 0:
10477                    myimgul = find_most_recent_file( myimgul )[0]
10478                    if verbose:
10479                        print("Flair  " + myimgul )
10480                    vimg = ants.image_read( myimgul )
10481                    myol = glob.glob(mod_search_path)
10482                    if len( myol ) == 0:
10483                        underlay = myimgsr * 0.0
10484                    else:
10485                        myol = find_most_recent_file( myol )[0]
10486                        if verbose:
10487                            print("Flair overlay " + myol )
10488                        underlay=ants.image_read( myol )
10489                        underlay=underlay*ants.threshold_image(underlay,0.05,math.inf)
10490                else:
10491                    vimg = noizimg.clone()
10492                    underlay = vimg * 0.0
10493            if original_sourcedir is None:
10494                myimgsr = glob.glob(mod_search_path)
10495                if len( myimgsr ) == 0:
10496                    vimg = noizimg.clone()
10497                else:
10498                    myimgsr=find_most_recent_file( myimgsr )[0]
10499                    vimg=ants.image_read( myimgsr )
10500        elif overmodX == 'DTI':
10501            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10502            if post:
10503                mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*fa.nii.gz")
10504            myimgsr = glob.glob(mod_search_path)
10505            if len( myimgsr ) > 0:
10506                myimgsr=find_most_recent_file( myimgsr )[0]
10507                vimg=ants.image_read( myimgsr )
10508            else:
10509                if verbose:
10510                    print("No " + overmodX)
10511                vimg = noizimg.clone()
10512        elif overmodX == 'DTI2':
10513            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10514            myimgsr = glob.glob(mod_search_path)
10515            if len( myimgsr ) > 0:
10516                myimgsr.sort()
10517                myimgsr=myimgsr[len(myimgsr)-1]
10518                vimg=ants.image_read( myimgsr )
10519            else:
10520                if verbose:
10521                    print("No " + overmodX)
10522                vimg = noizimg.clone()
10523        elif overmodX == 'NM2DMT':
10524            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10525            if post:
10526                mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*NM_avg.nii.gz" )
10527            myimgsr = glob.glob(mod_search_path)
10528            if len( myimgsr ) > 0:
10529                myimgsr0=myimgsr[0]
10530                vimg=ants.image_read( myimgsr0 )
10531                for k in range(1,len(myimgsr)):
10532                    temp = ants.image_read( myimgsr[k])
10533                    vimg=vimg+ants.resample_image_to_target(temp,vimg)
10534            else:
10535                if verbose:
10536                    print("No " + overmodX)
10537                vimg = noizimg.clone()
10538        elif overmodX == 'rsfMRI':
10539            mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*nii.gz")
10540            if post:
10541                mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_meanBold.nii.gz" )
10542                mod_search_path_ol = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_DefaultMode.nii.gz" )
10543                myol = glob.glob(mod_search_path_ol)
10544                if len( myol ) > 0:
10545                    myol = find_most_recent_file( myol )[0]
10546                    underlay = ants.image_read( myol )
10547                    if verbose:
10548                        print("BOLD overlay " + myol )
10549                    underlay = underlay * ants.threshold_image( underlay, 0.1, math.inf )
10550            myimgsr = glob.glob(mod_search_path)
10551            if len( myimgsr ) > 0:
10552                myimgsr=find_most_recent_file( myimgsr )[0]
10553                vimg=mm_read_to_3d( myimgsr )
10554            else:
10555                if verbose:
10556                    print("No " + overmodX)
10557                vimg = noizimg.clone()
10558        elif overmodX == 'perf':
10559            mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*nii.gz")
10560            if post:
10561                mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*cbf.nii.gz")
10562            myimgsr = glob.glob(mod_search_path)
10563            if len( myimgsr ) > 0:
10564                myimgsr=find_most_recent_file( myimgsr )[0]
10565                vimg=mm_read_to_3d( myimgsr )
10566            else:
10567                if verbose:
10568                    print("No " + overmodX)
10569                vimg = noizimg.clone()
10570        else :
10571            if verbose:
10572                print("Something else here")
10573            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10574            myimgsr = glob.glob(mod_search_path)
10575            if post:
10576                myimgsr=[]
10577            if len( myimgsr ) > 0:
10578                myimgsr=find_most_recent_file( myimgsr )[0]
10579                vimg=ants.image_read( myimgsr )
10580            else:
10581                if verbose:
10582                    print("No " + overmodX)
10583                vimg = noizimg
10584        if True:
10585            if extract_brain and overmodX == 'T1w' and post == False:
10586                vimg = vimg * antspyt1w.brain_extraction(vimg)
10587            if verbose:
10588                print(f"modality search path: {myimgsr}" + " num: " + str(nrgNum))
10589            if vimg.dimension == 4 and ( overmodX == "DTI2"  ):
10590                ttb0, ttdw=get_average_dwi_b0(vimg)
10591                vimg = ttdw
10592            elif vimg.dimension == 4 and overmodX == "DTI":
10593                ttb0, ttdw=get_average_dwi_b0(vimg)
10594                vimg = ttb0
10595            elif vimg.dimension == 4 :
10596                vimg=ants.get_average_of_timeseries(vimg)
10597            msk=ants.get_mask(vimg)
10598            if overmodX == 'T2Flair':
10599                msk=vimg*0+1
10600            if underlay is not None:
10601                print( overmodX + " has underlay" )
10602            else:
10603                underlay = vimg * 0.0
10604            if nrgNum == 0:
10605                refimg=ants.image_clone( vimg )
10606                noizimg = ants.add_noise_to_image( refimg*0, 'additivegaussian', [100,1] )
10607                vizlist.append( vimg )
10608                undlist.append( underlay )
10609            else:
10610                vimg = ants.iMath( vimg, 'TruncateIntensity',0.01,0.98)
10611                vizlist.append( ants.iMath( vimg, 'Normalize' ) * 255 )
10612                undlist.append( underlay )
10613
10614    # mask & crop systematically ...
10615    msk = ants.get_mask( refimg )
10616    refimg = ants.crop_image( refimg, msk )
10617
10618    for jj in range(len(vizlist)):
10619        vizlist[jj]=ants.resample_image_to_target( vizlist[jj], refimg )
10620        undlist[jj]=ants.resample_image_to_target( undlist[jj], refimg )
10621        print( 'viz: ' + str( jj ) )
10622        print( vizlist[jj] )
10623        print( 'und: ' + str( jj ) )
10624        print( undlist[jj] )
10625
10626
10627    xyz = [None]*3
10628    for i in range(3):
10629        if xyz[i] is None:
10630            xyz[i] = int(refimg.shape[i] * slice_factor )
10631
10632    if verbose:
10633        print('slice positions')
10634        print( xyz )
10635
10636    ants.plot_ortho_stack( vizlist, overlays=undlist, crop=False, reorient=False, filename=filename, xyz=xyz, orient_labels=False )
10637    return
10638    # listlen = len( vizlist )
10639    # vizlist = np.asarray( vizlist )
10640    if show_it is not None:
10641        filenameout=None
10642        if verbose:
10643            print( show_it )
10644        for a in [0,1,2]:
10645            n=int(np.round( refimg.shape[a] * slice_factor ))
10646            slices=np.repeat( int(n), listlen  )
10647            if isinstance(show_it,str):
10648                filenameout=show_it+'_ax'+str(int(a))+'_sl'+str(n)+'.png'
10649                if verbose:
10650                    print( filenameout )
10651#            ants.plot_grid(vizlist.reshape(2,3), slices.reshape(2,3), title='MM Subject ' + sid + ' ' + dtid, rfacecolor='white', axes=a, filename=filenameout )
10652    if verbose:
10653        print("viz complete.")
10654    return vizlist

This function creates visualizations of brain images for a specific subject in a project using ANTsPy.

Args:

sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)

projectid (str): Project name.

sid (str): Subject unique id.

dtid (str): Date.

extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.

slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.

post ( bool ) : if True, will visualize example post-processing results.

original_sourcedir (str): Root folder for original data (used if post=True)

filename (str): Output path with extension (.png)

verbose (bool): If True, information will be printed while running the function. Default is True.

Returns: None

def blind_image_assessment( image, viz_filename=None, title=False, pull_rank=False, resample=None, n_to_skip=10, verbose=False):
10657def blind_image_assessment(
10658    image,
10659    viz_filename=None,
10660    title=False,
10661    pull_rank=False,
10662    resample=None,
10663    n_to_skip = 10,
10664    verbose=False
10665):
10666    """
10667    quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D.  produces a png and csv where csv contains:
10668
10669    * reflection error ( estimates asymmetry )
10670
10671    * brisq ( blind quality assessment )
10672
10673    * patch eigenvalue ratio ( blind quality assessment )
10674
10675    * PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)
10676
10677    * mask volume ( estimates foreground object size )
10678
10679    * spacing
10680
10681    * dimension after cropping by mask
10682
10683    image : character or image object usually a nifti image
10684
10685    viz_filename : character for a png output image
10686
10687    title : display a summary title on the png
10688
10689    pull_rank : boolean
10690
10691    resample : None, numeric max or min, resamples image to isotropy
10692
10693    n_to_skip : 10 by default; samples time series every n_to_skip volume
10694
10695    verbose : boolean
10696
10697    """
10698    import glob as glob
10699    from os.path import exists
10700    import ants
10701    import matplotlib.pyplot as plt
10702    from PIL import Image
10703    from pathlib import Path
10704    import json
10705    import re
10706    from dipy.io.gradients import read_bvals_bvecs
10707    mystem=''
10708    if isinstance(image,list):
10709        isfilename=isinstance( image[0], str)
10710        image = image[0]
10711    else:
10712        isfilename=isinstance( image, str)
10713    outdf = pd.DataFrame()
10714    mymeta = None
10715    MagneticFieldStrength = None
10716    image_filename=''
10717    if isfilename:
10718        image_filename = image
10719        if isinstance(image,list):
10720            image_filename=image[0]
10721        json_name = re.sub(".nii.gz",".json",image_filename)
10722        if exists( json_name ):
10723            try:
10724                with open(json_name, 'r') as fcc_file:
10725                    mymeta = json.load(fcc_file)
10726                    if verbose:
10727                        print(json.dumps(mymeta, indent=4))
10728                    fcc_file.close()
10729            except:
10730                pass
10731        mystem=Path( image ).stem
10732        mystem=Path( mystem ).stem
10733        image_reference = ants.image_read( image )
10734        image = ants.image_read( image )
10735    else:
10736        image_reference = ants.image_clone( image )
10737    ntimepoints = 1
10738    bvalueMax=None
10739    bvecnorm=None
10740    if image_reference.dimension == 4:
10741        ntimepoints = image_reference.shape[3]
10742        if "DTI" in image_filename:
10743            myTSseg = segment_timeseries_by_meanvalue( image_reference )
10744            image_b0, image_dwi = get_average_dwi_b0( image_reference, fast=True )
10745            image_b0 = ants.iMath( image_b0, 'Normalize' )
10746            image_dwi = ants.iMath( image_dwi, 'Normalize' )
10747            bval_name = re.sub(".nii.gz",".bval",image_filename)
10748            bvec_name = re.sub(".nii.gz",".bvec",image_filename)
10749            if exists( bval_name ) and exists( bvec_name ):
10750                bvals, bvecs = read_bvals_bvecs( bval_name , bvec_name  )
10751                bvalueMax = bvals.max()
10752                bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
10753                bvecnorm = bvecnorm.max()
10754        else:
10755            image_b0 = ants.get_average_of_timeseries( image_reference ).iMath("Normalize")
10756    else:
10757        image_compare = ants.smooth_image( image_reference, 3, sigma_in_physical_coordinates=False )
10758    for jjj in range(0,ntimepoints,n_to_skip):
10759        modality='unknown'
10760        if "rsfMRI" in image_filename:
10761            modality='rsfMRI'
10762        elif "perf" in image_filename:
10763            modality='perf'
10764        elif "DTI" in image_filename:
10765            modality='DTI'
10766        elif "T1w" in image_filename:
10767            modality='T1w'
10768        elif "T2Flair" in image_filename:
10769            modality='T2Flair'
10770        elif "NM2DMT" in image_filename:
10771            modality='NM2DMT'
10772        if image_reference.dimension == 4:
10773            image = ants.slice_image( image_reference, idx=int(jjj), axis=3 )
10774            if "DTI" in image_filename:
10775                if jjj in myTSseg['highermeans']:
10776                    image_compare = ants.image_clone( image_b0 )
10777                    modality='DTIb0'
10778                else:
10779                    image_compare = ants.image_clone( image_dwi )
10780                    modality='DTIdwi'
10781            else:
10782                image_compare = ants.image_clone( image_b0 )
10783        # image = ants.iMath( image, 'TruncateIntensity',0.01,0.995)
10784        minspc = np.min(ants.get_spacing(image))
10785        maxspc = np.max(ants.get_spacing(image))
10786        if resample is not None:
10787            if resample == 'min':
10788                if minspc < 1e-12:
10789                    minspc = np.max(ants.get_spacing(image))
10790                newspc = np.repeat( minspc, 3 )
10791            elif resample == 'max':
10792                newspc = np.repeat( maxspc, 3 )
10793            else:
10794                newspc = np.repeat( resample, 3 )
10795            image = ants.resample_image( image, newspc )
10796            image_compare = ants.resample_image( image_compare, newspc )
10797        else:
10798            # check for spc close to zero
10799            spc = list(ants.get_spacing(image))
10800            for spck in range(len(spc)):
10801                if spc[spck] < 1e-12:
10802                    spc[spck]=1
10803            ants.set_spacing( image, spc )
10804            ants.set_spacing( image_compare, spc )
10805        # if "NM2DMT" in image_filename or "FIXME" in image_filename or "SPECT" in image_filename or "UNKNOWN" in image_filename:
10806        minspc = np.min(ants.get_spacing(image))
10807        maxspc = np.max(ants.get_spacing(image))
10808        msk = ants.threshold_image( ants.iMath(image,'Normalize'), 0.15, 1.0 )
10809        # else:
10810        #    msk = ants.get_mask( image )
10811        msk = ants.morphology(msk, "close", 3 )
10812        bgmsk = msk*0+1-msk
10813        mskdil = ants.iMath(msk, "MD", 4 )
10814        # ants.plot_ortho( image, msk, crop=False )
10815        nvox = int( msk.sum() )
10816        spc = ants.get_spacing( image )
10817        org = ants.get_origin( image )
10818        if ( nvox > 0 ):
10819            image = ants.crop_image( image, mskdil ).iMath("Normalize")
10820            msk = ants.crop_image( msk, mskdil ).iMath("Normalize")
10821            bgmsk = ants.crop_image( bgmsk, mskdil ).iMath("Normalize")
10822            image_compare = ants.crop_image( image_compare, mskdil ).iMath("Normalize")           
10823            npatch = int( np.round(  0.1 * nvox ) )
10824            npatch = np.min(  [512,npatch ] )
10825            patch_shape = []
10826            for k in range( 3 ):
10827                p = int( 32.0 / ants.get_spacing( image  )[k] )
10828                if p > int( np.round( image.shape[k] * 0.5 ) ):
10829                    p = int( np.round( image.shape[k] * 0.5 ) )
10830                patch_shape.append( p )
10831            if verbose:
10832                print(image)
10833                print( patch_shape )
10834                print( npatch )
10835            myevr = math.nan # dont want to fail if something odd happens in patch extraction
10836            try:
10837                myevr = antspyt1w.patch_eigenvalue_ratio( image, npatch, patch_shape,
10838                    evdepth = 0.9, mask=msk )
10839            except:
10840                pass
10841            if pull_rank:
10842                image = ants.rank_intensity(image)
10843            imagereflect = ants.reflect_image(image, axis=0)
10844            asym_err = ( image - imagereflect ).abs().mean()
10845            # estimate noise by center cropping, denoizing and taking magnitude of difference
10846            nocrop=False
10847            if image.dimension == 3:
10848                if image.shape[2] == 1:
10849                    nocrop=True        
10850            if maxspc/minspc > 10:
10851                nocrop=True
10852            if nocrop:
10853                mycc = ants.image_clone( image )
10854            else:
10855                mycc = antspyt1w.special_crop( image,
10856                    ants.get_center_of_mass( msk *0 + 1 ), patch_shape )
10857            myccd = ants.denoise_image( mycc, p=1,r=1,noise_model='Gaussian' )
10858            noizlevel = ( mycc - myccd ).abs().mean()
10859    #        ants.plot_ortho( image, crop=False, filename=viz_filename, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
10860    #        from brisque import BRISQUE
10861    #        obj = BRISQUE(url=False)
10862    #        mybrisq = obj.score( np.array( Image.open( viz_filename )) )
10863            msk_vol = msk.sum() * np.prod( spc )
10864            bgstd = image[ bgmsk == 1 ].std()
10865            fgmean = image[ msk == 1 ].mean()
10866            bgmean = image[ bgmsk == 1 ].mean()
10867            snrref = fgmean / bgstd
10868            cnrref = ( fgmean - bgmean ) / bgstd
10869            psnrref = antspynet.psnr(  image_compare, image  )
10870            ssimref = antspynet.ssim(  image_compare, image  )
10871            if nocrop:
10872                mymi = math.inf
10873            else:
10874                mymi = ants.image_mutual_information( image_compare, image )
10875        else:
10876            msk_vol = 0
10877            myevr = mymi = ssimref = psnrref = cnrref = asym_err = noizlevel = math.nan
10878            
10879        mriseries=None
10880        mrimfg=None
10881        mrimodel=None
10882        mriSAR=None
10883        BandwidthPerPixelPhaseEncode=None
10884        PixelBandwidth=None
10885        if mymeta is not None:
10886            # mriseries=mymeta['']
10887            try:
10888                mrimfg=mymeta['Manufacturer']
10889            except:
10890                pass
10891            try:
10892                mrimodel=mymeta['ManufacturersModelName']
10893            except:
10894                pass
10895            try:
10896                MagneticFieldStrength=mymeta['MagneticFieldStrength']
10897            except:
10898                pass
10899            try:
10900                PixelBandwidth=mymeta['PixelBandwidth']
10901            except:
10902                pass
10903            try:
10904                BandwidthPerPixelPhaseEncode=mymeta['BandwidthPerPixelPhaseEncode']
10905            except:
10906                pass
10907            try:
10908                mriSAR=mymeta['SAR']
10909            except:
10910                pass
10911        ttl=mystem + ' '
10912        ttl=''
10913        ttl=ttl + "NZ: " + "{:0.4f}".format(noizlevel) + " SNR: " + "{:0.4f}".format(snrref) + " CNR: " + "{:0.4f}".format(cnrref) + " PS: " + "{:0.4f}".format(psnrref)+ " SS: " + "{:0.4f}".format(ssimref) + " EVR: " + "{:0.4f}".format(myevr)+ " MI: " + "{:0.4f}".format(mymi)
10914        if viz_filename is not None and ( jjj == 0 or (jjj % 30 == 0) ) and image.shape[2] < 685:
10915            viz_filename_use = re.sub( ".png", "_slice"+str(jjj).zfill(4)+".png", viz_filename )
10916            ants.plot_ortho( image, crop=False, filename=viz_filename_use, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0,  title=ttl, titlefontsize=12, title_dy=-0.02,textfontcolor='red' )
10917        df = pd.DataFrame([[ 
10918            mystem, 
10919            image_reference.dimension, 
10920            noizlevel, snrref, cnrref, psnrref, ssimref, mymi, asym_err, myevr, msk_vol, 
10921            spc[0], spc[1], spc[2],org[0], org[1], org[2], 
10922            image.shape[0], image.shape[1], image.shape[2], ntimepoints, 
10923            jjj, modality, mriseries, mrimfg, mrimodel, MagneticFieldStrength, mriSAR, PixelBandwidth, BandwidthPerPixelPhaseEncode, bvalueMax, bvecnorm ]], 
10924            columns=[
10925                'filename', 
10926                'dimensionality',
10927                'noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol', 'spc0','spc1','spc2','org0','org1','org2','dimx','dimy','dimz','dimt','slice','modality', 'mriseries', 'mrimfg', 'mrimodel', 'mriMagneticFieldStrength', 'mriSAR', 'mriPixelBandwidth', 'mriPixelBandwidthPE', 'dti_bvalueMax', 'dti_bvecnorm' ])
10928        outdf = pd.concat( [outdf, df ], axis=0, ignore_index=False )
10929        if verbose:
10930            print( outdf )
10931    if viz_filename is not None:
10932        csvfn = re.sub( "png", "csv", viz_filename )
10933        outdf.to_csv( csvfn )
10934    return outdf

quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D. produces a png and csv where csv contains:

  • reflection error ( estimates asymmetry )

  • brisq ( blind quality assessment )

  • patch eigenvalue ratio ( blind quality assessment )

  • PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)

  • mask volume ( estimates foreground object size )

  • spacing

  • dimension after cropping by mask

image : character or image object usually a nifti image

viz_filename : character for a png output image

title : display a summary title on the png

pull_rank : boolean

resample : None, numeric max or min, resamples image to isotropy

n_to_skip : 10 by default; samples time series every n_to_skip volume

verbose : boolean

def average_blind_qc_by_modality(qc_full, verbose=False):
10960def average_blind_qc_by_modality(qc_full,verbose=False):
10961    """
10962    Averages time series qc results to yield one entry per image. this also filters to "known" columns.
10963
10964    Args:
10965    qc_full: pandas dataframe containing the full qc data.
10966
10967    Returns:
10968    pandas dataframe containing the processed qc data.
10969    """
10970    qc_full = remove_unwanted_columns( qc_full )
10971    # Get unique modalities
10972    modalities = qc_full['modality'].unique()
10973    modalities = modalities[modalities != 'unknown']
10974    # Get unique ids
10975    uid = qc_full['filename']
10976    to_average = uid.unique()
10977    meta = pd.DataFrame(columns=qc_full.columns )
10978    # Process each unique id
10979    n = len(to_average)
10980    for k in range(n):
10981        if verbose:
10982            if k % 100 == 0:
10983                progger = str( np.round( k / n * 100 ) )
10984                print( progger, end ="...", flush=True)
10985        m1sel = uid == to_average[k]
10986        if sum(m1sel) > 1:
10987            # If more than one entry for id, take the average of continuous columns,
10988            # maximum of the slice column, and the first entry of the other columns
10989            mfsub = process_dataframe_generalized(qc_full[m1sel],'filename')
10990        else:
10991            mfsub = qc_full[m1sel]
10992        meta.loc[k] = mfsub.iloc[0]
10993    meta['modality'] = meta['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
10994    return meta

Averages time series qc results to yield one entry per image. this also filters to "known" columns.

Args: qc_full: pandas dataframe containing the full qc data.

Returns: pandas dataframe containing the processed qc data.

def best_mmm(mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
1599def best_mmm( mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
1600    """
1601    Selects the best repeats per modality.
1602
1603    Args:
1604    wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')
1605
1606    mysep (str, optional): the separator used in the image file names. Defaults to '-'.
1607
1608    outlier_name : column name for outlier score
1609
1610    verbose (bool, optional): default True
1611
1612    Returns:
1613
1614    list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.
1615
1616    """
1617#    mmdf = mmdf.astype(str)
1618    mmdf[outlier_column]=mmdf[outlier_column].astype(float)
1619    msel = mmdf['modality'] == wmod
1620    if wmod == 'rsfMRI':
1621        msel1 = mmdf['modality'] == 'rsfMRI'
1622        msel2 = mmdf['modality'] == 'rsfMRI_LR'
1623        msel3 = mmdf['modality'] == 'rsfMRI_RL'
1624        msel = msel1 | msel2
1625        msel = msel | msel3
1626    if wmod == 'DTI':
1627        msel1 = mmdf['modality'] == 'DTI'
1628        msel2 = mmdf['modality'] == 'DTI_LR'
1629        msel3 = mmdf['modality'] == 'DTI_RL'
1630        msel4 = mmdf['modality'] == 'DTIdwi'
1631        msel5 = mmdf['modality'] == 'DTIb0'
1632        msel = msel1 | msel2 | msel3 | msel4 | msel5
1633    if sum(msel) == 0:
1634        return {'raw': None, 'filt': None}
1635    metasub = mmdf[msel].copy()
1636
1637    if verbose:
1638        print(f"{wmod} {(metasub.shape[0])} pre")
1639
1640    metasub['subjectID']=None
1641    metasub['date']=None
1642    metasub['subjectIDdate']=None
1643    metasub['imageID']=None
1644    metasub['negol']=math.nan
1645    for k in metasub.index:
1646        temp = metasub.loc[k, 'filename'].split( mysep )
1647        metasub.loc[k,'subjectID'] = str( temp[1] )
1648        metasub.loc[k,'date'] = str( temp[2] )
1649        metasub.loc[k,'subjectIDdate'] = str( temp[1] + mysep + temp[2] )
1650        metasub.loc[k,'imageID'] = str( temp[4])
1651
1652
1653    if 'ol_' in outlier_column:
1654        metasub['negol'] = metasub[outlier_column].max() - metasub[outlier_column]
1655    else:
1656        metasub['negol'] = metasub[outlier_column]
1657    if 'date' not in metasub.keys():
1658        metasub['date']=None
1659    metasubq = add_repeat_column( metasub, 'subjectIDdate' )
1660    metasubq = highest_quality_repeat(metasubq, 'filename', 'date', 'negol')
1661
1662    if verbose:
1663        print(f"{wmod} {metasubq.shape[0]} post")
1664
1665#    metasub = metasub.astype(str)
1666#    metasubq = metasubq.astype(str)
1667    metasub[outlier_column]=metasub[outlier_column].astype(float)
1668    metasubq[outlier_column]=metasubq[outlier_column].astype(float)
1669    return {'raw': metasub, 'filt': metasubq}

Selects the best repeats per modality.

Args: wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')

mysep (str, optional): the separator used in the image file names. Defaults to '-'.

outlier_name : column name for outlier score

verbose (bool, optional): default True

Returns:

list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.

def nrg_2_bids(nrg_filename):
 990def nrg_2_bids( nrg_filename ):
 991    """
 992    Convert an NRG filename to BIDS path/filename.
 993
 994    Parameters:
 995    nrg_filename (str): The NRG filename to convert.
 996
 997    Returns:
 998    str: The BIDS path/filename.
 999    """
1000
1001    # Split the NRG filename into its components
1002    nrg_dirname, nrg_basename = os.path.split(nrg_filename)
1003    nrg_suffix = '.' + nrg_basename.split('.',1)[-1]
1004    nrg_basename = nrg_basename.replace(nrg_suffix, '') # remove ext
1005    nrg_parts = nrg_basename.split('-')
1006    nrg_subject_id = nrg_parts[1]
1007    nrg_modality = nrg_parts[3]
1008    nrg_repeat= nrg_parts[4]
1009
1010    # Build the BIDS path/filename
1011    bids_dirname = os.path.join(nrg_dirname, 'bids')
1012    bids_subject = f'sub-{nrg_subject_id}'
1013    bids_session = f'ses-{nrg_repeat}'
1014
1015    valid_modalities = get_valid_modalities()
1016    if nrg_modality is not None:
1017        if not nrg_modality in valid_modalities:
1018            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
1019
1020    if nrg_modality == 'T1w' :
1021        bids_modality_folder = 'anat'
1022        bids_modality_filename = 'T1w'
1023
1024    if nrg_modality == 'T2Flair' :
1025        bids_modality_folder = 'anat'
1026        bids_modality_filename = 'flair'
1027
1028    if nrg_modality == 'NM2DMT' :
1029        bids_modality_folder = 'anat'
1030        bids_modality_filename = 'nm2dmt'
1031
1032    if nrg_modality == 'DTI' or nrg_modality == 'DTI_RL' or nrg_modality == 'DTI_LR' :
1033        bids_modality_folder = 'dwi'
1034        bids_modality_filename = 'dwi'
1035
1036    if nrg_modality == 'rsfMRI' or nrg_modality == 'rsfMRI_RL' or nrg_modality == 'rsfMRI_LR' :
1037        bids_modality_folder = 'func'
1038        bids_modality_filename = 'func'
1039
1040    if nrg_modality == 'perf'  :
1041        bids_modality_folder = 'perf'
1042        bids_modality_filename = 'perf'
1043
1044    bids_suffix = nrg_suffix[1:]
1045    bids_filename = f'{bids_subject}_{bids_session}_{bids_modality_filename}.{bids_suffix}'
1046
1047    # Return bids filepath/filename
1048    return os.path.join(bids_dirname, bids_subject, bids_session, bids_modality_folder, bids_filename)

Convert an NRG filename to BIDS path/filename.

Parameters: nrg_filename (str): The NRG filename to convert.

Returns: str: The BIDS path/filename.

def bids_2_nrg(bids_filename, project_name, date, nrg_modality=None):
1051def bids_2_nrg( bids_filename, project_name, date, nrg_modality=None ):
1052    """
1053    Convert a BIDS filename to NRG path/filename.
1054
1055    Parameters:
1056    bids_filename (str): The BIDS filename to convert
1057    project_name (str) : Name of project (i.e. PPMI)
1058    date (str) : Date of image acquisition
1059
1060
1061    Returns:
1062    str: The NRG path/filename.
1063    """
1064
1065    bids_dirname, bids_basename = os.path.split(bids_filename)
1066    bids_suffix = '.'+ bids_basename.split('.',1)[-1]
1067    bids_basename = bids_basename.replace(bids_suffix, '') # remove ext
1068    bids_parts = bids_basename.split('_')
1069    nrg_subject_id = bids_parts[0].replace('sub-','')
1070    nrg_image_id = bids_parts[1].replace('ses-', '')
1071    bids_modality = bids_parts[2]
1072    valid_modalities = get_valid_modalities()
1073    if nrg_modality is not None:
1074        if not nrg_modality in valid_modalities:
1075            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality: " + get_valid_modalities(asString=True))
1076
1077    if bids_modality == 'anat' and nrg_modality is None :
1078        nrg_modality = 'T1w'
1079
1080    if bids_modality == 'dwi' and nrg_modality is None  :
1081        nrg_modality = 'DTI'
1082
1083    if bids_modality == 'func' and nrg_modality is None  :
1084        nrg_modality = 'rsfMRI'
1085
1086    if bids_modality == 'perf' and nrg_modality is None  :
1087        nrg_modality = 'perf'
1088
1089    nrg_suffix = bids_suffix[1:]
1090    nrg_filename = f'{project_name}-{nrg_subject_id}-{date}-{nrg_modality}-{nrg_image_id}.{nrg_suffix}'
1091
1092    return os.path.join(project_name, nrg_subject_id, date, nrg_modality, nrg_image_id,nrg_filename)

Convert a BIDS filename to NRG path/filename.

Parameters: bids_filename (str): The BIDS filename to convert project_name (str) : Name of project (i.e. PPMI) date (str) : Date of image acquisition

Returns: str: The NRG path/filename.

def parse_nrg_filename(x, separator='-'):
973def parse_nrg_filename( x, separator='-' ):
974    """
975    split a NRG filename into its named parts
976    """
977    temp = x.split( separator )
978    if len(temp) != 5:
979        raise ValueError(x + " not a valid NRG filename")
980    return {
981        'project':temp[0],
982        'subjectID':temp[1],
983        'date':temp[2],
984        'modality':temp[3],
985        'imageID':temp[4]
986    }

split a NRG filename into its named parts

def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11610def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11611    """
11612    This function performs novelty detection using One-Class SVM.
11613
11614    Parameters:
11615
11616    - df_train (pandas dataframe): training data used to fit the model
11617
11618    - df_test (pandas dataframe): test data used to predict novelties
11619
11620    - nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)
11621
11622    - kernel (str): kernel type used in the SVM algorithm (default: 'rbf')
11623
11624    Returns:
11625
11626    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11627    """
11628    from sklearn.svm import OneClassSVM
11629    # Fit the model on the training data
11630    df_train[ df_train == math.inf ] = 0
11631    df_test[ df_test == math.inf ] = 0
11632    clf = OneClassSVM(nu=nu, kernel=kernel)
11633    from sklearn.preprocessing import StandardScaler
11634    scaler = StandardScaler()
11635    scaler.fit(df_train)
11636    clf.fit(scaler.transform(df_train))
11637    predictions = clf.predict(scaler.transform(df_test))
11638    predictions[predictions==1]=0
11639    predictions[predictions==-1]=1
11640    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11641        return pd.Series(predictions, index=df_test.index)
11642    else:
11643        return pd.Series(predictions)

This function performs novelty detection using One-Class SVM.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)

  • kernel (str): kernel type used in the SVM algorithm (default: 'rbf')

Returns:

predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)

def novelty_detection_ee(df_train, df_test, contamination=0.05):
11574def novelty_detection_ee(df_train, df_test, contamination=0.05):
11575    """
11576    This function performs novelty detection using Elliptic Envelope.
11577
11578    Parameters:
11579
11580    - df_train (pandas dataframe): training data used to fit the model
11581
11582    - df_test (pandas dataframe): test data used to predict novelties
11583
11584    - contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)
11585
11586    Returns:
11587
11588    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11589    """
11590    import pandas as pd
11591    from sklearn.covariance import EllipticEnvelope
11592    # Fit the model on the training data
11593    clf = EllipticEnvelope(contamination=contamination,support_fraction=1)
11594    df_train[ df_train == math.inf ] = 0
11595    df_test[ df_test == math.inf ] = 0
11596    from sklearn.preprocessing import StandardScaler
11597    scaler = StandardScaler()
11598    scaler.fit(df_train)
11599    clf.fit(scaler.transform(df_train))
11600    predictions = clf.predict(scaler.transform(df_test))
11601    predictions[predictions==1]=0
11602    predictions[predictions==-1]=1
11603    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11604        return pd.Series(predictions, index=df_test.index)
11605    else:
11606        return pd.Series(predictions)

This function performs novelty detection using Elliptic Envelope.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)

Returns:

predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)

def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11647def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11648    """
11649    This function performs novelty detection using Local Outlier Factor (LOF).
11650
11651    Parameters:
11652
11653    - df_train (pandas dataframe): training data used to fit the model
11654
11655    - df_test (pandas dataframe): test data used to predict novelties
11656
11657    - n_neighbors (int): number of neighbors used to compute the LOF (default: 20)
11658
11659    Returns:
11660
11661    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11662
11663    """
11664    from sklearn.neighbors import LocalOutlierFactor
11665    # Fit the model on the training data
11666    df_train[ df_train == math.inf ] = 0
11667    df_test[ df_test == math.inf ] = 0
11668    clf = LocalOutlierFactor(n_neighbors=n_neighbors, algorithm='auto',contamination='auto', novelty=True)
11669    from sklearn.preprocessing import StandardScaler
11670    scaler = StandardScaler()
11671    scaler.fit(df_train)
11672    clf.fit(scaler.transform(df_train))
11673    predictions = clf.predict(scaler.transform(df_test))
11674    predictions[predictions==1]=0
11675    predictions[predictions==-1]=1
11676    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11677        return pd.Series(predictions, index=df_test.index)
11678    else:
11679        return pd.Series(predictions)

This function performs novelty detection using Local Outlier Factor (LOF).

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • n_neighbors (int): number of neighbors used to compute the LOF (default: 20)

Returns:

  • predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11682def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11683    """
11684    This function performs novelty detection using Local Outlier Factor (LOF).
11685
11686    Parameters:
11687
11688    - df_train (pandas dataframe): training data used to fit the model
11689
11690    - df_test (pandas dataframe): test data used to predict novelties
11691
11692    - n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)
11693
11694    - distance_metric : default minkowski
11695
11696    Returns:
11697
11698    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11699
11700    """
11701    from PyNomaly import loop
11702    from sklearn.neighbors import NearestNeighbors
11703    from sklearn.preprocessing import StandardScaler
11704    scaler = StandardScaler()
11705    scaler.fit(df_train)
11706    data = np.vstack( [scaler.transform(df_test),scaler.transform(df_train)])
11707    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric=distance_metric)
11708    neigh.fit(data)
11709    d, idx = neigh.kneighbors(data, return_distance=True)
11710    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11711    return m.local_outlier_probabilities[range(df_test.shape[0])]

This function performs novelty detection using Local Outlier Factor (LOF).

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)

  • distance_metric : default minkowski

Returns:

  • predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
def novelty_detection_quantile(df_train, df_test):
11715def novelty_detection_quantile(df_train, df_test):
11716    """
11717    This function performs novelty detection using quantiles for each column.
11718
11719    Parameters:
11720
11721    - df_train (pandas dataframe): training data used to fit the model
11722
11723    - df_test (pandas dataframe): test data used to predict novelties
11724
11725    Returns:
11726
11727    - quantiles for the test sample at each column where values range in [0,1]
11728        and higher values mean the column is closer to the edge of the distribution
11729
11730    """
11731    myqs = df_test.copy()
11732    n = df_train.shape[0]
11733    df_trainkeys = df_train.keys()
11734    for k in range( df_train.shape[1] ):
11735        mykey = df_trainkeys[k]
11736        temp = (myqs[mykey][0] >  df_train[mykey]).sum() / n
11737        myqs[mykey] = abs( temp - 0.5 ) / 0.5
11738    return myqs

This function performs novelty detection using quantiles for each column.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

Returns:

  • quantiles for the test sample at each column where values range in [0,1] and higher values mean the column is closer to the edge of the distribution
def generate_mm_dataframe( projectID, subjectID, date, imageUniqueID, modality, source_image_directory, output_image_directory, t1_filename, flair_filename=[], rsf_filenames=[], dti_filenames=[], nm_filenames=[], perf_filename=[], pet3d_filename=[]):
649def generate_mm_dataframe(
650        projectID,
651        subjectID,
652        date,
653        imageUniqueID,
654        modality,
655        source_image_directory,
656        output_image_directory,
657        t1_filename,
658        flair_filename=[],
659        rsf_filenames=[],
660        dti_filenames=[],
661        nm_filenames=[],
662        perf_filename=[],
663        pet3d_filename=[],
664):
665    """
666    Generate a DataFrame for medical imaging data with extensive validation of input parameters.
667
668    This function creates a DataFrame containing information about medical imaging files,
669    ensuring that filenames match expected patterns for their modalities and that all
670    required images exist. It also validates the number of filenames provided for specific
671    modalities like rsfMRI, DTI, and NM.
672
673    Parameters:
674    - projectID (str): Project identifier.
675    - subjectID (str): Subject identifier.
676    - date (str): Date of the imaging study.
677    - imageUniqueID (str): Unique image identifier.
678    - modality (str): Modality of the imaging study.
679    - source_image_directory (str): Directory of the source images.
680    - output_image_directory (str): Directory for output images.
681    - t1_filename (str): Filename of the T1-weighted image.
682    - flair_filename (list): List of filenames for FLAIR images.
683    - rsf_filenames (list): List of filenames for rsfMRI images.
684    - dti_filenames (list): List of filenames for DTI images.
685    - nm_filenames (list): List of filenames for NM images.
686    - perf_filename (list): List of filenames for perfusion images.
687    - pet3d_filename (list): List of filenames for pet3d images.
688
689    Returns:
690    - pandas.DataFrame: A DataFrame containing the validated imaging study information.
691
692    Raises:
693    - ValueError: If any validation checks fail or if the number of columns does not match the data.
694    """
695    def check_pd_construction(data, columns):
696        # Check if the length of columns matches the length of data in each row
697        if all(len(row) == len(columns) for row in data):
698            return True
699        else:
700            return False
701    from os.path import exists
702    valid_modalities = get_valid_modalities()
703    if not isinstance(t1_filename, str):
704        raise ValueError("t1_filename is not a string")
705    if not exists(t1_filename):
706        raise ValueError("t1_filename does not exist")
707    if modality not in valid_modalities:
708        raise ValueError('modality ' + str(modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
709    # if not exists( output_image_directory ):
710    #    raise ValueError("output_image_directory does not exist")
711    if not exists( source_image_directory ):
712        raise ValueError("source_image_directory does not exist")
713    if len( rsf_filenames ) > 2:
714        raise ValueError("len( rsf_filenames ) > 2")
715    if len( dti_filenames ) > 3:
716        raise ValueError("len( dti_filenames ) > 3")
717    if len( nm_filenames ) > 11:
718        raise ValueError("len( nm_filenames ) > 11")
719    if len( rsf_filenames ) < 2:
720        for k in range(len(rsf_filenames),2):
721            rsf_filenames.append(None)
722    if len( dti_filenames ) < 3:
723        for k in range(len(dti_filenames),3):
724            dti_filenames.append(None)
725    if len( nm_filenames ) < 10:
726        for k in range(len(nm_filenames),10):
727            nm_filenames.append(None)
728    # check modality names
729    if not "T1w" in t1_filename:
730        raise ValueError("T1w is not in t1 filename " + t1_filename)
731    if flair_filename is not None:
732        if isinstance(flair_filename,list):
733            if (len(flair_filename) == 0):
734                flair_filename=None
735            else:
736                print("Take first entry from flair_filename list")
737                flair_filename=flair_filename[0]
738    if flair_filename is not None and not "lair" in flair_filename:
739            raise ValueError("flair is not flair filename " + flair_filename)
740    ## perfusion
741    if perf_filename is not None:
742        if isinstance(perf_filename,list):
743            if (len(perf_filename) == 0):
744                perf_filename=None
745            else:
746                print("Take first entry from perf_filename list")
747                perf_filename=perf_filename[0]
748    if perf_filename is not None and not "perf" in perf_filename:
749            raise ValueError("perf_filename is not perf filename " + perf_filename)
750
751    if pet3d_filename is not None:
752        if isinstance(pet3d_filename,list):
753            if (len(pet3d_filename) == 0):
754                pet3d_filename=None
755            else:
756                print("Take first entry from pet3d_filename list")
757                pet3d_filename=pet3d_filename[0]
758    if pet3d_filename is not None and not "pet" in pet3d_filename:
759            raise ValueError("pet3d_filename is not pet filename " + pet3d_filename)
760    
761    for k in nm_filenames:
762        if k is not None:
763            if not "NM" in k:
764                raise ValueError("NM is not flair filename " + k)
765    for k in dti_filenames:
766        if k is not None:
767            if not "DTI" in k and not "dwi" in k:
768                raise ValueError("DTI/DWI is not dti filename " + k)
769    for k in rsf_filenames:
770        if k is not None:
771            if not "fMRI" in k and not "func" in k:
772                raise ValueError("rsfMRI/func is not rsfmri filename " + k)
773    if perf_filename is not None:
774        if not "perf" in perf_filename:
775                raise ValueError("perf_filename is not a valid perfusion (perf) filename " + k)
776    allfns = [t1_filename] + [flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename] + [pet3d_filename]
777    for k in allfns:
778        if k is not None:
779            if not isinstance(k, str):
780                raise ValueError(str(k) + " is not a string")
781            if not exists( k ):
782                raise ValueError( "image " + k + " does not exist")
783    coredata = [
784        projectID,
785        subjectID,
786        date,
787        imageUniqueID,
788        modality,
789        source_image_directory,
790        output_image_directory,
791        t1_filename,
792        flair_filename, 
793        perf_filename,
794        pet3d_filename]
795    mydata0 = coredata +  rsf_filenames + dti_filenames
796    mydata = mydata0 + nm_filenames
797    corecols = [
798        'projectID',
799        'subjectID',
800        'date',
801        'imageID',
802        'modality',
803        'sourcedir',
804        'outputdir',
805        'filename',
806        'flairid',
807        'perfid',
808        'pet3did']
809    mycols0 = corecols + [
810        'rsfid1', 'rsfid2',
811        'dtid1', 'dtid2','dtid3']
812    nmext = [
813        'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
814        'nmid6', 'nmid7','nmid8', 'nmid9', 'nmid10' #, 'nmid11'
815    ]
816    mycols = mycols0 + nmext
817    if not check_pd_construction( [mydata], mycols ) :
818#        print( mydata )
819#        print( len(mydata ))
820#        print( mycols )
821#        print( len(mycols ))
822        raise ValueError( "Error in generate_mm_dataframe: len( mycols ) != len( mydata ) which indicates a bad input parameter to this function." )
823    studycsv = pd.DataFrame([ mydata ], columns=mycols)
824    return studycsv

Generate a DataFrame for medical imaging data with extensive validation of input parameters.

This function creates a DataFrame containing information about medical imaging files, ensuring that filenames match expected patterns for their modalities and that all required images exist. It also validates the number of filenames provided for specific modalities like rsfMRI, DTI, and NM.

Parameters:

  • projectID (str): Project identifier.
  • subjectID (str): Subject identifier.
  • date (str): Date of the imaging study.
  • imageUniqueID (str): Unique image identifier.
  • modality (str): Modality of the imaging study.
  • source_image_directory (str): Directory of the source images.
  • output_image_directory (str): Directory for output images.
  • t1_filename (str): Filename of the T1-weighted image.
  • flair_filename (list): List of filenames for FLAIR images.
  • rsf_filenames (list): List of filenames for rsfMRI images.
  • dti_filenames (list): List of filenames for DTI images.
  • nm_filenames (list): List of filenames for NM images.
  • perf_filename (list): List of filenames for perfusion images.
  • pet3d_filename (list): List of filenames for pet3d images.

Returns:

  • pandas.DataFrame: A DataFrame containing the validated imaging study information.

Raises:

  • ValueError: If any validation checks fail or if the number of columns does not match the data.
def aggregate_antspymm_results( input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path='./Processed/ANTsExpArt/', hiervariable='T1wHierarchical', valid_modalities=None, verbose=False):
12271def aggregate_antspymm_results(input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Processed/ANTsExpArt/", hiervariable='T1wHierarchical', valid_modalities=None, verbose=False ):
12272    """
12273    Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.
12274
12275    Parameters:
12276    - input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
12277    - subject_col (str): Name of the column to store subject IDs.
12278    - date_col (str): Name of the column to store date information.
12279    - image_col (str): Name of the column to store image IDs.
12280    - date_column (str): Name of the column representing the date information.
12281    - base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
12282    - hiervariable (str) : the string variable denoting the Hierarchical output
12283    - valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
12284    - verbose : boolean
12285
12286    Note:
12287    This function is tested under limited circumstances. Use with caution.
12288
12289    Example usage:
12290    agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")
12291
12292    Author:
12293    Avants and ChatGPT
12294    """
12295    import pandas as pd
12296    import numpy as np
12297    from glob import glob
12298
12299    def myread_csv(x, cnms):
12300        """
12301        Reads a CSV file and returns a DataFrame excluding specified columns.
12302
12303        Parameters:
12304        - x (str): File path of the input CSV file describing the blind QC output
12305        - cnms (list): List of column names to exclude from the DataFrame.
12306
12307        Returns:
12308        pd.DataFrame: DataFrame with specified columns excluded.
12309        """
12310        df = pd.read_csv(x)
12311        return df.loc[:, ~df.columns.isin(cnms)]
12312
12313    import warnings
12314    # Warning message for untested function
12315    warnings.warn("Warning: This function is not well tested. Use with caution.")
12316
12317    if valid_modalities is None:
12318        valid_modalities = get_valid_modalities('long')
12319
12320    # Read the input CSV file
12321    df = pd.read_csv(input_csv)
12322
12323    # Filter rows where modality is 'T1w'
12324    df = df[df['modality'] == 'T1w']
12325    badnames = get_names_from_data_frame( ['Unnamed'], df )
12326    df=df.drop(badnames, axis=1)
12327
12328    # Add new columns for subject ID, date, and image ID
12329    df[subject_col] = np.nan
12330    df[date_col] = date_column
12331    df[image_col] = np.nan
12332    df = df.astype({subject_col: str, date_col: str, image_col: str })
12333
12334#    if verbose:
12335#        print( df.shape )
12336#        print( df.dtypes )
12337
12338    # prefilter df for data that exists
12339    keep = np.tile( False, df.shape[0] )
12340    for x in range(df.shape[0]):
12341        temp = df['filename'].iloc[x].split("_")
12342        # Generalized search paths
12343        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12344        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12345        if len( hierfn ) > 0:
12346            keep[x]=True
12347
12348    
12349    df=df[keep]
12350    
12351    if verbose:
12352        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12353        print( df.shape )
12354
12355    myct = 0
12356    for x in range( df.shape[0]):
12357        if verbose:
12358            print(f"{x}...")
12359        locind = df.index[x]
12360        temp = df['filename'].iloc[x].split("_")
12361        if verbose:
12362            print( temp )
12363        df[subject_col].iloc[x]=temp[0]
12364        df[date_col].iloc[x]=date_column
12365        df[image_col].iloc[x]=temp[1]
12366
12367        # Generalized search paths
12368        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12369        if verbose:
12370            print(path_template)
12371        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12372        if len( hierfn ) > 0:
12373            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12374            if verbose:
12375                print(hierfn)
12376            hdf = pd.read_csv(hierfn[0])
12377            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12378            hdf=hdf.drop(badnames, axis=1)
12379            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12380            corenames = list(np.array(hdf.columns)[nums])
12381            hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12382            myct = myct + 1
12383            dflist = [hdf]
12384
12385            for mymod in valid_modalities:
12386                t1wfn = sorted(glob( path_template+ "-" + mymod + "-*wide.csv" ) )
12387                if len( t1wfn ) > 0 :
12388                    if verbose:
12389                        print(t1wfn)
12390                    t1df = myread_csv(t1wfn[0], corenames)
12391                    t1df = filter_df( t1df, mymod+'_')
12392                    dflist = dflist + [t1df]
12393                
12394            hdf = pd.concat( dflist, axis=1, ignore_index=False )
12395            if verbose:
12396                print( df.loc[locind,'filename'] )
12397            if myct == 1:
12398                subdf = df.iloc[[x]]
12399                hdf.index = subdf.index.copy()
12400                df = pd.concat( [df,hdf], axis=1, ignore_index=False )
12401            else:
12402                commcols = list(set(hdf.columns).intersection(df.columns))
12403                df.loc[locind, commcols] = hdf.loc[0, commcols]
12404    badnames = get_names_from_data_frame( ['Unnamed'], df )
12405    df=df.drop(badnames, axis=1)
12406    return( df )

Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.

Parameters:

  • input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
  • subject_col (str): Name of the column to store subject IDs.
  • date_col (str): Name of the column to store date information.
  • image_col (str): Name of the column to store image IDs.
  • date_column (str): Name of the column representing the date information.
  • base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
  • hiervariable (str) : the string variable denoting the Hierarchical output
  • valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
  • verbose : boolean

Note: This function is tested under limited circumstances. Use with caution.

Example usage: agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")

Author: Avants and ChatGPT

def aggregate_antspymm_results_sdf( study_df, project_col='projectID', subject_col='subjectID', date_col='date', image_col='imageID', base_path='./', hiervariable='T1wHierarchical', splitsep='-', idsep='-', wild_card_modality_id=False, second_split=False, verbose=False):
12429def aggregate_antspymm_results_sdf(
12430    study_df, 
12431    project_col='projectID',
12432    subject_col='subjectID', 
12433    date_col='date', 
12434    image_col='imageID', 
12435    base_path="./", 
12436    hiervariable='T1wHierarchical', 
12437    splitsep='-',
12438    idsep='-',
12439    wild_card_modality_id=False,
12440    second_split=False,
12441    verbose=False ):
12442    """
12443    Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame.  This assumes data is organized on disk 
12444    as follows:  rootdir/projectID/subjectID/date/outputid/imageid/ where 
12445    outputid is modality-specific and created by ANTsPyMM processing.
12446
12447    Parameters:
12448    - study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
12449    - project_col (str): Name of the column that stores the project ID
12450    - subject_col (str): Name of the column to store subject IDs.
12451    - date_col (str): Name of the column to store date information.
12452    - image_col (str): Name of the column to store image IDs.
12453    - base_path (str): Base path for searching for processing outputs of ANTsPyMM.
12454    - hiervariable (str) : the string variable denoting the Hierarchical output
12455    - splitsep (str):  the separator used to split the filename
12456    - idsep (str): the separator used to partition subjectid date and imageid 
12457        for example, if idsep is - then we have subjectid-date-imageid
12458    - wild_card_modality_id (bool): keep if False for safer execution
12459    - second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
12460    - verbose : boolean
12461
12462    Note:
12463    This function is tested under limited circumstances. Use with caution.
12464    One particular gotcha is if the imageID is stored as a numeric value in the dataframe 
12465    but is meant to be a string.  E.g. '000' (string) would be interpreted as 0 in the 
12466    file name glob.  This would miss the extant (on disk) csv.
12467
12468    Example usage:
12469    agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")
12470
12471    Author:
12472    Avants and ChatGPT
12473    """
12474    import pandas as pd
12475    import numpy as np
12476    from glob import glob
12477
12478    def progress_reporter(current_step, total_steps, width=50):
12479        # Calculate the proportion of progress
12480        progress = current_step / total_steps
12481        # Calculate the number of 'filled' characters in the progress bar
12482        filled_length = int(width * progress)
12483        # Create the progress bar string
12484        bar = 'â–ˆ' * filled_length + '-' * (width - filled_length)
12485        # Print the progress bar with percentage
12486        print(f'\rProgress: |{bar}| {int(100 * progress)}%', end='\r')
12487        # Print a new line when the progress is complete
12488        if current_step == total_steps:
12489            print()
12490
12491    def myread_csv(x, cnms):
12492        """
12493        Reads a CSV file and returns a DataFrame excluding specified columns.
12494
12495        Parameters:
12496        - x (str): File path of the input CSV file describing the blind QC output
12497        - cnms (list): List of column names to exclude from the DataFrame.
12498
12499        Returns:
12500        pd.DataFrame: DataFrame with specified columns excluded.
12501        """
12502        df = pd.read_csv(x)
12503        return df.loc[:, ~df.columns.isin(cnms)]
12504
12505    import warnings
12506    # Warning message for untested function
12507    warnings.warn("Warning: This function is not well tested. Use with caution.")
12508
12509    vmoddict = {}
12510    # Add key-value pairs
12511    vmoddict['imageID'] = 'T1w'
12512    vmoddict['flairid'] = 'T2Flair'
12513    vmoddict['perfid'] = 'perf'
12514    vmoddict['pet3did'] = 'pet3d'
12515    vmoddict['rsfid1'] = 'rsfMRI'
12516#    vmoddict['rsfid2'] = 'rsfMRI'
12517    vmoddict['dtid1'] = 'DTI'
12518#    vmoddict['dtid2'] = 'DTI'
12519    vmoddict['nmid1'] = 'NM2DMT'
12520#    vmoddict['nmid2'] = 'NM2DMT'
12521
12522    # Filter rows where modality is 'T1w'
12523    df = study_df[ study_df['modality'] == 'T1w']
12524    badnames = get_names_from_data_frame( ['Unnamed'], df )
12525    df=df.drop(badnames, axis=1)
12526    # prefilter df for data that exists
12527    keep = np.tile( False, df.shape[0] )
12528    for x in range(df.shape[0]):
12529        myfn = os.path.basename( df['filename'].iloc[x] )
12530        temp = myfn.split( splitsep )
12531        # Generalized search paths
12532        sid0 = str( temp[1] )
12533        sid = str( df[subject_col].iloc[x] )
12534        if sid0 != sid:
12535            warnings.warn("OUTER: the id derived from the filename " + sid0 + " does not match the id stored in the data frame " + sid )
12536            warnings.warn( "filename is : " +  myfn )
12537            warnings.warn( "sid is : " + sid )
12538            warnings.warn( "x is : " + str(x) )
12539        myproj = str(df[project_col].iloc[x])
12540        mydate = str(df[date_col].iloc[x])
12541        myid = str(df[image_col].iloc[x])
12542        if second_split:
12543            myid = myid.split(".")[0]
12544        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12545        hierfn = sorted(glob( path_template + "*" + hiervariable + "*wide.csv" ) )
12546        if len( hierfn ) == 0:
12547            print( hierfn )
12548            print( path_template )
12549            print( myproj )
12550            print( sid )
12551            print( mydate ) 
12552            print( myid )
12553        if len( hierfn ) > 0:
12554            keep[x]=True
12555
12556    # df=df[keep]
12557    if df.shape[0] == 0:
12558        warnings.warn("input data frame shape is filtered down to zero")
12559        return df
12560
12561    if not df.index.is_unique:
12562        warnings.warn("data frame does not have unique indices.  we therefore reset the index to allow the function to continue on." )
12563        df = df.reset_index()
12564
12565    
12566    if verbose:
12567        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12568        print( df.shape )
12569
12570    dfout = pd.DataFrame()
12571    myct = 0
12572    for x in range( df.shape[0]):
12573        if verbose:
12574            print("\n\n-------------------------------------------------")
12575            print(f"{x}...")
12576        else:
12577            progress_reporter(x, df.shape[0], width=500)
12578        locind = df.index[x]
12579        myfn = os.path.basename( df['filename'].iloc[x] )
12580        sid = str( df[subject_col].iloc[x] )
12581        tempB = myfn.split( splitsep )
12582        sid0 = str(tempB[1])
12583        if sid0 != sid and verbose:
12584            warnings.warn("INNER: the id derived from the filename " + str(sid) + " does not match the id stored in the data frame " + str(sid0) )
12585            warnings.warn( "filename is : " +  str(myfn) )
12586            warnings.warn( "sid is : " + str(sid) )
12587            warnings.warn( "x is : " + str(x) )
12588            warnings.warn( "index is : " + str(locind) )
12589        myproj = str(df[project_col].iloc[x])
12590        mydate = str(df[date_col].iloc[x])
12591        myid = str(df[image_col].iloc[x])
12592        if second_split:
12593            myid = myid.split(".")[0]
12594        if verbose:
12595            print( myfn )
12596            print( temp )
12597            print( "id " + sid  )
12598        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12599        searchhier = path_template + "*" + hiervariable + "*wide.csv"
12600        if verbose:
12601            print( searchhier )
12602        hierfn = sorted( glob( searchhier ) )
12603        if len( hierfn ) > 1:
12604            raise ValueError("there are " + str( len( hierfn ) ) + " number of hier fns with search path " + searchhier )
12605        if len( hierfn ) == 1:
12606            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12607            if verbose:
12608                print(hierfn)
12609            hdf = pd.read_csv(hierfn[0])
12610            if verbose:
12611                print( hdf['vol_hemisphere_lefthemispheres'] )
12612            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12613            hdf=hdf.drop(badnames, axis=1)
12614            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12615            corenames = list(np.array(hdf.columns)[nums])
12616            # hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12617            hdf = hdf.add_prefix("T1Hier_")
12618            myct = myct + 1
12619            dflist = [hdf]
12620
12621            for mymod in vmoddict.keys():
12622                if verbose:
12623                    print("\n\n************************* " + mymod + " *************************")
12624                modalityclass = vmoddict[ mymod ]
12625                if wild_card_modality_id:
12626                    mymodid = '*'
12627                else:
12628                    mymodid = str( df[mymod].iloc[x] )
12629                    if mymodid.lower() != "nan" and mymodid.lower() != "na":
12630                        mymodid = os.path.basename( mymodid )
12631                        mymodid = os.path.splitext( mymodid )[0]
12632                        mymodid = os.path.splitext( mymodid )[0]
12633                        temp = mymodid.split( idsep )
12634                        mymodid = temp[ len( temp )-1 ]
12635                    else:
12636                        if verbose:
12637                            print("missing")
12638                        continue
12639                if verbose:
12640                    print( "modality id is " + mymodid + " for modality " + modalityclass + ' modality specific subj ' + sid + ' modality specific id is ' + myid + " its date " +  mydate )
12641                modalityclasssearch = modalityclass
12642                if modalityclass in ['rsfMRI','DTI']:
12643                    modalityclasssearch=modalityclass+"*"
12644                path_template_m = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + modalityclasssearch + '/' + mymodid + "/"
12645                modsearch = path_template_m + "*" + modalityclasssearch + "*wide.csv"
12646                if verbose:
12647                    print( modsearch )
12648                t1wfn = sorted( glob( modsearch ) )
12649                if len( t1wfn ) > 1:
12650                    nlarge = len(t1wfn)
12651                    t1wfn = find_most_recent_file( t1wfn )
12652                    warnings.warn("there are " + str( nlarge ) + " number of wide fns with search path " + modsearch + " we take the most recent of these " + t1wfn[0] )
12653                if len( t1wfn ) == 1:
12654                    if verbose:
12655                        print(t1wfn)
12656                    t1df = myread_csv(t1wfn[0], corenames)
12657                    t1df = filter_df( t1df, modalityclass+'_')
12658                    dflist = dflist + [t1df]
12659                else:
12660                    if verbose:
12661                        print( " cannot find " + modsearch )
12662                
12663            hdf = pd.concat( dflist, axis=1, ignore_index=False)
12664            if verbose:
12665                print( "count: " + str( myct ) )
12666            subdf = df.iloc[[x]]
12667            hdf.index = subdf.index.copy()
12668            subdf = pd.concat( [subdf,hdf], axis=1, ignore_index=False)
12669            dfout = pd.concat( [dfout,subdf], axis=0, ignore_index=False )
12670
12671    if dfout.shape[0] > 0:
12672        badnames = get_names_from_data_frame( ['Unnamed'], dfout )
12673        dfout=dfout.drop(badnames, axis=1)
12674    return dfout

Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame. This assumes data is organized on disk as follows: rootdir/projectID/subjectID/date/outputid/imageid/ where outputid is modality-specific and created by ANTsPyMM processing.

Parameters:

  • study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
  • project_col (str): Name of the column that stores the project ID
  • subject_col (str): Name of the column to store subject IDs.
  • date_col (str): Name of the column to store date information.
  • image_col (str): Name of the column to store image IDs.
  • base_path (str): Base path for searching for processing outputs of ANTsPyMM.
  • hiervariable (str) : the string variable denoting the Hierarchical output
  • splitsep (str): the separator used to split the filename
  • idsep (str): the separator used to partition subjectid date and imageid for example, if idsep is - then we have subjectid-date-imageid
  • wild_card_modality_id (bool): keep if False for safer execution
  • second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
  • verbose : boolean

Note: This function is tested under limited circumstances. Use with caution. One particular gotcha is if the imageID is stored as a numeric value in the dataframe but is meant to be a string. E.g. '000' (string) would be interpreted as 0 in the file name glob. This would miss the extant (on disk) csv.

Example usage: agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")

Author: Avants and ChatGPT

def study_dataframe_from_matched_dataframe(matched_dataframe, rootdir, outputdir, verbose=False):
1231def study_dataframe_from_matched_dataframe( matched_dataframe, rootdir, outputdir, verbose=False ):
1232    """
1233    converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv
1234
1235    matched_dataframe : output of antspymm.match_modalities
1236
1237    rootdir : location for the input data root folder (in e.g. NRG format)
1238
1239    outputdir : location for the output data
1240
1241    verbose : boolean
1242    """
1243    iext='.nii.gz'
1244    from os.path import exists
1245    musthavecols = ['projectID', 'subjectID','date','imageID','filename']
1246    for k in range(len(musthavecols)):
1247        if not musthavecols[k] in matched_dataframe.keys():
1248            raise ValueError('matched_dataframe is missing column ' + musthavecols[k] + ' in study_dataframe_from_qc_dataframe' )
1249    csvrow=matched_dataframe.dropna(axis=1)
1250    pid=get_first_item_as_string( csvrow, 'projectID'  )
1251    sid=get_first_item_as_string( csvrow, 'subjectID'  ) # str(csvrow['subjectID'].iloc[0] )
1252    dt=get_first_item_as_string( csvrow, 'date'  )  # str(csvrow['date'].iloc[0])
1253    iid=get_first_item_as_string( csvrow, 'imageID'  ) # str(csvrow['imageID'].iloc[0])
1254    nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
1255    if not exists( nrgt1fn ) and iid == '0':
1256        iid='000'
1257        nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
1258    if not exists( nrgt1fn ):
1259        raise ValueError("T1 " + nrgt1fn + " does not exist in study_dataframe_from_qc_dataframe")
1260    flList=[]
1261    dtList=[]
1262    rsfList=[]
1263    nmList=[]
1264    perfList=[]
1265    if 'flairfn' in csvrow.keys():
1266        flid=get_first_item_as_string( csvrow, 'flairid' )
1267        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
1268        if not exists( nrgt2fn ) and flid == '0':
1269            flid='000'
1270            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
1271        if verbose:
1272            print("Trying " + nrgt2fn )
1273        if exists( nrgt2fn ):
1274            if verbose:
1275                print("success" )
1276            flList.append( nrgt2fn )
1277    if 'perffn' in csvrow.keys():
1278        flid=get_first_item_as_string( csvrow, 'perfid' )
1279        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
1280        if not exists( nrgt2fn ) and flid == '0':
1281            flid='000'
1282            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
1283        if verbose:
1284            print("Trying " + nrgt2fn )
1285        if exists( nrgt2fn ):
1286            if verbose:
1287                print("success" )
1288            perfList.append( nrgt2fn )
1289    if 'dtfn1' in csvrow.keys():
1290        dtid=get_first_item_as_string( csvrow, 'dtid1' )
1291        dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
1292        if len( dtfn1) == 0 :
1293            dtid = '000'
1294            dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
1295        dtfn1=dtfn1[0]
1296        if exists( dtfn1 ):
1297            dtList.append( dtfn1 )
1298    if 'dtfn2' in csvrow.keys():
1299        dtid=get_first_item_as_string( csvrow, 'dtid2' )
1300        dtfn2=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
1301        if len( dtfn2) == 0 :
1302            dtid = '000'
1303            dtfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
1304        dtfn2=dtfn2[0]
1305        if exists( dtfn2 ):
1306            dtList.append( dtfn2 )
1307    if 'dtfn3' in csvrow.keys():
1308        dtid=get_first_item_as_string( csvrow, 'dtid3' )
1309        dtfn3=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
1310        if len( dtfn3) == 0 :
1311            dtid = '000'
1312            dtfn3=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
1313        dtfn3=dtfn3[0]
1314        if exists( dtfn3 ):
1315            dtList.append( dtfn3 )
1316    if 'rsffn1' in csvrow.keys():
1317        rsid=get_first_item_as_string( csvrow, 'rsfid1' )
1318        rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
1319        if len( rsfn1 ) == 0 :
1320            rsid = '000'
1321            rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
1322        rsfn1=rsfn1[0]
1323        if exists( rsfn1 ):
1324            rsfList.append( rsfn1 )
1325    if 'rsffn2' in csvrow.keys():
1326        rsid=get_first_item_as_string( csvrow, 'rsfid2' )
1327        rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))[0]
1328        if len( rsfn2 ) == 0 :
1329            rsid = '000'
1330            rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))
1331        rsfn2=rsfn2[0]
1332        if exists( rsfn2 ):
1333            rsfList.append( rsfn2 )
1334    for j in range(11):
1335        keyname="nmfn"+str(j)
1336        keynameid="nmid"+str(j)
1337        if keyname in csvrow.keys() and keynameid in csvrow.keys():
1338            nmid=get_first_item_as_string( csvrow, keynameid )
1339            nmsearchpath=os.path.join( rootdir, pid, sid, dt, 'NM2DMT', nmid, "*"+nmid+iext)
1340            nmfn=glob.glob( nmsearchpath )
1341            nmfn=nmfn[0]
1342            if exists( nmfn ):
1343                nmList.append( nmfn )
1344    if verbose:
1345        print("assembled the image lists mapping to ....")
1346        print(nrgt1fn)
1347        print("NM")
1348        print(nmList)
1349        print("FLAIR")
1350        print(flList)
1351        print("DTI")
1352        print(dtList)
1353        print("rsfMRI")
1354        print(rsfList)
1355        print("perf")
1356        print(perfList)
1357    studycsv = generate_mm_dataframe(
1358        pid,
1359        sid,
1360        dt,
1361        iid, # the T1 id
1362        'T1w',
1363        rootdir,
1364        outputdir,
1365        t1_filename=nrgt1fn,
1366        flair_filename=flList,
1367        dti_filenames=dtList,
1368        rsf_filenames=rsfList,
1369        nm_filenames=nmList,
1370        perf_filename=perfList)
1371    return studycsv.dropna(axis=1)

converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv

matched_dataframe : output of antspymm.match_modalities

rootdir : location for the input data root folder (in e.g. NRG format)

outputdir : location for the output data

verbose : boolean

def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False, progress=False, verbose=False):
 9901def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False,
 9902progress=False, verbose=False ):
 9903    """
 9904    extend a study data frame with wide outputs
 9905
 9906    sdf : the input study dataframe from antspymm QC output
 9907
 9908    processing_dir:  the directory location of the processed data 
 9909
 9910    separator : string usually '-' or '_'
 9911
 9912    sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas
 9913
 9914    date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas
 9915
 9916    id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas
 9917
 9918    report_missing : boolean combined with verbose will report missing modalities
 9919
 9920    progress : integer reports percent progress modulo progress value 
 9921
 9922    verbose : boolean
 9923    """
 9924    from os.path import exists
 9925    musthavecols = ['projectID', 'subjectID','date','imageID']
 9926    for k in range(len(musthavecols)):
 9927        if not musthavecols[k] in sdf.keys():
 9928            raise ValueError('sdf is missing column ' +musthavecols[k] + ' in merge_wides_to_study_dataframe' )
 9929    possible_iids = [ 'imageID', 'imageID', 'imageID', 'flairid', 'dtid1', 'dtid2', 'rsfid1', 'rsfid2', 'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5', 'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'perfid' ]
 9930    modality_ids = [ 'T1wHierarchical', 'T1wHierarchicalSR', 'T1w', 'T2Flair', 'DTI', 'DTI', 'rsfMRI', 'rsfMRI', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'perf']
 9931    alldf=pd.DataFrame()
 9932    for myk in sdf.index:
 9933        if progress > 0 and int(myk) % int(progress) == 0:
 9934            print( str( round( myk/sdf.shape[0]*100.0)) + "%...", end='', flush=True)
 9935        if verbose:
 9936            print( "DOROW " + str(myk) + ' of ' + str( sdf.shape[0] ) )
 9937        csvrow = sdf.loc[sdf.index == myk].dropna(axis=1)
 9938        ct=-1
 9939        for iidkey in possible_iids:
 9940            ct=ct+1
 9941            mod_name = modality_ids[ct]
 9942            if iidkey in csvrow.keys():
 9943                if id_is_int:
 9944                    iid = str( int( csvrow[iidkey].iloc[0] ) )
 9945                else:
 9946                    iid = str( csvrow[iidkey].iloc[0] )
 9947                if verbose:
 9948                    print( "iidkey " + iidkey + " modality " + mod_name + ' iid '+ iid )
 9949                pid=str(csvrow['projectID'].iloc[0] )
 9950                if sid_is_int:
 9951                    sid=str(int(csvrow['subjectID'].iloc[0] ))
 9952                else:
 9953                    sid=str(csvrow['subjectID'].iloc[0] )
 9954                if date_is_int:
 9955                    dt=str(int(csvrow['date'].iloc[0]))
 9956                else:
 9957                    dt=str(csvrow['date'].iloc[0])
 9958                if id_is_int:
 9959                    t1iid=str(int(csvrow['imageID'].iloc[0]))
 9960                else:
 9961                    t1iid=str(csvrow['imageID'].iloc[0])
 9962                if t1iid != iid:
 9963                    iidj=iid+"_"+t1iid
 9964                else:
 9965                    iidj=iid
 9966                rootid = pid +separator+ sid +separator+dt+separator+mod_name+separator+iidj
 9967                myext = rootid +separator+'mmwide.csv'
 9968                nrgwidefn=os.path.join( processing_dir, pid, sid, dt, mod_name, iid, myext )
 9969                moddersub = mod_name
 9970                is_t1=False
 9971                if mod_name == 'T1wHierarchical':
 9972                    is_t1=True
 9973                    moddersub='T1Hier'
 9974                elif mod_name == 'T1wHierarchicalSR':
 9975                    is_t1=True
 9976                    moddersub='T1HSR'
 9977                if exists( nrgwidefn ):
 9978                    if verbose:
 9979                        print( nrgwidefn + " exists")
 9980                    mm=read_mm_csv( nrgwidefn, colprefix=moddersub+'_', is_t1=is_t1, separator=separator, verbose=verbose )
 9981                    if mm is not None:
 9982                        if mod_name == 'T1wHierarchical':
 9983                            a=list( csvrow.keys() )
 9984                            b=list( mm.keys() )
 9985                            abintersect=list(set(b).intersection( set(a) ) )
 9986                            if len( abintersect  ) > 0 :
 9987                                for qq in abintersect:
 9988                                    mm.pop( qq )
 9989                        # mm.index=csvrow.index
 9990                        uidname = mod_name + '_mmwide_filename'
 9991                        mm[ uidname ] = rootid
 9992                        csvrow=pd.concat( [csvrow,mm], axis=1, ignore_index=False )
 9993                else:
 9994                    if verbose and report_missing:
 9995                        print( nrgwidefn + " absent")
 9996        if alldf.shape[0] == 0:
 9997            alldf = csvrow.copy()
 9998            alldf = alldf.loc[:,~alldf.columns.duplicated()]
 9999        else:
10000            csvrow=csvrow.loc[:,~csvrow.columns.duplicated()]
10001            alldf = alldf.loc[:,~alldf.columns.duplicated()]
10002            alldf = pd.concat( [alldf, csvrow], axis=0, ignore_index=True )
10003    return alldf

extend a study data frame with wide outputs

sdf : the input study dataframe from antspymm QC output

processing_dir: the directory location of the processed data

separator : string usually '-' or '_'

sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas

date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas

id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas

report_missing : boolean combined with verbose will report missing modalities

progress : integer reports percent progress modulo progress value

verbose : boolean

def filter_image_files(image_paths, criteria='largest'):
12713def filter_image_files(image_paths, criteria='largest'):
12714    """
12715    Filters a list of image file paths based on specified criteria and returns 
12716    the path of the image that best matches that criteria (smallest, largest, or brightest).
12717
12718    Args:
12719    image_paths (list): A list of file paths to the images.
12720    criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').
12721
12722    Returns:
12723    str: The file path of the selected image, or None if no valid images are found.
12724    """
12725    import numpy as np
12726    if not image_paths:
12727        return None
12728
12729    selected_image_path = None
12730    if criteria == 'smallest' or criteria == 'largest':
12731        extreme_volume = None
12732
12733        for path in image_paths:
12734            try:
12735                image = ants.image_read(path)
12736                volume = np.prod(image.shape)
12737
12738                if criteria == 'largest':
12739                    if extreme_volume is None or volume > extreme_volume:
12740                        extreme_volume = volume
12741                        selected_image_path = path
12742                elif criteria == 'smallest':
12743                    if extreme_volume is None or volume < extreme_volume:
12744                        extreme_volume = volume
12745                        selected_image_path = path
12746
12747            except Exception as e:
12748                print(f"Error processing image {path}: {e}")
12749
12750    elif criteria == 'brightest':
12751        max_brightness = None
12752
12753        for path in image_paths:
12754            try:
12755                image = ants.image_read(path)
12756                brightness = np.mean(image.numpy())
12757
12758                if max_brightness is None or brightness > max_brightness:
12759                    max_brightness = brightness
12760                    selected_image_path = path
12761
12762            except Exception as e:
12763                print(f"Error processing image {path}: {e}")
12764
12765    else:
12766        raise ValueError("Criteria must be 'smallest', 'largest', or 'brightest'.")
12767
12768    return selected_image_path

Filters a list of image file paths based on specified criteria and returns the path of the image that best matches that criteria (smallest, largest, or brightest).

Args: image_paths (list): A list of file paths to the images. criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').

Returns: str: The file path of the selected image, or None if no valid images are found.

def docsamson( locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
522def docsamson(locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
523    """
524    Processes image file names based on the specified imaging modality and other parameters.
525
526    The function selects file names from the provided dictionary `studycsv` based on the imaging modality.
527    It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly.
528    The function can optionally print verbose output during processing.
529
530    Parameters:
531    locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'.
532    studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names.
533    outputdir (str): Base directory for output files.
534    projid (str): Project identifier.
535    sid (str): Subject identifier.
536    dtid (str): Data acquisition time identifier.
537    mysep (str): Separator used in file naming.
538    t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'.
539    verbose (bool, optional): If True, prints detailed information during execution.
540
541    Returns:
542    dict: A dictionary with keys 'modality', 'outprefix', and 'images'.
543        - 'modality' (str): The imaging modality used.
544        - 'outprefix' (str): The prefix for output file paths.
545        - 'images' (list): A list of processed image file names.
546
547    Notes:
548    - The function is designed to work within a specific workflow and might require adaptation for general use.
549
550    Examples:
551    >>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
552    >>> print(result['modality'])
553    'T1w'
554    >>> print(result['outprefix'])
555    '/path/to/output/directory/T1w/some_identifier'
556    >>> print(result['images'])
557    ['image1.nii', 'image2.nii']
558    """
559
560    import os
561    import re
562
563    myimgsInput = []
564    myoutputPrefix = None
565    imfns = ['filename', 'rsfid1', 'rsfid2', 'dtid1', 'dtid2', 'flairid']
566    
567    # Define image file names based on the modality
568    if locmod == 'T1w':
569        imfns=['filename']
570    elif locmod == 'T2Flair':
571        imfns=['flairid']
572    elif locmod == 'perf':
573        imfns=['perfid']
574    elif locmod == 'pet3d':
575        imfns=['pet3did']
576    elif locmod == 'NM2DMT':
577        imfns=[]
578        for i in range(11):
579            imfns.append('nmid' + str(i))
580    elif locmod == 'rsfMRI':
581        imfns=[]
582        for i in range(4):
583            imfns.append('rsfid' + str(i))
584    elif locmod == 'DTI':
585        imfns=[]
586        for i in range(4):
587            imfns.append('dtid' + str(i))
588    else:
589        raise ValueError("docsamson: no match of modality to filename id " + locmod )
590
591    # Process each file name
592    for i in imfns:
593        if verbose:
594            print(i + " " + locmod)
595        if i in studycsv.keys():
596            fni = str(studycsv[i].iloc[0])
597            if verbose:
598                print(i + " " + fni + ' exists ' + str(os.path.exists(fni)))
599            if os.path.exists(fni):
600                myimgsInput.append(fni)
601                temp = os.path.basename(fni)
602                mysplit = temp.split(mysep)
603                iid = re.sub(".nii.gz", "", mysplit[-1])
604                iid = re.sub(".mha", "", iid)
605                iid = re.sub(".nii", "", iid)
606                iid2 = iid
607                if locmod != 'T1w' and t1iid is not None:
608                    iid2 = iid + "_" + t1iid
609                else:
610                    iid2 = t1iid
611                myoutputPrefix = os.path.join(outputdir, projid, sid, dtid, locmod, iid, projid + mysep + sid + mysep + dtid + mysep + locmod + mysep + iid2)
612    
613    if verbose:
614        print(locmod)
615        print(myimgsInput)
616        print(myoutputPrefix)
617    
618    return {
619        'modality': locmod,
620        'outprefix': myoutputPrefix,
621        'images': myimgsInput
622    }

Processes image file names based on the specified imaging modality and other parameters.

The function selects file names from the provided dictionary studycsv based on the imaging modality. It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly. The function can optionally print verbose output during processing.

Parameters: locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'. studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names. outputdir (str): Base directory for output files. projid (str): Project identifier. sid (str): Subject identifier. dtid (str): Data acquisition time identifier. mysep (str): Separator used in file naming. t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'. verbose (bool, optional): If True, prints detailed information during execution.

Returns: dict: A dictionary with keys 'modality', 'outprefix', and 'images'. - 'modality' (str): The imaging modality used. - 'outprefix' (str): The prefix for output file paths. - 'images' (list): A list of processed image file names.

Notes:

  • The function is designed to work within a specific workflow and might require adaptation for general use.

Examples:

>>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
>>> print(result['modality'])
'T1w'
>>> print(result['outprefix'])
'/path/to/output/directory/T1w/some_identifier'
>>> print(result['images'])
['image1.nii', 'image2.nii']
def enantiomorphic_filling_without_mask(image, axis=0, intensity='low'):
12676def enantiomorphic_filling_without_mask( image, axis=0, intensity='low' ):
12677    """
12678    Perform an enantiomorphic lesion filling on an image without a lesion mask.
12679
12680    Args:
12681    image (antsImage): The ants image to flip and fill
12682    axis ( int ): the axis along which to reflect the image
12683    intensity ( str ) : low or high
12684
12685    Returns:
12686    ants.ANTsImage: The image after enantiomorphic filling.
12687    """
12688    imagen = ants.iMath( image, 'Normalize' )
12689    imagen = ants.iMath( imagen, "TruncateIntensity", 1e-6, 0.98 )
12690    imagen = ants.iMath( imagen, 'Normalize' )
12691    # Create a mirror image (flipping left and right)
12692    mirror_image = ants.reflect_image(imagen, axis=0, tx='antsRegistrationSyNQuickRepro[s]' )['warpedmovout']
12693
12694    # Create a symmetric version of the image by averaging the original and the mirror image
12695    symmetric_image = imagen * 0.5 + mirror_image * 0.5
12696
12697    # Identify potential lesion areas by finding differences between the original and symmetric image
12698    difference_image = image - symmetric_image
12699    diffseg = ants.threshold_image(difference_image, "Otsu", 3 )
12700    if intensity == 'low':
12701        likely_lesion = ants.threshold_image( diffseg, 1,  1)
12702    else:
12703        likely_lesion = ants.threshold_image( diffseg, 3,  3)
12704    likely_lesion = ants.smooth_image( likely_lesion, 3.0 ).iMath("Normalize")
12705    lesionneg = ( imagen*0+1.0 ) - likely_lesion
12706    filled_image = ants.image_clone(imagen)    
12707    filled_image = imagen * lesionneg + mirror_image * likely_lesion
12708
12709    return filled_image, diffseg

Perform an enantiomorphic lesion filling on an image without a lesion mask.

Args: image (antsImage): The ants image to flip and fill axis ( int ): the axis along which to reflect the image intensity ( str ) : low or high

Returns: ants.ANTsImage: The image after enantiomorphic filling.

def wmh( flair, t1, t1seg, mmfromconvexhull=3.0, strict=True, probability_mask=None, prior_probability=None, model='sysu', verbose=False):
10996def wmh( flair, t1, t1seg,
10997    mmfromconvexhull = 3.0,
10998    strict=True,
10999    probability_mask=None,
11000    prior_probability=None,
11001    model='sysu',
11002    verbose=False ) :
11003    """
11004    Outputs the WMH probability mask and a summary single measurement
11005
11006    Arguments
11007    ---------
11008    flair : ANTsImage
11009        input 3-D FLAIR brain image (not skull-stripped).
11010
11011    t1 : ANTsImage
11012        input 3-D T1 brain image (not skull-stripped).
11013
11014    t1seg : ANTsImage
11015        T1 segmentation image
11016
11017    mmfromconvexhull : float
11018        restrict WMH to regions that are WM or mmfromconvexhull mm away from the
11019        convex hull of the cerebrum.   we choose a default value based on
11020        Figure 4 from:
11021        https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf
11022
11023    strict: boolean - if True, only use convex hull distance
11024
11025    probability_mask : None - use to compute wmh just once - then this function
11026        just does refinement and summary
11027
11028    prior_probability : optional prior probability image in space of the input t1
11029
11030    model : either sysu or hyper
11031
11032    verbose : boolean
11033
11034    Returns
11035    ---------
11036    WMH probability map and a summary single measurement which is the sum of the WMH map
11037
11038    """
11039    import numpy as np
11040    import math
11041    t1_2_flair_reg = ants.registration(flair, t1, type_of_transform = 'antsRegistrationSyNRepro[r]') # Register T1 to Flair
11042    if probability_mask is None and model == 'sysu':
11043        if verbose:
11044            print('sysu')
11045        probability_mask = antspynet.sysu_media_wmh_segmentation( flair )
11046    elif probability_mask is None and model == 'hyper':
11047        if verbose:
11048            print('hyper')
11049        probability_mask = antspynet.hypermapp3r_segmentation( t1_2_flair_reg['warpedmovout'], flair )
11050    # t1_2_flair_reg = tra_initializer( flair, t1, n_simulations=4, max_rotation=5, transform=['rigid'], verbose=False )
11051    prior_probability_flair = None
11052    if prior_probability is not None:
11053        prior_probability_flair = ants.apply_transforms( flair, prior_probability,
11054            t1_2_flair_reg['fwdtransforms'] )
11055    wmseg_mask = ants.threshold_image( t1seg,
11056        low_thresh = 3, high_thresh = 3).iMath("FillHoles")
11057    wmseg_mask_use = ants.image_clone( wmseg_mask )
11058    distmask = None
11059    if mmfromconvexhull > 0:
11060            convexhull = ants.threshold_image( t1seg, 1, 4 )
11061            spc2vox = np.prod( ants.get_spacing( t1seg ) )
11062            voxdist = 0.0
11063            myspc = ants.get_spacing( t1seg )
11064            for k in range( t1seg.dimension ):
11065                voxdist = voxdist + myspc[k] * myspc[k]
11066            voxdist = math.sqrt( voxdist )
11067            nmorph = round( 2.0 / voxdist )
11068            convexhull = ants.morphology( convexhull, "close", nmorph ).iMath("FillHoles")
11069            dist = ants.iMath( convexhull, "MaurerDistance" ) * -1.0
11070            distmask = ants.threshold_image( dist, mmfromconvexhull, 1.e80 )
11071            wmseg_mask = wmseg_mask + distmask
11072            if strict:
11073                wmseg_mask_use = ants.threshold_image( wmseg_mask, 2, 2 )
11074            else:
11075                wmseg_mask_use = ants.threshold_image( wmseg_mask, 1, 2 )
11076    ##############################################################################
11077    wmseg_2_flair = ants.apply_transforms(flair, wmseg_mask_use,
11078        transformlist = t1_2_flair_reg['fwdtransforms'],
11079        interpolator = 'nearestNeighbor' )
11080    seg_2_flair = ants.apply_transforms(flair, t1seg,
11081        transformlist = t1_2_flair_reg['fwdtransforms'],
11082        interpolator = 'nearestNeighbor' )
11083    csfmask = ants.threshold_image(seg_2_flair,1,1)
11084    flairsnr = mask_snr( flair, csfmask, wmseg_2_flair, bias_correct = False )
11085    probability_mask_WM = wmseg_2_flair * probability_mask # Remove WMH signal outside of WM
11086    wmh_sum = np.prod( ants.get_spacing( flair ) ) * probability_mask_WM.sum()
11087    wmh_sum_prior = math.nan
11088    probability_mask_posterior = None
11089    if prior_probability_flair is not None:
11090        probability_mask_posterior = prior_probability_flair * probability_mask # use prior
11091        wmh_sum_prior = np.prod( ants.get_spacing(flair) ) * probability_mask_posterior.sum()
11092    if math.isnan( wmh_sum ):
11093        wmh_sum=0
11094    if math.isnan( wmh_sum_prior ):
11095        wmh_sum_prior=0
11096    flair_evr = antspyt1w.patch_eigenvalue_ratio( flair, 512, [16,16,16], evdepth = 0.9, mask=wmseg_2_flair )
11097    return{
11098        'WMH_probability_map_raw': probability_mask,
11099        'WMH_probability_map' : probability_mask_WM,
11100        'WMH_posterior_probability_map' : probability_mask_posterior,
11101        'wmh_mass': wmh_sum,
11102        'wmh_mass_prior': wmh_sum_prior,
11103        'wmh_evr' : flair_evr,
11104        'wmh_SNR' : flairsnr,
11105        'convexhull_mask': distmask }

Outputs the WMH probability mask and a summary single measurement

Arguments

flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped).

t1 : ANTsImage input 3-D T1 brain image (not skull-stripped).

t1seg : ANTsImage T1 segmentation image

mmfromconvexhull : float restrict WMH to regions that are WM or mmfromconvexhull mm away from the convex hull of the cerebrum. we choose a default value based on Figure 4 from: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf

strict: boolean - if True, only use convex hull distance

probability_mask : None - use to compute wmh just once - then this function just does refinement and summary

prior_probability : optional prior probability image in space of the input t1

model : either sysu or hyper

verbose : boolean

Returns

WMH probability map and a summary single measurement which is the sum of the WMH map

def remove_elements_from_numpy_array(original_array, indices_to_remove):
11148def remove_elements_from_numpy_array(original_array, indices_to_remove):
11149    """
11150    Remove specified elements or rows from a numpy array.
11151
11152    Parameters:
11153    original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed.
11154    indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.
11155
11156    Returns:
11157    numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None,
11158                   the function returns None.
11159    """
11160
11161    if original_array is None:
11162        return None
11163
11164    if original_array.ndim == 1:
11165        # Remove elements from a 1D array
11166        return np.delete(original_array, indices_to_remove)
11167    elif original_array.ndim == 2:
11168        # Remove rows from a 2D array
11169        return np.delete(original_array, indices_to_remove, axis=0)
11170    else:
11171        raise ValueError("original_array must be either 1D or 2D.")

Remove specified elements or rows from a numpy array.

Parameters: original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed. indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.

Returns: numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None, the function returns None.

def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg):
11474def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg ):
11475    """
11476    Process CBF time series to remove high-leverage points.
11477    Derived from the SCORE algorithm by Sudipto Dolui et. al.
11478
11479    Parameters:
11480    cbfts (ANTsImage): 4D ANTsImage of CBF time series.
11481    csf_seg (ANTsImage): CSF binary map.
11482    gm_seg (ANTsImage): Gray matter binary map.
11483    wm_seg (ANTsImage): WM binary map.
11484
11485    Returns:
11486    ANTsImage: Processed CBF time series.
11487    ndarray: Index of removed volumes.
11488    """
11489    
11490    n_gm_voxels = np.sum(gm_seg.numpy()) - 1
11491    n_wm_voxels = np.sum(wm_seg.numpy()) - 1
11492    n_csf_voxels = np.sum(csf_seg.numpy()) - 1
11493    mask1img = gm_seg + wm_seg + csf_seg
11494    mask1 = (mask1img==1).numpy()
11495    
11496    cbfts_np = cbfts.numpy()
11497    gmbool = (gm_seg==1).numpy()
11498    csfbool = (csf_seg==1).numpy()
11499    wmbool = (wm_seg==1).numpy()
11500    gm_cbf_ts = ants.timeseries_to_matrix( cbfts, gm_seg )
11501    gm_cbf_ts = np.squeeze(np.mean(gm_cbf_ts, axis=1))
11502    
11503    median_gm_cbf = np.median(gm_cbf_ts)
11504    mad_gm_cbf = np.median(np.abs(gm_cbf_ts - median_gm_cbf)) / 0.675
11505    indx = np.abs(gm_cbf_ts - median_gm_cbf) > (2.5 * mad_gm_cbf)
11506    
11507    # the spatial mean
11508    spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11509    spatmean = ants.from_numpy( spatmeannp )
11510    V = (
11511        n_gm_voxels * np.var(spatmeannp[gmbool])
11512        + n_wm_voxels * np.var(spatmeannp[wmbool])
11513        + n_csf_voxels * np.var(spatmeannp[csfbool])
11514    )
11515    V1 = math.inf
11516    ct=0
11517    while V < V1:
11518        ct=ct+1
11519        V1 = V
11520        CC = np.zeros(cbfts_np.shape[3])
11521        for s in range(cbfts_np.shape[3]):
11522            if indx[s]:
11523                continue
11524            tmp1 = ants.from_numpy( cbfts_np[:, :, :, s] )
11525            CC[s] = ants.image_similarity( spatmean, tmp1, metric_type='Correlation', fixed_mask=mask1img )
11526        inx = np.argmin(CC)
11527        indx[inx] = True
11528        spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11529        spatmean = ants.from_numpy( spatmeannp )
11530        V = (
11531          n_gm_voxels * np.var(spatmeannp[gmbool]) + 
11532          n_wm_voxels * np.var(spatmeannp[wmbool]) + 
11533          n_csf_voxels * np.var(spatmeannp[csfbool])
11534        )
11535    cbfts_recon = cbfts_np[:, :, :, ~indx]
11536    cbfts_recon = np.nan_to_num(cbfts_recon)
11537    cbfts_recon_ants = ants.from_numpy(cbfts_recon)
11538    cbfts_recon_ants = ants.copy_image_info(cbfts, cbfts_recon_ants)
11539    return cbfts_recon_ants, indx

Process CBF time series to remove high-leverage points. Derived from the SCORE algorithm by Sudipto Dolui et. al.

Parameters: cbfts (ANTsImage): 4D ANTsImage of CBF time series. csf_seg (ANTsImage): CSF binary map. gm_seg (ANTsImage): Gray matter binary map. wm_seg (ANTsImage): WM binary map.

Returns: ANTsImage: Processed CBF time series. ndarray: Index of removed volumes.

def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11173def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11174    """
11175    Remove specified volumes from a time series.
11176
11177    :param time_series: ANTsImage representing the time series (4D image).
11178    :param volumes_to_remove: List of volume indices to remove.
11179    :return: ANTsImage with specified volumes removed.
11180    """
11181    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11182        raise ValueError("time_series must be an ANTsImage.")
11183
11184    if time_series.dimension != 4:
11185        raise ValueError("time_series must be a 4D image.")
11186
11187    # Create a boolean index for volumes to keep
11188    volumes_to_keep = [i for i in range(time_series.shape[3]) if i not in volumes_to_remove]
11189
11190    # Select the volumes to keep
11191    filtered_time_series = ants.from_numpy( time_series.numpy()[..., volumes_to_keep] )
11192
11193    return ants.copy_image_info( time_series, filtered_time_series )

Remove specified volumes from a time series.

Parameters
  • time_series: ANTsImage representing the time series (4D image).
  • volumes_to_remove: List of volume indices to remove.
Returns

ANTsImage with specified volumes removed.

def loop_timeseries_censoring( x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11541def loop_timeseries_censoring(x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11542    """
11543    Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).
11544
11545    Parameters:
11546    x (ANTsImage): A 4D time series image.
11547    threshold (float): Threshold for determining high leverage volumes based on LoOP scores.
11548    mask (antsImage): restricts to a ROI
11549    n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used
11550    seed (int): random seed
11551    verbose (bool)
11552
11553    Returns:
11554    tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.
11555    """
11556    import warnings
11557    if x.shape[3] < 20: # just a guess at what we need here ...
11558        warnings.warn("Warning: the time dimension is < 20 - too few samples for loop. just return the original data.")
11559        return x, []
11560    if mask is None:
11561        flattened_series = flatten_time_series(x.numpy())
11562    else:
11563        flattened_series = ants.timeseries_to_matrix( x, mask )
11564    if verbose:
11565        print("loop_timeseries_censoring: flattened")
11566    loop_scores = calculate_loop_scores(flattened_series, n_features_sample=n_features_sample, seed=seed, verbose=verbose )
11567    high_leverage_volumes = np.where(loop_scores > threshold)[0]
11568    if verbose:
11569        print("loop_timeseries_censoring: High Leverage Volumes:", high_leverage_volumes)
11570    new_asl = remove_volumes_from_timeseries(x, high_leverage_volumes)
11571    return new_asl, high_leverage_volumes

Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).

Parameters: x (ANTsImage): A 4D time series image. threshold (float): Threshold for determining high leverage volumes based on LoOP scores. mask (antsImage): restricts to a ROI n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used seed (int): random seed verbose (bool)

Returns: tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.

def clean_tmp_directory( age_hours=1.0, use_sudo=False, extensions=['.nii', '.nii.gz'], log_file_path=None):
469def clean_tmp_directory(age_hours=1., use_sudo=False, extensions=[ '.nii', '.nii.gz' ], log_file_path=None):
470    """
471    Clean the /tmp directory by removing files and directories older than a certain number of hours.
472    Optionally uses sudo and can filter files by extensions.
473
474    :param age_hours: Age in hours to consider files and directories for deletion.
475    :param use_sudo: Whether to use sudo for removal commands.
476    :param extensions: List of file extensions to delete. If None, all files are considered.
477    :param log_file_path: Path to the log file. If None, a default path will be used based on the OS.
478
479    # Usage
480    # Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])
481    """
482    import os
483    import platform
484    import subprocess
485    from datetime import datetime, timedelta
486
487    if not isinstance(age_hours, float):
488        return
489
490    # Determine the tmp directory based on the operating system
491    tmp_dir = '/tmp'
492
493    # Set the log file path
494    if log_file_path is not None:
495        log_file = log_file_path
496
497    current_time = datetime.now()
498    for item in os.listdir(tmp_dir):
499        try:
500            item_path = os.path.join(tmp_dir, item)
501            item_stat = os.stat(item_path)
502
503            # Calculate the age of the file/directory
504            item_age = current_time - datetime.fromtimestamp(item_stat.st_mtime)
505            if item_age > timedelta(hours=age_hours):
506                # Check for file extensions if provided
507                if extensions is None or any(item.endswith(ext) for ext in extensions):
508                    # Construct the removal command
509                    rm_command = ['sudo', 'rm', '-rf', item_path] if use_sudo else ['rm', '-rf', item_path]
510                    subprocess.run(rm_command)
511
512                if log_file_path is not None:
513                    with open(log_file, 'a') as log:
514                        log.write(f"{datetime.now()}: Deleted {item_path}\n")
515        except Exception as e:
516            if log_file_path is not None:
517                with open(log_file, 'a') as log:
518                    log.write(f"{datetime.now()}: Error deleting {item_path}: {e}\n")

Clean the /tmp directory by removing files and directories older than a certain number of hours. Optionally uses sudo and can filter files by extensions.

Parameters
  • age_hours: Age in hours to consider files and directories for deletion.
  • use_sudo: Whether to use sudo for removal commands.
  • extensions: List of file extensions to delete. If None, all files are considered.
  • log_file_path: Path to the log file. If None, a default path will be used based on the OS.

Usage

Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])

def validate_nrg_file_format(path, separator):
189def validate_nrg_file_format(path, separator):
190    """
191    is your path nrg-etic?
192    Validates if a given path conforms to the NRG file format, taking into account known extensions
193    and the expected directory structure.
194
195    :param path: The file path to validate.
196    :param separator: The separator used in the filename and directory structure.
197    :return: A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.
198
199    : example
200
201    ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz'
202    ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'
203
204    validate_nrg_detailed(ntfngood, '-')
205    print( validate_nrg_detailed(ntfn, '-') )
206    print( validate_nrg_detailed(ntfn, '_') )
207
208    """
209    import re    
210
211    def normalize_path(path):
212        """
213        Replace multiple repeated '/' with just a single '/'
214        
215        :param path: The file path to normalize.
216        :return: The normalized file path with single '/'.
217        """
218        normalized_path = re.sub(r'/+', '/', path)
219        return normalized_path
220
221    def strip_known_extension(filename, known_extensions):
222        """
223        Strips a known extension from the filename.
224
225        :param filename: The filename from which to strip the extension.
226        :param known_extensions: A list of known extensions to strip from the filename.
227        :return: The filename with the known extension stripped off, if found.
228        """
229        for ext in known_extensions:
230            if filename.endswith(ext):
231                # Strip the extension and return the modified filename
232                return filename[:-len(ext)]
233        # If no known extension is found, return the original filename
234        return filename
235
236    import warnings
237    if normalize_path( path ) != path:
238        path = normalize_path( path )
239        warnings.warn("Probably had multiple repeated slashes eg /// in the file path.  this might cause issues. clean up with re.sub(r'/+', '/', path)")
240
241    known_extensions = [".nii.gz", ".nii", ".mhd", ".nrrd", ".mha", ".json", ".bval", ".bvec"]
242    known_extensions2 = [ext.lstrip('.') for ext in known_extensions]
243    def get_extension(filename, known_extensions ):
244        # List of known extensions in priority order
245        for ext in known_extensions:
246            if filename.endswith(ext):
247                return ext.strip('.')
248        return "Invalid extension"
249    
250    parts = path.split('/')
251    if len(parts) < 7:  # Checking for minimum path structure
252        return False, "Path structure is incomplete. Expected at least 7 components, found {}.".format(len(parts))
253    
254    # Extract directory components and filename
255    directory_components = parts[1:-1]  # Exclude the root '/' and filename
256    filename = parts[-1]
257    filename_without_extension = strip_known_extension( filename, known_extensions )
258    file_extension = get_extension( filename, known_extensions )
259    
260    # Validating file extension
261    if file_extension not in known_extensions2:
262        print( file_extension )
263        return False, "Invalid file extension: {}. Expected 'nii.gz' or 'json'.".format(file_extension)
264    
265    # Splitting the filename to validate individual parts
266    filename_parts = filename_without_extension.split(separator)
267    if len(filename_parts) != 5:  # Expecting 5 parts based on the NRG format
268        print( filename_parts )
269        return False, "Filename does not have exactly 5 parts separated by '{}'. Found {} parts.".format(separator, len(filename_parts))
270    
271    # Reconstruct expected filename from directory components
272    expected_filename_parts = directory_components[-5:]
273    expected_filename = separator.join(expected_filename_parts)
274    if filename_without_extension != expected_filename:
275        print( filename_without_extension )
276        print("--- vs expected ---")
277        print( expected_filename )
278        return False, "Filename structure does not match directory structure. Expected filename: {}.".format(expected_filename)
279    
280    # Validate directory structure against NRG format
281    study_name, subject_id, session, modality = directory_components[-4:-1] + [directory_components[-1].split('/')[0]]
282    if not all([study_name, subject_id, session, modality]):
283        return False, "Directory structure does not follow NRG format. Ensure StudyName, SubjectID, Session (ses_x), and Modality are correctly specified."
284    
285    # If all checks pass
286    return True, "The path conforms to the NRG format."

is your path nrg-etic? Validates if a given path conforms to the NRG file format, taking into account known extensions and the expected directory structure.

Parameters
  • path: The file path to validate.
  • separator: The separator used in the filename and directory structure.
Returns

A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.

: example

ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz' ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'

validate_nrg_detailed(ntfngood, '-') print( validate_nrg_detailed(ntfn, '-') ) print( validate_nrg_detailed(ntfn, '_') )

def ants_to_nibabel_affine(ants_img):
390def ants_to_nibabel_affine(ants_img):
391    """
392    Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space).
393    Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).
394    
395    Returns:
396        4x4 np.ndarray affine matrix in RAS space.
397    """
398    spatial_dim = ants_img.dimension
399    spacing = np.array(ants_img.spacing)
400    origin = np.array(ants_img.origin)
401    direction = np.array(ants_img.direction).reshape((spatial_dim, spatial_dim))
402    # Compute rotation-scale matrix
403    affine_linear = direction @ np.diag(spacing)
404    # Build full 4x4 affine with identity in homogeneous bottom row
405    affine = np.eye(4)
406    affine[:spatial_dim, :spatial_dim] = affine_linear
407    affine[:spatial_dim, 3] = origin
408    affine[3, 3]=1
409    # Convert LPS -> RAS by flipping x and y
410    lps_to_ras = np.diag([-1, -1, 1, 1])
411    affine = lps_to_ras @ affine
412    return affine

Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space). Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).

Returns: 4x4 np.ndarray affine matrix in RAS space.

def dict_to_dataframe( data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
415def dict_to_dataframe(data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
416    """
417    Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.
418
419    :param data_dict: Dictionary to be converted.
420    :param convert_lists: boolean
421    :param convert_arrays: boolean
422    :param convert_images: boolean
423    :param verbose: boolean
424    :return: DataFrame representation of the dictionary.
425    """
426    processed_data = {}
427    list_length = None
428    def mean_of_list(lst):
429        if not lst:  # Check if the list is not empty
430            return 0  # Return 0 or appropriate value for an empty list
431        all_numeric = all(isinstance(item, (int, float)) for item in lst)
432        if all_numeric:
433            return sum(lst) / len(lst)
434        return None
435    
436    for key, value in data_dict.items():
437        # Check if value is a scalar
438        if isinstance(value, (int, float, str, bool)):
439            processed_data[key] = [value]
440        # Check if value is a list of scalars
441        elif isinstance(value, list) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_lists:
442            meanvalue = mean_of_list( value )
443            newkey = key+"_mean"
444            if verbose:
445                print( " Key " + key + " is list with mean " + str(meanvalue) + " to " + newkey )
446            if newkey not in data_dict.keys() and convert_lists:
447                processed_data[newkey] = meanvalue
448        elif isinstance(value, np.ndarray) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_arrays:
449            meanvalue = value.mean()
450            newkey = key+"_mean"
451            if verbose:
452                print( " Key " + key + " is nparray with mean " + str(meanvalue) + " to " + newkey )
453            if newkey not in data_dict.keys():
454                processed_data[newkey] = meanvalue
455        elif isinstance(value, ants.core.ants_image.ANTsImage ) and convert_images:
456            meanvalue = value.mean()
457            newkey = key+"_mean"
458            if newkey not in data_dict.keys():
459                if verbose:
460                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " to " + newkey )
461                processed_data[newkey] = meanvalue
462            else:
463                if verbose:
464                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " but " + newkey + " already exists" )
465
466    return pd.DataFrame.from_dict(processed_data)

Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.

Parameters
  • data_dict: Dictionary to be converted.
  • convert_lists: boolean
  • convert_arrays: boolean
  • convert_images: boolean
  • verbose: boolean
Returns

DataFrame representation of the dictionary.