antspymm.mm

    1__all__ = ['version',
    2    'mm_read',
    3    'mm_read_to_3d',
    4    'image_write_with_thumbnail',
    5    'nrg_format_path',
    6    'highest_quality_repeat',
    7    'match_modalities',
    8    'mc_resample_image_to_target',
    9    'nrg_filelist_to_dataframe',
   10    'merge_timeseries_data',
   11    'timeseries_reg',
   12    'merge_dwi_data',
   13    'outlierness_by_modality',
   14    'bvec_reorientation',
   15    'get_dti',
   16    'dti_reg',
   17    'mc_reg',
   18    'get_data',
   19    'get_models',
   20    'get_valid_modalities',
   21    'dewarp_imageset',
   22    'super_res_mcimage',
   23    'segment_timeseries_by_meanvalue',
   24    'get_average_rsf',
   25    'get_average_dwi_b0',
   26    'dti_template',
   27    't1_based_dwi_brain_extraction',
   28    'mc_denoise',
   29    'tsnr',
   30    'dvars',
   31    'slice_snr',
   32    'impute_fa',
   33    'trim_dti_mask',
   34    'dipy_dti_recon',
   35    'concat_dewarp',
   36    'joint_dti_recon',
   37    'middle_slice_snr',
   38    'foreground_background_snr',
   39    'quantile_snr',
   40    'mask_snr',
   41    'dwi_deterministic_tracking',
   42    'dwi_closest_peak_tracking',
   43    'dwi_streamline_pairwise_connectivity',
   44    'dwi_streamline_connectivity',
   45    'hierarchical_modality_summary',
   46    'tra_initializer',
   47    'neuromelanin',
   48    'resting_state_fmri_networks',
   49    'write_bvals_bvecs',
   50    'crop_mcimage',
   51    'mm',
   52    'write_mm',
   53    'mm_nrg',
   54    'mm_csv',
   55    'collect_blind_qc_by_modality',
   56    'alffmap',
   57    'alff_image',
   58    'down2iso',
   59    'read_mm_csv',
   60    'assemble_modality_specific_dataframes',
   61    'bind_wide_mm_csvs',
   62    'merge_mm_dataframe',
   63    'augment_image',
   64    'boot_wmh',
   65    'threaded_bind_wide_mm_csvs',
   66    'get_names_from_data_frame',
   67    'average_mm_df',
   68    'quick_viz_mm_nrg',
   69    'blind_image_assessment',
   70    'average_blind_qc_by_modality',
   71    'best_mmm',
   72    'nrg_2_bids',
   73    'bids_2_nrg',
   74    'parse_nrg_filename',
   75    'novelty_detection_svm',
   76    'novelty_detection_ee',
   77    'novelty_detection_lof',
   78    'novelty_detection_loop',
   79    'novelty_detection_quantile',
   80    'generate_mm_dataframe',
   81    'aggregate_antspymm_results',
   82    'aggregate_antspymm_results_sdf',
   83    'study_dataframe_from_matched_dataframe',
   84    'merge_wides_to_study_dataframe',
   85    'filter_image_files',
   86    'docsamson',
   87    'enantiomorphic_filling_without_mask',
   88    'wmh',
   89    'remove_elements_from_numpy_array',
   90    'score_fmri_censoring',
   91    'remove_volumes_from_timeseries',
   92    'loop_timeseries_censoring',
   93    'clean_tmp_directory',
   94    'validate_nrg_file_format',
   95    'ants_to_nibabel_affine',
   96    'dict_to_dataframe']
   97
   98from pathlib import Path
   99from pathlib import PurePath
  100import os
  101import pandas as pd
  102import math
  103import os.path
  104from os import path
  105import pickle
  106import sys
  107import numpy as np
  108import random
  109import functools
  110from operator import mul
  111from scipy.sparse.linalg import svds
  112from scipy.stats import pearsonr
  113import re
  114import datetime as dt
  115from collections import Counter
  116import tempfile
  117import uuid
  118import warnings
  119
  120from dipy.core.histeq import histeq
  121import dipy.reconst.dti as dti
  122from dipy.core.gradients import (gradient_table, gradient_table_from_gradient_strength_bvecs)
  123from dipy.io.gradients import read_bvals_bvecs
  124from dipy.segment.mask import median_otsu
  125from dipy.reconst.dti import fractional_anisotropy, color_fa
  126import nibabel as nib
  127
  128import ants
  129import antspynet
  130import antspyt1w
  131import siq
  132import tensorflow as tf
  133
  134from multiprocessing import Pool
  135import glob as glob
  136
  137antspyt1w.set_global_scientific_computing_random_seed(
  138    antspyt1w.get_global_scientific_computing_random_seed( )
  139)
  140
  141DATA_PATH = os.path.expanduser('~/.antspymm/')
  142
  143def version( ):
  144    """
  145    report versions of this package and primary dependencies
  146
  147    Arguments
  148    ---------
  149    None
  150
  151    Returns
  152    -------
  153    a dictionary with package name and versions
  154
  155    Example
  156    -------
  157    >>> import antspymm
  158    >>> antspymm.version()
  159    """
  160    import pkg_resources
  161    return {
  162              'tensorflow': pkg_resources.get_distribution("tensorflow").version,
  163              'antspyx': pkg_resources.get_distribution("antspyx").version,
  164              'antspynet': pkg_resources.get_distribution("antspynet").version,
  165              'antspyt1w': pkg_resources.get_distribution("antspyt1w").version,
  166              'antspymm': pkg_resources.get_distribution("antspymm").version
  167              }
  168
  169def nrg_filename_to_subjectvisit(s, separator='-'):
  170    """
  171    Extracts a pattern from the input string.
  172    
  173    Parameters:
  174    - s: The input string from which to extract the pattern.
  175    - separator: The separator used in the string (default is '-').
  176    
  177    Returns:
  178    - A string in the format of 'PREFIX-Number-Date'
  179    """
  180    parts = os.path.basename(s).split(separator)
  181    # Assuming the pattern is always in the form of PREFIX-Number-Date-...
  182    # and PREFIX is always "PPMI", extract the first three parts
  183    extracted = separator.join(parts[:3])
  184    return extracted
  185
  186
  187def validate_nrg_file_format(path, separator):
  188    """
  189    is your path nrg-etic?
  190    Validates if a given path conforms to the NRG file format, taking into account known extensions
  191    and the expected directory structure.
  192
  193    :param path: The file path to validate.
  194    :param separator: The separator used in the filename and directory structure.
  195    :return: A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.
  196
  197    : example
  198
  199    ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz'
  200    ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'
  201
  202    validate_nrg_detailed(ntfngood, '-')
  203    print( validate_nrg_detailed(ntfn, '-') )
  204    print( validate_nrg_detailed(ntfn, '_') )
  205
  206    """
  207    import re    
  208
  209    def normalize_path(path):
  210        """
  211        Replace multiple repeated '/' with just a single '/'
  212        
  213        :param path: The file path to normalize.
  214        :return: The normalized file path with single '/'.
  215        """
  216        normalized_path = re.sub(r'/+', '/', path)
  217        return normalized_path
  218
  219    def strip_known_extension(filename, known_extensions):
  220        """
  221        Strips a known extension from the filename.
  222
  223        :param filename: The filename from which to strip the extension.
  224        :param known_extensions: A list of known extensions to strip from the filename.
  225        :return: The filename with the known extension stripped off, if found.
  226        """
  227        for ext in known_extensions:
  228            if filename.endswith(ext):
  229                # Strip the extension and return the modified filename
  230                return filename[:-len(ext)]
  231        # If no known extension is found, return the original filename
  232        return filename
  233
  234    import warnings
  235    if normalize_path( path ) != path:
  236        path = normalize_path( path )
  237        warnings.warn("Probably had multiple repeated slashes eg /// in the file path.  this might cause issues. clean up with re.sub(r'/+', '/', path)")
  238
  239    known_extensions = [".nii.gz", ".nii", ".mhd", ".nrrd", ".mha", ".json", ".bval", ".bvec"]
  240    known_extensions2 = [ext.lstrip('.') for ext in known_extensions]
  241    def get_extension(filename, known_extensions ):
  242        # List of known extensions in priority order
  243        for ext in known_extensions:
  244            if filename.endswith(ext):
  245                return ext.strip('.')
  246        return "Invalid extension"
  247    
  248    parts = path.split('/')
  249    if len(parts) < 7:  # Checking for minimum path structure
  250        return False, "Path structure is incomplete. Expected at least 7 components, found {}.".format(len(parts))
  251    
  252    # Extract directory components and filename
  253    directory_components = parts[1:-1]  # Exclude the root '/' and filename
  254    filename = parts[-1]
  255    filename_without_extension = strip_known_extension( filename, known_extensions )
  256    file_extension = get_extension( filename, known_extensions )
  257    
  258    # Validating file extension
  259    if file_extension not in known_extensions2:
  260        print( file_extension )
  261        return False, "Invalid file extension: {}. Expected 'nii.gz' or 'json'.".format(file_extension)
  262    
  263    # Splitting the filename to validate individual parts
  264    filename_parts = filename_without_extension.split(separator)
  265    if len(filename_parts) != 5:  # Expecting 5 parts based on the NRG format
  266        print( filename_parts )
  267        return False, "Filename does not have exactly 5 parts separated by '{}'. Found {} parts.".format(separator, len(filename_parts))
  268    
  269    # Reconstruct expected filename from directory components
  270    expected_filename_parts = directory_components[-5:]
  271    expected_filename = separator.join(expected_filename_parts)
  272    if filename_without_extension != expected_filename:
  273        print( filename_without_extension )
  274        print("--- vs expected ---")
  275        print( expected_filename )
  276        return False, "Filename structure does not match directory structure. Expected filename: {}.".format(expected_filename)
  277    
  278    # Validate directory structure against NRG format
  279    study_name, subject_id, session, modality = directory_components[-4:-1] + [directory_components[-1].split('/')[0]]
  280    if not all([study_name, subject_id, session, modality]):
  281        return False, "Directory structure does not follow NRG format. Ensure StudyName, SubjectID, Session (ses_x), and Modality are correctly specified."
  282    
  283    # If all checks pass
  284    return True, "The path conforms to the NRG format."
  285
  286
  287
  288def apply_transforms_mixed_interpolation(
  289    fixed,
  290    moving,
  291    transformlist,
  292    interpolator="linear",
  293    imagetype=0,
  294    whichtoinvert=None,
  295    mask=None,
  296    **kwargs
  297):
  298    """
  299    Apply ANTs transforms with mixed interpolation:
  300    - Linear interpolation inside `mask`
  301    - Nearest neighbor outside `mask`
  302
  303    Parameters
  304    ----------
  305    fixed : ANTsImage
  306        Fixed/reference image to define spatial domain.
  307
  308    moving : ANTsImage
  309        Moving image to be transformed.
  310
  311    transformlist : list of str
  312        List of filenames for transforms.
  313
  314    interpolator : str, optional
  315        Interpolator used inside the mask. Default is "linear".
  316
  317    imagetype : int
  318        Image type used by ANTs (0 = scalar, 1 = vector, etc.)
  319
  320    whichtoinvert : list of bool, optional
  321        List of booleans indicating which transforms to invert.
  322
  323    mask : ANTsImage
  324        Binary mask image indicating where to apply `interpolator` (e.g., "linear").
  325        Outside the mask, nearest neighbor is used.
  326
  327    kwargs : dict
  328        Additional arguments passed to `ants.apply_transforms`.
  329
  330    Returns
  331    -------
  332    ANTsImage
  333        Interpolated image using mixed interpolation, added across masked regions.
  334    """
  335    if mask is None:
  336        raise ValueError("A binary `mask` image must be provided.")
  337
  338    # Apply linear interpolation inside the mask
  339    interp_linear = ants.apply_transforms(
  340        fixed=fixed,
  341        moving=moving,
  342        transformlist=transformlist,
  343        interpolator=interpolator,
  344        imagetype=imagetype,
  345        whichtoinvert=whichtoinvert,
  346        **kwargs
  347    )
  348
  349    # Apply nearest-neighbor interpolation everywhere
  350    interp_nn = ants.apply_transforms(
  351        fixed=fixed,
  352        moving=moving,
  353        transformlist=transformlist,
  354        interpolator="nearestNeighbor",
  355        imagetype=imagetype,
  356        whichtoinvert=whichtoinvert,
  357        **kwargs
  358    )
  359
  360    # Combine: linear * mask + nn * (1 - mask)
  361    mixed_result = (interp_linear * mask) + (interp_nn * (1 - mask))
  362
  363    return mixed_result
  364
  365def get_antsimage_keys(dictionary):
  366    """
  367    Return the keys of the dictionary where the values are ANTsImages.
  368
  369    :param dictionary: A dictionary to inspect
  370    :return: A list of keys for which the values are ANTsImages
  371    """
  372    return [key for key, value in dictionary.items() if isinstance(value, ants.core.ants_image.ANTsImage)]
  373
  374def to_nibabel(img: "ants.core.ants_image.ANTsImage") -> nib.Nifti1Image:
  375    """
  376    Convert an ANTsPy image to a Nibabel Nifti1Image in-memory, using correct spatial affine.
  377
  378    Parameters:
  379        img (ants.ANTsImage): An image from ANTsPy.
  380
  381    Returns:
  382        nib.Nifti1Image: The corresponding Nibabel image with spatial orientation in RAS.
  383    """
  384    array_data = img.numpy()  # get voxel data as NumPy array
  385    affine = ants_to_nibabel_affine(img)
  386    return nib.Nifti1Image(array_data, affine)
  387
  388def ants_to_nibabel_affine(ants_img):
  389    """
  390    Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space).
  391    Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).
  392    
  393    Returns:
  394        4x4 np.ndarray affine matrix in RAS space.
  395    """
  396    spatial_dim = ants_img.dimension
  397    spacing = np.array(ants_img.spacing)
  398    origin = np.array(ants_img.origin)
  399    direction = np.array(ants_img.direction).reshape((spatial_dim, spatial_dim))
  400    # Compute rotation-scale matrix
  401    affine_linear = direction @ np.diag(spacing)
  402    # Build full 4x4 affine with identity in homogeneous bottom row
  403    affine = np.eye(4)
  404    affine[:spatial_dim, :spatial_dim] = affine_linear
  405    affine[:spatial_dim, 3] = origin
  406    affine[3, 3]=1
  407    # Convert LPS -> RAS by flipping x and y
  408    lps_to_ras = np.diag([-1, -1, 1, 1])
  409    affine = lps_to_ras @ affine
  410    return affine
  411
  412
  413def dict_to_dataframe(data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
  414    """
  415    Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.
  416
  417    :param data_dict: Dictionary to be converted.
  418    :param convert_lists: boolean
  419    :param convert_arrays: boolean
  420    :param convert_images: boolean
  421    :param verbose: boolean
  422    :return: DataFrame representation of the dictionary.
  423    """
  424    processed_data = {}
  425    list_length = None
  426    def mean_of_list(lst):
  427        if not lst:  # Check if the list is not empty
  428            return 0  # Return 0 or appropriate value for an empty list
  429        all_numeric = all(isinstance(item, (int, float)) for item in lst)
  430        if all_numeric:
  431            return sum(lst) / len(lst)
  432        return None
  433    
  434    for key, value in data_dict.items():
  435        # Check if value is a scalar
  436        if isinstance(value, (int, float, str, bool)):
  437            processed_data[key] = [value]
  438        # Check if value is a list of scalars
  439        elif isinstance(value, list) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_lists:
  440            meanvalue = mean_of_list( value )
  441            newkey = key+"_mean"
  442            if verbose:
  443                print( " Key " + key + " is list with mean " + str(meanvalue) + " to " + newkey )
  444            if newkey not in data_dict.keys() and convert_lists:
  445                processed_data[newkey] = meanvalue
  446        elif isinstance(value, np.ndarray) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_arrays:
  447            meanvalue = value.mean()
  448            newkey = key+"_mean"
  449            if verbose:
  450                print( " Key " + key + " is nparray with mean " + str(meanvalue) + " to " + newkey )
  451            if newkey not in data_dict.keys():
  452                processed_data[newkey] = meanvalue
  453        elif isinstance(value, ants.core.ants_image.ANTsImage ) and convert_images:
  454            meanvalue = value.mean()
  455            newkey = key+"_mean"
  456            if newkey not in data_dict.keys():
  457                if verbose:
  458                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " to " + newkey )
  459                processed_data[newkey] = meanvalue
  460            else:
  461                if verbose:
  462                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " but " + newkey + " already exists" )
  463
  464    return pd.DataFrame.from_dict(processed_data)
  465
  466
  467def clean_tmp_directory(age_hours=1., use_sudo=False, extensions=[ '.nii', '.nii.gz' ], log_file_path=None):
  468    """
  469    Clean the /tmp directory by removing files and directories older than a certain number of hours.
  470    Optionally uses sudo and can filter files by extensions.
  471
  472    :param age_hours: Age in hours to consider files and directories for deletion.
  473    :param use_sudo: Whether to use sudo for removal commands.
  474    :param extensions: List of file extensions to delete. If None, all files are considered.
  475    :param log_file_path: Path to the log file. If None, a default path will be used based on the OS.
  476
  477    # Usage
  478    # Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])
  479    """
  480    import os
  481    import platform
  482    import subprocess
  483    from datetime import datetime, timedelta
  484
  485    if not isinstance(age_hours, float):
  486        return
  487
  488    # Determine the tmp directory based on the operating system
  489    tmp_dir = '/tmp'
  490
  491    # Set the log file path
  492    if log_file_path is not None:
  493        log_file = log_file_path
  494
  495    current_time = datetime.now()
  496    for item in os.listdir(tmp_dir):
  497        try:
  498            item_path = os.path.join(tmp_dir, item)
  499            item_stat = os.stat(item_path)
  500
  501            # Calculate the age of the file/directory
  502            item_age = current_time - datetime.fromtimestamp(item_stat.st_mtime)
  503            if item_age > timedelta(hours=age_hours):
  504                # Check for file extensions if provided
  505                if extensions is None or any(item.endswith(ext) for ext in extensions):
  506                    # Construct the removal command
  507                    rm_command = ['sudo', 'rm', '-rf', item_path] if use_sudo else ['rm', '-rf', item_path]
  508                    subprocess.run(rm_command)
  509
  510                if log_file_path is not None:
  511                    with open(log_file, 'a') as log:
  512                        log.write(f"{datetime.now()}: Deleted {item_path}\n")
  513        except Exception as e:
  514            if log_file_path is not None:
  515                with open(log_file, 'a') as log:
  516                    log.write(f"{datetime.now()}: Error deleting {item_path}: {e}\n")
  517
  518
  519
  520def docsamson(locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
  521    """
  522    Processes image file names based on the specified imaging modality and other parameters.
  523
  524    The function selects file names from the provided dictionary `studycsv` based on the imaging modality.
  525    It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly.
  526    The function can optionally print verbose output during processing.
  527
  528    Parameters:
  529    locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'.
  530    studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names.
  531    outputdir (str): Base directory for output files.
  532    projid (str): Project identifier.
  533    sid (str): Subject identifier.
  534    dtid (str): Data acquisition time identifier.
  535    mysep (str): Separator used in file naming.
  536    t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'.
  537    verbose (bool, optional): If True, prints detailed information during execution.
  538
  539    Returns:
  540    dict: A dictionary with keys 'modality', 'outprefix', and 'images'.
  541        - 'modality' (str): The imaging modality used.
  542        - 'outprefix' (str): The prefix for output file paths.
  543        - 'images' (list): A list of processed image file names.
  544
  545    Notes:
  546    - The function is designed to work within a specific workflow and might require adaptation for general use.
  547
  548    Examples:
  549    >>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
  550    >>> print(result['modality'])
  551    'T1w'
  552    >>> print(result['outprefix'])
  553    '/path/to/output/directory/T1w/some_identifier'
  554    >>> print(result['images'])
  555    ['image1.nii', 'image2.nii']
  556    """
  557
  558    import os
  559    import re
  560
  561    myimgsInput = []
  562    myoutputPrefix = None
  563    imfns = ['filename', 'rsfid1', 'rsfid2', 'dtid1', 'dtid2', 'flairid']
  564    
  565    # Define image file names based on the modality
  566    if locmod == 'T1w':
  567        imfns=['filename']
  568    elif locmod == 'T2Flair':
  569        imfns=['flairid']
  570    elif locmod == 'perf':
  571        imfns=['perfid']
  572    elif locmod == 'pet3d':
  573        imfns=['pet3did']
  574    elif locmod == 'NM2DMT':
  575        imfns=[]
  576        for i in range(11):
  577            imfns.append('nmid' + str(i))
  578    elif locmod == 'rsfMRI':
  579        imfns=[]
  580        for i in range(4):
  581            imfns.append('rsfid' + str(i))
  582    elif locmod == 'DTI':
  583        imfns=[]
  584        for i in range(4):
  585            imfns.append('dtid' + str(i))
  586    else:
  587        raise ValueError("docsamson: no match of modality to filename id " + locmod )
  588
  589    # Process each file name
  590    for i in imfns:
  591        if verbose:
  592            print(i + " " + locmod)
  593        if i in studycsv.keys():
  594            fni = str(studycsv[i].iloc[0])
  595            if verbose:
  596                print(i + " " + fni + ' exists ' + str(os.path.exists(fni)))
  597            if os.path.exists(fni):
  598                myimgsInput.append(fni)
  599                temp = os.path.basename(fni)
  600                mysplit = temp.split(mysep)
  601                iid = re.sub(".nii.gz", "", mysplit[-1])
  602                iid = re.sub(".mha", "", iid)
  603                iid = re.sub(".nii", "", iid)
  604                iid2 = iid
  605                if locmod != 'T1w' and t1iid is not None:
  606                    iid2 = iid + "_" + t1iid
  607                else:
  608                    iid2 = t1iid
  609                myoutputPrefix = os.path.join(outputdir, projid, sid, dtid, locmod, iid, projid + mysep + sid + mysep + dtid + mysep + locmod + mysep + iid2)
  610    
  611    if verbose:
  612        print(locmod)
  613        print(myimgsInput)
  614        print(myoutputPrefix)
  615    
  616    return {
  617        'modality': locmod,
  618        'outprefix': myoutputPrefix,
  619        'images': myimgsInput
  620    }
  621
  622
  623def get_valid_modalities( long=False, asString=False, qc=False ):
  624    """
  625    return a list of valid modality identifiers used in NRG modality designation
  626    and that can be processed by this package.
  627
  628    long - return the long version
  629
  630    asString - concat list to string
  631    """
  632    if long:
  633        mymod = ["T1w", "NM2DMT", "rsfMRI", "rsfMRI_LR", "rsfMRI_RL", "rsfMRILR", "rsfMRIRL", "DTI", "DTI_LR","DTI_RL",  "DTILR","DTIRL","T2Flair", "dwi", "dwi_ap", "dwi_pa", "func", "func_ap", "func_pa", "perf", 'pet3d']
  634    elif qc:
  635        mymod = [ 'T1w', 'T2Flair', 'NM2DMT', 'DTI', 'DTIdwi','DTIb0', 'rsfMRI', "perf", 'pet3d' ]
  636    else:
  637        mymod = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI", "perf", 'pet3d' ]
  638    if not asString:
  639        return mymod
  640    else:
  641        mymodchar=""
  642        for x in mymod:
  643            mymodchar = mymodchar + " " + str(x)
  644        return mymodchar
  645
  646
  647def generate_mm_dataframe(
  648        projectID,
  649        subjectID,
  650        date,
  651        imageUniqueID,
  652        modality,
  653        source_image_directory,
  654        output_image_directory,
  655        t1_filename,
  656        flair_filename=[],
  657        rsf_filenames=[],
  658        dti_filenames=[],
  659        nm_filenames=[],
  660        perf_filename=[],
  661        pet3d_filename=[],
  662):
  663    """
  664    Generate a DataFrame for medical imaging data with extensive validation of input parameters.
  665
  666    This function creates a DataFrame containing information about medical imaging files,
  667    ensuring that filenames match expected patterns for their modalities and that all
  668    required images exist. It also validates the number of filenames provided for specific
  669    modalities like rsfMRI, DTI, and NM.
  670
  671    Parameters:
  672    - projectID (str): Project identifier.
  673    - subjectID (str): Subject identifier.
  674    - date (str): Date of the imaging study.
  675    - imageUniqueID (str): Unique image identifier.
  676    - modality (str): Modality of the imaging study.
  677    - source_image_directory (str): Directory of the source images.
  678    - output_image_directory (str): Directory for output images.
  679    - t1_filename (str): Filename of the T1-weighted image.
  680    - flair_filename (list): List of filenames for FLAIR images.
  681    - rsf_filenames (list): List of filenames for rsfMRI images.
  682    - dti_filenames (list): List of filenames for DTI images.
  683    - nm_filenames (list): List of filenames for NM images.
  684    - perf_filename (list): List of filenames for perfusion images.
  685    - pet3d_filename (list): List of filenames for pet3d images.
  686
  687    Returns:
  688    - pandas.DataFrame: A DataFrame containing the validated imaging study information.
  689
  690    Raises:
  691    - ValueError: If any validation checks fail or if the number of columns does not match the data.
  692    """
  693    def check_pd_construction(data, columns):
  694        # Check if the length of columns matches the length of data in each row
  695        if all(len(row) == len(columns) for row in data):
  696            return True
  697        else:
  698            return False
  699    from os.path import exists
  700    valid_modalities = get_valid_modalities()
  701    if not isinstance(t1_filename, str):
  702        raise ValueError("t1_filename is not a string")
  703    if not exists(t1_filename):
  704        raise ValueError("t1_filename does not exist")
  705    if modality not in valid_modalities:
  706        raise ValueError('modality ' + str(modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
  707    # if not exists( output_image_directory ):
  708    #    raise ValueError("output_image_directory does not exist")
  709    if not exists( source_image_directory ):
  710        raise ValueError("source_image_directory does not exist")
  711    if len( rsf_filenames ) > 2:
  712        raise ValueError("len( rsf_filenames ) > 2")
  713    if len( dti_filenames ) > 3:
  714        raise ValueError("len( dti_filenames ) > 3")
  715    if len( nm_filenames ) > 11:
  716        raise ValueError("len( nm_filenames ) > 11")
  717    if len( rsf_filenames ) < 2:
  718        for k in range(len(rsf_filenames),2):
  719            rsf_filenames.append(None)
  720    if len( dti_filenames ) < 3:
  721        for k in range(len(dti_filenames),3):
  722            dti_filenames.append(None)
  723    if len( nm_filenames ) < 10:
  724        for k in range(len(nm_filenames),10):
  725            nm_filenames.append(None)
  726    # check modality names
  727    if not "T1w" in t1_filename:
  728        raise ValueError("T1w is not in t1 filename " + t1_filename)
  729    if flair_filename is not None:
  730        if isinstance(flair_filename,list):
  731            if (len(flair_filename) == 0):
  732                flair_filename=None
  733            else:
  734                print("Take first entry from flair_filename list")
  735                flair_filename=flair_filename[0]
  736    if flair_filename is not None and not "lair" in flair_filename:
  737            raise ValueError("flair is not flair filename " + flair_filename)
  738    ## perfusion
  739    if perf_filename is not None:
  740        if isinstance(perf_filename,list):
  741            if (len(perf_filename) == 0):
  742                perf_filename=None
  743            else:
  744                print("Take first entry from perf_filename list")
  745                perf_filename=perf_filename[0]
  746    if perf_filename is not None and not "perf" in perf_filename:
  747            raise ValueError("perf_filename is not perf filename " + perf_filename)
  748
  749    if pet3d_filename is not None:
  750        if isinstance(pet3d_filename,list):
  751            if (len(pet3d_filename) == 0):
  752                pet3d_filename=None
  753            else:
  754                print("Take first entry from pet3d_filename list")
  755                pet3d_filename=pet3d_filename[0]
  756    if pet3d_filename is not None and not "pet" in pet3d_filename:
  757            raise ValueError("pet3d_filename is not pet filename " + pet3d_filename)
  758    
  759    for k in nm_filenames:
  760        if k is not None:
  761            if not "NM" in k:
  762                raise ValueError("NM is not flair filename " + k)
  763    for k in dti_filenames:
  764        if k is not None:
  765            if not "DTI" in k and not "dwi" in k:
  766                raise ValueError("DTI/DWI is not dti filename " + k)
  767    for k in rsf_filenames:
  768        if k is not None:
  769            if not "fMRI" in k and not "func" in k:
  770                raise ValueError("rsfMRI/func is not rsfmri filename " + k)
  771    if perf_filename is not None:
  772        if not "perf" in perf_filename:
  773                raise ValueError("perf_filename is not a valid perfusion (perf) filename " + k)
  774    allfns = [t1_filename] + [flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename] + [pet3d_filename]
  775    for k in allfns:
  776        if k is not None:
  777            if not isinstance(k, str):
  778                raise ValueError(str(k) + " is not a string")
  779            if not exists( k ):
  780                raise ValueError( "image " + k + " does not exist")
  781    coredata = [
  782        projectID,
  783        subjectID,
  784        date,
  785        imageUniqueID,
  786        modality,
  787        source_image_directory,
  788        output_image_directory,
  789        t1_filename,
  790        flair_filename, 
  791        perf_filename,
  792        pet3d_filename]
  793    mydata0 = coredata +  rsf_filenames + dti_filenames
  794    mydata = mydata0 + nm_filenames
  795    corecols = [
  796        'projectID',
  797        'subjectID',
  798        'date',
  799        'imageID',
  800        'modality',
  801        'sourcedir',
  802        'outputdir',
  803        'filename',
  804        'flairid',
  805        'perfid',
  806        'pet3did']
  807    mycols0 = corecols + [
  808        'rsfid1', 'rsfid2',
  809        'dtid1', 'dtid2','dtid3']
  810    nmext = [
  811        'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
  812        'nmid6', 'nmid7','nmid8', 'nmid9', 'nmid10' #, 'nmid11'
  813    ]
  814    mycols = mycols0 + nmext
  815    if not check_pd_construction( [mydata], mycols ) :
  816#        print( mydata )
  817#        print( len(mydata ))
  818#        print( mycols )
  819#        print( len(mycols ))
  820        raise ValueError( "Error in generate_mm_dataframe: len( mycols ) != len( mydata ) which indicates a bad input parameter to this function." )
  821    studycsv = pd.DataFrame([ mydata ], columns=mycols)
  822    return studycsv
  823
  824import pandas as pd
  825from os.path import exists
  826
  827def validate_filename(filename, valid_keywords, error_message):
  828    """
  829    Validate if the given filename contains any of the specified keywords.
  830
  831    Parameters:
  832    - filename (str): The filename to validate.
  833    - valid_keywords (list): A list of keywords to look for in the filename.
  834    - error_message (str): The error message to raise if validation fails.
  835
  836    Raises:
  837    - ValueError: If none of the keywords are found in the filename.
  838    """
  839    if filename is not None and not any(keyword in filename for keyword in valid_keywords):
  840        raise ValueError(error_message)
  841
  842def validate_modality(modality, valid_modalities):
  843    if modality not in valid_modalities:
  844        valid_modalities_str = ', '.join(valid_modalities)
  845        raise ValueError(f'Modality {modality} not a valid mm modality: {valid_modalities_str}')
  846
  847def extend_list_to_length(lst, target_length, fill_value=None):
  848    return lst + [fill_value] * (target_length - len(lst))
  849
  850def generate_mm_dataframe_gpt(
  851        projectID, subjectID, date, imageUniqueID, modality, 
  852        source_image_directory, output_image_directory, t1_filename, 
  853        flair_filename=[], rsf_filenames=[], dti_filenames=[], nm_filenames=[], perf_filename=[] ):
  854    """
  855    see help for generate_mm_dataframe - same as this
  856    """
  857    def check_pd_construction(data, columns):
  858        return all(len(row) == len(columns) for row in data)
  859
  860    flair_filename.sort()
  861    rsf_filenames.sort()
  862    dti_filenames.sort()
  863    nm_filenames.sort()
  864    perf_filename.sort()
  865
  866    valid_modalities = get_valid_modalities()  
  867
  868    if not isinstance(t1_filename, str):
  869        raise ValueError("t1_filename is not a string")
  870    if not exists(t1_filename):
  871        raise ValueError("t1_filename does not exist")
  872
  873    validate_modality(modality, valid_modalities)
  874
  875    if not exists(source_image_directory):
  876        raise ValueError("source_image_directory does not exist")
  877
  878    rsf_filenames = extend_list_to_length(rsf_filenames, 2)
  879    dti_filenames = extend_list_to_length(dti_filenames, 2)
  880    nm_filenames = extend_list_to_length(nm_filenames, 11)
  881
  882    validate_filename(t1_filename, ["T1w"], "T1w is not in t1 filename " + t1_filename)
  883
  884    if flair_filename:
  885        flair_filename = flair_filename[0] if isinstance(flair_filename, list) else flair_filename
  886        validate_filename(flair_filename, ["lair"], "flair is not in flair filename " + flair_filename)
  887
  888    if perf_filename:
  889        perf_filename = perf_filename[0] if isinstance(perf_filename, list) else perf_filename
  890        validate_filename(perf_filename, ["perf"], "perf_filename is not a valid perfusion (perf) filename")
  891
  892    for k in nm_filenames:
  893        if k: validate_filename(k, ["NM"], "NM is not in NM filename " + k)
  894
  895    for k in dti_filenames:
  896        if k: validate_filename(k, ["DTI","dwi"], "DTI or dwi is not in DTI filename " + k)
  897
  898    for k in rsf_filenames:
  899        if k: validate_filename(k, ["fMRI","func"], "rsfMRI or func is not in rsfMRI filename " + k)
  900
  901    allfns = [t1_filename, flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename]
  902    for k in allfns:
  903        if k and not exists(k):
  904            raise ValueError("image " + k + " does not exist")
  905
  906    coredata = [projectID, subjectID, date, imageUniqueID, modality,
  907                source_image_directory, output_image_directory, t1_filename, 
  908                flair_filename, perf_filename]
  909    mydata0 = coredata + rsf_filenames + dti_filenames
  910    mydata = mydata0 + nm_filenames
  911
  912    corecols = ['projectID', 'subjectID', 'date', 'imageID', 'modality',
  913                'sourcedir', 'outputdir', 'filename', 'flairid', 'perfid']
  914    mycols0 = corecols + ['rsfid1', 'rsfid2', 'dtid1', 'dtid2']
  915    nmext = ['nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
  916             'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'nmid11']
  917    mycols = mycols0 + nmext
  918
  919    if not check_pd_construction([mydata], mycols):
  920        print( mydata )
  921        print( mycols )
  922        raise ValueError("Error in generate_mm_dataframe: len(mycols) != len(mydata), indicating bad input parameters.")
  923
  924    studycsv = pd.DataFrame([mydata], columns=mycols)
  925    return studycsv
  926
  927
  928
  929def filter_columns_by_nan_percentage(df, max_nan_percentage=50.0):
  930    """
  931    Filter columns in a DataFrame based on a threshold for the percentage of NaN values.
  932
  933    Parameters
  934    ----------
  935    df : pandas.DataFrame
  936        The input DataFrame from which columns are to be filtered.
  937    max_nan_percentage : float, optional
  938        The maximum allowed percentage of NaN values in a column. Columns with a higher
  939        percentage of NaN values than this threshold will be removed from the DataFrame.
  940        The default is 50.0, which means columns with more than 50% NaN values will be removed.
  941
  942    Returns
  943    -------
  944    pandas.DataFrame
  945        A DataFrame with columns filtered based on the NaN values percentage criterion.
  946
  947    Examples
  948    --------
  949    >>> import pandas as pd
  950    >>> data = {'A': [1, 2, None, 4], 'B': [None, 2, 3, None], 'C': [1, 2, 3, 4]}
  951    >>> df = pd.DataFrame(data)
  952    >>> filtered_df = filter_columns_by_nan_percentage(df, 50.0)
  953    >>> print(filtered_df)
  954
  955    Notes
  956    -----
  957    The function calculates the percentage of NaN values in each column and filters out
  958    those columns where the percentage exceeds the `max_nan_percentage` threshold.
  959    """
  960    # Calculate the percentage of NaN values for each column
  961    nan_percentage = df.isnull().mean() * 100
  962
  963    # Filter columns where the percentage of NaN values is less than or equal to the threshold
  964    columns_to_keep = nan_percentage[nan_percentage <= max_nan_percentage].index
  965
  966    # Return the filtered DataFrame
  967    return df[columns_to_keep]
  968
  969
  970
  971def parse_nrg_filename( x, separator='-' ):
  972    """
  973    split a NRG filename into its named parts
  974    """
  975    temp = x.split( separator )
  976    if len(temp) != 5:
  977        raise ValueError(x + " not a valid NRG filename")
  978    return {
  979        'project':temp[0],
  980        'subjectID':temp[1],
  981        'date':temp[2],
  982        'modality':temp[3],
  983        'imageID':temp[4]
  984    }
  985
  986
  987
  988def nrg_2_bids( nrg_filename ):
  989    """
  990    Convert an NRG filename to BIDS path/filename.
  991
  992    Parameters:
  993    nrg_filename (str): The NRG filename to convert.
  994
  995    Returns:
  996    str: The BIDS path/filename.
  997    """
  998
  999    # Split the NRG filename into its components
 1000    nrg_dirname, nrg_basename = os.path.split(nrg_filename)
 1001    nrg_suffix = '.' + nrg_basename.split('.',1)[-1]
 1002    nrg_basename = nrg_basename.replace(nrg_suffix, '') # remove ext
 1003    nrg_parts = nrg_basename.split('-')
 1004    nrg_subject_id = nrg_parts[1]
 1005    nrg_modality = nrg_parts[3]
 1006    nrg_repeat= nrg_parts[4]
 1007
 1008    # Build the BIDS path/filename
 1009    bids_dirname = os.path.join(nrg_dirname, 'bids')
 1010    bids_subject = f'sub-{nrg_subject_id}'
 1011    bids_session = f'ses-{nrg_repeat}'
 1012
 1013    valid_modalities = get_valid_modalities()
 1014    if nrg_modality is not None:
 1015        if not nrg_modality in valid_modalities:
 1016            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
 1017
 1018    if nrg_modality == 'T1w' :
 1019        bids_modality_folder = 'anat'
 1020        bids_modality_filename = 'T1w'
 1021
 1022    if nrg_modality == 'T2Flair' :
 1023        bids_modality_folder = 'anat'
 1024        bids_modality_filename = 'flair'
 1025
 1026    if nrg_modality == 'NM2DMT' :
 1027        bids_modality_folder = 'anat'
 1028        bids_modality_filename = 'nm2dmt'
 1029
 1030    if nrg_modality == 'DTI' or nrg_modality == 'DTI_RL' or nrg_modality == 'DTI_LR' :
 1031        bids_modality_folder = 'dwi'
 1032        bids_modality_filename = 'dwi'
 1033
 1034    if nrg_modality == 'rsfMRI' or nrg_modality == 'rsfMRI_RL' or nrg_modality == 'rsfMRI_LR' :
 1035        bids_modality_folder = 'func'
 1036        bids_modality_filename = 'func'
 1037
 1038    if nrg_modality == 'perf'  :
 1039        bids_modality_folder = 'perf'
 1040        bids_modality_filename = 'perf'
 1041
 1042    bids_suffix = nrg_suffix[1:]
 1043    bids_filename = f'{bids_subject}_{bids_session}_{bids_modality_filename}.{bids_suffix}'
 1044
 1045    # Return bids filepath/filename
 1046    return os.path.join(bids_dirname, bids_subject, bids_session, bids_modality_folder, bids_filename)
 1047
 1048
 1049def bids_2_nrg( bids_filename, project_name, date, nrg_modality=None ):
 1050    """
 1051    Convert a BIDS filename to NRG path/filename.
 1052
 1053    Parameters:
 1054    bids_filename (str): The BIDS filename to convert
 1055    project_name (str) : Name of project (i.e. PPMI)
 1056    date (str) : Date of image acquisition
 1057
 1058
 1059    Returns:
 1060    str: The NRG path/filename.
 1061    """
 1062
 1063    bids_dirname, bids_basename = os.path.split(bids_filename)
 1064    bids_suffix = '.'+ bids_basename.split('.',1)[-1]
 1065    bids_basename = bids_basename.replace(bids_suffix, '') # remove ext
 1066    bids_parts = bids_basename.split('_')
 1067    nrg_subject_id = bids_parts[0].replace('sub-','')
 1068    nrg_image_id = bids_parts[1].replace('ses-', '')
 1069    bids_modality = bids_parts[2]
 1070    valid_modalities = get_valid_modalities()
 1071    if nrg_modality is not None:
 1072        if not nrg_modality in valid_modalities:
 1073            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality: " + get_valid_modalities(asString=True))
 1074
 1075    if bids_modality == 'anat' and nrg_modality is None :
 1076        nrg_modality = 'T1w'
 1077
 1078    if bids_modality == 'dwi' and nrg_modality is None  :
 1079        nrg_modality = 'DTI'
 1080
 1081    if bids_modality == 'func' and nrg_modality is None  :
 1082        nrg_modality = 'rsfMRI'
 1083
 1084    if bids_modality == 'perf' and nrg_modality is None  :
 1085        nrg_modality = 'perf'
 1086
 1087    nrg_suffix = bids_suffix[1:]
 1088    nrg_filename = f'{project_name}-{nrg_subject_id}-{date}-{nrg_modality}-{nrg_image_id}.{nrg_suffix}'
 1089
 1090    return os.path.join(project_name, nrg_subject_id, date, nrg_modality, nrg_image_id,nrg_filename)
 1091
 1092def collect_blind_qc_by_modality( modality_path, set_index_to_fn=True ):
 1093    """
 1094    Collects blind QC data from multiple CSV files with the same modality.
 1095
 1096    Args:
 1097
 1098    modality_path (str): The path to the folder containing the CSV files.
 1099
 1100    set_index_to_fn: boolean
 1101
 1102    Returns:
 1103    Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.
 1104    """
 1105    import glob as glob
 1106    fns = glob.glob( modality_path )
 1107    fns.sort()
 1108    jdf = pd.DataFrame()
 1109    for k in range(len(fns)):
 1110        temp=pd.read_csv(fns[k])
 1111        if not 'filename' in temp.keys():
 1112            temp['filename']=fns[k]
 1113        jdf=pd.concat( [jdf,temp], axis=0, ignore_index=False )
 1114    if set_index_to_fn:
 1115        jdf.reset_index(drop=True)
 1116        if "Unnamed: 0" in jdf.columns:
 1117            holder=jdf.pop( "Unnamed: 0" )
 1118        jdf.set_index('filename')
 1119    return jdf
 1120
 1121
 1122def outlierness_by_modality( qcdf, uid='filename', outlier_columns = ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi','reflection_err', 'EVR', 'msk_vol'], verbose=False ):
 1123    """
 1124    Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF.  LOF appears to be more conservative.  This function will impute missing columns with the mean.
 1125
 1126    Args:
 1127    - qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
 1128    - uid: (str) Unique identifier for a subject. Default is 'filename'.
 1129    - outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
 1130    - verbose: (bool) If True, prints information for each modality. Default is False.
 1131
 1132    Returns:
 1133    - qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column.  Higher values near 1 are more outlying.
 1134
 1135    Raises:
 1136    - ValueError: If uid is not present in the dataframe.
 1137
 1138    Example:
 1139    >>> df = pd.read_csv('data.csv')
 1140    >>> outlierness_by_modality(df)
 1141    """
 1142    from PyNomaly import loop
 1143    from sklearn.neighbors import LocalOutlierFactor
 1144    qcdfout = qcdf.copy()
 1145    pd.set_option('future.no_silent_downcasting', True)
 1146    qcdfout.replace([np.inf, -np.inf], np.nan, inplace=True)
 1147    if uid not in qcdfout.keys():
 1148        raise ValueError( str(uid) + " not in dataframe")
 1149    if 'ol_loop' not in qcdfout.keys():
 1150        qcdfout['ol_loop']=math.nan
 1151    if 'ol_lof' not in qcdfout.keys():
 1152        qcdfout['ol_lof']=math.nan
 1153    didit=False
 1154    for mod in get_valid_modalities( qc=True ):
 1155        didit=True
 1156        lof = LocalOutlierFactor()
 1157        locsel = qcdfout["modality"] == mod
 1158        rr = qcdfout[locsel][outlier_columns]
 1159        column_means = rr.mean()
 1160        rr.fillna(column_means, inplace=True)
 1161        if rr.shape[0] > 1:
 1162            if verbose:
 1163                print("calc: " + mod + " outlierness " )
 1164            myneigh = np.min( [24, int(np.round(rr.shape[0]*0.5)) ] )
 1165            temp = antspyt1w.loop_outlierness(rr.astype(float), standardize=True, extent=3, n_neighbors=myneigh, cluster_labels=None)
 1166            qcdfout.loc[locsel,'ol_loop']=temp.astype('float64')
 1167            yhat = lof.fit_predict(rr)
 1168            temp = lof.negative_outlier_factor_*(-1.0)
 1169            temp = temp - temp.min()
 1170            yhat[ yhat == 1] = 0
 1171            yhat[ yhat == -1] = 1 # these are outliers
 1172            qcdfout.loc[locsel,'ol_lof_decision']=yhat
 1173            qcdfout.loc[locsel,'ol_lof']=temp/temp.max()
 1174    if verbose:
 1175        print( didit )
 1176    return qcdfout
 1177
 1178
 1179def nrg_format_path( projectID, subjectID, date, modality, imageID, separator='-' ):
 1180    """
 1181    create the NRG path on disk given the project, subject id, date, modality and image id
 1182
 1183    Arguments
 1184    ---------
 1185
 1186    projectID : string for the project e.g. PPMI
 1187
 1188    subjectID : string uniquely identifying the subject e.g. 0001
 1189
 1190    date : string for the date usually 20550228 ie YYYYMMDD format
 1191
 1192    modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)
 1193
 1194    imageID : string uniquely identifying the specific image
 1195
 1196    separator : default to -
 1197
 1198    Returns
 1199    -------
 1200    the path where one would write the image on disk
 1201
 1202    """
 1203    thedirectory = os.path.join( str(projectID), str(subjectID), str(date), str(modality), str(imageID) )
 1204    thefilename = str(projectID) + separator + str(subjectID) + separator + str(date) + separator + str(modality) + separator + str(imageID)
 1205    return os.path.join( thedirectory, thefilename )
 1206
 1207
 1208def get_first_item_as_string(df, column_name):
 1209    """
 1210    Check if the first item in the specified column of the DataFrame is a string.
 1211    If it is not a string, attempt to convert it to an integer and then to a string.
 1212
 1213    Parameters:
 1214    df (pd.DataFrame): The DataFrame to operate on.
 1215    column_name (str): The name of the column to check.
 1216
 1217    Returns:
 1218    str: The first item in the specified column, guaranteed to be returned as a string.
 1219    """
 1220    if isinstance(df[column_name].iloc[0], str):
 1221        return df[column_name].iloc[0]
 1222    else:
 1223        try:
 1224            return str(int(df[column_name].iloc[0]))
 1225        except ValueError:
 1226            raise ValueError("The value cannot be converted to an integer.")
 1227
 1228
 1229def study_dataframe_from_matched_dataframe( matched_dataframe, rootdir, outputdir, verbose=False ):
 1230    """
 1231    converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv
 1232
 1233    matched_dataframe : output of antspymm.match_modalities
 1234
 1235    rootdir : location for the input data root folder (in e.g. NRG format)
 1236
 1237    outputdir : location for the output data
 1238
 1239    verbose : boolean
 1240    """
 1241    iext='.nii.gz'
 1242    from os.path import exists
 1243    musthavecols = ['projectID', 'subjectID','date','imageID','filename']
 1244    for k in range(len(musthavecols)):
 1245        if not musthavecols[k] in matched_dataframe.keys():
 1246            raise ValueError('matched_dataframe is missing column ' + musthavecols[k] + ' in study_dataframe_from_qc_dataframe' )
 1247    csvrow=matched_dataframe.dropna(axis=1)
 1248    pid=get_first_item_as_string( csvrow, 'projectID'  )
 1249    sid=get_first_item_as_string( csvrow, 'subjectID'  ) # str(csvrow['subjectID'].iloc[0] )
 1250    dt=get_first_item_as_string( csvrow, 'date'  )  # str(csvrow['date'].iloc[0])
 1251    iid=get_first_item_as_string( csvrow, 'imageID'  ) # str(csvrow['imageID'].iloc[0])
 1252    nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
 1253    if not exists( nrgt1fn ) and iid == '0':
 1254        iid='000'
 1255        nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
 1256    if not exists( nrgt1fn ):
 1257        raise ValueError("T1 " + nrgt1fn + " does not exist in study_dataframe_from_qc_dataframe")
 1258    flList=[]
 1259    dtList=[]
 1260    rsfList=[]
 1261    nmList=[]
 1262    perfList=[]
 1263    if 'flairfn' in csvrow.keys():
 1264        flid=get_first_item_as_string( csvrow, 'flairid' )
 1265        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
 1266        if not exists( nrgt2fn ) and flid == '0':
 1267            flid='000'
 1268            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
 1269        if verbose:
 1270            print("Trying " + nrgt2fn )
 1271        if exists( nrgt2fn ):
 1272            if verbose:
 1273                print("success" )
 1274            flList.append( nrgt2fn )
 1275    if 'perffn' in csvrow.keys():
 1276        flid=get_first_item_as_string( csvrow, 'perfid' )
 1277        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
 1278        if not exists( nrgt2fn ) and flid == '0':
 1279            flid='000'
 1280            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
 1281        if verbose:
 1282            print("Trying " + nrgt2fn )
 1283        if exists( nrgt2fn ):
 1284            if verbose:
 1285                print("success" )
 1286            perfList.append( nrgt2fn )
 1287    if 'dtfn1' in csvrow.keys():
 1288        dtid=get_first_item_as_string( csvrow, 'dtid1' )
 1289        dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
 1290        if len( dtfn1) == 0 :
 1291            dtid = '000'
 1292            dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
 1293        dtfn1=dtfn1[0]
 1294        if exists( dtfn1 ):
 1295            dtList.append( dtfn1 )
 1296    if 'dtfn2' in csvrow.keys():
 1297        dtid=get_first_item_as_string( csvrow, 'dtid2' )
 1298        dtfn2=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
 1299        if len( dtfn2) == 0 :
 1300            dtid = '000'
 1301            dtfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
 1302        dtfn2=dtfn2[0]
 1303        if exists( dtfn2 ):
 1304            dtList.append( dtfn2 )
 1305    if 'dtfn3' in csvrow.keys():
 1306        dtid=get_first_item_as_string( csvrow, 'dtid3' )
 1307        dtfn3=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
 1308        if len( dtfn3) == 0 :
 1309            dtid = '000'
 1310            dtfn3=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
 1311        dtfn3=dtfn3[0]
 1312        if exists( dtfn3 ):
 1313            dtList.append( dtfn3 )
 1314    if 'rsffn1' in csvrow.keys():
 1315        rsid=get_first_item_as_string( csvrow, 'rsfid1' )
 1316        rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
 1317        if len( rsfn1 ) == 0 :
 1318            rsid = '000'
 1319            rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
 1320        rsfn1=rsfn1[0]
 1321        if exists( rsfn1 ):
 1322            rsfList.append( rsfn1 )
 1323    if 'rsffn2' in csvrow.keys():
 1324        rsid=get_first_item_as_string( csvrow, 'rsfid2' )
 1325        rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))[0]
 1326        if len( rsfn2 ) == 0 :
 1327            rsid = '000'
 1328            rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))
 1329        rsfn2=rsfn2[0]
 1330        if exists( rsfn2 ):
 1331            rsfList.append( rsfn2 )
 1332    for j in range(11):
 1333        keyname="nmfn"+str(j)
 1334        keynameid="nmid"+str(j)
 1335        if keyname in csvrow.keys() and keynameid in csvrow.keys():
 1336            nmid=get_first_item_as_string( csvrow, keynameid )
 1337            nmsearchpath=os.path.join( rootdir, pid, sid, dt, 'NM2DMT', nmid, "*"+nmid+iext)
 1338            nmfn=glob.glob( nmsearchpath )
 1339            nmfn=nmfn[0]
 1340            if exists( nmfn ):
 1341                nmList.append( nmfn )
 1342    if verbose:
 1343        print("assembled the image lists mapping to ....")
 1344        print(nrgt1fn)
 1345        print("NM")
 1346        print(nmList)
 1347        print("FLAIR")
 1348        print(flList)
 1349        print("DTI")
 1350        print(dtList)
 1351        print("rsfMRI")
 1352        print(rsfList)
 1353        print("perf")
 1354        print(perfList)
 1355    studycsv = generate_mm_dataframe(
 1356        pid,
 1357        sid,
 1358        dt,
 1359        iid, # the T1 id
 1360        'T1w',
 1361        rootdir,
 1362        outputdir,
 1363        t1_filename=nrgt1fn,
 1364        flair_filename=flList,
 1365        dti_filenames=dtList,
 1366        rsf_filenames=rsfList,
 1367        nm_filenames=nmList,
 1368        perf_filename=perfList)
 1369    return studycsv.dropna(axis=1)
 1370
 1371def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
 1372    """
 1373    This function returns a subset of the input dataframe that retains only the rows
 1374    that correspond to the highest quality observation for each combination of ID and visit.
 1375
 1376    Parameters:
 1377    ----------
 1378    mxdfin: pandas.DataFrame
 1379        The input dataframe.
 1380    idvar: str
 1381        The name of the column that contains the ID variable.
 1382    visitvar: str
 1383        The name of the column that contains the visit variable.
 1384    qualityvar: str
 1385        The name of the column that contains the quality variable.
 1386
 1387    Returns:
 1388    -------
 1389    pandas.DataFrame
 1390        A subset of the input dataframe that retains only the rows that correspond
 1391        to the highest quality observation for each combination of ID and visit.
 1392    """
 1393    if visitvar not in mxdfin.columns:
 1394        raise ValueError("visitvar not in dataframe")
 1395    if idvar not in mxdfin.columns:
 1396        raise ValueError("idvar not in dataframe")
 1397    if qualityvar not in mxdfin.columns:
 1398        raise ValueError("qualityvar not in dataframe")
 1399
 1400    mxdfin[qualityvar] = mxdfin[qualityvar].astype(float)
 1401
 1402    vizzes = mxdfin[visitvar].unique()
 1403    uids = mxdfin[idvar].unique()
 1404    useit = np.zeros(mxdfin.shape[0], dtype=bool)
 1405
 1406    for u in uids:
 1407        losel = mxdfin[idvar] == u
 1408        vizzesloc = mxdfin[losel][visitvar].unique()
 1409
 1410        for v in vizzesloc:
 1411            losel = (mxdfin[idvar] == u) & (mxdfin[visitvar] == v)
 1412            mysnr = mxdfin.loc[losel, qualityvar]
 1413            myw = np.where(losel)[0]
 1414
 1415            if len(myw) > 1:
 1416                if any(~np.isnan(mysnr)):
 1417                    useit[myw[np.argmax(mysnr)]] = True
 1418                else:
 1419                    useit[myw] = True
 1420            else:
 1421                useit[myw] = True
 1422
 1423    return mxdfin[useit]
 1424
 1425
 1426def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False ):
 1427    """
 1428    Find the best multiple modality dataset at each time point
 1429
 1430    :param qc_dataframe: quality control data frame with
 1431    :param unique_identifier : the unique NRG filename for each image
 1432    :param outlier_column: outlierness score used to identify the best image (pair) at a given date
 1433    :param mysep (str, optional): the separator used in the image file names. Defaults to '-'.
 1434    :param verbose: boolean
 1435    :return: filtered matched modality data frame
 1436    """
 1437    import pandas as pd
 1438    import numpy as np
 1439    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
 1440    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
 1441    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
 1442    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
 1443    mmdf = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column )['filt']
 1444    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column )['filt']
 1445    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column )['filt']
 1446    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column )['filt']
 1447    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column )['filt']
 1448    mmdf['flairid'] = None
 1449    mmdf['flairfn'] = None
 1450    mmdf['flairloop'] = None
 1451    mmdf['flairlof'] = None
 1452    mmdf['dtid1'] = None
 1453    mmdf['dtfn1'] = None
 1454    mmdf['dtntimepoints1'] = 0
 1455    mmdf['dtloop1'] = math.nan
 1456    mmdf['dtlof1'] = math.nan
 1457    mmdf['dtid2'] = None
 1458    mmdf['dtfn2'] = None
 1459    mmdf['dtntimepoints2'] = 0
 1460    mmdf['dtloop2'] = math.nan
 1461    mmdf['dtlof2'] = math.nan
 1462    mmdf['rsfid1'] = None
 1463    mmdf['rsffn1'] = None
 1464    mmdf['rsfntimepoints1'] = 0
 1465    mmdf['rsfloop1'] = math.nan
 1466    mmdf['rsflof1'] = math.nan
 1467    mmdf['rsfid2'] = None
 1468    mmdf['rsffn2'] = None
 1469    mmdf['rsfntimepoints2'] = 0
 1470    mmdf['rsfloop2'] = math.nan
 1471    mmdf['rsflof2'] = math.nan
 1472    for k in range(1,11):
 1473        myid='nmid'+str(k)
 1474        mmdf[myid] = None
 1475        myid='nmfn'+str(k)
 1476        mmdf[myid] = None
 1477        myid='nmloop'+str(k)
 1478        mmdf[myid] = math.nan
 1479        myid='nmlof'+str(k)
 1480        mmdf[myid] = math.nan
 1481    if verbose:
 1482        print( mmdf.shape )
 1483    for k in range(mmdf.shape[0]):
 1484        if verbose:
 1485            if k % 100 == 0:
 1486                progger = str( k ) # np.round( k / mmdf.shape[0] * 100 ) )
 1487                print( progger, end ="...", flush=True)
 1488        if dtdf is not None:
 1489            locsel = (dtdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
 1490            if sum(locsel) == 1:
 1491                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = dtdf["imageID"][locsel].values[0]
 1492                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = dtdf[unique_identifier][locsel].values[0]
 1493                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = dtdf[outlier_column][locsel].values[0]
 1494                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(dtdf['ol_lof_decision'][locsel].values[0])
 1495                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
 1496            elif sum(locsel) > 1:
 1497                locdf = dtdf[locsel]
 1498                dedupe = locdf[["snr","cnr"]].duplicated()
 1499                locdf = locdf[~dedupe]
 1500                if locdf.shape[0] > 1:
 1501                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1502                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = locdf["imageID"].values[0]
 1503                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = locdf[unique_identifier].values[0]
 1504                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = locdf[outlier_column].values[0]
 1505                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(locdf['ol_lof_decision'][locsel].values[0])
 1506                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
 1507                if locdf.shape[0] > 1:
 1508                    mmdf.iloc[k, mmdf.columns.get_loc("dtid2")] = locdf["imageID"].values[1]
 1509                    mmdf.iloc[k, mmdf.columns.get_loc("dtfn2")] = locdf[unique_identifier].values[1]
 1510                    mmdf.iloc[k, mmdf.columns.get_loc("dtloop2")] = locdf[outlier_column].values[1]
 1511                    mmdf.iloc[k, mmdf.columns.get_loc("dtlof2")] = float(locdf['ol_lof_decision'][locsel].values[1])
 1512                    mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints2")] = float(dtdf['dimt'][locsel].values[1])
 1513        if rsdf is not None:
 1514            locsel = (rsdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
 1515            if sum(locsel) == 1:
 1516                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = rsdf["imageID"][locsel].values[0]
 1517                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = rsdf[unique_identifier][locsel].values[0]
 1518                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = rsdf[outlier_column][locsel].values[0]
 1519                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(rsdf['ol_lof_decision'].values[0])
 1520                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(rsdf['dimt'][locsel].values[0])
 1521            elif sum(locsel) > 1:
 1522                locdf = rsdf[locsel]
 1523                dedupe = locdf[["snr","cnr"]].duplicated()
 1524                locdf = locdf[~dedupe]
 1525                if locdf.shape[0] > 1:
 1526                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1527                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = locdf["imageID"].values[0]
 1528                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = locdf[unique_identifier].values[0]
 1529                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = locdf[outlier_column].values[0]
 1530                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(locdf['ol_lof_decision'].values[0])
 1531                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(locdf['dimt'][locsel].values[0])
 1532                if locdf.shape[0] > 1:
 1533                    mmdf.iloc[k, mmdf.columns.get_loc("rsfid2")] = locdf["imageID"].values[1]
 1534                    mmdf.iloc[k, mmdf.columns.get_loc("rsffn2")] = locdf[unique_identifier].values[1]
 1535                    mmdf.iloc[k, mmdf.columns.get_loc("rsfloop2")] = locdf[outlier_column].values[1]
 1536                    mmdf.iloc[k, mmdf.columns.get_loc("rsflof2")] = float(locdf['ol_lof_decision'].values[1])
 1537                    mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints2")] = float(locdf['dimt'][locsel].values[1])
 1538
 1539        if fldf is not None:
 1540            locsel = fldf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
 1541            if locsel.sum() == 1:
 1542                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = fldf['imageID'][locsel].values[0]
 1543                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = fldf[unique_identifier][locsel].values[0]
 1544                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = fldf[outlier_column][locsel].values[0]
 1545                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(fldf['ol_lof_decision'][locsel].values[0])
 1546            elif sum(locsel) > 1:
 1547                locdf = fldf[locsel]
 1548                dedupe = locdf[["snr","cnr"]].duplicated()
 1549                locdf = locdf[~dedupe]
 1550                if locdf.shape[0] > 1:
 1551                    locdf = locdf.sort_values(outlier_column).iloc[:2]
 1552                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = locdf["imageID"].values[0]
 1553                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = locdf[unique_identifier].values[0]
 1554                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = locdf[outlier_column].values[0]
 1555                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(locdf['ol_lof_decision'].values[0])
 1556
 1557        if nmdf is not None:
 1558            locsel = nmdf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
 1559            if locsel.sum() > 0:
 1560                locdf = nmdf[locsel]
 1561                for i in range(np.min( [10,locdf.shape[0]])):
 1562                    nmid = "nmid"+str(i+1)
 1563                    mmdf.loc[k,nmid] = locdf['imageID'].iloc[i]
 1564                    nmfn = "nmfn"+str(i+1)
 1565                    mmdf.loc[k,nmfn] = locdf['imageID'].iloc[i]
 1566                    nmloop = "nmloop"+str(i+1)
 1567                    mmdf.loc[k,nmloop] = locdf[outlier_column].iloc[i]
 1568                    nmloop = "nmlof"+str(i+1)
 1569                    mmdf.loc[k,nmloop] = float(locdf['ol_lof_decision'].iloc[i])
 1570
 1571    mmdf['rsf_total_timepoints']=mmdf['rsfntimepoints1']+mmdf['rsfntimepoints2']
 1572    mmdf['dt_total_timepoints']=mmdf['dtntimepoints1']+mmdf['dtntimepoints2']
 1573    return mmdf
 1574
 1575
 1576def add_repeat_column(df, groupby_column):
 1577    """
 1578    Adds a 'repeat' column to the DataFrame that counts occurrences of each unique value
 1579    in the specified 'groupby_column'. The count increments from 1 for each identical entry.
 1580    
 1581    Parameters:
 1582    - df: pandas DataFrame.
 1583    - groupby_column: The name of the column to group by and count repeats.
 1584    
 1585    Returns:
 1586    - Modified pandas DataFrame with an added 'repeat' column.
 1587    """
 1588    # Validate if the groupby_column exists in the DataFrame
 1589    if groupby_column not in df.columns:
 1590        raise ValueError(f"Column '{groupby_column}' does not exist in the DataFrame.")
 1591    
 1592    # Count the occurrences of each unique value in the specified column and increment from 1
 1593    df['repeat'] = df.groupby(groupby_column).cumcount() + 1
 1594    
 1595    return df
 1596
 1597def best_mmm( mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
 1598    """
 1599    Selects the best repeats per modality.
 1600
 1601    Args:
 1602    wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')
 1603
 1604    mysep (str, optional): the separator used in the image file names. Defaults to '-'.
 1605
 1606    outlier_name : column name for outlier score
 1607
 1608    verbose (bool, optional): default True
 1609
 1610    Returns:
 1611
 1612    list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.
 1613
 1614    """
 1615#    mmdf = mmdf.astype(str)
 1616    mmdf[outlier_column]=mmdf[outlier_column].astype(float)
 1617    msel = mmdf['modality'] == wmod
 1618    if wmod == 'rsfMRI':
 1619        msel1 = mmdf['modality'] == 'rsfMRI'
 1620        msel2 = mmdf['modality'] == 'rsfMRI_LR'
 1621        msel3 = mmdf['modality'] == 'rsfMRI_RL'
 1622        msel = msel1 | msel2
 1623        msel = msel | msel3
 1624    if wmod == 'DTI':
 1625        msel1 = mmdf['modality'] == 'DTI'
 1626        msel2 = mmdf['modality'] == 'DTI_LR'
 1627        msel3 = mmdf['modality'] == 'DTI_RL'
 1628        msel4 = mmdf['modality'] == 'DTIdwi'
 1629        msel5 = mmdf['modality'] == 'DTIb0'
 1630        msel = msel1 | msel2 | msel3 | msel4 | msel5
 1631    if sum(msel) == 0:
 1632        return {'raw': None, 'filt': None}
 1633    metasub = mmdf[msel].copy()
 1634
 1635    if verbose:
 1636        print(f"{wmod} {(metasub.shape[0])} pre")
 1637
 1638    metasub['subjectID']=None
 1639    metasub['date']=None
 1640    metasub['subjectIDdate']=None
 1641    metasub['imageID']=None
 1642    metasub['negol']=math.nan
 1643    for k in metasub.index:
 1644        temp = metasub.loc[k, 'filename'].split( mysep )
 1645        metasub.loc[k,'subjectID'] = str( temp[1] )
 1646        metasub.loc[k,'date'] = str( temp[2] )
 1647        metasub.loc[k,'subjectIDdate'] = str( temp[1] + mysep + temp[2] )
 1648        metasub.loc[k,'imageID'] = str( temp[4])
 1649
 1650
 1651    if 'ol_' in outlier_column:
 1652        metasub['negol'] = metasub[outlier_column].max() - metasub[outlier_column]
 1653    else:
 1654        metasub['negol'] = metasub[outlier_column]
 1655    if 'date' not in metasub.keys():
 1656        metasub['date']=None
 1657    metasubq = add_repeat_column( metasub, 'subjectIDdate' )
 1658    metasubq = highest_quality_repeat(metasubq, 'filename', 'date', 'negol')
 1659
 1660    if verbose:
 1661        print(f"{wmod} {metasubq.shape[0]} post")
 1662
 1663#    metasub = metasub.astype(str)
 1664#    metasubq = metasubq.astype(str)
 1665    metasub[outlier_column]=metasub[outlier_column].astype(float)
 1666    metasubq[outlier_column]=metasubq[outlier_column].astype(float)
 1667    return {'raw': metasub, 'filt': metasubq}
 1668
 1669def mm_read( x, standardize_intensity=False, modality='' ):
 1670    """
 1671    read an image from a filename - same as ants.image_read (for now)
 1672
 1673    standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one
 1674
 1675    modality : not used
 1676    """
 1677    if x is None:
 1678        raise ValueError( " None passed to function antspymm.mm_read." )
 1679    if not isinstance(x,str):
 1680        raise ValueError( " Non-string passed to function antspymm.mm_read." )
 1681    if not os.path.exists( x ):
 1682        raise ValueError( " file " + fni + " does not exist." )
 1683    img = ants.image_read( x, reorient=False )
 1684    if standardize_intensity:
 1685        img[img<0.0]=0.0
 1686        img=ants.iMath(img,'Normalize')
 1687    if modality == "T1w" and img.dimension == 4:
 1688        print("WARNING: input image is 4D - we attempt a hack fix that works in some odd cases of PPMI data - please check this image: " + x, flush=True )
 1689        i1=ants.slice_image(img,3,0)
 1690        i2=ants.slice_image(img,3,1)
 1691        kk=np.concatenate( [i1.numpy(),i2.numpy()], axis=2 )
 1692        kk=ants.from_numpy(kk)
 1693        img=ants.copy_image_info(i1,kk)
 1694    return img
 1695
 1696def mm_read_to_3d( x, slice=None, modality='' ):
 1697    """
 1698    read an image from a filename - and return as 3d or None if that is not possible
 1699    """
 1700    img = ants.image_read( x, reorient=False )
 1701    if img.dimension <= 3:
 1702        return img
 1703    elif img.dimension == 4:
 1704        nslices = img.shape[3]
 1705        if slice is None:
 1706            sl = np.round( nslices * 0.5 )
 1707        else:
 1708            sl = slice
 1709        if sl > nslices:
 1710            sl = nslices-1
 1711        return ants.slice_image( img, axis=3, idx=int(sl) )
 1712    elif img.dimension > 4:
 1713        return img
 1714    return None
 1715
 1716def timeseries_n3(x):
 1717    """
 1718    Perform N3 bias field correction on a time-series image dataset using ANTsPy library.
 1719
 1720    This function processes a multi-dimensional image dataset, where the last dimension
 1721    represents different time points. It applies N3 bias field correction to each time point 
 1722    individually to correct intensity non-uniformity.
 1723
 1724    Parameters:
 1725    x (ndarray): A multi-dimensional array where the last dimension represents time points. 
 1726                 Each 'slice' along this dimension is a separate image to be corrected.
 1727
 1728    Returns:
 1729    ndarray: A multi-dimensional array of the same shape as x, with N3 bias field correction 
 1730             applied to each time slice.
 1731
 1732    The function works as follows:
 1733    - Initializes an empty list `mimg` to store the corrected images.
 1734    - Determines the number of time points in the input image series.
 1735    - Iterates over each time point, extracting the image slice and applying N3 bias 
 1736      field correction.
 1737    - The corrected images are then appended to the `mimg` list.
 1738    - Finally, the list of corrected images is converted back into a multi-dimensional 
 1739      array and returned.
 1740
 1741    Example:
 1742    corrected_images = timeseries_n3(image_data)
 1743    """
 1744    mimg = []
 1745    n = len(x.shape) - 1
 1746    for kk in range(x.shape[n]):
 1747        temp = ants.slice_image(x, axis=n, idx=kk)
 1748        temp = ants.n3_bias_field_correction(temp, downsample_factor=2)
 1749        mimg.append(temp)
 1750    return ants.list_to_ndimage(x, mimg)
 1751
 1752def image_write_with_thumbnail( x,  fn, y=None, thumb=True ):
 1753    """
 1754    will write the image and (optionally) a png thumbnail with (optional) overlay/underlay
 1755    """
 1756    ants.image_write( x, fn )
 1757    if not thumb or x.components > 1:
 1758        return
 1759    thumb_fn=re.sub(".nii.gz","_3dthumb.png",fn)
 1760    if thumb and x.dimension == 3:
 1761        if y is None:
 1762            try:
 1763                ants.plot_ortho( x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1764            except:
 1765                pass
 1766        else:
 1767            try:
 1768                ants.plot_ortho( y, x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1769            except:
 1770                pass
 1771    if thumb and x.dimension == 4:
 1772        thumb_fn=re.sub(".nii.gz","_4dthumb.png",fn)
 1773        nslices = x.shape[3]
 1774        sl = np.round( nslices * 0.5 )
 1775        if sl > nslices:
 1776            sl = nslices-1
 1777        xview = ants.slice_image( x, axis=3, idx=int(sl) )
 1778        if y is None:
 1779            try:
 1780                ants.plot_ortho( xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1781            except:
 1782                pass
 1783        else:
 1784            if y.dimension == 3:
 1785                try:
 1786                    ants.plot_ortho(y, xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
 1787                except:
 1788                    pass
 1789    return
 1790
 1791def convert_np_in_dict(data_dict):
 1792    """
 1793    Convert values in the dictionary from nupmy float or int to regular float or int.
 1794
 1795    :param data_dict: A dictionary with values of various types.
 1796    :return: Dictionary with numpy values converted.
 1797    """
 1798    converted_dict = {}
 1799    for key, value in data_dict.items():
 1800        if isinstance(value, (np.float32, np.float64)):
 1801            converted_dict[key] = float(value)
 1802        elif isinstance(value, (np.int8,  np.uint8, np.int16,  np.uint16, np.int32,  np.uint32, np.int64,  np.uint64)):
 1803            converted_dict[key] = int(value)
 1804        else:
 1805            converted_dict[key] = value
 1806    return converted_dict
 1807
 1808def mc_resample_image_to_target( x , y, interp_type='linear' ):
 1809    """
 1810    multichannel version of resample_image_to_target
 1811    """
 1812    xx=ants.split_channels( x )
 1813    yy=ants.split_channels( y )[0]
 1814    newl=[]
 1815    for k in range(len(xx)):
 1816        newl.append(  ants.resample_image_to_target( xx[k], yy, interp_type=interp_type ) )
 1817    return ants.merge_channels( newl )
 1818
 1819def nrg_filelist_to_dataframe( filename_list, myseparator="-" ):
 1820    """
 1821    convert a list of files in nrg format to a dataframe
 1822
 1823    Arguments
 1824    ---------
 1825    filename_list : globbed list of files
 1826
 1827    myseparator : string separator between nrg parts
 1828
 1829    Returns
 1830    -------
 1831
 1832    df : pandas data frame
 1833
 1834    """
 1835    def getmtime(x):
 1836        x= dt.datetime.fromtimestamp(os.path.getmtime(x)).strftime("%Y-%m-%d %H:%M:%d")
 1837        return x
 1838    df=pd.DataFrame(columns=['filename','file_last_mod_t','else','sid','visitdate','modality','uid'])
 1839    df.set_index('filename')
 1840    df['filename'] = pd.Series([file for file in filename_list ])
 1841    # I applied a time modified file to df['file_last_mod_t'] by getmtime function
 1842    df['file_last_mod_t'] = df['filename'].apply(lambda x: getmtime(x))
 1843    for k in range(df.shape[0]):
 1844        locfn=df['filename'].iloc[k]
 1845        splitter=os.path.basename(locfn).split( myseparator )
 1846        df['sid'].iloc[k]=splitter[1]
 1847        df['visitdate'].iloc[k]=splitter[2]
 1848        df['modality'].iloc[k]=splitter[3]
 1849        temp = os.path.splitext(splitter[4])[0]
 1850        df['uid'].iloc[k]=os.path.splitext(temp)[0]
 1851    return df
 1852
 1853
 1854def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
 1855    """
 1856    merge time series data into space of reference_image
 1857
 1858    img_LR : image
 1859
 1860    img_RL : image
 1861
 1862    allow_resample : boolean
 1863
 1864    """
 1865    # concatenate the images into the reference space
 1866    mimg=[]
 1867    for kk in range( img_LR.shape[3] ):
 1868        temp = ants.slice_image( img_LR, axis=3, idx=kk )
 1869        mimg.append( temp )
 1870    for kk in range( img_RL.shape[3] ):
 1871        temp = ants.slice_image( img_RL, axis=3, idx=kk )
 1872        if kk == 0:
 1873            insamespace = ants.image_physical_space_consistency( temp, mimg[0] )
 1874        if allow_resample and not insamespace :
 1875            temp = ants.resample_image_to_target( temp, mimg[0] )
 1876        mimg.append( temp )
 1877    return ants.list_to_ndimage( img_LR, mimg )
 1878
 1879def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
 1880    """
 1881    Copy spatial metadata (origin, spacing, direction) from a 3D image to the
 1882    spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
 1883
 1884    Parameters
 1885    ----------
 1886    spatial_img : ants.ANTsImage
 1887        A 3D ANTsImage with the desired spatial metadata.
 1888    timeseries_img : ants.ANTsImage
 1889        A 4D ANTsImage to update.
 1890
 1891    Returns
 1892    -------
 1893    ants.ANTsImage
 1894        A 4D ANTsImage with updated spatial metadata.
 1895    """
 1896    if spatial_img.dimension != 3:
 1897        raise ValueError("spatial_img must be a 3D ANTsImage.")
 1898    if timeseries_img.dimension != 4:
 1899        raise ValueError("timeseries_img must be a 4D ANTsImage.")
 1900    # Get 3D metadata
 1901    spatial_origin = list(spatial_img.origin)
 1902    spatial_spacing = list(spatial_img.spacing)
 1903    spatial_direction = spatial_img.direction  # 3x3
 1904    # Get original 4D metadata
 1905    ts_spacing = list(timeseries_img.spacing)
 1906    ts_origin = list(timeseries_img.origin)
 1907    ts_direction = timeseries_img.direction  # 4x4
 1908    # Replace only the first 3 entries for origin and spacing
 1909    new_origin = spatial_origin + [ts_origin[3]]
 1910    new_spacing = spatial_spacing + [ts_spacing[3]]
 1911    # Replace top-left 3x3 block of direction matrix, preserve last row/column
 1912    new_direction = ts_direction.copy()
 1913    new_direction[:3, :3] = spatial_direction
 1914    # Create updated image
 1915    updated_img = ants.from_numpy(
 1916        timeseries_img.numpy(),
 1917        origin=new_origin,
 1918        spacing=new_spacing,
 1919        direction=new_direction
 1920    )
 1921    return updated_img
 1922
 1923def timeseries_transform(transform, image, reference, interpolation='linear'):
 1924    """
 1925    Apply a spatial transform to each 3D volume in a 4D time series image.
 1926
 1927    Parameters
 1928    ----------
 1929    transform : ants transform object
 1930        Path(s) to ANTs-compatible transform(s) to apply.
 1931    image : ants.ANTsImage
 1932        4D input image with shape (X, Y, Z, T).
 1933    reference : ants.ANTsImage
 1934        Reference image to match in space.
 1935    interpolation : str
 1936        Interpolation method: 'linear', 'nearestNeighbor', etc.
 1937
 1938    Returns
 1939    -------
 1940    ants.ANTsImage
 1941        4D transformed image.
 1942    """
 1943    if image.dimension != 4:
 1944        raise ValueError("Input image must be 4D (X, Y, Z, T).")
 1945    n_volumes = image.shape[3]
 1946    transformed_volumes = []
 1947    for t in range(n_volumes):
 1948        vol = ants.slice_image( image, 3, t )
 1949        transformed = ants.apply_ants_transform_to_image(
 1950            transform=transform,
 1951            image=vol,
 1952            reference=reference,
 1953            interpolation=interpolation
 1954        )
 1955        transformed_volumes.append(transformed.numpy())
 1956    # Stack along time axis and convert to ANTsImage
 1957    transformed_array = np.stack(transformed_volumes, axis=-1)
 1958    out_image = ants.from_numpy(transformed_array)
 1959    out_image = ants.copy_image_info(image, out_image)
 1960    out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
 1961    return out_image
 1962
 1963def timeseries_reg(
 1964    image,
 1965    avg_b0,
 1966    type_of_transform='antsRegistrationSyNRepro[r]',
 1967    total_sigma=1.0,
 1968    fdOffset=2.0,
 1969    trim = 0,
 1970    output_directory=None,
 1971    return_numpy_motion_parameters=False,
 1972    verbose=False, **kwargs
 1973):
 1974    """
 1975    Correct time-series data for motion.
 1976
 1977    Arguments
 1978    ---------
 1979    image: antsImage, usually ND where D=4.
 1980
 1981    avg_b0: Fixed image b0 image
 1982
 1983    type_of_transform : string
 1984            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 1985            See ants registration for details.
 1986
 1987    fdOffset: offset value to use in framewise displacement calculation
 1988
 1989    trim : integer - trim this many images off the front of the time series
 1990
 1991    output_directory : string
 1992            output will be placed in this directory plus a numeric extension.
 1993
 1994    return_numpy_motion_parameters : boolean
 1995
 1996    verbose: boolean
 1997
 1998    kwargs: keyword args
 1999            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2000
 2001    Returns
 2002    -------
 2003    dict containing follow key/value pairs:
 2004        `motion_corrected`: Moving image warped to space of fixed image.
 2005        `motion_parameters`: transforms for each image in the time series.
 2006        `FD`: Framewise displacement generalized for arbitrary transformations.
 2007
 2008    Notes
 2009    -----
 2010    Control extra arguments via kwargs. see ants.registration for details.
 2011
 2012    Example
 2013    -------
 2014    >>> import ants
 2015    """
 2016    idim = image.dimension
 2017    ishape = image.shape
 2018    nTimePoints = ishape[idim - 1]
 2019    FD = np.zeros(nTimePoints)
 2020    if type_of_transform is None:
 2021        return {
 2022            "motion_corrected": image,
 2023            "motion_parameters": None,
 2024            "FD": FD
 2025        }
 2026
 2027    remove_it=False
 2028    if output_directory is None:
 2029        remove_it=True
 2030        output_directory = tempfile.mkdtemp()
 2031    output_directory_w = output_directory + "/ts_reg/"
 2032    os.makedirs(output_directory_w,exist_ok=True)
 2033    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2034    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2035    if verbose:
 2036        print('bold motcorr with ' + type_of_transform)
 2037        print(output_directory_w)
 2038        print(ofnG)
 2039        print(ofnL)
 2040        print("remove_it " + str( remove_it ) )
 2041
 2042    # get a local deformation from slice to local avg space
 2043    motion_parameters = list()
 2044    motion_corrected = list()
 2045    mask = ants.get_mask( avg_b0 )
 2046    centerOfMass = mask.get_center_of_mass()
 2047    npts = pow(2, idim - 1)
 2048    pointOffsets = np.zeros((npts, idim - 1))
 2049    myrad = np.ones(idim - 1).astype(int).tolist()
 2050    mask1vals = np.zeros(int(mask.sum()))
 2051    mask1vals[round(len(mask1vals) / 2)] = 1
 2052    mask1 = ants.make_image(mask, mask1vals)
 2053    myoffsets = ants.get_neighborhood_in_mask(
 2054        mask1, mask1, radius=myrad, spatial_info=True
 2055    )["offsets"]
 2056    mycols = list("xy")
 2057    if idim - 1 == 3:
 2058        mycols = list("xyz")
 2059    useinds = list()
 2060    for k in range(myoffsets.shape[0]):
 2061        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2062            useinds.append(k)
 2063        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2064    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2065    if verbose:
 2066        print("Progress:")
 2067    counter = round( nTimePoints / 10 ) + 1
 2068    for k in range( nTimePoints):
 2069        if verbose and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
 2070            myperc = round( k / nTimePoints * 100)
 2071            print(myperc, end="%.", flush=True)
 2072        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2073        temp = ants.iMath(temp, "Normalize")
 2074        txprefix = ofnL+str(k % 2).zfill(4)+"_"
 2075        if temp.numpy().var() > 0:
 2076            myrig = ants.registration(
 2077                    avg_b0, temp,
 2078                    type_of_transform='antsRegistrationSyNRepro[r]',
 2079                    outprefix=txprefix
 2080                )
 2081            if type_of_transform == 'SyN':
 2082                myreg = ants.registration(
 2083                    avg_b0, temp,
 2084                    type_of_transform='SyNOnly',
 2085                    total_sigma=total_sigma,
 2086                    initial_transform=myrig['fwdtransforms'][0],
 2087                    outprefix=txprefix,
 2088                    **kwargs
 2089                )
 2090            else:
 2091                myreg = myrig
 2092            fdptsTxI = ants.apply_transforms_to_points(
 2093                idim - 1, fdpts, myrig["fwdtransforms"]
 2094            )
 2095            if k > 0 and motion_parameters[k - 1] != "NA":
 2096                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2097                    idim - 1, fdpts, motion_parameters[k - 1]
 2098                )
 2099            else:
 2100                fdptsTxIminus1 = fdptsTxI
 2101            # take the absolute value, then the mean across columns, then the sum
 2102            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2103            motion_parameters.append(myreg["fwdtransforms"])
 2104        else:
 2105            motion_parameters.append("NA")
 2106
 2107        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2108        if temp.numpy().var() > 0:
 2109            img1w = ants.apply_transforms( avg_b0,
 2110                temp,
 2111                motion_parameters[k] )
 2112            motion_corrected.append(img1w)
 2113        else:
 2114            motion_corrected.append(avg_b0)
 2115
 2116    motion_parameters = motion_parameters[trim:len(motion_parameters)]
 2117    if return_numpy_motion_parameters:
 2118        motion_parameters = read_ants_transforms_to_numpy( motion_parameters )
 2119
 2120    if remove_it:
 2121        import shutil
 2122        shutil.rmtree(output_directory, ignore_errors=True )
 2123
 2124    if verbose:
 2125        print("Done")
 2126    d4siz = list(avg_b0.shape)
 2127    d4siz.append( 2 )
 2128    spc = list(ants.get_spacing( avg_b0 ))
 2129    spc.append( ants.get_spacing(image)[3] )
 2130    mydir = ants.get_direction( avg_b0 )
 2131    mydir4d = ants.get_direction( image )
 2132    mydir4d[0:3,0:3]=mydir
 2133    myorg = list(ants.get_origin( avg_b0 ))
 2134    myorg.append( 0.0 )
 2135    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
 2136    return {
 2137        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected[trim:len(motion_corrected)]),
 2138        "motion_parameters": motion_parameters,
 2139        "FD": FD[trim:len(FD)]
 2140    }
 2141
 2142
 2143def merge_dwi_data( img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL ):
 2144    """
 2145    merge motion and distortion corrected data if possible
 2146
 2147    img_LRdwp : image
 2148
 2149    bval_LR : array
 2150
 2151    bvec_LR : array
 2152
 2153    img_RLdwp : image
 2154
 2155    bval_RL : array
 2156
 2157    bvec_RL : array
 2158
 2159    """
 2160    import warnings
 2161    insamespace = ants.image_physical_space_consistency( img_LRdwp, img_RLdwp )
 2162    if not insamespace :
 2163        warnings.warn('not insamespace ... corrected image pair should occupy the same physical space; returning only the 1st set and wont join these data.')
 2164        return img_LRdwp, bval_LR, bvec_LR
 2165    
 2166    bval_LR = np.concatenate([bval_LR,bval_RL])
 2167    bvec_LR = np.concatenate([bvec_LR,bvec_RL])
 2168    # concatenate the images
 2169    mimg=[]
 2170    for kk in range( img_LRdwp.shape[3] ):
 2171            mimg.append( ants.slice_image( img_LRdwp, axis=3, idx=kk ) )
 2172    for kk in range( img_RLdwp.shape[3] ):
 2173            mimg.append( ants.slice_image( img_RLdwp, axis=3, idx=kk ) )
 2174    img_LRdwp = ants.list_to_ndimage( img_LRdwp, mimg )
 2175    return img_LRdwp, bval_LR, bvec_LR
 2176
 2177def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
 2178    if motion_parameters is None:
 2179        return bvecs
 2180    n = len(motion_parameters)
 2181    if n < 1:
 2182        return bvecs
 2183    from scipy.linalg import inv, polar
 2184    from dipy.core.gradients import reorient_bvecs
 2185    dipymoco = np.zeros( [n,3,3] )
 2186    for myidx in range(n):
 2187        if myidx < bvecs.shape[0]:
 2188            dipymoco[myidx,:,:] = np.eye( 3 )
 2189            if motion_parameters[myidx] != 'NA':
 2190                temp = motion_parameters[myidx]
 2191                if len(temp) == 4 :
 2192                    temp1=temp[3] # FIXME should be composite of index 1 and 3
 2193                    temp2=temp[1] # FIXME should be composite of index 1 and 3
 2194                    txparam1 = ants.read_transform(temp1)
 2195                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
 2196                    txparam2 = ants.read_transform(temp2)
 2197                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
 2198                    Rinv = inv( np.dot( txparam2, txparam1 ) )
 2199                elif len(temp) == 2 :
 2200                    temp=temp[1] # FIXME should be composite of index 1 and 3
 2201                    txparam = ants.read_transform(temp)
 2202                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
 2203                    Rinv = inv( txparam )
 2204                elif len(temp) == 3 :
 2205                    temp1=temp[2] # FIXME should be composite of index 1 and 3
 2206                    temp2=temp[1] # FIXME should be composite of index 1 and 3
 2207                    txparam1 = ants.read_transform(temp1)
 2208                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
 2209                    txparam2 = ants.read_transform(temp2)
 2210                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
 2211                    Rinv = inv( np.dot( txparam2, txparam1 ) )
 2212                else:
 2213                    temp=temp[0]
 2214                    txparam = ants.read_transform(temp)
 2215                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
 2216                    Rinv = inv( txparam )
 2217                bvecs[myidx,:] = np.dot( Rinv, bvecs[myidx,:] )
 2218                if rebase is not None:
 2219                    # FIXME - should combine these operations
 2220                    bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
 2221    return bvecs
 2222
 2223
 2224def distortion_correct_bvecs(bvecs, def_grad, A_img, A_ref):
 2225    """
 2226    Vectorized computation of voxel-wise distortion corrected b-vectors.
 2227
 2228    Parameters
 2229    ----------
 2230    bvecs : ndarray (N, 3)
 2231    def_grad : ndarray (X, Y, Z, 3, 3) containing rotations derived from the deformation gradient
 2232    A_img : ndarray (3, 3) direction matrix of the fixed image (target undistorted space)
 2233    A_ref : ndarray (3, 3) direction matrix of the moving image (being corrected)
 2234
 2235    Returns
 2236    -------
 2237    bvecs_5d : ndarray (X, Y, Z, N, 3)
 2238    """
 2239    X, Y, Z = def_grad.shape[:3]
 2240    N = bvecs.shape[0]
 2241    # Combined rotation: R_voxel = A_ref.T @ A_img @ def_grad
 2242    A = A_ref.T @ A_img
 2243    R_voxel = np.einsum('ij,xyzjk->xyzik', A, def_grad)  # (X, Y, Z, 3, 3)
 2244    # Apply R_voxel.T @ bvecs
 2245    # First, reshape R_voxel: (X*Y*Z, 3, 3)
 2246    R_voxel_reshaped = R_voxel.reshape(-1, 3, 3)
 2247    # Rotate all bvecs for each voxel
 2248    # Output: (X*Y*Z, N, 3)
 2249    rotated = np.einsum('vij,nj->vni', R_voxel_reshaped, bvecs)
 2250    # Normalize
 2251    norms = np.linalg.norm(rotated, axis=2, keepdims=True)
 2252    rotated /= np.clip(norms, 1e-8, None)
 2253    # Reshape back to (X, Y, Z, N, 3)
 2254    bvecs_5d = rotated.reshape(X, Y, Z, N, 3)
 2255    return bvecs_5d    
 2256
 2257def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
 2258    """
 2259    extract DTI data from a dipy tensormodel
 2260
 2261    reference_image : antsImage defining physical space (3D)
 2262
 2263    tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon
 2264
 2265    upper_triangular: boolean otherwise use lower triangular coding
 2266
 2267    return_image : boolean return the ANTsImage form of DTI otherwise return an array
 2268
 2269    Returns
 2270    -------
 2271    either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form)
 2272        or a 5D NumPy array (dim=X.Y.Z.3.3)
 2273
 2274    Notes
 2275    -----
 2276    DiPy returns lower triangular form but ANTs expects upper triangular.
 2277        Here, we default to the ANTs standard but could generalize in the future 
 2278        because not much here depends on ANTs standards of tensor data.
 2279        ANTs xx,xy,xz,yy,yz,zz
 2280        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
 2281
 2282    """
 2283    # make the DTI - see 
 2284    # https://dipy.org/documentation/1.7.0/examples_built/07_reconstruction/reconst_dti/#sphx-glr-examples-built-07-reconstruction-reconst-dti-py
 2285    # By default, in DIPY, values are ordered as (Dxx, Dxy, Dyy, Dxz, Dyz, Dzz)
 2286    # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2287    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
 2288    import dipy.reconst.dti as dti
 2289    dtiut = dti.lower_triangular(tensormodel.quadratic_form)
 2290    it = np.ndindex( reference_image.shape )
 2291    yyind=2
 2292    xzind=3
 2293    if upper_triangular:
 2294        yyind=3
 2295        xzind=2
 2296        for i in it: # convert to upper triangular
 2297            dtiut[i] = dtiut[i][ reoind ] # do we care if this is doing extra work?
 2298    if return_image:
 2299        dtiAnts = ants.from_numpy(dtiut,has_components=True)
 2300        ants.copy_image_info( reference_image, dtiAnts )
 2301        return dtiAnts
 2302    # copy these data into a tensor 
 2303    dtinp = np.zeros(reference_image.shape + (3,3), dtype=float)  
 2304    dtix = np.zeros((3,3), dtype=float)  
 2305    it = np.ndindex( reference_image.shape )
 2306    for i in it:
 2307        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2308        dtix[0,0]=dtivec[0]
 2309        dtix[1,1]=dtivec[yyind] # 2 for LT
 2310        dtix[2,2]=dtivec[5] 
 2311        dtix[0,1]=dtix[1,0]=dtivec[1]
 2312        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
 2313        dtix[1,2]=dtix[2,1]=dtivec[4]
 2314        dtinp[i]=dtix
 2315    return dtinp
 2316
 2317def triangular_to_tensor( image, upper_triangular=True ):
 2318    """
 2319    convert triangular tensor image to a full tensor form (in numpy)
 2320
 2321    image : antsImage holding dti in either upper or lower triangular format 
 2322
 2323    upper_triangular: boolean
 2324
 2325    Note
 2326    --------
 2327    see get_dti function for more details
 2328    """
 2329    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
 2330    it = np.ndindex( image.shape )
 2331    yyind=2
 2332    xzind=3
 2333    if upper_triangular:
 2334        yyind=3
 2335        xzind=2
 2336    # copy these data into a tensor 
 2337    dtinp = np.zeros(image.shape + (3,3), dtype=float)
 2338    dtix = np.zeros((3,3), dtype=float)
 2339    it = np.ndindex( image.shape )
 2340    dtiut = image.numpy()
 2341    for i in it:
 2342        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2343        dtix[0,0]=dtivec[0]
 2344        dtix[1,1]=dtivec[yyind] # 2 for LT
 2345        dtix[2,2]=dtivec[5] 
 2346        dtix[0,1]=dtix[1,0]=dtivec[1]
 2347        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
 2348        dtix[1,2]=dtix[2,1]=dtivec[4]
 2349        dtinp[i]=dtix
 2350    return dtinp
 2351
 2352
 2353def dti_numpy_to_image( reference_image, tensorarray, upper_triangular=True):
 2354    """
 2355    convert numpy DTI data to antsImage
 2356
 2357    reference_image : antsImage defining physical space (3D)
 2358
 2359    tensorarray : numpy array X,Y,Z,3,3 shape
 2360
 2361    upper_triangular: boolean otherwise use lower triangular coding
 2362
 2363    Returns
 2364    -------
 2365    ANTsImage
 2366
 2367    Notes
 2368    -----
 2369    DiPy returns lower triangular form but ANTs expects upper triangular.
 2370        Here, we default to the ANTs standard but could generalize in the future 
 2371        because not much here depends on ANTs standards of tensor data.
 2372        ANTs xx,xy,xz,yy,yz,zz
 2373        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
 2374
 2375    """
 2376    dtiut = np.zeros(reference_image.shape + (6,), dtype=float)  
 2377    dtivec = np.zeros(6, dtype=float)  
 2378    it = np.ndindex( reference_image.shape )
 2379    yyind=2
 2380    xzind=3
 2381    if upper_triangular:
 2382        yyind=3
 2383        xzind=2
 2384    for i in it:
 2385        dtix = tensorarray[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
 2386        dtivec[0]=dtix[0,0]
 2387        dtivec[yyind]=dtix[1,1] # 2 for LT
 2388        dtivec[5]=dtix[2,2]
 2389        dtivec[1]=dtix[0,1]
 2390        dtivec[xzind]=dtix[2,0] # 3 for LT
 2391        dtivec[4]=dtix[1,2]
 2392        dtiut[i]=dtivec
 2393    dtiAnts = ants.from_numpy( dtiut, has_components=True )
 2394    ants.copy_image_info( reference_image, dtiAnts )
 2395    return dtiAnts
 2396
 2397
 2398def deformation_gradient_optimized(warp_image, to_rotation=False, to_inverse_rotation=False):
 2399    """
 2400    Compute the deformation gradient tensor from a displacement (warp) field image.
 2401
 2402    This function computes the **deformation gradient** `F = ∂φ/∂x` where `φ(x) = x + u(x)` is the mapping
 2403    induced by the displacement field `u(x)` stored in `warp_image`.
 2404
 2405    The returned tensor field has shape `(x, y, z, dim, dim)` (for 3D), where each matrix represents 
 2406    the **Jacobian** of the transformation at that voxel. The gradient is computed in the physical space 
 2407    of the image using spacing and direction metadata.
 2408
 2409    Optionally, the deformation gradient can be projected onto the space of pure rotations using the polar
 2410    decomposition (via SVD). This is useful for applications like reorientation of tensors (e.g., DTI).
 2411
 2412    Parameters
 2413    ----------
 2414    warp_image : ants.ANTsImage
 2415        A vector-valued ANTsImage encoding the warp/displacement field. It must have `dim` components
 2416        (e.g., shape `(x, y, z, 3)` for 3D) representing the displacements in each spatial direction.
 2417        
 2418    to_rotation : bool, optional
 2419        If True, the deformation gradient will be replaced with its **nearest rotation matrix**
 2420        using the polar decomposition (`F → R`, where `F = R U`).
 2421        
 2422    to_inverse_rotation : bool, optional
 2423        If True, the deformation gradient will be replaced with the **inverse of the rotation**
 2424        (`F → R.T`), which is often needed for transforming tensors **back** to their original frame.
 2425
 2426    Returns
 2427    -------
 2428    F : np.ndarray
 2429        A NumPy array of shape `(x, y, z, dim, dim)` (or `(dim1, dim2, ..., dim, dim)` in general),
 2430        representing the deformation gradient tensor field at each voxel.
 2431
 2432    Raises
 2433    ------
 2434    RuntimeError
 2435        If `warp_image` is not an `ants.ANTsImage`.
 2436
 2437    Notes
 2438    -----
 2439    - The function computes gradients in physical space using the spacing of the image and applies 
 2440      the image direction matrix (`tdir`) to properly orient gradients.
 2441    - The gradient is added to the identity matrix to yield the deformation gradient `F = I + ∂u/∂x`.
 2442    - The polar decomposition ensures `F` is replaced with the closest rotation matrix (orthogonal, det=1).
 2443    - This is a **vectorized pure NumPy implementation**, intended for performance and simplicity.
 2444
 2445    Examples
 2446    --------
 2447    >>> warp = ants.create_warp_image(reference_image, displacement_field)
 2448    >>> F = deformation_gradient_optimized(warp)
 2449    >>> R = deformation_gradient_optimized(warp, to_rotation=True)
 2450    >>> Rinv = deformation_gradient_optimized(warp, to_inverse_rotation=True)
 2451    """
 2452    if not ants.is_image(warp_image):
 2453        raise RuntimeError("antsimage is required")
 2454    dim = warp_image.dimension
 2455    tshp = warp_image.shape
 2456    tdir = warp_image.direction
 2457    spc = warp_image.spacing
 2458    warpnp = warp_image.numpy()
 2459    gradient_list = [np.gradient(warpnp[..., k], *spc, axis=range(dim)) for k in range(dim)]
 2460    # This correctly calculates J.T, where dg[..., i, j] = d(u_j)/d(x_i)
 2461    dg = np.stack([np.stack(grad_k, axis=-1) for grad_k in gradient_list], axis=-1)
 2462    dg = (tdir @ dg).swapaxes(-1, -2)
 2463    dg += np.eye(dim)
 2464    if to_rotation or to_inverse_rotation:
 2465        U, s, Vh = np.linalg.svd(dg)
 2466        Z = U @ Vh
 2467        dets = np.linalg.det(Z)
 2468        reflection_mask = dets < 0
 2469        Vh[reflection_mask, -1, :] *= -1
 2470        Z[reflection_mask] = U[reflection_mask] @ Vh[reflection_mask]
 2471        dg = Z
 2472        if to_inverse_rotation:
 2473            dg = np.transpose(dg, axes=(*range(dg.ndim - 2), dg.ndim - 1, dg.ndim - 2))
 2474    new_shape = tshp + (dim,dim)
 2475    return np.reshape(dg, new_shape)
 2476
 2477
 2478def transform_and_reorient_dti( fixed, moving_dti, composite_transform, verbose=False, **kwargs):
 2479    """
 2480    Applies a transformation to a DTI image using an ANTs composite transform,
 2481    including local tensor reorientation via the Finite Strain method.
 2482
 2483    This function expects:
 2484    - Input `moving_dti` to be a 6-component ANTsImage (upper triangular format).
 2485    - `composite_transform` to point to an ANTs-readable transform file,
 2486      which maps points from `fixed` space to `moving` space.
 2487
 2488    Args:
 2489        fixed (ants.ANTsImage): The reference space image (defines the output grid).
 2490        moving_dti (ants.ANTsImage): The input DTI (6-component), to be transformed.
 2491        composite_transform (str): File path to an ANTs transform
 2492                                   (e.g., from `ants.read_transform` or a written composite transform).
 2493        verbose (bool): Whether to print verbose output during execution.
 2494        **kwargs: Additional keyword arguments passed to `ants.apply_transforms`.
 2495
 2496    Returns:
 2497        ants.ANTsImage: The transformed and reoriented DTI image in the `fixed` space,
 2498                        in 6-component upper triangular format.
 2499    """
 2500    if moving_dti.dimension != 3:
 2501        raise ValueError('moving_dti must be 3-dimensional.')
 2502    if moving_dti.components != 6:
 2503        raise ValueError('moving_dti must have 6 components (upper triangular format).')
 2504
 2505    if verbose:
 2506        print("1. Spatially transforming DTI scalar components from moving to fixed space...")
 2507
 2508    # ants.apply_transforms resamples the *values* of each DTI component from 'moving_dti'
 2509    # onto the grid of 'fixed'.
 2510    # The output 'dtiw' will have the same spatial metadata (spacing, origin, direction) as 'fixed'.
 2511    # However, the tensor values contained within it are still oriented as they were in
 2512    # 'moving_dti's original image space, not 'fixed' image space, and certainly not yet reoriented
 2513    # by the local deformation.
 2514    dtsplit = moving_dti.split_channels()
 2515    dtiw_channels = []
 2516    for k in range(len(dtsplit)):
 2517        dtiw_channels.append( ants.apply_transforms( fixed, dtsplit[k], composite_transform, **kwargs ) )
 2518    dtiw = ants.merge_channels(dtiw_channels)
 2519    
 2520    if verbose:
 2521        print(f"   DTI scalar components resampled to fixed grid. Result shape: {dtiw.shape}")
 2522        print("2. Computing local rotation field from composite transform...")
 2523    
 2524    # Read the composite transform as an image (assumed to be a displacement field).
 2525    # The 'deformation_gradient_optimized' function is assumed to be 100% correct,
 2526    # meaning it returns the appropriate local rotation matrix field (R_moving_to_fixed)
 2527    # in (spatial_dims..., 3, 3 ) format when called with these flags.
 2528    wtx = ants.image_read(composite_transform)
 2529    R_moving_to_fixed_forward = deformation_gradient_optimized(
 2530        wtx,
 2531        to_rotation=False,      # This means the *deformation gradient* F=I+J is computed first.
 2532        to_inverse_rotation=True # This requests the inverse of the rotation part of F.
 2533    )
 2534
 2535    if verbose:
 2536        print(f"   Local reorientation matrices (R_moving_to_fixed_forward) computed. Shape: {R_moving_to_fixed_forward.shape}")
 2537        print("3. Converting 6-component DTI to full 3x3 tensors for vectorized reorientation...")
 2538    
 2539    # Convert `dtiw` (resampled, but still in moving-image-space orientation)
 2540    # from 6-components to full 3x3 tensor representation.
 2541    # dtiw2tensor_np will have shape (spatial_dims..., 3, 3).
 2542    dtiw2tensor_np = triangular_to_tensor(dtiw)
 2543    
 2544    if verbose:
 2545        print("4. Applying vectorized tensor reorientation (Finite Strain Method)...")
 2546    
 2547    # --- Vectorized Tensor Reorientation ---
 2548    # This replaces the entire `for i in it:` loop and its contents with efficient NumPy operations.
 2549
 2550    # Step 4.1: Rebase tensors from `moving_dti.direction` coordinate system to World Coordinates.
 2551    # D_world_moving_orient = moving_dti.direction @ D_moving_image_frame @ moving_dti.direction.T
 2552    # This transforms the tensor's components from being relative to `moving_dti`'s image axes
 2553    # (where they are currently defined) into absolute World (physical) coordinates.
 2554    D_world_moving_orient = np.einsum(
 2555        'ab, ...bc, cd -> ...ad',
 2556        moving_dti.direction,            # 3x3 matrix (moving_image_axes -> world_axes)
 2557        dtiw2tensor_np,                  # (spatial_dims..., 3, 3)
 2558        moving_dti.direction.T           # 3x3 matrix (world_axes -> moving_image_axes) - inverse of moving_dti.direction
 2559    )
 2560
 2561    # Step 4.2: Apply local rotation in World Coordinates (Finite Strain Reorientation).
 2562    # D_reoriented_world = R_moving_to_fixed_forward @ D_world_moving_orient @ (R_moving_to_fixed_forward).T
 2563    # This is the core reorientation step, transforming the tensor's orientation from
 2564    # the original `moving` space to the new `fixed` space, all within world coordinates.
 2565    D_world_fixed_orient = np.einsum(
 2566        '...ab, ...bc, ...cd -> ...ad',
 2567        R_moving_to_fixed_forward,      # (spatial_dims..., 3, 3) - local rotation
 2568        D_world_moving_orient,          # (spatial_dims..., 3, 3) - tensor in world space, moving_orient
 2569        np.swapaxes(R_moving_to_fixed_forward, -1, -2) # (spatial_dims..., 3, 3) - transpose of local rotation
 2570    )
 2571
 2572    # Step 4.3: Rebase reoriented tensors from World Coordinates to `fixed.direction` coordinate system.
 2573    # D_final_fixed_image_frame = (fixed.direction).T @ D_world_fixed_orient @ fixed.direction
 2574    # This transforms the tensor's components from absolute World (physical) coordinates
 2575    # back into `fixed.direction`'s image coordinate system.
 2576    final_dti_tensors_numpy = np.einsum(
 2577        'ba, ...bc, cd -> ...ad',
 2578        fixed.direction,                # Using `fixed.direction` here, but 'ba' indices specify to use its transpose.
 2579        D_world_fixed_orient,           # (spatial_dims..., 3, 3)
 2580        fixed.direction                 # 3x3 matrix (world_axes -> fixed_image_axes)
 2581    )
 2582
 2583    if verbose:
 2584        print("   Vectorized tensor reorientation complete.")
 2585
 2586    if verbose:
 2587        print("5. Converting reoriented full tensors back to 6-component ANTsImage...")
 2588    
 2589    # Convert the final (spatial_dims..., 3, 3) NumPy array of tensors back into a
 2590    # 6-component ANTsImage with the correct spatial metadata from `fixed`.
 2591    final_dti_image = dti_numpy_to_image(fixed, final_dti_tensors_numpy)
 2592    
 2593    if verbose:
 2594        print(f"Done. Final reoriented DTI image in fixed space generated. Shape: {final_dti_image.shape}")
 2595
 2596    return final_dti_image
 2597
 2598def dti_reg(
 2599    image,
 2600    avg_b0,
 2601    avg_dwi,
 2602    bvals=None,
 2603    bvecs=None,
 2604    b0_idx=None,
 2605    type_of_transform="antsRegistrationSyNRepro[r]",
 2606    total_sigma=3.0,
 2607    fdOffset=2.0,
 2608    mask_csf=False,
 2609    brain_mask_eroded=None,
 2610    output_directory=None,
 2611    verbose=False, **kwargs
 2612):
 2613    """
 2614    Correct time-series data for motion - with optional deformation.
 2615
 2616    Arguments
 2617    ---------
 2618        image: antsImage, usually ND where D=4.
 2619
 2620        avg_b0: Fixed image b0 image
 2621
 2622        avg_dwi: Fixed dwi same space as b0 image
 2623
 2624        bvals: bvalues (file or array)
 2625
 2626        bvecs: bvecs (file or array)
 2627
 2628        b0_idx: indices of b0
 2629
 2630        type_of_transform : string
 2631            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 2632            See ants registration for details.
 2633
 2634        fdOffset: offset value to use in framewise displacement calculation
 2635
 2636        mask_csf: boolean
 2637
 2638        brain_mask_eroded: optional mask that will trigger mixed interpolation
 2639
 2640        output_directory : string
 2641            output will be placed in this directory plus a numeric extension.
 2642
 2643        verbose: boolean
 2644
 2645        kwargs: keyword args
 2646            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2647
 2648    Returns
 2649    -------
 2650    dict containing follow key/value pairs:
 2651        `motion_corrected`: Moving image warped to space of fixed image.
 2652        `motion_parameters`: transforms for each image in the time series.
 2653        `FD`: Framewise displacement generalized for arbitrary transformations.
 2654
 2655    Notes
 2656    -----
 2657    Control extra arguments via kwargs. see ants.registration for details.
 2658
 2659    Example
 2660    -------
 2661    >>> import ants
 2662    """
 2663
 2664    idim = image.dimension
 2665    ishape = image.shape
 2666    nTimePoints = ishape[idim - 1]
 2667    FD = np.zeros(nTimePoints)
 2668    if bvals is not None and bvecs is not None:
 2669        if isinstance(bvecs, str):
 2670            bvals, bvecs = read_bvals_bvecs( bvals , bvecs  )
 2671        else: # assume we already read them
 2672            bvals = bvals.copy()
 2673            bvecs = bvecs.copy()
 2674    if type_of_transform is None:
 2675        return {
 2676            "motion_corrected": image,
 2677            "motion_parameters": None,
 2678            "FD": FD,
 2679            'bvals':bvals,
 2680            'bvecs':bvecs
 2681        }
 2682
 2683    from scipy.linalg import inv, polar
 2684    from dipy.core.gradients import reorient_bvecs
 2685
 2686    remove_it=False
 2687    if output_directory is None:
 2688        remove_it=True
 2689        output_directory = tempfile.mkdtemp()
 2690    output_directory_w = output_directory + "/dti_reg/"
 2691    os.makedirs(output_directory_w,exist_ok=True)
 2692    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2693    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2694    if verbose:
 2695        print(output_directory_w)
 2696        print(ofnG)
 2697        print(ofnL)
 2698        print("remove_it " + str( remove_it ) )
 2699
 2700    if b0_idx is None:
 2701        # b0_idx = segment_timeseries_by_meanvalue( image )['highermeans']
 2702        b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
 2703
 2704    # first get a local deformation from slice to local avg space
 2705    # then get a global deformation from avg to ref space
 2706    ab0, adw = get_average_dwi_b0( image )
 2707    # mask is used to roughly locate middle of brain
 2708    mask = ants.threshold_image( ants.iMath(adw,'Normalize'), 0.1, 1.0 )
 2709    if brain_mask_eroded is None:
 2710        brain_mask_eroded = mask * 0 + 1
 2711    motion_parameters = list()
 2712    motion_corrected = list()
 2713    centerOfMass = mask.get_center_of_mass()
 2714    npts = pow(2, idim - 1)
 2715    pointOffsets = np.zeros((npts, idim - 1))
 2716    myrad = np.ones(idim - 1).astype(int).tolist()
 2717    mask1vals = np.zeros(int(mask.sum()))
 2718    mask1vals[round(len(mask1vals) / 2)] = 1
 2719    mask1 = ants.make_image(mask, mask1vals)
 2720    myoffsets = ants.get_neighborhood_in_mask(
 2721        mask1, mask1, radius=myrad, spatial_info=True
 2722    )["offsets"]
 2723    mycols = list("xy")
 2724    if idim - 1 == 3:
 2725        mycols = list("xyz")
 2726    useinds = list()
 2727    for k in range(myoffsets.shape[0]):
 2728        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2729            useinds.append(k)
 2730        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2731    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2732
 2733
 2734    if verbose:
 2735        print("begin global distortion correction")
 2736    # initrig = tra_initializer(avg_b0, ab0, max_rotation=60, transform=['rigid'], verbose=verbose)
 2737    if mask_csf:
 2738        bcsf = ants.threshold_image( avg_b0,"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 2739    else:
 2740        bcsf = ab0 * 0 + 1
 2741
 2742    initrig = ants.registration( avg_b0, ab0,'antsRegistrationSyNRepro[r]',outprefix=ofnG)
 2743    deftx = ants.registration( avg_dwi, adw, 'SyNOnly',
 2744        syn_metric='CC', syn_sampling=2,
 2745        reg_iterations=[50,50,20],
 2746        multivariate_extras=[ [ "CC", avg_b0, ab0, 1, 2 ]],
 2747        initial_transform=initrig['fwdtransforms'][0],
 2748        outprefix=ofnG
 2749        )['fwdtransforms']
 2750    if verbose:
 2751        print("end global distortion correction")
 2752
 2753    if verbose:
 2754        print("Progress:")
 2755    counter = round( nTimePoints / 10 ) + 1
 2756    for k in range(nTimePoints):
 2757        if verbose and nTimePoints > 0 and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
 2758            myperc = round( k / nTimePoints * 100)
 2759            print(myperc, end="%.", flush=True)
 2760        if k in b0_idx:
 2761            fixed=ants.image_clone( ab0 )
 2762        else:
 2763            fixed=ants.image_clone( adw )
 2764        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2765        temp = ants.iMath(temp, "Normalize")
 2766        txprefix = ofnL+str(k).zfill(4)+"rig_"
 2767        txprefix2 = ofnL+str(k % 2).zfill(4)+"def_"
 2768        if temp.numpy().var() > 0:
 2769            myrig = ants.registration(
 2770                    fixed, temp,
 2771                    type_of_transform='antsRegistrationSyNRepro[r]',
 2772                    outprefix=txprefix,
 2773                    **kwargs
 2774                )
 2775            if type_of_transform == 'SyN':
 2776                myreg = ants.registration(
 2777                    fixed, temp,
 2778                    type_of_transform='SyNOnly',
 2779                    total_sigma=total_sigma, grad_step=0.1,
 2780                    initial_transform=myrig['fwdtransforms'][0],
 2781                    outprefix=txprefix2,
 2782                    **kwargs
 2783                )
 2784            else:
 2785                myreg = myrig
 2786            fdptsTxI = ants.apply_transforms_to_points(
 2787                idim - 1, fdpts, myrig["fwdtransforms"]
 2788            )
 2789            if k > 0 and motion_parameters[k - 1] != "NA":
 2790                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2791                    idim - 1, fdpts, motion_parameters[k - 1]
 2792                )
 2793            else:
 2794                fdptsTxIminus1 = fdptsTxI
 2795            # take the absolute value, then the mean across columns, then the sum
 2796            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2797            motion_parameters.append(myreg["fwdtransforms"])
 2798        else:
 2799            motion_parameters.append("NA")
 2800
 2801        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2802        if k in b0_idx:
 2803            fixed=ants.image_clone( ab0 )
 2804        else:
 2805            fixed=ants.image_clone( adw )
 2806        if temp.numpy().var() > 0:
 2807            motion_parameters[k]=deftx+motion_parameters[k]
 2808            img1w = apply_transforms_mixed_interpolation( avg_dwi,
 2809                ants.slice_image(image, axis=idim - 1, idx=k),
 2810                motion_parameters[k], mask=brain_mask_eroded )
 2811            motion_corrected.append(img1w)
 2812        else:
 2813            motion_corrected.append(fixed)
 2814
 2815    if verbose:
 2816        print("Reorient bvecs")
 2817    if bvecs is not None:
 2818        #    direction = target->GetDirection().GetTranspose() * img_mov->GetDirection().GetVnlMatrix();
 2819        rebase = np.dot( np.transpose( avg_b0.direction  ), ab0.direction )
 2820        bvecs = bvec_reorientation( motion_parameters, bvecs, rebase )
 2821
 2822    if remove_it:
 2823        import shutil
 2824        shutil.rmtree(output_directory, ignore_errors=True )
 2825
 2826    if verbose:
 2827        print("Done")
 2828    d4siz = list(avg_b0.shape)
 2829    d4siz.append( 2 )
 2830    spc = list(ants.get_spacing( avg_b0 ))
 2831    spc.append( 1.0 )
 2832    mydir = ants.get_direction( avg_b0 )
 2833    mydir4d = ants.get_direction( image )
 2834    mydir4d[0:3,0:3]=mydir
 2835    myorg = list(ants.get_origin( avg_b0 ))
 2836    myorg.append( 0.0 )
 2837    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
 2838    return {
 2839        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected),
 2840        "motion_parameters": motion_parameters,
 2841        "FD": FD,
 2842        'bvals':bvals,
 2843        'bvecs':bvecs
 2844    }
 2845
 2846
 2847def mc_reg(
 2848    image,
 2849    fixed=None,
 2850    type_of_transform="antsRegistrationSyNRepro[r]",
 2851    mask=None,
 2852    total_sigma=3.0,
 2853    fdOffset=2.0,
 2854    output_directory=None,
 2855    verbose=False, **kwargs
 2856):
 2857    """
 2858    Correct time-series data for motion - with deformation.
 2859
 2860    Arguments
 2861    ---------
 2862        image: antsImage, usually ND where D=4.
 2863
 2864        fixed: Fixed image to register all timepoints to.  If not provided,
 2865            mean image is used.
 2866
 2867        type_of_transform : string
 2868            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
 2869            See ants registration for details.
 2870
 2871        fdOffset: offset value to use in framewise displacement calculation
 2872
 2873        output_directory : string
 2874            output will be named with this prefix plus a numeric extension.
 2875
 2876        verbose: boolean
 2877
 2878        kwargs: keyword args
 2879            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
 2880
 2881    Returns
 2882    -------
 2883    dict containing follow key/value pairs:
 2884        `motion_corrected`: Moving image warped to space of fixed image.
 2885        `motion_parameters`: transforms for each image in the time series.
 2886        `FD`: Framewise displacement generalized for arbitrary transformations.
 2887
 2888    Notes
 2889    -----
 2890    Control extra arguments via kwargs. see ants.registration for details.
 2891
 2892    Example
 2893    -------
 2894    >>> import ants
 2895    >>> fi = ants.image_read(ants.get_ants_data('ch2'))
 2896    >>> mytx = ants.motion_correction( fi )
 2897    """
 2898    remove_it=False
 2899    if output_directory is None:
 2900        remove_it=True
 2901        output_directory = tempfile.mkdtemp()
 2902    output_directory_w = output_directory + "/mc_reg/"
 2903    os.makedirs(output_directory_w,exist_ok=True)
 2904    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
 2905    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
 2906    if verbose:
 2907        print(output_directory_w)
 2908        print(ofnG)
 2909        print(ofnL)
 2910
 2911    idim = image.dimension
 2912    ishape = image.shape
 2913    nTimePoints = ishape[idim - 1]
 2914    if fixed is None:
 2915        fixed = ants.get_average_of_timeseries( image )
 2916    if mask is None:
 2917        mask = ants.get_mask(fixed)
 2918    FD = np.zeros(nTimePoints)
 2919    motion_parameters = list()
 2920    motion_corrected = list()
 2921    centerOfMass = mask.get_center_of_mass()
 2922    npts = pow(2, idim - 1)
 2923    pointOffsets = np.zeros((npts, idim - 1))
 2924    myrad = np.ones(idim - 1).astype(int).tolist()
 2925    mask1vals = np.zeros(int(mask.sum()))
 2926    mask1vals[round(len(mask1vals) / 2)] = 1
 2927    mask1 = ants.make_image(mask, mask1vals)
 2928    myoffsets = ants.get_neighborhood_in_mask(
 2929        mask1, mask1, radius=myrad, spatial_info=True
 2930    )["offsets"]
 2931    mycols = list("xy")
 2932    if idim - 1 == 3:
 2933        mycols = list("xyz")
 2934    useinds = list()
 2935    for k in range(myoffsets.shape[0]):
 2936        if abs(myoffsets[k, :]).sum() == (idim - 2):
 2937            useinds.append(k)
 2938        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
 2939    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
 2940    if verbose:
 2941        print("Progress:")
 2942    counter = 0
 2943    for k in range(nTimePoints):
 2944        mycount = round(k / nTimePoints * 100)
 2945        if verbose and mycount == counter:
 2946            counter = counter + 10
 2947            print(mycount, end="%.", flush=True)
 2948        temp = ants.slice_image(image, axis=idim - 1, idx=k)
 2949        temp = ants.iMath(temp, "Normalize")
 2950        if temp.numpy().var() > 0:
 2951            myrig = ants.registration(
 2952                    fixed, temp,
 2953                    type_of_transform='antsRegistrationSyNRepro[r]',
 2954                    outprefix=ofnL+str(k).zfill(4)+"_",
 2955                    **kwargs
 2956                )
 2957            if type_of_transform == 'SyN':
 2958                myreg = ants.registration(
 2959                    fixed, temp,
 2960                    type_of_transform='SyNOnly',
 2961                    total_sigma=total_sigma,
 2962                    initial_transform=myrig['fwdtransforms'][0],
 2963                    outprefix=ofnL+str(k).zfill(4)+"_",
 2964                    **kwargs
 2965                )
 2966            else:
 2967                myreg = myrig
 2968            fdptsTxI = ants.apply_transforms_to_points(
 2969                idim - 1, fdpts, myreg["fwdtransforms"]
 2970            )
 2971            if k > 0 and motion_parameters[k - 1] != "NA":
 2972                fdptsTxIminus1 = ants.apply_transforms_to_points(
 2973                    idim - 1, fdpts, motion_parameters[k - 1]
 2974                )
 2975            else:
 2976                fdptsTxIminus1 = fdptsTxI
 2977            # take the absolute value, then the mean across columns, then the sum
 2978            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
 2979            motion_parameters.append(myreg["fwdtransforms"])
 2980            img1w = ants.apply_transforms( fixed,
 2981                ants.slice_image(image, axis=idim - 1, idx=k),
 2982                myreg["fwdtransforms"] )
 2983            motion_corrected.append(img1w)
 2984        else:
 2985            motion_parameters.append("NA")
 2986            motion_corrected.append(temp)
 2987
 2988    if remove_it:
 2989        import shutil
 2990        shutil.rmtree(output_directory, ignore_errors=True )
 2991
 2992    if verbose:
 2993        print("Done")
 2994    return {
 2995        "motion_corrected": ants.list_to_ndimage(image, motion_corrected),
 2996        "motion_parameters": motion_parameters,
 2997        "FD": FD,
 2998    }
 2999
 3000def map_scalar_to_labels(dataframe, label_image_template):
 3001    """
 3002    Map scalar values from a DataFrame to associated integer image labels.
 3003
 3004    Parameters:
 3005    - dataframe (pd.DataFrame): A Pandas DataFrame containing a label column and scalar_value column.
 3006    - label_image_template (ants.ANTsImage): ANTs image with (at least some of) the same values as labels.
 3007
 3008    Returns:
 3009    - ants.ANTsImage: A label image with scalar values mapped to associated integer labels.
 3010    """
 3011
 3012    # Create an empty label image with the same geometry as the template
 3013    mapped_label_image = label_image_template.clone() * 0.0
 3014
 3015    # Loop through DataFrame and map scalar values to labels
 3016    for index, row in dataframe.iterrows():
 3017        label = int(row['label'])  # Assuming the DataFrame has a 'label' column
 3018        scalar_value = row['scalar_value']  # Replace with your column name
 3019        mapped_label_image[label_image_template == label] = scalar_value
 3020
 3021    return mapped_label_image
 3022
 3023
 3024def template_figure_with_overlay(scalar_label_df, prefix, outputfilename=None, template='cit168', xyz=None, mask_dilation=25, padding=12, verbose=True):
 3025    """
 3026    Process and visualize images with mapped scalar values.
 3027
 3028    Parameters:
 3029    - scalar_label_df (pd.DataFrame): A Pandas DataFrame containing scalar values and labels.
 3030    - prefix (str): The prefix for input image files.
 3031    - template (str, optional): Template for selecting image data (default is 'cit168').
 3032    - xyz (str, optional): The integer index of the slices to display.
 3033    - mask_dilation (int, optional): Dilation factor for creating a mask (default is 25).
 3034    - padding (int, optional): Padding value for the mapped images (default is 12).
 3035    - verbose (bool, optional): Enable verbose mode for printing (default is True).
 3036
 3037    Example Usage:
 3038    >>> scalar_label_df = pd.DataFrame({'label': [1, 2, 3], 'scalar_value': [0.5, 0.8, 1.2]})
 3039    >>> prefix = '../PPMI_template0_'
 3040    >>> process_and_visualize_images(scalar_label_df, prefix, template='cit168', xyz=None, mask_dilation=25, padding=12, verbose=True)
 3041    """
 3042
 3043    # Template image paths
 3044    template_paths = {
 3045        'cit168': 'cit168lab.nii.gz',
 3046        'bf': 'bf.nii.gz',
 3047        'cerebellum': 'cerebellum.nii.gz',
 3048        'mtl': 'mtl.nii.gz',
 3049        'ctx': 'dkt_cortex.nii.gz',
 3050        'jhuwm': 'JHU_wm.nii.gz'
 3051    }
 3052
 3053    if template not in template_paths:
 3054        print( "Valid options:")
 3055        print( template_paths )
 3056        raise ValueError(f"Template option '{template}' does not exist.")
 3057
 3058    template_image_path = template_paths[template]
 3059    template_image = ants.image_read(f'{prefix}{template_image_path}')
 3060
 3061    # Load image data
 3062    edgeimg = ants.image_read(f'{prefix}edge.nii.gz')
 3063    dktimg = ants.image_read(f'{prefix}dkt_parcellation.nii.gz')
 3064    segimg = ants.image_read(f'{prefix}tissue_segmentation.nii.gz')
 3065    ttl = ''
 3066
 3067    # Load and process the template image
 3068    ventricles = ants.threshold_image(dktimg, 4, 4) + ants.threshold_image(dktimg, 43, 43)
 3069    seggm = ants.mask_image(segimg, segimg, [2, 4], binarize=False)
 3070    edgeimg = edgeimg.clone()
 3071    edgeimg[edgeimg == 0] = ventricles[edgeimg == 0]
 3072    segwm = ants.threshold_image(segimg, 3, 4).morphology("open", 1)
 3073
 3074    # Define cropping mask
 3075    cmask = ants.threshold_image(template_image, 1, 1.e9).iMath("MD", mask_dilation)
 3076
 3077    mapped_image = map_scalar_to_labels(scalar_label_df, template_image)
 3078    tcrop = ants.crop_image(template_image, cmask)
 3079    toviz = ants.crop_image(mapped_image, cmask)
 3080    seggm = ants.crop_image(edgeimg, cmask)
 3081       
 3082    # Map scalar values to labels and visualize
 3083    toviz = ants.pad_image(toviz, pad_width=(padding, padding, padding))
 3084    seggm = ants.pad_image(seggm, pad_width=(padding, padding, padding))
 3085    tcrop = ants.pad_image(tcrop, pad_width=(padding, padding, padding))
 3086
 3087    if xyz is None:
 3088        if template == 'cit168':
 3089            xyz=[140, 89, 94]
 3090        elif template == 'bf':
 3091            xyz=[114,92,76]
 3092        elif template == 'cerebellum':
 3093            xyz=[169, 128, 137]
 3094        elif template == 'mtl':
 3095            xyz=[154, 112, 113]
 3096        elif template == 'ctx':
 3097            xyz=[233, 190, 174]
 3098        elif template == 'jhuwm':
 3099            xyz=[146, 133, 182]
 3100
 3101    if verbose:
 3102        print("plot xyz for " + template )
 3103        print( xyz )
 3104        
 3105    if outputfilename is None:
 3106        temp = ants.plot_ortho( seggm, overlay=toviz, crop=False,
 3107                        xyz=xyz, cbar_length=0.2, cbar_vertical=False,
 3108                        flat=True, xyz_lines=False, resample=False, orient_labels=False,
 3109                        title=ttl, titlefontsize=12, title_dy=-0.02, textfontcolor='red', 
 3110                        cbar=True, allow_xyz_change=False)
 3111    else:
 3112        temp = ants.plot_ortho( seggm, overlay=toviz, crop=False,
 3113                    xyz=xyz, cbar_length=0.2, cbar_vertical=False,
 3114                    flat=True, xyz_lines=False, resample=False, orient_labels=False,
 3115                    title=ttl, titlefontsize=12, title_dy=-0.02, textfontcolor='red', 
 3116                    cbar=True, allow_xyz_change=False, filename=outputfilename )
 3117    seggm = temp['image']
 3118    toviz = temp['overlay']
 3119    return { "underlay": seggm, 'overlay': toviz, 'seg': tcrop  }
 3120
 3121def get_data( name=None, force_download=False, version=26, target_extension='.csv' ):
 3122    """
 3123    Get ANTsPyMM data filename
 3124
 3125    The first time this is called, it will download data to ~/.antspymm.
 3126    After, it will just read data from disk.  The ~/.antspymm may need to
 3127    be periodically deleted in order to ensure data is current.
 3128
 3129    Arguments
 3130    ---------
 3131    name : string
 3132        name of data tag to retrieve
 3133        Options:
 3134            - 'all'
 3135
 3136    force_download: boolean
 3137
 3138    version: version of data to download (integer)
 3139
 3140    Returns
 3141    -------
 3142    string
 3143        filepath of selected data
 3144
 3145    Example
 3146    -------
 3147    >>> import antspymm
 3148    >>> antspymm.get_data()
 3149    """
 3150    os.makedirs(DATA_PATH, exist_ok=True)
 3151
 3152    def mv_subfolder_files(folder, verbose=False):
 3153        """
 3154        Move files from subfolders to the parent folder.
 3155
 3156        Parameters
 3157        ----------
 3158        folder : str
 3159            Path to the folder.
 3160        verbose : bool, optional
 3161            Print information about the files and folders being processed (default is False).
 3162
 3163        Returns
 3164        -------
 3165        None
 3166        """
 3167        import os
 3168        import shutil
 3169        for root, dirs, files in os.walk(folder):
 3170            if verbose:
 3171                print(f"Processing directory: {root}")
 3172                print(f"Subdirectories: {dirs}")
 3173                print(f"Files: {files}")
 3174            
 3175            for file in files:
 3176                if root != folder:
 3177                    if verbose:
 3178                        print(f"Moving file: {file} from {root} to {folder}")
 3179                    shutil.move(os.path.join(root, file), folder)
 3180            
 3181            for dir in dirs:
 3182                if root != folder:
 3183                    if verbose:
 3184                        print(f"Removing directory: {dir} from {root}")
 3185                    shutil.rmtree(os.path.join(root, dir))
 3186
 3187    def download_data( version ):
 3188        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
 3189        target_file_name = "16912366.zip"
 3190        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
 3191            cache_subdir=DATA_PATH, extract = True )
 3192        mv_subfolder_files( os.path.expanduser("~/.antspymm"), False )
 3193        os.remove( DATA_PATH + target_file_name )
 3194
 3195    if force_download:
 3196        download_data( version = version )
 3197
 3198
 3199    files = []
 3200    for fname in os.listdir(DATA_PATH):
 3201        if ( fname.endswith(target_extension) ) :
 3202            fname = os.path.join(DATA_PATH, fname)
 3203            files.append(fname)
 3204
 3205    if len( files ) == 0 :
 3206        download_data( version = version )
 3207        for fname in os.listdir(DATA_PATH):
 3208            if ( fname.endswith(target_extension) ) :
 3209                fname = os.path.join(DATA_PATH, fname)
 3210                files.append(fname)
 3211
 3212
 3213    if name == 'all':
 3214        return files
 3215
 3216    datapath = None
 3217
 3218    for fname in os.listdir(DATA_PATH):
 3219        mystem = (Path(fname).resolve().stem)
 3220        mystem = (Path(mystem).resolve().stem)
 3221        mystem = (Path(mystem).resolve().stem)
 3222        if ( name == mystem and fname.endswith(target_extension) ) :
 3223            datapath = os.path.join(DATA_PATH, fname)
 3224
 3225    return datapath
 3226
 3227
 3228def get_models( version=3, force_download=True ):
 3229    """
 3230    Get ANTsPyMM data models
 3231
 3232    force_download: boolean
 3233
 3234    Returns
 3235    -------
 3236    None
 3237
 3238    """
 3239    os.makedirs(DATA_PATH, exist_ok=True)
 3240
 3241    def download_data( version ):
 3242        url = "https://figshare.com/ndownloader/articles/21718412/versions/"+str(version)
 3243        target_file_name = "21718412.zip"
 3244        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
 3245            cache_subdir=DATA_PATH, extract = True )
 3246        os.remove( DATA_PATH + target_file_name )
 3247
 3248    if force_download:
 3249        download_data( version = version )
 3250    return
 3251
 3252
 3253
 3254def dewarp_imageset( image_list, initial_template=None,
 3255    iterations=None, padding=0, target_idx=[0], **kwargs ):
 3256    """
 3257    Dewarp a set of images
 3258
 3259    Makes simplifying heuristic decisions about how to transform an image set
 3260    into an unbiased reference space.  Will handle plenty of decisions
 3261    automatically so beware.  Computes an average shape space for the images
 3262    and transforms them to that space.
 3263
 3264    Arguments
 3265    ---------
 3266    image_list : list containing antsImages 2D, 3D or 4D
 3267
 3268    initial_template : optional
 3269
 3270    iterations : number of template building iterations
 3271
 3272    padding:  will pad the images by an integer amount to limit edge effects
 3273
 3274    target_idx : the target indices for the time series over which we should average;
 3275        a list of integer indices into the last axis of the input images.
 3276
 3277    kwargs : keyword args
 3278        arguments passed to ants registration - these must be set explicitly
 3279
 3280    Returns
 3281    -------
 3282    a dictionary with the mean image and the list of the transformed images as
 3283    well as motion correction parameters for each image in the input list
 3284
 3285    Example
 3286    -------
 3287    >>> import antspymm
 3288    """
 3289    outlist = []
 3290    avglist = []
 3291    if len(image_list[0].shape) > 3:
 3292        imagetype = 3
 3293        for k in range(len(image_list)):
 3294            for j in range(len(target_idx)):
 3295                avglist.append( ants.slice_image( image_list[k], axis=3, idx=target_idx[j] ) )
 3296    else:
 3297        imagetype = 0
 3298        avglist=image_list
 3299
 3300    pw=[]
 3301    for k in range(len(avglist[0].shape)):
 3302        pw.append( padding )
 3303    for k in range(len(avglist)):
 3304        avglist[k] = ants.pad_image( avglist[k], pad_width=pw  )
 3305
 3306    if initial_template is None:
 3307        initial_template = avglist[0] * 0
 3308        for k in range(len(avglist)):
 3309            initial_template = initial_template + avglist[k]/len(avglist)
 3310
 3311    if iterations is None:
 3312        iterations = 2
 3313
 3314    btp = ants.build_template(
 3315        initial_template=initial_template,
 3316        image_list=avglist,
 3317        gradient_step=0.5, blending_weight=0.8,
 3318        iterations=iterations, verbose=False, **kwargs )
 3319
 3320    # last - warp all images to this frame
 3321    mocoplist = []
 3322    mocofdlist = []
 3323    reglist = []
 3324    for k in range(len(image_list)):
 3325        if imagetype == 3:
 3326            moco0 = ants.motion_correction( image=image_list[k], fixed=btp, type_of_transform='antsRegistrationSyNRepro[r]' )
 3327            mocoplist.append( moco0['motion_parameters'] )
 3328            mocofdlist.append( moco0['FD'] )
 3329            locavg = ants.slice_image( moco0['motion_corrected'], axis=3, idx=0 ) * 0.0
 3330            for j in range(len(target_idx)):
 3331                locavg = locavg + ants.slice_image( moco0['motion_corrected'], axis=3, idx=target_idx[j] )
 3332            locavg = locavg * 1.0 / len(target_idx)
 3333        else:
 3334            locavg = image_list[k]
 3335        reg = ants.registration( btp, locavg, **kwargs )
 3336        reglist.append( reg )
 3337        if imagetype == 3:
 3338            myishape = image_list[k].shape
 3339            mytslength = myishape[ len(myishape) - 1 ]
 3340            mywarpedlist = []
 3341            for j in range(mytslength):
 3342                locimg = ants.slice_image( image_list[k], axis=3, idx = j )
 3343                mywarped = ants.apply_transforms( btp, locimg,
 3344                    reg['fwdtransforms'] + moco0['motion_parameters'][j], imagetype=0 )
 3345                mywarpedlist.append( mywarped )
 3346            mywarped = ants.list_to_ndimage( image_list[k], mywarpedlist )
 3347        else:
 3348            mywarped = ants.apply_transforms( btp, image_list[k], reg['fwdtransforms'], imagetype=imagetype )
 3349        outlist.append( mywarped )
 3350
 3351    return {
 3352        'dewarpedmean':btp,
 3353        'dewarped':outlist,
 3354        'deformable_registrations': reglist,
 3355        'FD': mocofdlist,
 3356        'motionparameters': mocoplist }
 3357
 3358
 3359def super_res_mcimage( image,
 3360    srmodel,
 3361    truncation=[0.0001,0.995],
 3362    poly_order='hist',
 3363    target_range=[0,1],
 3364    isotropic = False,
 3365    verbose=False ):
 3366    """
 3367    Super resolution on a timeseries or multi-channel image
 3368
 3369    Arguments
 3370    ---------
 3371    image : an antsImage
 3372
 3373    srmodel : a tensorflow fully convolutional model
 3374
 3375    truncation :  quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]
 3376
 3377    poly_order : if not None, will fit a global regression model to map
 3378        intensity back to original histogram space; if 'hist' will match
 3379        by histogram matching - ants.histogram_match_image
 3380
 3381    target_range : 2-element tuple
 3382        a tuple or array defining the (min, max) of the input image
 3383        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 3384        intensity. This range should match the mapping used in the training
 3385        of the network.
 3386
 3387    isotropic : boolean
 3388
 3389    verbose : boolean
 3390
 3391    Returns
 3392    -------
 3393    super resolution version of the image
 3394
 3395    Example
 3396    -------
 3397    >>> import antspymm
 3398    """
 3399    idim = image.dimension
 3400    ishape = image.shape
 3401    nTimePoints = ishape[idim - 1]
 3402    mcsr = list()
 3403    for k in range(nTimePoints):
 3404        if verbose and (( k % 5 ) == 0 ):
 3405            mycount = round(k / nTimePoints * 100)
 3406            print(mycount, end="%.", flush=True)
 3407        temp = ants.slice_image( image, axis=idim - 1, idx=k )
 3408        temp = ants.iMath( temp, "TruncateIntensity", truncation[0], truncation[1] )
 3409        mysr = antspynet.apply_super_resolution_model_to_image( temp, srmodel,
 3410            target_range = target_range )
 3411        if poly_order is not None:
 3412            bilin = ants.resample_image_to_target( temp, mysr )
 3413            if poly_order == 'hist':
 3414                mysr = ants.histogram_match_image( mysr, bilin )
 3415            else:
 3416                mysr = antspynet.regression_match_image( mysr, bilin, poly_order = poly_order )
 3417        if isotropic:
 3418            mysr = down2iso( mysr )
 3419        if k == 0:
 3420            upshape = list()
 3421            for j in range(len(ishape)-1):
 3422                upshape.append( mysr.shape[j] )
 3423            upshape.append( ishape[ idim-1 ] )
 3424            if verbose:
 3425                print("SR will be of voxel size:" + str(upshape) )
 3426        mcsr.append( mysr )
 3427
 3428    upshape = list()
 3429    for j in range(len(ishape)-1):
 3430        upshape.append( mysr.shape[j] )
 3431    upshape.append( ishape[ idim-1 ] )
 3432    if verbose:
 3433        print("SR will be of voxel size:" + str(upshape) )
 3434
 3435    imageup = ants.resample_image( image, upshape, use_voxels = True )
 3436    if verbose:
 3437        print("Done")
 3438
 3439    return ants.list_to_ndimage( imageup, mcsr )
 3440
 3441
 3442def segment_timeseries_by_bvalue(bvals):
 3443    """
 3444    Segments a time series based on a threshold applied to b-values.
 3445    
 3446    This function categorizes indices of the given b-values array into two groups:
 3447    one for indices where b-values are above a near-zero threshold, and another
 3448    where b-values are at or below this threshold. The threshold is set to 1e-12.
 3449    
 3450    Parameters:
 3451    - bvals (numpy.ndarray): An array of b-values.
 3452
 3453    Returns:
 3454    - dict: A dictionary with two keys, 'largerbvals' and 'lowbvals', each containing
 3455      the indices of bvals where the b-values are above and at/below the threshold, respectively.
 3456    """
 3457    # Define the threshold
 3458    threshold = 1e-12
 3459    def find_min_value(data):
 3460        if isinstance(data, list):
 3461            return min(data)
 3462        elif isinstance(data, np.ndarray):
 3463            return np.min(data)
 3464        else:
 3465            raise TypeError("Input must be either a list or a numpy array")
 3466
 3467    # Get indices where b-values are greater than the threshold
 3468    lowermeans = list(np.where(bvals > threshold)[0])
 3469    
 3470    # Get indices where b-values are less than or equal to the threshold
 3471    highermeans = list(np.where(bvals <= threshold)[0])
 3472    
 3473    if len(highermeans) == 0:
 3474        minval = find_min_value( bvals )
 3475        lowermeans = list(np.where(bvals > minval )[0])
 3476        highermeans = list(np.where(bvals <= minval)[0])
 3477
 3478    return {
 3479        'largerbvals': lowermeans,
 3480        'lowbvals': highermeans
 3481    }
 3482
 3483def segment_timeseries_by_meanvalue( image, quantile = 0.995 ):
 3484    """
 3485    Identify indices of a time series where we assume there is a different mean
 3486    intensity over the volumes.  The indices of volumes with higher and lower
 3487    intensities is returned.  Can be used to automatically identify B0 volumes
 3488    in DWI timeseries.
 3489
 3490    Arguments
 3491    ---------
 3492    image : an antsImage holding B0 and DWI
 3493
 3494    quantile : a quantile for splitting the indices of the volume - should be greater than 0.5
 3495
 3496    Returns
 3497    -------
 3498    dictionary holding the two sets of indices
 3499
 3500    Example
 3501    -------
 3502    >>> import antspymm
 3503    """
 3504    ishape = image.shape
 3505    lastdim = len(ishape)-1
 3506    meanvalues = list()
 3507    for x in range(ishape[lastdim]):
 3508        meanvalues.append(  ants.slice_image( image, axis=lastdim, idx=x ).mean() )
 3509    myhiq = np.quantile( meanvalues, quantile )
 3510    myloq = np.quantile( meanvalues, 1.0 - quantile )
 3511    lowerindices = list()
 3512    higherindices = list()
 3513    for x in range(len(meanvalues)):
 3514        hiabs = abs( meanvalues[x] - myhiq )
 3515        loabs = abs( meanvalues[x] - myloq )
 3516        if hiabs < loabs:
 3517            higherindices.append(x)
 3518        else:
 3519            lowerindices.append(x)
 3520
 3521    return {
 3522    'lowermeans':lowerindices,
 3523    'highermeans':higherindices }
 3524
 3525
 3526def get_average_rsf( x, min_t=10, max_t=35 ):
 3527    """
 3528    automatically generates the average bold image with quick registration
 3529
 3530    returns:
 3531        avg_bold
 3532    """
 3533    output_directory = tempfile.mkdtemp()
 3534    ofn = output_directory + "/w"
 3535    bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3536    oavg = ants.slice_image( x, axis=3, idx=0 )
 3537    if x.shape[3] <= min_t:
 3538        min_t=0
 3539    if x.shape[3] <= max_t:
 3540        max_t=x.shape[3]
 3541    for myidx in range(min_t,max_t):
 3542        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3543        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3544    bavg = ants.iMath( bavg, 'Normalize' )
 3545    oavg = ants.image_clone( bavg )
 3546    bavg = oavg * 0.0
 3547    for myidx in range(min_t,max_t):
 3548        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3549        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3550    import shutil
 3551    shutil.rmtree(output_directory, ignore_errors=True )
 3552    bavg = ants.iMath( bavg, 'Normalize' )
 3553    return bavg
 3554    # return ants.n4_bias_field_correction(bavg, mask=ants.get_mask( bavg ) )
 3555
 3556
 3557def get_average_dwi_b0( x, fixed_b0=None, fixed_dwi=None, fast=False ):
 3558    """
 3559    automatically generates the average b0 and dwi and outputs both;
 3560    maps dwi to b0 space at end.
 3561
 3562    x : input image
 3563
 3564    fixed_b0 : alernative reference space
 3565
 3566    fixed_dwi : alernative reference space
 3567
 3568    fast : boolean
 3569
 3570    returns:
 3571        avg_b0, avg_dwi
 3572    """
 3573    output_directory = tempfile.mkdtemp()
 3574    ofn = output_directory + "/w"
 3575    temp = segment_timeseries_by_meanvalue( x )
 3576    b0_idx = temp['highermeans']
 3577    non_b0_idx = temp['lowermeans']
 3578    if ( fixed_b0 is None and fixed_dwi is None ) or fast:
 3579        xavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3580        bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
 3581        fixed_b0_use = ants.slice_image( x, axis=3, idx=b0_idx[0] )
 3582        fixed_dwi_use = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
 3583    else:
 3584        temp_b0 = ants.slice_image( x, axis=3, idx=b0_idx[0] )
 3585        temp_dwi = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
 3586        xavg = fixed_b0 * 0.0
 3587        bavg = fixed_b0 * 0.0
 3588        tempreg = ants.registration( fixed_b0, temp_b0,'antsRegistrationSyNRepro[r]')
 3589        fixed_b0_use = tempreg['warpedmovout']
 3590        fixed_dwi_use = ants.apply_transforms( fixed_b0, temp_dwi, tempreg['fwdtransforms'] )
 3591    for myidx in range(x.shape[3]):
 3592        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3593        if not fast:
 3594            if not myidx in b0_idx:
 3595                xavg = xavg + ants.registration(fixed_dwi_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3596            else:
 3597                bavg = bavg + ants.registration(fixed_b0_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
 3598        else:
 3599            if not myidx in b0_idx:
 3600                xavg = xavg + b0
 3601            else:
 3602                bavg = bavg + b0
 3603    bavg = ants.iMath( bavg, 'Normalize' )
 3604    xavg = ants.iMath( xavg, 'Normalize' )
 3605    import shutil
 3606    shutil.rmtree(output_directory, ignore_errors=True )
 3607    avgb0=ants.n4_bias_field_correction(bavg)
 3608    avgdwi=ants.n4_bias_field_correction(xavg)
 3609    avgdwi=ants.registration( avgb0, avgdwi, 'antsRegistrationSyNRepro[r]' )['warpedmovout']
 3610    return avgb0, avgdwi
 3611
 3612def dti_template(
 3613    b_image_list=None,
 3614    w_image_list=None,
 3615    iterations=5,
 3616    gradient_step=0.5,
 3617    mask_csf=False,
 3618    average_both=True,
 3619    verbose=False
 3620):
 3621    """
 3622    two channel version of build_template
 3623
 3624    returns:
 3625        avg_b0, avg_dwi
 3626    """
 3627    output_directory = tempfile.mkdtemp()
 3628    mydeftx = tempfile.NamedTemporaryFile(delete=False,dir=output_directory).name
 3629    tmp = tempfile.NamedTemporaryFile(delete=False,dir=output_directory,suffix=".nii.gz")
 3630    wavgfn = tmp.name
 3631    tmp2 = tempfile.NamedTemporaryFile(delete=False,dir=output_directory)
 3632    comptx = tmp2.name
 3633    weights = np.repeat(1.0 / len(b_image_list), len(b_image_list))
 3634    weights = [x / sum(weights) for x in weights]
 3635    w_initial_template = w_image_list[0]
 3636    b_initial_template = b_image_list[0]
 3637    b_initial_template = ants.iMath(b_initial_template,"Normalize")
 3638    w_initial_template = ants.iMath(w_initial_template,"Normalize")
 3639    if mask_csf:
 3640        bcsf0 = ants.threshold_image( b_image_list[0],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 3641        bcsf1 = ants.threshold_image( b_image_list[1],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
 3642    else:
 3643        bcsf0 = b_image_list[0] * 0 + 1
 3644        bcsf1 = b_image_list[1] * 0 + 1
 3645    bavg = b_initial_template.clone() * bcsf0
 3646    wavg = w_initial_template.clone() * bcsf0
 3647    bcsf = [ bcsf0, bcsf1 ]
 3648    for i in range(iterations):
 3649        for k in range(len(w_image_list)):
 3650            fimg=wavg
 3651            mimg=w_image_list[k] * bcsf[k]
 3652            fimg2=bavg
 3653            mimg2=b_image_list[k] * bcsf[k]
 3654            w1 = ants.registration(
 3655                fimg, mimg, type_of_transform='antsRegistrationSyNQuickRepro[s]',
 3656                    multivariate_extras= [ [ "CC", fimg2, mimg2, 1, 2 ]],
 3657                    outprefix=mydeftx,
 3658                    verbose=0 )
 3659            txname = ants.apply_transforms(wavg, wavg,
 3660                w1["fwdtransforms"], compose=comptx )
 3661            if k == 0:
 3662                txavg = ants.image_read(txname) * weights[k]
 3663                wavgnew = ants.apply_transforms( wavg,
 3664                    w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3665                bavgnew = ants.apply_transforms( wavg,
 3666                    b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3667            else:
 3668                txavg = txavg + ants.image_read(txname) * weights[k]
 3669                if i >= (iterations-2) and average_both:
 3670                    wavgnew = wavgnew+ants.apply_transforms( wavg,
 3671                        w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3672                    bavgnew = bavgnew+ants.apply_transforms( wavg,
 3673                        b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
 3674        if verbose:
 3675            print("iteration:",str(i),str(txavg.abs().mean()))
 3676        wscl = (-1.0) * gradient_step
 3677        txavg = txavg * wscl
 3678        ants.image_write( txavg, wavgfn )
 3679        wavg = ants.apply_transforms(wavg, wavgnew, wavgfn).iMath("Normalize")
 3680        bavg = ants.apply_transforms(bavg, bavgnew, wavgfn).iMath("Normalize")
 3681    import shutil
 3682    shutil.rmtree( output_directory, ignore_errors=True )
 3683    if verbose:
 3684        print("done")
 3685    return bavg, wavg
 3686
 3687def read_ants_transforms_to_numpy(transform_files ):
 3688    """
 3689    Read a list of ANTs transform files and convert them to a NumPy array.
 3690    The function filters out any files that are not .mat and will only  use
 3691    the first .mat in each entry of the list.
 3692
 3693    :param transform_files: List of lists of file paths to ANTs transform files.  
 3694    :return: NumPy array of the transforms.
 3695    """
 3696    extension = '.mat'
 3697    # Filter the list of lists
 3698    filtered_lists = [[string for string in sublist if string.endswith(extension)] 
 3699                    for sublist in transform_files]
 3700    transforms = []
 3701    for file in filtered_lists:
 3702        transform = ants.read_transform(file[0])
 3703        np_transform = np.array(ants.get_ants_transform_parameters(transform)[0:9])
 3704        transforms.append(np_transform)
 3705    return np.array(transforms)
 3706
 3707def t1_based_dwi_brain_extraction(
 3708    t1w_head,
 3709    t1w,
 3710    dwi,
 3711    b0_idx = None,
 3712    transform='antsRegistrationSyNRepro[r]',
 3713    deform=None,
 3714    verbose=False
 3715):
 3716    """
 3717    Map a t1-based brain extraction to b0 and return a mask and average b0
 3718
 3719    Arguments
 3720    ---------
 3721    t1w_head : an antsImage of the hole head
 3722
 3723    t1w : an antsImage probably but not necessarily T1-weighted
 3724
 3725    dwi : an antsImage holding B0 and DWI
 3726
 3727    b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess
 3728
 3729    transform : string Rigid or other ants.registration tx type
 3730
 3731    deform : follow up transform with deformation
 3732
 3733    Returns
 3734    -------
 3735    dictionary holding the avg_b0 and its mask
 3736
 3737    Example
 3738    -------
 3739    >>> import antspymm
 3740    """
 3741    t1w_use = ants.iMath( t1w, "Normalize" )
 3742    t1bxt = ants.threshold_image( t1w_use, 0.05, 1 ).iMath("FillHoles")
 3743    if b0_idx is None:
 3744        b0_idx = segment_timeseries_by_meanvalue( dwi )['highermeans']
 3745    # first get the average b0
 3746    if len( b0_idx ) > 1:
 3747        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] ).iMath("Normalize")
 3748        for n in range(1,len(b0_idx)):
 3749            temp = ants.slice_image( dwi, axis=3, idx=b0_idx[n] )
 3750            reg = ants.registration( b0_avg, temp, 'antsRegistrationSyNRepro[r]' )
 3751            b0_avg = b0_avg + ants.iMath( reg['warpedmovout'], "Normalize")
 3752    else:
 3753        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] )
 3754    b0_avg = ants.iMath(b0_avg,"Normalize")
 3755    reg = tra_initializer( b0_avg, t1w, n_simulations=12,   verbose=verbose )
 3756    if deform is not None:
 3757        reg = ants.registration( b0_avg, t1w,
 3758            'SyNOnly',
 3759            total_sigma=0.5,
 3760            initial_transform=reg['fwdtransforms'][0],
 3761            verbose=False )
 3762    outmsk = ants.apply_transforms( b0_avg, t1bxt, reg['fwdtransforms'], interpolator='linear').threshold_image( 0.5, 1.0 )
 3763    return  {
 3764    'b0_avg':b0_avg,
 3765    'b0_mask':outmsk }
 3766
 3767def mc_denoise( x, ratio = 0.5 ):
 3768    """
 3769    ants denoising for timeseries (4D)
 3770
 3771    Arguments
 3772    ---------
 3773    x : an antsImage 4D
 3774
 3775    ratio : weight between 1 and 0 - lower weights bring result closer to initial image
 3776
 3777    Returns
 3778    -------
 3779    denoised time series
 3780
 3781    """
 3782    dwpimage = []
 3783    for myidx in range(x.shape[3]):
 3784        b0 = ants.slice_image( x, axis=3, idx=myidx)
 3785        dnzb0 = ants.denoise_image( b0, p=1,r=1,noise_model='Gaussian' )
 3786        dwpimage.append( dnzb0 * ratio + b0 * (1.0-ratio) )
 3787    return ants.list_to_ndimage( x, dwpimage )
 3788
 3789def tsnr( x, mask, indices=None ):
 3790    """
 3791    3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1
 3792
 3793    x: image
 3794
 3795    mask : mask
 3796
 3797    indices: indices to use
 3798
 3799    returns a 3D image
 3800    """
 3801    M = ants.timeseries_to_matrix( x, mask )
 3802    M = M - M.min()
 3803    M = M / M.max()
 3804    if indices is not None:
 3805        M=M[indices,:]
 3806    stdM = np.std(M, axis=0 )
 3807    stdM[np.isnan(stdM)] = 0
 3808    tt = round( 0.975*100 )
 3809    threshold_std = np.percentile( stdM, tt )
 3810    tsnrimage = ants.make_image( mask, stdM )
 3811    return tsnrimage
 3812
 3813def dvars( x,  mask, indices=None ):
 3814    """
 3815    dvars on a time series image ... the matrix is normalized to range of 0,1
 3816
 3817    x: image
 3818
 3819    mask : mask
 3820
 3821    indices: indices to use
 3822
 3823    returns an array
 3824    """
 3825    M = ants.timeseries_to_matrix( x, mask )
 3826    M = M - M.min()
 3827    M = M / M.max()
 3828    if indices is not None:
 3829        M=M[indices,:]
 3830    DVARS = np.zeros( M.shape[0] )
 3831    for i in range(1, M.shape[0] ):
 3832        vecdiff = M[i-1,:] - M[i,:]
 3833        DVARS[i] = np.sqrt( ( vecdiff * vecdiff ).mean() )
 3834    DVARS[0] = DVARS.mean()
 3835    return DVARS
 3836
 3837
 3838def slice_snr( x,  background_mask, foreground_mask, indices=None ):
 3839    """
 3840    slice-wise SNR on a time series image
 3841
 3842    x: image
 3843
 3844    background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask
 3845
 3846    foreground_mask : mask - maybe cortex or WM or brain mask
 3847
 3848    indices: indices to use
 3849
 3850    returns an array
 3851    """
 3852    xuse=ants.iMath(x,"Normalize")
 3853    MB = ants.timeseries_to_matrix( xuse, background_mask )
 3854    MF = ants.timeseries_to_matrix( xuse, foreground_mask )
 3855    if indices is not None:
 3856        MB=MB[indices,:]
 3857        MF=MF[indices,:]
 3858    ssnr = np.zeros( MB.shape[0] )
 3859    for i in range( MB.shape[0] ):
 3860        ssnr[i]=MF[i,:].mean()/MB[i,:].std()
 3861    ssnr[np.isnan(ssnr)] = 0
 3862    return ssnr
 3863
 3864
 3865def impute_fa( fa, md ):
 3866    """
 3867    impute bad values in dti, fa, md
 3868    """
 3869    def imputeit( x, fa ):
 3870        badfa=ants.threshold_image(fa,1,1)
 3871        if badfa.max() == 1:
 3872            temp=ants.image_clone(x)
 3873            temp[badfa==1]=0
 3874            temp=ants.iMath(temp,'GD',2)
 3875            x[ badfa==1 ]=temp[badfa==1]
 3876        return x
 3877    md=imputeit( md, fa )
 3878    fa=imputeit( ants.image_clone(fa), fa )
 3879    return fa, md
 3880
 3881def trim_dti_mask( fa, mask, param=4.0 ):
 3882    """
 3883    trim the dti mask to get rid of bright fa rim
 3884
 3885    this function erodes the famask by param amount then segments the rim into
 3886    bright and less bright parts.  the bright parts are trimmed from the mask
 3887    and the remaining edges are cleaned up a bit with closing.
 3888
 3889    param: closing radius unit is in physical space
 3890    """
 3891    spacing = ants.get_spacing(mask)
 3892    spacing_product = np.prod( spacing )
 3893    spcmin = min( spacing )
 3894    paramVox = int(np.round( param / spcmin ))
 3895    trim_mask = ants.image_clone( mask )
 3896    trim_mask = ants.iMath( trim_mask, "FillHoles" )
 3897    edgemask = trim_mask - ants.iMath( trim_mask, "ME", paramVox )
 3898    maxk=4
 3899    edgemask = ants.threshold_image( fa * edgemask, "Otsu", maxk )
 3900    edgemask = ants.threshold_image( edgemask, maxk-1, maxk )
 3901    trim_mask[edgemask >= 1 ]=0
 3902    trim_mask = ants.iMath(trim_mask,"ME",paramVox-1)
 3903    trim_mask = ants.iMath(trim_mask,'GetLargestComponent')
 3904    trim_mask = ants.iMath(trim_mask,"MD",paramVox-1)
 3905    return trim_mask
 3906
 3907
 3908
 3909def efficient_tensor_fit( gtab, fit_method, imagein, maskin, diffusion_model='DTI', 
 3910                         chunk_size=10, num_threads=1, verbose=True):
 3911    """
 3912    Efficient and optionally parallelized tensor reconstruction using DiPy.
 3913
 3914    Parameters
 3915    ----------
 3916    gtab : GradientTable
 3917        Dipy gradient table.
 3918    fit_method : str
 3919        Tensor fitting method (e.g. 'WLS', 'OLS', 'RESTORE').
 3920    imagein : ants.ANTsImage
 3921        4D diffusion-weighted image.
 3922    maskin : ants.ANTsImage
 3923        Binary brain mask image.
 3924    diffusion_model : string, optional
 3925        DTI, FreeWater, DKI.
 3926    chunk_size : int, optional
 3927        Number of slices (along z-axis) to process at once.
 3928    num_threads : int, optional
 3929        Number of threads to use (1 = single-threaded).
 3930    verbose : bool, optional
 3931        Print status updates.
 3932    
 3933    Returns
 3934    -------
 3935    tenfit : TensorFit or FreeWaterTensorFit
 3936        Fitted tensor model.
 3937    FA : ants.ANTsImage
 3938        Fractional anisotropy image.
 3939    MD : ants.ANTsImage
 3940        Mean diffusivity image.
 3941    RGB : ants.ANTsImage
 3942        RGB FA map.
 3943    """
 3944    assert imagein.dimension == 4, "Input image must be 4D"
 3945
 3946    import ants
 3947    import numpy as np
 3948    import dipy.reconst.dti as dti
 3949    import dipy.reconst.fwdti as fwdti
 3950    from dipy.reconst.dti import fractional_anisotropy
 3951    from dipy.reconst.dti import color_fa
 3952    from concurrent.futures import ThreadPoolExecutor, as_completed
 3953
 3954    img_data = imagein.numpy()
 3955    mask = maskin.numpy().astype(bool)
 3956    X, Y, Z, N = img_data.shape
 3957    if verbose:
 3958        print(f"Input shape: {img_data.shape}, Processing in chunks of {chunk_size} slices.")
 3959
 3960    model = fwdti.FreeWaterTensorModel(gtab) if diffusion_model == 'FreeWater' else dti.TensorModel(gtab, fit_method=fit_method)
 3961
 3962    def process_chunk(z_start):
 3963        z_end = min(Z, z_start + chunk_size)
 3964        local_data = img_data[:, :, z_start:z_end, :]
 3965        local_mask = mask[:, :, z_start:z_end]
 3966        masked_data = local_data * local_mask[..., None]
 3967        masked_data = np.nan_to_num(masked_data, nan=0)
 3968        fit = model.fit(masked_data)
 3969        FA_chunk = fractional_anisotropy(fit.evals)
 3970        FA_chunk[np.isnan(FA_chunk)] = 1
 3971        FA_chunk = np.clip(FA_chunk, 0, 1)
 3972        MD_chunk = dti.mean_diffusivity(fit.evals)
 3973        RGB_chunk = color_fa(FA_chunk, fit.evecs)
 3974        return z_start, z_end, FA_chunk, MD_chunk, RGB_chunk
 3975
 3976    FA_vol = np.zeros((X, Y, Z), dtype=np.float32)
 3977    MD_vol = np.zeros((X, Y, Z), dtype=np.float32)
 3978    RGB_vol = np.zeros((X, Y, Z, 3), dtype=np.float32)
 3979
 3980    chunks = range(0, Z, chunk_size)
 3981    if num_threads > 1:
 3982        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 3983            futures = {executor.submit(process_chunk, z): z for z in chunks}
 3984            for f in as_completed(futures):
 3985                z_start, z_end, FA_chunk, MD_chunk, RGB_chunk = f.result()
 3986                FA_vol[:, :, z_start:z_end] = FA_chunk
 3987                MD_vol[:, :, z_start:z_end] = MD_chunk
 3988                RGB_vol[:, :, z_start:z_end, :] = RGB_chunk
 3989    else:
 3990        for z in chunks:
 3991            z_start, z_end, FA_chunk, MD_chunk, RGB_chunk = process_chunk(z)
 3992            FA_vol[:, :, z_start:z_end] = FA_chunk
 3993            MD_vol[:, :, z_start:z_end] = MD_chunk
 3994            RGB_vol[:, :, z_start:z_end, :] = RGB_chunk
 3995
 3996    b0 = ants.slice_image(imagein, axis=3, idx=0)
 3997    FA = ants.copy_image_info(b0, ants.from_numpy(FA_vol))
 3998    MD = ants.copy_image_info(b0, ants.from_numpy(MD_vol))
 3999    RGB_channels = [ants.copy_image_info(b0, ants.from_numpy(RGB_vol[..., i])) for i in range(3)]
 4000    RGB = ants.merge_channels(RGB_channels)
 4001
 4002    return model.fit(img_data * mask[..., None]), FA, MD, RGB
 4003
 4004
 4005
 4006def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
 4007                      model_params=None, bvals_to_use=None,
 4008                      chunk_size=1024, num_threads=1, verbose=True):
 4009    """
 4010    Efficient and optionally parallelized diffusion model reconstruction using DiPy.
 4011
 4012    Parameters
 4013    ----------
 4014    gtab : GradientTable
 4015        DiPy gradient table.
 4016    diffusion_model : str
 4017        One of ['DTI', 'FreeWater', 'DKI'].
 4018    imagein : ants.ANTsImage
 4019        4D diffusion-weighted image.
 4020    maskin : ants.ANTsImage
 4021        Binary brain mask image.
 4022    model_params : dict, optional
 4023        Additional parameters passed to model constructors.
 4024    bvals_to_use : list of int, optional
 4025        Subset of b-values to use for the fit (e.g., [0, 1000, 2000]).
 4026    chunk_size : int, optional
 4027        Maximum number of voxels per chunk (default 1024).
 4028    num_threads : int, optional
 4029        Number of parallel threads.
 4030    verbose : bool, optional
 4031        Whether to print status messages.
 4032
 4033    Returns
 4034    -------
 4035    fit : dipy ModelFit
 4036        The fitted model object.
 4037    FA : ants.ANTsImage or None
 4038        Fractional anisotropy image (if applicable).
 4039    MD : ants.ANTsImage or None
 4040        Mean diffusivity image (if applicable).
 4041    RGB : ants.ANTsImage or None
 4042        Color FA image (if applicable).
 4043    """
 4044    import ants
 4045    import numpy as np
 4046    import dipy.reconst.dti as dti
 4047    import dipy.reconst.fwdti as fwdti
 4048    import dipy.reconst.dki as dki
 4049    from dipy.core.gradients import gradient_table
 4050    from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
 4051    from concurrent.futures import ThreadPoolExecutor, as_completed
 4052
 4053    assert imagein.dimension == 4, "Input image must be 4D"
 4054    model_params = model_params or {}
 4055
 4056    img_data = imagein.numpy()
 4057    mask = maskin.numpy().astype(bool)
 4058    X, Y, Z, N = img_data.shape
 4059    inplane_size = X * Y
 4060
 4061    # Convert chunk_size from voxel count to number of slices
 4062    slices_per_chunk = max(1, chunk_size // inplane_size)
 4063
 4064    if verbose:
 4065        print(f"[INFO] Image shape: {img_data.shape}")
 4066        print(f"[INFO] Using model: {diffusion_model}")
 4067        print(f"[INFO] Max voxels per chunk: {chunk_size} (→ {slices_per_chunk} slices) | Threads: {num_threads}")
 4068
 4069    if bvals_to_use is not None:
 4070        bvals_to_use = set(bvals_to_use)
 4071        sel = np.isin(gtab.bvals, list(bvals_to_use))
 4072        img_data = img_data[..., sel]
 4073        gtab = gradient_table(gtab.bvals[sel], bvecs=gtab.bvecs[sel])
 4074        if verbose:
 4075            print(f"[INFO] Selected b-values: {sorted(bvals_to_use)}")
 4076            print(f"[INFO] Selected volumes: {sel.sum()} / {N}")
 4077
 4078    def get_model(name, gtab, **params):
 4079        if name == 'DTI':
 4080            return dti.TensorModel(gtab, **params)
 4081        elif name == 'FreeWater':
 4082            return fwdti.FreeWaterTensorModel(gtab)
 4083        elif name == 'DKI':
 4084            return dki.DiffusionKurtosisModel(gtab, **params)
 4085        else:
 4086            raise ValueError(f"Unsupported model: {name}")
 4087
 4088    model = get_model(diffusion_model, gtab, **model_params)
 4089
 4090    FA_vol = np.zeros((X, Y, Z), dtype=np.float32)
 4091    MD_vol = np.zeros((X, Y, Z), dtype=np.float32)
 4092    RGB_vol = np.zeros((X, Y, Z, 3), dtype=np.float32)
 4093    has_tensor_metrics = diffusion_model in ['DTI', 'FreeWater']
 4094
 4095    def process_chunk(z_start):
 4096        z_end = min(Z, z_start + slices_per_chunk)
 4097        local_data = img_data[:, :, z_start:z_end, :]
 4098        local_mask = mask[:, :, z_start:z_end]
 4099        masked_data = local_data * local_mask[..., None]
 4100        masked_data = np.nan_to_num(masked_data, nan=0)
 4101        fit = model.fit(masked_data)
 4102        if has_tensor_metrics and hasattr(fit, 'evals') and hasattr(fit, 'evecs'):
 4103            FA = fractional_anisotropy(fit.evals)
 4104            FA[np.isnan(FA)] = 1
 4105            FA = np.clip(FA, 0, 1)
 4106            MD = mean_diffusivity(fit.evals)
 4107            RGB = color_fa(FA, fit.evecs)
 4108            return z_start, z_end, FA, MD, RGB
 4109        return z_start, z_end, None, None, None
 4110
 4111    chunks = range(0, Z, slices_per_chunk)
 4112    if num_threads > 1:
 4113        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 4114            futures = {executor.submit(process_chunk, z): z for z in chunks}
 4115            for f in as_completed(futures):
 4116                z_start, z_end, FA, MD, RGB = f.result()
 4117                if FA is not None:
 4118                    FA_vol[:, :, z_start:z_end] = FA
 4119                    MD_vol[:, :, z_start:z_end] = MD
 4120                    RGB_vol[:, :, z_start:z_end, :] = RGB
 4121    else:
 4122        for z in chunks:
 4123            z_start, z_end, FA, MD, RGB = process_chunk(z)
 4124            if FA is not None:
 4125                FA_vol[:, :, z_start:z_end] = FA
 4126                MD_vol[:, :, z_start:z_end] = MD
 4127                RGB_vol[:, :, z_start:z_end, :] = RGB
 4128
 4129    b0 = ants.slice_image(imagein, axis=3, idx=0)
 4130    FA_img = ants.copy_image_info(b0, ants.from_numpy(FA_vol)) if has_tensor_metrics else None
 4131    MD_img = ants.copy_image_info(b0, ants.from_numpy(MD_vol)) if has_tensor_metrics else None
 4132    RGB_img = (ants.merge_channels([
 4133        ants.copy_image_info(b0, ants.from_numpy(RGB_vol[..., i])) for i in range(3)
 4134    ]) if has_tensor_metrics else None)
 4135
 4136    full_fit = model.fit(img_data * mask[..., None])
 4137    return full_fit, FA_img, MD_img, RGB_img
 4138
 4139
 4140def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
 4141                                bvals_to_use=None, num_threads=1, verbose=True):
 4142    """
 4143    Voxel-wise diffusion model fitting with individual b-vectors per voxel.
 4144
 4145    Parameters
 4146    ----------
 4147    imagein : ants.ANTsImage
 4148        4D DWI image (X, Y, Z, N).
 4149    maskin : ants.ANTsImage
 4150        3D binary mask.
 4151    bvals : (N,) array-like
 4152        Common b-values across volumes.
 4153    bvecs_5d : (X, Y, Z, N, 3) ndarray
 4154        Voxel-specific b-vectors.
 4155    model_params : dict
 4156        Extra arguments for model.
 4157    bvals_to_use : list[int]
 4158        Subset of b-values to include.
 4159    num_threads : int
 4160        Number of threads to use.
 4161    verbose : bool
 4162        Whether to print status.
 4163
 4164    Returns
 4165    -------
 4166    FA_img : ants.ANTsImage
 4167        Fractional anisotropy.
 4168    MD_img : ants.ANTsImage
 4169        Mean diffusivity.
 4170    RGB_img : ants.ANTsImage
 4171        RGB FA image.
 4172    """
 4173    import numpy as np
 4174    import ants
 4175    import dipy.reconst.dti as dti
 4176    from dipy.core.gradients import gradient_table
 4177    from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
 4178    from concurrent.futures import ThreadPoolExecutor
 4179    from tqdm import tqdm
 4180
 4181    model_params = model_params or {}
 4182    img = imagein.numpy()
 4183    mask = maskin.numpy().astype(bool)
 4184    X, Y, Z, N = img.shape
 4185
 4186    if bvals_to_use is not None:
 4187        sel = np.isin(bvals, bvals_to_use)
 4188        img = img[..., sel]
 4189        bvals = bvals[sel]
 4190        bvecs_5d = bvecs_5d[..., sel, :]
 4191
 4192    FA = np.zeros((X, Y, Z), dtype=np.float32)
 4193    MD = np.zeros((X, Y, Z), dtype=np.float32)
 4194    RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
 4195
 4196    def fit_voxel(ix, iy, iz):
 4197        if not mask[ix, iy, iz]:
 4198            return
 4199        sig = img[ix, iy, iz, :]
 4200        if np.all(sig == 0):
 4201            return
 4202        bv = bvecs_5d[ix, iy, iz, :, :]
 4203        gtab = gradient_table(bvals, bv)
 4204        try:
 4205            model = dti.TensorModel(gtab, **model_params)
 4206            fit = model.fit(sig)
 4207            evals = fit.evals
 4208            evecs = fit.evecs
 4209            FA[ix, iy, iz] = fractional_anisotropy(evals)
 4210            MD[ix, iy, iz] = mean_diffusivity(evals)
 4211            RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
 4212        except Exception as e:
 4213            if verbose:
 4214                print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
 4215
 4216    coords = np.argwhere(mask)
 4217    if verbose:
 4218        print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
 4219
 4220    if num_threads > 1:
 4221        with ThreadPoolExecutor(max_workers=num_threads) as executor:
 4222            list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
 4223    else:
 4224        for c in tqdm(coords):
 4225            fit_voxel(*c)
 4226
 4227    ref = ants.slice_image(imagein, axis=3, idx=0)
 4228    return (
 4229        ants.copy_image_info(ref, ants.from_numpy(FA)),
 4230        ants.copy_image_info(ref, ants.from_numpy(MD)),
 4231        ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
 4232    )
 4233
 4234
 4235def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
 4236    """
 4237    Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
 4238
 4239    Parameters
 4240    ----------
 4241    global_bvecs : ndarray of shape (N, 3)
 4242        Global diffusion gradient directions.
 4243    voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
 4244        3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
 4245    transpose : bool, optional
 4246        If True, transpose the rotation matrices before applying them to the b-vectors.
 4247
 4248
 4249    Returns
 4250    -------
 4251    bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
 4252        Voxel-specific b-vectors.
 4253    """
 4254    X, Y, Z, _, _ = voxel_rotations.shape
 4255    N = global_bvecs.shape[0]
 4256    bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
 4257
 4258    for n in range(N):
 4259        bvec = global_bvecs[n]
 4260        for i in range(X):
 4261            for j in range(Y):
 4262                for k in range(Z):
 4263                    R = voxel_rotations[i, j, k]
 4264                    if transpose:
 4265                        R = R.T  # Use transpose if needed
 4266                    bvecs_5d[i, j, k, n, :] = R @ bvec
 4267    return bvecs_5d
 4268
 4269def dipy_dti_recon(
 4270    image,
 4271    bvalsfn,
 4272    bvecsfn,
 4273    mask = None,
 4274    b0_idx = None,
 4275    mask_dilation = 2,
 4276    mask_closing = 5,
 4277    fit_method='WLS',
 4278    trim_the_mask=2.0,
 4279    diffusion_model='DTI',
 4280    verbose=False ):
 4281    """
 4282    DiPy DTI reconstruction - building on the DiPy basic DTI example
 4283
 4284    Arguments
 4285    ---------
 4286    image : an antsImage holding B0 and DWI
 4287
 4288    bvalsfn : bvalues  obtained by dipy read_bvals_bvecs or the values themselves
 4289
 4290    bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves
 4291
 4292    mask : brain mask for the DWI/DTI reconstruction; if it is not in the same
 4293        space as the image, we will resample directly to the image space.  This
 4294        could lead to problems if the inputs are really incorrect.
 4295
 4296    b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue
 4297
 4298    mask_dilation : integer zero or more dilates the brain mask
 4299
 4300    mask_closing : integer zero or more closes the brain mask
 4301
 4302    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.
 4303
 4304    trim_the_mask : float >=0 post-hoc method for trimming the mask
 4305
 4306    diffusion_model : string
 4307        DTI, FreeWater, DKI
 4308
 4309    verbose : boolean
 4310
 4311    Returns
 4312    -------
 4313    dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)
 4314
 4315    NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)
 4316
 4317    NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect
 4318        the tailing image volumes.
 4319
 4320    Example
 4321    -------
 4322    >>> import antspymm
 4323    """
 4324
 4325    import dipy.reconst.fwdti as fwdti
 4326
 4327    if isinstance(bvecsfn, str):
 4328        bvals, bvecs = read_bvals_bvecs( bvalsfn , bvecsfn   )
 4329    else: # assume we already read them
 4330        bvals = bvalsfn.copy()
 4331        bvecs = bvecsfn.copy()
 4332
 4333    if bvals.max() < 1.0:
 4334        raise ValueError("DTI recon error: maximum bvalues are too small.")
 4335
 4336    b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
 4337
 4338    b0 = ants.slice_image( image, axis=3, idx=b0_idx[0] )
 4339    bxtmod='bold'
 4340    bxtmod='t2'
 4341    constant_mask=False
 4342    if verbose:
 4343        print( np.unique( bvals ), flush=True )
 4344    if mask is not None:
 4345        if verbose:
 4346            print("use set bxt in dipy_dti_recon", flush=True)
 4347        constant_mask=True
 4348        mask = ants.resample_image_to_target( mask, b0, interp_type='nearestNeighbor')
 4349    else:
 4350        if verbose:
 4351            print("use deep learning bxt in dipy_dti_recon")
 4352        mask = antspynet.brain_extraction( b0, bxtmod ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 4353    if mask_closing > 0 and not constant_mask :
 4354        mask = ants.morphology( mask, "close", mask_closing ) # good
 4355    maskdil = ants.iMath( mask, "MD", mask_dilation )
 4356
 4357    if verbose:
 4358        print("recon dti.TensorModel",flush=True)
 4359
 4360    bvecs = repair_bvecs( bvecs )
 4361    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 4362    mynt=1
 4363    threads_env = os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")
 4364    if threads_env is not None:
 4365        mynt = int(threads_env)
 4366    tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
 4367                                             num_threads=mynt )
 4368    if verbose:
 4369        print("recon dti.TensorModel done",flush=True)
 4370
 4371    # change the brain mask based on high FA values
 4372    if trim_the_mask > 0 and fit_method is not None:
 4373        mask = trim_dti_mask( FA, mask, trim_the_mask )
 4374        tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
 4375                                             num_threads=mynt )
 4376
 4377    return {
 4378        'tensormodel' : tenfit,
 4379        'MD' : MD1 ,
 4380        'FA' : FA ,
 4381        'RGB' : RGB,
 4382        'dwi_mask':mask,
 4383        'bvals':bvals,
 4384        'bvecs':bvecs
 4385        }
 4386
 4387
 4388def concat_dewarp(
 4389        refimg,
 4390        originalDWI,
 4391        physSpaceDWI,
 4392        dwpTx,
 4393        motion_parameters,
 4394        motion_correct=True,
 4395        verbose=False ):
 4396    """
 4397    Apply concatentated motion correction and dewarping transforms to timeseries image.
 4398
 4399    Arguments
 4400    ---------
 4401
 4402    refimg : an antsImage defining the reference domain (3D)
 4403
 4404    originalDWI : the antsImage in original (not interpolated space) (4D)
 4405
 4406    physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)
 4407
 4408    dwpTx : dewarping transform
 4409
 4410    motion_parameters : previously computed list of motion parameters
 4411
 4412    motion_correct : boolean
 4413
 4414    verbose : boolean
 4415
 4416    """
 4417    # apply the dewarping tx to the original dwi and reconstruct again
 4418    # NOTE: refimg must be in the same space for this to work correctly
 4419    # due to the use of ants.list_to_ndimage( originalDWI, dwpimage )
 4420    dwpimage = []
 4421    for myidx in range(originalDWI.shape[3]):
 4422        b0 = ants.slice_image( originalDWI, axis=3, idx=myidx)
 4423        concatx = dwpTx.copy()
 4424        if motion_correct:
 4425            concatx = concatx + motion_parameters[myidx]
 4426        if verbose and myidx == 0:
 4427            print("dwp parameters")
 4428            print( dwpTx )
 4429            print("Motion parameters")
 4430            print( motion_parameters[myidx] )
 4431            print("concat parameters")
 4432            print(concatx)
 4433        warpedb0 = ants.apply_transforms( refimg, b0, concatx,
 4434            interpolator='nearestNeighbor' )
 4435        dwpimage.append( warpedb0 )
 4436    return ants.list_to_ndimage( physSpaceDWI, dwpimage )
 4437
 4438
 4439def joint_dti_recon(
 4440    img_LR,
 4441    bval_LR,
 4442    bvec_LR,
 4443    jhu_atlas,
 4444    jhu_labels,
 4445    reference_B0,
 4446    reference_DWI,
 4447    srmodel = None,
 4448    img_RL = None,
 4449    bval_RL = None,
 4450    bvec_RL = None,
 4451    t1w = None,
 4452    brain_mask = None,
 4453    motion_correct = None,
 4454    dewarp_modality = 'FA',
 4455    denoise=False,
 4456    fit_method='WLS',
 4457    impute = False,
 4458    censor = True,
 4459    diffusion_model = 'DTI',
 4460    verbose = False ):
 4461    """
 4462    1. pass in subject data and 1mm JHU atlas/labels
 4463    2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
 4464    3. dewarp the images using dewarp_modality or T1w
 4465    4. apply dewarping to the original data
 4466        ===> may want to apply SR at this step
 4467    5. reconstruct DTI again
 4468    6. label images and do registration
 4469    7. return relevant outputs
 4470
 4471    NOTE: RL images are optional; should pass t1w in this case.
 4472
 4473    Arguments
 4474    ---------
 4475
 4476    img_LR : an antsImage holding B0 and DWI LR acquisition
 4477
 4478    bval_LR : bvalue filename LR
 4479
 4480    bvec_LR : bvector filename LR
 4481
 4482    jhu_atlas : atlas FA image
 4483
 4484    jhu_labels : atlas labels
 4485
 4486    reference_B0 : the "target" B0 image space
 4487
 4488    reference_DWI : the "target" DW image space
 4489
 4490    srmodel : optional h5 (tensorflow) model
 4491
 4492    img_RL : an antsImage holding B0 and DWI RL acquisition
 4493
 4494    bval_RL : bvalue filename RL
 4495
 4496    bvec_RL : bvector filename RL
 4497
 4498    t1w : antsimage t1w neuroimage (brain-extracted)
 4499
 4500    brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space
 4501
 4502    motion_correct : None Rigid or SyN
 4503
 4504    dewarp_modality : string average_dwi, average_b0, MD or FA
 4505
 4506    denoise: boolean
 4507
 4508    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
 4509
 4510    impute : boolean
 4511
 4512    censor : boolean
 4513
 4514    diffusion_model : string
 4515        DTI, FreeWater, DKI
 4516
 4517    verbose : boolean
 4518
 4519    Returns
 4520    -------
 4521    dictionary holding the mean_fa, its summary statistics via JHU labels,
 4522        the JHU registration, the JHU labels, the dewarping dictionary and the
 4523        dti reconstruction dictionaries.
 4524
 4525    Example
 4526    -------
 4527    >>> import antspymm
 4528    """
 4529
 4530    if verbose:
 4531        print("Recon DTI on OR images ...")
 4532
 4533    def fix_dwi_shape( img, bvalfn, bvecfn ):
 4534        if isinstance(bvecfn, str):
 4535            bvals, bvecs = read_bvals_bvecs( bvalfn , bvecfn   )
 4536        else:
 4537            bvals = bvalfn
 4538            bvecs = bvecfn
 4539        if bvecs.shape[0] < img.shape[3]:
 4540            imgout = ants.from_numpy( img[:,:,:,0:bvecs.shape[0]] )
 4541            imgout = ants.copy_image_info( img, imgout )
 4542            return( imgout )
 4543        else:
 4544            return( img )
 4545
 4546    img_LR = fix_dwi_shape( img_LR, bval_LR, bvec_LR )
 4547    if denoise :
 4548        img_LR = mc_denoise( img_LR )
 4549    if img_RL is not None:
 4550        img_RL = fix_dwi_shape( img_RL, bval_RL, bvec_RL )
 4551        if denoise :
 4552            img_RL = mc_denoise( img_RL )
 4553
 4554    brainmaske = None
 4555    if brain_mask is not None:
 4556        maskInRightSpace = ants.image_physical_space_consistency( brain_mask, reference_B0 )
 4557        if not maskInRightSpace :
 4558            raise ValueError('not maskInRightSpace ... provided brain mask should be in reference_B0 space')
 4559        brainmaske = ants.iMath( brain_mask, "ME", 2 )
 4560
 4561    if img_RL is not None :
 4562        if verbose:
 4563            print("img_RL correction")
 4564        reg_RL = dti_reg(
 4565            img_RL,
 4566            avg_b0=reference_B0,
 4567            avg_dwi=reference_DWI,
 4568            bvals=bval_RL,
 4569            bvecs=bvec_RL,
 4570            type_of_transform=motion_correct,
 4571            brain_mask_eroded=brainmaske,
 4572            verbose=True )
 4573    else:
 4574        reg_RL=None
 4575
 4576
 4577    if verbose:
 4578        print("img_LR correction")
 4579    reg_LR = dti_reg(
 4580            img_LR,
 4581            avg_b0=reference_B0,
 4582            avg_dwi=reference_DWI,
 4583            bvals=bval_LR,
 4584            bvecs=bvec_LR,
 4585            type_of_transform=motion_correct,
 4586            brain_mask_eroded=brainmaske,
 4587            verbose=True )
 4588
 4589    ts_LR_avg = None
 4590    ts_RL_avg = None
 4591    reg_its = [100,50,10]
 4592    img_LRdwp = ants.image_clone( reg_LR[ 'motion_corrected' ] )
 4593    if img_RL is not None:
 4594        img_RLdwp = ants.image_clone( reg_RL[ 'motion_corrected' ] )
 4595        if srmodel is not None:
 4596            if verbose:
 4597                print("convert img_RL_dwp to img_RL_dwp_SR")
 4598            img_RLdwp = super_res_mcimage( img_RLdwp, srmodel, isotropic=True,
 4599                        verbose=verbose )
 4600    if srmodel is not None:
 4601        reg_its = [100] + reg_its
 4602        if verbose:
 4603            print("convert img_LR_dwp to img_LR_dwp_SR")
 4604        img_LRdwp = super_res_mcimage( img_LRdwp, srmodel, isotropic=True,
 4605                verbose=verbose )
 4606    if verbose:
 4607        print("recon after distortion correction", flush=True)
 4608
 4609    if impute:
 4610        print("impute begin", flush=True)
 4611        img_LRdwp=impute_dwi( img_LRdwp, verbose=True )
 4612        print("impute done", flush=True)
 4613    elif censor:
 4614        print("censor begin", flush=True)
 4615        img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'] = censor_dwi( img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'], verbose=True )
 4616        print("censor done", flush=True)
 4617    if impute and img_RL is not None:
 4618        img_RLdwp=impute_dwi( img_RLdwp, verbose=True )
 4619    elif censor and img_RL is not None:
 4620        img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'] = censor_dwi( img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'], verbose=True )
 4621
 4622    if img_RL is not None:
 4623        img_LRdwp, bval_LR, bvec_LR = merge_dwi_data(
 4624            img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'],
 4625            img_RLdwp, reg_RL['bvals'], reg_RL['bvecs']
 4626        )
 4627    else:
 4628        bval_LR=reg_LR['bvals']
 4629        bvec_LR=reg_LR['bvecs']
 4630
 4631    if verbose:
 4632        print("final recon", flush=True)
 4633        print(img_LRdwp)
 4634
 4635    recon_LR_dewarp = dipy_dti_recon(
 4636            img_LRdwp, bval_LR, bvec_LR,
 4637            mask = brain_mask,
 4638            fit_method=fit_method,
 4639            mask_dilation=0, diffusion_model=diffusion_model, verbose=True )
 4640    if verbose:
 4641        print("recon done", flush=True)
 4642
 4643    if img_RL is not None:
 4644        fdjoin = [ reg_LR['FD'],
 4645                   reg_RL['FD'] ]
 4646        framewise_displacement=np.concatenate( fdjoin )
 4647    else:
 4648        framewise_displacement=reg_LR['FD']
 4649
 4650    motion_count = ( framewise_displacement > 1.5  ).sum()
 4651    reconFA = recon_LR_dewarp['FA']
 4652    reconMD = recon_LR_dewarp['MD']
 4653
 4654    if verbose:
 4655        print("JHU reg",flush=True)
 4656
 4657    OR_FA2JHUreg = ants.registration( reconFA, jhu_atlas,
 4658        type_of_transform = 'antsRegistrationSyNQuickRepro[s]', 
 4659        reg_iterations=reg_its, verbose=False )
 4660    OR_FA_jhulabels = ants.apply_transforms( reconFA, jhu_labels,
 4661        OR_FA2JHUreg['fwdtransforms'], interpolator='genericLabel')
 4662
 4663    df_FA_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
 4664        'FA_JHU_labels_edited',
 4665        reconFA,
 4666        OR_FA_jhulabels)
 4667    df_FA_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 4668            {'df_FA_JHU_ORRL' : df_FA_JHU_ORRL},
 4669            col_names = ['Mean'] )
 4670
 4671    df_MD_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
 4672        'MD_JHU_labels_edited',
 4673        reconMD,
 4674        OR_FA_jhulabels)
 4675    df_MD_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 4676            {'df_MD_JHU_ORRL' : df_MD_JHU_ORRL},
 4677            col_names = ['Mean'] )
 4678
 4679    temp = segment_timeseries_by_meanvalue( img_LRdwp )
 4680    b0_idx = temp['highermeans']
 4681    non_b0_idx = temp['lowermeans']
 4682
 4683    nonbrainmask = ants.iMath( recon_LR_dewarp['dwi_mask'], "MD",2) - recon_LR_dewarp['dwi_mask']
 4684    fgmask = ants.threshold_image( reconFA, 0.5 , 1.0).iMath("GetLargestComponent")
 4685    bgmask = ants.threshold_image( reconFA, 1e-4 , 0.1)
 4686    fa_SNR = 0.0
 4687    fa_SNR = mask_snr( reconFA, bgmask, fgmask, bias_correct=False )
 4688    fa_evr = antspyt1w.patch_eigenvalue_ratio( reconFA, 512, [16,16,16], evdepth = 0.9, mask=recon_LR_dewarp['dwi_mask'] )
 4689
 4690    dti_itself = get_dti( reconFA, recon_LR_dewarp['tensormodel'], return_image=True )
 4691    return convert_np_in_dict( {
 4692        'dti': dti_itself,
 4693        'recon_fa':reconFA,
 4694        'recon_fa_summary':df_FA_JHU_ORRL_bfwide,
 4695        'recon_md':reconMD,
 4696        'recon_md_summary':df_MD_JHU_ORRL_bfwide,
 4697        'jhu_labels':OR_FA_jhulabels,
 4698        'jhu_registration':OR_FA2JHUreg,
 4699        'reg_LR':reg_LR,
 4700        'reg_RL':reg_RL,
 4701        'dtrecon_LR_dewarp':recon_LR_dewarp,
 4702        'dwi_LR_dewarped':img_LRdwp,
 4703        'bval_unique_count': len(np.unique(bval_LR)),
 4704        'bval_LR':bval_LR,
 4705        'bvec_LR':bvec_LR,
 4706        'bval_RL':bval_RL,
 4707        'bvec_RL':bvec_RL,
 4708        'b0avg': reference_B0,
 4709        'dwiavg': reference_DWI,
 4710        'framewise_displacement':framewise_displacement,
 4711        'high_motion_count': motion_count,
 4712        'tsnr_b0': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
 4713        'tsnr_dwi': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
 4714        'dvars_b0': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
 4715        'dvars_dwi': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
 4716        'ssnr_b0': slice_snr( img_LRdwp, bgmask , fgmask, b0_idx),
 4717        'ssnr_dwi': slice_snr( img_LRdwp, bgmask, fgmask, non_b0_idx),
 4718        'fa_evr': fa_evr,
 4719        'fa_SNR': fa_SNR
 4720    } )
 4721
 4722
 4723def middle_slice_snr( x, background_dilation=5 ):
 4724    """
 4725
 4726    Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image.
 4727    Estimates noise from a background mask which is a
 4728    dilation of the foreground mask minus the foreground mask.
 4729    Actually estimates the reciprocal of the coefficient of variation.
 4730
 4731    Arguments
 4732    ---------
 4733
 4734    x : an antsImage
 4735
 4736    background_dilation : integer - amount to dilate foreground mask
 4737
 4738    """
 4739    xshp = x.shape
 4740    xmidslice = ants.slice_image( x, 2, int( xshp[2]/2 )  )
 4741    xmidslice = ants.iMath( xmidslice - xmidslice.min(), "Normalize" )
 4742    xmidslice = ants.n3_bias_field_correction( xmidslice )
 4743    xmidslice = ants.n3_bias_field_correction( xmidslice )
 4744    xmidslicemask = ants.threshold_image( xmidslice, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
 4745    xbkgmask = ants.iMath( xmidslicemask, "MD", background_dilation ) - xmidslicemask
 4746    signal = (xmidslice[ xmidslicemask == 1] ).mean()
 4747    noise = (xmidslice[ xbkgmask == 1] ).std()
 4748    return signal / noise
 4749
 4750def foreground_background_snr( x, background_dilation=10,
 4751        erode_foreground=False):
 4752    """
 4753
 4754    Estimate signal to noise ratio (SNR) in an image.
 4755    Estimates noise from a background mask which is a
 4756    dilation of the foreground mask minus the foreground mask.
 4757    Actually estimates the reciprocal of the coefficient of variation.
 4758
 4759    Arguments
 4760    ---------
 4761
 4762    x : an antsImage
 4763
 4764    background_dilation : integer - amount to dilate foreground mask
 4765
 4766    erode_foreground : boolean - 2nd option which erodes the initial
 4767    foregound mask  to create a new foreground mask.  the background
 4768    mask is the initial mask minus the eroded mask.
 4769
 4770    """
 4771    xshp = x.shape
 4772    xbc = ants.iMath( x - x.min(), "Normalize" )
 4773    xbc = ants.n3_bias_field_correction( xbc )
 4774    xmask = ants.threshold_image( xbc, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
 4775    xbkgmask = ants.iMath( xmask, "MD", background_dilation ) - xmask
 4776    fgmask = xmask
 4777    if erode_foreground:
 4778        fgmask = ants.iMath( xmask, "ME", background_dilation )
 4779        xbkgmask = xmask - fgmask
 4780    signal = (xbc[ fgmask == 1] ).mean()
 4781    noise = (xbc[ xbkgmask == 1] ).std()
 4782    return signal / noise
 4783
 4784def quantile_snr( x,
 4785    lowest_quantile=0.01,
 4786    low_quantile=0.1,
 4787    high_quantile=0.5,
 4788    highest_quantile=0.95 ):
 4789    """
 4790
 4791    Estimate signal to noise ratio (SNR) in an image.
 4792    Estimates noise from a background mask which is a
 4793    dilation of the foreground mask minus the foreground mask.
 4794    Actually estimates the reciprocal of the coefficient of variation.
 4795
 4796    Arguments
 4797    ---------
 4798
 4799    x : an antsImage
 4800
 4801    lowest_quantile : float value < 1 and > 0
 4802
 4803    low_quantile : float value < 1 and > 0
 4804
 4805    high_quantile : float value < 1 and > 0
 4806
 4807    highest_quantile : float value < 1 and > 0
 4808
 4809    """
 4810    import numpy as np
 4811    xshp = x.shape
 4812    xbc = ants.iMath( x - x.min(), "Normalize" )
 4813    xbc = ants.n3_bias_field_correction( xbc )
 4814    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
 4815    y = xbc.numpy()
 4816    ylowest = np.quantile( y[y>0], lowest_quantile )
 4817    ylo = np.quantile( y[y>0], low_quantile )
 4818    yhi = np.quantile( y[y>0], high_quantile )
 4819    yhiest = np.quantile( y[y>0], highest_quantile )
 4820    xbkgmask = ants.threshold_image( xbc, ylowest, ylo )
 4821    fgmask = ants.threshold_image( xbc, yhi, yhiest )
 4822    signal = (xbc[ fgmask == 1] ).mean()
 4823    noise = (xbc[ xbkgmask == 1] ).std()
 4824    return signal / noise
 4825
 4826def mask_snr( x, background_mask, foreground_mask, bias_correct=True ):
 4827    """
 4828
 4829    Estimate signal to noise ratio (SNR) in an image using
 4830    a user-defined foreground and background mask.
 4831    Actually estimates the reciprocal of the coefficient of variation.
 4832
 4833    Arguments
 4834    ---------
 4835
 4836    x : an antsImage
 4837
 4838    background_mask : binary antsImage
 4839
 4840    foreground_mask : binary antsImage
 4841
 4842    bias_correct : boolean
 4843
 4844    """
 4845    import numpy as np
 4846    if foreground_mask.sum() <= 1 or background_mask.sum() <= 1:
 4847        return 0
 4848    xbc = ants.iMath( x - x.min(), "Normalize" )
 4849    if bias_correct:
 4850        xbc = ants.n3_bias_field_correction( xbc )
 4851    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
 4852    signal = (xbc[ foreground_mask == 1] ).mean()
 4853    noise = (xbc[ background_mask == 1] ).std()
 4854    return signal / noise
 4855
 4856
 4857def dwi_deterministic_tracking(
 4858    dwi,
 4859    fa,
 4860    bvals,
 4861    bvecs,
 4862    num_processes=1,
 4863    mask=None,
 4864    label_image = None,
 4865    seed_labels = None,
 4866    fa_thresh = 0.05,
 4867    seed_density = 1,
 4868    step_size = 0.15,
 4869    peak_indices = None,
 4870    fit_method='WLS',
 4871    verbose = False ):
 4872    """
 4873
 4874    Performs deterministic tractography from the DWI and returns a tractogram
 4875    and path length data frame.
 4876
 4877    Arguments
 4878    ---------
 4879
 4880    dwi : an antsImage holding DWI acquisition
 4881
 4882    fa : an antsImage holding FA values
 4883
 4884    bvals : bvalues
 4885
 4886    bvecs : bvectors
 4887
 4888    num_processes : number of subprocesses
 4889
 4890    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
 4891        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 4892
 4893    label_image : atlas labels
 4894
 4895    seed_labels : list of label numbers from the atlas labels
 4896
 4897    fa_thresh : 0.25 defaults
 4898
 4899    seed_density : 1 default number of seeds per voxel
 4900
 4901    step_size : for tracking
 4902
 4903    peak_indices : pass these in, if they are previously estimated.  otherwise, will
 4904        compute on the fly (slow)
 4905
 4906    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
 4907
 4908    verbose : boolean
 4909
 4910    Returns
 4911    -------
 4912    dictionary holding tracts and stateful object.
 4913
 4914    Example
 4915    -------
 4916    >>> import antspymm
 4917    """
 4918    import os
 4919    import re
 4920    import nibabel as nib
 4921    import numpy as np
 4922    import ants
 4923    from dipy.io.gradients import read_bvals_bvecs
 4924    from dipy.core.gradients import gradient_table
 4925    from dipy.tracking import utils
 4926    import dipy.reconst.dti as dti
 4927    from dipy.segment.clustering import QuickBundles
 4928    from dipy.tracking.utils import path_length
 4929    if verbose:
 4930        print("begin tracking",flush=True)
 4931
 4932    affine = ants_to_nibabel_affine(dwi)
 4933
 4934    if isinstance( bvals, str ) or isinstance( bvecs, str ):
 4935        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
 4936    bvecs = repair_bvecs( bvecs )
 4937    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 4938    if mask is None:
 4939        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 4940    dwi_data = dwi.numpy()
 4941    dwi_mask = mask.numpy() == 1
 4942    dti_model = dti.TensorModel(gtab,fit_method=fit_method)
 4943    if verbose:
 4944        print("begin tracking fit",flush=True)
 4945    dti_fit = dti_model.fit(dwi_data, mask=dwi_mask)  # This step may take a while
 4946    evecs_img = dti_fit.evecs
 4947    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
 4948    stopping_criterion = ThresholdStoppingCriterion(fa.numpy(), fa_thresh)
 4949    from dipy.data import get_sphere
 4950    sphere = get_sphere(name='symmetric362')
 4951    from dipy.direction import peaks_from_model
 4952    if peak_indices is None:
 4953        # problems with multi-threading ...
 4954        # see https://github.com/dipy/dipy/issues/2519
 4955        if verbose:
 4956            print("begin peaks",flush=True)
 4957        mynump=1
 4958        # if os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"):
 4959        #    mynump = os.environ['ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS']
 4960        # current_openblas = os.environ.get('OPENBLAS_NUM_THREADS', '')
 4961        # current_mkl = os.environ.get('MKL_NUM_THREADS', '')
 4962        # os.environ['DIPY_OPENBLAS_NUM_THREADS'] = current_openblas
 4963        # os.environ['DIPY_MKL_NUM_THREADS'] = current_mkl
 4964        # os.environ['OPENBLAS_NUM_THREADS'] = '1'
 4965        # os.environ['MKL_NUM_THREADS'] = '1'
 4966        peak_indices = peaks_from_model(
 4967            model=dti_model,
 4968            data=dwi_data,
 4969            sphere=sphere,
 4970            relative_peak_threshold=.5,
 4971            min_separation_angle=25,
 4972            mask=dwi_mask,
 4973            npeaks=3, return_odf=False,
 4974            return_sh=False,
 4975            parallel=int(mynump) > 1,
 4976            num_processes=int(mynump)
 4977            )
 4978        if False:
 4979            if 'DIPY_OPENBLAS_NUM_THREADS' in os.environ:
 4980                os.environ['OPENBLAS_NUM_THREADS'] = \
 4981                    os.environ.pop('DIPY_OPENBLAS_NUM_THREADS', '')
 4982                if os.environ['OPENBLAS_NUM_THREADS'] in ['', None]:
 4983                    os.environ.pop('OPENBLAS_NUM_THREADS', '')
 4984            if 'DIPY_MKL_NUM_THREADS' in os.environ:
 4985                os.environ['MKL_NUM_THREADS'] = \
 4986                    os.environ.pop('DIPY_MKL_NUM_THREADS', '')
 4987                if os.environ['MKL_NUM_THREADS'] in ['', None]:
 4988                    os.environ.pop('MKL_NUM_THREADS', '')
 4989
 4990    if label_image is None or seed_labels is None:
 4991        seed_mask = fa.numpy().copy()
 4992        seed_mask[seed_mask >= fa_thresh] = 1
 4993        seed_mask[seed_mask < fa_thresh] = 0
 4994    else:
 4995        labels = label_image.numpy()
 4996        seed_mask = labels * 0
 4997        for u in seed_labels:
 4998            seed_mask[ labels == u ] = 1
 4999    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
 5000    from dipy.tracking.local_tracking import LocalTracking
 5001    from dipy.tracking.streamline import Streamlines
 5002    if verbose:
 5003        print("streamlines begin ...", flush=True)
 5004    streamlines_generator = LocalTracking(
 5005        peak_indices, stopping_criterion, seeds, affine=affine, step_size=step_size)
 5006    streamlines = Streamlines(streamlines_generator)
 5007    from dipy.io.stateful_tractogram import Space, StatefulTractogram
 5008    from dipy.io.streamline import save_tractogram
 5009    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
 5010    if verbose:
 5011        print("streamlines done", flush=True)
 5012    return {
 5013          'tractogram': sft,
 5014          'streamlines': streamlines,
 5015          'peak_indices': peak_indices
 5016          }
 5017
 5018def repair_bvecs( bvecs ):
 5019    bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
 5020    # bvecnormnan = np.isnan( bvecnorm )
 5021    # nan_indices = list( np.unique( np.where(np.isnan(bvecs))[0]))
 5022    # bvecs = remove_elements_from_numpy_array( bvecs, nan_indices )
 5023    if abs(np.linalg.norm(bvecs)-1) > 0.009:
 5024        warnings.warn( "Warning: bvecs are not unit norm - we normalize them here but this may indicate a problem with the data.  Norm is : " + str( np.linalg.norm(bvecs) ) + " shape is " + str( bvecs.shape[0] ) + " " + str( bvecs.shape[1] ))
 5025        bvecs=np.where(bvecnorm > 1e-16, bvecs / bvecnorm, 0)
 5026    return bvecs
 5027
 5028
 5029def dwi_closest_peak_tracking(
 5030    dwi,
 5031    fa,
 5032    bvals,
 5033    bvecs,
 5034    num_processes=1,
 5035    mask=None,
 5036    label_image = None,
 5037    seed_labels = None,
 5038    fa_thresh = 0.05,
 5039    seed_density = 1,
 5040    step_size = 0.15,
 5041    peak_indices = None,
 5042    verbose = False ):
 5043    """
 5044
 5045    Performs deterministic tractography from the DWI and returns a tractogram
 5046    and path length data frame.
 5047
 5048    Arguments
 5049    ---------
 5050
 5051    dwi : an antsImage holding DWI acquisition
 5052
 5053    fa : an antsImage holding FA values
 5054
 5055    bvals : bvalues
 5056
 5057    bvecs : bvectors
 5058
 5059    num_processes : number of subprocesses
 5060
 5061    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
 5062        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 5063
 5064    label_image : atlas labels
 5065
 5066    seed_labels : list of label numbers from the atlas labels
 5067
 5068    fa_thresh : 0.25 defaults
 5069
 5070    seed_density : 1 default number of seeds per voxel
 5071
 5072    step_size : for tracking
 5073
 5074    peak_indices : pass these in, if they are previously estimated.  otherwise, will
 5075        compute on the fly (slow)
 5076
 5077    verbose : boolean
 5078
 5079    Returns
 5080    -------
 5081    dictionary holding tracts and stateful object.
 5082
 5083    Example
 5084    -------
 5085    >>> import antspymm
 5086    """
 5087    import os
 5088    import re
 5089    import nibabel as nib
 5090    import numpy as np
 5091    import ants
 5092    from dipy.io.gradients import read_bvals_bvecs
 5093    from dipy.core.gradients import gradient_table
 5094    from dipy.tracking import utils
 5095    import dipy.reconst.dti as dti
 5096    from dipy.segment.clustering import QuickBundles
 5097    from dipy.tracking.utils import path_length
 5098    from dipy.core.gradients import gradient_table
 5099    from dipy.data import small_sphere
 5100    from dipy.direction import BootDirectionGetter, ClosestPeakDirectionGetter
 5101    from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,
 5102                                    auto_response_ssst)
 5103    from dipy.reconst.shm import CsaOdfModel
 5104    from dipy.tracking.local_tracking import LocalTracking
 5105    from dipy.tracking.streamline import Streamlines
 5106    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
 5107
 5108    if verbose:
 5109        print("begin tracking",flush=True)
 5110
 5111    affine = ants_to_nibabel_affine(dwi)
 5112    if isinstance( bvals, str ) or isinstance( bvecs, str ):
 5113        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
 5114    bvecs = repair_bvecs( bvecs )
 5115    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
 5116    if mask is None:
 5117        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
 5118    dwi_data = dwi.numpy()
 5119    dwi_mask = mask.numpy() == 1
 5120
 5121
 5122    response, ratio = auto_response_ssst(gtab, dwi_data, roi_radii=10, fa_thr=0.7)
 5123    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
 5124    csd_fit = csd_model.fit(dwi_data, mask=dwi_mask)
 5125    csa_model = CsaOdfModel(gtab, sh_order=6)
 5126    gfa = csa_model.fit(dwi_data, mask=dwi_mask).gfa
 5127    stopping_criterion = ThresholdStoppingCriterion(gfa, .25)
 5128
 5129
 5130    if label_image is None or seed_labels is None:
 5131        seed_mask = fa.numpy().copy()
 5132        seed_mask[seed_mask >= fa_thresh] = 1
 5133        seed_mask[seed_mask < fa_thresh] = 0
 5134    else:
 5135        labels = label_image.numpy()
 5136        seed_mask = labels * 0
 5137        for u in seed_labels:
 5138            seed_mask[ labels == u ] = 1
 5139    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
 5140    if verbose:
 5141        print("streamlines begin ...", flush=True)
 5142
 5143    pmf = csd_fit.odf(small_sphere).clip(min=0)
 5144    if verbose:
 5145        print("ClosestPeakDirectionGetter begin ...", flush=True)
 5146    peak_dg = ClosestPeakDirectionGetter.from_pmf(pmf, max_angle=30.,
 5147                                                sphere=small_sphere)
 5148    if verbose:
 5149        print("local tracking begin ...", flush=True)
 5150    streamlines_generator = LocalTracking(peak_dg, stopping_criterion, seeds,
 5151                                            affine, step_size=.5)
 5152    streamlines = Streamlines(streamlines_generator)
 5153    from dipy.io.stateful_tractogram import Space, StatefulTractogram
 5154    from dipy.io.streamline import save_tractogram
 5155    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
 5156    if verbose:
 5157        print("streamlines done", flush=True)
 5158    return {
 5159          'tractogram': sft,
 5160          'streamlines': streamlines
 5161          }
 5162
 5163def dwi_streamline_pairwise_connectivity( streamlines, label_image, labels_to_connect=[1,None], verbose=False ):
 5164    """
 5165
 5166    Return streamlines connecting all of the regions in the label set. Ideal
 5167    for just 2 regions.
 5168
 5169    Arguments
 5170    ---------
 5171
 5172    streamlines : streamline object from dipy
 5173
 5174    label_image : atlas labels
 5175
 5176    labels_to_connect : list of 2 labels or [label,None]
 5177
 5178    verbose : boolean
 5179
 5180    Returns
 5181    -------
 5182    the subset of streamlines and a streamline count
 5183
 5184    Example
 5185    -------
 5186    >>> import antspymm
 5187    """
 5188    from dipy.tracking.streamline import Streamlines
 5189    keep_streamlines = Streamlines()
 5190
 5191    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5192
 5193    lin_T, offset = utils._mapping_to_voxel(affine)
 5194    label_image_np = label_image.numpy()
 5195    def check_it( sl, target_label, label_image, index, full=False ):
 5196        if full:
 5197            maxind=sl.shape[0]
 5198            for index in range(maxind):
 5199                pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5200                mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5201                if mylab == target_label[0] or mylab == target_label[1]:
 5202                    return { 'ok': True, 'label':mylab }
 5203        else:
 5204            pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5205            mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5206            if mylab == target_label[0] or mylab == target_label[1]:
 5207                return { 'ok': True, 'label':mylab }
 5208        return { 'ok': False, 'label':None }
 5209    ct=0
 5210    for k in range( len( streamlines ) ):
 5211        sl = streamlines[k]
 5212        mycheck = check_it( sl, labels_to_connect, label_image_np, index=0, full=True )
 5213        if mycheck['ok']:
 5214            otherind=1
 5215            if mycheck['label'] == labels_to_connect[1]:
 5216                otherind=0
 5217            lsl = len( sl )-1
 5218            pt = utils._to_voxel_coordinates(sl[lsl,:], lin_T, offset)
 5219            mylab_end = (label_image_np[ pt[0], pt[1], pt[2] ]).astype(int)
 5220            accept_point = mylab_end == labels_to_connect[otherind]
 5221            if verbose and accept_point:
 5222                print( mylab_end )
 5223            if labels_to_connect[1] is None:
 5224                accept_point = mylab_end != 0
 5225            if accept_point:
 5226                keep_streamlines.append(sl)
 5227                ct=ct+1
 5228    return { 'streamlines': keep_streamlines, 'count': ct }
 5229
 5230def dwi_streamline_pairwise_connectivity_old(
 5231    streamlines,
 5232    label_image,
 5233    exclusion_label = None,
 5234    verbose = False ):
 5235    import os
 5236    import re
 5237    import nibabel as nib
 5238    import numpy as np
 5239    import ants
 5240    from dipy.io.gradients import read_bvals_bvecs
 5241    from dipy.core.gradients import gradient_table
 5242    from dipy.tracking import utils
 5243    import dipy.reconst.dti as dti
 5244    from dipy.segment.clustering import QuickBundles
 5245    from dipy.tracking.utils import path_length
 5246    from dipy.tracking.local_tracking import LocalTracking
 5247    from dipy.tracking.streamline import Streamlines
 5248    volUnit = np.prod( ants.get_spacing( label_image ) )
 5249    labels = label_image.numpy()
 5250
 5251    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5252
 5253    import numpy as np
 5254    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5255    import pandas as pd
 5256    ulabs = np.unique( labels[ labels > 0 ] )
 5257    if exclusion_label is not None:
 5258        ulabs = ulabs[ ulabs != exclusion_label ]
 5259        exc_slice = labels == exclusion_label
 5260    if verbose:
 5261        print("Begin connectivity")
 5262    tracts = []
 5263    for k in range(len(ulabs)):
 5264        cc_slice = labels == ulabs[k]
 5265        cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5266        cc_streamlines = Streamlines(cc_streamlines)
 5267        if exclusion_label is not None:
 5268            cc_streamlines = utils.target(cc_streamlines, affine, exc_slice, include=False)
 5269            cc_streamlines = Streamlines(cc_streamlines)
 5270        for j in range(len(ulabs)):
 5271            cc_slice2 = labels == ulabs[j]
 5272            cc_streamlines2 = utils.target(cc_streamlines, affine, cc_slice2)
 5273            cc_streamlines2 = Streamlines(cc_streamlines2)
 5274            if exclusion_label is not None:
 5275                cc_streamlines2 = utils.target(cc_streamlines2, affine, exc_slice, include=False)
 5276                cc_streamlines2 = Streamlines(cc_streamlines2)
 5277            tracts.append( cc_streamlines2 )
 5278        if verbose:
 5279            print("end connectivity")
 5280    return {
 5281          'pairwise_tracts': tracts
 5282          }
 5283
 5284
 5285def dwi_streamline_connectivity(
 5286    streamlines,
 5287    label_image,
 5288    label_dataframe,
 5289    verbose = False ):
 5290    """
 5291
 5292    Summarize network connetivity of the input streamlines between all of the
 5293    regions in the label set.
 5294
 5295    Arguments
 5296    ---------
 5297
 5298    streamlines : streamline object from dipy
 5299
 5300    label_image : atlas labels
 5301
 5302    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
 5303
 5304    verbose : boolean
 5305
 5306    Returns
 5307    -------
 5308    dictionary holding summary connection statistics in wide format and matrix format.
 5309
 5310    Example
 5311    -------
 5312    >>> import antspymm
 5313    """
 5314    import os
 5315    import re
 5316    import nibabel as nib
 5317    import numpy as np
 5318    import ants
 5319    from dipy.io.gradients import read_bvals_bvecs
 5320    from dipy.core.gradients import gradient_table
 5321    from dipy.tracking import utils
 5322    import dipy.reconst.dti as dti
 5323    from dipy.segment.clustering import QuickBundles
 5324    from dipy.tracking.utils import path_length
 5325    from dipy.tracking.local_tracking import LocalTracking
 5326    from dipy.tracking.streamline import Streamlines
 5327    import os
 5328    import re
 5329    import nibabel as nib
 5330    import numpy as np
 5331    import ants
 5332    from dipy.io.gradients import read_bvals_bvecs
 5333    from dipy.core.gradients import gradient_table
 5334    from dipy.tracking import utils
 5335    import dipy.reconst.dti as dti
 5336    from dipy.segment.clustering import QuickBundles
 5337    from dipy.tracking.utils import path_length
 5338    from dipy.tracking.local_tracking import LocalTracking
 5339    from dipy.tracking.streamline import Streamlines
 5340    volUnit = np.prod( ants.get_spacing( label_image ) )
 5341    labels = label_image.numpy()
 5342
 5343    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5344
 5345    import numpy as np
 5346    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5347    import pandas as pd
 5348    ulabs = label_dataframe['Label']
 5349    labels_to_connect = ulabs[ulabs > 0]
 5350    Ctdf = None
 5351    lin_T, offset = utils._mapping_to_voxel(affine)
 5352    label_image_np = label_image.numpy()
 5353    def check_it( sl, target_label, label_image, index, not_label = None ):
 5354        pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
 5355        mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
 5356        if not_label is None:
 5357            if ( mylab == target_label ).sum() > 0 :
 5358                return { 'ok': True, 'label':mylab }
 5359        else:
 5360            if ( mylab == target_label ).sum() > 0 and ( mylab == not_label ).sum() == 0:
 5361                return { 'ok': True, 'label':mylab }
 5362        return { 'ok': False, 'label':None }
 5363    ct=0
 5364    which = lambda lst:list(np.where(lst)[0])
 5365    myCount = np.zeros( [len(ulabs),len(ulabs)])
 5366    for k in range( len( streamlines ) ):
 5367            sl = streamlines[k]
 5368            mycheck = check_it( sl, labels_to_connect, label_image_np, index=0 )
 5369            if mycheck['ok']:
 5370                exclabel=mycheck['label']
 5371                lsl = len( sl )-1
 5372                mycheck2 = check_it( sl, labels_to_connect, label_image_np, index=lsl, not_label=exclabel )
 5373                if mycheck2['ok']:
 5374                    myCount[ulabs == mycheck['label'],ulabs == mycheck2['label']]+=1
 5375                    ct=ct+1
 5376    Ctdf = label_dataframe.copy()
 5377    for k in range(len(ulabs)):
 5378            nn3 = "CnxCount"+str(k).zfill(3)
 5379            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
 5380    Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
 5381    return { 'connectivity_matrix' :  myCount, 'connectivity_wide' : Ctdfw }
 5382
 5383def dwi_streamline_connectivity_old(
 5384    streamlines,
 5385    label_image,
 5386    label_dataframe,
 5387    verbose = False ):
 5388    """
 5389
 5390    Summarize network connetivity of the input streamlines between all of the
 5391    regions in the label set.
 5392
 5393    Arguments
 5394    ---------
 5395
 5396    streamlines : streamline object from dipy
 5397
 5398    label_image : atlas labels
 5399
 5400    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
 5401
 5402    verbose : boolean
 5403
 5404    Returns
 5405    -------
 5406    dictionary holding summary connection statistics in wide format and matrix format.
 5407
 5408    Example
 5409    -------
 5410    >>> import antspymm
 5411    """
 5412
 5413    if verbose:
 5414        print("streamline connections ...")
 5415
 5416    import os
 5417    import re
 5418    import nibabel as nib
 5419    import numpy as np
 5420    import ants
 5421    from dipy.io.gradients import read_bvals_bvecs
 5422    from dipy.core.gradients import gradient_table
 5423    from dipy.tracking import utils
 5424    import dipy.reconst.dti as dti
 5425    from dipy.segment.clustering import QuickBundles
 5426    from dipy.tracking.utils import path_length
 5427    from dipy.tracking.local_tracking import LocalTracking
 5428    from dipy.tracking.streamline import Streamlines
 5429
 5430    volUnit = np.prod( ants.get_spacing( label_image ) )
 5431    labels = label_image.numpy()
 5432
 5433    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
 5434
 5435    if verbose:
 5436        print("path length begin ... volUnit = " + str( volUnit ) )
 5437    import numpy as np
 5438    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
 5439    import pandas as pd
 5440    ulabs = label_dataframe['Label']
 5441    pathLmean = np.zeros( [len(ulabs)])
 5442    pathLtot = np.zeros( [len(ulabs)])
 5443    pathCt = np.zeros( [len(ulabs)])
 5444    for k in range(len(ulabs)):
 5445        cc_slice = labels == ulabs[k]
 5446        cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5447        cc_streamlines = Streamlines(cc_streamlines)
 5448        if len(cc_streamlines) > 0:
 5449            wmpl = path_length(cc_streamlines, affine, cc_slice)
 5450            mean_path_length = wmpl[wmpl>0].mean()
 5451            total_path_length = wmpl[wmpl>0].sum()
 5452            pathLmean[int(k)] = mean_path_length
 5453            pathLtot[int(k)] = total_path_length
 5454            pathCt[int(k)] = len(cc_streamlines) * volUnit
 5455
 5456    # convert paths to data frames
 5457    pathdf = label_dataframe.copy()
 5458    pathdf.insert(pathdf.shape[1], "mean_path_length", pathLmean )
 5459    pathdf.insert(pathdf.shape[1], "total_path_length", pathLtot )
 5460    pathdf.insert(pathdf.shape[1], "streamline_count", pathCt )
 5461    pathdfw =antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5462        {path_length:pathdf }, ['mean_path_length', 'total_path_length', 'streamline_count'] )
 5463    allconnexwide = pathdfw
 5464
 5465    if verbose:
 5466        print("path length done ...")
 5467
 5468    Mdfw = None
 5469    Tdfw = None
 5470    Mdf = None
 5471    Tdf = None
 5472    Ctdf = None
 5473    Ctdfw = None
 5474    if True:
 5475        if verbose:
 5476            print("Begin connectivity")
 5477        M = np.zeros( [len(ulabs),len(ulabs)])
 5478        T = np.zeros( [len(ulabs),len(ulabs)])
 5479        myCount = np.zeros( [len(ulabs),len(ulabs)])
 5480        for k in range(len(ulabs)):
 5481            cc_slice = labels == ulabs[k]
 5482            cc_streamlines = utils.target(streamlines, affine, cc_slice)
 5483            cc_streamlines = Streamlines(cc_streamlines)
 5484            for j in range(len(ulabs)):
 5485                cc_slice2 = labels == ulabs[j]
 5486                cc_streamlines2 = utils.target(cc_streamlines, affine, cc_slice2)
 5487                cc_streamlines2 = Streamlines(cc_streamlines2)
 5488                if len(cc_streamlines2) > 0 :
 5489                    wmpl = path_length(cc_streamlines2, affine, cc_slice2)
 5490                    mean_path_length = wmpl[wmpl>0].mean()
 5491                    total_path_length = wmpl[wmpl>0].sum()
 5492                    M[int(j),int(k)] = mean_path_length
 5493                    T[int(j),int(k)] = total_path_length
 5494                    myCount[int(j),int(k)] = len( cc_streamlines2 ) * volUnit
 5495        if verbose:
 5496            print("end connectivity")
 5497        Mdf = label_dataframe.copy()
 5498        Tdf = label_dataframe.copy()
 5499        Ctdf = label_dataframe.copy()
 5500        for k in range(len(ulabs)):
 5501            nn1 = "CnxMeanPL"+str(k).zfill(3)
 5502            nn2 = "CnxTotPL"+str(k).zfill(3)
 5503            nn3 = "CnxCount"+str(k).zfill(3)
 5504            Mdf.insert(Mdf.shape[1], nn1, M[k,:] )
 5505            Tdf.insert(Tdf.shape[1], nn2, T[k,:] )
 5506            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
 5507        Mdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkm' : Mdf },  Mdf.keys()[2:Mdf.shape[1]] )
 5508        Tdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkt' : Tdf },  Tdf.keys()[2:Tdf.shape[1]] )
 5509        Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
 5510        allconnexwide = pd.concat( [
 5511            pathdfw,
 5512            Mdfw,
 5513            Tdfw,
 5514            Ctdfw ], axis=1, ignore_index=False )
 5515
 5516    return {
 5517          'connectivity': allconnexwide,
 5518          'connectivity_matrix_mean': Mdf,
 5519          'connectivity_matrix_total': Tdf,
 5520          'connectivity_matrix_count': Ctdf
 5521          }
 5522
 5523
 5524def hierarchical_modality_summary(
 5525    target_image,
 5526    hier,
 5527    transformlist,
 5528    modality_name,
 5529    return_keys = ["Mean","Volume"],
 5530    verbose = False ):
 5531    """
 5532
 5533    Use output of antspyt1w.hierarchical to summarize a modality
 5534
 5535    Arguments
 5536    ---------
 5537
 5538    target_image : the image to summarize - should be brain extracted
 5539
 5540    hier : dictionary holding antspyt1w.hierarchical output
 5541
 5542    transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)
 5543
 5544    modality_name : adds the modality name to the data frame columns
 5545
 5546    return_keys = ["Mean","Volume"] keys to return
 5547
 5548    verbose : boolean
 5549
 5550    Returns
 5551    -------
 5552    data frame holding summary statistics in wide format
 5553
 5554    Example
 5555    -------
 5556    >>> import antspymm
 5557    """
 5558    dfout = pd.DataFrame()
 5559    def myhelper( target_image, seg, mytx, mapname, modname, mydf, extra='', verbose=False ):
 5560        if verbose:
 5561            print( mapname )
 5562        target_image_mask = ants.image_clone( target_image ) * 0.0
 5563        target_image_mask[ target_image != 0 ] = 1
 5564        cortmapped = ants.apply_transforms(
 5565            target_image,
 5566            seg,
 5567            mytx, interpolator='nearestNeighbor' ) * target_image_mask
 5568        mapped = antspyt1w.map_intensity_to_dataframe(
 5569            mapname,
 5570            target_image,
 5571            cortmapped )
 5572        mapped.iloc[:,1] = modname + '_' + extra + mapped.iloc[:,1]
 5573        mappedw = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5574            { 'x' : mapped},
 5575            col_names = return_keys )
 5576        if verbose:
 5577            print( mappedw.keys() )
 5578        if mydf.shape[0] > 0:
 5579            mydf = pd.concat( [ mydf, mappedw], axis=1, ignore_index=False )
 5580        else:
 5581            mydf = mappedw
 5582        return mydf
 5583    if hier['dkt_parc']['dkt_cortex'] is not None:
 5584        dfout = myhelper( target_image, hier['dkt_parc']['dkt_cortex'], transformlist,
 5585            "dkt", modality_name, dfout, extra='', verbose=verbose )
 5586    if hier['deep_cit168lab'] is not None:
 5587        dfout = myhelper( target_image, hier['deep_cit168lab'], transformlist,
 5588            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='deep_', verbose=verbose )
 5589    if hier['cit168lab'] is not None:
 5590        dfout = myhelper( target_image, hier['cit168lab'], transformlist,
 5591            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='', verbose=verbose  )
 5592    if hier['bf'] is not None:
 5593        dfout = myhelper( target_image, hier['bf'], transformlist,
 5594            "nbm3CH13", modality_name, dfout, extra='', verbose=verbose  )
 5595    # if hier['mtl'] is not None:
 5596    #    dfout = myhelper( target_image, hier['mtl'], reg,
 5597    #        "mtl_description", modality_name, dfout, extra='', verbose=verbose  )
 5598    return dfout
 5599
 5600def get_rsf_outputs( coords ):
 5601    if coords == 'powers':
 5602        return list([ 'meanBold', 'fmri_template', 'alff', 'falff', 'PerAF', 
 5603                   'CinguloopercularTaskControl', 'DefaultMode', 
 5604                   'MemoryRetrieval', 'VentralAttention', 'Visual',
 5605                   'FrontoparietalTaskControl', 'Salience', 'Subcortical', 'DorsalAttention'])
 5606    else:
 5607        yeo = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
 5608        return list( yeo['SystemName'].unique() )
 5609
 5610def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30,
 5611    transform=['rigid'], compreg=None, random_seed=42, verbose=False ):
 5612    """
 5613    multi-start multi-transform registration solution - based on ants.registration
 5614
 5615    fixed: fixed image
 5616
 5617    moving: moving image
 5618
 5619    n_simulations : number of simulations
 5620
 5621    max_rotation : maximum rotation angle
 5622
 5623    transform : list of transforms to loop through
 5624
 5625    compreg : registration results against which to compare
 5626
 5627    random_seed : random seed for reproducibility
 5628
 5629    verbose : boolean
 5630
 5631    """
 5632    import random
 5633    if random_seed is not None:
 5634        random.seed(random_seed)
 5635    if True:
 5636        output_directory = tempfile.mkdtemp()
 5637        output_directory_w = output_directory + "/tra_reg/"
 5638        os.makedirs(output_directory_w,exist_ok=True)
 5639        bestmi = math.inf
 5640        bestvar = 0.0
 5641        myorig = list(ants.get_origin( fixed ))
 5642        mymax = 0;
 5643        for k in range(len( myorig ) ):
 5644            if abs(myorig[k]) > mymax:
 5645                mymax = abs(myorig[k])
 5646        maxtrans = mymax * 0.05
 5647        if compreg is None:
 5648            bestreg=ants.registration( fixed,moving,'Translation',
 5649                outprefix=output_directory_w+"trans")
 5650            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
 5651        else :
 5652            bestreg=compreg
 5653            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
 5654        for mytx in transform:
 5655            regtx = 'antsRegistrationSyNRepro[r]'
 5656            with tempfile.NamedTemporaryFile(suffix='.h5') as tp:
 5657                if mytx == 'translation':
 5658                    regtx = 'Translation'
 5659                    rRotGenerator = ants.contrib.RandomTranslate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
 5660                elif mytx == 'affine':
 5661                    regtx = 'Affine'
 5662                    rRotGenerator = ants.contrib.RandomRotate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
 5663                else:
 5664                    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rotation*(-1.0), max_rotation ), reference=fixed )
 5665                for k in range(n_simulations):
 5666                    simtx = ants.compose_ants_transforms( [rRotGenerator.transform(), initx] )
 5667                    ants.write_transform( simtx, tp.name )
 5668                    if k > 0:
 5669                        reg = ants.registration( fixed, moving, regtx,
 5670                            initial_transform=tp.name,
 5671                            outprefix=output_directory_w+"reg"+str(k),
 5672                            verbose=False )
 5673                    else:
 5674                        reg = ants.registration( fixed, moving,
 5675                            regtx,
 5676                            outprefix=output_directory_w+"reg"+str(k),
 5677                            verbose=False )
 5678                    mymi = math.inf
 5679                    temp = reg['warpedmovout']
 5680                    myvar = temp.numpy().var()
 5681                    if verbose:
 5682                        print( str(k) + " : " + regtx  + " : " + mytx + " _var_ " + str( myvar ) )
 5683                    if myvar > 0 :
 5684                        mymi = ants.image_mutual_information( fixed, temp )
 5685                        if mymi < bestmi:
 5686                            if verbose:
 5687                                print( "mi @ " + str(k) + " : " + str(mymi), flush=True)
 5688                            bestmi = mymi
 5689                            bestreg = reg
 5690                            bestvar = myvar
 5691        if bestvar == 0.0 and compreg is not None:
 5692            return compreg        
 5693        return bestreg
 5694
 5695def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8,
 5696    bias_correct=True,
 5697    denoise=None,
 5698    srmodel=None,
 5699    target_range=[0,1],
 5700    poly_order='hist',
 5701    normalize_nm = False,
 5702    verbose=False ) :
 5703
 5704  """
 5705  Outputs the averaged and registered neuromelanin image, and neuromelanin labels
 5706
 5707  Arguments
 5708  ---------
 5709  list_nm_image : list of ANTsImages
 5710    list of neuromenlanin repeat images
 5711
 5712  t1 : ANTsImage
 5713    input 3-D T1 brain image
 5714
 5715  t1_head : ANTsImage
 5716    input 3-D T1 head image
 5717
 5718  t1lab : ANTsImage
 5719    t1 labels that will be propagated to the NM
 5720
 5721  brain_stem_dilation : integer default 8
 5722    dilates the brain stem mask to better match coverage of NM
 5723
 5724  bias_correct : boolean
 5725
 5726  denoise : None or integer
 5727
 5728  srmodel : None -- this is a work in progress feature, probably not optimal
 5729
 5730  target_range : 2-element tuple
 5731        a tuple or array defining the (min, max) of the input image
 5732        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 5733        intensity. This range should match the mapping used in the training
 5734        of the network.
 5735
 5736  poly_order : if not None, will fit a global regression model to map
 5737      intensity back to original histogram space; if 'hist' will match
 5738      by histogram matching - ants.histogram_match_image
 5739
 5740  normalize_nm : boolean - WIP not validated
 5741
 5742  verbose : boolean
 5743
 5744  Returns
 5745  ---------
 5746  Averaged and registered neuromelanin image and neuromelanin labels and wide csv
 5747
 5748  """
 5749
 5750  fnt=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni.nii.gz" )
 5751  fntNM=os.path.expanduser("~/.antspymm/CIT168_T1w_700um_pad_adni_NM_norm_avg.nii.gz" )
 5752  fntbst=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni_brainstem.nii.gz")
 5753  fnslab=os.path.expanduser("~/.antspyt1w/CIT168_MT_Slab_adni.nii.gz")
 5754  fntseg=os.path.expanduser("~/.antspyt1w/det_atlas_25_pad_LR_adni.nii.gz")
 5755
 5756  template = mm_read( fnt )
 5757  templateNM = ants.iMath( mm_read( fntNM ), "Normalize" )
 5758  templatebstem = mm_read( fntbst ).threshold_image( 1, 1000 )
 5759  # reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
 5760  reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
 5761  # map NM avg to t1 for neuromelanin processing
 5762  nmavg2t1 = ants.apply_transforms( t1, templateNM,
 5763    reg['fwdtransforms'], interpolator='linear' )
 5764  slab2t1 = ants.threshold_image( nmavg2t1, "Otsu", 2 ).threshold_image(1,2).iMath("MD",1).iMath("FillHoles")
 5765  # map brain stem and slab to t1 for neuromelanin processing
 5766  bstem2t1 = ants.apply_transforms( t1, templatebstem,
 5767    reg['fwdtransforms'],
 5768    interpolator='nearestNeighbor' ).iMath("MD",1)
 5769  slab2t1B = ants.apply_transforms( t1, mm_read( fnslab ),
 5770    reg['fwdtransforms'], interpolator = 'nearestNeighbor')
 5771  bstem2t1 = ants.crop_image( bstem2t1, slab2t1 )
 5772  cropper = ants.decrop_image( bstem2t1, slab2t1 ).iMath("MD",brain_stem_dilation)
 5773
 5774  # Average images in image_list
 5775  nm_avg = list_nm_images[0]*0.0
 5776  for k in range(len( list_nm_images )):
 5777    if denoise is not None:
 5778        list_nm_images[k] = ants.denoise_image( list_nm_images[k],
 5779            shrink_factor=1,
 5780            p=denoise,
 5781            r=denoise+1,
 5782            noise_model='Gaussian' )
 5783    if bias_correct :
 5784        n4mask = ants.threshold_image( ants.iMath(list_nm_images[k], "Normalize" ), 0.05, 1 )
 5785        list_nm_images[k] = ants.n4_bias_field_correction( list_nm_images[k], mask=n4mask )
 5786    nm_avg = nm_avg + ants.resample_image_to_target( list_nm_images[k], nm_avg ) / len( list_nm_images )
 5787
 5788  if verbose:
 5789      print("Register each nm image in list_nm_images to the averaged nm image (avg)")
 5790  nm_avg_new = nm_avg * 0.0
 5791  txlist = []
 5792  for k in range(len( list_nm_images )):
 5793    if verbose:
 5794        print(str(k) + " of " + str(len( list_nm_images ) ) )
 5795    current_image = ants.registration( list_nm_images[k], nm_avg,
 5796        type_of_transform = 'antsRegistrationSyNRepro[r]' )
 5797    txlist.append( current_image['fwdtransforms'][0] )
 5798    current_image = current_image['warpedfixout']
 5799    nm_avg_new = nm_avg_new + current_image / len( list_nm_images )
 5800  nm_avg = nm_avg_new
 5801
 5802  if verbose:
 5803      print("do slab registration to map anatomy to NM space")
 5804  t1c = ants.crop_image( t1_head, slab2t1 ).iMath("Normalize") # old way
 5805  nmavg2t1c = ants.crop_image( nmavg2t1, slab2t1 ).iMath("Normalize")
 5806  # slabreg = ants.registration( nm_avg, nmavg2t1c, 'antsRegistrationSyNRepro[r]' )
 5807  slabreg = tra_initializer( nm_avg, t1c, verbose=verbose )
 5808  if False:
 5809      slabregT1 = tra_initializer( nm_avg, t1c, verbose=verbose  )
 5810      miNM = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
 5811            ants.iMath(slabreg0['warpedmovout'],"Normalize") )
 5812      miT1 = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
 5813            ants.iMath(slabreg1['warpedmovout'],"Normalize") )
 5814      if miT1 < miNM:
 5815        slabreg = slabregT1
 5816  labels2nm = ants.apply_transforms( nm_avg, t1lab, slabreg['fwdtransforms'],
 5817    interpolator = 'genericLabel' )
 5818  cropper2nm = ants.apply_transforms( nm_avg, cropper, slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
 5819  nm_avg_cropped = ants.crop_image( nm_avg, cropper2nm )
 5820
 5821  if verbose:
 5822      print("now map these labels to each individual nm")
 5823  crop_mask_list = []
 5824  crop_nm_list = []
 5825  for k in range(len( list_nm_images )):
 5826      concattx = []
 5827      concattx.append( txlist[k] )
 5828      concattx.append( slabreg['fwdtransforms'][0] )
 5829      cropmask = ants.apply_transforms( list_nm_images[k], cropper,
 5830        concattx, interpolator = 'nearestNeighbor' )
 5831      crop_mask_list.append( cropmask )
 5832      temp = ants.crop_image( list_nm_images[k], cropmask )
 5833      crop_nm_list.append( temp )
 5834
 5835  if srmodel is not None:
 5836      if verbose:
 5837          print( " start sr " + str(len( crop_nm_list )) )
 5838      for k in range(len( crop_nm_list )):
 5839          if verbose:
 5840              print( " do sr " + str(k) )
 5841              print( crop_nm_list[k] )
 5842          temp = antspynet.apply_super_resolution_model_to_image(
 5843                crop_nm_list[k], srmodel, target_range=target_range,
 5844                regression_order=None )
 5845          if poly_order is not None:
 5846              bilin = ants.resample_image_to_target( crop_nm_list[k], temp )
 5847              if poly_order == 'hist':
 5848                  temp = ants.histogram_match_image( temp, bilin )
 5849              else:
 5850                  temp = antspynet.regression_match_image( temp, bilin, poly_order = poly_order )
 5851          crop_nm_list[k] = temp
 5852
 5853  nm_avg_cropped = crop_nm_list[0]*0.0
 5854  if verbose:
 5855      print( "cropped average" )
 5856      print( nm_avg_cropped )
 5857  for k in range(len( crop_nm_list )):
 5858      nm_avg_cropped = nm_avg_cropped + ants.apply_transforms( nm_avg_cropped,
 5859        crop_nm_list[k], txlist[k] ) / len( crop_nm_list )
 5860  for loop in range( 3 ):
 5861      nm_avg_cropped_new = nm_avg_cropped * 0.0
 5862      for k in range(len( crop_nm_list )):
 5863            myreg = ants.registration(
 5864                ants.iMath(nm_avg_cropped,"Normalize"),
 5865                ants.iMath(crop_nm_list[k],"Normalize"),
 5866                'antsRegistrationSyNRepro[r]' )
 5867            warpednext = ants.apply_transforms(
 5868                nm_avg_cropped_new,
 5869                crop_nm_list[k],
 5870                myreg['fwdtransforms'] )
 5871            nm_avg_cropped_new = nm_avg_cropped_new + warpednext
 5872      nm_avg_cropped = nm_avg_cropped_new / len( crop_nm_list )
 5873
 5874  slabregUpdated = tra_initializer( nm_avg_cropped, t1c, compreg=slabreg,verbose=verbose  )
 5875  tempOrig = ants.apply_transforms( nm_avg_cropped_new, t1c, slabreg['fwdtransforms'] )
 5876  tempUpdate = ants.apply_transforms( nm_avg_cropped_new, t1c, slabregUpdated['fwdtransforms'] )
 5877  miUpdate = ants.image_mutual_information(
 5878    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempUpdate,"Normalize") )
 5879  miOrig = ants.image_mutual_information(
 5880    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempOrig,"Normalize") )
 5881  if miUpdate < miOrig :
 5882      slabreg = slabregUpdated
 5883
 5884  if normalize_nm:
 5885      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
 5886      nm_avg_cropped = ants.iMath( nm_avg_cropped, "TruncateIntensity",0.05,0.95)
 5887      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
 5888
 5889  labels2nm = ants.apply_transforms( nm_avg_cropped, t1lab,
 5890        slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
 5891
 5892  # fix the reference region - keep top two parts
 5893  def get_biggest_part( x, labeln ):
 5894      temp33 = ants.threshold_image( x, labeln, labeln ).iMath("GetLargestComponent")
 5895      x[ x == labeln] = 0
 5896      x[ temp33 == 1 ] = labeln
 5897
 5898  get_biggest_part( labels2nm, 33 )
 5899  get_biggest_part( labels2nm, 34 )
 5900
 5901  if verbose:
 5902      print( "map summary measurements to wide format" )
 5903  nmdf = antspyt1w.map_intensity_to_dataframe(
 5904          'CIT168_Reinf_Learn_v1_label_descriptions_pad',
 5905          nm_avg_cropped,
 5906          labels2nm)
 5907  if verbose:
 5908      print( "merge to wide format" )
 5909  nmdf_wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 5910              {'NM' : nmdf},
 5911              col_names = ['Mean'] )
 5912
 5913  rr_mask = ants.mask_image( labels2nm, labels2nm, [33,34] , binarize=True )
 5914  sn_mask = ants.mask_image( labels2nm, labels2nm, [7,9,23,25] , binarize=True )
 5915  nmavgsnr = mask_snr( nm_avg_cropped, rr_mask, sn_mask, bias_correct = False )
 5916
 5917  snavg = nm_avg_cropped[ sn_mask == 1].mean()
 5918  rravg = nm_avg_cropped[ rr_mask == 1].mean()
 5919  snstd = nm_avg_cropped[ sn_mask == 1].std()
 5920  rrstd = nm_avg_cropped[ rr_mask == 1].std()
 5921  vol_element = np.prod( ants.get_spacing(sn_mask) )
 5922  snvol = vol_element * sn_mask.sum()
 5923
 5924  # get the mean voxel position of the SN
 5925  if snvol > 0:
 5926      sn_z = ants.transform_physical_point_to_index( sn_mask, ants.get_center_of_mass(sn_mask ))[2]
 5927      sn_z = sn_z/sn_mask.shape[2] # around 0.5 would be nice
 5928  else:
 5929      sn_z = math.nan
 5930
 5931  nm_evr = 0.0
 5932  if cropper2nm.sum() > 0:
 5933    nm_evr = antspyt1w.patch_eigenvalue_ratio( nm_avg, 512, [6,6,6], 
 5934        evdepth = 0.9, mask=cropper2nm )
 5935
 5936  simg = ants.smooth_image( nm_avg_cropped, np.min(ants.get_spacing(nm_avg_cropped)) )
 5937  k = 2.0
 5938  rrthresh = (rravg + k * rrstd)
 5939  nmabovekthresh_mask = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5940  snvolabovethresh = vol_element * nmabovekthresh_mask.sum()
 5941  snintmeanabovethresh = float( ( simg * nmabovekthresh_mask ).mean() )
 5942  snintsumabovethresh = float( ( simg * nmabovekthresh_mask ).sum() )
 5943
 5944  k = 3.0
 5945  rrthresh = (rravg + k * rrstd)
 5946  nmabovekthresh_mask3 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5947  snvolabovethresh3 = vol_element * nmabovekthresh_mask3.sum()
 5948
 5949  k = 1.0
 5950  rrthresh = (rravg + k * rrstd)
 5951  nmabovekthresh_mask1 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
 5952  snvolabovethresh1 = vol_element * nmabovekthresh_mask1.sum()
 5953  
 5954  if verbose:
 5955    print( "nm vol @2std above rrmean: " + str( snvolabovethresh ) )
 5956    print( "nm intmean @2std above rrmean: " + str( snintmeanabovethresh ) )
 5957    print( "nm intsum @2std above rrmean: " + str( snintsumabovethresh ) )
 5958    print( "nm done" )
 5959
 5960  return convert_np_in_dict( {
 5961      'NM_avg' : nm_avg,
 5962      'NM_avg_cropped' : nm_avg_cropped,
 5963      'NM_labels': labels2nm,
 5964      'NM_cropped': crop_nm_list,
 5965      'NM_midbrainROI': cropper2nm,
 5966      'NM_dataframe': nmdf,
 5967      'NM_dataframe_wide': nmdf_wide,
 5968      't1_to_NM': slabreg['warpedmovout'],
 5969      't1_to_NM_transform' : slabreg['fwdtransforms'],
 5970      'NM_avg_signaltonoise' : nmavgsnr,
 5971      'NM_avg_substantianigra' : snavg,
 5972      'NM_std_substantianigra' : snstd,
 5973      'NM_volume_substantianigra' : snvol,
 5974      'NM_volume_substantianigra_1std' : snvolabovethresh1,
 5975      'NM_volume_substantianigra_2std' : snvolabovethresh,
 5976      'NM_intmean_substantianigra_2std' : snintmeanabovethresh,
 5977      'NM_intsum_substantianigra_2std' : snintsumabovethresh,
 5978      'NM_volume_substantianigra_3std' : snvolabovethresh3,
 5979      'NM_avg_refregion' : rravg,
 5980      'NM_std_refregion' : rrstd,
 5981      'NM_min' : nm_avg_cropped.min(),
 5982      'NM_max' : nm_avg_cropped.max(),
 5983      'NM_mean' : nm_avg_cropped.numpy().mean(),
 5984      'NM_sd' : np.std( nm_avg_cropped.numpy() ),
 5985      'NM_q0pt05' : np.quantile( nm_avg_cropped.numpy(), 0.05 ),
 5986      'NM_q0pt10' : np.quantile( nm_avg_cropped.numpy(), 0.10 ),
 5987      'NM_q0pt90' : np.quantile( nm_avg_cropped.numpy(), 0.90 ),
 5988      'NM_q0pt95' : np.quantile( nm_avg_cropped.numpy(), 0.95 ),
 5989      'NM_substantianigra_z_coordinate' : sn_z,
 5990      'NM_evr' : nm_evr,
 5991      'NM_count': len( list_nm_images )
 5992       } )
 5993
 5994
 5995
 5996def estimate_optimal_pca_components(data, variance_threshold=0.80, plot=False):
 5997    """
 5998    Estimate the optimal number of PCA components to represent the given data.
 5999
 6000    :param data: The data matrix (samples x features).
 6001    :param variance_threshold: Threshold for cumulative explained variance (default 0.95).
 6002    :param plot: If True, plot the cumulative explained variance graph (default False).
 6003    :return: The optimal number of principal components.
 6004    """
 6005    import numpy as np
 6006    from sklearn.decomposition import PCA
 6007    import matplotlib.pyplot as plt
 6008
 6009    # Perform PCA
 6010    pca = PCA()
 6011    pca.fit(data)
 6012
 6013    # Calculate cumulative explained variance
 6014    cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
 6015
 6016    # Determine the number of components for desired explained variance
 6017    n_components = np.where(cumulative_variance >= variance_threshold)[0][0] + 1
 6018
 6019    # Optionally plot the explained variance
 6020    if plot:
 6021        plt.figure(figsize=(8, 4))
 6022        plt.plot(cumulative_variance, linewidth=2)
 6023        plt.axhline(y=variance_threshold, color='r', linestyle='--')
 6024        plt.axvline(x=n_components - 1, color='r', linestyle='--')
 6025        plt.xlabel('Number of Components')
 6026        plt.ylabel('Cumulative Explained Variance')
 6027        plt.title('Explained Variance by Number of Principal Components')
 6028        plt.show()
 6029
 6030    return n_components
 6031
 6032import numpy as np
 6033
 6034def compute_PerAF_voxel(time_series):
 6035    """
 6036    Compute the Percentage Amplitude Fluctuation (PerAF) for a given time series.
 6037
 6038    10.1371/journal.pone.0227021
 6039
 6040    PerAF = 100/n * sum(|(x_i - m)/m|) 
 6041    where m = 1/n * sum(x_i), x_i is the signal intensity at each time point, 
 6042    and n is the total number of time points.
 6043
 6044    :param time_series: Numpy array of time series data
 6045    :return: Computed PerAF value
 6046    """
 6047    n = len(time_series)
 6048    m = np.mean(time_series)
 6049    perAF = 100 / n * np.sum(np.abs((time_series - m) / m))
 6050    return perAF
 6051
 6052def calculate_trimmed_mean(data, proportion_to_trim):
 6053    """
 6054    Calculate the trimmed mean for a given data array.
 6055
 6056    :param data: A numpy array of data.
 6057    :param proportion_to_trim: Proportion (0 to 0.5) of data to trim from each end.
 6058    :return: The trimmed mean of the data.
 6059    """
 6060    from scipy import stats
 6061    return stats.trim_mean(data, proportion_to_trim)
 6062
 6063def PerAF( x, mask, globalmean=True ):
 6064    """
 6065    Compute the Percentage Amplitude Fluctuation (PerAF) for a given time series.
 6066
 6067    10.1371/journal.pone.0227021
 6068
 6069    PerAF = 100/n * sum(|(x_i - m)/m|) 
 6070    where m = 1/n * sum(x_i), x_i is the signal intensity at each time point, 
 6071    and n is the total number of time points.
 6072
 6073    :param x: time series antsImage
 6074    :param mask: brain mask
 6075    :param globalmean: boolean if True divide by the globalmean in the brain mask
 6076    :return: Computed PerAF image
 6077    """
 6078    time_series = ants.timeseries_to_matrix( x, mask )
 6079    n = time_series.shape[1]
 6080    vec = np.zeros( n )
 6081    for i in range(n):
 6082        vec[i] = compute_PerAF_voxel( time_series[:,i] )
 6083    outimg = ants.make_image( mask, vec )
 6084    if globalmean:
 6085        outimg = outimg / calculate_trimmed_mean( vec, 0.01 )
 6086    return outimg
 6087
 6088
 6089
 6090def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation,
 6091    f=[0.03, 0.08],
 6092    FD_threshold=5.0,
 6093    spa = None,
 6094    spt = None,
 6095    nc = 5,
 6096    outlier_threshold=0.250,
 6097    ica_components = 0,
 6098    impute = True,
 6099    censor = True,
 6100    despike = 2.5,
 6101    motion_as_nuisance = True,
 6102    powers = False,
 6103    upsample = 3.0,
 6104    clean_tmp = None,
 6105    paramset='unset',
 6106    verbose=False ):
 6107  """
 6108  Compute resting state network correlation maps based on the J Power labels.
 6109  This will output a map for each of the major network systems.  This function 
 6110  will by optionally upsample data to 2mm during the registration process if data 
 6111  is below that resolution.
 6112
 6113  registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations
 6114
 6115  Arguments
 6116  ---------
 6117  fmri : BOLD fmri antsImage
 6118
 6119  fmri_template : reference space for BOLD
 6120
 6121  t1 : ANTsImage
 6122    input 3-D T1 brain image (brain extracted)
 6123
 6124  t1segmentation : ANTsImage
 6125    t1 segmentation - a six tissue segmentation image in T1 space
 6126
 6127  f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015
 6128
 6129  spa : gaussian smoothing for spatial component (physical coordinates)
 6130
 6131  spt : gaussian smoothing for temporal component
 6132
 6133  nc  : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm
 6134
 6135  ica_components : integer if greater than 0 then include ica components
 6136
 6137  impute : boolean if True, then use imputation in f/ALFF, PerAF calculation
 6138
 6139  censor : boolean if True, then use censoring (censoring)
 6140
 6141  despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction
 6142
 6143  motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance
 6144
 6145  powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)
 6146
 6147  upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution
 6148
 6149  clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing.  if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.
 6150
 6151  verbose : boolean
 6152
 6153  Returns
 6154  ---------
 6155  a dictionary containing the derived network maps
 6156
 6157  References
 6158  ---------
 6159
 6160  10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."
 6161
 6162  10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable"  "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding  high frequency information) reduced both reliability and signal-noise separation"
 6163
 6164  10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.
 6165
 6166  10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.
 6167
 6168  """
 6169
 6170  import warnings
 6171
 6172  if clean_tmp is not None:
 6173    clean_tmp_directory( age_hours = clean_tmp )
 6174
 6175  if nc > 1:
 6176    nc = int(nc)
 6177  else:
 6178    nc=float(nc)
 6179
 6180  type_of_transform="antsRegistrationSyNQuickRepro[r]" # , # should probably not change this
 6181  remove_it=True
 6182  output_directory = tempfile.mkdtemp()
 6183  output_directory_w = output_directory + "/ts_t1_reg/"
 6184  os.makedirs(output_directory_w,exist_ok=True)
 6185  ofnt1tx = tempfile.NamedTemporaryFile(delete=False,suffix='t1_deformation',dir=output_directory_w).name
 6186
 6187  import numpy as np
 6188# Assuming core and utils are modules or packages with necessary functions
 6189
 6190  if upsample > 0.0:
 6191      spc = ants.get_spacing( fmri )
 6192      minspc = upsample
 6193      if min(spc[0:3]) < minspc:
 6194          minspc = min(spc[0:3])
 6195      newspc = [minspc,minspc,minspc]
 6196      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
 6197
 6198  def temporal_derivative_same_shape(array):
 6199    """
 6200    Compute the temporal derivative of a 2D numpy array along the 0th axis (time)
 6201    and ensure the output has the same shape as the input.
 6202
 6203    :param array: 2D numpy array with time as the 0th axis.
 6204    :return: 2D numpy array of the temporal derivative with the same shape as input.
 6205    """
 6206    derivative = np.diff(array, axis=0)
 6207    
 6208    # Append a row to maintain the same shape
 6209    # You can choose to append a row of zeros or the last row of the derivative
 6210    # Here, a row of zeros is appended
 6211    zeros_row = np.zeros((1, array.shape[1]))
 6212    return np.vstack((zeros_row, derivative ))
 6213
 6214  def compute_tSTD(M, quantile, x=0, axis=0):
 6215    stdM = np.std(M, axis=axis)
 6216    # set bad values to x
 6217    stdM[stdM == 0] = x
 6218    stdM[np.isnan(stdM)] = x
 6219    tt = round(quantile * 100)
 6220    threshold_std = np.percentile(stdM, tt)
 6221    return {'tSTD': stdM, 'threshold_std': threshold_std}
 6222
 6223  def get_compcor_matrix(boldImage, mask, quantile):
 6224    """
 6225    Compute the compcor matrix.
 6226
 6227    :param boldImage: The bold image.
 6228    :param mask: The mask to apply, if None, it will be computed.
 6229    :param quantile: Quantile for computing threshold in tSTD.
 6230    :return: The compor matrix.
 6231    """
 6232    if mask is None:
 6233        temp = ants.slice_image(boldImage, axis=boldImage.dimension - 1, idx=0)
 6234        mask = ants.get_mask(temp)
 6235
 6236    imagematrix = ants.timeseries_to_matrix(boldImage, mask)
 6237    temp = compute_tSTD(imagematrix, quantile, 0)
 6238    tsnrmask = ants.make_image(mask, temp['tSTD'])
 6239    tsnrmask = ants.threshold_image(tsnrmask, temp['threshold_std'], temp['tSTD'].max())
 6240    M = ants.timeseries_to_matrix(boldImage, tsnrmask)
 6241    return M
 6242
 6243
 6244  from sklearn.decomposition import FastICA
 6245  def find_indices(lst, value):
 6246    return [index for index, element in enumerate(lst) if element > value]
 6247
 6248  def mean_of_list(lst):
 6249    if not lst:  # Check if the list is not empty
 6250        return 0  # Return 0 or appropriate value for an empty list
 6251    return sum(lst) / len(lst)
 6252  fmrispc = list( ants.get_spacing( fmri ) )
 6253  if spa is None:
 6254    spa = mean_of_list( fmrispc[0:3] ) * 1.0
 6255  if spt is None:
 6256    spt = fmrispc[3] * 0.5
 6257      
 6258  import numpy as np
 6259  import pandas as pd
 6260  import re
 6261  import math
 6262  # point data resources
 6263  A = np.zeros((1,1))
 6264  dfnname='DefaultMode'
 6265  if powers:
 6266      powers_areal_mni_itk = pd.read_csv( get_data('powers_mni_itk', target_extension=".csv")) # power coordinates
 6267      coords='powers'
 6268  else:
 6269      powers_areal_mni_itk = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
 6270      coords='yeo_17_500_2023'
 6271  fmri = ants.iMath( fmri, 'Normalize' )
 6272  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("FillHoles")
 6273  if verbose:
 6274      print("Begin rsfmri motion correction")
 6275  debug=False
 6276  if debug:
 6277      ants.image_write( fmri_template, '/tmp/fmri_template.nii.gz' )
 6278      ants.image_write( fmri, '/tmp/fmri.nii.gz' )
 6279      print("debug wrote fmri and fmri_template")
 6280  # mot-co
 6281  corrmo = timeseries_reg(
 6282    fmri, fmri_template,
 6283    type_of_transform=type_of_transform,
 6284    total_sigma=0.5,
 6285    fdOffset=2.0,
 6286    trim = 8,
 6287    output_directory=None,
 6288    verbose=verbose,
 6289    syn_metric='CC',
 6290    syn_sampling=2,
 6291    reg_iterations=[40,20,5],
 6292    return_numpy_motion_parameters=True )
 6293  
 6294  if verbose:
 6295      print("End rsfmri motion correction")
 6296      print("--maximum motion : " + str(corrmo['FD'].max()) )
 6297      print("=== next anatomically based mapping ===")
 6298
 6299  despiking_count = np.zeros( corrmo['motion_corrected'].shape[3] )
 6300  if despike > 0.0:
 6301      corrmo['motion_corrected'], despiking_count = despike_time_series_afni( corrmo['motion_corrected'], c1=despike )
 6302
 6303  despiking_count_summary = despiking_count.sum() / np.prod( corrmo['motion_corrected'].shape )
 6304  high_motion_count=(corrmo['FD'] > FD_threshold ).sum()
 6305  high_motion_pct=high_motion_count / fmri.shape[3]
 6306
 6307  # filter mask based on TSNR
 6308  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 6309  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 6310  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",2)
 6311  bmask = bmask * tsnrmask
 6312
 6313  # anatomical mapping
 6314  und = fmri_template * bmask
 6315  t1reg = ants.registration( und, t1,
 6316     "antsRegistrationSyNQuickRepro[s]", outprefix=ofnt1tx )
 6317  if verbose:
 6318    print("t1 2 bold done")
 6319  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 6320  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 6321  gmseg = ants.threshold_image( gmseg, 1, 4 )
 6322  gmseg = ants.iMath( gmseg, 'MD', 1 ) # FIXMERSF
 6323  gmseg = ants.apply_transforms( und, gmseg,
 6324    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' ) * bmask
 6325  csfAndWM = ( ants.threshold_image( t1segmentation, 1, 1 ) +
 6326               ants.threshold_image( t1segmentation, 3, 3 ) ).morphology("erode",1)
 6327  csfAndWM = ants.apply_transforms( und, csfAndWM,
 6328    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6329  csf = ants.threshold_image( t1segmentation, 1, 1 )
 6330  csf = ants.apply_transforms( und, csf, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6331  wm = ants.threshold_image( t1segmentation, 3, 3 ).morphology("erode",1)
 6332  wm = ants.apply_transforms( und, wm, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 6333  if powers:
 6334    ch2 = mm_read( ants.get_ants_data( "ch2" ) )
 6335  else:
 6336    ch2 = mm_read( get_data( "PPMI_template0_brain", target_extension='.nii.gz' ) )
 6337  treg = ants.registration( 
 6338    # this is to make the impact of resolution consistent
 6339    ants.resample_image(t1, [1.0,1.0,1.0], interp_type=0), 
 6340    ch2, "antsRegistrationSyNQuickRepro[s]" )
 6341  if powers:
 6342    concatx2 = treg['invtransforms'] + t1reg['invtransforms']
 6343    pts2bold = ants.apply_transforms_to_points( 3, powers_areal_mni_itk, concatx2,
 6344        whichtoinvert = ( True, False, True, False ) )
 6345    locations = pts2bold.iloc[:,:3].values
 6346    ptImg = ants.make_points_image( locations, bmask, radius = 2 )
 6347  else:
 6348    concatx2 = t1reg['fwdtransforms'] + treg['fwdtransforms']    
 6349    rsfsegfn = get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".nii.gz")
 6350    rsfsegimg = ants.image_read( rsfsegfn )
 6351    ptImg = ants.apply_transforms( und, rsfsegimg, concatx2, interpolator='nearestNeighbor' ) * bmask
 6352    pts2bold = powers_areal_mni_itk
 6353    # ants.plot( und, ptImg, crop=True, axis=2 )
 6354
 6355  # optional smoothing
 6356  tr = ants.get_spacing( corrmo['motion_corrected'] )[3]
 6357  smth = ( spa, spa, spa, spt ) # this is for sigmaInPhysicalCoordinates = TRUE
 6358  simg = ants.smooth_image( corrmo['motion_corrected'], smth, sigma_in_physical_coordinates = True )
 6359
 6360  # collect censoring indices
 6361  hlinds = find_indices( corrmo['FD'], FD_threshold )
 6362  if verbose:
 6363    print("high motion indices")
 6364    print( hlinds )
 6365  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6366    fmrimotcorr, hlinds2 = loop_timeseries_censoring( corrmo['motion_corrected'], 
 6367      threshold=outlier_threshold, verbose=verbose )
 6368    hlinds.extend( hlinds2 )
 6369    del fmrimotcorr
 6370  hlinds = list(set(hlinds)) # make unique
 6371
 6372  # nuisance
 6373  globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], bmask )
 6374  globalsignal = np.nanmean( globalmat, axis = 1 )
 6375  del globalmat
 6376  compcorquantile=0.50
 6377  nc_wm=nc_csf=nc
 6378  if nc < 1:
 6379    globalmat = get_compcor_matrix( corrmo['motion_corrected'], wm, compcorquantile )
 6380    nc_wm = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
 6381    globalmat = get_compcor_matrix( corrmo['motion_corrected'], csf, compcorquantile )
 6382    nc_csf = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
 6383    del globalmat
 6384  if verbose:
 6385    print("include compcor components as nuisance: csf " + str(nc_csf) + " wm " + str(nc_wm))
 6386  mycompcor_csf = ants.compcor( corrmo['motion_corrected'],
 6387    ncompcor=nc_csf, quantile=compcorquantile, mask = csf,
 6388    filter_type='polynomial', degree=2 )
 6389  mycompcor_wm = ants.compcor( corrmo['motion_corrected'],
 6390    ncompcor=nc_wm, quantile=compcorquantile, mask = wm,
 6391    filter_type='polynomial', degree=2 )
 6392  nuisance = np.c_[ mycompcor_csf[ 'components' ], mycompcor_wm[ 'components' ] ]
 6393
 6394  if motion_as_nuisance:
 6395      if verbose:
 6396          print("include motion as nuisance")
 6397          print( corrmo['motion_parameters'].shape )
 6398      deriv = temporal_derivative_same_shape( corrmo['motion_parameters']  )
 6399      nuisance = np.c_[ nuisance, corrmo['motion_parameters'], deriv ]
 6400
 6401  if ica_components > 0:
 6402    if verbose:
 6403        print("include ica components as nuisance: " + str(ica_components))
 6404    ica = FastICA(n_components=ica_components, max_iter=10000, tol=0.001, random_state=42 )
 6405    globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], csfAndWM )
 6406    nuisance_ica = ica.fit_transform(globalmat)  # Reconstruct signals
 6407    nuisance = np.c_[ nuisance, nuisance_ica ]
 6408    del globalmat
 6409
 6410  # concat all nuisance data
 6411  # nuisance = np.c_[ nuisance, mycompcor['basis'] ]
 6412  # nuisance = np.c_[ nuisance, corrmo['FD'] ]
 6413  nuisance = np.c_[ nuisance, globalsignal ]
 6414
 6415  if impute:
 6416    simgimp = impute_timeseries( simg, hlinds, method='linear')
 6417  else:
 6418    simgimp = simg
 6419
 6420  # falff/alff stuff  def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
 6421  myfalff=alff_image( simgimp, bmask, flo=f[0], fhi=f[1], nuisance=nuisance  )
 6422
 6423  # bandpass any data collected before here -- if bandpass requested
 6424  if f[0] > 0 and f[1] < 1.0:
 6425    if verbose:
 6426        print( "bandpass: " + str(f[0]) + " <=> " + str( f[1] ) )
 6427    nuisance = ants.bandpass_filter_matrix( nuisance, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
 6428    globalmat = ants.timeseries_to_matrix( simg, bmask )
 6429    globalmat = ants.bandpass_filter_matrix( globalmat, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
 6430    simg = ants.matrix_to_timeseries( simg, globalmat, bmask )
 6431
 6432  if verbose:
 6433    print("now regress nuisance")
 6434
 6435
 6436  if len( hlinds ) > 0 :
 6437    if censor:
 6438        nuisance = remove_elements_from_numpy_array( nuisance, hlinds  )
 6439        simg = remove_volumes_from_timeseries( simg, hlinds )
 6440
 6441  gmmat = ants.timeseries_to_matrix( simg, bmask )
 6442  gmmat = ants.regress_components( gmmat, nuisance )
 6443  simg = ants.matrix_to_timeseries(simg, gmmat, bmask)
 6444
 6445
 6446  # structure the output data
 6447  outdict = {}
 6448  outdict['paramset'] = paramset
 6449  outdict['upsampling'] = upsample
 6450  outdict['coords'] = coords
 6451  outdict['dfnname']=dfnname
 6452  outdict['meanBold'] = und
 6453
 6454  # add correlation matrix that captures each node pair
 6455  # some of the spheres overlap so extract separately from each ROI
 6456  if powers:
 6457    nPoints = int(pts2bold['ROI'].max())
 6458    pointrange = list(range(int(nPoints)))
 6459  else:
 6460    nPoints = int(ptImg.max())
 6461    pointrange = list(range(int(nPoints)))
 6462  nVolumes = simg.shape[3]
 6463  meanROI = np.zeros([nVolumes, nPoints])
 6464  roiNames = []
 6465  if debug:
 6466      ptImgAll = und * 0.
 6467  for i in pointrange:
 6468    # specify name for matrix entries that's links back to ROI number and network; e.g., ROI1_Uncertain
 6469    netLabel = re.sub( " ", "", pts2bold.loc[i,'SystemName'])
 6470    netLabel = re.sub( "-", "", netLabel )
 6471    netLabel = re.sub( "/", "", netLabel )
 6472    roiLabel = "ROI" + str(pts2bold.loc[i,'ROI']) + '_' + netLabel
 6473    roiNames.append( roiLabel )
 6474    if powers:
 6475        ptImage = ants.make_points_image(pts2bold.iloc[[i],:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6476    else:
 6477        #print("Doing " + pts2bold.loc[i,'SystemName'] + " at " + str(i) )
 6478        #ptImage = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==pts2bold.loc[i,'SystemName']],binarize=True)
 6479        ptImage=ants.threshold_image( ptImg, pts2bold.loc[i,'ROI'], pts2bold.loc[i,'ROI'] )
 6480    if debug:
 6481      ptImgAll = ptImgAll + ptImage
 6482    if ptImage.sum() > 0 :
 6483        meanROI[:,i] = ants.timeseries_to_matrix( simg, ptImage).mean(axis=1)
 6484
 6485  if debug:
 6486      ants.image_write( simg, '/tmp/simg.nii.gz' )
 6487      ants.image_write( ptImgAll, '/tmp/ptImgAll.nii.gz' )
 6488      ants.image_write( und, '/tmp/und.nii.gz' )
 6489      ants.image_write( und, '/tmp/und.nii.gz' )
 6490
 6491  # get full correlation matrix
 6492  corMat = np.corrcoef(meanROI, rowvar=False)
 6493  outputMat = pd.DataFrame(corMat)
 6494  outputMat.columns = roiNames
 6495  outputMat['ROIs'] = roiNames
 6496  # add to dictionary
 6497  outdict['fullCorrMat'] = outputMat
 6498
 6499  networks = powers_areal_mni_itk['SystemName'].unique()
 6500  # this is just for human readability - reminds us of which we choose by default
 6501  if powers:
 6502    netnames = ['Cingulo-opercular Task Control', 'Default Mode',
 6503                    'Memory Retrieval', 'Ventral Attention', 'Visual',
 6504                    'Fronto-parietal Task Control', 'Salience', 'Subcortical',
 6505                    'Dorsal Attention']
 6506    numofnets = [3,5,6,7,8,9,10,11,13]
 6507  else:
 6508    netnames = networks
 6509    numofnets = list(range(len(netnames)))
 6510 
 6511  ct = 0
 6512  for mynet in numofnets:
 6513    netname = re.sub( " ", "", networks[mynet] )
 6514    netname = re.sub( "-", "", netname )
 6515    ww = np.where( powers_areal_mni_itk['SystemName'] == networks[mynet] )[0]
 6516    if powers:
 6517        dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6518    else:
 6519        dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[mynet]],binarize=True)
 6520    if dfnImg.max() >= 1:
 6521        if verbose:
 6522            print("DO: " + coords + " " + netname )
 6523        dfnmat = ants.timeseries_to_matrix( simg, ants.threshold_image( dfnImg, 1, dfnImg.max() ) )
 6524        dfnsignal = np.nanmean( dfnmat, axis = 1 )
 6525        nan_count_dfn = np.count_nonzero( np.isnan( dfnsignal) )
 6526        if nan_count_dfn > 0 :
 6527            warnings.warn( " mynet " + netnames[ mynet ] + " vs " +  " mean-signal has nans " + str( nan_count_dfn ) ) 
 6528        gmmatDFNCorr = np.zeros( gmmat.shape[1] )
 6529        if nan_count_dfn == 0:
 6530            for k in range( gmmat.shape[1] ):
 6531                nan_count_gm = np.count_nonzero( np.isnan( gmmat[:,k]) )
 6532                if debug and False:
 6533                    print( str( k ) +  " nans gm " + str(nan_count_gm)  )
 6534                if nan_count_gm == 0:
 6535                    gmmatDFNCorr[ k ] = pearsonr( dfnsignal, gmmat[:,k] )[0]
 6536        corrImg = ants.make_image( bmask, gmmatDFNCorr  )
 6537        outdict[ netname ] = corrImg * gmseg
 6538    else:
 6539        outdict[ netname ] = None
 6540    ct = ct + 1
 6541
 6542  A = np.zeros( ( len( numofnets ) , len( numofnets ) ) )
 6543  A_wide = np.zeros( ( 1, len( numofnets ) * len( numofnets ) ) )
 6544  newnames=[]
 6545  newnames_wide=[]
 6546  ct = 0
 6547  for i in range( len( numofnets ) ):
 6548      netnamei = re.sub( " ", "", networks[numofnets[i]] )
 6549      netnamei = re.sub( "-", "", netnamei )
 6550      newnames.append( netnamei  )
 6551      ww = np.where( powers_areal_mni_itk['SystemName'] == networks[numofnets[i]] )[0]
 6552      if powers:
 6553          dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
 6554      else:
 6555          dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[numofnets[i]]],binarize=True)
 6556      for j in range( len( numofnets ) ):
 6557          netnamej = re.sub( " ", "", networks[numofnets[j]] )
 6558          netnamej = re.sub( "-", "", netnamej )
 6559          newnames_wide.append( netnamei + "_2_" + netnamej )
 6560          A[i,j] = 0
 6561          if dfnImg is not None and netnamej is not None:
 6562            subbit = dfnImg == 1
 6563            if subbit is not None:
 6564                if subbit.sum() > 0 and netnamej in outdict:
 6565                    A[i,j] = outdict[ netnamej ][ subbit ].mean()
 6566          A_wide[0,ct] = A[i,j]
 6567          ct=ct+1
 6568
 6569  A = pd.DataFrame( A )
 6570  A.columns = newnames
 6571  A['networks']=newnames
 6572  A_wide = pd.DataFrame( A_wide )
 6573  A_wide.columns = newnames_wide
 6574  outdict['corr'] = A
 6575  outdict['corr_wide'] = A_wide
 6576  outdict['fmri_template'] = fmri_template
 6577  outdict['brainmask'] = bmask
 6578  outdict['gmmask'] = gmseg
 6579  outdict['alff'] = myfalff['alff']
 6580  outdict['falff'] = myfalff['falff']
 6581  # add global mean and standard deviation for post-hoc z-scoring
 6582  outdict['alff_mean'] = (myfalff['alff'][myfalff['alff']!=0]).mean()
 6583  outdict['alff_sd'] = (myfalff['alff'][myfalff['alff']!=0]).std()
 6584  outdict['falff_mean'] = (myfalff['falff'][myfalff['falff']!=0]).mean()
 6585  outdict['falff_sd'] = (myfalff['falff'][myfalff['falff']!=0]).std()
 6586
 6587  perafimg = PerAF( simgimp, bmask )
 6588  for k in pointrange:
 6589    anatname=( pts2bold['AAL'][k] )
 6590    if isinstance(anatname, str):
 6591        anatname = re.sub("_","",anatname)
 6592    else:
 6593        anatname='Unk'
 6594    if powers:
 6595        kk = f"{k:0>3}"+"_"
 6596    else:
 6597        kk = f"{k % int(nPoints/2):0>3}"+"_"
 6598    fname='falffPoint'+kk+anatname
 6599    aname='alffPoint'+kk+anatname
 6600    pname='perafPoint'+kk+anatname
 6601    localsel = ptImg == k
 6602    if localsel.sum() > 0 : # check if non-empty
 6603        outdict[fname]=(outdict['falff'][localsel]).mean()
 6604        outdict[aname]=(outdict['alff'][localsel]).mean()
 6605        outdict[pname]=(perafimg[localsel]).mean()
 6606    else:
 6607        outdict[fname]=math.nan
 6608        outdict[aname]=math.nan
 6609        outdict[pname]=math.nan
 6610
 6611  rsfNuisance = pd.DataFrame( nuisance )
 6612  if remove_it:
 6613    import shutil
 6614    shutil.rmtree(output_directory, ignore_errors=True )
 6615
 6616  if not powers:
 6617    dfnsum=outdict['DefaultA']+outdict['DefaultB']+outdict['DefaultC']
 6618    outdict['DefaultMode']=dfnsum
 6619    dfnsum=outdict['VisCent']+outdict['VisPeri']
 6620    outdict['Visual']=dfnsum
 6621
 6622  nonbrainmask = ants.iMath( bmask, "MD",2) - bmask
 6623  trimmask = ants.iMath( bmask, "ME",2)
 6624  edgemask = ants.iMath( bmask, "ME",1) - trimmask
 6625  outdict['motion_corrected'] = corrmo['motion_corrected']
 6626  outdict['nuisance'] = rsfNuisance
 6627  outdict['PerAF'] = perafimg
 6628  outdict['tsnr'] = mytsnr
 6629  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
 6630  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 6631  outdict['bandpass_freq_0']=f[0]
 6632  outdict['bandpass_freq_1']=f[1]
 6633  outdict['censor']=int(censor)
 6634  outdict['spatial_smoothing']=spa
 6635  outdict['outlier_threshold']=outlier_threshold
 6636  outdict['FD_threshold']=outlier_threshold
 6637  outdict['high_motion_count'] = high_motion_count
 6638  outdict['high_motion_pct'] = high_motion_pct
 6639  outdict['despiking_count_summary'] = despiking_count_summary
 6640  outdict['FD_max'] = corrmo['FD'].max()
 6641  outdict['FD_mean'] = corrmo['FD'].mean()
 6642  outdict['FD_sd'] = corrmo['FD'].std()
 6643  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
 6644  outdict['n_outliers'] = len(hlinds)
 6645  outdict['nc_wm'] = int(nc_wm)
 6646  outdict['nc_csf'] = int(nc_csf)
 6647  outdict['minutes_original_data'] = ( tr * fmri.shape[3] ) / 60.0 # minutes of useful data
 6648  outdict['minutes_censored_data'] = ( tr * simg.shape[3] ) / 60.0 # minutes of useful data
 6649  return convert_np_in_dict( outdict )
 6650
 6651
 6652def calculate_CBF(Delta_M, M_0, mask,
 6653                  Lambda=0.9, T_1=0.67, Alpha=0.68, w=1.0, Tau=1.5):
 6654    """
 6655    Calculate the Cerebral Blood Flow (CBF) where Delta_M and M_0 are antsImages 
 6656    and the other variables are scalars.  Guesses at default values are used here. 
 6657    We use the pCASL equation.  NOT YET TESTED.
 6658
 6659    Parameters:
 6660    Delta_M (antsImage): Change in magnetization (matrix)
 6661    M_0 (antsImage): Initial magnetization (matrix)
 6662    mask ( antsImage ): where to do the calculation
 6663    Lambda (float): Scalar
 6664    T_1 (float): Scalar representing relaxation time
 6665    Alpha (float): Scalar representing flip angle
 6666    w (float): Scalar
 6667    Tau (float): Scalar
 6668
 6669    Returns:
 6670    np.ndarray: CBF values (matrix)
 6671    """
 6672    cbf = M_0 * 0.0
 6673    m0thresh = np.quantile( M_0[mask==1], 0.1 )
 6674    sel = mask == 1 and M_0 > m0thresh
 6675    cbf[ sel ] = Delta_M[ sel ] * 60. * 100. * (Lambda * T_1)/( M_0[sel] * 2.0 * Alpha * 
 6676        (np.exp( -w * T_1) - np.exp(-(Tau + w) * T_1)))
 6677    cbf[ cbf < 0.0]=0.0
 6678    return cbf
 6679
 6680def despike_time_series_afni(image, c1=2.5, c2=4):
 6681    """
 6682    Despike a time series image using L1 polynomial fitting and nonlinear filtering.
 6683    Based on afni 3dDespike
 6684
 6685    :param image: ANTsPy image object containing time series data.
 6686    :param c1: Spike threshold value. Default is 2.5.
 6687    :param c2: Upper range of allowed deviation. Default is 4.
 6688    :return: Despiked ANTsPy image object.
 6689    """
 6690    data = image.numpy()  # Convert to numpy array
 6691    despiked_data = np.copy(data)  # Create a copy for despiked data
 6692    curve = despiked_data * 0.0
 6693
 6694    def l1_fit_polynomial(time_series, degree=2):
 6695        """
 6696        Fit a polynomial of given degree to the time series using least squares.
 6697        
 6698        :param time_series: 1D numpy array of voxel time series data.
 6699        :param degree: Degree of the polynomial to fit.
 6700        :return: Fitted polynomial values for the time series.
 6701        """
 6702        t = np.arange(len(time_series))
 6703        coefs = np.polyfit(t, time_series, degree)
 6704        polynomial = np.polyval(coefs, t)
 6705        return polynomial
 6706
 6707    # L1 fit a smooth-ish curve to each voxel time series
 6708    # Curve fitting for each voxel
 6709    for x in range(data.shape[0]):
 6710        for y in range(data.shape[1]):
 6711            for z in range(data.shape[2]):
 6712                voxel_time_series = data[x, y, z, :]
 6713                curve[x, y, z, :] = l1_fit_polynomial(voxel_time_series, degree=2)
 6714
 6715    # Compute the MAD of the residuals
 6716    residuals = data - curve
 6717    mad = np.median(np.abs(residuals - np.median(residuals, axis=-1, keepdims=True)), axis=-1, keepdims=True)
 6718    sigma = np.sqrt(np.pi / 2) * mad
 6719    # Ensure sigma is not zero to avoid division by zero
 6720    sigma_safe = np.where(sigma == 0, 1e-10, sigma)
 6721
 6722    # Optionally, handle NaN or inf values in data, curve, or sigma
 6723    data = np.nan_to_num(data, nan=0.0, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6724    curve = np.nan_to_num(curve, nan=0.0, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6725    sigma_safe = np.nan_to_num(sigma_safe, nan=1e-10, posinf=np.finfo(np.float64).max, neginf=np.finfo(np.float64).min)
 6726
 6727    # Despike algorithm
 6728    spike_counts = np.zeros( image.shape[3] )
 6729    for i in range(data.shape[-1]):
 6730        s = (data[..., i] - curve[..., i]) / sigma_safe[..., 0]
 6731        ww = s > c1
 6732        s_prime = np.where( ww, c1 + (c2 - c1) * np.tanh((s - c1) / (c2 - c1)), s)
 6733        spike_counts[i] = ww.sum()
 6734        despiked_data[..., i] = curve[..., i] + s_prime * sigma[..., 0]
 6735
 6736    # Convert back to ANTsPy image
 6737    despiked_image = ants.from_numpy(despiked_data)
 6738    return ants.copy_image_info( image, despiked_image ), spike_counts
 6739
 6740def despike_time_series(image, threshold=3.0, replacement='threshold' ):
 6741    """
 6742    Despike a time series image.
 6743    
 6744    :param image: ANTsPy image object containing time series data.
 6745    :param threshold: z-score value to identify spikes. Default is 3.
 6746    :param replacement: median or threshold - the latter is similar 3DDespike but simpler
 6747    :return: Despiked ANTsPy image object.
 6748    """
 6749    # Convert image to numpy array
 6750    data = image.numpy()
 6751    
 6752    # Calculate the mean and standard deviation along the time axis
 6753    mean = np.mean(data, axis=-1)
 6754    std = np.std(data, axis=-1)
 6755
 6756    # Identify spikes: points where the deviation from the mean exceeds the threshold
 6757    spikes = np.abs(data - mean[..., np.newaxis]) > threshold * std[..., np.newaxis]
 6758
 6759    # Replace spike values
 6760    spike_counts = np.zeros( image.shape[3] )
 6761    for i in range(data.shape[-1]):
 6762        slice = data[..., i]
 6763        spike_locations = spikes[..., i]
 6764        spike_counts[i] = spike_locations.sum()
 6765        if replacement == 'median':
 6766            slice[spike_locations] = np.median(slice)  # Replace with median or another method
 6767        else:
 6768	    # Calculate threshold values (mean ± threshold * std)
 6769            threshold_values = mean + np.sign(slice - mean) * threshold * std
 6770            slice[spike_locations] = threshold_values[spike_locations]
 6771        data[..., i] = slice
 6772    # Convert back to ANTsPy image
 6773    despike_image = ants.from_numpy(data)
 6774    despike_image = ants.copy_image_info( image, despike_image )
 6775    return despike_image, spike_counts
 6776
 6777
 6778
 6779def bold_perfusion_minimal( 
 6780        fmri, 
 6781        m0_image = None,
 6782        spa = (0., 0., 0., 0.),
 6783        nc  = 0,
 6784        tc='alternating',
 6785        n_to_trim=0,
 6786        outlier_threshold=0.250,
 6787        plot_brain_mask=False,
 6788        verbose=False ):
 6789  """
 6790  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 6791
 6792  This function is intended for use in debugging/testing or when one lacks a T1w image.
 6793
 6794  Arguments
 6795  ---------
 6796
 6797  fmri : BOLD fmri antsImage
 6798
 6799  m0_image: a pre-defined m0 antsImage
 6800
 6801  spa : gaussian smoothing for spatial and temporal component e.g. (1,1,1,0) in physical space coordinates
 6802
 6803  nc  : number of components for compcor filtering
 6804
 6805  tc: string either alternating or split (default is alternating ie CTCTCT; split is CCCCTTTT)
 6806
 6807  n_to_trim: number of volumes to trim off the front of the time series to account for initial magnetic saturation effects or to allow the signal to reach a steady state. in some cases, trailing volumes or other outlier volumes may need to be rejected.  this code does not currently handle that issue.
 6808
 6809  outlier_threshold (numeric): between zero (remove all) and one (remove none); automatically calculates outlierness and uses it to censor the time series.
 6810
 6811  plot_brain_mask : boolean can help with checking data quality visually
 6812
 6813  verbose : boolean
 6814
 6815  Returns
 6816  ---------
 6817  a dictionary containing the derived network maps
 6818
 6819  """
 6820  import numpy as np
 6821  import pandas as pd
 6822  import re
 6823  import math
 6824  from sklearn.linear_model import RANSACRegressor, TheilSenRegressor, HuberRegressor, QuantileRegressor, LinearRegression, SGDRegressor
 6825  from sklearn.multioutput import MultiOutputRegressor
 6826  from sklearn.preprocessing import StandardScaler
 6827
 6828  def replicate_list(user_list, target_size):
 6829    # Calculate the number of times the list should be replicated
 6830    replication_factor = target_size // len(user_list)
 6831    # Replicate the list and handle any remaining elements
 6832    replicated_list = user_list * replication_factor
 6833    remaining_elements = target_size % len(user_list)
 6834    replicated_list += user_list[:remaining_elements]
 6835    return replicated_list
 6836
 6837  def one_hot_encode(char_list):
 6838    unique_chars = list(set(char_list))
 6839    encoding_dict = {char: [1 if char == c else 0 for c in unique_chars] for char in unique_chars}
 6840    encoded_matrix = np.array([encoding_dict[char] for char in char_list])
 6841    return encoded_matrix
 6842  
 6843  A = np.zeros((1,1))
 6844  fmri_template = ants.get_average_of_timeseries( fmri )
 6845  if n_to_trim is None:
 6846    n_to_trim=0
 6847  mytrim=n_to_trim
 6848  perf_total_sigma = 1.5
 6849  corrmo = timeseries_reg(
 6850    fmri, fmri_template,
 6851    type_of_transform='antsRegistrationSyNRepro[r]',
 6852    total_sigma=perf_total_sigma,
 6853    fdOffset=2.0,
 6854    trim = mytrim,
 6855    output_directory=None,
 6856    verbose=verbose,
 6857    syn_metric='CC',
 6858    syn_sampling=2,
 6859    reg_iterations=[40,20,5] )
 6860  if verbose:
 6861      print("End rsfmri motion correction")
 6862      print("--maximum motion : " + str(corrmo['FD'].max()) )
 6863
 6864  if m0_image is not None:
 6865      m0 = m0_image
 6866
 6867  ntp = corrmo['motion_corrected'].shape[3]
 6868  fmri_template = ants.get_average_of_timeseries( corrmo['motion_corrected'] )
 6869  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 6870  if plot_brain_mask:
 6871    ants.plot( fmri_template, bmask, axis=1, crop=True )
 6872    ants.plot( fmri_template, bmask, axis=2, crop=True )
 6873
 6874  if tc == 'alternating':
 6875      tclist = replicate_list( ['C','T'], ntp )
 6876  else:
 6877      tclist = replicate_list( ['C'], int(ntp/2) ) + replicate_list( ['T'],  int(ntp/2) )
 6878
 6879  tclist = one_hot_encode( tclist[0:ntp ] )
 6880  fmrimotcorr=corrmo['motion_corrected']
 6881  hlinds = None
 6882  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6883    fmrimotcorr, hlinds = loop_timeseries_censoring( fmrimotcorr, outlier_threshold, mask=None, verbose=verbose )
 6884    tclist = remove_elements_from_numpy_array( tclist, hlinds)
 6885    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 6886
 6887  # redo template and registration at (potentially) upsampled scale
 6888  fmri_template = ants.iMath( ants.get_average_of_timeseries( fmrimotcorr ), "Normalize" )
 6889  corrmo = timeseries_reg(
 6890        fmri, fmri_template,
 6891        type_of_transform='antsRegistrationSyNRepro[r]',
 6892        total_sigma=perf_total_sigma,
 6893        fdOffset=2.0,
 6894        trim = mytrim,
 6895        output_directory=None,
 6896        verbose=verbose,
 6897        syn_metric='CC',
 6898        syn_sampling=2,
 6899        reg_iterations=[40,20,5] )
 6900  if verbose:
 6901        print("End 2nd rsfmri motion correction")
 6902        print("--maximum motion : " + str(corrmo['FD'].max()) )
 6903
 6904  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 6905    corrmo['motion_corrected'] = remove_volumes_from_timeseries( corrmo['motion_corrected'], hlinds )
 6906    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 6907
 6908  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
 6909  if plot_brain_mask:
 6910    ants.plot( fmri_template, bmask, axis=1, crop=True )
 6911    ants.plot( fmri_template, bmask, axis=2, crop=True )
 6912
 6913  regression_mask = bmask.clone()
 6914  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 6915  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 6916  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",3)
 6917  bmask = bmask * ants.iMath( tsnrmask, "FillHoles" )
 6918  fmrimotcorr=corrmo['motion_corrected']
 6919  und = fmri_template * bmask
 6920  compcorquantile=0.50
 6921  mycompcor = ants.compcor( fmrimotcorr,
 6922    ncompcor=nc, quantile=compcorquantile, mask = bmask,
 6923    filter_type='polynomial', degree=2 )
 6924  tr = ants.get_spacing( fmrimotcorr )[3]
 6925  simg = ants.smooth_image(fmrimotcorr, spa, sigma_in_physical_coordinates = True )
 6926  nuisance = mycompcor['basis']
 6927  nuisance = np.c_[ nuisance, mycompcor['components'] ]
 6928  if verbose:
 6929    print("make sure nuisance is independent of TC")
 6930  nuisance = ants.regress_components( nuisance, tclist )
 6931  regression_mask = bmask.clone()
 6932  gmmat = ants.timeseries_to_matrix( simg, regression_mask )
 6933  regvars = np.hstack( (nuisance, tclist ))
 6934  coefind = regvars.shape[1]-1
 6935  regvars = regvars[:,range(coefind)]
 6936  predictor_of_interest_idx = regvars.shape[1]-1
 6937  valid_perf_models = ['huber','quantile','theilsen','ransac', 'sgd', 'linear','SM']
 6938  perfusion_regression_model='linear'
 6939  if verbose:
 6940    print( "begin perfusion estimation with " + perfusion_regression_model + " model " )
 6941  regression_model = LinearRegression()
 6942  regression_model.fit( regvars, gmmat )
 6943  coefind = regression_model.coef_.shape[1]-1
 6944  perfimg = ants.make_image( regression_mask, regression_model.coef_[:,coefind] )
 6945  gmseg = ants.image_clone( bmask )
 6946  meangmval = ( perfimg[ gmseg == 1 ] ).mean()
 6947  if meangmval < 0:
 6948      perfimg = perfimg * (-1.0)
 6949  negative_voxels = ( perfimg < 0.0 ).sum() / np.prod( perfimg.shape ) * 100.0
 6950  perfimg[ perfimg < 0.0 ] = 0.0 # non-physiological
 6951
 6952  if m0_image is None:
 6953    m0 = ants.get_average_of_timeseries( fmrimotcorr )
 6954  else:
 6955    # register m0 to current template
 6956    m0reg = ants.registration( fmri_template, m0, 'antsRegistrationSyNRepro[r]', verbose=False )
 6957    m0 = m0reg['warpedmovout']
 6958
 6959  if ntp == 2 :
 6960      img0 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=0 )
 6961      img1 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=1 )
 6962      if m0_image is None:
 6963        if img0.mean() < img1.mean():
 6964            perfimg=img0
 6965            m0=img1
 6966        else:
 6967            perfimg=img1
 6968            m0=img0
 6969      else:
 6970        if img0.mean() < img1.mean():
 6971            perfimg=img1-img0
 6972        else:
 6973            perfimg=img0-img1
 6974  
 6975  cbf = calculate_CBF( Delta_M=perfimg, M_0=m0, mask=bmask )
 6976  meangmval = ( perfimg[ gmseg == 1 ] ).mean()        
 6977  meangmvalcbf = ( cbf[ gmseg == 1 ] ).mean()
 6978  if verbose:
 6979    print("perfimg.max() " + str(  perfimg.max() ) )
 6980  outdict = {}
 6981  outdict['meanBold'] = und
 6982  outdict['brainmask'] = bmask
 6983  rsfNuisance = pd.DataFrame( nuisance )
 6984  rsfNuisance['FD']=corrmo['FD']
 6985  outdict['perfusion']=perfimg
 6986  outdict['cbf']=cbf
 6987  outdict['m0']=m0
 6988  outdict['perfusion_gm_mean']=meangmval
 6989  outdict['cbf_gm_mean']=meangmvalcbf
 6990  outdict['motion_corrected'] = corrmo['motion_corrected']
 6991  outdict['brain_mask'] = bmask
 6992  outdict['nuisance'] = rsfNuisance
 6993  outdict['tsnr'] = mytsnr
 6994  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 6995  outdict['FD_max'] = rsfNuisance['FD'].max()
 6996  outdict['FD_mean'] = rsfNuisance['FD'].mean()
 6997  outdict['FD_sd'] = rsfNuisance['FD'].std()
 6998  outdict['outlier_volumes']=hlinds
 6999  outdict['negative_voxels']=negative_voxels
 7000  return convert_np_in_dict( outdict )
 7001
 7002
 7003
 7004def warn_if_small_mask( mask: ants.ANTsImage, threshold_fraction: float = 0.05, label: str = ' ' ):
 7005    """
 7006    Warn the user if the number of non-zero voxels in the mask
 7007    is less than a given fraction of the total number of voxels in the mask.
 7008
 7009    Parameters
 7010    ----------
 7011    mask : ants.ANTsImage
 7012        The binary mask to evaluate.
 7013    threshold_fraction : float, optional
 7014        Fraction threshold below which a warning is triggered (default is 0.05).
 7015    
 7016    Returns
 7017    -------
 7018    None
 7019    """
 7020    import warnings
 7021    image_size = np.prod(mask.shape)
 7022    mask_size = np.count_nonzero(mask.numpy())
 7023    if mask_size / image_size < threshold_fraction:
 7024        percentage = 100.0 * mask_size / image_size
 7025        warnings.warn(
 7026            f"[ants] Warning: {label} contains only {mask_size} voxels "
 7027            f"({percentage:.2f}% of image volume). "
 7028            f"This is below the threshold of {threshold_fraction * 100:.2f}% and may lead to unreliable results.",
 7029            UserWarning
 7030        )
 7031
 7032def bold_perfusion( 
 7033    fmri, t1head, t1, t1segmentation, t1dktcit,
 7034                   FD_threshold=0.5,
 7035                   spa = (0., 0., 0., 0.),
 7036                   nc = 3,
 7037                   type_of_transform='antsRegistrationSyNRepro[r]',
 7038                   tc='alternating',
 7039                   n_to_trim=0,
 7040                   m0_image = None,
 7041                   m0_indices=None,
 7042                   outlier_threshold=0.250,
 7043                   add_FD_to_nuisance=False,
 7044                   n3=False,
 7045                   segment_timeseries=False,
 7046                   trim_the_mask=4.25,
 7047                   upsample=True,
 7048                   perfusion_regression_model='linear',
 7049                   verbose=False ):
 7050  """
 7051  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 7052
 7053  Arguments
 7054  ---------
 7055  fmri : BOLD fmri antsImage
 7056
 7057  t1head : ANTsImage
 7058    input 3-D T1 brain image (not brain extracted)
 7059
 7060  t1 : ANTsImage
 7061    input 3-D T1 brain image (brain extracted)
 7062
 7063  t1segmentation : ANTsImage
 7064    t1 segmentation - a six tissue segmentation image in T1 space
 7065
 7066  t1dktcit : ANTsImage
 7067    t1 dkt cortex plus cit parcellation
 7068
 7069  spa : gaussian smoothing for spatial and temporal component e.g. (1,1,1,0) in physical space coordinates
 7070
 7071  nc  : number of components for compcor filtering
 7072
 7073  type_of_transform : SyN or Rigid
 7074
 7075  tc: string either alternating or split (default is alternating ie CTCTCT; split is CCCCTTTT)
 7076
 7077  n_to_trim: number of volumes to trim off the front of the time series to account for initial magnetic saturation effects or to allow the signal to reach a steady state. in some cases, trailing volumes or other outlier volumes may need to be rejected.  this code does not currently handle that issue.
 7078
 7079  m0_image: a pre-defined m0 image - we expect this to be 3D.  if it is not, we naively 
 7080    average over the 4th dimension.
 7081
 7082  m0_indices: which indices in the perfusion image are the m0.  if set, n_to_trim will be ignored.
 7083
 7084  outlier_threshold (numeric): between zero (remove all) and one (remove none); automatically calculates outlierness and uses it to censor the time series.
 7085
 7086  add_FD_to_nuisance: boolean
 7087
 7088  n3: boolean
 7089
 7090  segment_timeseries : boolean
 7091
 7092  trim_the_mask : float >= 0 post-hoc method for trimming the mask
 7093
 7094  upsample: boolean
 7095
 7096  perfusion_regression_model: string 'linear', 'ransac', 'theilsen', 'huber', 'quantile', 'sgd'; 'linear' and 'huber' are the only ones that work ok by default and are relatively quick to compute.
 7097
 7098  verbose : boolean
 7099
 7100  Returns
 7101  ---------
 7102  a dictionary containing the derived network maps
 7103
 7104  """
 7105  import numpy as np
 7106  import pandas as pd
 7107  import re
 7108  import math
 7109  from sklearn.linear_model import RANSACRegressor, TheilSenRegressor, HuberRegressor, QuantileRegressor, LinearRegression, SGDRegressor
 7110  from sklearn.multioutput import MultiOutputRegressor
 7111  from sklearn.preprocessing import StandardScaler
 7112
 7113  # remove outlier volumes
 7114  if segment_timeseries:
 7115    lo_vs_high = segment_timeseries_by_meanvalue(fmri)
 7116    fmri = remove_volumes_from_timeseries( fmri, lo_vs_high['lowermeans'] )
 7117
 7118  ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7119  cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7120
 7121  if n3:
 7122    fmri = timeseries_n3( fmri )
 7123
 7124  if m0_image is not None:
 7125    if m0_image.dimension == 4:
 7126      m0_image = ants.get_average_of_timeseries( m0_image )
 7127
 7128  def select_regression_model(regression_model, min_samples=10 ):
 7129    if regression_model == 'sgd' :
 7130      sgd_regressor = SGDRegressor(penalty='elasticnet', alpha=1e-5, l1_ratio=0.15, max_iter=20000, tol=1e-3, random_state=42)
 7131      return sgd_regressor
 7132    elif regression_model == 'ransac':
 7133      ransac = RANSACRegressor(
 7134            min_samples=0.8,
 7135            max_trials=10,         # Maximum number of iterations
 7136#            min_samples=min_samples, # Minimum number samples to be chosen as inliers in each iteration
 7137#            stop_probability=0.80,  # Probability to stop the algorithm if a good subset is found
 7138#            stop_n_inliers=40,      # Stop if this number of inliers is found
 7139#            stop_score=0.8,         # Stop if the model score reaches this value
 7140#            n_jobs=int(os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")) # Use all available CPU cores for parallel processing
 7141        )
 7142      return ransac
 7143    models = {
 7144        'sgd':SGDRegressor,
 7145        'ransac': RANSACRegressor,
 7146        'theilsen': TheilSenRegressor,
 7147        'huber': HuberRegressor,
 7148        'quantile': QuantileRegressor
 7149    }
 7150    return models.get(regression_model.lower(), LinearRegression)()
 7151
 7152  def replicate_list(user_list, target_size):
 7153    # Calculate the number of times the list should be replicated
 7154    replication_factor = target_size // len(user_list)
 7155    # Replicate the list and handle any remaining elements
 7156    replicated_list = user_list * replication_factor
 7157    remaining_elements = target_size % len(user_list)
 7158    replicated_list += user_list[:remaining_elements]
 7159    return replicated_list
 7160
 7161  def one_hot_encode(char_list):
 7162    unique_chars = list(set(char_list))
 7163    encoding_dict = {char: [1 if char == c else 0 for c in unique_chars] for char in unique_chars}
 7164    encoded_matrix = np.array([encoding_dict[char] for char in char_list])
 7165    return encoded_matrix
 7166  
 7167  A = np.zeros((1,1))
 7168  # fmri = ants.iMath( fmri, 'Normalize' )
 7169  fmri_template, hlinds = loop_timeseries_censoring( fmri, 0.10 )
 7170  fmri_template = ants.get_average_of_timeseries( fmri_template )
 7171  del hlinds
 7172  rig = ants.registration( fmri_template, t1head, 'antsRegistrationSyNRepro[r]' )
 7173  bmask = ants.apply_transforms( fmri_template, ants.threshold_image(t1segmentation,1,6), rig['fwdtransforms'][0], interpolator='genericLabel' )
 7174  if m0_indices is None:
 7175    if n_to_trim is None:
 7176        n_to_trim=0
 7177    mytrim=n_to_trim
 7178  else:
 7179    mytrim = 0
 7180  perf_total_sigma = 1.5
 7181  corrmo = timeseries_reg(
 7182    fmri, fmri_template,
 7183    type_of_transform=type_of_transform,
 7184    total_sigma=perf_total_sigma,
 7185    fdOffset=2.0,
 7186    trim = mytrim,
 7187    output_directory=None,
 7188    verbose=verbose,
 7189    syn_metric='CC',
 7190    syn_sampling=2,
 7191    reg_iterations=[40,20,5] )
 7192  if verbose:
 7193      print("End rsfmri motion correction")
 7194
 7195  if m0_image is not None:
 7196      m0 = m0_image
 7197  elif m0_indices is not None:
 7198    not_m0 = list( range( fmri.shape[3] ) )
 7199    not_m0 = [x for x in not_m0 if x not in m0_indices]
 7200    if verbose:
 7201        print( m0_indices )
 7202        print( not_m0 )
 7203    # then remove it from the time series
 7204    m0 = remove_volumes_from_timeseries( corrmo['motion_corrected'], not_m0 )
 7205    m0 = ants.get_average_of_timeseries( m0 )
 7206    corrmo['motion_corrected'] = remove_volumes_from_timeseries( 
 7207        corrmo['motion_corrected'], m0_indices )
 7208    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], m0_indices )
 7209    fmri = remove_volumes_from_timeseries( fmri, m0_indices )
 7210
 7211  ntp = corrmo['motion_corrected'].shape[3]
 7212  if tc == 'alternating':
 7213      tclist = replicate_list( ['C','T'], ntp )
 7214  else:
 7215      tclist = replicate_list( ['C'], int(ntp/2) ) + replicate_list( ['T'],  int(ntp/2) )
 7216
 7217  tclist = one_hot_encode( tclist[0:ntp ] )
 7218  fmrimotcorr=corrmo['motion_corrected']
 7219  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 7220    fmrimotcorr, hlinds = loop_timeseries_censoring( fmrimotcorr, outlier_threshold, mask=None, verbose=verbose )
 7221    tclist = remove_elements_from_numpy_array( tclist, hlinds)
 7222    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 7223
 7224  # redo template and registration at (potentially) upsampled scale
 7225  fmri_template = ants.iMath( ants.get_average_of_timeseries( fmrimotcorr ), "Normalize" )
 7226  if upsample:
 7227      spc = ants.get_spacing( fmri )
 7228      minspc = 2.0
 7229      if min(spc[0:3]) < minspc:
 7230          minspc = min(spc[0:3])
 7231      newspc = [minspc,minspc,minspc]
 7232      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
 7233
 7234  if verbose:
 7235      print( 'fmri_template')
 7236      print( fmri_template )
 7237
 7238  rig = ants.registration( fmri_template, t1head, 'antsRegistrationSyNRepro[r]' )
 7239  bmask = ants.apply_transforms( fmri_template, 
 7240    ants.threshold_image(t1segmentation,1,6), 
 7241    rig['fwdtransforms'][0], 
 7242    interpolator='genericLabel' )
 7243
 7244  warn_if_small_mask( bmask, label='bold_perfusion:bmask')
 7245
 7246  corrmo = timeseries_reg(
 7247        fmri, fmri_template,
 7248        type_of_transform=type_of_transform,
 7249        total_sigma=perf_total_sigma,
 7250        fdOffset=2.0,
 7251        trim = mytrim,
 7252        output_directory=None,
 7253        verbose=verbose,
 7254        syn_metric='CC',
 7255        syn_sampling=2,
 7256        reg_iterations=[40,20,5] )
 7257  if verbose:
 7258        print("End 2nd rsfmri motion correction")
 7259
 7260  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
 7261    corrmo['motion_corrected'] = remove_volumes_from_timeseries( corrmo['motion_corrected'], hlinds )
 7262    corrmo['FD'] = remove_elements_from_numpy_array( corrmo['FD'], hlinds )
 7263
 7264  regression_mask = bmask.clone()
 7265  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
 7266  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
 7267  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",3)
 7268  bmask = bmask * ants.iMath( tsnrmask, "FillHoles" )
 7269  warn_if_small_mask( bmask, label='bold_perfusion:bmask*tsnrmask')
 7270  fmrimotcorr=corrmo['motion_corrected']
 7271  und = fmri_template * bmask
 7272  t1reg = ants.registration( und, t1, "antsRegistrationSyNRepro[s]" )
 7273  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 7274  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 7275  gmseg = ants.threshold_image( gmseg, 1, 4 )
 7276  gmseg = ants.iMath( gmseg, 'MD', 1 )
 7277  gmseg = ants.apply_transforms( und, gmseg,
 7278    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7279  csfseg = ants.threshold_image( t1segmentation, 1, 1 )
 7280  wmseg = ants.threshold_image( t1segmentation, 3, 3 )
 7281  csfAndWM = ( csfseg + wmseg ).morphology("erode",1)
 7282  compcorquantile=0.50
 7283  csfAndWM = ants.apply_transforms( und, csfAndWM,
 7284    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7285  csfseg = ants.apply_transforms( und, csfseg,
 7286    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7287  wmseg = ants.apply_transforms( und, wmseg,
 7288    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7289  warn_if_small_mask( wmseg, label='bold_perfusion:wmseg')
 7290  # warn_if_small_mask( csfseg, threshold_fraction=0.01, label='bold_perfusion:csfseg')
 7291  warn_if_small_mask( csfAndWM, label='bold_perfusion:csfAndWM')
 7292  mycompcor = ants.compcor( fmrimotcorr,
 7293    ncompcor=nc, quantile=compcorquantile, mask = csfAndWM,
 7294    filter_type='polynomial', degree=2 )
 7295  tr = ants.get_spacing( fmrimotcorr )[3]
 7296  simg = ants.smooth_image(fmrimotcorr, spa, sigma_in_physical_coordinates = True )
 7297  nuisance = mycompcor['basis']
 7298  nuisance = np.c_[ nuisance, mycompcor['components'] ]
 7299  if add_FD_to_nuisance:
 7300    nuisance = np.c_[ nuisance, corrmo['FD'] ]
 7301  if verbose:
 7302    print("make sure nuisance is independent of TC")
 7303  nuisance = ants.regress_components( nuisance, tclist )
 7304  regression_mask = bmask.clone()
 7305  gmmat = ants.timeseries_to_matrix( simg, regression_mask )
 7306  regvars = np.hstack( (nuisance, tclist ))
 7307  coefind = regvars.shape[1]-1
 7308  regvars = regvars[:,range(coefind)]
 7309  predictor_of_interest_idx = regvars.shape[1]-1
 7310  valid_perf_models = ['huber','quantile','theilsen','ransac', 'sgd', 'linear','SM']
 7311  if verbose:
 7312    print( "begin perfusion estimation with " + perfusion_regression_model + " model " )
 7313  if perfusion_regression_model == 'linear':
 7314    regression_model = LinearRegression()
 7315    regression_model.fit( regvars, gmmat )
 7316    coefind = regression_model.coef_.shape[1]-1
 7317    perfimg = ants.make_image( regression_mask, regression_model.coef_[:,coefind] )
 7318  elif perfusion_regression_model == 'SM': #
 7319    import statsmodels.api as sm
 7320    coeffs = np.zeros( gmmat.shape[1] )
 7321    # Loop over each outcome column in the outcomes matrix
 7322    for outcome_idx in range(gmmat.shape[1]):
 7323        outcome = gmmat[:, outcome_idx]  # Select one outcome column
 7324        model = sm.RLM(outcome, sm.add_constant(regvars), M=sm.robust.norms.HuberT())  # Huber's T norm for robust regression
 7325        results = model.fit()
 7326        coefficients = results.params  # Coefficients of all predictors
 7327        coeffs[outcome_idx] = coefficients[predictor_of_interest_idx]
 7328    perfimg = ants.make_image( regression_mask, coeffs )
 7329  elif perfusion_regression_model in valid_perf_models :
 7330    scaler = StandardScaler()
 7331    gmmat = scaler.fit_transform(gmmat)
 7332    coeffs = np.zeros( gmmat.shape[1] )
 7333    huber_regressor = select_regression_model( perfusion_regression_model )
 7334    multioutput_model = MultiOutputRegressor(huber_regressor)
 7335    multioutput_model.fit( regvars, gmmat )
 7336    ct=0
 7337    for i, estimator in enumerate(multioutput_model.estimators_):
 7338      coefficients = estimator.coef_
 7339      coeffs[ct]=coefficients[predictor_of_interest_idx]
 7340      ct=ct+1
 7341    perfimg = ants.make_image( regression_mask, coeffs )
 7342  else:
 7343    raise ValueError( perfusion_regression_model + " regression model is not found.")
 7344  meangmval = ( perfimg[ gmseg == 1 ] ).mean()
 7345  if meangmval < 0:
 7346      perfimg = perfimg * (-1.0)
 7347  negative_voxels = ( perfimg < 0.0 ).sum() / np.prod( perfimg.shape ) * 100.0
 7348  perfimg[ perfimg < 0.0 ] = 0.0 # non-physiological
 7349
 7350  # LaTeX code for Cerebral Blood Flow (CBF) calculation using ASL MRI
 7351  """
 7352CBF = \\frac{\\Delta M \\cdot \\lambda}{2 \\cdot T_1 \\cdot \\alpha \\cdot M_0 \\cdot (e^{-\\frac{w}{T_1}} - e^{-\\frac{w + \\tau}{T_1}})}
 7353
 7354Where:
 7355- \\Delta M is the difference in magnetization between labeled and control images.
 7356- \\lambda is the brain-blood partition coefficient, typically around 0.9 mL/g.
 7357- T_1 is the longitudinal relaxation time of blood, which is a tissue-specific constant.
 7358- \\alpha is the labeling efficiency.
 7359- M_0 is the equilibrium magnetization of brain tissue (from the M0 image).
 7360- w is the post-labeling delay, the time between the end of the labeling and the acquisition of the image.
 7361- \\tau is the labeling duration.
 7362  """
 7363  if m0_indices is None and m0_image is None:
 7364    m0 = ants.get_average_of_timeseries( fmrimotcorr )
 7365  else:
 7366    # register m0 to current template
 7367    m0reg = ants.registration( fmri_template, m0, 'antsRegistrationSyNRepro[r]', verbose=False )
 7368    m0 = m0reg['warpedmovout']
 7369
 7370  if ntp == 2 :
 7371      img0 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=0 )
 7372      img1 = ants.slice_image( corrmo['motion_corrected'], axis=3, idx=1 )
 7373      if m0_image is None:
 7374        if img0.mean() < img1.mean():
 7375            perfimg=img0
 7376            m0=img1
 7377        else:
 7378            perfimg=img1
 7379            m0=img0
 7380      else:
 7381        if img0.mean() < img1.mean():
 7382            perfimg=img1-img0
 7383        else:
 7384            perfimg=img0-img1
 7385
 7386  cbf = calculate_CBF(
 7387      Delta_M=perfimg, M_0=m0, mask=bmask )
 7388  if trim_the_mask > 0.0 :
 7389    bmask = trim_dti_mask( cbf, bmask, trim_the_mask )
 7390    perfimg = perfimg * bmask
 7391    cbf = cbf * bmask
 7392
 7393  meangmval = ( perfimg[ gmseg == 1 ] ).mean()        
 7394  meangmvalcbf = ( cbf[ gmseg == 1 ] ).mean()
 7395  if verbose:
 7396    print("perfimg.max() " + str(  perfimg.max() ) )
 7397  outdict = {}
 7398  outdict['meanBold'] = und
 7399  outdict['brainmask'] = bmask
 7400  rsfNuisance = pd.DataFrame( nuisance )
 7401  rsfNuisance['FD']=corrmo['FD']
 7402
 7403  if verbose:
 7404      print("perfusion dataframe begin")
 7405  dktseg = ants.apply_transforms( und, t1dktcit,
 7406    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7407  df_perf = antspyt1w.map_intensity_to_dataframe(
 7408        'dkt_cortex_cit_deep_brain',
 7409        perfimg,
 7410        dktseg)
 7411  df_perf = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7412              {'perf' : df_perf},
 7413              col_names = ['Mean'] )
 7414  df_cbf = antspyt1w.map_intensity_to_dataframe(
 7415        'dkt_cortex_cit_deep_brain',
 7416        cbf,
 7417        dktseg)
 7418  df_cbf = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7419              {'cbf' : df_cbf},
 7420              col_names = ['Mean'] )
 7421  df_cbf = df_cbf.add_prefix('cbf_')
 7422  df_perf = pd.concat( [df_perf,df_cbf], axis=1, ignore_index=False )
 7423  if verbose:
 7424      print("perfusion dataframe end")
 7425
 7426  outdict['perfusion']=perfimg
 7427  outdict['cbf']=cbf
 7428  outdict['m0']=m0
 7429  outdict['perfusion_gm_mean']=meangmval
 7430  outdict['cbf_gm_mean']=meangmvalcbf
 7431  outdict['perf_dataframe']=df_perf
 7432  outdict['motion_corrected'] = corrmo['motion_corrected']
 7433  outdict['gmseg'] = gmseg
 7434  outdict['brain_mask'] = bmask
 7435  outdict['nuisance'] = rsfNuisance
 7436  outdict['tsnr'] = mytsnr
 7437  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
 7438  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
 7439  outdict['high_motion_count'] = (rsfNuisance['FD'] > FD_threshold ).sum()
 7440  outdict['high_motion_pct'] = (rsfNuisance['FD'] > FD_threshold ).sum() / rsfNuisance.shape[0]
 7441  outdict['FD_max'] = rsfNuisance['FD'].max()
 7442  outdict['FD_mean'] = rsfNuisance['FD'].mean()
 7443  outdict['FD_sd'] = rsfNuisance['FD'].std()
 7444  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
 7445  outdict['t1reg'] = t1reg
 7446  outdict['outlier_volumes']=hlinds
 7447  outdict['n_outliers']=len(hlinds)
 7448  outdict['negative_voxels']=negative_voxels
 7449  return convert_np_in_dict( outdict )
 7450
 7451
 7452def pet3d_summary( pet3d, t1head, t1, t1segmentation, t1dktcit,
 7453                   spa = (0., 0., 0.),
 7454                   type_of_transform='antsRegistrationSyNRepro[r]',
 7455                   upsample=True,
 7456                   verbose=False ):
 7457  """
 7458  Estimate perfusion from a BOLD time series image.  Will attempt to figure out the T-C labels from the data.  The function uses defaults to quantify CBF but these will usually not be correct for your own data.  See the function calculate_CBF for an example of how one might do quantification based on the outputs of this function specifically the perfusion, m0 and mask images that are part of the output dictionary.
 7459
 7460  Arguments
 7461  ---------
 7462  pet3d : 3D PET antsImage
 7463
 7464  t1head : ANTsImage
 7465    input 3-D T1 brain image (not brain extracted)
 7466
 7467  t1 : ANTsImage
 7468    input 3-D T1 brain image (brain extracted)
 7469
 7470  t1segmentation : ANTsImage
 7471    t1 segmentation - a six tissue segmentation image in T1 space
 7472
 7473  t1dktcit : ANTsImage
 7474    t1 dkt cortex plus cit parcellation
 7475
 7476  type_of_transform : SyN or Rigid
 7477
 7478  upsample: boolean
 7479
 7480  verbose : boolean
 7481
 7482  Returns
 7483  ---------
 7484  a dictionary containing the derived network maps
 7485
 7486  """
 7487  import numpy as np
 7488  import pandas as pd
 7489  import re
 7490  import math
 7491
 7492  ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7493  cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7494  
 7495  pet3dr=pet3d
 7496  if upsample:
 7497      spc = ants.get_spacing( pet3d )
 7498      minspc = 1.0
 7499      if min(spc) < minspc:
 7500        minspc = min(spc)
 7501      newspc = [minspc,minspc,minspc]
 7502      pet3dr = ants.resample_image( pet3d, newspc, interp_type=0 )
 7503
 7504  rig = ants.registration( pet3dr, t1head, 'antsRegistrationSyNRepro[r]' )
 7505  bmask = ants.apply_transforms( pet3dr, 
 7506    ants.threshold_image(t1segmentation,1,6), 
 7507    rig['fwdtransforms'][0], 
 7508    interpolator='genericLabel' )
 7509  if verbose:
 7510    print("End t1=>pet registration")
 7511
 7512  und = pet3dr * bmask
 7513#  t1reg = ants.registration( und, t1, "antsRegistrationSyNQuickRepro[s]" )
 7514  t1reg = rig # ants.registration( und, t1, "Rigid" )
 7515  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
 7516  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
 7517  gmseg = ants.threshold_image( gmseg, 1, 4 )
 7518  gmseg = ants.iMath( gmseg, 'MD', 1 )
 7519  gmseg = ants.apply_transforms( und, gmseg,
 7520    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7521  csfseg = ants.threshold_image( t1segmentation, 1, 1 )
 7522  wmseg = ants.threshold_image( t1segmentation, 3, 3 )
 7523  csfAndWM = ( csfseg + wmseg ).morphology("erode",1)
 7524  csfAndWM = ants.apply_transforms( und, csfAndWM,
 7525    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7526  csfseg = ants.apply_transforms( und, csfseg,
 7527    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7528  wmseg = ants.apply_transforms( und, wmseg,
 7529    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
 7530  wmsignal = pet3dr[ ants.iMath(wmseg,"ME",1) == 1 ].mean()
 7531  gmsignal = pet3dr[ gmseg == 1 ].mean()
 7532  csfsignal = pet3dr[ csfseg == 1 ].mean()
 7533  if verbose:
 7534    print("pet3dr.max() " + str(  pet3dr.max() ) )
 7535  if verbose:
 7536      print("pet3d dataframe begin")
 7537  dktseg = ants.apply_transforms( und, t1dktcit,
 7538    t1reg['fwdtransforms'], interpolator = 'genericLabel' ) * bmask
 7539  df_pet3d = antspyt1w.map_intensity_to_dataframe(
 7540        'dkt_cortex_cit_deep_brain',
 7541        und,
 7542        dktseg)
 7543  df_pet3d = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 7544              {'pet3d' : df_pet3d},
 7545              col_names = ['Mean'] )
 7546  if verbose:
 7547      print("pet3d dataframe end")
 7548
 7549  outdict = {}
 7550  outdict['pet3d_dataframe']=df_pet3d
 7551  outdict['pet3d'] = pet3dr
 7552  outdict['brainmask'] = bmask
 7553  outdict['gm_mean']=gmsignal
 7554  outdict['wm_mean']=wmsignal
 7555  outdict['csf_mean']=csfsignal
 7556  outdict['pet3d_dataframe']=df_pet3d
 7557  outdict['t1reg'] = t1reg
 7558  return convert_np_in_dict( outdict )
 7559
 7560
 7561def write_bvals_bvecs(bvals, bvecs, prefix ):
 7562    ''' Write FSL FDT bvals and bvecs files
 7563
 7564    adapted from dipy.external code
 7565
 7566    Parameters
 7567    -------------
 7568    bvals : (N,) sequence
 7569       Vector with diffusion gradient strength (one per diffusion
 7570       acquisition, N=no of acquisitions)
 7571    bvecs : (N, 3) array-like
 7572       diffusion gradient directions
 7573    prefix : string
 7574       path to write FDT bvals, bvecs text files
 7575       None results in current working directory.
 7576    '''
 7577    _VAL_FMT = '   %e'
 7578    bvals = tuple(bvals)
 7579    bvecs = np.asarray(bvecs)
 7580    bvecs[np.isnan(bvecs)] = 0
 7581    N = len(bvals)
 7582    fname = prefix + '.bval'
 7583    fmt = _VAL_FMT * N + '\n'
 7584    myfile = open(fname, 'wt')
 7585    myfile.write(fmt % bvals)
 7586    myfile.close()
 7587    fname = prefix + '.bvec'
 7588    bvf = open(fname, 'wt')
 7589    for dim_vals in bvecs.T:
 7590        bvf.write(fmt % tuple(dim_vals))
 7591    bvf.close()
 7592    
 7593
 7594def crop_mcimage( x, mask, padder=None ):
 7595    """
 7596    crop a time series (4D) image by a 3D mask
 7597
 7598    Parameters
 7599    -------------
 7600
 7601    x : raw image
 7602
 7603    mask  : mask for cropping
 7604
 7605    """
 7606    cropmask = ants.crop_image( mask, mask )
 7607    myorig = list( ants.get_origin(cropmask) )
 7608    myorig.append( ants.get_origin( x )[3] )
 7609    croplist = []
 7610    if len(x.shape) > 3:
 7611        for k in range(x.shape[3]):
 7612            temp = ants.slice_image( x, axis=3, idx=k )
 7613            temp = ants.crop_image( temp, mask )
 7614            if padder is not None:
 7615                temp = ants.pad_image( temp, pad_width=padder )
 7616            croplist.append( temp )
 7617        temp = ants.list_to_ndimage( x, croplist )
 7618        temp.set_origin( myorig )
 7619        return temp
 7620    else:
 7621        return( ants.crop_image( x, mask ) )
 7622
 7623
 7624def mm(
 7625    t1_image,
 7626    hier,
 7627    rsf_image=[],
 7628    flair_image=None,
 7629    nm_image_list=None,
 7630    dw_image=[], bvals=[], bvecs=[],
 7631    perfusion_image=None,
 7632    srmodel=None,
 7633    do_tractography = False,
 7634    do_kk = False,
 7635    do_normalization = None,
 7636    group_template = None,
 7637    group_transform = None,
 7638    target_range = [0,1],
 7639    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
 7640    dti_denoise = False,
 7641    perfusion_trim=10,
 7642    perfusion_m0_image=None,
 7643    perfusion_m0=None,
 7644    rsf_upsampling=3.0,
 7645    pet_3d_image=None,
 7646    test_run = False,
 7647    verbose = False ):
 7648    """
 7649    Multiple modality processing and normalization
 7650
 7651    aggregates modality-specific processing under one roof.  see individual
 7652    modality specific functions for details.
 7653
 7654    Parameters
 7655    -------------
 7656
 7657    t1_image : raw t1 image
 7658
 7659    hier  : output of antspyt1w.hierarchical ( see read hierarchical )
 7660
 7661    rsf_image : list of resting state fmri
 7662
 7663    flair_image : flair
 7664
 7665    nm_image_list : list of neuromelanin images
 7666
 7667    dw_image : list of diffusion weighted images
 7668
 7669    bvals : list of bvals file names
 7670
 7671    bvecs : list of bvecs file names
 7672
 7673    perfusion_image : single perfusion image
 7674
 7675    srmodel : optional srmodel
 7676
 7677    do_tractography : boolean
 7678
 7679    do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)
 7680
 7681    do_normalization : template transformation if available
 7682
 7683    group_template : optional reference template corresponding to the group_transform
 7684
 7685    group_transform : optional transforms corresponding to the group_template
 7686
 7687    target_range : 2-element tuple
 7688        a tuple or array defining the (min, max) of the input image
 7689        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
 7690        intensity. This range should match the mapping used in the training
 7691        of the network.
 7692    
 7693    dti_motion_correct : None Rigid or SyN
 7694
 7695    dti_denoise : boolean
 7696
 7697    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
 7698
 7699    perfusion_m0_image : optional antsImage m0 associated with the perfusion time series
 7700
 7701    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
 7702
 7703    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
 7704
 7705    pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of 
 7706        this image.  we just process it and provide summary information.
 7707
 7708    test_run : boolean 
 7709
 7710    verbose : boolean
 7711
 7712    """
 7713    from os.path import exists
 7714    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 7715    ex_path_mm = os.path.expanduser( "~/.antspymm/" )
 7716    mycsvfn = ex_path + "FA_JHU_labels_edited.csv"
 7717    citcsvfn = ex_path + "CIT168_Reinf_Learn_v1_label_descriptions_pad.csv"
 7718    dktcsvfn = ex_path + "dkt.csv"
 7719    cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
 7720    JHU_atlasfn = ex_path + 'JHU-ICBM-FA-1mm.nii.gz' # Read in JHU atlas
 7721    JHU_labelsfn = ex_path + 'JHU-ICBM-labels-1mm.nii.gz' # Read in JHU labels
 7722    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 7723    if not exists( mycsvfn ) or not exists( citcsvfn ) or not exists( cnxcsvfn ) or not exists( dktcsvfn ) or not exists( JHU_atlasfn ) or not exists( JHU_labelsfn ) or not exists( templatefn ):
 7724        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 7725        raise ValueError('**missing files** => call get_data from latest antspyt1w and antspymm.')
 7726    mycsv = pd.read_csv(  mycsvfn )
 7727    citcsv = pd.read_csv(  os.path.expanduser( citcsvfn ) )
 7728    dktcsv = pd.read_csv(  os.path.expanduser( dktcsvfn ) )
 7729    cnxcsv = pd.read_csv(  os.path.expanduser( cnxcsvfn ) )
 7730    JHU_atlas = mm_read( JHU_atlasfn ) # Read in JHU atlas
 7731    JHU_labels = mm_read( JHU_labelsfn ) # Read in JHU labels
 7732    template = mm_read( templatefn ) # Read in template
 7733    if group_template is None:
 7734        group_template = template
 7735        group_transform = do_normalization['fwdtransforms']
 7736    if verbose:
 7737        print("Using group template:")
 7738        print( group_template )
 7739    #####################
 7740    #  T1 hierarchical  #
 7741    #####################
 7742    t1imgbrn = hier['brain_n4_dnz']
 7743    t1atropos = hier['dkt_parc']['tissue_segmentation']
 7744    output_dict = {
 7745        'kk': None,
 7746        'rsf': None,
 7747        'flair' : None,
 7748        'NM' : None,
 7749        'DTI' : None,
 7750        'FA_summ' : None,
 7751        'MD_summ' : None,
 7752        'tractography' : None,
 7753        'tractography_connectivity' : None,
 7754        'perf' : None,
 7755        'pet3d' : None,
 7756    }
 7757    normalization_dict = {
 7758        'kk_norm': None,
 7759        'NM_norm' : None,
 7760        'DTI_norm': None,
 7761        'FA_norm' : None,
 7762        'MD_norm' : None,
 7763        'perf_norm' : None,
 7764        'alff_norm' : None,
 7765        'falff_norm' : None,
 7766        'CinguloopercularTaskControl_norm' : None,
 7767        'DefaultMode_norm' : None,
 7768        'MemoryRetrieval_norm' : None,
 7769        'VentralAttention_norm' : None,
 7770        'Visual_norm' : None,
 7771        'FrontoparietalTaskControl_norm' : None,
 7772        'Salience_norm' : None,
 7773        'Subcortical_norm' : None,
 7774        'DorsalAttention_norm' : None,
 7775        'pet3d_norm' : None
 7776    }
 7777    if test_run:
 7778        return output_dict, normalization_dict
 7779
 7780    if do_kk:
 7781        if verbose:
 7782            print('kk in mm')
 7783        output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( t1_image,
 7784            labels=hier['dkt_parc']['dkt_cortex'], iterations=45 )
 7785
 7786    if perfusion_image is not None:
 7787        if perfusion_image.shape[3] > 1: # FIXME - better heuristic?
 7788            output_dict['perf'] = bold_perfusion(
 7789                perfusion_image,
 7790                t1_image,
 7791                hier['brain_n4_dnz'],
 7792                t1atropos,
 7793                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
 7794                n_to_trim = perfusion_trim,
 7795                m0_image = perfusion_m0_image,
 7796                m0_indices = perfusion_m0,
 7797                verbose=verbose )
 7798
 7799    if pet_3d_image is not None:
 7800        if pet_3d_image.dimension == 3: # FIXME - better heuristic?
 7801            output_dict['pet3d'] = pet3d_summary(
 7802                pet_3d_image,
 7803                t1_image,
 7804                hier['brain_n4_dnz'],
 7805                t1atropos,
 7806                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
 7807                verbose=verbose )
 7808    ################################## do the rsf .....
 7809    if len(rsf_image) > 0:
 7810        my_motion_tx = 'antsRegistrationSyNRepro[r]'
 7811        rsf_image = [i for i in rsf_image if i is not None]
 7812        if verbose:
 7813            print('rsf length ' + str( len( rsf_image ) ) )
 7814        if len( rsf_image ) >= 2: # assume 2 is the largest possible value
 7815            rsf_image1 = rsf_image[0]
 7816            rsf_image2 = rsf_image[1]
 7817            # build a template then join the images
 7818            if verbose:
 7819                print("initial average for rsf")
 7820            rsfavg1, hlinds = loop_timeseries_censoring( rsf_image1, 0.1 )
 7821            rsfavg1=get_average_rsf(rsfavg1)
 7822            rsfavg2, hlinds = loop_timeseries_censoring( rsf_image2, 0.1 )
 7823            rsfavg2=get_average_rsf(rsfavg2)
 7824            if verbose:
 7825                print("template average for rsf")
 7826            init_temp = ants.image_clone( rsfavg1 )
 7827            if rsf_image1.shape[3] < rsf_image2.shape[3]:
 7828                init_temp = ants.image_clone( rsfavg2 )
 7829            boldTemplate = ants.build_template(
 7830                initial_template = init_temp,
 7831                image_list=[rsfavg1,rsfavg2],
 7832                type_of_transform="antsRegistrationSyNQuickRepro[s]",
 7833                iterations=5, verbose=False )
 7834            if verbose:
 7835                print("join the 2 rsf")
 7836            if rsf_image1.shape[3] > 10 and rsf_image2.shape[3] > 10:
 7837                leadvols = list(range(8))
 7838                rsf_image2 = remove_volumes_from_timeseries( rsf_image2, leadvols )
 7839                rsf_image = merge_timeseries_data( rsf_image1, rsf_image2 )
 7840            elif rsf_image1.shape[3] > rsf_image2.shape[3]:
 7841                rsf_image = rsf_image1
 7842            else:
 7843                rsf_image = rsf_image2
 7844        elif len( rsf_image ) == 1:
 7845            rsf_image = rsf_image[0]
 7846            boldTemplate, hlinds = loop_timeseries_censoring( rsf_image, 0.1 )
 7847            boldTemplate = get_average_rsf(boldTemplate)
 7848        if rsf_image.shape[3] > 10: # FIXME - better heuristic?
 7849            rsfprolist = [] # FIXMERSF
 7850            # Create the parameter DataFrame
 7851            df = pd.DataFrame({
 7852                "num": [134, 122, 129],
 7853                "loop": [0.50, 0.25, 0.50],
 7854                "cens": [True, True, True],
 7855                "HM": [1.0, 5.0, 0.5],
 7856                "ff": ["tight", "tight", "tight"],
 7857                "CC": [5, 5, 0.8],
 7858                "imp": [True, True, True],
 7859                "up": [rsf_upsampling, rsf_upsampling, rsf_upsampling],
 7860                "coords": [False,False,False]
 7861            }, index=[0, 1, 2])
 7862            for p in range(df.shape[0]):
 7863                if verbose:
 7864                    print("rsf parameters")
 7865                    print( df.iloc[p] )
 7866                if df['ff'].iloc[p] == 'broad':
 7867                    f=[ 0.008, 0.15 ]
 7868                elif df['ff'].iloc[p] == 'tight':
 7869                    f=[ 0.03, 0.08 ]
 7870                elif df['ff'].iloc[p] == 'mid':
 7871                    f=[ 0.01, 0.1 ]
 7872                elif df['ff'].iloc[p] == 'mid2':
 7873                    f=[ 0.01, 0.08 ]
 7874                else:
 7875                    raise ValueError("we do not recognize this parameter choice for frequency filtering: " + df['ff'].iloc[p] )
 7876                HM = df['HM'].iloc[p]
 7877                CC = df['CC'].iloc[p]
 7878                loop= df['loop'].iloc[p]
 7879                cens =df['cens'].iloc[p]
 7880                imp = df['imp'].iloc[p]
 7881                rsf0 = resting_state_fmri_networks(
 7882                                            rsf_image,
 7883                                            boldTemplate,
 7884                                            hier['brain_n4_dnz'],
 7885                                            t1atropos,
 7886                                            f=f,
 7887                                            FD_threshold=HM, 
 7888                                            spa = None, 
 7889                                            spt = None, 
 7890                                            nc = CC,
 7891                                            outlier_threshold=loop,
 7892                                            ica_components = 0,
 7893                                            impute = imp,
 7894                                            censor = cens,
 7895                                            despike = 2.5,
 7896                                            motion_as_nuisance = True,
 7897                                            upsample=df['up'].iloc[p],
 7898                                            clean_tmp=0.66,
 7899                                            paramset=df['num'].iloc[p],
 7900                                            powers=df['coords'].iloc[p],
 7901                                            verbose=verbose ) # default
 7902                rsfprolist.append( rsf0 )
 7903            output_dict['rsf'] = rsfprolist
 7904
 7905    if nm_image_list is not None:
 7906        if verbose:
 7907            print('nm')
 7908        if srmodel is None:
 7909            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], verbose=verbose )
 7910        else:
 7911            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], srmodel=srmodel, target_range=target_range, verbose=verbose  )
 7912################################## do the dti .....
 7913    if len(dw_image) > 0 :
 7914        if verbose:
 7915            print('dti-x')
 7916        if len( dw_image ) == 1: # use T1 for distortion correction and brain extraction
 7917            if verbose:
 7918                print("We have only one DTI: " + str(len(dw_image)))
 7919            dw_image = dw_image[0]
 7920            btpB0, btpDW = get_average_dwi_b0(dw_image)
 7921            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
 7922            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
 7923                syn_metric='CC', syn_sampling=2,
 7924                reg_iterations=[50,50,20],
 7925                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
 7926                initial_transform=initrig
 7927                )
 7928            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
 7929            btpDW = ants.apply_transforms( btpDW, btpDW,
 7930                tempreg['invtransforms'][1], interpolator='linear')
 7931            btpB0 = ants.apply_transforms( btpB0, btpB0,
 7932                tempreg['invtransforms'][1], interpolator='linear')
 7933            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'][1], interpolator='nearestNeighbor')
 7934            # dwimask = ants.iMath(dwimask,'MD',1)
 7935            t12dwi = ants.apply_transforms( btpDW, hier['brain_n4_dnz'], tempreg['fwdtransforms'][1], interpolator='linear')
 7936            output_dict['DTI'] = joint_dti_recon(
 7937                dw_image,
 7938                bvals[0],
 7939                bvecs[0],
 7940                jhu_atlas=JHU_atlas,
 7941                jhu_labels=JHU_labels,
 7942                brain_mask = dwimask,
 7943                reference_B0 = btpB0,
 7944                reference_DWI = btpDW,
 7945                srmodel=srmodel,
 7946                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
 7947                denoise=dti_denoise,
 7948                verbose = verbose)
 7949        else :  # use phase encoding acquisitions for distortion correction and T1 for brain extraction
 7950            if verbose:
 7951                print("We have both DTI_LR and DTI_RL: " + str(len(dw_image)))
 7952            a1b,a1w=get_average_dwi_b0(dw_image[0])
 7953            a2b,a2w=get_average_dwi_b0(dw_image[1],fixed_b0=a1b,fixed_dwi=a1w)
 7954            btpB0, btpDW = dti_template(
 7955                b_image_list=[a1b,a2b],
 7956                w_image_list=[a1w,a2w],
 7957                iterations=7, verbose=verbose )
 7958            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
 7959            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
 7960                syn_metric='CC', syn_sampling=2,
 7961                reg_iterations=[50,50,20],
 7962                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
 7963                initial_transform=initrig
 7964                )
 7965            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
 7966            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'], interpolator='nearestNeighbor')
 7967            output_dict['DTI'] = joint_dti_recon(
 7968                dw_image[0],
 7969                bvals[0],
 7970                bvecs[0],
 7971                jhu_atlas=JHU_atlas,
 7972                jhu_labels=JHU_labels,
 7973                brain_mask = dwimask,
 7974                reference_B0 = btpB0,
 7975                reference_DWI = btpDW,
 7976                srmodel=srmodel,
 7977                img_RL=dw_image[1],
 7978                bval_RL=bvals[1],
 7979                bvec_RL=bvecs[1],
 7980                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
 7981                denoise=dti_denoise,
 7982                verbose = verbose)
 7983        mydti = output_dict['DTI']
 7984        # summarize dwi with T1 outputs
 7985        # first - register ....
 7986        reg = ants.registration( mydti['recon_fa'], hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[s]', total_sigma=1.0 )
 7987        ##################################################
 7988        output_dict['FA_summ'] = hierarchical_modality_summary(
 7989            mydti['recon_fa'],
 7990            hier=hier,
 7991            modality_name='fa',
 7992            transformlist=reg['fwdtransforms'],
 7993            verbose = False )
 7994        ##################################################
 7995        output_dict['MD_summ'] = hierarchical_modality_summary(
 7996            mydti['recon_md'],
 7997            hier=hier,
 7998            modality_name='md',
 7999            transformlist=reg['fwdtransforms'],
 8000            verbose = False )
 8001        # these inputs should come from nicely processed data
 8002        dktmapped = ants.apply_transforms(
 8003            mydti['recon_fa'],
 8004            hier['dkt_parc']['dkt_cortex'],
 8005            reg['fwdtransforms'], interpolator='nearestNeighbor' )
 8006        citmapped = ants.apply_transforms(
 8007            mydti['recon_fa'],
 8008            hier['cit168lab'],
 8009            reg['fwdtransforms'], interpolator='nearestNeighbor' )
 8010        dktmapped[ citmapped > 0]=0
 8011        mask = ants.threshold_image( mydti['recon_fa'], 0.01, 2.0 ).iMath("GetLargestComponent")
 8012        if do_tractography: # dwi_deterministic_tracking dwi_closest_peak_tracking
 8013            output_dict['tractography'] = dwi_deterministic_tracking(
 8014                mydti['dwi_LR_dewarped'],
 8015                mydti['recon_fa'],
 8016                mydti['bval_LR'],
 8017                mydti['bvec_LR'],
 8018                seed_density = 1,
 8019                mask=mask,
 8020                verbose = verbose )
 8021            mystr = output_dict['tractography']
 8022            output_dict['tractography_connectivity'] = dwi_streamline_connectivity( mystr['streamlines'], dktmapped+citmapped, cnxcsv, verbose=verbose )
 8023    ################################## do the flair .....
 8024    if flair_image is not None:
 8025        if verbose:
 8026            print('flair')
 8027        wmhprior = None
 8028        priorfn = ex_path_mm + 'CIT168_wmhprior_700um_pad_adni.nii.gz'
 8029        if ( exists( priorfn ) ):
 8030            wmhprior = ants.image_read( priorfn )
 8031            wmhprior = ants.apply_transforms( t1_image, wmhprior, do_normalization['invtransforms'] )
 8032        output_dict['flair'] = boot_wmh( flair_image, t1_image, t1atropos,
 8033            prior_probability=wmhprior, verbose=verbose )
 8034    #################################################################
 8035    ### NOTES: deforming to a common space and writing out images ###
 8036    ### images we want come from: DTI, NM, rsf, thickness ###########
 8037    #################################################################
 8038    if do_normalization is not None:
 8039        if verbose:
 8040            print('normalization')
 8041        # might reconsider this template space - cropped and/or higher res?
 8042        # template = ants.resample_image( template, [1,1,1], use_voxels=False )
 8043        # t1reg = ants.registration( template, hier['brain_n4_dnz'], "antsRegistrationSyNQuickRepro[s]")
 8044        t1reg = do_normalization
 8045        if do_kk:
 8046            normalization_dict['kk_norm'] = ants.apply_transforms( group_template, output_dict['kk']['thickness_image'], group_transform )
 8047        if output_dict['DTI'] is not None:
 8048            mydti = output_dict['DTI']
 8049            dtirig = ants.registration( hier['brain_n4_dnz'], mydti['recon_fa'], 'antsRegistrationSyNRepro[r]' )
 8050            normalization_dict['MD_norm'] = ants.apply_transforms( group_template, mydti['recon_md'],group_transform+dtirig['fwdtransforms'] )
 8051            normalization_dict['FA_norm'] = ants.apply_transforms( group_template, mydti['recon_fa'],group_transform+dtirig['fwdtransforms'] )
 8052            output_directory = tempfile.mkdtemp()
 8053            do_dti_norm=False
 8054            if do_dti_norm:
 8055                comptx = ants.apply_transforms( group_template, group_template, group_transform+dtirig['fwdtransforms'], compose = output_directory + '/xxx' )
 8056                tspc=[2.,2.,2.]
 8057                if srmodel is not None:
 8058                    tspc=[1.,1.,1.]
 8059                group_template2mm = ants.resample_image( group_template, tspc  )
 8060                normalization_dict['DTI_norm'] = transform_and_reorient_dti( group_template2mm, mydti['dti'], comptx, verbose=False )
 8061            import shutil
 8062            shutil.rmtree(output_directory, ignore_errors=True )
 8063        if output_dict['rsf'] is not None:
 8064            if False:
 8065                rsfpro = output_dict['rsf'] # FIXME
 8066                rsfrig = ants.registration( hier['brain_n4_dnz'], rsfpro['meanBold'], 'antsRegistrationSyNRepro[r]' )
 8067                for netid in get_antsimage_keys( rsfpro ):
 8068                    rsfkey = netid + "_norm"
 8069                    normalization_dict[rsfkey] = ants.apply_transforms(
 8070                        group_template, rsfpro[netid],
 8071                        group_transform+rsfrig['fwdtransforms'] )
 8072        if output_dict['perf'] is not None: # zizzer
 8073            comptx = group_transform + output_dict['perf']['t1reg']['invtransforms']
 8074            normalization_dict['perf_norm'] = ants.apply_transforms( group_template,
 8075                output_dict['perf']['perfusion'], comptx,
 8076                whichtoinvert=[False,False,True,False] )
 8077            normalization_dict['cbf_norm'] = ants.apply_transforms( group_template,
 8078                output_dict['perf']['cbf'], comptx,
 8079                whichtoinvert=[False,False,True,False] )
 8080        if output_dict['pet3d'] is not None: # zizzer
 8081            secondTx=output_dict['pet3d']['t1reg']['invtransforms']
 8082            comptx = group_transform + secondTx
 8083            if len( secondTx ) == 2:
 8084                wti=[False,False,True,False]
 8085            else:
 8086                wti=[False,False,True]
 8087            normalization_dict['pet3d_norm'] = ants.apply_transforms( group_template,
 8088                output_dict['pet3d']['pet3d'], comptx,
 8089                whichtoinvert=wti )
 8090        if nm_image_list is not None:
 8091            nmpro = output_dict['NM']
 8092            nmrig = nmpro['t1_to_NM_transform'] # this is an inverse tx
 8093            normalization_dict['NM_norm'] = ants.apply_transforms( group_template, nmpro['NM_avg'], group_transform+nmrig,
 8094                whichtoinvert=[False,False,True])
 8095
 8096    if verbose:
 8097        print('mm done')
 8098    return output_dict, normalization_dict
 8099
 8100
 8101def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False ):
 8102    """
 8103    write the tabular and normalization output of the mm function
 8104
 8105    Parameters
 8106    -------------
 8107
 8108    output_prefix : prefix for file outputs - modality specific postfix will be added
 8109
 8110    mm  : output of mm function for modality-space processing should be a dictionary with 
 8111        dictionary entries for each modality.
 8112
 8113    mm_norm : output of mm function for normalized processing
 8114
 8115    t1wide : wide output data frame from t1 hierarchical
 8116
 8117    separator : string or character separator for filenames
 8118
 8119    verbose : boolean
 8120
 8121    Returns
 8122    ---------
 8123
 8124    both csv and image files written to disk.  the primary outputs will be
 8125    output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images
 8126
 8127    """
 8128    from dipy.io.streamline import save_tractogram
 8129    if mm_norm is not None:
 8130        for mykey in mm_norm:
 8131            tempfn = output_prefix + separator + mykey + '.nii.gz'
 8132            if mm_norm[mykey] is not None:
 8133                image_write_with_thumbnail( mm_norm[mykey], tempfn )
 8134    thkderk = None
 8135    if t1wide is not None:
 8136        thkderk = t1wide.iloc[: , 1:]
 8137    kkderk = None
 8138    if 'kk' in mm:
 8139        if mm['kk'] is not None:
 8140            kkderk = mm['kk']['thickness_dataframe'].iloc[: , 1:]
 8141            mykey='thickness_image'
 8142            tempfn = output_prefix + separator + mykey + '.nii.gz'
 8143            image_write_with_thumbnail( mm['kk'][mykey], tempfn )
 8144    nmderk = None
 8145    if 'NM' in mm:
 8146        if mm['NM'] is not None:
 8147            nmderk = mm['NM']['NM_dataframe_wide'].iloc[: , 1:]
 8148            for mykey in get_antsimage_keys( mm['NM'] ):
 8149                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8150                image_write_with_thumbnail( mm['NM'][mykey], tempfn, thumb=False )
 8151
 8152    faderk = mdderk = fat1derk = mdt1derk = None
 8153
 8154    if 'DTI' in mm:
 8155        if mm['DTI'] is not None:
 8156            mydti = mm['DTI']
 8157            myop = output_prefix + separator
 8158            ants.image_write( mydti['dti'],  myop + 'dti.nii.gz' )
 8159            write_bvals_bvecs( mydti['bval_LR'], mydti['bvec_LR'], myop + 'reoriented' )
 8160            image_write_with_thumbnail( mydti['dwi_LR_dewarped'],  myop + 'dwi.nii.gz' )
 8161            image_write_with_thumbnail( mydti['dtrecon_LR_dewarp']['RGB'] ,  myop + 'DTIRGB.nii.gz' )
 8162            image_write_with_thumbnail( mydti['jhu_labels'],  myop+'dtijhulabels.nii.gz', mydti['recon_fa'] )
 8163            image_write_with_thumbnail( mydti['recon_fa'],  myop+'dtifa.nii.gz' )
 8164            image_write_with_thumbnail( mydti['recon_md'],  myop+'dtimd.nii.gz' )
 8165            image_write_with_thumbnail( mydti['b0avg'],  myop+'b0avg.nii.gz' )
 8166            image_write_with_thumbnail( mydti['dwiavg'],  myop+'dwiavg.nii.gz' )
 8167            faderk = mm['DTI']['recon_fa_summary'].iloc[: , 1:]
 8168            mdderk = mm['DTI']['recon_md_summary'].iloc[: , 1:]
 8169            fat1derk = mm['FA_summ'].iloc[: , 1:]
 8170            mdt1derk = mm['MD_summ'].iloc[: , 1:]
 8171    if 'tractography' in mm:
 8172        if mm['tractography'] is not None:
 8173            ofn = output_prefix + separator + 'tractogram.trk'
 8174            if mm['tractography']['tractogram'] is not None:
 8175                save_tractogram( mm['tractography']['tractogram'], ofn )
 8176    cnxderk = None
 8177    if 'tractography_connectivity' in mm:
 8178        if mm['tractography_connectivity'] is not None:
 8179            cnxderk = mm['tractography_connectivity']['connectivity_wide'].iloc[: , 1:] # NOTE: connectivity_wide is not much tested
 8180            ofn = output_prefix + separator + 'dtistreamlineconn.csv'
 8181            pd.DataFrame(mm['tractography_connectivity']['connectivity_matrix']).to_csv( ofn )
 8182
 8183    dlist = [
 8184        thkderk,
 8185        kkderk,
 8186        nmderk,
 8187        faderk,
 8188        mdderk,
 8189        fat1derk,
 8190        mdt1derk,
 8191        cnxderk
 8192        ]
 8193    is_all_none = all(element is None for element in dlist)
 8194    if is_all_none:
 8195        mm_wide = pd.DataFrame({'u_hier_id': [output_prefix] })
 8196    else:
 8197        mm_wide = pd.concat( dlist, axis=1, ignore_index=False )
 8198
 8199    mm_wide = mm_wide.copy()
 8200    if 'NM' in mm:
 8201        if mm['NM'] is not None:
 8202            nmwide = dict_to_dataframe( mm['NM'] )
 8203            if mm_wide.shape[0] > 0 and nmwide.shape[0] > 0:
 8204                nmwide.set_index( mm_wide.index, inplace=True )
 8205            mm_wide = pd.concat( [mm_wide, nmwide ], axis=1, ignore_index=False )
 8206    if 'flair' in mm:
 8207        if mm['flair'] is not None:
 8208            myop = output_prefix + separator + 'wmh.nii.gz'
 8209            pngfnb = output_prefix + separator + 'wmh_seg.png'
 8210            ants.plot( mm['flair']['flair'], mm['flair']['WMH_posterior_probability_map'], axis=2, nslices=21, ncol=7, filename=pngfnb, crop=True )
 8211            if mm['flair']['WMH_probability_map'] is not None:
 8212                image_write_with_thumbnail( mm['flair']['WMH_probability_map'], myop, thumb=False )
 8213            flwide = dict_to_dataframe( mm['flair'] )
 8214            if mm_wide.shape[0] > 0 and flwide.shape[0] > 0:
 8215                flwide.set_index( mm_wide.index, inplace=True )
 8216            mm_wide = pd.concat( [mm_wide, flwide ], axis=1, ignore_index=False )
 8217    if 'rsf' in mm:
 8218        if mm['rsf'] is not None:
 8219            fcnxpro=99
 8220            rsfdata = mm['rsf']
 8221            if not isinstance( rsfdata, list ):
 8222                rsfdata = [ rsfdata ]
 8223            for rsfpro in rsfdata:
 8224                fcnxpro=str( rsfpro['paramset']  )
 8225                pronum = 'fcnxpro'+str(fcnxpro)+"_"
 8226                if verbose:
 8227                    print("Collect rsf data " + pronum)
 8228                new_rsf_wide = dict_to_dataframe( rsfpro )
 8229                new_rsf_wide = pd.concat( [new_rsf_wide, rsfpro['corr_wide'] ], axis=1, ignore_index=False )
 8230                new_rsf_wide = new_rsf_wide.add_prefix( pronum )
 8231                new_rsf_wide.set_index( mm_wide.index, inplace=True )
 8232                ofn = output_prefix + separator + pronum + '.csv'
 8233                new_rsf_wide.to_csv( ofn )
 8234                mm_wide = pd.concat( [mm_wide, new_rsf_wide ], axis=1, ignore_index=False )
 8235                for mykey in get_antsimage_keys( rsfpro ):
 8236                    myop = output_prefix + separator + pronum + mykey + '.nii.gz'
 8237                    image_write_with_thumbnail( rsfpro[mykey], myop, thumb=True )
 8238                ofn = output_prefix + separator + pronum + 'rsfcorr.csv'
 8239                rsfpro['corr'].to_csv( ofn )
 8240                # apply same principle to new correlation matrix, doesn't need to be incorporated with mm_wide
 8241                ofn2 = output_prefix + separator + pronum + 'nodescorr.csv'
 8242                rsfpro['fullCorrMat'].to_csv( ofn2 )
 8243    if 'DTI' in mm:
 8244        if mm['DTI'] is not None:
 8245            mydti = mm['DTI']
 8246            mm_wide['dti_tsnr_b0_mean'] =  mydti['tsnr_b0'].mean()
 8247            mm_wide['dti_tsnr_dwi_mean'] =  mydti['tsnr_dwi'].mean()
 8248            mm_wide['dti_dvars_b0_mean'] =  mydti['dvars_b0'].mean()
 8249            mm_wide['dti_dvars_dwi_mean'] =  mydti['dvars_dwi'].mean()
 8250            mm_wide['dti_ssnr_b0_mean'] =  mydti['ssnr_b0'].mean()
 8251            mm_wide['dti_ssnr_dwi_mean'] =  mydti['ssnr_dwi'].mean()
 8252            mm_wide['dti_fa_evr'] =  mydti['fa_evr']
 8253            mm_wide['dti_fa_SNR'] =  mydti['fa_SNR']
 8254            if mydti['framewise_displacement'] is not None:
 8255                mm_wide['dti_high_motion_count'] =  mydti['high_motion_count']
 8256                mm_wide['dti_FD_mean'] = mydti['framewise_displacement'].mean()
 8257                mm_wide['dti_FD_max'] = mydti['framewise_displacement'].max()
 8258                mm_wide['dti_FD_sd'] = mydti['framewise_displacement'].std()
 8259                fdfn = output_prefix + separator + '_fd.csv'
 8260            else:
 8261                mm_wide['dti_FD_mean'] = mm_wide['dti_FD_max'] = mm_wide['dti_FD_sd'] = 'NA'
 8262
 8263    if 'perf' in mm:
 8264        if mm['perf'] is not None:
 8265            perfpro = mm['perf']
 8266            prwide = dict_to_dataframe( perfpro )
 8267            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
 8268                prwide.set_index( mm_wide.index, inplace=True )
 8269            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
 8270            if 'perf_dataframe' in perfpro.keys():
 8271                pderk = perfpro['perf_dataframe'].iloc[: , 1:]
 8272                pderk.set_index( mm_wide.index, inplace=True )
 8273                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
 8274            else:
 8275                print("FIXME - perfusion dataframe")
 8276            for mykey in get_antsimage_keys( mm['perf'] ):
 8277                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8278                image_write_with_thumbnail( mm['perf'][mykey], tempfn, thumb=False )
 8279
 8280    if 'pet3d' in mm:
 8281        if mm['pet3d'] is not None:
 8282            pet3dpro = mm['pet3d']
 8283            prwide = dict_to_dataframe( pet3dpro )
 8284            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
 8285                prwide.set_index( mm_wide.index, inplace=True )
 8286            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
 8287            if 'pet3d_dataframe' in pet3dpro.keys():
 8288                pderk = pet3dpro['pet3d_dataframe'].iloc[: , 1:]
 8289                pderk.set_index( mm_wide.index, inplace=True )
 8290                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
 8291            else:
 8292                print("FIXME - pet3dusion dataframe")
 8293            for mykey in get_antsimage_keys( mm['pet3d'] ):
 8294                tempfn = output_prefix + separator + mykey + '.nii.gz'
 8295                image_write_with_thumbnail( mm['pet3d'][mykey], tempfn, thumb=False )
 8296
 8297    mmwidefn = output_prefix + separator + 'mmwide.csv'
 8298    mm_wide.to_csv( mmwidefn )
 8299    if verbose:
 8300        print( output_prefix + " write_mm done." )
 8301    return
 8302
 8303
 8304def mm_nrg(
 8305    studyid,   # pandas data frame
 8306    sourcedir = os.path.expanduser( "~/data/PPMI/MV/example_s3_b/images/PPMI/" ),
 8307    sourcedatafoldername = 'images', # root for source data
 8308    processDir = "processed", # where output will go - parallel to sourcedatafoldername
 8309    mysep = '-', # define a separator for filename components
 8310    srmodel_T1 = False, # optional - will add a great deal of time
 8311    srmodel_NM = False, # optional - will add a great deal of time
 8312    srmodel_DTI = False, # optional - will add a great deal of time
 8313    visualize = True,
 8314    nrg_modality_list = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI" ],
 8315    verbose = True
 8316):
 8317    """
 8318    too dangerous to document ... use with care.
 8319
 8320    processes multiple modality MRI specifically:
 8321
 8322    * T1w
 8323    * T2Flair
 8324    * DTI, DTI_LR, DTI_RL
 8325    * rsfMRI, rsfMRI_LR, rsfMRI_RL
 8326    * NM2DMT (neuromelanin)
 8327
 8328    other modalities may be added later ...
 8329
 8330    "trust me, i know what i'm doing" - sledgehammer
 8331
 8332    convert to pynb via:
 8333        p2j mm.py -o
 8334
 8335    convert the ipynb to html via:
 8336        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
 8337
 8338    this function assumes NRG format for the input data ....
 8339    we also assume that t1w hierarchical (if already done) was written
 8340    via its standardized write function.
 8341    NRG = https://github.com/stnava/biomedicalDataOrganization
 8342
 8343    this function is verbose
 8344
 8345    Parameters
 8346    -------------
 8347
 8348    studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID
 8349        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
 8350        these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor,
 8351        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
 8352        t1 is required.  rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.  see antspymm.generate_mm_dataframe
 8353
 8354    sourcedir : a study specific folder containing individual subject folders
 8355
 8356    sourcedatafoldername : root for source data e.g. "images"
 8357
 8358    processDir : where output will go - parallel to sourcedatafoldername e.g.
 8359        "processed"
 8360
 8361    mysep : define a character separator for filename components
 8362
 8363    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
 8364
 8365    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8366
 8367    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8368
 8369    visualize : True - will plot some results to png
 8370
 8371    nrg_modality_list : list of permissible modalities - always include [T1w] as base
 8372
 8373    verbose : boolean
 8374
 8375    Returns
 8376    ---------
 8377
 8378    writes output to disk and potentially produces figures that may be
 8379    captured in a ipynb / html file.
 8380
 8381    """
 8382    studyid = studyid.dropna(axis=1)
 8383    if studyid.shape[0] < 1:
 8384        raise ValueError('studyid has no rows')
 8385    musthavecols = ['subjectID','date','imageID']
 8386    for k in range(len(musthavecols)):
 8387        if not musthavecols[k] in studyid.keys():
 8388            raise ValueError('studyid is missing column ' +musthavecols[k] )
 8389    def makewideout( x, separator = '-' ):
 8390        return x + separator + 'mmwide.csv'
 8391    if nrg_modality_list[0] != 'T1w':
 8392        nrg_modality_list.insert(0, "T1w" )
 8393    testloop = False
 8394    counter=0
 8395    import glob as glob
 8396    from os.path import exists
 8397    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8398    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
 8399    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8400    if not exists( templatefn ):
 8401        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 8402        antspyt1w.get_data( force_download=True )
 8403        get_data( force_download=True )
 8404    temp = sourcedir.split( "/" )
 8405    splitCount = len( temp )
 8406    template = mm_read( templatefn ) # Read in template
 8407    test_run = False
 8408    if test_run:
 8409        visualize=False
 8410    # get sid and dtid from studyid
 8411    sid = str(studyid['subjectID'].iloc[0])
 8412    dtid = str(studyid['date'].iloc[0])
 8413    iid = str(studyid['imageID'].iloc[0])
 8414    subjectrootpath = os.path.join(sourcedir,sid, dtid)
 8415    if verbose:
 8416        print("subjectrootpath: "+ subjectrootpath )
 8417    myimgsInput = glob.glob( subjectrootpath+"/*" )
 8418    myimgsInput.sort( )
 8419    if verbose:
 8420        print( myimgsInput )
 8421    # hierarchical
 8422    # NOTE: if there are multiple T1s for this time point, should take
 8423    # the one with the highest resnetGrade
 8424    t1_search_path = os.path.join(subjectrootpath, "T1w", iid, "*nii.gz")
 8425    if verbose:
 8426        print(f"t1 search path: {t1_search_path}")
 8427    t1fn = glob.glob(t1_search_path)
 8428    t1fn.sort()
 8429    if len( t1fn ) < 1:
 8430        raise ValueError('mm_nrg cannot find the T1w with uid ' + iid + ' @ ' + subjectrootpath )
 8431    t1fn = t1fn[0]
 8432    t1 = mm_read( t1fn )
 8433    hierfn0 = re.sub( sourcedatafoldername, processDir, t1fn)
 8434    hierfn0 = re.sub( ".nii.gz", "", hierfn0)
 8435    hierfn = re.sub( "T1w", "T1wHierarchical", hierfn0)
 8436    hierfn = hierfn + mysep
 8437    hierfntest = hierfn + 'snseg.csv'
 8438    regout = hierfn0 + mysep + "syn"
 8439    templateTx = {
 8440        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
 8441        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
 8442    if verbose:
 8443        print( "-<REGISTRATION EXISTENCE>-: \n" + 
 8444              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
 8445            str(exists( templateTx['fwdtransforms'][0])) + " " +
 8446            str(exists( templateTx['fwdtransforms'][1])) + " " +
 8447            str(exists( templateTx['invtransforms'][0])) + " " +
 8448            str(exists( templateTx['invtransforms'][1])) )
 8449    if verbose:
 8450        print( hierfntest )
 8451    hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 8452    hier = None
 8453    if not hierexists and not testloop:
 8454        subjectpropath = os.path.dirname( hierfn )
 8455        if verbose:
 8456            print( subjectpropath )
 8457        os.makedirs( subjectpropath, exist_ok=True  )
 8458        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
 8459        antspyt1w.write_hierarchical( hier, hierfn )
 8460        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8461                hier['dataframes'], identifier=None )
 8462        t1wide.to_csv( hierfn + 'mmwide.csv' )
 8463    ################# read the hierarchical data ###############################
 8464    hier = antspyt1w.read_hierarchical( hierfn )
 8465    if exists( hierfn + 'mmwide.csv' ) :
 8466        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 8467    elif not testloop:
 8468        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8469                hier['dataframes'], identifier=None )
 8470    if srmodel_T1 is not False :
 8471        hierfnSR = re.sub( sourcedatafoldername, processDir, t1fn)
 8472        hierfnSR = re.sub( "T1w", "T1wHierarchicalSR", hierfnSR)
 8473        hierfnSR = re.sub( ".nii.gz", "", hierfnSR)
 8474        hierfnSR = hierfnSR + mysep
 8475        hierfntest = hierfnSR + 'mtl.csv'
 8476        if verbose:
 8477            print( hierfntest )
 8478        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 8479        if not hierexists:
 8480            subjectpropath = os.path.dirname( hierfnSR )
 8481            if verbose:
 8482                print( subjectpropath )
 8483            os.makedirs( subjectpropath, exist_ok=True  )
 8484            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
 8485            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
 8486            mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_2chan_featvggL6_postseg_best_mdl.h5"
 8487            if isinstance( srmodel_T1, str ):
 8488                mdlfn = os.path.join( ex_pathmm, srmodel_T1 )
 8489            if verbose:
 8490                print( mdlfn )
 8491            if exists( mdlfn ):
 8492                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 8493            else:
 8494                print( mdlfn + " does not exist - will not run.")
 8495            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
 8496            antspyt1w.write_hierarchical( hierSR, hierfnSR )
 8497            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8498                    hierSR['dataframes'], identifier=None )
 8499            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
 8500    hier = antspyt1w.read_hierarchical( hierfn )
 8501    if exists( hierfn + 'mmwide.csv' ) :
 8502        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 8503    elif not testloop:
 8504        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 8505                hier['dataframes'], identifier=None )
 8506    if not testloop:
 8507        t1imgbrn = hier['brain_n4_dnz']
 8508        t1atropos = hier['dkt_parc']['tissue_segmentation']
 8509    # loop over modalities and then unique image IDs
 8510    # we treat NM in a "special" way -- aggregating repeats
 8511    # other modalities (beyond T1) are treated individually
 8512    nimages = len(myimgsInput)
 8513    if verbose:
 8514        print(  " we have : " + str(nimages) + " modalities.")
 8515    for overmodX in nrg_modality_list:
 8516        counter=counter+1
 8517        if counter > (len(nrg_modality_list)+1):
 8518            print("This is weird. " + str(counter))
 8519            return
 8520        if overmodX == 'T1w':
 8521            iidOtherMod = iid
 8522            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8523            myimgsr = glob.glob(mod_search_path)
 8524        elif overmodX == 'NM2DMT' and ('nmid1' in studyid.keys() ):
 8525            iidOtherMod = str( int(studyid['nmid1'].iloc[0]) )
 8526            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8527            myimgsr = glob.glob(mod_search_path)
 8528            for nmnum in range(2,11):
 8529                locnmnum = 'nmid'+str(nmnum)
 8530                if locnmnum in studyid.keys() :
 8531                    iidOtherMod = str( int(studyid[locnmnum].iloc[0]) )
 8532                    mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8533                    myimgsr.append( glob.glob(mod_search_path)[0] )
 8534        elif 'rsfMRI' in overmodX and ( ( 'rsfid1' in studyid.keys() ) or ('rsfid2' in studyid.keys() ) ):
 8535            myimgsr = []
 8536            if  'rsfid1' in studyid.keys():
 8537                iidOtherMod = str( int(studyid['rsfid1'].iloc[0]) )
 8538                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8539                myimgsr.append( glob.glob(mod_search_path)[0] )
 8540            if  'rsfid2' in studyid.keys():
 8541                iidOtherMod = str( int(studyid['rsfid2'].iloc[0]) )
 8542                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8543                myimgsr.append( glob.glob(mod_search_path)[0] )
 8544        elif 'DTI' in overmodX and (  'dtid1' in studyid.keys() or  'dtid2' in studyid.keys() ):
 8545            myimgsr = []
 8546            if  'dtid1' in studyid.keys():
 8547                iidOtherMod = str( int(studyid['dtid1'].iloc[0]) )
 8548                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8549                myimgsr.append( glob.glob(mod_search_path)[0] )
 8550            if  'dtid2' in studyid.keys():
 8551                iidOtherMod = str( int(studyid['dtid2'].iloc[0]) )
 8552                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
 8553                myimgsr.append( glob.glob(mod_search_path)[0] )
 8554        elif 'T2Flair' in overmodX and ('flairid' in studyid.keys() ):
 8555            iidOtherMod = str( int(studyid['flairid'].iloc[0]) )
 8556            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
 8557            myimgsr = glob.glob(mod_search_path)
 8558        if verbose:
 8559            print( "overmod " + overmodX + " " + iidOtherMod )
 8560            print(f"modality search path: {mod_search_path}")
 8561        myimgsr.sort()
 8562        if len(myimgsr) > 0:
 8563            overmodXx = str(overmodX)
 8564            dowrite=False
 8565            if verbose:
 8566                print( 'overmodX is : ' + overmodXx )
 8567                print( 'example image name is : '  )
 8568                print( myimgsr )
 8569            if overmodXx == 'NM2DMT':
 8570                myimgsr2 = myimgsr
 8571                myimgsr2.sort()
 8572                is4d = False
 8573                temp = ants.image_read( myimgsr2[0] )
 8574                if temp.dimension == 4:
 8575                    is4d = True
 8576                if len( myimgsr2 ) == 1 and not is4d: # check dimension
 8577                    myimgsr2 = myimgsr2 + myimgsr2
 8578                subjectpropath = os.path.dirname( myimgsr2[0] )
 8579                subjectpropath = re.sub( sourcedatafoldername, processDir,subjectpropath )
 8580                if verbose:
 8581                    print( "subjectpropath " + subjectpropath )
 8582                mysplit = subjectpropath.split( "/" )
 8583                os.makedirs( subjectpropath, exist_ok=True  )
 8584                mysplitCount = len( mysplit )
 8585                project = mysplit[mysplitCount-5]
 8586                subject = mysplit[mysplitCount-4]
 8587                date = mysplit[mysplitCount-3]
 8588                modality = mysplit[mysplitCount-2]
 8589                uider = mysplit[mysplitCount-1]
 8590                identifier = mysep.join([project, subject, date, modality ])
 8591                identifier = identifier + "_" + iid
 8592                mymm = subjectpropath + "/" + identifier
 8593                mymmout = makewideout( mymm )
 8594                if verbose and not exists( mymmout ):
 8595                    print( "NM " + mymm  + ' execution ')
 8596                elif verbose and exists( mymmout ) :
 8597                    print( "NM " + mymm + ' complete ' )
 8598                if exists( mymmout ):
 8599                    continue
 8600                if is4d:
 8601                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
 8602                else:
 8603                    nmlist = []
 8604                    for zz in myimgsr2:
 8605                        nmlist.append( mm_read( zz ) )
 8606                srmodel_NM_mdl = None
 8607                if srmodel_NM is not False:
 8608                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
 8609                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
 8610                    if isinstance( srmodel_NM, str ):
 8611                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
 8612                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
 8613                    if exists( mdlfn ):
 8614                        if verbose:
 8615                            print(mdlfn)
 8616                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
 8617                    else:
 8618                        print( mdlfn + " does not exist - wont use SR")
 8619                if not testloop:
 8620                    tabPro, normPro = mm( t1, hier,
 8621                            nm_image_list = nmlist,
 8622                            srmodel=srmodel_NM_mdl,
 8623                            do_tractography=False,
 8624                            do_kk=False,
 8625                            do_normalization=templateTx,
 8626                            test_run=test_run,
 8627                            verbose=True )
 8628                    if not test_run:
 8629                        write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=None, separator=mysep )
 8630                        nmpro = tabPro['NM']
 8631                        mysl = range( nmpro['NM_avg'].shape[2] )
 8632                    if visualize:
 8633                        mysl = range( nmpro['NM_avg'].shape[2] )
 8634                        ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
 8635                        mysl = range( nmpro['NM_avg_cropped'].shape[2] )
 8636                        ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
 8637                        ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
 8638                        ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
 8639            else :
 8640                if len( myimgsr ) > 0:
 8641                    dowrite=False
 8642                    myimgcount = 0
 8643                    if len( myimgsr ) > 0 :
 8644                        myimg = myimgsr[myimgcount]
 8645                        subjectpropath = os.path.dirname( myimg )
 8646                        subjectpropath = re.sub( sourcedatafoldername, processDir, subjectpropath )
 8647                        mysplit = subjectpropath.split("/")
 8648                        mysplitCount = len( mysplit )
 8649                        project = mysplit[mysplitCount-5]
 8650                        date = mysplit[mysplitCount-4]
 8651                        subject = mysplit[mysplitCount-3]
 8652                        mymod = mysplit[mysplitCount-2] # FIXME system dependent
 8653                        uid = mysplit[mysplitCount-1] # unique image id
 8654                        os.makedirs( subjectpropath, exist_ok=True  )
 8655                        if mymod == 'T1w':
 8656                            identifier = mysep.join([project, date, subject, mymod, uid])
 8657                        else:  # add the T1 unique id since that drives a lot of the analysis
 8658                            identifier = mysep.join([project, date, subject, mymod, uid ])
 8659                            identifier = identifier + "_" + iid
 8660                        mymm = subjectpropath + "/" + identifier
 8661                        mymmout = makewideout( mymm )
 8662                        if verbose and not exists( mymmout ):
 8663                            print("Modality specific processing: " + mymod + " execution " )
 8664                            print( mymm )
 8665                        elif verbose and exists( mymmout ) :
 8666                            print("Modality specific processing: " + mymod + " complete " )
 8667                        if exists( mymmout ) :
 8668                            continue
 8669                        if verbose:
 8670                            print(subjectpropath)
 8671                            print(identifier)
 8672                            print( myimg )
 8673                        if not testloop:
 8674                            img = mm_read( myimg )
 8675                            ishapelen = len( img.shape )
 8676                            if mymod == 'T1w' and ishapelen == 3: # for a real run, set to True
 8677                                if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
 8678                                    if verbose:
 8679                                        print('start t1 registration')
 8680                                    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8681                                    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8682                                    template = mm_read( templatefn )
 8683                                    template = ants.resample_image( template, [1,1,1], use_voxels=False )
 8684                                    t1reg = ants.registration( template, hier['brain_n4_dnz'],
 8685                                        "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
 8686                                    myjac = ants.create_jacobian_determinant_image( template,
 8687                                        t1reg['fwdtransforms'][0], do_log=True, geom=True )
 8688                                    image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
 8689                                    if visualize:
 8690                                        ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
 8691                                        ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
 8692                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
 8693                                    dowrite=True
 8694                                    if verbose:
 8695                                        print('start kk')
 8696                                    tabPro, normPro = mm( t1, hier,
 8697                                        srmodel=None,
 8698                                        do_tractography=False,
 8699                                        do_kk=True,
 8700                                        do_normalization=templateTx,
 8701                                        test_run=test_run,
 8702                                        verbose=True )
 8703                                    if visualize:
 8704                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
 8705                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
 8706                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
 8707                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
 8708                            if mymod == 'T2Flair' and ishapelen == 3:
 8709                                dowrite=True
 8710                                tabPro, normPro = mm( t1, hier,
 8711                                    flair_image = img,
 8712                                    srmodel=None,
 8713                                    do_tractography=False,
 8714                                    do_kk=False,
 8715                                    do_normalization=templateTx,
 8716                                    test_run=test_run,
 8717                                    verbose=True )
 8718                                if visualize:
 8719                                    maxslice = np.min( [21, img.shape[2] ] )
 8720                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
 8721                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
 8722                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
 8723                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
 8724                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
 8725                                img2 = None
 8726                                if len( myimgsr ) > 1:
 8727                                    img2 = mm_read( myimgsr[myimgcount+1] )
 8728                                    ishapelen2 = len( img2.shape )
 8729                                    if ishapelen2 != 4 :
 8730                                        img2 = None
 8731                                dowrite=True
 8732                                tabPro, normPro = mm( t1, hier,
 8733                                    rsf_image=[img,img2],
 8734                                    srmodel=None,
 8735                                    do_tractography=False,
 8736                                    do_kk=False,
 8737                                    do_normalization=templateTx,
 8738                                    test_run=test_run,
 8739                                    verbose=True )
 8740                                if tabPro['rsf'] is not None and visualize:
 8741                                    dfn=tabPro['rsf']['dfnname']
 8742                                    maxslice = np.min( [21, tabPro['rsf']['meanBold'].shape[2] ] )
 8743                                    ants.plot( tabPro['rsf']['meanBold'],
 8744                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=mymm+mysep+"meanBOLD.png" )
 8745                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['alff'],"Normalize"),
 8746                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=mymm+mysep+"boldALFF.png" )
 8747                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['falff'],"Normalize"),
 8748                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=mymm+mysep+"boldfALFF.png" )
 8749                                    ants.plot( tabPro['rsf']['meanBold'], tabPro['rsf'][dfn],
 8750                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='DefaultMode', filename=mymm+mysep+"boldDefaultMode.png" )
 8751                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
 8752                                dowrite=True
 8753                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
 8754                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
 8755                                imgList = [ img ]
 8756                                bvalfnList = [ bvalfn ]
 8757                                bvecfnList = [ bvecfn ]
 8758                                if len( myimgsr ) > 1:  # find DTI_RL
 8759                                    dtilrfn = myimgsr[myimgcount+1]
 8760                                    if len( dtilrfn ) == 1:
 8761                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 8762                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 8763                                        imgRL = ants.image_read( dtilrfn )
 8764                                        imgList.append( imgRL )
 8765                                        bvalfnList.append( bvalfnRL )
 8766                                        bvecfnList.append( bvecfnRL )
 8767                                srmodel_DTI_mdl=None
 8768                                if srmodel_DTI is not False:
 8769                                    temp = ants.get_spacing(img)
 8770                                    dtspc=[temp[0],temp[1],temp[2]]
 8771                                    bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
 8772                                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
 8773                                    if isinstance( srmodel_DTI, str ):
 8774                                        srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
 8775                                        mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
 8776                                    if exists( mdlfn ):
 8777                                        if verbose:
 8778                                            print(mdlfn)
 8779                                        srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 8780                                    else:
 8781                                        print(mdlfn + " does not exist - wont use SR")
 8782                                tabPro, normPro = mm( t1, hier,
 8783                                    dw_image=imgList,
 8784                                    bvals = bvalfnList,
 8785                                    bvecs = bvecfnList,
 8786                                    srmodel=srmodel_DTI_mdl,
 8787                                    do_tractography=not test_run,
 8788                                    do_kk=False,
 8789                                    do_normalization=templateTx,
 8790                                    test_run=test_run,
 8791                                    verbose=True )
 8792                                mydti = tabPro['DTI']
 8793                                if visualize:
 8794                                    maxslice = np.min( [21, mydti['recon_fa'] ] )
 8795                                    ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
 8796                                    ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
 8797                                    ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
 8798                            if dowrite:
 8799                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep, verbose=True )
 8800                                for mykey in normPro.keys():
 8801                                    if normPro[mykey] is not None:
 8802                                        if visualize and normPro[mykey].components == 1 and False:
 8803                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
 8804        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
 8805            return
 8806        if verbose:
 8807            print("done with " + overmodX )
 8808    if verbose:
 8809        print("mm_nrg complete.")
 8810    return
 8811
 8812
 8813
 8814def mm_csv(
 8815    studycsv,   # pandas data frame
 8816    mysep = '-', # or "_" for BIDS
 8817    srmodel_T1 = False, # optional - will add a great deal of time
 8818    srmodel_NM = False, # optional - will add a great deal of time
 8819    srmodel_DTI = False, # optional - will add a great deal of time
 8820    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
 8821    dti_denoise = True,
 8822    nrg_modality_list = None,
 8823    normalization_template = None,
 8824    normalization_template_output = None,
 8825    normalization_template_transform_type = "antsRegistrationSyNRepro[s]",
 8826    normalization_template_spacing=None,
 8827    enantiomorphic=False,
 8828    perfusion_trim = 10,
 8829    perfusion_m0_image = None,
 8830    perfusion_m0 = None,
 8831    rsf_upsampling = 3.0,
 8832    pet3d = None,
 8833):
 8834    """
 8835    too dangerous to document ... use with care.
 8836
 8837    processes multiple modality MRI specifically:
 8838
 8839    * T1w
 8840    * T2Flair
 8841    * DTI, DTI_LR, DTI_RL
 8842    * rsfMRI, rsfMRI_LR, rsfMRI_RL
 8843    * NM2DMT (neuromelanin)
 8844
 8845    other modalities may be added later ...
 8846
 8847    "trust me, i know what i'm doing" - sledgehammer
 8848
 8849    convert to pynb via:
 8850        p2j mm.py -o
 8851
 8852    convert the ipynb to html via:
 8853        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
 8854
 8855    this function does not assume NRG format for the input data ....
 8856
 8857    Parameters
 8858    -------------
 8859
 8860    studycsv : must have columns:
 8861        - subjectID
 8862        - date or session
 8863        - imageID
 8864        - modality
 8865        - sourcedir
 8866        - outputdir
 8867        - filename (path to the t1 image)
 8868        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
 8869        these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor,
 8870        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
 8871        t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.
 8872        see antspymm.generate_mm_dataframe
 8873
 8874    sourcedir : a study specific folder containing individual subject folders
 8875
 8876    outputdir : a study specific folder where individual output subject folders will go
 8877
 8878    filename : the raw image filename (full path)
 8879
 8880    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
 8881
 8882    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8883
 8884    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
 8885
 8886    dti_motion_correct : None, Rigid or SyN
 8887
 8888    dti_denoise : boolean
 8889
 8890    nrg_modality_list : optional; defaults to None; use to focus on a given modality
 8891
 8892    normalization_template : optional; defaults to None; if present, all images will
 8893        be deformed into this space and the deformation will be stored with an extension
 8894        related to this variable.  this should be a brain extracted T1w image.
 8895
 8896    normalization_template_output : optional string; defaults to None; naming for the 
 8897        normalization_template outputs which will be in the T1w directory.
 8898
 8899    normalization_template_transform_type : optional string transform type passed to ants.registration
 8900
 8901    normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed 
 8902    
 8903    enantiomorphic: boolean (WIP)
 8904
 8905    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
 8906
 8907    perfusion_m0_image : optional m0 antsImage associated with the perfusion time series
 8908
 8909    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
 8910
 8911    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
 8912
 8913    pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize
 8914
 8915    Returns
 8916    ---------
 8917
 8918    writes output to disk and produces figures
 8919
 8920    """
 8921    import traceback
 8922    visualize = True
 8923    verbose = True
 8924    if verbose:
 8925        print( version() )
 8926    if nrg_modality_list is None:
 8927        nrg_modality_list = get_valid_modalities()
 8928    if studycsv.shape[0] < 1:
 8929        raise ValueError('studycsv has no rows')
 8930    musthavecols = ['projectID', 'subjectID','date','imageID','modality','sourcedir','outputdir','filename']
 8931    for k in range(len(musthavecols)):
 8932        if not musthavecols[k] in studycsv.keys():
 8933            raise ValueError('studycsv is missing column ' +musthavecols[k] )
 8934    def makewideout( x, separator = mysep ):
 8935        return x + separator + 'mmwide.csv'
 8936    testloop = False
 8937    counter=0
 8938    import glob as glob
 8939    from os.path import exists
 8940    ex_path = os.path.expanduser( "~/.antspyt1w/" )
 8941    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
 8942    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 8943    if not exists( templatefn ):
 8944        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
 8945        antspyt1w.get_data( force_download=True )
 8946        get_data( force_download=True )
 8947    template = mm_read( templatefn ) # Read in template
 8948    test_run = False
 8949    if test_run:
 8950        visualize=False
 8951    # get sid and dtid from studycsv
 8952    # musthavecols = ['projectID','subjectID','date','imageID','modality','sourcedir','outputdir','filename']
 8953    projid = str(studycsv['projectID'].iloc[0])
 8954    sid = str(studycsv['subjectID'].iloc[0])
 8955    dtid = str(studycsv['date'].iloc[0])
 8956    iid = str(studycsv['imageID'].iloc[0])
 8957    t1iidUse=iid
 8958    modality = str(studycsv['modality'].iloc[0])
 8959    sourcedir = str(studycsv['sourcedir'].iloc[0])
 8960    outputdir = str(studycsv['outputdir'].iloc[0])
 8961    filename = str(studycsv['filename'].iloc[0])
 8962    if not exists(filename):
 8963            raise ValueError('mm_nrg cannot find filename ' + filename + ' in mm_csv' )
 8964
 8965    # hierarchical
 8966    # NOTE: if there are multiple T1s for this time point, should take
 8967    # the one with the highest resnetGrade
 8968    t1fn = filename
 8969    if not exists( t1fn ):
 8970        raise ValueError('mm_nrg cannot find the T1w with uid ' + t1fn )
 8971    t1 = mm_read( t1fn, modality='T1w' )
 8972    minspc = np.min(ants.get_spacing(t1))
 8973    minshape = np.min(t1.shape)
 8974    if minspc < 1e-16:
 8975        warnings.warn('minimum spacing in T1w is too small - cannot process. ' + str(minspc) )
 8976        return
 8977    if minshape < 32:
 8978        warnings.warn('minimum shape in T1w is too small - cannot process. ' + str(minshape) )
 8979        return
 8980
 8981    if enantiomorphic:
 8982        t1 = enantiomorphic_filling_without_mask( t1, axis=0 )[0]
 8983    hierfn = outputdir + "/"  + projid + "/" + sid + "/" + dtid + "/" + "T1wHierarchical" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchical" + mysep + iid + mysep
 8984    hierfnSR = outputdir + "/" + projid + "/"  + sid + "/" + dtid + "/" + "T1wHierarchicalSR" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchicalSR" + mysep + iid + mysep
 8985    hierfntest = hierfn + 'cerebellum.csv'
 8986    if verbose:
 8987        print( hierfntest )
 8988    regout = re.sub("T1wHierarchical","T1w",hierfn) + "syn"
 8989    templateTx = {
 8990        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
 8991        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
 8992    groupTx = None
 8993    # make the T1w directory
 8994    os.makedirs( os.path.dirname(re.sub("T1wHierarchical","T1w",hierfn)), exist_ok=True  )
 8995    if normalization_template_output is not None:
 8996        normout = re.sub("T1wHierarchical","T1w",hierfn) +  normalization_template_output
 8997        templateNormTx = {
 8998            'fwdtransforms': [ normout+'1Warp.nii.gz', normout+'0GenericAffine.mat'],
 8999            'invtransforms': [ normout+'0GenericAffine.mat', normout+'1InverseWarp.nii.gz']  }
 9000        groupTx = templateNormTx['fwdtransforms']
 9001    if verbose:
 9002        print( "-<REGISTRATION EXISTENCE>-: \n" + 
 9003              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
 9004            str(exists( templateTx['fwdtransforms'][0])) + " " +
 9005            str(exists( templateTx['fwdtransforms'][1])) + " " +
 9006            str(exists( templateTx['invtransforms'][0])) + " " +
 9007            str(exists( templateTx['invtransforms'][1])) )
 9008    if verbose:
 9009        print( hierfntest )
 9010    hierexists = exists( hierfntest ) and exists( templateTx['fwdtransforms'][0]) and exists( templateTx['fwdtransforms'][1]) and exists( templateTx['invtransforms'][0]) and exists( templateTx['invtransforms'][1])
 9011    hier = None
 9012    if not hierexists and not testloop:
 9013        subjectpropath = os.path.dirname( hierfn )
 9014        if verbose:
 9015            print( subjectpropath )
 9016        os.makedirs( subjectpropath, exist_ok=True  )
 9017        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
 9018        antspyt1w.write_hierarchical( hier, hierfn )
 9019        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9020                hier['dataframes'], identifier=None )
 9021        t1wide.to_csv( hierfn + 'mmwide.csv' )
 9022    ################# read the hierarchical data ###############################
 9023    # over-write the rbp data with a consistent and recent approach ############
 9024    redograding = True
 9025    if redograding:
 9026        myx = antspyt1w.inspect_raw_t1( t1, hierfn + 'rbp' , option='both' )
 9027        myx['brain'].to_csv( hierfn + 'rbp.csv', index=False )
 9028        myx['brain'].to_csv( hierfn + 'rbpbrain.csv', index=False )
 9029        del myx
 9030
 9031    hier = antspyt1w.read_hierarchical( hierfn )
 9032    t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9033        hier['dataframes'], identifier=None )
 9034    rgrade = str( t1wide['resnetGrade'].iloc[0] )
 9035    if t1wide['resnetGrade'].iloc[0] < 0.20:
 9036        warnings.warn('T1w quality check indicates failure: ' + rgrade + " will not process." )
 9037        return
 9038    else:
 9039        print('T1w quality check indicates success: ' + rgrade + " will process." )
 9040
 9041    if srmodel_T1 is not False :
 9042        hierfntest = hierfnSR + 'mtl.csv'
 9043        if verbose:
 9044            print( hierfntest )
 9045        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
 9046        if not hierexists:
 9047            subjectpropath = os.path.dirname( hierfnSR )
 9048            if verbose:
 9049                print( subjectpropath )
 9050            os.makedirs( subjectpropath, exist_ok=True  )
 9051            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
 9052            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
 9053            mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_2chan_featvggL6_postseg_best_mdl.h5"
 9054            if isinstance( srmodel_T1, str ):
 9055                mdlfn = os.path.join( ex_pathmm, srmodel_T1 )
 9056            if verbose:
 9057                print( mdlfn )
 9058            if exists( mdlfn ):
 9059                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 9060            else:
 9061                print( mdlfn + " does not exist - will not run.")
 9062            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
 9063            antspyt1w.write_hierarchical( hierSR, hierfnSR )
 9064            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9065                    hierSR['dataframes'], identifier=None )
 9066            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
 9067    hier = antspyt1w.read_hierarchical( hierfn )
 9068    if exists( hierfn + 'mmwide.csv' ) :
 9069        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
 9070    elif not testloop:
 9071        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
 9072                hier['dataframes'], identifier=None )
 9073    if not testloop:
 9074        t1imgbrn = hier['brain_n4_dnz']
 9075        t1atropos = hier['dkt_parc']['tissue_segmentation']
 9076
 9077    if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
 9078        if verbose:
 9079            print('start t1 registration')
 9080        ex_path = os.path.expanduser( "~/.antspyt1w/" )
 9081        templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
 9082        template = mm_read( templatefn )
 9083        template = ants.resample_image( template, [1,1,1], use_voxels=False )
 9084        t1reg = ants.registration( template, 
 9085            hier['brain_n4_dnz'],
 9086            "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
 9087        myjac = ants.create_jacobian_determinant_image( template,
 9088            t1reg['fwdtransforms'][0], do_log=True, geom=True )
 9089        image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
 9090        if visualize:
 9091            ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
 9092            ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
 9093
 9094    if normalization_template_output is not None and normalization_template is not None:
 9095        if verbose:
 9096            print("begin group template registration")
 9097        if not exists( normout+'0GenericAffine.mat' ):
 9098            if normalization_template_spacing is not None:
 9099                normalization_template_rr=ants.resample_image(normalization_template,normalization_template_spacing)
 9100            else:
 9101                normalization_template_rr=normalization_template
 9102            greg = ants.registration( 
 9103                normalization_template_rr, 
 9104                hier['brain_n4_dnz'],
 9105                normalization_template_transform_type,
 9106                outprefix = normout, verbose=False )
 9107            myjac = ants.create_jacobian_determinant_image( template,
 9108                    greg['fwdtransforms'][0], do_log=True, geom=True )
 9109            image_write_with_thumbnail( myjac, normout + "logjacobian.nii.gz", thumb=False )
 9110            if verbose:
 9111                print("end group template registration")
 9112        else:
 9113            if verbose:
 9114                print("group template registration already done")
 9115
 9116    # loop over modalities and then unique image IDs
 9117    # we treat NM in a "special" way -- aggregating repeats
 9118    # other modalities (beyond T1) are treated individually
 9119    for overmodX in nrg_modality_list:
 9120        # define 1. input images 2. output prefix
 9121        mydoc = docsamson( overmodX, studycsv=studycsv, outputdir=outputdir, projid=projid, sid=sid, dtid=dtid, mysep=mysep,t1iid=t1iidUse )
 9122        myimgsr = mydoc['images']
 9123        mymm = mydoc['outprefix']
 9124        mymod = mydoc['modality']
 9125        if verbose:
 9126            print( mydoc )
 9127        if len(myimgsr) > 0:
 9128            dowrite=False
 9129            if verbose:
 9130                print( 'overmodX is : ' + overmodX )
 9131                print( 'example image name is : '  )
 9132                print( myimgsr )
 9133            if overmodX == 'NM2DMT':
 9134                dowrite = True
 9135                visualize = True
 9136                subjectpropath = os.path.dirname( mydoc['outprefix'] )
 9137                if verbose:
 9138                    print("subjectpropath is")
 9139                    print(subjectpropath)
 9140                    os.makedirs( subjectpropath, exist_ok=True  )
 9141                myimgsr2 = myimgsr
 9142                myimgsr2.sort()
 9143                is4d = False
 9144                temp = ants.image_read( myimgsr2[0] )
 9145                if temp.dimension == 4:
 9146                    is4d = True
 9147                if len( myimgsr2 ) == 1 and not is4d: # check dimension
 9148                    myimgsr2 = myimgsr2 + myimgsr2
 9149                mymmout = makewideout( mymm )
 9150                if verbose and not exists( mymmout ):
 9151                    print( "NM " + mymm  + ' execution ')
 9152                elif verbose and exists( mymmout ) :
 9153                    print( "NM " + mymm + ' complete ' )
 9154                if exists( mymmout ):
 9155                    continue
 9156                if is4d:
 9157                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
 9158                else:
 9159                    nmlist = []
 9160                    for zz in myimgsr2:
 9161                        nmlist.append( mm_read( zz ) )
 9162                srmodel_NM_mdl = None
 9163                if srmodel_NM is not False:
 9164                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
 9165                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
 9166                    if isinstance( srmodel_NM, str ):
 9167                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
 9168                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
 9169                    if exists( mdlfn ):
 9170                        if verbose:
 9171                            print(mdlfn)
 9172                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
 9173                    else:
 9174                        print( mdlfn + " does not exist - wont use SR")
 9175                if not testloop:
 9176                    try:
 9177                        tabPro, normPro = mm( t1, hier,
 9178                            nm_image_list = nmlist,
 9179                            srmodel=srmodel_NM_mdl,
 9180                            do_tractography=False,
 9181                            do_kk=False,
 9182                            do_normalization=templateTx,
 9183                            group_template = normalization_template,
 9184                            group_transform = groupTx,
 9185                            test_run=test_run,
 9186                            verbose=True )
 9187                    except Exception as e:
 9188                        error_info = traceback.format_exc()
 9189                        print(error_info)
 9190                        visualize=False
 9191                        dowrite=False
 9192                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9193                        pass
 9194                    if not test_run:
 9195                        if dowrite:
 9196                            write_mm( output_prefix=mymm, mm=tabPro,
 9197                                mm_norm=normPro, t1wide=None, separator=mysep )
 9198                        if visualize :
 9199                            nmpro = tabPro['NM']
 9200                            mysl = range( nmpro['NM_avg'].shape[2] )
 9201                            ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
 9202                            mysl = range( nmpro['NM_avg_cropped'].shape[2] )
 9203                            ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
 9204                            ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
 9205                            ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
 9206            else :
 9207                if len( myimgsr ) > 0 :
 9208                    dowrite=False
 9209                    myimgcount=0
 9210                    if len( myimgsr ) > 0 :
 9211                        myimg = myimgsr[ myimgcount ]
 9212                        subjectpropath = os.path.dirname( mydoc['outprefix'] )
 9213                        if verbose:
 9214                            print("subjectpropath is")
 9215                            print(subjectpropath)
 9216                        os.makedirs( subjectpropath, exist_ok=True  )
 9217                        mymmout = makewideout( mymm )
 9218                        if verbose and not exists( mymmout ):
 9219                            print( "Modality specific processing: " + mymod + " execution " )
 9220                            print( mymm )
 9221                        elif verbose and exists( mymmout ) :
 9222                            print("Modality specific processing: " + mymod + " complete " )
 9223                        if exists( mymmout ) :
 9224                            continue
 9225                        if verbose:
 9226                            print( subjectpropath )
 9227                            print( myimg )
 9228                        if not testloop:
 9229                            img = mm_read( myimg )
 9230                            ishapelen = len( img.shape )
 9231                            if mymod == 'T1w' and ishapelen == 3:
 9232                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
 9233                                    dowrite=True
 9234                                    if verbose:
 9235                                        print('start kk')
 9236                                    try:
 9237                                        tabPro, normPro = mm( t1, hier,
 9238                                            srmodel=None,
 9239                                            do_tractography=False,
 9240                                            do_kk=True,
 9241                                            do_normalization=templateTx,
 9242                                            group_template = normalization_template,
 9243                                            group_transform = groupTx,
 9244                                            test_run=test_run,
 9245                                            verbose=True )
 9246                                    except Exception as e:
 9247                                        error_info = traceback.format_exc()
 9248                                        print(error_info)
 9249                                        visualize=False
 9250                                        dowrite=False
 9251                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9252                                        pass
 9253                                    if visualize:
 9254                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
 9255                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
 9256                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
 9257                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
 9258                            if mymod == 'T2Flair' and ishapelen == 3 and np.min(img.shape) > 15:
 9259                                dowrite=True
 9260                                try:
 9261                                    tabPro, normPro = mm( t1, hier,
 9262                                        flair_image = img,
 9263                                        srmodel=None,
 9264                                        do_tractography=False,
 9265                                        do_kk=False,
 9266                                        do_normalization=templateTx,
 9267                                        group_template = normalization_template,
 9268                                        group_transform = groupTx,
 9269                                        test_run=test_run,
 9270                                        verbose=True )
 9271                                except Exception as e:
 9272                                        error_info = traceback.format_exc()
 9273                                        print(error_info)
 9274                                        visualize=False
 9275                                        dowrite=False
 9276                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9277                                        pass
 9278                                if visualize:
 9279                                    maxslice = np.min( [21, img.shape[2] ] )
 9280                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
 9281                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
 9282                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
 9283                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
 9284                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
 9285                                img2 = None
 9286                                if len( myimgsr ) > 1:
 9287                                    img2 = mm_read( myimgsr[myimgcount+1] )
 9288                                    ishapelen2 = len( img2.shape )
 9289                                    if ishapelen2 != 4 or 1 in img2.shape:
 9290                                        img2 = None
 9291                                if 1 in img.shape:
 9292                                    warnings.warn( 'rsfMRI image shape suggests it is an incorrectly converted mosaic image - will not process.')
 9293                                    dowrite=False
 9294                                    tabPro={'rsf':None}
 9295                                    normPro={'rsf':None}
 9296                                else:
 9297                                    dowrite=True
 9298                                    try:
 9299                                        tabPro, normPro = mm( t1, hier,
 9300                                            rsf_image=[img,img2],
 9301                                            srmodel=None,
 9302                                            do_tractography=False,
 9303                                            do_kk=False,
 9304                                            do_normalization=templateTx,
 9305                                            group_template = normalization_template,
 9306                                            group_transform = groupTx,
 9307                                            rsf_upsampling = rsf_upsampling,
 9308                                            test_run=test_run,
 9309                                            verbose=True )
 9310                                    except Exception as e:
 9311                                        error_info = traceback.format_exc()
 9312                                        print(error_info)
 9313                                        visualize=False
 9314                                        dowrite=False
 9315                                        tabPro={'rsf':None}
 9316                                        normPro={'rsf':None}
 9317                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9318                                        pass
 9319                                if tabPro['rsf'] is not None and visualize:
 9320                                    for tpro in tabPro['rsf']: # FIXMERSF
 9321                                        maxslice = np.min( [21, tpro['meanBold'].shape[2] ] )
 9322                                        tproprefix = mymm+mysep+str(tpro['paramset'])+mysep
 9323                                        ants.plot( tpro['meanBold'],
 9324                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=tproprefix+"meanBOLD.png" )
 9325                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['alff'],"Normalize"),
 9326                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=tproprefix+"boldALFF.png" )
 9327                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['falff'],"Normalize"),
 9328                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=tproprefix+"boldfALFF.png" )
 9329                                        dfn=tpro['dfnname']
 9330                                        ants.plot( tpro['meanBold'], tpro[dfn],
 9331                                            axis=2, nslices=maxslice, ncol=7, crop=True, title=dfn, filename=tproprefix+"boldDefaultMode.png" )
 9332                            if ( mymod == 'perf' ) and ishapelen == 4:
 9333                                dowrite=True
 9334                                try:
 9335                                    tabPro, normPro = mm( t1, hier,
 9336                                        perfusion_image=img,
 9337                                        srmodel=None,
 9338                                        do_tractography=False,
 9339                                        do_kk=False,
 9340                                        do_normalization=templateTx,
 9341                                        group_template = normalization_template,
 9342                                        group_transform = groupTx,
 9343                                        test_run=test_run,
 9344                                        perfusion_trim=perfusion_trim,
 9345                                        perfusion_m0_image=perfusion_m0_image,
 9346                                        perfusion_m0=perfusion_m0,
 9347                                        verbose=True )
 9348                                except Exception as e:
 9349                                        error_info = traceback.format_exc()
 9350                                        print(error_info)
 9351                                        visualize=False
 9352                                        dowrite=False
 9353                                        tabPro={'perf':None}
 9354                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9355                                        pass
 9356                                if tabPro['perf'] is not None and visualize:
 9357                                    maxslice = np.min( [21, tabPro['perf']['meanBold'].shape[2] ] )
 9358                                    ants.plot( tabPro['perf']['perfusion'],
 9359                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='perfusion image', filename=mymm+mysep+"perfusion.png" )
 9360                                    ants.plot( tabPro['perf']['cbf'],
 9361                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='CBF image', filename=mymm+mysep+"cbf.png" )
 9362                                    ants.plot( tabPro['perf']['m0'],
 9363                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='M0 image', filename=mymm+mysep+"m0.png" )
 9364
 9365                            if ( mymod == 'pet3d' ) and ishapelen == 3:
 9366                                dowrite=True
 9367                                try:
 9368                                    tabPro, normPro = mm( t1, hier,
 9369                                        srmodel=None,
 9370                                        do_tractography=False,
 9371                                        do_kk=False,
 9372                                        do_normalization=templateTx,
 9373                                        group_template = normalization_template,
 9374                                        group_transform = groupTx,
 9375                                        test_run=test_run,
 9376                                        pet_3d_image=img,
 9377                                        verbose=True )
 9378                                except Exception as e:
 9379                                        error_info = traceback.format_exc()
 9380                                        print(error_info)
 9381                                        visualize=False
 9382                                        dowrite=False
 9383                                        tabPro={'pet3d':None}
 9384                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9385                                        pass
 9386                                if tabPro['pet3d'] is not None and visualize:
 9387                                    maxslice = np.min( [21, tabPro['pet3d']['pet3d'].shape[2] ] )
 9388                                    ants.plot( tabPro['pet3d']['pet3d'],
 9389                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='PET image', filename=mymm+mysep+"pet3d.png" )
 9390                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
 9391                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
 9392                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
 9393                                imgList = [ img ]
 9394                                bvalfnList = [ bvalfn ]
 9395                                bvecfnList = [ bvecfn ]
 9396                                missing_dti_data=False # bval, bvec or images
 9397                                if len( myimgsr ) == 2:  # find DTI_RL
 9398                                    dtilrfn = myimgsr[myimgcount+1]
 9399                                    if exists( dtilrfn ):
 9400                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 9401                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 9402                                        imgRL = ants.image_read( dtilrfn )
 9403                                        imgList.append( imgRL )
 9404                                        bvalfnList.append( bvalfnRL )
 9405                                        bvecfnList.append( bvecfnRL )
 9406                                elif len( myimgsr ) == 3:  # find DTI_RL
 9407                                    print("DTI trinity")
 9408                                    dtilrfn = myimgsr[myimgcount+1]
 9409                                    dtilrfn2 = myimgsr[myimgcount+2]
 9410                                    if exists( dtilrfn ) and exists( dtilrfn2 ):
 9411                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
 9412                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
 9413                                        bvalfnRL2 = re.sub( '.nii.gz', '.bval' , dtilrfn2 )
 9414                                        bvecfnRL2 = re.sub( '.nii.gz', '.bvec' , dtilrfn2 )
 9415                                        imgRL = ants.image_read( dtilrfn )
 9416                                        imgRL2 = ants.image_read( dtilrfn2 )
 9417                                        bvals, bvecs = read_bvals_bvecs( bvalfnRL , bvecfnRL  )
 9418                                        print( bvals.max() )
 9419                                        bvals2, bvecs2 = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
 9420                                        print( bvals2.max() )
 9421                                        temp = merge_dwi_data( imgRL, bvals, bvecs, imgRL2, bvals2, bvecs2  )
 9422                                        imgList.append( temp[0] )
 9423                                        bvalfnList.append( mymm+mysep+'joined.bval' )
 9424                                        bvecfnList.append( mymm+mysep+'joined.bvec' )
 9425                                        write_bvals_bvecs( temp[1], temp[2], mymm+mysep+'joined' )
 9426                                        bvalsX, bvecsX = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
 9427                                        print( bvalsX.max() )
 9428                                # check existence of all files expected ...
 9429                                for dtiex in bvalfnList+bvecfnList+myimgsr:
 9430                                    if not exists(dtiex):
 9431                                        print('mm_csv: missing dti data ' + dtiex )
 9432                                        missing_dti_data=True
 9433                                        dowrite=False
 9434                                if not missing_dti_data:
 9435                                    dowrite=True
 9436                                    srmodel_DTI_mdl=None
 9437                                    if srmodel_DTI is not False:
 9438                                        temp = ants.get_spacing(img)
 9439                                        dtspc=[temp[0],temp[1],temp[2]]
 9440                                        bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
 9441                                        mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
 9442                                        if isinstance( srmodel_DTI, str ):
 9443                                            srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
 9444                                            mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
 9445                                        if exists( mdlfn ):
 9446                                            if verbose:
 9447                                                print(mdlfn)
 9448                                            srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
 9449                                        else:
 9450                                            print(mdlfn + " does not exist - wont use SR")
 9451                                    try:
 9452                                        tabPro, normPro = mm( t1, hier,
 9453                                            dw_image=imgList,
 9454                                            bvals = bvalfnList,
 9455                                            bvecs = bvecfnList,
 9456                                            srmodel=srmodel_DTI_mdl,
 9457                                            do_tractography=not test_run,
 9458                                            do_kk=False,
 9459                                            do_normalization=templateTx,
 9460                                            group_template = normalization_template,
 9461                                            group_transform = groupTx,
 9462                                            dti_motion_correct = dti_motion_correct,
 9463                                            dti_denoise = dti_denoise,
 9464                                            test_run=test_run,
 9465                                            verbose=True )
 9466                                    except Exception as e:
 9467                                            error_info = traceback.format_exc()
 9468                                            print(error_info)
 9469                                            visualize=False
 9470                                            dowrite=False
 9471                                            tabPro={'DTI':None}
 9472                                            print(f"antspymmerror occurred while processing {overmodX}: {e}")
 9473                                            pass
 9474                                    mydti = tabPro['DTI']
 9475                                    if visualize and tabPro['DTI'] is not None:
 9476                                        maxslice = np.min( [21, mydti['recon_fa'] ] )
 9477                                        ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
 9478                                        ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
 9479                                        ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
 9480                            if dowrite:
 9481                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep )
 9482                                for mykey in normPro.keys():
 9483                                    if normPro[mykey] is not None and normPro[mykey].components == 1:
 9484                                        if visualize and False:
 9485                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
 9486        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
 9487            return
 9488        if verbose:
 9489            print("done with " + overmodX )
 9490    if verbose:
 9491        print("mm_nrg complete.")
 9492    return
 9493
 9494def spec_taper(x, p=0.1):
 9495    from scipy import stats, signal, fft
 9496    from statsmodels.regression.linear_model import yule_walker
 9497    """
 9498    Computes a tapered version of x, with tapering p.
 9499
 9500    Adapted from R's stats::spec.taper at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9501
 9502    """
 9503
 9504    p = np.r_[p]
 9505    assert np.all((p >= 0) & (p < 0.5)), "'p' must be between 0 and 0.5"
 9506
 9507    x = np.r_[x].astype('float64')
 9508    original_shape = x.shape
 9509
 9510    assert len(original_shape) <= 2, "'x' must have at most 2 dimensions"
 9511    while len(x.shape) < 2:
 9512        x = np.expand_dims(x, axis=1)
 9513
 9514    nr, nc = x.shape
 9515    if len(p) == 1:
 9516        p = p * np.ones(nc)
 9517    else:
 9518        assert len(p) == nc, "length of 'p' must be 1 or equal the number of columns of 'x'"
 9519
 9520    for i in range(nc):
 9521        m = int(np.floor(nr * p[i]))
 9522        if m == 0:
 9523            continue
 9524        w = 0.5 * (1 - np.cos(np.pi * np.arange(1, 2 * m, step=2)/(2 * m)))
 9525        x[:, i] = np.r_[w, np.ones(nr - 2 * m), w[::-1]] * x[:, i]
 9526
 9527    x = np.reshape(x, original_shape)
 9528    return x
 9529
 9530def plot_spec(spec_res, coverage=None, ax=None, title=None):
 9531    import matplotlib.pyplot as plt
 9532    """Convenience plotting method, also includes confidence cross in the same style as R.
 9533
 9534    Note that the location of the cross is irrelevant; only width and height matter."""
 9535    f, Pxx = spec_res['freq'], spec_res['spec']
 9536
 9537    if coverage is not None:
 9538        ci = spec_ci(spec_res['df'], coverage=coverage)
 9539        conf_x = (max(spec_res['freq']) - spec_res['bandwidth']) + np.r_[-0.5, 0.5] * spec_res['bandwidth']
 9540        conf_y = max(spec_res['spec']) / ci[1]
 9541
 9542    if ax is None:
 9543        ax = plt.gca()
 9544
 9545    ax.plot(f, Pxx, color='C0')
 9546    ax.set_xlabel('Frequency')
 9547    ax.set_ylabel('Log Spectrum')
 9548    ax.set_yscale('log')
 9549    if coverage is not None:
 9550        ax.plot(np.mean(conf_x) * np.r_[1, 1], conf_y * ci, color='red')
 9551        ax.plot(conf_x, np.mean(conf_y) * np.r_[1, 1], color='red')
 9552
 9553    ax.set_title(spec_res['method'] if title is None else title)
 9554
 9555def spec_ci(df, coverage=0.95):
 9556    from scipy import stats, signal, fft
 9557    from statsmodels.regression.linear_model import yule_walker
 9558    """
 9559    Computes the confidence interval for a spectral fit, based on the number of degrees of freedom.
 9560
 9561    Adapted from R's stats::plot.spec at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9562
 9563    """
 9564
 9565    assert coverage >= 0 and coverage < 1, "coverage probability out of range [0, 1)"
 9566
 9567    tail = 1 - coverage
 9568
 9569    phi = stats.chi2.cdf(x=df, df=df)
 9570    upper_quantile = 1 - tail * (1 - phi)
 9571    lower_quantile = tail * phi
 9572
 9573    return df / stats.chi2.ppf([upper_quantile, lower_quantile], df=df)
 9574
 9575def spec_pgram(x, xfreq=1, spans=None, kernel=None, taper=0.1, pad=0, fast=True, demean=False, detrend=True,
 9576               plot=True, **kwargs):
 9577    """
 9578    Computes the spectral density estimate using a periodogram.  Optionally, it also:
 9579    - Uses a provided kernel window, or a sequence of spans for convoluted modified Daniell kernels.
 9580    - Tapers the start and end of the series to avoid end-of-signal effects.
 9581    - Pads the provided series before computation, adding pad*(length of series) zeros at the end.
 9582    - Pads the provided series before computation to speed up FFT calculation.
 9583    - Performs demeaning or detrending on the series.
 9584    - Plots results.
 9585
 9586    Implemented to ensure compatibility with R's spectral functions, as opposed to reusing scipy's periodogram.
 9587
 9588    Adapted from R's stats::spec.pgram at https://github.com/telmo-correa/time-series-analysis/blob/master/Python/spectrum.py
 9589
 9590    example:
 9591
 9592    import numpy as np
 9593    import antspymm
 9594    myx = np.random.rand(100,1)
 9595    myspec = antspymm.spec_pgram(myx,0.5)
 9596
 9597    """
 9598    from scipy import stats, signal, fft
 9599    from statsmodels.regression.linear_model import yule_walker
 9600    def daniell_window_modified(m):
 9601        """ Single-pass modified Daniell kernel window.
 9602
 9603        Weight is normalized to add up to 1, and all values are the same, other than the first and the
 9604        last, which are divided by 2.
 9605        """
 9606        def w(k):
 9607            return np.where(np.abs(k) < m, 1 / (2*m), np.where(np.abs(k) == m, 1/(4*m), 0))
 9608
 9609        return w(np.arange(-m, m+1))
 9610
 9611    def daniell_window_convolve(v):
 9612        """ Convolved version of multiple modified Daniell kernel windows.
 9613
 9614        Parameter v should be an iterable of m values.
 9615        """
 9616
 9617        if len(v) == 0:
 9618            return np.r_[1]
 9619
 9620        if len(v) == 1:
 9621            return daniell_window_modified(v[0])
 9622
 9623        return signal.convolve(daniell_window_modified(v[0]), daniell_window_convolve(v[1:]))
 9624
 9625    # Ensure we can store non-integers in x, and that it is a numpy object
 9626    x = np.r_[x].astype('float64')
 9627    original_shape = x.shape
 9628
 9629    # Ensure correct dimensions
 9630    assert len(original_shape) <= 2, "'x' must have at most 2 dimensions"
 9631    while len(x.shape) < 2:
 9632        x = np.expand_dims(x, axis=1)
 9633
 9634    N, nser = x.shape
 9635    N0 = N
 9636
 9637    # Ensure only one of spans, kernel is provided, and build the kernel window if needed
 9638    assert (spans is None) or (kernel is None), "must specify only one of 'spans' or 'kernel'"
 9639    if spans is not None:
 9640        kernel = daniell_window_convolve(np.floor_divide(np.r_[spans], 2))
 9641
 9642    # Detrend or demean the series
 9643    if detrend:
 9644        t = np.arange(N) - (N - 1)/2
 9645        sumt2 = N * (N**2 - 1)/12
 9646        x -= (np.repeat(np.expand_dims(np.mean(x, axis=0), 0), N, axis=0) + np.outer(np.sum(x.T * t, axis=1), t/sumt2).T)
 9647    elif demean:
 9648        x -= np.mean(x, axis=0)
 9649
 9650    # Compute taper and taper adjustment variables
 9651    x = spec_taper(x, taper)
 9652    u2 = (1 - (5/8) * taper * 2)
 9653    u4 = (1 - (93/128) * taper * 2)
 9654
 9655    # Pad the series with copies of the same shape, but filled with zeroes
 9656    if pad > 0:
 9657        x = np.r_[x, np.zeros((pad * x.shape[0], x.shape[1]))]
 9658        N = x.shape[0]
 9659
 9660    # Further pad the series to accelerate FFT computation
 9661    if fast:
 9662        newN = fft.next_fast_len(N, True)
 9663        x = np.r_[x, np.zeros((newN - N, x.shape[1]))]
 9664        N = newN
 9665
 9666    # Compute the Fourier frequencies (R's spec.pgram convention style)
 9667    Nspec = int(np.floor(N/2))
 9668    freq = (np.arange(Nspec) + 1) * xfreq / N
 9669
 9670    # Translations to keep same row / column convention as stats::mvfft
 9671    xfft = fft.fft(x.T).T
 9672
 9673    # Compute the periodogram for each i, j
 9674    pgram = np.empty((N, nser, nser), dtype='complex')
 9675    for i in range(nser):
 9676        for j in range(nser):
 9677            pgram[:, i, j] = xfft[:, i] * np.conj(xfft[:, j]) / (N0 * xfreq)
 9678            pgram[0, i, j] = 0.5 * (pgram[1, i, j] + pgram[-1, i, j])
 9679
 9680    if kernel is None:
 9681        # Values pre-adjustment
 9682        df = 2
 9683        bandwidth = np.sqrt(1 / 12)
 9684    else:
 9685        def conv_circular(signal, kernel):
 9686            """
 9687            Performs 1D circular convolution, in the same style as R::kernapply,
 9688            assuming the kernel window is centered at 0.
 9689            """
 9690            pad = len(signal) - len(kernel)
 9691            half_window = int((len(kernel) + 1) / 2)
 9692            indexes = range(-half_window, len(signal) - half_window)
 9693            orig_conv = np.real(fft.ifft(fft.fft(signal) * fft.fft(np.r_[np.zeros(pad), kernel])))
 9694            return orig_conv.take(indexes, mode='wrap')
 9695
 9696        # Convolve pgram with kernel with circular conv
 9697        for i in range(nser):
 9698            for j in range(nser):
 9699                pgram[:, i, j] = conv_circular(pgram[:, i, j], kernel)
 9700
 9701        df = 2 / np.sum(kernel**2)
 9702        m = (len(kernel) - 1)/2
 9703        k = np.arange(-m, m+1)
 9704        bandwidth = np.sqrt(np.sum((1/12 + k**2) * kernel))
 9705
 9706    df = df/(u4/u2**2)*(N0/N)
 9707    bandwidth = bandwidth * xfreq/N
 9708
 9709    # Remove padded results
 9710    pgram = pgram[1:(Nspec+1), :, :]
 9711
 9712    spec = np.empty((Nspec, nser))
 9713    for i in range(nser):
 9714        spec[:, i] = np.real(pgram[:, i, i])
 9715
 9716    if nser == 1:
 9717        coh = None
 9718        phase = None
 9719    else:
 9720        coh = np.empty((Nspec, int(nser * (nser - 1)/2)))
 9721        phase = np.empty((Nspec, int(nser * (nser - 1)/2)))
 9722        for i in range(nser):
 9723            for j in range(i+1, nser):
 9724                index = int(i + j*(j-1)/2)
 9725                coh[:, index] = np.abs(pgram[:, i, j])**2 / (spec[:, i] * spec[:, j])
 9726                phase[:, index] = np.angle(pgram[:, i, j])
 9727
 9728    spec = spec / u2
 9729    spec = spec.squeeze()
 9730
 9731    results = {
 9732        'freq': freq,
 9733        'spec': spec,
 9734        'coh': coh,
 9735        'phase': phase,
 9736        'kernel': kernel,
 9737        'df': df,
 9738        'bandwidth': bandwidth,
 9739        'n.used': N,
 9740        'orig.n': N0,
 9741        'taper': taper,
 9742        'pad': pad,
 9743        'detrend': detrend,
 9744        'demean': demean,
 9745        'method': 'Raw Periodogram' if kernel is None else 'Smoothed Periodogram'
 9746    }
 9747
 9748    if plot:
 9749        plot_spec(results, coverage=0.95, **kwargs)
 9750
 9751    return results
 9752
 9753def alffmap( x, flo=0.01, fhi=0.1, tr=1, detrend = True ):
 9754    """
 9755    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
 9756    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
 9757    are related measures that quantify the amplitude of low frequency
 9758    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
 9759    same function in ANTsR.
 9760
 9761    x input vector for the time series of interest
 9762    flo low frequency, typically 0.01
 9763    fhi high frequency, typically 0.1
 9764    tr the period associated with the vector x (inverse of frequency)
 9765    detrend detrend the input time series
 9766
 9767    return vector is output showing ALFF and fALFF values
 9768    """
 9769    temp = spec_pgram( x, xfreq=1.0/tr, demean=False, detrend=detrend, taper=0, fast=True, plot=False )
 9770    fselect = np.logical_and( temp['freq'] >= flo, temp['freq'] <= fhi )
 9771    denom = (temp['spec']).sum()
 9772    numer = (temp['spec'][fselect]).sum()
 9773    return {  'alff':numer, 'falff': numer/denom }
 9774
 9775
 9776def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
 9777    """
 9778    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
 9779    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
 9780    are related measures that quantify the amplitude of low frequency
 9781    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
 9782
 9783    x - input clean resting state fmri
 9784    mask - mask over which to compute f/alff
 9785    flo - low frequency, typically 0.01
 9786    fhi - high frequency, typically 0.1
 9787    nuisance - optional nuisance matrix
 9788
 9789    return dictionary with ALFF and fALFF images
 9790    """
 9791    xmat = ants.timeseries_to_matrix( x, mask )
 9792    if nuisance is not None:
 9793        xmat = ants.regress_components( xmat, nuisance )
 9794    alffvec = xmat[0,:]*0
 9795    falffvec = xmat[0,:]*0
 9796    mytr = ants.get_spacing( x )[3]
 9797    for n in range( xmat.shape[1] ):
 9798        temp = alffmap( xmat[:,n], flo=flo, fhi=fhi, tr=mytr )
 9799        alffvec[n]=temp['alff']
 9800        falffvec[n]=temp['falff']
 9801    alffi=ants.make_image( mask, alffvec )
 9802    falffi=ants.make_image( mask, falffvec )
 9803    alfftrimmedmean = calculate_trimmed_mean( alffvec, 0.01 )
 9804    falfftrimmedmean = calculate_trimmed_mean( falffvec, 0.01 )
 9805    alffi=alffi / alfftrimmedmean
 9806    falffi=falffi / falfftrimmedmean
 9807    return {  'alff': alffi, 'falff': falffi }
 9808
 9809
 9810def down2iso( x, interpolation='linear', takemin=False ):
 9811    """
 9812    will downsample an anisotropic image to an isotropic resolution
 9813
 9814    x: input image
 9815
 9816    interpolation: linear or nearestneighbor
 9817
 9818    takemin : boolean map to min space; otherwise max
 9819
 9820    return image downsampled to isotropic resolution
 9821    """
 9822    spc = ants.get_spacing( x )
 9823    if takemin:
 9824        newspc = np.asarray(spc).min()
 9825    else:
 9826        newspc = np.asarray(spc).max()
 9827    newspc = np.repeat( newspc, x.dimension )
 9828    if interpolation == 'linear':
 9829        xs = ants.resample_image( x, newspc, interp_type=0)
 9830    else:
 9831        xs = ants.resample_image( x, newspc, interp_type=1)
 9832    return xs
 9833
 9834
 9835def read_mm_csv( x, is_t1=False, colprefix=None, separator='-', verbose=False ):
 9836    splitter=os.path.basename(x).split( separator )
 9837    lensplit = len( splitter )-1
 9838    temp = os.path.basename(x)
 9839    temp = os.path.splitext(temp)[0]
 9840    temp = re.sub(separator+'mmwide','',temp)
 9841    idcols = ['u_hier_id','sid','visitdate','modality','mmimageuid','t1imageuid']
 9842    df = pd.DataFrame( columns = idcols, index=range(1) )
 9843    valstoadd = [temp] + splitter[1:(lensplit-1)]
 9844    if is_t1:
 9845        valstoadd = valstoadd + [splitter[(lensplit-1)],splitter[(lensplit-1)]]
 9846    else:
 9847        split2=splitter[(lensplit-1)].split( "_" )
 9848        if len(split2) == 1:
 9849            split2.append( split2[0] )
 9850        if len(valstoadd) == 3:
 9851            valstoadd = valstoadd + [split2[0]] + [math.nan] + [split2[1]]
 9852        else:
 9853            valstoadd = valstoadd + [split2[0],split2[1]]
 9854    if verbose:
 9855        print( valstoadd )
 9856    df.iloc[0] = valstoadd
 9857    if verbose:
 9858        print( "read xdf: " + x )
 9859    xdf = pd.read_csv( x )
 9860    df.reset_index()
 9861    xdf.reset_index(drop=True)
 9862    if "Unnamed: 0" in xdf.columns:
 9863        holder=xdf.pop( "Unnamed: 0" )
 9864    if "Unnamed: 1" in xdf.columns:
 9865        holder=xdf.pop( "Unnamed: 1" )
 9866    if "u_hier_id.1" in xdf.columns:
 9867        holder=xdf.pop( "u_hier_id.1" )
 9868    if "u_hier_id" in xdf.columns:
 9869        holder=xdf.pop( "u_hier_id" )
 9870    if not is_t1:
 9871        if 'resnetGrade' in xdf.columns:
 9872            index_no = xdf.columns.get_loc('resnetGrade')
 9873            xdf = xdf.drop( xdf.columns[range(index_no+1)] , axis=1)
 9874
 9875    if xdf.shape[0] == 2:
 9876        xdfcols = xdf.columns
 9877        xdf = xdf.iloc[1]
 9878        ddnum = xdf.to_numpy()
 9879        ddnum = ddnum.reshape([1,ddnum.shape[0]])
 9880        newcolnames = xdf.index.to_list()
 9881        if len(newcolnames) != ddnum.shape[1]:
 9882            print("Cannot Merge : Shape MisMatch " + str( len(newcolnames) ) + " " + str(ddnum.shape[1]))
 9883        else:
 9884            xdf = pd.DataFrame(ddnum, columns=xdfcols )
 9885    if xdf.shape[1] == 0:
 9886        return None
 9887    if colprefix is not None:
 9888        xdf.columns=colprefix + xdf.columns
 9889    return pd.concat( [df,xdf], axis=1, ignore_index=False )
 9890
 9891def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False,
 9892progress=False, verbose=False ):
 9893    """
 9894    extend a study data frame with wide outputs
 9895
 9896    sdf : the input study dataframe from antspymm QC output
 9897
 9898    processing_dir:  the directory location of the processed data 
 9899
 9900    separator : string usually '-' or '_'
 9901
 9902    sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas
 9903
 9904    date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas
 9905
 9906    id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas
 9907
 9908    report_missing : boolean combined with verbose will report missing modalities
 9909
 9910    progress : integer reports percent progress modulo progress value 
 9911
 9912    verbose : boolean
 9913    """
 9914    from os.path import exists
 9915    musthavecols = ['projectID', 'subjectID','date','imageID']
 9916    for k in range(len(musthavecols)):
 9917        if not musthavecols[k] in sdf.keys():
 9918            raise ValueError('sdf is missing column ' +musthavecols[k] + ' in merge_wides_to_study_dataframe' )
 9919    possible_iids = [ 'imageID', 'imageID', 'imageID', 'flairid', 'dtid1', 'dtid2', 'rsfid1', 'rsfid2', 'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5', 'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'perfid' ]
 9920    modality_ids = [ 'T1wHierarchical', 'T1wHierarchicalSR', 'T1w', 'T2Flair', 'DTI', 'DTI', 'rsfMRI', 'rsfMRI', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'perf']
 9921    alldf=pd.DataFrame()
 9922    for myk in sdf.index:
 9923        if progress > 0 and int(myk) % int(progress) == 0:
 9924            print( str( round( myk/sdf.shape[0]*100.0)) + "%...", end='', flush=True)
 9925        if verbose:
 9926            print( "DOROW " + str(myk) + ' of ' + str( sdf.shape[0] ) )
 9927        csvrow = sdf.loc[sdf.index == myk].dropna(axis=1)
 9928        ct=-1
 9929        for iidkey in possible_iids:
 9930            ct=ct+1
 9931            mod_name = modality_ids[ct]
 9932            if iidkey in csvrow.keys():
 9933                if id_is_int:
 9934                    iid = str( int( csvrow[iidkey].iloc[0] ) )
 9935                else:
 9936                    iid = str( csvrow[iidkey].iloc[0] )
 9937                if verbose:
 9938                    print( "iidkey " + iidkey + " modality " + mod_name + ' iid '+ iid )
 9939                pid=str(csvrow['projectID'].iloc[0] )
 9940                if sid_is_int:
 9941                    sid=str(int(csvrow['subjectID'].iloc[0] ))
 9942                else:
 9943                    sid=str(csvrow['subjectID'].iloc[0] )
 9944                if date_is_int:
 9945                    dt=str(int(csvrow['date'].iloc[0]))
 9946                else:
 9947                    dt=str(csvrow['date'].iloc[0])
 9948                if id_is_int:
 9949                    t1iid=str(int(csvrow['imageID'].iloc[0]))
 9950                else:
 9951                    t1iid=str(csvrow['imageID'].iloc[0])
 9952                if t1iid != iid:
 9953                    iidj=iid+"_"+t1iid
 9954                else:
 9955                    iidj=iid
 9956                rootid = pid +separator+ sid +separator+dt+separator+mod_name+separator+iidj
 9957                myext = rootid +separator+'mmwide.csv'
 9958                nrgwidefn=os.path.join( processing_dir, pid, sid, dt, mod_name, iid, myext )
 9959                moddersub = mod_name
 9960                is_t1=False
 9961                if mod_name == 'T1wHierarchical':
 9962                    is_t1=True
 9963                    moddersub='T1Hier'
 9964                elif mod_name == 'T1wHierarchicalSR':
 9965                    is_t1=True
 9966                    moddersub='T1HSR'
 9967                if exists( nrgwidefn ):
 9968                    if verbose:
 9969                        print( nrgwidefn + " exists")
 9970                    mm=read_mm_csv( nrgwidefn, colprefix=moddersub+'_', is_t1=is_t1, separator=separator, verbose=verbose )
 9971                    if mm is not None:
 9972                        if mod_name == 'T1wHierarchical':
 9973                            a=list( csvrow.keys() )
 9974                            b=list( mm.keys() )
 9975                            abintersect=list(set(b).intersection( set(a) ) )
 9976                            if len( abintersect  ) > 0 :
 9977                                for qq in abintersect:
 9978                                    mm.pop( qq )
 9979                        # mm.index=csvrow.index
 9980                        uidname = mod_name + '_mmwide_filename'
 9981                        mm[ uidname ] = rootid
 9982                        csvrow=pd.concat( [csvrow,mm], axis=1, ignore_index=False )
 9983                else:
 9984                    if verbose and report_missing:
 9985                        print( nrgwidefn + " absent")
 9986        if alldf.shape[0] == 0:
 9987            alldf = csvrow.copy()
 9988            alldf = alldf.loc[:,~alldf.columns.duplicated()]
 9989        else:
 9990            csvrow=csvrow.loc[:,~csvrow.columns.duplicated()]
 9991            alldf = alldf.loc[:,~alldf.columns.duplicated()]
 9992            alldf = pd.concat( [alldf, csvrow], axis=0, ignore_index=True )
 9993    return alldf
 9994
 9995def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False ):
 9996    moddersub = re.sub( "[*]","",nrg_modality)
 9997    nmdf=pd.DataFrame()
 9998    for k in range( hierdfin.shape[0] ):
 9999        if progress is not None:
10000            if k % progress == 0:
10001                progger = str( np.round( k / hierdfin.shape[0] * 100 ) )
10002                print( progger, end ="...", flush=True)
10003        temp = mm_wide_csvs[k]
10004        mypartsf = temp.split("T1wHierarchical")
10005        myparts = mypartsf[0]
10006        t1iid = str(mypartsf[1].split("/")[1])
10007        fnsnm = glob.glob(myparts+"/" + nrg_modality + "/*/*" + t1iid + "*wide.csv")
10008        if len( fnsnm ) > 0 :
10009            for y in fnsnm:
10010                temp=read_mm_csv( y, colprefix=moddersub+'_', is_t1=False, separator=separator, verbose=verbose )
10011                if temp is not None:
10012                    nmdf=pd.concat( [nmdf, temp], axis=0, ignore_index=False )
10013    return nmdf
10014
10015def bind_wide_mm_csvs( mm_wide_csvs, merge=True, separator='-', verbose = 0 ) :
10016    """
10017    will convert a list of t1w hierarchical csv filenames to a merged dataframe
10018
10019    returns a pair of data frames, the left side having all entries and the
10020        right side having row averaged entries i.e. unique values for each visit
10021
10022    set merge to False to return individual dataframes ( for debugging )
10023
10024    return alldata, row_averaged_data
10025    """
10026    mm_wide_csvs.sort()
10027    if not mm_wide_csvs:
10028        print("No files found with specified pattern")
10029        return
10030    # 1. row-bind the t1whier data
10031    # 2. same for each other modality
10032    # 3. merge the modalities by the keys
10033    hierdf = pd.DataFrame()
10034    for y in mm_wide_csvs:
10035        temp=read_mm_csv( y, colprefix='T1Hier_', separator=separator, is_t1=True )
10036        if temp is not None:
10037            hierdf=pd.concat( [hierdf, temp], axis=0, ignore_index=False )
10038    if verbose > 0:
10039        mypro=50
10040    else:
10041        mypro=None
10042    if verbose > 0:
10043        print("thickness")
10044    thkdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T1w', progress=mypro, verbose=verbose==2)
10045    if verbose > 0:
10046        print("flair")
10047    flairdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T2Flair', progress=mypro, verbose=verbose==2)
10048    if verbose > 0:
10049        print("NM")
10050    nmdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'NM2DMT', progress=mypro, verbose=verbose==2)
10051    if verbose > 0:
10052        print("rsf")
10053    rsfdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'rsfMRI*', progress=mypro, verbose=verbose==2)
10054    if verbose > 0:
10055        print("dti")
10056    dtidf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'DTI*', progress=mypro, verbose=verbose==2 )
10057    if not merge:
10058        return hierdf, thkdf, flairdf, nmdf, rsfdf, dtidf
10059    hierdfmix = hierdf.copy()
10060    modality_df_suffixes = [
10061        (thkdf, "_thk"),
10062        (flairdf, "_flair"),
10063        (nmdf, "_nm"),
10064        (rsfdf, "_rsf"),
10065        (dtidf, "_dti"),
10066    ]
10067    for pair in modality_df_suffixes:
10068        hierdfmix = merge_mm_dataframe(hierdfmix, pair[0], pair[1])
10069    hierdfmix = hierdfmix.replace(r'^\s*$', np.nan, regex=True)
10070    return hierdfmix, hierdfmix.groupby("u_hier_id", as_index=False).mean(numeric_only=True)
10071
10072def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10073    try:
10074        hierdf = hierdf.merge(mmdf, on=['sid', 'visitdate', 't1imageuid'], suffixes=("",mm_suffix),how='left')
10075        return hierdf
10076    except KeyError:
10077        return hierdf
10078
10079def augment_image( x,  max_rot=10, nzsd=1 ):
10080    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rot*(-1.0), max_rot ), reference=x )
10081    tx = rRotGenerator.transform()
10082    itx = ants.invert_ants_transform(tx)
10083    y = ants.apply_ants_transform_to_image( tx, x, x, interpolation='linear')
10084    y = ants.add_noise_to_image( y,'additivegaussian', [0,nzsd] )
10085    return y, tx, itx
10086
10087def boot_wmh( flair, t1, t1seg, mmfromconvexhull = 0.0, strict=True,
10088        probability_mask=None, prior_probability=None, n_simulations=16,
10089        random_seed = 42,
10090        verbose=False ) :
10091    import random
10092    random.seed( random_seed )
10093    if verbose and prior_probability is None:
10094        print("augmented flair")
10095    if verbose and prior_probability is not None:
10096        print("augmented flair with prior")
10097    wmh_sum_aug = 0
10098    wmh_sum_prior_aug = 0
10099    augprob = flair * 0.0
10100    augprob_prior = None
10101    if prior_probability is not None:
10102        augprob_prior = flair * 0.0
10103    for n in range(n_simulations):
10104        augflair, tx, itx = augment_image( ants.iMath(flair,"Normalize"), 5, 0.01 )
10105        locwmh = wmh( augflair, t1, t1seg, mmfromconvexhull = mmfromconvexhull,
10106            strict=strict, probability_mask=None, prior_probability=prior_probability )
10107        if verbose:
10108            print( "flair sim: " + str(n) + " vol: " + str( locwmh['wmh_mass'] )+ " vol-prior: " + str( locwmh['wmh_mass_prior'] )+ " snr: " + str( locwmh['wmh_SNR'] ) )
10109        wmh_sum_aug = wmh_sum_aug + locwmh['wmh_mass']
10110        wmh_sum_prior_aug = wmh_sum_prior_aug + locwmh['wmh_mass_prior']
10111        temp = locwmh['WMH_probability_map']
10112        augprob = augprob + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10113        if prior_probability is not None:
10114            temp = locwmh['WMH_posterior_probability_map']
10115            augprob_prior = augprob_prior + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10116    augprob = augprob * (1.0/float( n_simulations ))
10117    if prior_probability is not None:
10118        augprob_prior = augprob_prior * (1.0/float( n_simulations ))
10119    wmh_sum_aug = wmh_sum_aug / float( n_simulations )
10120    wmh_sum_prior_aug = wmh_sum_prior_aug / float( n_simulations )
10121    return{
10122      'flair' : ants.iMath(flair,"Normalize"),
10123      'WMH_probability_map' : augprob,
10124      'WMH_posterior_probability_map' : augprob_prior,
10125      'wmh_mass': wmh_sum_aug,
10126      'wmh_mass_prior': wmh_sum_prior_aug,
10127      'wmh_evr': locwmh['wmh_evr'],
10128      'wmh_SNR': locwmh['wmh_SNR']  }
10129
10130
10131def threaded_bind_wide_mm_csvs( mm_wide_csvs, n_workers ):
10132    from concurrent.futures import as_completed
10133    from concurrent import futures
10134    import concurrent.futures
10135    def chunks(l, n):
10136        """Yield n number of sequential chunks from l."""
10137        d, r = divmod(len(l), n)
10138        for i in range(n):
10139            si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
10140            yield l[si:si+(d+1 if i < r else d)]
10141    import numpy as np
10142    newx = list( chunks( mm_wide_csvs, n_workers ) )
10143    import pandas as pd
10144    alldf = pd.DataFrame()
10145    alldfavg = pd.DataFrame()
10146    with futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
10147        to_do = []
10148        for group in range(len(newx)) :
10149            future = executor.submit(bind_wide_mm_csvs, newx[group] )
10150            to_do.append(future)
10151        results = []
10152        for future in futures.as_completed(to_do):
10153            res0, res1 = future.result()
10154            alldf=pd.concat(  [alldf, res0 ], axis=0, ignore_index=False )
10155            alldfavg=pd.concat(  [alldfavg, res1 ], axis=0, ignore_index=False )
10156    return alldf, alldfavg
10157
10158
10159def get_names_from_data_frame(x, demogIn, exclusions=None):
10160    """
10161    data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]}
10162    antspymm.get_names_from_data_frame( ['e'], df )
10163    antspymm.get_names_from_data_frame( ['a','e'], df )
10164    antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )
10165    """
10166    # Check if x is a string and convert it to a list
10167    if isinstance(x, str):
10168        x = [x]
10169    def get_unique( qq ):
10170        unique = []
10171        for number in qq:
10172            if number in unique:
10173                continue
10174            else:
10175                unique.append(number)
10176        return unique
10177    outnames = list(demogIn.columns[demogIn.columns.str.contains(x[0])])
10178    if len(x) > 1:
10179        for y in x[1:]:
10180            outnames = [i for i in outnames if y in i]
10181    outnames = get_unique( outnames )
10182    if exclusions is not None:
10183        toexclude = [name for name in outnames if exclusions[0] in name ]
10184        if len(exclusions) > 1:
10185            for zz in exclusions[1:]:
10186                toexclude.extend([name for name in outnames if zz in name ])
10187        if len(toexclude) > 0:
10188            outnames = [name for name in outnames if name not in toexclude]
10189    return outnames
10190
10191
10192def average_mm_df( jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False ):
10193    """
10194    jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )
10195    """
10196
10197    jmm = jmm_in.copy()
10198    dxcols=['subjectid1','subjectid2','modalityid','joinid','correlation','distance']
10199    joinDiagnostics = pd.DataFrame( columns = dxcols )
10200    nanList=[math.nan]
10201    def rob(x, y=0.99):
10202        x[x > np.quantile(x, y, nan_policy="omit")] = np.nan
10203        return x
10204
10205    jmm = jmm.replace(r'^\s*$', np.nan, regex=True)
10206
10207    if verbose:
10208        print("do rsfMRI")
10209    # here - we first have to average within each row
10210    dt0 = get_names_from_data_frame(["rsfMRI"], jmm, exclusions=["Unnamed", "rsfMRI_LR", "rsfMRI_RL"])
10211    dt1 = get_names_from_data_frame(["rsfMRI_RL"], jmm, exclusions=["Unnamed"])
10212    if len( dt0 ) > 0 and len( dt1 ) > 0:
10213        flid = dt0[0]
10214        wrows = []
10215        for i in range(jmm.shape[0]):
10216            if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) :
10217                wrows.append(i)
10218        for k in wrows:
10219            v1 = jmm.iloc[k][dt0[1:]].astype(float)
10220            v2 = jmm.iloc[k][dt1[1:]].astype(float)
10221            vvec = [v1[0], v2[0]]
10222            if any(~np.isnan(vvec)):
10223                mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10224                jmm.iloc[k][dt0[0]] = 'rsfMRI'
10225                if len(mynna) == 1:
10226                    if mynna[0] == 0:
10227                        jmm.iloc[k][dt0[1:]] = v1
10228                    if mynna[0] == 1:
10229                        jmm.iloc[k][dt0[1:]] = v2
10230                elif len(mynna) > 1:
10231                    if len(v2) > diagnostic_n:
10232                        v1dx=v1[0:diagnostic_n]
10233                        v2dx=v2[0:diagnostic_n]
10234                    else :
10235                        v1dx=v1
10236                        v2dx=v2
10237                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10238                    mycorr = np.corrcoef( v1dx.values, v2dx.values )[0,1]
10239                    myerr=np.sqrt(np.mean((v1dx.values - v2dx.values)**2))
10240                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'rsfMRI','colavg',mycorr,myerr]
10241                    if mycorr > corr_thresh:
10242                        jmm.loc[k, dt0[1:]] = v1.values*0.5 + v2.values*0.5
10243                    else:
10244                        jmm.loc[k, dt0[1:]] = nanList * len(v1)
10245                    if verbose:
10246                        print( joinDiagnosticsLoc )
10247                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10248
10249    if verbose:
10250        print("do DTI")
10251    # here - we first have to average within each row
10252    dt0 = get_names_from_data_frame(["DTI"], jmm, exclusions=["Unnamed", "DTI_LR", "DTI_RL"])
10253    dt1 = get_names_from_data_frame(["DTI_LR"], jmm, exclusions=["Unnamed"])
10254    dt2 = get_names_from_data_frame( ["DTI_RL"], jmm, exclusions=["Unnamed"])
10255    flid = dt0[0]
10256    wrows = []
10257    for i in range(jmm.shape[0]):
10258        if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) or not pd.isna(jmm[dt2[1]][i]):
10259            wrows.append(i)
10260    for k in wrows:
10261        v1 = jmm.loc[k, dt0[1:]].astype(float)
10262        v2 = jmm.loc[k, dt1[1:]].astype(float)
10263        v3 = jmm.loc[k, dt2[1:]].astype(float)
10264        checkcol = dt0[5]
10265        if not np.isnan(v1[checkcol]):
10266            if v1[checkcol] < 0.25:
10267                v1.replace(np.nan, inplace=True)
10268        checkcol = dt1[5]
10269        if not np.isnan(v2[checkcol]):
10270            if v2[checkcol] < 0.25:
10271                v2.replace(np.nan, inplace=True)
10272        checkcol = dt2[5]
10273        if not np.isnan(v3[checkcol]):
10274            if v3[checkcol] < 0.25:
10275                v3.replace(np.nan, inplace=True)
10276        vvec = [v1[0], v2[0], v3[0]]
10277        if any(~np.isnan(vvec)):
10278            mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10279            jmm.loc[k, dt0[0]] = 'DTI'
10280            if len(mynna) == 1:
10281                if mynna[0] == 0:
10282                    jmm.loc[k, dt0[1:]] = v1
10283                if mynna[0] == 1:
10284                    jmm.loc[k, dt0[1:]] = v2
10285                if mynna[0] == 2:
10286                    jmm.loc[k, dt0[1:]] = v3
10287            elif len(mynna) > 1:
10288                if mynna[0] == 0:
10289                    jmm.loc[k, dt0[1:]] = v1
10290                else:
10291                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10292                    mycorr = np.corrcoef( v2[0:diagnostic_n].values, v3[0:diagnostic_n].values )[0,1]
10293                    myerr=np.sqrt(np.mean((v2[0:diagnostic_n].values - v3[0:diagnostic_n].values)**2))
10294                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'DTI','colavg',mycorr,myerr]
10295                    if mycorr > corr_thresh:
10296                        jmm.loc[k, dt0[1:]] = v2.values*0.5 + v3.values*0.5
10297                    else: #
10298                        jmm.loc[k, dt0[1:]] = nanList * len( dt0[1:] )
10299                    if verbose:
10300                        print( joinDiagnosticsLoc )
10301                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10302
10303
10304    # first task - sort by u_hier_id
10305    jmm = jmm.sort_values( "u_hier_id" )
10306    # get rid of junk columns
10307    badnames = get_names_from_data_frame( ['Unnamed'], jmm )
10308    jmm=jmm.drop(badnames, axis=1)
10309    jmm=jmm.set_index("u_hier_id",drop=False)
10310    # 2nd - get rid of duplicated u_hier_id
10311    jmmUniq = jmm.drop_duplicates( subset="u_hier_id" ) # fast and easy
10312    # for each modality, count which ids have more than one
10313    mod_names = get_valid_modalities()
10314    for mod_name in mod_names:
10315        fl_names = get_names_from_data_frame([mod_name], jmm,
10316            exclusions=['Unnamed',"DTI_LR","DTI_RL","rsfMRI_RL","rsfMRI_LR"])
10317        if len( fl_names ) > 1:
10318            if verbose:
10319                print(mod_name)
10320                print(fl_names)
10321            fl_id = fl_names[0]
10322            n_names = len(fl_names)
10323            locvec = jmm[fl_names[n_names-1]].astype(float)
10324            boolvec=~pd.isna(locvec)
10325            jmmsub = jmm[boolvec][ ['u_hier_id']+fl_names]
10326            my_tbl = Counter(jmmsub['u_hier_id'])
10327            gtoavg = [name for name in my_tbl.keys() if my_tbl[name] == 1]
10328            gtoavgG1 = [name for name in my_tbl.keys() if my_tbl[name] > 1]
10329            if verbose:
10330                print("Join 1")
10331            jmmsub1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavg)][['u_hier_id']+fl_names]
10332            for u in gtoavg:
10333                jmmUniq.loc[u][fl_names[1:]] = jmmsub1.loc[u][fl_names[1:]]
10334            if verbose and len(gtoavgG1) > 1:
10335                print("Join >1")
10336            jmmsubG1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavgG1)][['u_hier_id']+fl_names]
10337            for u in gtoavgG1:
10338                temp = jmmsubG1.loc[u][ ['u_hier_id']+fl_names ]
10339                dropnames = get_names_from_data_frame( ['MM.ID'], temp )
10340                tempVec = temp.drop(columns=dropnames)
10341                joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10342                id1=temp[fl_id].iloc[0]
10343                id2=temp[fl_id].iloc[1]
10344                v1=tempVec.iloc[0][1:].astype(float).to_numpy()
10345                v2=tempVec.iloc[1][1:].astype(float).to_numpy()
10346                if len(v2) > diagnostic_n:
10347                    v1=v1[0:diagnostic_n]
10348                    v2=v2[0:diagnostic_n]
10349                mycorr = np.corrcoef( v1, v2 )[0,1]
10350                # mycorr=temparr[np.triu_indices_from(temparr, k=1)].mean()
10351                myerr=np.sqrt(np.mean((v1 - v2)**2))
10352                joinDiagnosticsLoc.iloc[0] = [id1,id2,mod_name,'rowavg',mycorr,myerr]
10353                if verbose:
10354                    print( joinDiagnosticsLoc )
10355                temp = jmmsubG1.loc[u][fl_names[1:]].astype(float)
10356                if mycorr > corr_thresh or len( v1 ) < 10:
10357                    jmmUniq.loc[u][fl_names[1:]] = temp.mean(axis=0)
10358                else:
10359                    jmmUniq.loc[u][fl_names[1:]] = nanList * temp.shape[1]
10360                joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], 
10361                                            axis=0, ignore_index=False )
10362
10363    return jmmUniq, jmm, joinDiagnostics
10364
10365
10366
10367def quick_viz_mm_nrg(
10368    sourcedir, # root folder
10369    projectid, # project name
10370    sid , # subject unique id
10371    dtid, # date
10372    extract_brain=True,
10373    slice_factor = 0.55,
10374    post = False,
10375    original_sourcedir = None,
10376    filename = None, # output path
10377    verbose = True
10378):
10379    """
10380    This function creates visualizations of brain images for a specific subject in a project using ANTsPy.
10381
10382    Args:
10383
10384    sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)
10385    
10386    projectid (str): Project name.
10387    
10388    sid (str): Subject unique id.
10389    
10390    dtid (str): Date.
10391    
10392    extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.
10393    
10394    slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.
10395
10396    post ( bool ) : if True, will visualize example post-processing results.
10397    
10398    original_sourcedir (str): Root folder for original data (used if post=True)
10399    
10400    filename (str): Output path with extension (.png)
10401    
10402    verbose (bool): If True, information will be printed while running the function. Default is True.
10403
10404    Returns:
10405    None
10406
10407    """
10408    iid='*'
10409    import glob as glob
10410    from os.path import exists
10411    import ants
10412    temp = sourcedir.split( "/" )
10413    subjectrootpath = os.path.join(sourcedir, projectid, sid, dtid)
10414    if verbose:
10415        print( 'subjectrootpath' )
10416        print( subjectrootpath )
10417    t1_search_path = os.path.join(subjectrootpath, "T1w", "*", "*nii.gz")
10418    if verbose:
10419        print(f"t1 search path: {t1_search_path}")
10420    t1fn = glob.glob(t1_search_path)
10421    if len( t1fn ) < 1:
10422        raise ValueError('quick_viz_mm_nrg cannot find the T1w @ ' + subjectrootpath )
10423    vizlist=[]
10424    undlist=[]
10425    nrg_modality_list = [ 'T1w', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10426    if post:
10427        nrg_modality_list = [ 'T1wHierarchical', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10428    for nrgNum in [0,1,2,3,4,5]:
10429        underlay = None
10430        overmodX = nrg_modality_list[nrgNum]
10431        if  'T1w' in overmodX :
10432            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10433            if post:
10434                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*brain_n4_dnz.nii.gz")
10435                mod_search_path_ol = os.path.join(subjectrootpath, overmodX, iid, "*thickness_image.nii.gz" )
10436                mod_search_path_ol = re.sub( "T1wHierarchical","T1w",mod_search_path_ol)
10437                myol = glob.glob(mod_search_path_ol)
10438                if len( myol ) > 0:
10439                    temper = find_most_recent_file( myol )[0]
10440                    underlay = ants.image_read(  temper )
10441                    if verbose:
10442                        print("T1w overlay " + temper )
10443                    underlay = underlay * ants.threshold_image( underlay, 0.2, math.inf )
10444            myimgsr = glob.glob(mod_search_path)
10445            if len( myimgsr ) == 0:
10446                if verbose:
10447                    print("No t1 images: " + sid + dtid )
10448                return None
10449            myimgsr=find_most_recent_file( myimgsr )[0]
10450            vimg=ants.image_read( myimgsr )
10451        elif  'T2Flair' in overmodX :
10452            if verbose:
10453                print("search flair")
10454            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10455            if post and original_sourcedir is not None:
10456                if verbose:
10457                    print("post in flair")
10458                mysubdir = os.path.join(original_sourcedir, projectid, sid, dtid)
10459                mod_search_path_under = os.path.join(mysubdir, overmodX, iid, "*T2Flair*.nii.gz")
10460                if verbose:
10461                    print("post in flair mod_search_path_under " + mod_search_path_under)
10462                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*wmh.nii.gz")
10463                if verbose:
10464                    print("post in flair mod_search_path " + mod_search_path )
10465                myimgul = glob.glob(mod_search_path_under)
10466                if len( myimgul ) > 0:
10467                    myimgul = find_most_recent_file( myimgul )[0]
10468                    if verbose:
10469                        print("Flair  " + myimgul )
10470                    vimg = ants.image_read( myimgul )
10471                    myol = glob.glob(mod_search_path)
10472                    if len( myol ) == 0:
10473                        underlay = myimgsr * 0.0
10474                    else:
10475                        myol = find_most_recent_file( myol )[0]
10476                        if verbose:
10477                            print("Flair overlay " + myol )
10478                        underlay=ants.image_read( myol )
10479                        underlay=underlay*ants.threshold_image(underlay,0.05,math.inf)
10480                else:
10481                    vimg = noizimg.clone()
10482                    underlay = vimg * 0.0
10483            if original_sourcedir is None:
10484                myimgsr = glob.glob(mod_search_path)
10485                if len( myimgsr ) == 0:
10486                    vimg = noizimg.clone()
10487                else:
10488                    myimgsr=find_most_recent_file( myimgsr )[0]
10489                    vimg=ants.image_read( myimgsr )
10490        elif overmodX == 'DTI':
10491            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10492            if post:
10493                mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*fa.nii.gz")
10494            myimgsr = glob.glob(mod_search_path)
10495            if len( myimgsr ) > 0:
10496                myimgsr=find_most_recent_file( myimgsr )[0]
10497                vimg=ants.image_read( myimgsr )
10498            else:
10499                if verbose:
10500                    print("No " + overmodX)
10501                vimg = noizimg.clone()
10502        elif overmodX == 'DTI2':
10503            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10504            myimgsr = glob.glob(mod_search_path)
10505            if len( myimgsr ) > 0:
10506                myimgsr.sort()
10507                myimgsr=myimgsr[len(myimgsr)-1]
10508                vimg=ants.image_read( myimgsr )
10509            else:
10510                if verbose:
10511                    print("No " + overmodX)
10512                vimg = noizimg.clone()
10513        elif overmodX == 'NM2DMT':
10514            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10515            if post:
10516                mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*NM_avg.nii.gz" )
10517            myimgsr = glob.glob(mod_search_path)
10518            if len( myimgsr ) > 0:
10519                myimgsr0=myimgsr[0]
10520                vimg=ants.image_read( myimgsr0 )
10521                for k in range(1,len(myimgsr)):
10522                    temp = ants.image_read( myimgsr[k])
10523                    vimg=vimg+ants.resample_image_to_target(temp,vimg)
10524            else:
10525                if verbose:
10526                    print("No " + overmodX)
10527                vimg = noizimg.clone()
10528        elif overmodX == 'rsfMRI':
10529            mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*nii.gz")
10530            if post:
10531                mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_meanBold.nii.gz" )
10532                mod_search_path_ol = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_DefaultMode.nii.gz" )
10533                myol = glob.glob(mod_search_path_ol)
10534                if len( myol ) > 0:
10535                    myol = find_most_recent_file( myol )[0]
10536                    underlay = ants.image_read( myol )
10537                    if verbose:
10538                        print("BOLD overlay " + myol )
10539                    underlay = underlay * ants.threshold_image( underlay, 0.1, math.inf )
10540            myimgsr = glob.glob(mod_search_path)
10541            if len( myimgsr ) > 0:
10542                myimgsr=find_most_recent_file( myimgsr )[0]
10543                vimg=mm_read_to_3d( myimgsr )
10544            else:
10545                if verbose:
10546                    print("No " + overmodX)
10547                vimg = noizimg.clone()
10548        elif overmodX == 'perf':
10549            mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*nii.gz")
10550            if post:
10551                mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*cbf.nii.gz")
10552            myimgsr = glob.glob(mod_search_path)
10553            if len( myimgsr ) > 0:
10554                myimgsr=find_most_recent_file( myimgsr )[0]
10555                vimg=mm_read_to_3d( myimgsr )
10556            else:
10557                if verbose:
10558                    print("No " + overmodX)
10559                vimg = noizimg.clone()
10560        else :
10561            if verbose:
10562                print("Something else here")
10563            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10564            myimgsr = glob.glob(mod_search_path)
10565            if post:
10566                myimgsr=[]
10567            if len( myimgsr ) > 0:
10568                myimgsr=find_most_recent_file( myimgsr )[0]
10569                vimg=ants.image_read( myimgsr )
10570            else:
10571                if verbose:
10572                    print("No " + overmodX)
10573                vimg = noizimg
10574        if True:
10575            if extract_brain and overmodX == 'T1w' and post == False:
10576                vimg = vimg * antspyt1w.brain_extraction(vimg)
10577            if verbose:
10578                print(f"modality search path: {myimgsr}" + " num: " + str(nrgNum))
10579            if vimg.dimension == 4 and ( overmodX == "DTI2"  ):
10580                ttb0, ttdw=get_average_dwi_b0(vimg)
10581                vimg = ttdw
10582            elif vimg.dimension == 4 and overmodX == "DTI":
10583                ttb0, ttdw=get_average_dwi_b0(vimg)
10584                vimg = ttb0
10585            elif vimg.dimension == 4 :
10586                vimg=ants.get_average_of_timeseries(vimg)
10587            msk=ants.get_mask(vimg)
10588            if overmodX == 'T2Flair':
10589                msk=vimg*0+1
10590            if underlay is not None:
10591                print( overmodX + " has underlay" )
10592            else:
10593                underlay = vimg * 0.0
10594            if nrgNum == 0:
10595                refimg=ants.image_clone( vimg )
10596                noizimg = ants.add_noise_to_image( refimg*0, 'additivegaussian', [100,1] )
10597                vizlist.append( vimg )
10598                undlist.append( underlay )
10599            else:
10600                vimg = ants.iMath( vimg, 'TruncateIntensity',0.01,0.98)
10601                vizlist.append( ants.iMath( vimg, 'Normalize' ) * 255 )
10602                undlist.append( underlay )
10603
10604    # mask & crop systematically ...
10605    msk = ants.get_mask( refimg )
10606    refimg = ants.crop_image( refimg, msk )
10607
10608    for jj in range(len(vizlist)):
10609        vizlist[jj]=ants.resample_image_to_target( vizlist[jj], refimg )
10610        undlist[jj]=ants.resample_image_to_target( undlist[jj], refimg )
10611        print( 'viz: ' + str( jj ) )
10612        print( vizlist[jj] )
10613        print( 'und: ' + str( jj ) )
10614        print( undlist[jj] )
10615
10616
10617    xyz = [None]*3
10618    for i in range(3):
10619        if xyz[i] is None:
10620            xyz[i] = int(refimg.shape[i] * slice_factor )
10621
10622    if verbose:
10623        print('slice positions')
10624        print( xyz )
10625
10626    ants.plot_ortho_stack( vizlist, overlays=undlist, crop=False, reorient=False, filename=filename, xyz=xyz, orient_labels=False )
10627    return
10628    # listlen = len( vizlist )
10629    # vizlist = np.asarray( vizlist )
10630    if show_it is not None:
10631        filenameout=None
10632        if verbose:
10633            print( show_it )
10634        for a in [0,1,2]:
10635            n=int(np.round( refimg.shape[a] * slice_factor ))
10636            slices=np.repeat( int(n), listlen  )
10637            if isinstance(show_it,str):
10638                filenameout=show_it+'_ax'+str(int(a))+'_sl'+str(n)+'.png'
10639                if verbose:
10640                    print( filenameout )
10641#            ants.plot_grid(vizlist.reshape(2,3), slices.reshape(2,3), title='MM Subject ' + sid + ' ' + dtid, rfacecolor='white', axes=a, filename=filenameout )
10642    if verbose:
10643        print("viz complete.")
10644    return vizlist
10645
10646
10647def blind_image_assessment(
10648    image,
10649    viz_filename=None,
10650    title=False,
10651    pull_rank=False,
10652    resample=None,
10653    n_to_skip = 10,
10654    verbose=False
10655):
10656    """
10657    quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D.  produces a png and csv where csv contains:
10658
10659    * reflection error ( estimates asymmetry )
10660
10661    * brisq ( blind quality assessment )
10662
10663    * patch eigenvalue ratio ( blind quality assessment )
10664
10665    * PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)
10666
10667    * mask volume ( estimates foreground object size )
10668
10669    * spacing
10670
10671    * dimension after cropping by mask
10672
10673    image : character or image object usually a nifti image
10674
10675    viz_filename : character for a png output image
10676
10677    title : display a summary title on the png
10678
10679    pull_rank : boolean
10680
10681    resample : None, numeric max or min, resamples image to isotropy
10682
10683    n_to_skip : 10 by default; samples time series every n_to_skip volume
10684
10685    verbose : boolean
10686
10687    """
10688    import glob as glob
10689    from os.path import exists
10690    import ants
10691    import matplotlib.pyplot as plt
10692    from PIL import Image
10693    from pathlib import Path
10694    import json
10695    import re
10696    from dipy.io.gradients import read_bvals_bvecs
10697    mystem=''
10698    if isinstance(image,list):
10699        isfilename=isinstance( image[0], str)
10700        image = image[0]
10701    else:
10702        isfilename=isinstance( image, str)
10703    outdf = pd.DataFrame()
10704    mymeta = None
10705    MagneticFieldStrength = None
10706    image_filename=''
10707    if isfilename:
10708        image_filename = image
10709        if isinstance(image,list):
10710            image_filename=image[0]
10711        json_name = re.sub(".nii.gz",".json",image_filename)
10712        if exists( json_name ):
10713            try:
10714                with open(json_name, 'r') as fcc_file:
10715                    mymeta = json.load(fcc_file)
10716                    if verbose:
10717                        print(json.dumps(mymeta, indent=4))
10718                    fcc_file.close()
10719            except:
10720                pass
10721        mystem=Path( image ).stem
10722        mystem=Path( mystem ).stem
10723        image_reference = ants.image_read( image )
10724        image = ants.image_read( image )
10725    else:
10726        image_reference = ants.image_clone( image )
10727    ntimepoints = 1
10728    bvalueMax=None
10729    bvecnorm=None
10730    if image_reference.dimension == 4:
10731        ntimepoints = image_reference.shape[3]
10732        if "DTI" in image_filename:
10733            myTSseg = segment_timeseries_by_meanvalue( image_reference )
10734            image_b0, image_dwi = get_average_dwi_b0( image_reference, fast=True )
10735            image_b0 = ants.iMath( image_b0, 'Normalize' )
10736            image_dwi = ants.iMath( image_dwi, 'Normalize' )
10737            bval_name = re.sub(".nii.gz",".bval",image_filename)
10738            bvec_name = re.sub(".nii.gz",".bvec",image_filename)
10739            if exists( bval_name ) and exists( bvec_name ):
10740                bvals, bvecs = read_bvals_bvecs( bval_name , bvec_name  )
10741                bvalueMax = bvals.max()
10742                bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
10743                bvecnorm = bvecnorm.max()
10744        else:
10745            image_b0 = ants.get_average_of_timeseries( image_reference ).iMath("Normalize")
10746    else:
10747        image_compare = ants.smooth_image( image_reference, 3, sigma_in_physical_coordinates=False )
10748    for jjj in range(0,ntimepoints,n_to_skip):
10749        modality='unknown'
10750        if "rsfMRI" in image_filename:
10751            modality='rsfMRI'
10752        elif "perf" in image_filename:
10753            modality='perf'
10754        elif "DTI" in image_filename:
10755            modality='DTI'
10756        elif "T1w" in image_filename:
10757            modality='T1w'
10758        elif "T2Flair" in image_filename:
10759            modality='T2Flair'
10760        elif "NM2DMT" in image_filename:
10761            modality='NM2DMT'
10762        if image_reference.dimension == 4:
10763            image = ants.slice_image( image_reference, idx=int(jjj), axis=3 )
10764            if "DTI" in image_filename:
10765                if jjj in myTSseg['highermeans']:
10766                    image_compare = ants.image_clone( image_b0 )
10767                    modality='DTIb0'
10768                else:
10769                    image_compare = ants.image_clone( image_dwi )
10770                    modality='DTIdwi'
10771            else:
10772                image_compare = ants.image_clone( image_b0 )
10773        # image = ants.iMath( image, 'TruncateIntensity',0.01,0.995)
10774        minspc = np.min(ants.get_spacing(image))
10775        maxspc = np.max(ants.get_spacing(image))
10776        if resample is not None:
10777            if resample == 'min':
10778                if minspc < 1e-12:
10779                    minspc = np.max(ants.get_spacing(image))
10780                newspc = np.repeat( minspc, 3 )
10781            elif resample == 'max':
10782                newspc = np.repeat( maxspc, 3 )
10783            else:
10784                newspc = np.repeat( resample, 3 )
10785            image = ants.resample_image( image, newspc )
10786            image_compare = ants.resample_image( image_compare, newspc )
10787        else:
10788            # check for spc close to zero
10789            spc = list(ants.get_spacing(image))
10790            for spck in range(len(spc)):
10791                if spc[spck] < 1e-12:
10792                    spc[spck]=1
10793            ants.set_spacing( image, spc )
10794            ants.set_spacing( image_compare, spc )
10795        # if "NM2DMT" in image_filename or "FIXME" in image_filename or "SPECT" in image_filename or "UNKNOWN" in image_filename:
10796        minspc = np.min(ants.get_spacing(image))
10797        maxspc = np.max(ants.get_spacing(image))
10798        msk = ants.threshold_image( ants.iMath(image,'Normalize'), 0.15, 1.0 )
10799        # else:
10800        #    msk = ants.get_mask( image )
10801        msk = ants.morphology(msk, "close", 3 )
10802        bgmsk = msk*0+1-msk
10803        mskdil = ants.iMath(msk, "MD", 4 )
10804        # ants.plot_ortho( image, msk, crop=False )
10805        nvox = int( msk.sum() )
10806        spc = ants.get_spacing( image )
10807        org = ants.get_origin( image )
10808        if ( nvox > 0 ):
10809            image = ants.crop_image( image, mskdil ).iMath("Normalize")
10810            msk = ants.crop_image( msk, mskdil ).iMath("Normalize")
10811            bgmsk = ants.crop_image( bgmsk, mskdil ).iMath("Normalize")
10812            image_compare = ants.crop_image( image_compare, mskdil ).iMath("Normalize")           
10813            npatch = int( np.round(  0.1 * nvox ) )
10814            npatch = np.min(  [512,npatch ] )
10815            patch_shape = []
10816            for k in range( 3 ):
10817                p = int( 32.0 / ants.get_spacing( image  )[k] )
10818                if p > int( np.round( image.shape[k] * 0.5 ) ):
10819                    p = int( np.round( image.shape[k] * 0.5 ) )
10820                patch_shape.append( p )
10821            if verbose:
10822                print(image)
10823                print( patch_shape )
10824                print( npatch )
10825            myevr = math.nan # dont want to fail if something odd happens in patch extraction
10826            try:
10827                myevr = antspyt1w.patch_eigenvalue_ratio( image, npatch, patch_shape,
10828                    evdepth = 0.9, mask=msk )
10829            except:
10830                pass
10831            if pull_rank:
10832                image = ants.rank_intensity(image)
10833            imagereflect = ants.reflect_image(image, axis=0)
10834            asym_err = ( image - imagereflect ).abs().mean()
10835            # estimate noise by center cropping, denoizing and taking magnitude of difference
10836            nocrop=False
10837            if image.dimension == 3:
10838                if image.shape[2] == 1:
10839                    nocrop=True        
10840            if maxspc/minspc > 10:
10841                nocrop=True
10842            if nocrop:
10843                mycc = ants.image_clone( image )
10844            else:
10845                mycc = antspyt1w.special_crop( image,
10846                    ants.get_center_of_mass( msk *0 + 1 ), patch_shape )
10847            myccd = ants.denoise_image( mycc, p=2,r=2,noise_model='Gaussian' )
10848            noizlevel = ( mycc - myccd ).abs().mean()
10849    #        ants.plot_ortho( image, crop=False, filename=viz_filename, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
10850    #        from brisque import BRISQUE
10851    #        obj = BRISQUE(url=False)
10852    #        mybrisq = obj.score( np.array( Image.open( viz_filename )) )
10853            msk_vol = msk.sum() * np.prod( spc )
10854            bgstd = image[ bgmsk == 1 ].std()
10855            fgmean = image[ msk == 1 ].mean()
10856            bgmean = image[ bgmsk == 1 ].mean()
10857            snrref = fgmean / bgstd
10858            cnrref = ( fgmean - bgmean ) / bgstd
10859            psnrref = antspynet.psnr(  image_compare, image  )
10860            ssimref = antspynet.ssim(  image_compare, image  )
10861            if nocrop:
10862                mymi = math.inf
10863            else:
10864                mymi = ants.image_mutual_information( image_compare, image )
10865        else:
10866            msk_vol = 0
10867            myevr = mymi = ssimref = psnrref = cnrref = asym_err = noizlevel = math.nan
10868            
10869        mriseries=None
10870        mrimfg=None
10871        mrimodel=None
10872        mriSAR=None
10873        BandwidthPerPixelPhaseEncode=None
10874        PixelBandwidth=None
10875        if mymeta is not None:
10876            # mriseries=mymeta['']
10877            try:
10878                mrimfg=mymeta['Manufacturer']
10879            except:
10880                pass
10881            try:
10882                mrimodel=mymeta['ManufacturersModelName']
10883            except:
10884                pass
10885            try:
10886                MagneticFieldStrength=mymeta['MagneticFieldStrength']
10887            except:
10888                pass
10889            try:
10890                PixelBandwidth=mymeta['PixelBandwidth']
10891            except:
10892                pass
10893            try:
10894                BandwidthPerPixelPhaseEncode=mymeta['BandwidthPerPixelPhaseEncode']
10895            except:
10896                pass
10897            try:
10898                mriSAR=mymeta['SAR']
10899            except:
10900                pass
10901        ttl=mystem + ' '
10902        ttl=''
10903        ttl=ttl + "NZ: " + "{:0.4f}".format(noizlevel) + " SNR: " + "{:0.4f}".format(snrref) + " CNR: " + "{:0.4f}".format(cnrref) + " PS: " + "{:0.4f}".format(psnrref)+ " SS: " + "{:0.4f}".format(ssimref) + " EVR: " + "{:0.4f}".format(myevr)+ " MI: " + "{:0.4f}".format(mymi)
10904        if viz_filename is not None and ( jjj == 0 or (jjj % 30 == 0) ) and image.shape[2] < 685:
10905            viz_filename_use = re.sub( ".png", "_slice"+str(jjj).zfill(4)+".png", viz_filename )
10906            ants.plot_ortho( image, crop=False, filename=viz_filename_use, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0,  title=ttl, titlefontsize=12, title_dy=-0.02,textfontcolor='red' )
10907        df = pd.DataFrame([[ 
10908            mystem, 
10909            image_reference.dimension, 
10910            noizlevel, snrref, cnrref, psnrref, ssimref, mymi, asym_err, myevr, msk_vol, 
10911            spc[0], spc[1], spc[2],org[0], org[1], org[2], 
10912            image.shape[0], image.shape[1], image.shape[2], ntimepoints, 
10913            jjj, modality, mriseries, mrimfg, mrimodel, MagneticFieldStrength, mriSAR, PixelBandwidth, BandwidthPerPixelPhaseEncode, bvalueMax, bvecnorm ]], 
10914            columns=[
10915                'filename', 
10916                'dimensionality',
10917                'noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol', 'spc0','spc1','spc2','org0','org1','org2','dimx','dimy','dimz','dimt','slice','modality', 'mriseries', 'mrimfg', 'mrimodel', 'mriMagneticFieldStrength', 'mriSAR', 'mriPixelBandwidth', 'mriPixelBandwidthPE', 'dti_bvalueMax', 'dti_bvecnorm' ])
10918        outdf = pd.concat( [outdf, df ], axis=0, ignore_index=False )
10919        if verbose:
10920            print( outdf )
10921    if viz_filename is not None:
10922        csvfn = re.sub( "png", "csv", viz_filename )
10923        outdf.to_csv( csvfn )
10924    return outdf
10925
10926def remove_unwanted_columns(df):
10927    # Identify columns to drop: those named 'X' or starting with 'Unnamed'
10928    cols_to_drop = [col for col in df.columns if col == 'X' or col.startswith('Unnamed')]
10929    
10930    # Drop the identified columns from the DataFrame, if any
10931    df_cleaned = df.drop(columns=cols_to_drop, errors='ignore')
10932    
10933    return df_cleaned
10934
10935def process_dataframe_generalized(df, group_by_column):
10936    # Make sure the group_by_column is excluded from both numeric and other columns calculations
10937    numeric_cols = df.select_dtypes(include='number').columns.difference([group_by_column])
10938    other_cols = df.columns.difference(numeric_cols).difference([group_by_column])
10939    
10940    # Define aggregation functions: mean for numeric cols, mode for other cols
10941    # Update to handle empty mode results safely
10942    agg_dict = {col: 'mean' for col in numeric_cols}
10943    agg_dict.update({
10944        col: lambda x: pd.Series.mode(x).iloc[0] if not pd.Series.mode(x).empty else None for col in other_cols
10945    })    
10946    # Group by the specified column, applying different aggregation functions to different columns
10947    processed_df = df.groupby(group_by_column, as_index=False).agg(agg_dict)
10948    return processed_df
10949
10950def average_blind_qc_by_modality(qc_full,verbose=False):
10951    """
10952    Averages time series qc results to yield one entry per image. this also filters to "known" columns.
10953
10954    Args:
10955    qc_full: pandas dataframe containing the full qc data.
10956
10957    Returns:
10958    pandas dataframe containing the processed qc data.
10959    """
10960    qc_full = remove_unwanted_columns( qc_full )
10961    # Get unique modalities
10962    modalities = qc_full['modality'].unique()
10963    modalities = modalities[modalities != 'unknown']
10964    # Get unique ids
10965    uid = qc_full['filename']
10966    to_average = uid.unique()
10967    meta = pd.DataFrame(columns=qc_full.columns )
10968    # Process each unique id
10969    n = len(to_average)
10970    for k in range(n):
10971        if verbose:
10972            if k % 100 == 0:
10973                progger = str( np.round( k / n * 100 ) )
10974                print( progger, end ="...", flush=True)
10975        m1sel = uid == to_average[k]
10976        if sum(m1sel) > 1:
10977            # If more than one entry for id, take the average of continuous columns,
10978            # maximum of the slice column, and the first entry of the other columns
10979            mfsub = process_dataframe_generalized(qc_full[m1sel],'filename')
10980        else:
10981            mfsub = qc_full[m1sel]
10982        meta.loc[k] = mfsub.iloc[0]
10983    meta['modality'] = meta['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
10984    return meta
10985
10986def wmh( flair, t1, t1seg,
10987    mmfromconvexhull = 3.0,
10988    strict=True,
10989    probability_mask=None,
10990    prior_probability=None,
10991    model='sysu',
10992    verbose=False ) :
10993    """
10994    Outputs the WMH probability mask and a summary single measurement
10995
10996    Arguments
10997    ---------
10998    flair : ANTsImage
10999        input 3-D FLAIR brain image (not skull-stripped).
11000
11001    t1 : ANTsImage
11002        input 3-D T1 brain image (not skull-stripped).
11003
11004    t1seg : ANTsImage
11005        T1 segmentation image
11006
11007    mmfromconvexhull : float
11008        restrict WMH to regions that are WM or mmfromconvexhull mm away from the
11009        convex hull of the cerebrum.   we choose a default value based on
11010        Figure 4 from:
11011        https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf
11012
11013    strict: boolean - if True, only use convex hull distance
11014
11015    probability_mask : None - use to compute wmh just once - then this function
11016        just does refinement and summary
11017
11018    prior_probability : optional prior probability image in space of the input t1
11019
11020    model : either sysu or hyper
11021
11022    verbose : boolean
11023
11024    Returns
11025    ---------
11026    WMH probability map and a summary single measurement which is the sum of the WMH map
11027
11028    """
11029    import numpy as np
11030    import math
11031    t1_2_flair_reg = ants.registration(flair, t1, type_of_transform = 'antsRegistrationSyNRepro[r]') # Register T1 to Flair
11032    if probability_mask is None and model == 'sysu':
11033        if verbose:
11034            print('sysu')
11035        probability_mask = antspynet.sysu_media_wmh_segmentation( flair )
11036    elif probability_mask is None and model == 'hyper':
11037        if verbose:
11038            print('hyper')
11039        probability_mask = antspynet.hypermapp3r_segmentation( t1_2_flair_reg['warpedmovout'], flair )
11040    # t1_2_flair_reg = tra_initializer( flair, t1, n_simulations=4, max_rotation=5, transform=['rigid'], verbose=False )
11041    prior_probability_flair = None
11042    if prior_probability is not None:
11043        prior_probability_flair = ants.apply_transforms( flair, prior_probability,
11044            t1_2_flair_reg['fwdtransforms'] )
11045    wmseg_mask = ants.threshold_image( t1seg,
11046        low_thresh = 3, high_thresh = 3).iMath("FillHoles")
11047    wmseg_mask_use = ants.image_clone( wmseg_mask )
11048    distmask = None
11049    if mmfromconvexhull > 0:
11050            convexhull = ants.threshold_image( t1seg, 1, 4 )
11051            spc2vox = np.prod( ants.get_spacing( t1seg ) )
11052            voxdist = 0.0
11053            myspc = ants.get_spacing( t1seg )
11054            for k in range( t1seg.dimension ):
11055                voxdist = voxdist + myspc[k] * myspc[k]
11056            voxdist = math.sqrt( voxdist )
11057            nmorph = round( 2.0 / voxdist )
11058            convexhull = ants.morphology( convexhull, "close", nmorph ).iMath("FillHoles")
11059            dist = ants.iMath( convexhull, "MaurerDistance" ) * -1.0
11060            distmask = ants.threshold_image( dist, mmfromconvexhull, 1.e80 )
11061            wmseg_mask = wmseg_mask + distmask
11062            if strict:
11063                wmseg_mask_use = ants.threshold_image( wmseg_mask, 2, 2 )
11064            else:
11065                wmseg_mask_use = ants.threshold_image( wmseg_mask, 1, 2 )
11066    ##############################################################################
11067    wmseg_2_flair = ants.apply_transforms(flair, wmseg_mask_use,
11068        transformlist = t1_2_flair_reg['fwdtransforms'],
11069        interpolator = 'nearestNeighbor' )
11070    seg_2_flair = ants.apply_transforms(flair, t1seg,
11071        transformlist = t1_2_flair_reg['fwdtransforms'],
11072        interpolator = 'nearestNeighbor' )
11073    csfmask = ants.threshold_image(seg_2_flair,1,1)
11074    flairsnr = mask_snr( flair, csfmask, wmseg_2_flair, bias_correct = False )
11075    probability_mask_WM = wmseg_2_flair * probability_mask # Remove WMH signal outside of WM
11076    wmh_sum = np.prod( ants.get_spacing( flair ) ) * probability_mask_WM.sum()
11077    wmh_sum_prior = math.nan
11078    probability_mask_posterior = None
11079    if prior_probability_flair is not None:
11080        probability_mask_posterior = prior_probability_flair * probability_mask # use prior
11081        wmh_sum_prior = np.prod( ants.get_spacing(flair) ) * probability_mask_posterior.sum()
11082    if math.isnan( wmh_sum ):
11083        wmh_sum=0
11084    if math.isnan( wmh_sum_prior ):
11085        wmh_sum_prior=0
11086    flair_evr = antspyt1w.patch_eigenvalue_ratio( flair, 512, [16,16,16], evdepth = 0.9, mask=wmseg_2_flair )
11087    return{
11088        'WMH_probability_map_raw': probability_mask,
11089        'WMH_probability_map' : probability_mask_WM,
11090        'WMH_posterior_probability_map' : probability_mask_posterior,
11091        'wmh_mass': wmh_sum,
11092        'wmh_mass_prior': wmh_sum_prior,
11093        'wmh_evr' : flair_evr,
11094        'wmh_SNR' : flairsnr,
11095        'convexhull_mask': distmask }
11096
11097
11098def replace_elements_in_numpy_array(original_array, indices_to_replace, new_value):
11099    """
11100    Replace specified elements or rows in a numpy array with a new value.
11101
11102    Parameters:
11103    original_array (numpy.ndarray): A numpy array in which elements or rows are to be replaced.
11104    indices_to_replace (list or numpy.ndarray): Indices of elements or rows to be replaced.
11105    new_value: The new value to replace the specified elements or rows.
11106
11107    Returns:
11108    numpy.ndarray: A new numpy array with the specified elements or rows replaced. If the input array is None,
11109                   the function returns None.
11110    """
11111
11112    if original_array is None:
11113        return None
11114
11115    max_index = original_array.size if original_array.ndim == 1 else original_array.shape[0]
11116
11117    # Filter out invalid indices and check for any out-of-bounds indices
11118    valid_indices = []
11119    for idx in indices_to_replace:
11120        if idx < max_index:
11121            valid_indices.append(idx)
11122        else:
11123            warnings.warn(f"Warning: Index {idx} is out of bounds and will be ignored.")
11124
11125    if original_array.ndim == 1:
11126        # Replace elements in a 1D array
11127        original_array[valid_indices] = new_value
11128    elif original_array.ndim == 2:
11129        # Replace rows in a 2D array
11130        original_array[valid_indices, :] = new_value
11131    else:
11132        raise ValueError("original_array must be either 1D or 2D.")
11133
11134    return original_array
11135
11136
11137
11138def remove_elements_from_numpy_array(original_array, indices_to_remove):
11139    """
11140    Remove specified elements or rows from a numpy array.
11141
11142    Parameters:
11143    original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed.
11144    indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.
11145
11146    Returns:
11147    numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None,
11148                   the function returns None.
11149    """
11150
11151    if original_array is None:
11152        return None
11153
11154    if original_array.ndim == 1:
11155        # Remove elements from a 1D array
11156        return np.delete(original_array, indices_to_remove)
11157    elif original_array.ndim == 2:
11158        # Remove rows from a 2D array
11159        return np.delete(original_array, indices_to_remove, axis=0)
11160    else:
11161        raise ValueError("original_array must be either 1D or 2D.")
11162
11163def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11164    """
11165    Remove specified volumes from a time series.
11166
11167    :param time_series: ANTsImage representing the time series (4D image).
11168    :param volumes_to_remove: List of volume indices to remove.
11169    :return: ANTsImage with specified volumes removed.
11170    """
11171    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11172        raise ValueError("time_series must be an ANTsImage.")
11173
11174    if time_series.dimension != 4:
11175        raise ValueError("time_series must be a 4D image.")
11176
11177    # Create a boolean index for volumes to keep
11178    volumes_to_keep = [i for i in range(time_series.shape[3]) if i not in volumes_to_remove]
11179
11180    # Select the volumes to keep
11181    filtered_time_series = ants.from_numpy( time_series.numpy()[..., volumes_to_keep] )
11182
11183    return ants.copy_image_info( time_series, filtered_time_series )
11184
11185def remove_elements_from_list(original_list, elements_to_remove):
11186    """
11187    Remove specified elements from a list.
11188
11189    Parameters:
11190    original_list (list): The original list from which elements will be removed.
11191    elements_to_remove (list): A list of elements that need to be removed from the original list.
11192
11193    Returns:
11194    list: A new list with the specified elements removed.
11195    """
11196    return [element for element in original_list if element not in elements_to_remove]
11197
11198
11199def impute_timeseries(time_series, volumes_to_impute, method='linear', verbose=False):
11200    """
11201    Impute specified volumes from a time series with interpolated values.
11202
11203    :param time_series: ANTsImage representing the time series (4D image).
11204    :param volumes_to_impute: List of volume indices to impute.
11205    :param method: Interpolation method ('linear' or other methods if implemented).
11206    :param verbose: boolean
11207    :return: ANTsImage with specified volumes imputed.
11208    """
11209    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11210        raise ValueError("time_series must be an ANTsImage.")
11211
11212    if time_series.dimension != 4:
11213        raise ValueError("time_series must be a 4D image.")
11214
11215    # Convert time_series to numpy for manipulation
11216    time_series_np = time_series.numpy()
11217    total_volumes = time_series_np.shape[3]
11218
11219    # Create a complement list of volumes not to impute
11220    volumes_not_to_impute = [i for i in range(total_volumes) if i not in volumes_to_impute]
11221
11222    # Define the lower and upper bounds
11223    min_valid_index = min(volumes_not_to_impute)
11224    max_valid_index = max(volumes_not_to_impute)
11225
11226    for vol_idx in volumes_to_impute:
11227        # Ensure the volume index is within the valid range
11228        if vol_idx < 0 or vol_idx >= total_volumes:
11229            raise ValueError(f"Volume index {vol_idx} is out of bounds.")
11230
11231        # Find the nearest valid lower index within the bounds
11232        lower_candidates = [v for v in volumes_not_to_impute if v <= vol_idx]
11233        lower_idx = max(lower_candidates) if lower_candidates else min_valid_index
11234
11235        # Find the nearest valid upper index within the bounds
11236        upper_candidates = [v for v in volumes_not_to_impute if v >= vol_idx]
11237        upper_idx = min(upper_candidates) if upper_candidates else max_valid_index
11238
11239        if verbose:
11240            print(f"Imputing volume {vol_idx} using indices {lower_idx} and {upper_idx}")
11241
11242        if method == 'linear':
11243            # Linear interpolation between the two nearest volumes
11244            lower_volume = time_series_np[..., lower_idx]
11245            upper_volume = time_series_np[..., upper_idx]
11246            interpolated_volume = (lower_volume + upper_volume) / 2
11247        else:
11248            # Placeholder for other interpolation methods
11249            raise NotImplementedError("Currently, only linear interpolation is implemented.")
11250
11251        # Replace the specified volume with the interpolated volume
11252        time_series_np[..., vol_idx] = interpolated_volume
11253
11254    # Convert the numpy array back to ANTsImage
11255    imputed_time_series = ants.from_numpy(time_series_np)
11256    imputed_time_series = ants.copy_image_info(time_series, imputed_time_series)
11257
11258    return imputed_time_series
11259
11260def impute_dwi( dwi, threshold = 0.20, imputeb0=False, mask=None, verbose=False ):
11261    """
11262    Identify bad volumes in a dwi and impute them fully automatically.
11263
11264    :param dwi: ANTsImage representing the time series (4D image).
11265    :param threshold: threshold (0,1) for outlierness (lower means impute more data)
11266    :param imputeb0: boolean will impute the b0 with dwi if True
11267    :param mask: restricts to a region of interest
11268    :param verbose: boolean
11269    :return: ANTsImage automatically imputed.
11270    """
11271    list1 = segment_timeseries_by_meanvalue( dwi )['highermeans']
11272    if imputeb0:
11273        dwib = impute_timeseries( dwi, list1 ) # focus on the dwi - not the b0
11274        looped, list2 = loop_timeseries_censoring( dwib, threshold, mask )
11275    else:
11276        looped, list2 = loop_timeseries_censoring( dwi, threshold, mask )
11277    if verbose:
11278        print( list1 )
11279        print( list2 )
11280    complement = remove_elements_from_list( list2, list1 )
11281    if verbose:
11282        print( "Imputing:")
11283        print( complement )
11284    if len( complement ) == 0:
11285        return dwi
11286    return impute_timeseries( dwi, complement )
11287
11288def censor_dwi( dwi, bval, bvec, threshold = 0.20, imputeb0=False, mask=None, verbose=False ):
11289    """
11290    Identify bad volumes in a dwi and impute them fully automatically.
11291
11292    :param dwi: ANTsImage representing the time series (4D image).
11293    :param bval: bval array
11294    :param bvec: bvec array
11295    :param threshold: threshold (0,1) for outlierness (lower means impute more data)
11296    :param imputeb0: boolean will impute the b0 with dwi if True
11297    :param mask: restricts to a region of interest
11298    :param verbose: boolean
11299    :return: ANTsImage automatically imputed.
11300    """
11301    list1 = segment_timeseries_by_meanvalue( dwi )['highermeans']
11302    if imputeb0:
11303        dwib = impute_timeseries( dwi, list1 ) # focus on the dwi - not the b0
11304        looped, list2 = loop_timeseries_censoring( dwib, threshold, mask, verbose=verbose)
11305    else:
11306        looped, list2 = loop_timeseries_censoring( dwi, threshold, mask, verbose=verbose )
11307    if verbose:
11308        print( list1 )
11309        print( list2 )
11310    complement = remove_elements_from_list( list2, list1 )
11311    if verbose:
11312        print( "censoring:")
11313        print( complement )
11314    if len( complement ) == 0:
11315        return dwi, bval, bvec
11316    return remove_volumes_from_timeseries( dwi, complement ), remove_elements_from_numpy_array( bval, complement ), remove_elements_from_numpy_array( bvec, complement )
11317
11318
11319def flatten_time_series(time_series):
11320    """
11321    Flatten a 4D time series into a 2D array.
11322    
11323    :param time_series: A 4D numpy array where the last dimension is time.
11324    :return: A 2D numpy array where each row is a flattened volume.
11325    """
11326    n_volumes = time_series.shape[3]
11327    return time_series.reshape(-1, n_volumes).T
11328
11329def calculate_loop_scores_full(flattened_series, n_neighbors=20, verbose=True ):
11330    """
11331    Calculate Local Outlier Probabilities for each volume.
11332    
11333    :param flattened_series: A 2D numpy array from flatten_time_series.
11334    :param n_neighbors: Number of neighbors to use for calculating LOF scores.
11335    :param verbose: boolean
11336    :return: An array of LoOP scores.
11337    """
11338    from PyNomaly import loop
11339    from sklearn.neighbors import NearestNeighbors
11340    from sklearn.preprocessing import StandardScaler
11341    # replace nans with zero
11342    if verbose:
11343        print("loop: nan_to_num")
11344    flattened_series=np.nan_to_num(flattened_series, nan=0)
11345    scaler = StandardScaler()
11346    scaler.fit(flattened_series)
11347    data = scaler.transform(flattened_series)
11348    data=np.nan_to_num(data, nan=0)
11349    if n_neighbors > int(flattened_series.shape[0]/2.0):
11350        n_neighbors = int(flattened_series.shape[0]/2.0)
11351    if verbose:
11352        print("loop: nearest neighbors init")
11353    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric='minkowski')
11354    if verbose:
11355        print("loop: nearest neighbors fit")
11356    neigh.fit(data)
11357    d, idx = neigh.kneighbors(data, return_distance=True)
11358    if verbose:
11359        print("loop: probability")
11360    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11361    return m.local_outlier_probabilities[:]
11362
11363
11364def calculate_loop_scores(
11365    flattened_series,
11366    n_neighbors=20,
11367    n_features_sample=0.02,
11368    n_feature_repeats=5,
11369    seed=42,
11370    use_approx_knn=True,
11371    verbose=True,
11372):
11373    """
11374    Memory-efficient and robust LoOP score estimation with optional approximate KNN
11375    and averaging over multiple random feature subsets.
11376
11377    Parameters:
11378        flattened_series (np.ndarray): 2D array (n_samples x n_features)
11379        n_neighbors (int): Number of neighbors for LoOP
11380        n_features_sample (int or float): Number or fraction of features to sample
11381        n_feature_repeats (int): How many independent feature subsets to sample and average over
11382        seed (int): Random seed
11383        use_approx_knn (bool): Whether to use fast approximate KNN (via pynndescent)
11384        verbose (bool): Verbose output
11385
11386    Returns:
11387        np.ndarray: Averaged local outlier probabilities (length n_samples)
11388    """
11389    import numpy as np
11390    from sklearn.preprocessing import StandardScaler
11391    from PyNomaly import loop
11392
11393    # Optional approximate nearest neighbors
11394    try:
11395        from pynndescent import NNDescent
11396        has_nn_descent = True
11397    except ImportError:
11398        has_nn_descent = False
11399
11400    rng = np.random.default_rng(seed)
11401    X = np.nan_to_num(flattened_series, nan=0).astype(np.float32)
11402    n_samples, n_features = X.shape
11403
11404    # Handle feature sampling
11405    if isinstance(n_features_sample, float):
11406        if 0 < n_features_sample <= 1.0:
11407            n_features_sample = max(1, int(n_features_sample * n_features))
11408        else:
11409            raise ValueError("If float, n_features_sample must be in (0, 1].")
11410
11411    n_features_sample = min(n_features, n_features_sample)
11412
11413    if n_neighbors >= n_samples:
11414        n_neighbors = max(1, n_samples // 2)
11415
11416    if verbose:
11417        print(f"[LoOP] Input shape: {X.shape}")
11418        print(f"[LoOP] Sampling {n_features_sample} features per repeat, {n_feature_repeats} repeats")
11419        print(f"[LoOP] Using {n_neighbors} neighbors")
11420
11421    loop_scores = []
11422
11423    for rep in range(n_feature_repeats):
11424        feature_idx = rng.choice(n_features, n_features_sample, replace=False)
11425        X_sub = X[:, feature_idx]
11426
11427        scaler = StandardScaler(copy=False)
11428        X_sub = scaler.fit_transform(X_sub)
11429        X_sub = np.nan_to_num(X_sub, nan=0)
11430
11431        # Approximate or exact KNN
11432        if use_approx_knn and has_nn_descent and n_samples > 1000:
11433            if verbose:
11434                print(f"  [Rep {rep+1}] Using NNDescent (approximate KNN)")
11435            ann = NNDescent(X_sub, n_neighbors=n_neighbors, random_state=seed + rep)
11436            indices, dists = ann.neighbor_graph
11437        else:
11438            from sklearn.neighbors import NearestNeighbors
11439            if verbose:
11440                print(f"  [Rep {rep+1}] Using NearestNeighbors (exact KNN)")
11441            nn = NearestNeighbors(n_neighbors=n_neighbors)
11442            nn.fit(X_sub)
11443            dists, indices = nn.kneighbors(X_sub)
11444
11445        # LoOP score for this repeat
11446        model = loop.LocalOutlierProbability(
11447            distance_matrix=dists,
11448            neighbor_matrix=indices,
11449            n_neighbors=n_neighbors
11450        ).fit()
11451        loop_scores.append(model.local_outlier_probabilities[:])
11452
11453    # Average over repeats
11454    loop_scores = np.stack(loop_scores)
11455    loop_scores_mean = loop_scores.mean(axis=0)
11456
11457    if verbose:
11458        print(f"[LoOP] Averaged over {n_feature_repeats} feature subsets. Final shape: {loop_scores_mean.shape}")
11459
11460    return loop_scores_mean
11461
11462
11463
11464def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg ):
11465    """
11466    Process CBF time series to remove high-leverage points.
11467    Derived from the SCORE algorithm by Sudipto Dolui et. al.
11468
11469    Parameters:
11470    cbfts (ANTsImage): 4D ANTsImage of CBF time series.
11471    csf_seg (ANTsImage): CSF binary map.
11472    gm_seg (ANTsImage): Gray matter binary map.
11473    wm_seg (ANTsImage): WM binary map.
11474
11475    Returns:
11476    ANTsImage: Processed CBF time series.
11477    ndarray: Index of removed volumes.
11478    """
11479    
11480    n_gm_voxels = np.sum(gm_seg.numpy()) - 1
11481    n_wm_voxels = np.sum(wm_seg.numpy()) - 1
11482    n_csf_voxels = np.sum(csf_seg.numpy()) - 1
11483    mask1img = gm_seg + wm_seg + csf_seg
11484    mask1 = (mask1img==1).numpy()
11485    
11486    cbfts_np = cbfts.numpy()
11487    gmbool = (gm_seg==1).numpy()
11488    csfbool = (csf_seg==1).numpy()
11489    wmbool = (wm_seg==1).numpy()
11490    gm_cbf_ts = ants.timeseries_to_matrix( cbfts, gm_seg )
11491    gm_cbf_ts = np.squeeze(np.mean(gm_cbf_ts, axis=1))
11492    
11493    median_gm_cbf = np.median(gm_cbf_ts)
11494    mad_gm_cbf = np.median(np.abs(gm_cbf_ts - median_gm_cbf)) / 0.675
11495    indx = np.abs(gm_cbf_ts - median_gm_cbf) > (2.5 * mad_gm_cbf)
11496    
11497    # the spatial mean
11498    spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11499    spatmean = ants.from_numpy( spatmeannp )
11500    V = (
11501        n_gm_voxels * np.var(spatmeannp[gmbool])
11502        + n_wm_voxels * np.var(spatmeannp[wmbool])
11503        + n_csf_voxels * np.var(spatmeannp[csfbool])
11504    )
11505    V1 = math.inf
11506    ct=0
11507    while V < V1:
11508        ct=ct+1
11509        V1 = V
11510        CC = np.zeros(cbfts_np.shape[3])
11511        for s in range(cbfts_np.shape[3]):
11512            if indx[s]:
11513                continue
11514            tmp1 = ants.from_numpy( cbfts_np[:, :, :, s] )
11515            CC[s] = ants.image_similarity( spatmean, tmp1, metric_type='Correlation', fixed_mask=mask1img )
11516        inx = np.argmin(CC)
11517        indx[inx] = True
11518        spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11519        spatmean = ants.from_numpy( spatmeannp )
11520        V = (
11521          n_gm_voxels * np.var(spatmeannp[gmbool]) + 
11522          n_wm_voxels * np.var(spatmeannp[wmbool]) + 
11523          n_csf_voxels * np.var(spatmeannp[csfbool])
11524        )
11525    cbfts_recon = cbfts_np[:, :, :, ~indx]
11526    cbfts_recon = np.nan_to_num(cbfts_recon)
11527    cbfts_recon_ants = ants.from_numpy(cbfts_recon)
11528    cbfts_recon_ants = ants.copy_image_info(cbfts, cbfts_recon_ants)
11529    return cbfts_recon_ants, indx
11530
11531def loop_timeseries_censoring(x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11532    """
11533    Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).
11534
11535    Parameters:
11536    x (ANTsImage): A 4D time series image.
11537    threshold (float): Threshold for determining high leverage volumes based on LoOP scores.
11538    mask (antsImage): restricts to a ROI
11539    n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used
11540    seed (int): random seed
11541    verbose (bool)
11542
11543    Returns:
11544    tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.
11545    """
11546    import warnings
11547    if x.shape[3] < 20: # just a guess at what we need here ...
11548        warnings.warn("Warning: the time dimension is < 20 - too few samples for loop. just return the original data.")
11549        return x, []
11550    if mask is None:
11551        flattened_series = flatten_time_series(x.numpy())
11552    else:
11553        flattened_series = ants.timeseries_to_matrix( x, mask )
11554    if verbose:
11555        print("loop_timeseries_censoring: flattened")
11556    loop_scores = calculate_loop_scores(flattened_series, n_features_sample=n_features_sample, seed=seed, verbose=verbose )
11557    high_leverage_volumes = np.where(loop_scores > threshold)[0]
11558    if verbose:
11559        print("loop_timeseries_censoring: High Leverage Volumes:", high_leverage_volumes)
11560    new_asl = remove_volumes_from_timeseries(x, high_leverage_volumes)
11561    return new_asl, high_leverage_volumes
11562
11563
11564def novelty_detection_ee(df_train, df_test, contamination=0.05):
11565    """
11566    This function performs novelty detection using Elliptic Envelope.
11567
11568    Parameters:
11569
11570    - df_train (pandas dataframe): training data used to fit the model
11571
11572    - df_test (pandas dataframe): test data used to predict novelties
11573
11574    - contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)
11575
11576    Returns:
11577
11578    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11579    """
11580    import pandas as pd
11581    from sklearn.covariance import EllipticEnvelope
11582    # Fit the model on the training data
11583    clf = EllipticEnvelope(contamination=contamination,support_fraction=1)
11584    df_train[ df_train == math.inf ] = 0
11585    df_test[ df_test == math.inf ] = 0
11586    from sklearn.preprocessing import StandardScaler
11587    scaler = StandardScaler()
11588    scaler.fit(df_train)
11589    clf.fit(scaler.transform(df_train))
11590    predictions = clf.predict(scaler.transform(df_test))
11591    predictions[predictions==1]=0
11592    predictions[predictions==-1]=1
11593    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11594        return pd.Series(predictions, index=df_test.index)
11595    else:
11596        return pd.Series(predictions)
11597
11598
11599
11600def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11601    """
11602    This function performs novelty detection using One-Class SVM.
11603
11604    Parameters:
11605
11606    - df_train (pandas dataframe): training data used to fit the model
11607
11608    - df_test (pandas dataframe): test data used to predict novelties
11609
11610    - nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)
11611
11612    - kernel (str): kernel type used in the SVM algorithm (default: 'rbf')
11613
11614    Returns:
11615
11616    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11617    """
11618    from sklearn.svm import OneClassSVM
11619    # Fit the model on the training data
11620    df_train[ df_train == math.inf ] = 0
11621    df_test[ df_test == math.inf ] = 0
11622    clf = OneClassSVM(nu=nu, kernel=kernel)
11623    from sklearn.preprocessing import StandardScaler
11624    scaler = StandardScaler()
11625    scaler.fit(df_train)
11626    clf.fit(scaler.transform(df_train))
11627    predictions = clf.predict(scaler.transform(df_test))
11628    predictions[predictions==1]=0
11629    predictions[predictions==-1]=1
11630    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11631        return pd.Series(predictions, index=df_test.index)
11632    else:
11633        return pd.Series(predictions)
11634
11635
11636
11637def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11638    """
11639    This function performs novelty detection using Local Outlier Factor (LOF).
11640
11641    Parameters:
11642
11643    - df_train (pandas dataframe): training data used to fit the model
11644
11645    - df_test (pandas dataframe): test data used to predict novelties
11646
11647    - n_neighbors (int): number of neighbors used to compute the LOF (default: 20)
11648
11649    Returns:
11650
11651    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11652
11653    """
11654    from sklearn.neighbors import LocalOutlierFactor
11655    # Fit the model on the training data
11656    df_train[ df_train == math.inf ] = 0
11657    df_test[ df_test == math.inf ] = 0
11658    clf = LocalOutlierFactor(n_neighbors=n_neighbors, algorithm='auto',contamination='auto', novelty=True)
11659    from sklearn.preprocessing import StandardScaler
11660    scaler = StandardScaler()
11661    scaler.fit(df_train)
11662    clf.fit(scaler.transform(df_train))
11663    predictions = clf.predict(scaler.transform(df_test))
11664    predictions[predictions==1]=0
11665    predictions[predictions==-1]=1
11666    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11667        return pd.Series(predictions, index=df_test.index)
11668    else:
11669        return pd.Series(predictions)
11670
11671
11672def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11673    """
11674    This function performs novelty detection using Local Outlier Factor (LOF).
11675
11676    Parameters:
11677
11678    - df_train (pandas dataframe): training data used to fit the model
11679
11680    - df_test (pandas dataframe): test data used to predict novelties
11681
11682    - n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)
11683
11684    - distance_metric : default minkowski
11685
11686    Returns:
11687
11688    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11689
11690    """
11691    from PyNomaly import loop
11692    from sklearn.neighbors import NearestNeighbors
11693    from sklearn.preprocessing import StandardScaler
11694    scaler = StandardScaler()
11695    scaler.fit(df_train)
11696    data = np.vstack( [scaler.transform(df_test),scaler.transform(df_train)])
11697    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric=distance_metric)
11698    neigh.fit(data)
11699    d, idx = neigh.kneighbors(data, return_distance=True)
11700    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11701    return m.local_outlier_probabilities[range(df_test.shape[0])]
11702
11703
11704
11705def novelty_detection_quantile(df_train, df_test):
11706    """
11707    This function performs novelty detection using quantiles for each column.
11708
11709    Parameters:
11710
11711    - df_train (pandas dataframe): training data used to fit the model
11712
11713    - df_test (pandas dataframe): test data used to predict novelties
11714
11715    Returns:
11716
11717    - quantiles for the test sample at each column where values range in [0,1]
11718        and higher values mean the column is closer to the edge of the distribution
11719
11720    """
11721    myqs = df_test.copy()
11722    n = df_train.shape[0]
11723    df_trainkeys = df_train.keys()
11724    for k in range( df_train.shape[1] ):
11725        mykey = df_trainkeys[k]
11726        temp = (myqs[mykey][0] >  df_train[mykey]).sum() / n
11727        myqs[mykey] = abs( temp - 0.5 ) / 0.5
11728    return myqs
11729
11730
11731
11732def shorten_pymm_names(x):
11733    """
11734    Shortens pmymm names by applying a series of regex substitutions.
11735    
11736    Parameters:
11737    x (str): The input string to be shortened
11738    
11739    Returns:
11740    str: The shortened string
11741    """
11742    xx = x.lower()
11743    xx = re.sub("_", ".", xx)  # Replace underscores with periods
11744    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11745    # Apply the following regex substitutions in order
11746    xx = re.sub("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.","ilf.and.ifo", xx, flags=re.I)
11747    xx = re.sub(r"sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.", "ilf.and.ifo", xx, flags=re.I)
11748    xx = re.sub(r".cres.stria.terminalis.can.not.be.resolved.with.current.resolution.", "", 
11749xx, flags=re.I)
11750    xx = re.sub("_", ".", xx)  # Replace underscores with periods
11751    xx = re.sub(r"longitudinal.fasciculus", "l.fasc", xx, flags=re.I)
11752    xx = re.sub(r"corona.radiata", "cor.rad", xx, flags=re.I)
11753    xx = re.sub("central", "cent", xx, flags=re.I)
11754    xx = re.sub(r"deep.cit168", "dp.", xx, flags=re.I)
11755    xx = re.sub("cit168", "", xx, flags=re.I)
11756    xx = re.sub(".include", "", xx, flags=re.I)
11757    xx = re.sub("mtg.sn", "", xx, flags=re.I)
11758    xx = re.sub("brainstem", ".bst", xx, flags=re.I)
11759    xx = re.sub(r"rsfmri.", "rsf.", xx, flags=re.I)
11760    xx = re.sub(r"dti.mean.fa.", "dti.fa.", xx, flags=re.I)
11761    xx = re.sub("perf.cbf.mean.", "cbf.", xx, flags=re.I)
11762    xx = re.sub(".jhu.icbm.labels.1mm", "", xx, flags=re.I)
11763    xx = re.sub(".include.optic.radiation.", "", xx, flags=re.I)
11764    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11765    xx = re.sub("\.\.", ".", xx, flags=re.I)  # Replace double dots with single dot
11766    xx = re.sub("cerebellar.peduncle", "cereb.ped", xx, flags=re.I)
11767    xx = re.sub(r"anterior.limb.of.internal.capsule", "ant.int.cap", xx, flags=re.I)
11768    xx = re.sub(r"posterior.limb.of.internal.capsule", "post.int.cap", xx, flags=re.I)
11769    xx = re.sub("t1hier.", "t1.", xx, flags=re.I)
11770    xx = re.sub("anterior", "ant", xx, flags=re.I)
11771    xx = re.sub("posterior", "post", xx, flags=re.I)
11772    xx = re.sub("inferior", "inf", xx, flags=re.I)
11773    xx = re.sub("superior", "sup", xx, flags=re.I)
11774    xx = re.sub(r"dktcortex", ".ctx", xx, flags=re.I)
11775    xx = re.sub(".lravg", "", xx, flags=re.I)
11776    xx = re.sub("dti.mean.fa", "dti.fa", xx, flags=re.I)
11777    xx = re.sub(r"retrolenticular.part.of.internal", "rent.int.cap", xx, flags=re.I)
11778    xx = re.sub(r"iculus.could.be.a.part.of.ant.internal.capsule", "", xx, flags=re.I)  # Twice
11779    xx = re.sub(".fronto.occipital.", ".frnt.occ.", xx, flags=re.I)
11780    xx = re.sub(r".longitidinal.fasciculus.", ".long.fasc.", xx, flags=re.I)  # Twice
11781    xx = re.sub(".external.capsule", ".ext.cap", xx, flags=re.I)
11782    xx = re.sub("of.internal.capsule", ".int.cap", xx, flags=re.I)
11783    xx = re.sub("fornix.cres.stria.terminalis", "fornix.", xx, flags=re.I)
11784    xx = re.sub("capsule", "", xx, flags=re.I)
11785    xx = re.sub("and.inf.frnt.occ.fasciculus.", "", xx, flags=re.I)
11786    xx = re.sub("crossing.tract.a.part.of.mcp.", "", xx, flags=re.I)
11787    return xx[:40]  # Truncate to first 40 characters
11788
11789
11790def shorten_pymm_names2(x, verbose=False ):
11791    """
11792    Shortens pmymm names by applying a series of regex substitutions.
11793
11794    Parameters:
11795    x (str): The input string to be shortened
11796
11797    verbose (bool): explain the patterns and replacements and their impact
11798
11799    Returns:
11800    str: The shortened string
11801    """
11802    # Define substitution patterns as tuples
11803    substitutions = [
11804        ("_", "."),  
11805        ("\.\.", "."),
11806        ("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.","ilf.and.ifo"),
11807        (r"sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus.", "ilf.and.ifo"),
11808        (r".cres.stria.terminalis.can.not.be.resolved.with.current.resolution.", ""),
11809        ("_", "."),
11810        (r"longitudinal.fasciculus", "l.fasc"),
11811        (r"corona.radiata", "cor.rad"),
11812        ("central", "cent"),
11813        (r"deep.cit168", "dp."),
11814        ("cit168", ""),
11815        (".include", ""),
11816        ("mtg.sn", ""),
11817        ("brainstem", ".bst"),
11818        (r"rsfmri.", "rsf."),
11819        (r"dti.mean.fa.", "dti.fa."),
11820        ("perf.cbf.mean.", "cbf."),
11821        (".jhu.icbm.labels.1mm", ""),
11822        (".include.optic.radiation.", ""),
11823        ("\.\.", "."),  # Replace double dots with single dot
11824        ("\.\.", "."),  # Replace double dots with single dot
11825        ("cerebellar.peduncle", "cereb.ped"),
11826        (r"anterior.limb.of.internal.capsule", "ant.int.cap"),
11827        (r"posterior.limb.of.internal.capsule", "post.int.cap"),
11828        ("t1hier.", "t1."),
11829        ("anterior", "ant"),
11830        ("posterior", "post"),
11831        ("inferior", "inf"),
11832        ("superior", "sup"),
11833        (r"dktcortex", ".ctx"),
11834        (".lravg", ""),
11835        ("dti.mean.fa", "dti.fa"),
11836        (r"retrolenticular.part.of.internal", "rent.int.cap"),
11837        (r"iculus.could.be.a.part.of.ant.internal.capsule", ""),  # Twice
11838        (".fronto.occipital.", ".frnt.occ."),
11839        (r".longitidinal.fasciculus.", ".long.fasc."),  # Twice
11840        (".external.capsule", ".ext.cap"),
11841        ("of.internal.capsule", ".int.cap"),
11842        ("fornix.cres.stria.terminalis", "fornix."),
11843        ("capsule", ""),
11844        ("and.inf.frnt.occ.fasciculus.", ""),
11845        ("crossing.tract.a.part.of.mcp.", "")
11846      ]
11847
11848    # Apply substitutions in order
11849    for pattern, replacement in substitutions:
11850        if verbose:
11851            print("Pre " + x + " pattern "+pattern + " repl " + replacement )
11852        x = re.sub(pattern, replacement, x.lower(), flags=re.IGNORECASE)
11853        if verbose:
11854            print("Post " + x)
11855
11856    return x[:40]  # Truncate to first 40 characters
11857
11858
11859def brainmap_figure(statistical_df, data_dictionary, output_prefix, brain_image, overlay_cmap='bwr', nslices=21, ncol=7, edge_image_dilation = 0, black_bg=True, axes = [0,1,2], fixed_overlay_range=None, crop=5, verbose=0 ):
11860    """
11861    Create figures based on statistical data and an underlying brain image.
11862
11863    Assumes both ~/.antspyt1w and ~/.antspymm data is available
11864
11865    Parameters:
11866    - statistical_df (pandas dataframe): with 2 columns named anat and values
11867        the anat column should have names that meet *partial matching* criterion 
11868        with respect to regions that are measured in antspymm.   value will be 
11869        the value to be displayed.   if two examples of a given region exist in 
11870        statistical_df, then the largest absolute value will be taken for display.
11871    - data_dictionary (pandas dataframe): antspymm data dictionary.
11872    - output_prefix (str): Prefix for the output figure filenames.
11873    - brain_image (antsImage): the brain image on which results will overlay.
11874    - overlay_cmap (str): see matplotlib
11875    - nslices (int): number of slices to show
11876    - ncol (int): number of columns to show
11877    - edge_image_dilation (int): integer greater than or equal to zero
11878    - black_bg (bool): boolean
11879    - axes (list): integer list typically [0,1,2] sagittal coronal axial
11880    - fixed_overlay_range (list): scalar pair will try to keep a constant cbar and will truncate the overlay at these min/max values
11881    - crop (int): crops the image to display by the extent of the overlay; larger values dilate the masks more.
11882    - verbose (bool): boolean
11883
11884    Returns:
11885    an image with values mapped to the associated regions
11886    """
11887    import re
11888
11889    def is_bst_region(filename):
11890        return filename[-4:] == '.bst'
11891
11892    # Read the statistical file
11893    zz = statistical_df 
11894    
11895    # Read the data dictionary from a CSV file
11896    mydict = data_dictionary
11897    mydict = mydict[~mydict['Measurement'].str.contains("tractography-based connectivity", na=False)]
11898    mydict2=mydict.copy()
11899    mydict2['tidynames']=mydict2['tidynames'].str.replace(".left","")
11900    mydict2['tidynames']=mydict2['tidynames'].str.replace(".right","")
11901
11902    statistical_df['anat'] = statistical_df['anat'].str.replace("_", ".", regex=True)
11903
11904    # Load image and process it
11905    edgeimg = ants.iMath(brain_image,"Normalize")
11906    if edge_image_dilation > 0:
11907        edgeimg = ants.iMath( edgeimg, "MD", edge_image_dilation)
11908
11909    # Define lists and data frames
11910    postfix = ['bf', 'cit168lab', 'mtl', 'cerebellum', 'dkt_cortex','brainstem','JHU_wm','yeo']
11911    atlas = ['BF', 'CIT168', 'MTL', 'TustisonCobra', 'desikan-killiany-tourville','brainstem','JHU_wm','yeo']
11912    postdesc = ['nbm3CH13', 'CIT168_Reinf_Learn_v1_label_descriptions_pad', 'mtl_description', 'cerebellum', 'dkt','CIT168_T1w_700um_pad_adni_brainstem','FA_JHU_labels_edited','ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic']
11913    templateprefix = '~/.antspymm/PPMI_template0_'
11914    # Iterate through columns and create figures
11915    col2viz = 'values'
11916    if True:
11917        anattoshow = zz['anat'].unique()
11918        if verbose > 0:
11919            print(col2viz)
11920            print(anattoshow)
11921        # Rest of your code for figure creation goes here...
11922        addem = edgeimg * 0
11923        for k in range(len(anattoshow)):
11924            if verbose > 0 :
11925                print(str(k) +  " " + anattoshow[k]  )
11926            mysub = zz[zz['anat'].str.contains(anattoshow[k])]
11927            anatsear=shorten_pymm_names( anattoshow[k] )
11928            anatsear=re.sub(r'[()]', '.', anatsear )
11929            anatsear=re.sub(r'\.\.', '.', anatsear )
11930            anatsear=re.sub("dti.mean.md.snc","md.snc",anatsear)
11931            anatsear=re.sub("dti.mean.fa.snc","fa.snc",anatsear)
11932            anatsear=re.sub("dti.mean.md.snr","md.snr",anatsear)
11933            anatsear=re.sub("dti.mean.fa.snr","fa.snr",anatsear)
11934            anatsear=re.sub("dti.mean.md.","",anatsear)
11935            anatsear=re.sub("dti.mean.fa.","",anatsear)
11936            anatsear=re.sub("dti.md.","",anatsear)
11937            anatsear=re.sub("dti.fa.","",anatsear)
11938            anatsear=re.sub("dti.md","",anatsear)
11939            anatsear=re.sub("dti.fa","",anatsear)
11940            anatsear=re.sub("cbf.","",anatsear)
11941            anatsear=re.sub("rsfmri.fcnxpro122.","",anatsear)
11942            anatsear=re.sub("rsfmri.fcnxpro129.","",anatsear)
11943            anatsear=re.sub("rsfmri.fcnxpro134.","",anatsear)
11944            anatsear=re.sub("t1hier.vollravg","",anatsear)
11945            anatsear=re.sub("t1hier.volasym","",anatsear)
11946            anatsear=re.sub("t1hier.thkasym","",anatsear)
11947            anatsear=re.sub("t1hier.areaasym","",anatsear)
11948            anatsear=re.sub("t1hier.vol.","",anatsear)
11949            anatsear=re.sub("t1hier.thk.","",anatsear)
11950            anatsear=re.sub("t1hier.area.","",anatsear)
11951            anatsear=re.sub("t1.volasym","",anatsear)
11952            anatsear=re.sub("t1.thkasym","",anatsear)
11953            anatsear=re.sub("t1.areaasym","",anatsear)
11954            anatsear=re.sub("t1.vol.","",anatsear)
11955            anatsear=re.sub("t1.thk.","",anatsear)
11956            anatsear=re.sub("t1.area.","",anatsear)
11957            anatsear=re.sub("asymdp.","",anatsear)
11958            anatsear=re.sub("asym.","",anatsear)
11959            anatsear=re.sub("asym","",anatsear)
11960            anatsear=re.sub("lravg.","",anatsear)
11961            anatsear=re.sub("lravg","",anatsear)
11962            anatsear=re.sub("dktcortex","",anatsear)
11963            anatsear=re.sub("dktregions","",anatsear)
11964            anatsear=re.sub("_",".",anatsear)
11965            anatsear=re.sub("superior","sup",anatsear)
11966            anatsear=re.sub("cerebellum","",anatsear)
11967            anatsear=re.sub("brainstem","",anatsear)
11968            anatsear=re.sub("t.limb.int","t.int",anatsear)
11969            anatsear=re.sub("paracentral","paracent",anatsear)
11970            anatsear=re.sub("precentral","precent",anatsear)
11971            anatsear=re.sub("postcentral","postcent",anatsear)
11972            anatsear=re.sub("sup.cerebellar.peduncle","sup.cereb.ped",anatsear)
11973            anatsear=re.sub("inferior.cerebellar.peduncle","inf.cereb.ped",anatsear)
11974            anatsear=re.sub(".crossing.tract.a.part.of.mcp.","",anatsear)
11975            anatsear=re.sub(".crossing.tract.a.part.of.","",anatsear)
11976            anatsear=re.sub(".column.and.body.of.fornix.","",anatsear)
11977            anatsear=re.sub("fronto.occipital.fasciculus.could.be.a.part.of.ant.internal.capsule","frnt.occ",anatsear)
11978            anatsear=re.sub("inferior.fronto.occipital.fasciculus.could.be.a.part.of.anterior.internal.capsule","inf.frnt.occ",anatsear)
11979            anatsear=re.sub("fornix.cres.stria.terminalis.can.not.be.resolved.with.current.resolution","fornix.column.and.body.of.fornix",anatsear)
11980            anatsear=re.sub("external.capsule","ext.cap",anatsear)
11981            anatsear=re.sub(".jhu.icbm.labels.1mm","",anatsear)
11982            anatsear=re.sub("dp.",".",anatsear)
11983            anatsear=re.sub(".mtg.sn.snc.",".snc.",anatsear)
11984            anatsear=re.sub(".mtg.sn.snr.",".snr.",anatsear)
11985            anatsear=re.sub("mtg.sn.snc.",".snc.",anatsear)
11986            anatsear=re.sub("mtg.sn.snr.",".snr.",anatsear)
11987            anatsear=re.sub("mtg.sn.snc",".snc.",anatsear)
11988            anatsear=re.sub("mtg.sn.snr",".snr.",anatsear)
11989            anatsear=re.sub("anterior.","ant.",anatsear)
11990            anatsear=re.sub("rsf.","",anatsear)
11991            anatsear=re.sub("fcnxpro122.","",anatsear)
11992            anatsear=re.sub("fcnxpro129.","",anatsear)
11993            anatsear=re.sub("fcnxpro134.","",anatsear)
11994            anatsear=re.sub("ant.corona.radiata","ant.cor.rad",anatsear)
11995            anatsear=re.sub("sup.corona.radiata","sup.cor.rad",anatsear)
11996            anatsear=re.sub("posterior.thalamic.radiation.include.optic.radiation","post.thalamic.radiation",anatsear)
11997            anatsear=re.sub("retrolenticular.part.of.internal.capsule","rent.int.cap",anatsear)
11998            anatsear=re.sub("post.limb.of.internal.capsule","post.int.cap",anatsear)
11999            anatsear=re.sub("ant.limb.of.internal.capsule","ant.int.cap",anatsear)
12000            anatsear=re.sub("sagittal.stratum.include.inferior.longitidinal.fasciculus.and.inferior.fronto.occipital.fasciculus","ilf.and.ifo",anatsear)
12001            anatsear=re.sub("post.thalamic.radiation.optic.rad","post.thalamic.radiation",anatsear)
12002            atlassearch = mydict['tidynames'].str.contains(anatsear)
12003            if atlassearch.sum() == 0:
12004                atlassearch = mydict2['tidynames'].str.contains(anatsear)
12005            if verbose > 0 :
12006                print( " anatsear " + anatsear + " atlassearch " )
12007            if atlassearch.sum() > 0:
12008                whichatlas = mydict[atlassearch]['Atlas'].iloc[0]
12009                oglabelname = mydict[atlassearch]['Label'].iloc[0]
12010                oglabelname=re.sub("_",".",oglabelname)
12011                oglabelname=re.sub(r'\.\.','.',oglabelname)
12012            else:
12013                print(anatsear)
12014                oglabelname='unknown'
12015                whichatlas=None
12016            if verbose > 0:
12017                print("oglabelname " + oglabelname + " whichatlas " + str(whichatlas) )
12018            vals2viz = mysub[col2viz].agg(['min', 'max'])
12019            vals2viz = vals2viz[abs(vals2viz).idxmax()]
12020            myext = None
12021            if anatsear == 'cingulum.hippocampus':
12022                myext = 'JHU_wm'
12023            elif 'dktcortex' in anattoshow[k] or whichatlas == 'desikan-killiany-tourville' or 'dtkregions' in anattoshow[k]  :
12024                myext = 'dkt_cortex'
12025            elif ('cit168' in anattoshow[k] or whichatlas == 'CIT168') and not 'brainstem' in anattoshow[k] and not is_bst_region(anatsear):
12026                myext = 'cit168lab'
12027            elif 'mtl' in anattoshow[k]:
12028                myext = 'mtl'
12029                oglabelname=re.sub('mtl', '',anatsear)
12030            elif 'cerebellum' in anattoshow[k]:
12031                myext = 'cerebellum'
12032                oglabelname=re.sub('cerebellum', '',anatsear)
12033                oglabelname=re.sub('t1.vo','',oglabelname)
12034                # oglabelname=oglabelname[2:]
12035            elif 'brainstem' in anattoshow[k] or is_bst_region(anatsear):
12036                myext = 'brainstem'
12037            elif any(item in anattoshow[k] for item in ['nbm', 'bf']):
12038                myext = 'bf'
12039                oglabelname=re.sub('bf', '',oglabelname)
12040#                oglabelname=re.sub(r'\.', '_',anatsear)
12041            elif whichatlas == 'johns hopkins white matter':
12042                myext = 'JHU_wm'
12043            elif whichatlas == 'desikan-killiany-tourville':
12044                myext = 'dkt_cortex'
12045            elif whichatlas == 'CIT168':
12046                myext = 'cit168lab'
12047            elif whichatlas == 'BF':
12048                myext = 'bf'
12049                oglabelname=re.sub('bf', '',oglabelname)
12050            elif whichatlas == 'yeo_homotopic':
12051                myext = 'yeo'
12052            if myext is None and verbose > 0 :
12053                if whichatlas is None:
12054                    whichatlas='None'
12055                if anattoshow[k] is None:
12056                    anattoshow[k]='None'
12057                print( "MYEXT " + anattoshow[k] + ' unfound ' + whichatlas )
12058            else:
12059                if verbose > 0 :
12060                    print( "MYEXT " + myext )
12061
12062            if myext == 'cit168lab':
12063                oglabelname=re.sub("cit168","",oglabelname)
12064            
12065            for j in postfix:
12066                if j == "dkt_cortex":
12067                    j = 'dktcortex'
12068                if j == "deep_cit168lab":
12069                    j = 'deep_cit168'
12070                anattoshow[k] = anattoshow[k].replace(j, "")
12071            if verbose > 0:
12072                print( anattoshow[k] + " " + str( vals2viz ) )
12073            correctdescript = postdesc[postfix.index(myext)]
12074            locfilename =  templateprefix + myext + '.nii.gz'
12075            if verbose > 0:
12076                print( locfilename )
12077            if myext == 'yeo':
12078                oglabelname=oglabelname.lower()
12079                oglabelname=re.sub("rsfmri_fcnxpro122_","",oglabelname)
12080                oglabelname=re.sub("rsfmri_fcnxpro129_","",oglabelname)
12081                oglabelname=re.sub("rsfmri_fcnxpro134_","",oglabelname)
12082                oglabelname=re.sub("rsfmri.fcnxpro122.","",oglabelname)
12083                oglabelname=re.sub("rsfmri.fcnxpro129.","",oglabelname)
12084                oglabelname=re.sub("rsfmri.fcnxpro134.","",oglabelname)
12085                oglabelname=re.sub("_",".",oglabelname)
12086                locfilename = "~/.antspymm/ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic.nii.gz"
12087                atlasDescript = pd.read_csv(f"~/.antspymm/{correctdescript}.csv")
12088                atlasDescript.rename(columns={'SystemName': 'Description'}, inplace=True)
12089                atlasDescript.rename(columns={'ROI': 'Label'}, inplace=True)
12090                atlasDescript['Description'] = atlasDescript['Description'].str.lower()
12091            else:
12092                atlasDescript = pd.read_csv(f"~/.antspyt1w/{correctdescript}.csv")
12093                atlasDescript['Description'] = atlasDescript['Description'].str.lower()
12094                atlasDescript['Description'] = atlasDescript['Description'].str.replace(" ", "_")
12095                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_left_", "_")
12096                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_right_", "_")
12097                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_left", "")
12098                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_right", "")
12099                atlasDescript['Description'] = atlasDescript['Description'].str.replace("left_", "")
12100                atlasDescript['Description'] = atlasDescript['Description'].str.replace("right_", "")
12101                atlasDescript['Description'] = atlasDescript['Description'].str.replace("/",".")
12102                atlasDescript['Description'] = atlasDescript['Description'].str.replace("_",".")
12103                atlasDescript['Description'] = atlasDescript['Description'].str.replace(r'[()]', '', regex=True)
12104                atlasDescript['Description'] = atlasDescript['Description'].str.replace(r'\.\.', '.')
12105                if myext == 'JHU_wm':
12106                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-", ".")
12107                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("jhu.icbm.labels.1mm", "")
12108                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("fronto-occipital", "frnt.occ")
12109                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("superior", "sup")
12110                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("fa-", "")
12111                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-left-", "")
12112                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("-right-", "")
12113                if myext == 'cerebellum':
12114                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l_", "")
12115                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r_", "")
12116                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l.", "")
12117                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r.", "")
12118                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("_",".")
12119
12120            if verbose > 0:
12121                print( atlasDescript )
12122            oglabelname = oglabelname.lower()
12123            oglabelname = re.sub(" ", "_",oglabelname)
12124            oglabelname = re.sub("_left_", "_",oglabelname)
12125            oglabelname = re.sub("_right_", "_",oglabelname)
12126            oglabelname = re.sub("_left", "",oglabelname)
12127            oglabelname = re.sub("_right", "",oglabelname)
12128            oglabelname = re.sub("t1hier_vol_", "",oglabelname)
12129            oglabelname = re.sub("t1hier_area_", "",oglabelname)
12130            oglabelname = re.sub("t1hier_thk_", "",oglabelname)
12131            oglabelname = re.sub("dktregions", "",oglabelname)
12132            oglabelname = re.sub("dktcortex", "",oglabelname)
12133
12134            oglabelname = re.sub(" ", ".",oglabelname)
12135            oglabelname = re.sub(".left.", ".",oglabelname)
12136            oglabelname = re.sub(".right.", ".",oglabelname)
12137            oglabelname = re.sub(".left", "",oglabelname)
12138            oglabelname = re.sub(".right", "",oglabelname)
12139            oglabelname = re.sub("t1hier.vol.", "",oglabelname)
12140            oglabelname = re.sub("t1hier.area.", "",oglabelname)
12141            oglabelname = re.sub("t1hier.thk.", "",oglabelname)
12142            oglabelname = re.sub("dktregions", "",oglabelname)
12143            oglabelname = re.sub("dktcortex", "",oglabelname)
12144            oglabelname=re.sub("brainstem","",oglabelname)
12145            if myext == 'JHU_wm':
12146                oglabelname = re.sub("dti_mean_fa.", "",oglabelname)
12147                oglabelname = re.sub("dti_mean_md.", "",oglabelname)
12148                oglabelname = re.sub("dti.mean.fa.", "",oglabelname)
12149                oglabelname = re.sub("dti.mean.md.", "",oglabelname)
12150                oglabelname = re.sub(".left.", "",oglabelname)
12151                oglabelname = re.sub(".right.", "",oglabelname)
12152                oglabelname = re.sub(".lravg.", "",oglabelname)
12153                oglabelname = re.sub(".asym.", "",oglabelname)
12154                oglabelname = re.sub(".jhu.icbm.labels.1mm", "",oglabelname)
12155                oglabelname = re.sub("superior", "sup",oglabelname)
12156
12157            if verbose > 0:
12158                print("oglabelname " + oglabelname )
12159
12160            if myext == 'cerebellum':
12161                if not atlasDescript.empty and 'Description' in atlasDescript.columns:
12162                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("l_", "")
12163                    atlasDescript['Description'] = atlasDescript['Description'].str.replace("r_", "")
12164                    oglabelname=re.sub("ravg","",oglabelname)
12165                    oglabelname=re.sub("lavg","",oglabelname)
12166                    whichindex = atlasDescript.index[atlasDescript['Description'] == oglabelname].values
12167                else:
12168                    if atlasDescript.empty:
12169                        print("The DataFrame 'atlasDescript' is empty.")
12170                    if 'Description' not in atlasDescript.columns:
12171                        print("The column 'Description' does not exist in 'atlasDescript'.")
12172            else:
12173                whichindex = atlasDescript.index[atlasDescript['Description'].str.contains(oglabelname)]
12174
12175            if type(whichindex) is np.int64:
12176                labelnums = atlasDescript.loc[whichindex, 'Label']
12177            else:
12178                labelnums = list(atlasDescript.loc[whichindex, 'Label'])
12179
12180            if myext == 'yeo':
12181                parts = re.findall(r'\D+', oglabelname)
12182                oglabelname = [part.replace('_', '') for part in parts if part.replace('_', '')]
12183                oglabelname = [part.replace('.', '') for part in parts if part.replace('.', '')]
12184                filtered_df = atlasDescript[atlasDescript['Description'].isin(oglabelname)]
12185                labelnums = filtered_df['Label'].tolist()
12186
12187            if not isinstance(labelnums, list):
12188                labelnums=[labelnums]
12189            addemiszero = ants.threshold_image(addem, 0, 0)
12190            temp = ants.image_read(locfilename)
12191            temp = ants.mask_image(temp, temp, level=labelnums, binarize=True)
12192            if verbose > 0:
12193                print("DEBUG")
12194                print(  temp.sum() ) 
12195                print( labelnums )
12196            temp[temp == 1] = (vals2viz)
12197            temp[addemiszero == 0] = 0
12198            addem = addem + temp
12199
12200        if verbose > 0:
12201            print('Done Adding')
12202        for axx in axes:
12203            figfn=output_prefix+f"fig{col2viz}ax{axx}_py.jpg"
12204            if crop > 0:
12205                cmask = ants.threshold_image( addem,1e-5, 1e9 ).iMath("MD",crop) + ants.threshold_image( addem,-1e9, -1e-5 ).iMath("MD",crop)
12206                addemC = ants.crop_image( addem, cmask )
12207                edgeimgC = ants.crop_image( edgeimg, cmask )
12208            else:
12209                addemC = addem
12210                edgeimgC = edgeimg
12211            if fixed_overlay_range is not None:
12212                addemC[0:3,0:3,0:3]=fixed_overlay_range[0]
12213                addemC[4:7,4:7,4:7]=fixed_overlay_range[1]
12214                addemC[ addemC <= fixed_overlay_range[0] ] = 0 # fixed_overlay_range[0]
12215                addemC[ addemC >= fixed_overlay_range[1] ] = fixed_overlay_range[1]
12216            ants.plot(edgeimgC, addemC, axis=axx, nslices=nslices, ncol=ncol,       
12217                overlay_cmap=overlay_cmap, resample=False, overlay_alpha=1.0,
12218                filename=figfn, cbar=axx==axes[0], crop=True, black_bg=black_bg )
12219        if verbose > 0:
12220            print(f"{col2viz} done")
12221    if verbose:
12222        print("DONE brain map figures")
12223    return addem
12224
12225def filter_df(indf, myprefix):
12226    """
12227    Process and filter a pandas DataFrame, removing certain columns, 
12228    filtering based on data types, computing the mean of numeric columns, 
12229    and adding a prefix to column names.
12230
12231    Parameters:
12232    indf (pandas.DataFrame): The input DataFrame to be processed.
12233    myprefix (str): A string prefix to be added to the column names 
12234                    of the processed DataFrame.
12235
12236    Steps:
12237    1. Removes columns with names containing 'Unnamed'.
12238    2. If the DataFrame has no rows, it returns the empty DataFrame.
12239    3. Filters out columns based on the type of the first element, 
12240       keeping those that are of type `object`, `int`, or `float`.
12241    4. Removes columns that are of `object` dtype.
12242    5. Calculates the mean of the remaining columns, skipping NaN values.
12243    6. Adds the specified `myprefix` to the column names.
12244
12245    Returns:
12246    pandas.DataFrame: A transformed DataFrame with a single row containing 
12247                      the mean values of the filtered columns, and with 
12248                      column names prefixed as specified.
12249    """
12250    indf = indf.loc[:, ~indf.columns.str.contains('Unnamed*', na=False, regex=True)]
12251    if indf.shape[0] == 0:
12252        return indf
12253    nums = [isinstance(indf[col].iloc[0], (object, int, float)) for col in indf.columns]
12254    indf = indf.loc[:, nums]
12255    indf = indf.loc[:, indf.dtypes != 'object']
12256    indf = pd.DataFrame(indf.mean(axis=0, skipna=True)).T
12257    indf = indf.add_prefix(myprefix)
12258    return indf
12259
12260
12261def aggregate_antspymm_results(input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Processed/ANTsExpArt/", hiervariable='T1wHierarchical', valid_modalities=None, verbose=False ):
12262    """
12263    Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.
12264
12265    Parameters:
12266    - input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
12267    - subject_col (str): Name of the column to store subject IDs.
12268    - date_col (str): Name of the column to store date information.
12269    - image_col (str): Name of the column to store image IDs.
12270    - date_column (str): Name of the column representing the date information.
12271    - base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
12272    - hiervariable (str) : the string variable denoting the Hierarchical output
12273    - valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
12274    - verbose : boolean
12275
12276    Note:
12277    This function is tested under limited circumstances. Use with caution.
12278
12279    Example usage:
12280    agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")
12281
12282    Author:
12283    Avants and ChatGPT
12284    """
12285    import pandas as pd
12286    import numpy as np
12287    from glob import glob
12288
12289    def myread_csv(x, cnms):
12290        """
12291        Reads a CSV file and returns a DataFrame excluding specified columns.
12292
12293        Parameters:
12294        - x (str): File path of the input CSV file describing the blind QC output
12295        - cnms (list): List of column names to exclude from the DataFrame.
12296
12297        Returns:
12298        pd.DataFrame: DataFrame with specified columns excluded.
12299        """
12300        df = pd.read_csv(x)
12301        return df.loc[:, ~df.columns.isin(cnms)]
12302
12303    import warnings
12304    # Warning message for untested function
12305    warnings.warn("Warning: This function is not well tested. Use with caution.")
12306
12307    if valid_modalities is None:
12308        valid_modalities = get_valid_modalities('long')
12309
12310    # Read the input CSV file
12311    df = pd.read_csv(input_csv)
12312
12313    # Filter rows where modality is 'T1w'
12314    df = df[df['modality'] == 'T1w']
12315    badnames = get_names_from_data_frame( ['Unnamed'], df )
12316    df=df.drop(badnames, axis=1)
12317
12318    # Add new columns for subject ID, date, and image ID
12319    df[subject_col] = np.nan
12320    df[date_col] = date_column
12321    df[image_col] = np.nan
12322    df = df.astype({subject_col: str, date_col: str, image_col: str })
12323
12324#    if verbose:
12325#        print( df.shape )
12326#        print( df.dtypes )
12327
12328    # prefilter df for data that exists
12329    keep = np.tile( False, df.shape[0] )
12330    for x in range(df.shape[0]):
12331        temp = df['filename'].iloc[x].split("_")
12332        # Generalized search paths
12333        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12334        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12335        if len( hierfn ) > 0:
12336            keep[x]=True
12337
12338    
12339    df=df[keep]
12340    
12341    if verbose:
12342        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12343        print( df.shape )
12344
12345    myct = 0
12346    for x in range( df.shape[0]):
12347        if verbose:
12348            print(f"{x}...")
12349        locind = df.index[x]
12350        temp = df['filename'].iloc[x].split("_")
12351        if verbose:
12352            print( temp )
12353        df[subject_col].iloc[x]=temp[0]
12354        df[date_col].iloc[x]=date_column
12355        df[image_col].iloc[x]=temp[1]
12356
12357        # Generalized search paths
12358        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12359        if verbose:
12360            print(path_template)
12361        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12362        if len( hierfn ) > 0:
12363            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12364            if verbose:
12365                print(hierfn)
12366            hdf = pd.read_csv(hierfn[0])
12367            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12368            hdf=hdf.drop(badnames, axis=1)
12369            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12370            corenames = list(np.array(hdf.columns)[nums])
12371            hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12372            myct = myct + 1
12373            dflist = [hdf]
12374
12375            for mymod in valid_modalities:
12376                t1wfn = sorted(glob( path_template+ "-" + mymod + "-*wide.csv" ) )
12377                if len( t1wfn ) > 0 :
12378                    if verbose:
12379                        print(t1wfn)
12380                    t1df = myread_csv(t1wfn[0], corenames)
12381                    t1df = filter_df( t1df, mymod+'_')
12382                    dflist = dflist + [t1df]
12383                
12384            hdf = pd.concat( dflist, axis=1, ignore_index=False )
12385            if verbose:
12386                print( df.loc[locind,'filename'] )
12387            if myct == 1:
12388                subdf = df.iloc[[x]]
12389                hdf.index = subdf.index.copy()
12390                df = pd.concat( [df,hdf], axis=1, ignore_index=False )
12391            else:
12392                commcols = list(set(hdf.columns).intersection(df.columns))
12393                df.loc[locind, commcols] = hdf.loc[0, commcols]
12394    badnames = get_names_from_data_frame( ['Unnamed'], df )
12395    df=df.drop(badnames, axis=1)
12396    return( df )
12397
12398def find_most_recent_file(file_list):
12399    """
12400    Finds and returns the most recently modified file from a list of file paths.
12401    
12402    Parameters:
12403    - file_list: A list of strings, where each string is a path to a file.
12404    
12405    Returns:
12406    - The path to the most recently modified file in the list, or None if the list is empty or contains no valid files.
12407    """
12408    # Filter out items that are not files or do not exist
12409    valid_files = [f for f in file_list if os.path.isfile(f)]
12410    
12411    # Check if the filtered list is not empty
12412    if valid_files:
12413        # Find the file with the latest modification time
12414        most_recent_file = max(valid_files, key=os.path.getmtime)
12415        return [most_recent_file]
12416    else:
12417        return None
12418    
12419def aggregate_antspymm_results_sdf(
12420    study_df, 
12421    project_col='projectID',
12422    subject_col='subjectID', 
12423    date_col='date', 
12424    image_col='imageID', 
12425    base_path="./", 
12426    hiervariable='T1wHierarchical', 
12427    splitsep='-',
12428    idsep='-',
12429    wild_card_modality_id=False,
12430    second_split=False,
12431    verbose=False ):
12432    """
12433    Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame.  This assumes data is organized on disk 
12434    as follows:  rootdir/projectID/subjectID/date/outputid/imageid/ where 
12435    outputid is modality-specific and created by ANTsPyMM processing.
12436
12437    Parameters:
12438    - study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
12439    - project_col (str): Name of the column that stores the project ID
12440    - subject_col (str): Name of the column to store subject IDs.
12441    - date_col (str): Name of the column to store date information.
12442    - image_col (str): Name of the column to store image IDs.
12443    - base_path (str): Base path for searching for processing outputs of ANTsPyMM.
12444    - hiervariable (str) : the string variable denoting the Hierarchical output
12445    - splitsep (str):  the separator used to split the filename
12446    - idsep (str): the separator used to partition subjectid date and imageid 
12447        for example, if idsep is - then we have subjectid-date-imageid
12448    - wild_card_modality_id (bool): keep if False for safer execution
12449    - second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
12450    - verbose : boolean
12451
12452    Note:
12453    This function is tested under limited circumstances. Use with caution.
12454    One particular gotcha is if the imageID is stored as a numeric value in the dataframe 
12455    but is meant to be a string.  E.g. '000' (string) would be interpreted as 0 in the 
12456    file name glob.  This would miss the extant (on disk) csv.
12457
12458    Example usage:
12459    agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")
12460
12461    Author:
12462    Avants and ChatGPT
12463    """
12464    import pandas as pd
12465    import numpy as np
12466    from glob import glob
12467
12468    def progress_reporter(current_step, total_steps, width=50):
12469        # Calculate the proportion of progress
12470        progress = current_step / total_steps
12471        # Calculate the number of 'filled' characters in the progress bar
12472        filled_length = int(width * progress)
12473        # Create the progress bar string
12474        bar = 'â–ˆ' * filled_length + '-' * (width - filled_length)
12475        # Print the progress bar with percentage
12476        print(f'\rProgress: |{bar}| {int(100 * progress)}%', end='\r')
12477        # Print a new line when the progress is complete
12478        if current_step == total_steps:
12479            print()
12480
12481    def myread_csv(x, cnms):
12482        """
12483        Reads a CSV file and returns a DataFrame excluding specified columns.
12484
12485        Parameters:
12486        - x (str): File path of the input CSV file describing the blind QC output
12487        - cnms (list): List of column names to exclude from the DataFrame.
12488
12489        Returns:
12490        pd.DataFrame: DataFrame with specified columns excluded.
12491        """
12492        df = pd.read_csv(x)
12493        return df.loc[:, ~df.columns.isin(cnms)]
12494
12495    import warnings
12496    # Warning message for untested function
12497    warnings.warn("Warning: This function is not well tested. Use with caution.")
12498
12499    vmoddict = {}
12500    # Add key-value pairs
12501    vmoddict['imageID'] = 'T1w'
12502    vmoddict['flairid'] = 'T2Flair'
12503    vmoddict['perfid'] = 'perf'
12504    vmoddict['pet3did'] = 'pet3d'
12505    vmoddict['rsfid1'] = 'rsfMRI'
12506#    vmoddict['rsfid2'] = 'rsfMRI'
12507    vmoddict['dtid1'] = 'DTI'
12508#    vmoddict['dtid2'] = 'DTI'
12509    vmoddict['nmid1'] = 'NM2DMT'
12510#    vmoddict['nmid2'] = 'NM2DMT'
12511
12512    # Filter rows where modality is 'T1w'
12513    df = study_df[ study_df['modality'] == 'T1w']
12514    badnames = get_names_from_data_frame( ['Unnamed'], df )
12515    df=df.drop(badnames, axis=1)
12516    # prefilter df for data that exists
12517    keep = np.tile( False, df.shape[0] )
12518    for x in range(df.shape[0]):
12519        myfn = os.path.basename( df['filename'].iloc[x] )
12520        temp = myfn.split( splitsep )
12521        # Generalized search paths
12522        sid0 = str( temp[1] )
12523        sid = str( df[subject_col].iloc[x] )
12524        if sid0 != sid:
12525            warnings.warn("OUTER: the id derived from the filename " + sid0 + " does not match the id stored in the data frame " + sid )
12526            warnings.warn( "filename is : " +  myfn )
12527            warnings.warn( "sid is : " + sid )
12528            warnings.warn( "x is : " + str(x) )
12529        myproj = str(df[project_col].iloc[x])
12530        mydate = str(df[date_col].iloc[x])
12531        myid = str(df[image_col].iloc[x])
12532        if second_split:
12533            myid = myid.split(".")[0]
12534        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12535        hierfn = sorted(glob( path_template + "*" + hiervariable + "*wide.csv" ) )
12536        if len( hierfn ) == 0:
12537            print( hierfn )
12538            print( path_template )
12539            print( myproj )
12540            print( sid )
12541            print( mydate ) 
12542            print( myid )
12543        if len( hierfn ) > 0:
12544            keep[x]=True
12545
12546    # df=df[keep]
12547    if df.shape[0] == 0:
12548        warnings.warn("input data frame shape is filtered down to zero")
12549        return df
12550
12551    if not df.index.is_unique:
12552        warnings.warn("data frame does not have unique indices.  we therefore reset the index to allow the function to continue on." )
12553        df = df.reset_index()
12554
12555    
12556    if verbose:
12557        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12558        print( df.shape )
12559
12560    dfout = pd.DataFrame()
12561    myct = 0
12562    for x in range( df.shape[0]):
12563        if verbose:
12564            print("\n\n-------------------------------------------------")
12565            print(f"{x}...")
12566        else:
12567            progress_reporter(x, df.shape[0], width=500)
12568        locind = df.index[x]
12569        myfn = os.path.basename( df['filename'].iloc[x] )
12570        sid = str( df[subject_col].iloc[x] )
12571        tempB = myfn.split( splitsep )
12572        sid0 = str(tempB[1])
12573        if sid0 != sid and verbose:
12574            warnings.warn("INNER: the id derived from the filename " + str(sid) + " does not match the id stored in the data frame " + str(sid0) )
12575            warnings.warn( "filename is : " +  str(myfn) )
12576            warnings.warn( "sid is : " + str(sid) )
12577            warnings.warn( "x is : " + str(x) )
12578            warnings.warn( "index is : " + str(locind) )
12579        myproj = str(df[project_col].iloc[x])
12580        mydate = str(df[date_col].iloc[x])
12581        myid = str(df[image_col].iloc[x])
12582        if second_split:
12583            myid = myid.split(".")[0]
12584        if verbose:
12585            print( myfn )
12586            print( temp )
12587            print( "id " + sid  )
12588        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12589        searchhier = path_template + "*" + hiervariable + "*wide.csv"
12590        if verbose:
12591            print( searchhier )
12592        hierfn = sorted( glob( searchhier ) )
12593        if len( hierfn ) > 1:
12594            raise ValueError("there are " + str( len( hierfn ) ) + " number of hier fns with search path " + searchhier )
12595        if len( hierfn ) == 1:
12596            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12597            if verbose:
12598                print(hierfn)
12599            hdf = pd.read_csv(hierfn[0])
12600            if verbose:
12601                print( hdf['vol_hemisphere_lefthemispheres'] )
12602            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12603            hdf=hdf.drop(badnames, axis=1)
12604            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12605            corenames = list(np.array(hdf.columns)[nums])
12606            # hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12607            hdf = hdf.add_prefix("T1Hier_")
12608            myct = myct + 1
12609            dflist = [hdf]
12610
12611            for mymod in vmoddict.keys():
12612                if verbose:
12613                    print("\n\n************************* " + mymod + " *************************")
12614                modalityclass = vmoddict[ mymod ]
12615                if wild_card_modality_id:
12616                    mymodid = '*'
12617                else:
12618                    mymodid = str( df[mymod].iloc[x] )
12619                    if mymodid.lower() != "nan" and mymodid.lower() != "na":
12620                        mymodid = os.path.basename( mymodid )
12621                        mymodid = os.path.splitext( mymodid )[0]
12622                        mymodid = os.path.splitext( mymodid )[0]
12623                        temp = mymodid.split( idsep )
12624                        mymodid = temp[ len( temp )-1 ]
12625                    else:
12626                        if verbose:
12627                            print("missing")
12628                        continue
12629                if verbose:
12630                    print( "modality id is " + mymodid + " for modality " + modalityclass + ' modality specific subj ' + sid + ' modality specific id is ' + myid + " its date " +  mydate )
12631                modalityclasssearch = modalityclass
12632                if modalityclass in ['rsfMRI','DTI']:
12633                    modalityclasssearch=modalityclass+"*"
12634                path_template_m = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + modalityclasssearch + '/' + mymodid + "/"
12635                modsearch = path_template_m + "*" + modalityclasssearch + "*wide.csv"
12636                if verbose:
12637                    print( modsearch )
12638                t1wfn = sorted( glob( modsearch ) )
12639                if len( t1wfn ) > 1:
12640                    nlarge = len(t1wfn)
12641                    t1wfn = find_most_recent_file( t1wfn )
12642                    warnings.warn("there are " + str( nlarge ) + " number of wide fns with search path " + modsearch + " we take the most recent of these " + t1wfn[0] )
12643                if len( t1wfn ) == 1:
12644                    if verbose:
12645                        print(t1wfn)
12646                    t1df = myread_csv(t1wfn[0], corenames)
12647                    t1df = filter_df( t1df, modalityclass+'_')
12648                    dflist = dflist + [t1df]
12649                else:
12650                    if verbose:
12651                        print( " cannot find " + modsearch )
12652                
12653            hdf = pd.concat( dflist, axis=1, ignore_index=False)
12654            if verbose:
12655                print( "count: " + str( myct ) )
12656            subdf = df.iloc[[x]]
12657            hdf.index = subdf.index.copy()
12658            subdf = pd.concat( [subdf,hdf], axis=1, ignore_index=False)
12659            dfout = pd.concat( [dfout,subdf], axis=0, ignore_index=False )
12660
12661    if dfout.shape[0] > 0:
12662        badnames = get_names_from_data_frame( ['Unnamed'], dfout )
12663        dfout=dfout.drop(badnames, axis=1)
12664    return dfout
12665
12666def enantiomorphic_filling_without_mask( image, axis=0, intensity='low' ):
12667    """
12668    Perform an enantiomorphic lesion filling on an image without a lesion mask.
12669
12670    Args:
12671    image (antsImage): The ants image to flip and fill
12672    axis ( int ): the axis along which to reflect the image
12673    intensity ( str ) : low or high
12674
12675    Returns:
12676    ants.ANTsImage: The image after enantiomorphic filling.
12677    """
12678    imagen = ants.iMath( image, 'Normalize' )
12679    imagen = ants.iMath( imagen, "TruncateIntensity", 1e-6, 0.98 )
12680    imagen = ants.iMath( imagen, 'Normalize' )
12681    # Create a mirror image (flipping left and right)
12682    mirror_image = ants.reflect_image(imagen, axis=0, tx='antsRegistrationSyNQuickRepro[s]' )['warpedmovout']
12683
12684    # Create a symmetric version of the image by averaging the original and the mirror image
12685    symmetric_image = imagen * 0.5 + mirror_image * 0.5
12686
12687    # Identify potential lesion areas by finding differences between the original and symmetric image
12688    difference_image = image - symmetric_image
12689    diffseg = ants.threshold_image(difference_image, "Otsu", 3 )
12690    if intensity == 'low':
12691        likely_lesion = ants.threshold_image( diffseg, 1,  1)
12692    else:
12693        likely_lesion = ants.threshold_image( diffseg, 3,  3)
12694    likely_lesion = ants.smooth_image( likely_lesion, 3.0 ).iMath("Normalize")
12695    lesionneg = ( imagen*0+1.0 ) - likely_lesion
12696    filled_image = ants.image_clone(imagen)    
12697    filled_image = imagen * lesionneg + mirror_image * likely_lesion
12698
12699    return filled_image, diffseg
12700
12701
12702
12703def filter_image_files(image_paths, criteria='largest'):
12704    """
12705    Filters a list of image file paths based on specified criteria and returns 
12706    the path of the image that best matches that criteria (smallest, largest, or brightest).
12707
12708    Args:
12709    image_paths (list): A list of file paths to the images.
12710    criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').
12711
12712    Returns:
12713    str: The file path of the selected image, or None if no valid images are found.
12714    """
12715    import numpy as np
12716    if not image_paths:
12717        return None
12718
12719    selected_image_path = None
12720    if criteria == 'smallest' or criteria == 'largest':
12721        extreme_volume = None
12722
12723        for path in image_paths:
12724            try:
12725                image = ants.image_read(path)
12726                volume = np.prod(image.shape)
12727
12728                if criteria == 'largest':
12729                    if extreme_volume is None or volume > extreme_volume:
12730                        extreme_volume = volume
12731                        selected_image_path = path
12732                elif criteria == 'smallest':
12733                    if extreme_volume is None or volume < extreme_volume:
12734                        extreme_volume = volume
12735                        selected_image_path = path
12736
12737            except Exception as e:
12738                print(f"Error processing image {path}: {e}")
12739
12740    elif criteria == 'brightest':
12741        max_brightness = None
12742
12743        for path in image_paths:
12744            try:
12745                image = ants.image_read(path)
12746                brightness = np.mean(image.numpy())
12747
12748                if max_brightness is None or brightness > max_brightness:
12749                    max_brightness = brightness
12750                    selected_image_path = path
12751
12752            except Exception as e:
12753                print(f"Error processing image {path}: {e}")
12754
12755    else:
12756        raise ValueError("Criteria must be 'smallest', 'largest', or 'brightest'.")
12757
12758    return selected_image_path
12759
12760
12761
12762def mm_match_by_qc_scoring(df_a, df_b, match_column, criteria, prefix='matched_', exclude_columns=None):
12763    """
12764    Match each row in df_a to a row in df_b based on a matching column and criteria for selecting the best match,
12765    with options to prefix column names from df_b and exclude certain columns from the final output. Additionally,
12766    returns a DataFrame containing rows from df_b that were not matched to any row in df_a.
12767
12768    Parameters:
12769    - df_a: DataFrame A.
12770    - df_b: DataFrame B.
12771    - match_column: The column name on which rows should match between DataFrame A and B.
12772    - criteria: A dictionary where keys are column names and values are 'min' or 'max', indicating whether
12773                the column should be minimized or maximized for the best match.
12774    - prefix: A string prefix to add to column names from df_b in the final output to avoid duplication.
12775    - exclude_columns: A list of column names from df_b to exclude from the final output.
12776    
12777    Returns:
12778    - A tuple of two DataFrames: 
12779        1. A new DataFrame combining df_a with matched rows from df_b.
12780        2. A DataFrame containing rows from df_b that were not matched to df_a.
12781    """
12782    from scipy.stats import zscore
12783    df_a = df_a.loc[:, ~df_a.columns.str.startswith('Unnamed:')].copy()
12784    if df_b is not None:
12785        df_b = df_b.loc[:, ~df_b.columns.str.startswith('Unnamed:')].copy()
12786    else:
12787        return df_a, pd.DataFrame()
12788    
12789    # Normalize df_b based on criteria
12790    for col, crit in criteria.items():
12791        if crit == 'max':
12792            df_b.loc[df_b.index, f'score_{col}'] = zscore(-df_b[col])
12793        elif crit == 'min':
12794            df_b.loc[df_b.index, f'score_{col}'] = zscore(df_b[col])
12795
12796    # Calculate 'best_score' by summing all score columns
12797    score_columns = [f'score_{col}' for col in criteria.keys()]
12798    df_b['best_score'] = df_b[score_columns].sum(axis=1)
12799
12800    matched_indices = []  # Track indices of matched rows in df_b
12801
12802    # Match rows
12803    matched_rows = []
12804    for _, row_a in df_a.iterrows():
12805        matches = df_b[df_b[match_column] == row_a[match_column]]
12806        if not matches.empty:
12807            best_idx = matches['best_score'].idxmin()
12808            best_match = matches.loc[best_idx]
12809            matched_indices.append(best_idx)  # Track this index as matched
12810            matched_rows.append(best_match)
12811        else:
12812            matched_rows.append(pd.Series(dtype='float64'))
12813
12814    # Create a DataFrame from matched rows
12815    df_matched = pd.DataFrame(matched_rows).reset_index(drop=True)
12816    
12817    # Exclude specified columns and add prefix
12818    if exclude_columns is not None:
12819        df_matched = df_matched.drop(columns=exclude_columns, errors='ignore')
12820    df_matched = df_matched.rename(columns=lambda x: f"{prefix}{x}" if x != match_column and x in df_matched.columns else x)
12821
12822    # Combine df_a with matched rows from df_b
12823    result_df = pd.concat([df_a.reset_index(drop=True), df_matched], axis=1)
12824    
12825    # Extract unmatched rows from df_b
12826    unmatched_df_b = df_b.drop(index=matched_indices).reset_index(drop=True)
12827
12828    return result_df, unmatched_df_b
12829
12830
12831def fix_LR_RL_stuff(df, col1, col2, size_col1, size_col2, id1, id2 ):
12832    df_copy = df.copy()
12833    # Ensure columns contain strings for substring checks
12834    df_copy[col1] = df_copy[col1].astype(str)
12835    df_copy[col2] = df_copy[col2].astype(str)
12836    df_copy[id1] = df_copy[id1].astype(str)
12837    df_copy[id2] = df_copy[id2].astype(str)
12838    
12839    for index, row in df_copy.iterrows():
12840        col1_val = row[col1]
12841        col2_val = row[col2]
12842        size1 = row[size_col1]
12843        size2 = row[size_col2]
12844        
12845        # Check for 'RL' or 'LR' in each column and compare sizes
12846        if ('RL' in col1_val or 'LR' in col1_val) and ('RL' in col2_val or 'LR' in col2_val):
12847            continue
12848        elif 'RL' not in col1_val and 'LR' not in col1_val and 'RL' not in col2_val and 'LR' not in col2_val:
12849            if size1 < size2:
12850                df_copy.at[index, col1] = df_copy.at[index, col2]
12851                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12852                df_copy.at[index, id1] = df_copy.at[index, id2]
12853                df_copy.at[index, size_col2] = 0
12854                df_copy.at[index, col2] = None
12855                df_copy.at[index, id2] = None
12856            else:
12857                df_copy.at[index, col2] = None
12858                df_copy.at[index, size_col2] = 0
12859                df_copy.at[index, id2] = None
12860        elif 'RL' in col1_val or 'LR' in col1_val:
12861            if size1 < size2:
12862                df_copy.at[index, col1] = df_copy.at[index, col2]
12863                df_copy.at[index, id1] = df_copy.at[index, id2]
12864                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12865                df_copy.at[index, size_col2] = 0
12866                df_copy.at[index, col2] = None
12867                df_copy.at[index, id2] = None
12868            else:
12869                df_copy.at[index, col2] = None
12870                df_copy.at[index, id2] = None
12871                df_copy.at[index, size_col2] = 0
12872        elif 'RL' in col2_val or 'LR' in col2_val:
12873            if size2 < size1:
12874                df_copy.at[index, id2] = None
12875                df_copy.at[index, col2] = None
12876                df_copy.at[index, size_col2] = 0
12877            else:
12878                df_copy.at[index, col1] = df_copy.at[index, col2]
12879                df_copy.at[index, id1] = df_copy.at[index, id2]
12880                df_copy.at[index, size_col1] = df_copy.at[index, size_col2]
12881                df_copy.at[index, size_col2] = 0
12882                df_copy.at[index, col2] = None    
12883                df_copy.at[index, id2] = None    
12884    return df_copy
12885
12886
12887def renameit(df, old_col_name, new_col_name):
12888    """
12889    Renames a column in a pandas DataFrame in place. Raises an error if the specified old column name does not exist.
12890
12891    Parameters:
12892    - df: pandas.DataFrame
12893        The DataFrame in which the column is to be renamed.
12894    - old_col_name: str
12895        The current name of the column to be renamed.
12896    - new_col_name: str
12897        The new name for the column.
12898    
12899    Raises:
12900    - ValueError: If the old column name does not exist in the DataFrame.
12901    
12902    Returns:
12903    None
12904    """
12905    import warnings
12906    # Check if the old column name exists in the DataFrame
12907    if old_col_name not in df.columns:
12908        warnings.warn(f"The column '{old_col_name}' does not exist in the DataFrame.")
12909        return
12910    
12911    # Proceed with renaming the column if it exists
12912    df.rename(columns={old_col_name: new_col_name}, inplace=True)
12913
12914
12915def mm_match_by_qc_scoring_all( qc_dataframe, fix_LRRL=True, mysep='-', verbose=True ):
12916    """
12917    Processes a quality control (QC) DataFrame to perform modality-specific matching and filtering based
12918    on predefined criteria, optimizing for minimal outliers and noise, and maximal signal-to-noise ratio (SNR),
12919    expected value of randomness (EVR), and dimensionality time (dimt).
12920
12921    This function iteratively matches dataframes derived from the QC dataframe for different imaging modalities,
12922    applying a series of filters to select the best matches based on the QC metrics. Matches are made with
12923    consideration to minimize outlier loop and noise, while maximizing SNR, EVR, and dimt for each modality.
12924
12925    Parameters:
12926    ----------
12927    qc_dataframe : pandas.DataFrame
12928        The DataFrame containing QC metrics for different modalities and imaging data.
12929    fix_LRRL : bool, optional
12930    mysep : string, character such as - or _ the usual antspymm separator argument
12931
12932    verbose : bool, optional
12933        If True, prints the progress and the shape of the DataFrame being processed in each step.
12934
12935    Process:
12936    -------
12937    1. Standardizes modalities by merging DTI-related entries.
12938    2. Converts specific columns to appropriate data types for processing.
12939    3. Performs modality-specific matching and filtering based on the outlier column and criteria for each modality.
12940    4. Iteratively processes unmatched data for predefined modalities with specific prefixes to find further matches.
12941    
12942    Returns:
12943    -------
12944    pandas.DataFrame
12945        The matched and filtered DataFrame after applying all QC scoring and matching operations across specified modalities.
12946
12947    """
12948    qc_dataframe=remove_unwanted_columns( qc_dataframe )
12949    qc_dataframe['modality'] = qc_dataframe['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
12950    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
12951    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
12952    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
12953    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
12954    outlier_column='ol_loop'
12955    mmdf0 = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column, mysep=mysep )['filt']
12956    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column, mysep=mysep  )['filt']
12957    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column, mysep=mysep  )['filt']
12958    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column, mysep=mysep  )['filt']
12959    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column, mysep=mysep  )['filt']
12960    pfdf = best_mmm( qc_dataframe, 'perf', outlier_column=outlier_column, mysep=mysep  )['filt']
12961
12962    criteria = {'ol_loop': 'min', 'noise': 'min', 'snr': 'max', 'EVR': 'max', 'reflection_err':'min'}
12963    xcl = [ 'mrimfg', 'mrimodel','mriMagneticFieldStrength', 'dti_failed', 'rsf_failed', 'subjectID', 'date', 'subjectIDdate','repeat']
12964    # Assuming df_a and df_b are already loaded
12965    mmdf, undffl = mm_match_by_qc_scoring(mmdf0, fldf, 'subjectIDdate', criteria, 
12966                        prefix='T2Flair_', exclude_columns=xcl )
12967
12968    mmdf, undfpf = mm_match_by_qc_scoring(mmdf, pfdf, 'subjectIDdate', criteria, 
12969                        prefix='perf_', exclude_columns=xcl )
12970
12971    prefixes = ['NM1_', 'NM2_', 'NM3_', 'NM4_', 'NM5_', 'NM6_']  
12972    undfmod = nmdf  # Initialize 'undfmod' with 'nmdf' for the first iteration
12973    if undfmod is not None:
12974        if verbose:
12975            print('start NM')
12976            print( undfmod.shape )
12977        for prefix in prefixes:
12978            if undfmod.shape[0] > 50:
12979                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
12980                if verbose:
12981                    print( prefix )
12982                    print( undfmod.shape )
12983
12984    criteria = {'ol_loop': 'min', 'noise': 'min', 'snr': 'max', 'EVR': 'max', 'dimt':'max'}
12985    # higher bvalues lead to more noise ...
12986    criteria = {'ol_loop': 'min', 'noise': 'min',  'dti_bvalueMax':'min',  'dimt':'max'}
12987    prefixes = ['DTI1_', 'DTI2_', 'DTI3_']  # List of prefixes for each matching iteration
12988    undfmod = dtdf
12989    if undfmod is not None:
12990        if verbose:
12991            print('start DT')
12992            print( undfmod.shape )
12993        for prefix in prefixes:
12994            if undfmod.shape[0] > 50:
12995                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
12996                if verbose:
12997                    print( prefix )
12998                    print( undfmod.shape )
12999
13000    prefixes = ['rsf1_', 'rsf2_', 'rsf3_']  # List of prefixes for each matching iteration
13001    undfmod = rsdf  # Initialize 'undfmod' with 'nmdf' for the first iteration
13002    if undfmod is not None:
13003        if verbose:
13004            print('start rsf')
13005            print( undfmod.shape )
13006        for prefix in prefixes:
13007            if undfmod.shape[0] > 50:
13008                mmdf, undfmod = mm_match_by_qc_scoring(mmdf, undfmod, 'subjectIDdate', criteria, prefix=prefix, exclude_columns=xcl)
13009                if verbose:
13010                    print( prefix )
13011                    print( undfmod.shape )
13012    
13013    if fix_LRRL:
13014        #        mmdf=fix_LR_RL_stuff( mmdf, 'DTI1_filename', 'DTI2_filename', 'DTI1_dimt', 'DTI2_dimt')
13015        mmdf=fix_LR_RL_stuff( mmdf, 'rsf1_filename', 'rsf2_filename', 'rsf1_dimt', 'rsf2_dimt', 'rsf1_imageID', 'rsf2_imageID'  )
13016    else:
13017        import warnings
13018        warnings.warn("FIXME: should fix LR and RL situation for the DTI and rsfMRI")
13019
13020    # now do the necessary replacements
13021    
13022    renameit( mmdf, 'perf_imageID', 'perfid' )
13023    renameit( mmdf, 'perf_filename', 'perffn' )
13024    renameit( mmdf, 'T2Flair_imageID', 'flairid' )
13025    renameit( mmdf, 'T2Flair_filename', 'flairfn' )
13026    renameit( mmdf, 'rsf1_imageID', 'rsfid1' )
13027    renameit( mmdf, 'rsf2_imageID', 'rsfid2' )
13028    renameit( mmdf, 'rsf1_filename', 'rsffn1' )
13029    renameit( mmdf, 'rsf2_filename', 'rsffn2' )
13030    renameit( mmdf, 'DTI1_imageID', 'dtid1' )
13031    renameit( mmdf, 'DTI2_imageID', 'dtid2' )
13032    renameit( mmdf, 'DTI3_imageID', 'dtid3' )
13033    renameit( mmdf, 'DTI1_filename', 'dtfn1' )
13034    renameit( mmdf, 'DTI2_filename', 'dtfn2' )
13035    renameit( mmdf, 'DTI3_filename', 'dtfn3' )
13036    for x in range(1,6):
13037        temp0="NM"+str(x)+"_imageID"
13038        temp1="nmid"+str(x)
13039        renameit( mmdf, temp0, temp1 )
13040        temp0="NM"+str(x)+"_filename"
13041        temp1="nmfn"+str(x)
13042        renameit( mmdf, temp0, temp1 )
13043    return mmdf
def version():
145def version( ):
146    """
147    report versions of this package and primary dependencies
148
149    Arguments
150    ---------
151    None
152
153    Returns
154    -------
155    a dictionary with package name and versions
156
157    Example
158    -------
159    >>> import antspymm
160    >>> antspymm.version()
161    """
162    import pkg_resources
163    return {
164              'tensorflow': pkg_resources.get_distribution("tensorflow").version,
165              'antspyx': pkg_resources.get_distribution("antspyx").version,
166              'antspynet': pkg_resources.get_distribution("antspynet").version,
167              'antspyt1w': pkg_resources.get_distribution("antspyt1w").version,
168              'antspymm': pkg_resources.get_distribution("antspymm").version
169              }

report versions of this package and primary dependencies

Arguments

None

Returns

a dictionary with package name and versions

Example

>>> import antspymm
>>> antspymm.version()
def mm_read(x, standardize_intensity=False, modality=''):
1671def mm_read( x, standardize_intensity=False, modality='' ):
1672    """
1673    read an image from a filename - same as ants.image_read (for now)
1674
1675    standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one
1676
1677    modality : not used
1678    """
1679    if x is None:
1680        raise ValueError( " None passed to function antspymm.mm_read." )
1681    if not isinstance(x,str):
1682        raise ValueError( " Non-string passed to function antspymm.mm_read." )
1683    if not os.path.exists( x ):
1684        raise ValueError( " file " + fni + " does not exist." )
1685    img = ants.image_read( x, reorient=False )
1686    if standardize_intensity:
1687        img[img<0.0]=0.0
1688        img=ants.iMath(img,'Normalize')
1689    if modality == "T1w" and img.dimension == 4:
1690        print("WARNING: input image is 4D - we attempt a hack fix that works in some odd cases of PPMI data - please check this image: " + x, flush=True )
1691        i1=ants.slice_image(img,3,0)
1692        i2=ants.slice_image(img,3,1)
1693        kk=np.concatenate( [i1.numpy(),i2.numpy()], axis=2 )
1694        kk=ants.from_numpy(kk)
1695        img=ants.copy_image_info(i1,kk)
1696    return img

read an image from a filename - same as ants.image_read (for now)

standardize_intensity : boolean ; if True will set negative values to zero and normalize into the range of zero to one

modality : not used

def mm_read_to_3d(x, slice=None, modality=''):
1698def mm_read_to_3d( x, slice=None, modality='' ):
1699    """
1700    read an image from a filename - and return as 3d or None if that is not possible
1701    """
1702    img = ants.image_read( x, reorient=False )
1703    if img.dimension <= 3:
1704        return img
1705    elif img.dimension == 4:
1706        nslices = img.shape[3]
1707        if slice is None:
1708            sl = np.round( nslices * 0.5 )
1709        else:
1710            sl = slice
1711        if sl > nslices:
1712            sl = nslices-1
1713        return ants.slice_image( img, axis=3, idx=int(sl) )
1714    elif img.dimension > 4:
1715        return img
1716    return None

read an image from a filename - and return as 3d or None if that is not possible

def image_write_with_thumbnail(x, fn, y=None, thumb=True):
1754def image_write_with_thumbnail( x,  fn, y=None, thumb=True ):
1755    """
1756    will write the image and (optionally) a png thumbnail with (optional) overlay/underlay
1757    """
1758    ants.image_write( x, fn )
1759    if not thumb or x.components > 1:
1760        return
1761    thumb_fn=re.sub(".nii.gz","_3dthumb.png",fn)
1762    if thumb and x.dimension == 3:
1763        if y is None:
1764            try:
1765                ants.plot_ortho( x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1766            except:
1767                pass
1768        else:
1769            try:
1770                ants.plot_ortho( y, x, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1771            except:
1772                pass
1773    if thumb and x.dimension == 4:
1774        thumb_fn=re.sub(".nii.gz","_4dthumb.png",fn)
1775        nslices = x.shape[3]
1776        sl = np.round( nslices * 0.5 )
1777        if sl > nslices:
1778            sl = nslices-1
1779        xview = ants.slice_image( x, axis=3, idx=int(sl) )
1780        if y is None:
1781            try:
1782                ants.plot_ortho( xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1783            except:
1784                pass
1785        else:
1786            if y.dimension == 3:
1787                try:
1788                    ants.plot_ortho(y, xview, crop=True, filename=thumb_fn, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
1789                except:
1790                    pass
1791    return

will write the image and (optionally) a png thumbnail with (optional) overlay/underlay

def nrg_format_path(projectID, subjectID, date, modality, imageID, separator='-'):
1181def nrg_format_path( projectID, subjectID, date, modality, imageID, separator='-' ):
1182    """
1183    create the NRG path on disk given the project, subject id, date, modality and image id
1184
1185    Arguments
1186    ---------
1187
1188    projectID : string for the project e.g. PPMI
1189
1190    subjectID : string uniquely identifying the subject e.g. 0001
1191
1192    date : string for the date usually 20550228 ie YYYYMMDD format
1193
1194    modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)
1195
1196    imageID : string uniquely identifying the specific image
1197
1198    separator : default to -
1199
1200    Returns
1201    -------
1202    the path where one would write the image on disk
1203
1204    """
1205    thedirectory = os.path.join( str(projectID), str(subjectID), str(date), str(modality), str(imageID) )
1206    thefilename = str(projectID) + separator + str(subjectID) + separator + str(date) + separator + str(modality) + separator + str(imageID)
1207    return os.path.join( thedirectory, thefilename )

create the NRG path on disk given the project, subject id, date, modality and image id

Arguments

projectID : string for the project e.g. PPMI

subjectID : string uniquely identifying the subject e.g. 0001

date : string for the date usually 20550228 ie YYYYMMDD format

modality : string should be one of T1w, T2Flair, rsfMRI, NM2DMT and DTI ... rsfMRI and DTI may also be DTI_LR, DTI_RL, rsfMRI_LR and rsfMRI_RL where the RL / LR relates to phase encoding direction (even if it is AP/PA)

imageID : string uniquely identifying the specific image

separator : default to -

Returns

the path where one would write the image on disk

def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
1373def highest_quality_repeat(mxdfin, idvar, visitvar, qualityvar):
1374    """
1375    This function returns a subset of the input dataframe that retains only the rows
1376    that correspond to the highest quality observation for each combination of ID and visit.
1377
1378    Parameters:
1379    ----------
1380    mxdfin: pandas.DataFrame
1381        The input dataframe.
1382    idvar: str
1383        The name of the column that contains the ID variable.
1384    visitvar: str
1385        The name of the column that contains the visit variable.
1386    qualityvar: str
1387        The name of the column that contains the quality variable.
1388
1389    Returns:
1390    -------
1391    pandas.DataFrame
1392        A subset of the input dataframe that retains only the rows that correspond
1393        to the highest quality observation for each combination of ID and visit.
1394    """
1395    if visitvar not in mxdfin.columns:
1396        raise ValueError("visitvar not in dataframe")
1397    if idvar not in mxdfin.columns:
1398        raise ValueError("idvar not in dataframe")
1399    if qualityvar not in mxdfin.columns:
1400        raise ValueError("qualityvar not in dataframe")
1401
1402    mxdfin[qualityvar] = mxdfin[qualityvar].astype(float)
1403
1404    vizzes = mxdfin[visitvar].unique()
1405    uids = mxdfin[idvar].unique()
1406    useit = np.zeros(mxdfin.shape[0], dtype=bool)
1407
1408    for u in uids:
1409        losel = mxdfin[idvar] == u
1410        vizzesloc = mxdfin[losel][visitvar].unique()
1411
1412        for v in vizzesloc:
1413            losel = (mxdfin[idvar] == u) & (mxdfin[visitvar] == v)
1414            mysnr = mxdfin.loc[losel, qualityvar]
1415            myw = np.where(losel)[0]
1416
1417            if len(myw) > 1:
1418                if any(~np.isnan(mysnr)):
1419                    useit[myw[np.argmax(mysnr)]] = True
1420                else:
1421                    useit[myw] = True
1422            else:
1423                useit[myw] = True
1424
1425    return mxdfin[useit]

This function returns a subset of the input dataframe that retains only the rows that correspond to the highest quality observation for each combination of ID and visit.

Parameters:

mxdfin: pandas.DataFrame The input dataframe. idvar: str The name of the column that contains the ID variable. visitvar: str The name of the column that contains the visit variable. qualityvar: str The name of the column that contains the quality variable.

Returns:

pandas.DataFrame A subset of the input dataframe that retains only the rows that correspond to the highest quality observation for each combination of ID and visit.

def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False):
1428def match_modalities( qc_dataframe, unique_identifier='filename', outlier_column='ol_loop', mysep='-', verbose=False ):
1429    """
1430    Find the best multiple modality dataset at each time point
1431
1432    :param qc_dataframe: quality control data frame with
1433    :param unique_identifier : the unique NRG filename for each image
1434    :param outlier_column: outlierness score used to identify the best image (pair) at a given date
1435    :param mysep (str, optional): the separator used in the image file names. Defaults to '-'.
1436    :param verbose: boolean
1437    :return: filtered matched modality data frame
1438    """
1439    import pandas as pd
1440    import numpy as np
1441    qc_dataframe['filename']=qc_dataframe['filename'].astype(str)
1442    qc_dataframe['ol_loop']=qc_dataframe['ol_loop'].astype(float)
1443    qc_dataframe['ol_lof']=qc_dataframe['ol_lof'].astype(float)
1444    qc_dataframe['ol_lof_decision']=qc_dataframe['ol_lof_decision'].astype(float)
1445    mmdf = best_mmm( qc_dataframe, 'T1w', outlier_column=outlier_column )['filt']
1446    fldf = best_mmm( qc_dataframe, 'T2Flair', outlier_column=outlier_column )['filt']
1447    nmdf = best_mmm( qc_dataframe, 'NM2DMT', outlier_column=outlier_column )['filt']
1448    rsdf = best_mmm( qc_dataframe, 'rsfMRI', outlier_column=outlier_column )['filt']
1449    dtdf = best_mmm( qc_dataframe, 'DTI', outlier_column=outlier_column )['filt']
1450    mmdf['flairid'] = None
1451    mmdf['flairfn'] = None
1452    mmdf['flairloop'] = None
1453    mmdf['flairlof'] = None
1454    mmdf['dtid1'] = None
1455    mmdf['dtfn1'] = None
1456    mmdf['dtntimepoints1'] = 0
1457    mmdf['dtloop1'] = math.nan
1458    mmdf['dtlof1'] = math.nan
1459    mmdf['dtid2'] = None
1460    mmdf['dtfn2'] = None
1461    mmdf['dtntimepoints2'] = 0
1462    mmdf['dtloop2'] = math.nan
1463    mmdf['dtlof2'] = math.nan
1464    mmdf['rsfid1'] = None
1465    mmdf['rsffn1'] = None
1466    mmdf['rsfntimepoints1'] = 0
1467    mmdf['rsfloop1'] = math.nan
1468    mmdf['rsflof1'] = math.nan
1469    mmdf['rsfid2'] = None
1470    mmdf['rsffn2'] = None
1471    mmdf['rsfntimepoints2'] = 0
1472    mmdf['rsfloop2'] = math.nan
1473    mmdf['rsflof2'] = math.nan
1474    for k in range(1,11):
1475        myid='nmid'+str(k)
1476        mmdf[myid] = None
1477        myid='nmfn'+str(k)
1478        mmdf[myid] = None
1479        myid='nmloop'+str(k)
1480        mmdf[myid] = math.nan
1481        myid='nmlof'+str(k)
1482        mmdf[myid] = math.nan
1483    if verbose:
1484        print( mmdf.shape )
1485    for k in range(mmdf.shape[0]):
1486        if verbose:
1487            if k % 100 == 0:
1488                progger = str( k ) # np.round( k / mmdf.shape[0] * 100 ) )
1489                print( progger, end ="...", flush=True)
1490        if dtdf is not None:
1491            locsel = (dtdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
1492            if sum(locsel) == 1:
1493                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = dtdf["imageID"][locsel].values[0]
1494                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = dtdf[unique_identifier][locsel].values[0]
1495                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = dtdf[outlier_column][locsel].values[0]
1496                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(dtdf['ol_lof_decision'][locsel].values[0])
1497                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
1498            elif sum(locsel) > 1:
1499                locdf = dtdf[locsel]
1500                dedupe = locdf[["snr","cnr"]].duplicated()
1501                locdf = locdf[~dedupe]
1502                if locdf.shape[0] > 1:
1503                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1504                mmdf.iloc[k, mmdf.columns.get_loc("dtid1")] = locdf["imageID"].values[0]
1505                mmdf.iloc[k, mmdf.columns.get_loc("dtfn1")] = locdf[unique_identifier].values[0]
1506                mmdf.iloc[k, mmdf.columns.get_loc("dtloop1")] = locdf[outlier_column].values[0]
1507                mmdf.iloc[k, mmdf.columns.get_loc("dtlof1")] = float(locdf['ol_lof_decision'][locsel].values[0])
1508                mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints1")] = float(dtdf['dimt'][locsel].values[0])
1509                if locdf.shape[0] > 1:
1510                    mmdf.iloc[k, mmdf.columns.get_loc("dtid2")] = locdf["imageID"].values[1]
1511                    mmdf.iloc[k, mmdf.columns.get_loc("dtfn2")] = locdf[unique_identifier].values[1]
1512                    mmdf.iloc[k, mmdf.columns.get_loc("dtloop2")] = locdf[outlier_column].values[1]
1513                    mmdf.iloc[k, mmdf.columns.get_loc("dtlof2")] = float(locdf['ol_lof_decision'][locsel].values[1])
1514                    mmdf.iloc[k, mmdf.columns.get_loc("dtntimepoints2")] = float(dtdf['dimt'][locsel].values[1])
1515        if rsdf is not None:
1516            locsel = (rsdf["subjectIDdate"] == mmdf["subjectIDdate"].iloc[k])
1517            if sum(locsel) == 1:
1518                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = rsdf["imageID"][locsel].values[0]
1519                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = rsdf[unique_identifier][locsel].values[0]
1520                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = rsdf[outlier_column][locsel].values[0]
1521                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(rsdf['ol_lof_decision'].values[0])
1522                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(rsdf['dimt'][locsel].values[0])
1523            elif sum(locsel) > 1:
1524                locdf = rsdf[locsel]
1525                dedupe = locdf[["snr","cnr"]].duplicated()
1526                locdf = locdf[~dedupe]
1527                if locdf.shape[0] > 1:
1528                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1529                mmdf.iloc[k, mmdf.columns.get_loc("rsfid1")] = locdf["imageID"].values[0]
1530                mmdf.iloc[k, mmdf.columns.get_loc("rsffn1")] = locdf[unique_identifier].values[0]
1531                mmdf.iloc[k, mmdf.columns.get_loc("rsfloop1")] = locdf[outlier_column].values[0]
1532                mmdf.iloc[k, mmdf.columns.get_loc("rsflof1")] = float(locdf['ol_lof_decision'].values[0])
1533                mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints1")] = float(locdf['dimt'][locsel].values[0])
1534                if locdf.shape[0] > 1:
1535                    mmdf.iloc[k, mmdf.columns.get_loc("rsfid2")] = locdf["imageID"].values[1]
1536                    mmdf.iloc[k, mmdf.columns.get_loc("rsffn2")] = locdf[unique_identifier].values[1]
1537                    mmdf.iloc[k, mmdf.columns.get_loc("rsfloop2")] = locdf[outlier_column].values[1]
1538                    mmdf.iloc[k, mmdf.columns.get_loc("rsflof2")] = float(locdf['ol_lof_decision'].values[1])
1539                    mmdf.iloc[k, mmdf.columns.get_loc("rsfntimepoints2")] = float(locdf['dimt'][locsel].values[1])
1540
1541        if fldf is not None:
1542            locsel = fldf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
1543            if locsel.sum() == 1:
1544                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = fldf['imageID'][locsel].values[0]
1545                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = fldf[unique_identifier][locsel].values[0]
1546                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = fldf[outlier_column][locsel].values[0]
1547                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(fldf['ol_lof_decision'][locsel].values[0])
1548            elif sum(locsel) > 1:
1549                locdf = fldf[locsel]
1550                dedupe = locdf[["snr","cnr"]].duplicated()
1551                locdf = locdf[~dedupe]
1552                if locdf.shape[0] > 1:
1553                    locdf = locdf.sort_values(outlier_column).iloc[:2]
1554                mmdf.iloc[k, mmdf.columns.get_loc("flairid")] = locdf["imageID"].values[0]
1555                mmdf.iloc[k, mmdf.columns.get_loc("flairfn")] = locdf[unique_identifier].values[0]
1556                mmdf.iloc[k, mmdf.columns.get_loc("flairloop")] = locdf[outlier_column].values[0]
1557                mmdf.iloc[k, mmdf.columns.get_loc("flairlof")] = float(locdf['ol_lof_decision'].values[0])
1558
1559        if nmdf is not None:
1560            locsel = nmdf['subjectIDdate'] == mmdf['subjectIDdate'].iloc[k]
1561            if locsel.sum() > 0:
1562                locdf = nmdf[locsel]
1563                for i in range(np.min( [10,locdf.shape[0]])):
1564                    nmid = "nmid"+str(i+1)
1565                    mmdf.loc[k,nmid] = locdf['imageID'].iloc[i]
1566                    nmfn = "nmfn"+str(i+1)
1567                    mmdf.loc[k,nmfn] = locdf['imageID'].iloc[i]
1568                    nmloop = "nmloop"+str(i+1)
1569                    mmdf.loc[k,nmloop] = locdf[outlier_column].iloc[i]
1570                    nmloop = "nmlof"+str(i+1)
1571                    mmdf.loc[k,nmloop] = float(locdf['ol_lof_decision'].iloc[i])
1572
1573    mmdf['rsf_total_timepoints']=mmdf['rsfntimepoints1']+mmdf['rsfntimepoints2']
1574    mmdf['dt_total_timepoints']=mmdf['dtntimepoints1']+mmdf['dtntimepoints2']
1575    return mmdf

Find the best multiple modality dataset at each time point

Parameters
  • qc_dataframe: quality control data frame with
  • unique_identifier: the unique NRG filename for each image
  • outlier_column: outlierness score used to identify the best image (pair) at a given date
  • mysep (str, optional): the separator used in the image file names. Defaults to '-'.
  • verbose: boolean
Returns

filtered matched modality data frame

def mc_resample_image_to_target(x, y, interp_type='linear'):
1810def mc_resample_image_to_target( x , y, interp_type='linear' ):
1811    """
1812    multichannel version of resample_image_to_target
1813    """
1814    xx=ants.split_channels( x )
1815    yy=ants.split_channels( y )[0]
1816    newl=[]
1817    for k in range(len(xx)):
1818        newl.append(  ants.resample_image_to_target( xx[k], yy, interp_type=interp_type ) )
1819    return ants.merge_channels( newl )

multichannel version of resample_image_to_target

def nrg_filelist_to_dataframe(filename_list, myseparator='-'):
1821def nrg_filelist_to_dataframe( filename_list, myseparator="-" ):
1822    """
1823    convert a list of files in nrg format to a dataframe
1824
1825    Arguments
1826    ---------
1827    filename_list : globbed list of files
1828
1829    myseparator : string separator between nrg parts
1830
1831    Returns
1832    -------
1833
1834    df : pandas data frame
1835
1836    """
1837    def getmtime(x):
1838        x= dt.datetime.fromtimestamp(os.path.getmtime(x)).strftime("%Y-%m-%d %H:%M:%d")
1839        return x
1840    df=pd.DataFrame(columns=['filename','file_last_mod_t','else','sid','visitdate','modality','uid'])
1841    df.set_index('filename')
1842    df['filename'] = pd.Series([file for file in filename_list ])
1843    # I applied a time modified file to df['file_last_mod_t'] by getmtime function
1844    df['file_last_mod_t'] = df['filename'].apply(lambda x: getmtime(x))
1845    for k in range(df.shape[0]):
1846        locfn=df['filename'].iloc[k]
1847        splitter=os.path.basename(locfn).split( myseparator )
1848        df['sid'].iloc[k]=splitter[1]
1849        df['visitdate'].iloc[k]=splitter[2]
1850        df['modality'].iloc[k]=splitter[3]
1851        temp = os.path.splitext(splitter[4])[0]
1852        df['uid'].iloc[k]=os.path.splitext(temp)[0]
1853    return df

convert a list of files in nrg format to a dataframe

Arguments

filename_list : globbed list of files

myseparator : string separator between nrg parts

Returns

df : pandas data frame

def merge_timeseries_data(img_LR, img_RL, allow_resample=True):
1856def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
1857    """
1858    merge time series data into space of reference_image
1859
1860    img_LR : image
1861
1862    img_RL : image
1863
1864    allow_resample : boolean
1865
1866    """
1867    # concatenate the images into the reference space
1868    mimg=[]
1869    for kk in range( img_LR.shape[3] ):
1870        temp = ants.slice_image( img_LR, axis=3, idx=kk )
1871        mimg.append( temp )
1872    for kk in range( img_RL.shape[3] ):
1873        temp = ants.slice_image( img_RL, axis=3, idx=kk )
1874        if kk == 0:
1875            insamespace = ants.image_physical_space_consistency( temp, mimg[0] )
1876        if allow_resample and not insamespace :
1877            temp = ants.resample_image_to_target( temp, mimg[0] )
1878        mimg.append( temp )
1879    return ants.list_to_ndimage( img_LR, mimg )

merge time series data into space of reference_image

img_LR : image

img_RL : image

allow_resample : boolean

def timeseries_reg( image, avg_b0, type_of_transform='antsRegistrationSyNRepro[r]', total_sigma=1.0, fdOffset=2.0, trim=0, output_directory=None, return_numpy_motion_parameters=False, verbose=False, **kwargs):
1965def timeseries_reg(
1966    image,
1967    avg_b0,
1968    type_of_transform='antsRegistrationSyNRepro[r]',
1969    total_sigma=1.0,
1970    fdOffset=2.0,
1971    trim = 0,
1972    output_directory=None,
1973    return_numpy_motion_parameters=False,
1974    verbose=False, **kwargs
1975):
1976    """
1977    Correct time-series data for motion.
1978
1979    Arguments
1980    ---------
1981    image: antsImage, usually ND where D=4.
1982
1983    avg_b0: Fixed image b0 image
1984
1985    type_of_transform : string
1986            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
1987            See ants registration for details.
1988
1989    fdOffset: offset value to use in framewise displacement calculation
1990
1991    trim : integer - trim this many images off the front of the time series
1992
1993    output_directory : string
1994            output will be placed in this directory plus a numeric extension.
1995
1996    return_numpy_motion_parameters : boolean
1997
1998    verbose: boolean
1999
2000    kwargs: keyword args
2001            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2002
2003    Returns
2004    -------
2005    dict containing follow key/value pairs:
2006        `motion_corrected`: Moving image warped to space of fixed image.
2007        `motion_parameters`: transforms for each image in the time series.
2008        `FD`: Framewise displacement generalized for arbitrary transformations.
2009
2010    Notes
2011    -----
2012    Control extra arguments via kwargs. see ants.registration for details.
2013
2014    Example
2015    -------
2016    >>> import ants
2017    """
2018    idim = image.dimension
2019    ishape = image.shape
2020    nTimePoints = ishape[idim - 1]
2021    FD = np.zeros(nTimePoints)
2022    if type_of_transform is None:
2023        return {
2024            "motion_corrected": image,
2025            "motion_parameters": None,
2026            "FD": FD
2027        }
2028
2029    remove_it=False
2030    if output_directory is None:
2031        remove_it=True
2032        output_directory = tempfile.mkdtemp()
2033    output_directory_w = output_directory + "/ts_reg/"
2034    os.makedirs(output_directory_w,exist_ok=True)
2035    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2036    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2037    if verbose:
2038        print('bold motcorr with ' + type_of_transform)
2039        print(output_directory_w)
2040        print(ofnG)
2041        print(ofnL)
2042        print("remove_it " + str( remove_it ) )
2043
2044    # get a local deformation from slice to local avg space
2045    motion_parameters = list()
2046    motion_corrected = list()
2047    mask = ants.get_mask( avg_b0 )
2048    centerOfMass = mask.get_center_of_mass()
2049    npts = pow(2, idim - 1)
2050    pointOffsets = np.zeros((npts, idim - 1))
2051    myrad = np.ones(idim - 1).astype(int).tolist()
2052    mask1vals = np.zeros(int(mask.sum()))
2053    mask1vals[round(len(mask1vals) / 2)] = 1
2054    mask1 = ants.make_image(mask, mask1vals)
2055    myoffsets = ants.get_neighborhood_in_mask(
2056        mask1, mask1, radius=myrad, spatial_info=True
2057    )["offsets"]
2058    mycols = list("xy")
2059    if idim - 1 == 3:
2060        mycols = list("xyz")
2061    useinds = list()
2062    for k in range(myoffsets.shape[0]):
2063        if abs(myoffsets[k, :]).sum() == (idim - 2):
2064            useinds.append(k)
2065        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2066    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2067    if verbose:
2068        print("Progress:")
2069    counter = round( nTimePoints / 10 ) + 1
2070    for k in range( nTimePoints):
2071        if verbose and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
2072            myperc = round( k / nTimePoints * 100)
2073            print(myperc, end="%.", flush=True)
2074        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2075        temp = ants.iMath(temp, "Normalize")
2076        txprefix = ofnL+str(k % 2).zfill(4)+"_"
2077        if temp.numpy().var() > 0:
2078            myrig = ants.registration(
2079                    avg_b0, temp,
2080                    type_of_transform='antsRegistrationSyNRepro[r]',
2081                    outprefix=txprefix
2082                )
2083            if type_of_transform == 'SyN':
2084                myreg = ants.registration(
2085                    avg_b0, temp,
2086                    type_of_transform='SyNOnly',
2087                    total_sigma=total_sigma,
2088                    initial_transform=myrig['fwdtransforms'][0],
2089                    outprefix=txprefix,
2090                    **kwargs
2091                )
2092            else:
2093                myreg = myrig
2094            fdptsTxI = ants.apply_transforms_to_points(
2095                idim - 1, fdpts, myrig["fwdtransforms"]
2096            )
2097            if k > 0 and motion_parameters[k - 1] != "NA":
2098                fdptsTxIminus1 = ants.apply_transforms_to_points(
2099                    idim - 1, fdpts, motion_parameters[k - 1]
2100                )
2101            else:
2102                fdptsTxIminus1 = fdptsTxI
2103            # take the absolute value, then the mean across columns, then the sum
2104            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2105            motion_parameters.append(myreg["fwdtransforms"])
2106        else:
2107            motion_parameters.append("NA")
2108
2109        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2110        if temp.numpy().var() > 0:
2111            img1w = ants.apply_transforms( avg_b0,
2112                temp,
2113                motion_parameters[k] )
2114            motion_corrected.append(img1w)
2115        else:
2116            motion_corrected.append(avg_b0)
2117
2118    motion_parameters = motion_parameters[trim:len(motion_parameters)]
2119    if return_numpy_motion_parameters:
2120        motion_parameters = read_ants_transforms_to_numpy( motion_parameters )
2121
2122    if remove_it:
2123        import shutil
2124        shutil.rmtree(output_directory, ignore_errors=True )
2125
2126    if verbose:
2127        print("Done")
2128    d4siz = list(avg_b0.shape)
2129    d4siz.append( 2 )
2130    spc = list(ants.get_spacing( avg_b0 ))
2131    spc.append( ants.get_spacing(image)[3] )
2132    mydir = ants.get_direction( avg_b0 )
2133    mydir4d = ants.get_direction( image )
2134    mydir4d[0:3,0:3]=mydir
2135    myorg = list(ants.get_origin( avg_b0 ))
2136    myorg.append( 0.0 )
2137    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
2138    return {
2139        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected[trim:len(motion_corrected)]),
2140        "motion_parameters": motion_parameters,
2141        "FD": FD[trim:len(FD)]
2142    }

Correct time-series data for motion.

Arguments

image: antsImage, usually ND where D=4.

avg_b0: Fixed image b0 image

type_of_transform : string A linear or non-linear registration type. Mutual information metric and rigid transformation by default. See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

trim : integer - trim this many images off the front of the time series

output_directory : string output will be placed in this directory plus a numeric extension.

return_numpy_motion_parameters : boolean

verbose: boolean

kwargs: keyword args extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
def merge_dwi_data(img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL):
2145def merge_dwi_data( img_LRdwp, bval_LR, bvec_LR, img_RLdwp, bval_RL, bvec_RL ):
2146    """
2147    merge motion and distortion corrected data if possible
2148
2149    img_LRdwp : image
2150
2151    bval_LR : array
2152
2153    bvec_LR : array
2154
2155    img_RLdwp : image
2156
2157    bval_RL : array
2158
2159    bvec_RL : array
2160
2161    """
2162    import warnings
2163    insamespace = ants.image_physical_space_consistency( img_LRdwp, img_RLdwp )
2164    if not insamespace :
2165        warnings.warn('not insamespace ... corrected image pair should occupy the same physical space; returning only the 1st set and wont join these data.')
2166        return img_LRdwp, bval_LR, bvec_LR
2167    
2168    bval_LR = np.concatenate([bval_LR,bval_RL])
2169    bvec_LR = np.concatenate([bvec_LR,bvec_RL])
2170    # concatenate the images
2171    mimg=[]
2172    for kk in range( img_LRdwp.shape[3] ):
2173            mimg.append( ants.slice_image( img_LRdwp, axis=3, idx=kk ) )
2174    for kk in range( img_RLdwp.shape[3] ):
2175            mimg.append( ants.slice_image( img_RLdwp, axis=3, idx=kk ) )
2176    img_LRdwp = ants.list_to_ndimage( img_LRdwp, mimg )
2177    return img_LRdwp, bval_LR, bvec_LR

merge motion and distortion corrected data if possible

img_LRdwp : image

bval_LR : array

bvec_LR : array

img_RLdwp : image

bval_RL : array

bvec_RL : array

def outlierness_by_modality( qcdf, uid='filename', outlier_columns=['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'], verbose=False):
1124def outlierness_by_modality( qcdf, uid='filename', outlier_columns = ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi','reflection_err', 'EVR', 'msk_vol'], verbose=False ):
1125    """
1126    Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF.  LOF appears to be more conservative.  This function will impute missing columns with the mean.
1127
1128    Args:
1129    - qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
1130    - uid: (str) Unique identifier for a subject. Default is 'filename'.
1131    - outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
1132    - verbose: (bool) If True, prints information for each modality. Default is False.
1133
1134    Returns:
1135    - qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column.  Higher values near 1 are more outlying.
1136
1137    Raises:
1138    - ValueError: If uid is not present in the dataframe.
1139
1140    Example:
1141    >>> df = pd.read_csv('data.csv')
1142    >>> outlierness_by_modality(df)
1143    """
1144    from PyNomaly import loop
1145    from sklearn.neighbors import LocalOutlierFactor
1146    qcdfout = qcdf.copy()
1147    pd.set_option('future.no_silent_downcasting', True)
1148    qcdfout.replace([np.inf, -np.inf], np.nan, inplace=True)
1149    if uid not in qcdfout.keys():
1150        raise ValueError( str(uid) + " not in dataframe")
1151    if 'ol_loop' not in qcdfout.keys():
1152        qcdfout['ol_loop']=math.nan
1153    if 'ol_lof' not in qcdfout.keys():
1154        qcdfout['ol_lof']=math.nan
1155    didit=False
1156    for mod in get_valid_modalities( qc=True ):
1157        didit=True
1158        lof = LocalOutlierFactor()
1159        locsel = qcdfout["modality"] == mod
1160        rr = qcdfout[locsel][outlier_columns]
1161        column_means = rr.mean()
1162        rr.fillna(column_means, inplace=True)
1163        if rr.shape[0] > 1:
1164            if verbose:
1165                print("calc: " + mod + " outlierness " )
1166            myneigh = np.min( [24, int(np.round(rr.shape[0]*0.5)) ] )
1167            temp = antspyt1w.loop_outlierness(rr.astype(float), standardize=True, extent=3, n_neighbors=myneigh, cluster_labels=None)
1168            qcdfout.loc[locsel,'ol_loop']=temp.astype('float64')
1169            yhat = lof.fit_predict(rr)
1170            temp = lof.negative_outlier_factor_*(-1.0)
1171            temp = temp - temp.min()
1172            yhat[ yhat == 1] = 0
1173            yhat[ yhat == -1] = 1 # these are outliers
1174            qcdfout.loc[locsel,'ol_lof_decision']=yhat
1175            qcdfout.loc[locsel,'ol_lof']=temp/temp.max()
1176    if verbose:
1177        print( didit )
1178    return qcdfout

Calculates outlierness scores for each modality in a dataframe based on given outlier columns using antspyt1w.loop_outlierness() and LOF. LOF appears to be more conservative. This function will impute missing columns with the mean.

Args:

  • qcdf: (Pandas DataFrame) Dataframe containing columns with outlier information for each modality.
  • uid: (str) Unique identifier for a subject. Default is 'filename'.
  • outlier_columns: (list) List of columns containing outlier information. Default is ['noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol'].
  • verbose: (bool) If True, prints information for each modality. Default is False.

Returns:

  • qcdf: (Pandas DataFrame) Updated dataframe with outlierness scores for each modality in the 'ol_loop' and 'ol_lof' column. Higher values near 1 are more outlying.

Raises:

  • ValueError: If uid is not present in the dataframe.

Example:

>>> df = pd.read_csv('data.csv')
>>> outlierness_by_modality(df)
def bvec_reorientation(motion_parameters, bvecs, rebase=None):
2179def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
2180    if motion_parameters is None:
2181        return bvecs
2182    n = len(motion_parameters)
2183    if n < 1:
2184        return bvecs
2185    from scipy.linalg import inv, polar
2186    from dipy.core.gradients import reorient_bvecs
2187    dipymoco = np.zeros( [n,3,3] )
2188    for myidx in range(n):
2189        if myidx < bvecs.shape[0]:
2190            dipymoco[myidx,:,:] = np.eye( 3 )
2191            if motion_parameters[myidx] != 'NA':
2192                temp = motion_parameters[myidx]
2193                if len(temp) == 4 :
2194                    temp1=temp[3] # FIXME should be composite of index 1 and 3
2195                    temp2=temp[1] # FIXME should be composite of index 1 and 3
2196                    txparam1 = ants.read_transform(temp1)
2197                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
2198                    txparam2 = ants.read_transform(temp2)
2199                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
2200                    Rinv = inv( np.dot( txparam2, txparam1 ) )
2201                elif len(temp) == 2 :
2202                    temp=temp[1] # FIXME should be composite of index 1 and 3
2203                    txparam = ants.read_transform(temp)
2204                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
2205                    Rinv = inv( txparam )
2206                elif len(temp) == 3 :
2207                    temp1=temp[2] # FIXME should be composite of index 1 and 3
2208                    temp2=temp[1] # FIXME should be composite of index 1 and 3
2209                    txparam1 = ants.read_transform(temp1)
2210                    txparam1 = ants.get_ants_transform_parameters(txparam1)[0:9].reshape( [3,3])
2211                    txparam2 = ants.read_transform(temp2)
2212                    txparam2 = ants.get_ants_transform_parameters(txparam2)[0:9].reshape( [3,3])
2213                    Rinv = inv( np.dot( txparam2, txparam1 ) )
2214                else:
2215                    temp=temp[0]
2216                    txparam = ants.read_transform(temp)
2217                    txparam = ants.get_ants_transform_parameters(txparam)[0:9].reshape( [3,3])
2218                    Rinv = inv( txparam )
2219                bvecs[myidx,:] = np.dot( Rinv, bvecs[myidx,:] )
2220                if rebase is not None:
2221                    # FIXME - should combine these operations
2222                    bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
2223    return bvecs
def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False):
2259def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
2260    """
2261    extract DTI data from a dipy tensormodel
2262
2263    reference_image : antsImage defining physical space (3D)
2264
2265    tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon
2266
2267    upper_triangular: boolean otherwise use lower triangular coding
2268
2269    return_image : boolean return the ANTsImage form of DTI otherwise return an array
2270
2271    Returns
2272    -------
2273    either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form)
2274        or a 5D NumPy array (dim=X.Y.Z.3.3)
2275
2276    Notes
2277    -----
2278    DiPy returns lower triangular form but ANTs expects upper triangular.
2279        Here, we default to the ANTs standard but could generalize in the future 
2280        because not much here depends on ANTs standards of tensor data.
2281        ANTs xx,xy,xz,yy,yz,zz
2282        DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz
2283
2284    """
2285    # make the DTI - see 
2286    # https://dipy.org/documentation/1.7.0/examples_built/07_reconstruction/reconst_dti/#sphx-glr-examples-built-07-reconstruction-reconst-dti-py
2287    # By default, in DIPY, values are ordered as (Dxx, Dxy, Dyy, Dxz, Dyz, Dzz)
2288    # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
2289    reoind = np.array([0,1,3,2,4,5]) # arrays are faster than lists
2290    import dipy.reconst.dti as dti
2291    dtiut = dti.lower_triangular(tensormodel.quadratic_form)
2292    it = np.ndindex( reference_image.shape )
2293    yyind=2
2294    xzind=3
2295    if upper_triangular:
2296        yyind=3
2297        xzind=2
2298        for i in it: # convert to upper triangular
2299            dtiut[i] = dtiut[i][ reoind ] # do we care if this is doing extra work?
2300    if return_image:
2301        dtiAnts = ants.from_numpy(dtiut,has_components=True)
2302        ants.copy_image_info( reference_image, dtiAnts )
2303        return dtiAnts
2304    # copy these data into a tensor 
2305    dtinp = np.zeros(reference_image.shape + (3,3), dtype=float)  
2306    dtix = np.zeros((3,3), dtype=float)  
2307    it = np.ndindex( reference_image.shape )
2308    for i in it:
2309        dtivec = dtiut[i] # in ANTs - we have: [xx,xy,xz,yy,yz,zz]
2310        dtix[0,0]=dtivec[0]
2311        dtix[1,1]=dtivec[yyind] # 2 for LT
2312        dtix[2,2]=dtivec[5] 
2313        dtix[0,1]=dtix[1,0]=dtivec[1]
2314        dtix[0,2]=dtix[2,0]=dtivec[xzind] # 3 for LT
2315        dtix[1,2]=dtix[2,1]=dtivec[4]
2316        dtinp[i]=dtix
2317    return dtinp

extract DTI data from a dipy tensormodel

reference_image : antsImage defining physical space (3D)

tensormodel : from dipy e.g. the variable myoutx['dtrecon_LR_dewarp']['tensormodel'] if myoutx is produced my joint_dti_recon

upper_triangular: boolean otherwise use lower triangular coding

return_image : boolean return the ANTsImage form of DTI otherwise return an array

Returns

either an ANTsImage (dim=X.Y.Z with 6 component voxels, upper triangular form) or a 5D NumPy array (dim=X.Y.Z.3.3)

Notes

DiPy returns lower triangular form but ANTs expects upper triangular. Here, we default to the ANTs standard but could generalize in the future because not much here depends on ANTs standards of tensor data. ANTs xx,xy,xz,yy,yz,zz DiPy Dxx, Dxy, Dyy, Dxz, Dyz, Dzz

def dti_reg( image, avg_b0, avg_dwi, bvals=None, bvecs=None, b0_idx=None, type_of_transform='antsRegistrationSyNRepro[r]', total_sigma=3.0, fdOffset=2.0, mask_csf=False, brain_mask_eroded=None, output_directory=None, verbose=False, **kwargs):
2600def dti_reg(
2601    image,
2602    avg_b0,
2603    avg_dwi,
2604    bvals=None,
2605    bvecs=None,
2606    b0_idx=None,
2607    type_of_transform="antsRegistrationSyNRepro[r]",
2608    total_sigma=3.0,
2609    fdOffset=2.0,
2610    mask_csf=False,
2611    brain_mask_eroded=None,
2612    output_directory=None,
2613    verbose=False, **kwargs
2614):
2615    """
2616    Correct time-series data for motion - with optional deformation.
2617
2618    Arguments
2619    ---------
2620        image: antsImage, usually ND where D=4.
2621
2622        avg_b0: Fixed image b0 image
2623
2624        avg_dwi: Fixed dwi same space as b0 image
2625
2626        bvals: bvalues (file or array)
2627
2628        bvecs: bvecs (file or array)
2629
2630        b0_idx: indices of b0
2631
2632        type_of_transform : string
2633            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
2634            See ants registration for details.
2635
2636        fdOffset: offset value to use in framewise displacement calculation
2637
2638        mask_csf: boolean
2639
2640        brain_mask_eroded: optional mask that will trigger mixed interpolation
2641
2642        output_directory : string
2643            output will be placed in this directory plus a numeric extension.
2644
2645        verbose: boolean
2646
2647        kwargs: keyword args
2648            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2649
2650    Returns
2651    -------
2652    dict containing follow key/value pairs:
2653        `motion_corrected`: Moving image warped to space of fixed image.
2654        `motion_parameters`: transforms for each image in the time series.
2655        `FD`: Framewise displacement generalized for arbitrary transformations.
2656
2657    Notes
2658    -----
2659    Control extra arguments via kwargs. see ants.registration for details.
2660
2661    Example
2662    -------
2663    >>> import ants
2664    """
2665
2666    idim = image.dimension
2667    ishape = image.shape
2668    nTimePoints = ishape[idim - 1]
2669    FD = np.zeros(nTimePoints)
2670    if bvals is not None and bvecs is not None:
2671        if isinstance(bvecs, str):
2672            bvals, bvecs = read_bvals_bvecs( bvals , bvecs  )
2673        else: # assume we already read them
2674            bvals = bvals.copy()
2675            bvecs = bvecs.copy()
2676    if type_of_transform is None:
2677        return {
2678            "motion_corrected": image,
2679            "motion_parameters": None,
2680            "FD": FD,
2681            'bvals':bvals,
2682            'bvecs':bvecs
2683        }
2684
2685    from scipy.linalg import inv, polar
2686    from dipy.core.gradients import reorient_bvecs
2687
2688    remove_it=False
2689    if output_directory is None:
2690        remove_it=True
2691        output_directory = tempfile.mkdtemp()
2692    output_directory_w = output_directory + "/dti_reg/"
2693    os.makedirs(output_directory_w,exist_ok=True)
2694    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2695    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2696    if verbose:
2697        print(output_directory_w)
2698        print(ofnG)
2699        print(ofnL)
2700        print("remove_it " + str( remove_it ) )
2701
2702    if b0_idx is None:
2703        # b0_idx = segment_timeseries_by_meanvalue( image )['highermeans']
2704        b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
2705
2706    # first get a local deformation from slice to local avg space
2707    # then get a global deformation from avg to ref space
2708    ab0, adw = get_average_dwi_b0( image )
2709    # mask is used to roughly locate middle of brain
2710    mask = ants.threshold_image( ants.iMath(adw,'Normalize'), 0.1, 1.0 )
2711    if brain_mask_eroded is None:
2712        brain_mask_eroded = mask * 0 + 1
2713    motion_parameters = list()
2714    motion_corrected = list()
2715    centerOfMass = mask.get_center_of_mass()
2716    npts = pow(2, idim - 1)
2717    pointOffsets = np.zeros((npts, idim - 1))
2718    myrad = np.ones(idim - 1).astype(int).tolist()
2719    mask1vals = np.zeros(int(mask.sum()))
2720    mask1vals[round(len(mask1vals) / 2)] = 1
2721    mask1 = ants.make_image(mask, mask1vals)
2722    myoffsets = ants.get_neighborhood_in_mask(
2723        mask1, mask1, radius=myrad, spatial_info=True
2724    )["offsets"]
2725    mycols = list("xy")
2726    if idim - 1 == 3:
2727        mycols = list("xyz")
2728    useinds = list()
2729    for k in range(myoffsets.shape[0]):
2730        if abs(myoffsets[k, :]).sum() == (idim - 2):
2731            useinds.append(k)
2732        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2733    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2734
2735
2736    if verbose:
2737        print("begin global distortion correction")
2738    # initrig = tra_initializer(avg_b0, ab0, max_rotation=60, transform=['rigid'], verbose=verbose)
2739    if mask_csf:
2740        bcsf = ants.threshold_image( avg_b0,"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
2741    else:
2742        bcsf = ab0 * 0 + 1
2743
2744    initrig = ants.registration( avg_b0, ab0,'antsRegistrationSyNRepro[r]',outprefix=ofnG)
2745    deftx = ants.registration( avg_dwi, adw, 'SyNOnly',
2746        syn_metric='CC', syn_sampling=2,
2747        reg_iterations=[50,50,20],
2748        multivariate_extras=[ [ "CC", avg_b0, ab0, 1, 2 ]],
2749        initial_transform=initrig['fwdtransforms'][0],
2750        outprefix=ofnG
2751        )['fwdtransforms']
2752    if verbose:
2753        print("end global distortion correction")
2754
2755    if verbose:
2756        print("Progress:")
2757    counter = round( nTimePoints / 10 ) + 1
2758    for k in range(nTimePoints):
2759        if verbose and nTimePoints > 0 and ( ( k % counter ) ==  0 ) or ( k == (nTimePoints-1) ):
2760            myperc = round( k / nTimePoints * 100)
2761            print(myperc, end="%.", flush=True)
2762        if k in b0_idx:
2763            fixed=ants.image_clone( ab0 )
2764        else:
2765            fixed=ants.image_clone( adw )
2766        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2767        temp = ants.iMath(temp, "Normalize")
2768        txprefix = ofnL+str(k).zfill(4)+"rig_"
2769        txprefix2 = ofnL+str(k % 2).zfill(4)+"def_"
2770        if temp.numpy().var() > 0:
2771            myrig = ants.registration(
2772                    fixed, temp,
2773                    type_of_transform='antsRegistrationSyNRepro[r]',
2774                    outprefix=txprefix,
2775                    **kwargs
2776                )
2777            if type_of_transform == 'SyN':
2778                myreg = ants.registration(
2779                    fixed, temp,
2780                    type_of_transform='SyNOnly',
2781                    total_sigma=total_sigma, grad_step=0.1,
2782                    initial_transform=myrig['fwdtransforms'][0],
2783                    outprefix=txprefix2,
2784                    **kwargs
2785                )
2786            else:
2787                myreg = myrig
2788            fdptsTxI = ants.apply_transforms_to_points(
2789                idim - 1, fdpts, myrig["fwdtransforms"]
2790            )
2791            if k > 0 and motion_parameters[k - 1] != "NA":
2792                fdptsTxIminus1 = ants.apply_transforms_to_points(
2793                    idim - 1, fdpts, motion_parameters[k - 1]
2794                )
2795            else:
2796                fdptsTxIminus1 = fdptsTxI
2797            # take the absolute value, then the mean across columns, then the sum
2798            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2799            motion_parameters.append(myreg["fwdtransforms"])
2800        else:
2801            motion_parameters.append("NA")
2802
2803        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2804        if k in b0_idx:
2805            fixed=ants.image_clone( ab0 )
2806        else:
2807            fixed=ants.image_clone( adw )
2808        if temp.numpy().var() > 0:
2809            motion_parameters[k]=deftx+motion_parameters[k]
2810            img1w = apply_transforms_mixed_interpolation( avg_dwi,
2811                ants.slice_image(image, axis=idim - 1, idx=k),
2812                motion_parameters[k], mask=brain_mask_eroded )
2813            motion_corrected.append(img1w)
2814        else:
2815            motion_corrected.append(fixed)
2816
2817    if verbose:
2818        print("Reorient bvecs")
2819    if bvecs is not None:
2820        #    direction = target->GetDirection().GetTranspose() * img_mov->GetDirection().GetVnlMatrix();
2821        rebase = np.dot( np.transpose( avg_b0.direction  ), ab0.direction )
2822        bvecs = bvec_reorientation( motion_parameters, bvecs, rebase )
2823
2824    if remove_it:
2825        import shutil
2826        shutil.rmtree(output_directory, ignore_errors=True )
2827
2828    if verbose:
2829        print("Done")
2830    d4siz = list(avg_b0.shape)
2831    d4siz.append( 2 )
2832    spc = list(ants.get_spacing( avg_b0 ))
2833    spc.append( 1.0 )
2834    mydir = ants.get_direction( avg_b0 )
2835    mydir4d = ants.get_direction( image )
2836    mydir4d[0:3,0:3]=mydir
2837    myorg = list(ants.get_origin( avg_b0 ))
2838    myorg.append( 0.0 )
2839    avg_b0_4d = ants.make_image(d4siz,0,spacing=spc,origin=myorg,direction=mydir4d)
2840    return {
2841        "motion_corrected": ants.list_to_ndimage(avg_b0_4d, motion_corrected),
2842        "motion_parameters": motion_parameters,
2843        "FD": FD,
2844        'bvals':bvals,
2845        'bvecs':bvecs
2846    }

Correct time-series data for motion - with optional deformation.

Arguments

image: antsImage, usually ND where D=4.

avg_b0: Fixed image b0 image

avg_dwi: Fixed dwi same space as b0 image

bvals: bvalues (file or array)

bvecs: bvecs (file or array)

b0_idx: indices of b0

type_of_transform : string
    A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
    See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

mask_csf: boolean

brain_mask_eroded: optional mask that will trigger mixed interpolation

output_directory : string
    output will be placed in this directory plus a numeric extension.

verbose: boolean

kwargs: keyword args
    extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
def mc_reg( image, fixed=None, type_of_transform='antsRegistrationSyNRepro[r]', mask=None, total_sigma=3.0, fdOffset=2.0, output_directory=None, verbose=False, **kwargs):
2849def mc_reg(
2850    image,
2851    fixed=None,
2852    type_of_transform="antsRegistrationSyNRepro[r]",
2853    mask=None,
2854    total_sigma=3.0,
2855    fdOffset=2.0,
2856    output_directory=None,
2857    verbose=False, **kwargs
2858):
2859    """
2860    Correct time-series data for motion - with deformation.
2861
2862    Arguments
2863    ---------
2864        image: antsImage, usually ND where D=4.
2865
2866        fixed: Fixed image to register all timepoints to.  If not provided,
2867            mean image is used.
2868
2869        type_of_transform : string
2870            A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
2871            See ants registration for details.
2872
2873        fdOffset: offset value to use in framewise displacement calculation
2874
2875        output_directory : string
2876            output will be named with this prefix plus a numeric extension.
2877
2878        verbose: boolean
2879
2880        kwargs: keyword args
2881            extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
2882
2883    Returns
2884    -------
2885    dict containing follow key/value pairs:
2886        `motion_corrected`: Moving image warped to space of fixed image.
2887        `motion_parameters`: transforms for each image in the time series.
2888        `FD`: Framewise displacement generalized for arbitrary transformations.
2889
2890    Notes
2891    -----
2892    Control extra arguments via kwargs. see ants.registration for details.
2893
2894    Example
2895    -------
2896    >>> import ants
2897    >>> fi = ants.image_read(ants.get_ants_data('ch2'))
2898    >>> mytx = ants.motion_correction( fi )
2899    """
2900    remove_it=False
2901    if output_directory is None:
2902        remove_it=True
2903        output_directory = tempfile.mkdtemp()
2904    output_directory_w = output_directory + "/mc_reg/"
2905    os.makedirs(output_directory_w,exist_ok=True)
2906    ofnG = tempfile.NamedTemporaryFile(delete=False,suffix='global_deformation',dir=output_directory_w).name
2907    ofnL = tempfile.NamedTemporaryFile(delete=False,suffix='local_deformation',dir=output_directory_w).name
2908    if verbose:
2909        print(output_directory_w)
2910        print(ofnG)
2911        print(ofnL)
2912
2913    idim = image.dimension
2914    ishape = image.shape
2915    nTimePoints = ishape[idim - 1]
2916    if fixed is None:
2917        fixed = ants.get_average_of_timeseries( image )
2918    if mask is None:
2919        mask = ants.get_mask(fixed)
2920    FD = np.zeros(nTimePoints)
2921    motion_parameters = list()
2922    motion_corrected = list()
2923    centerOfMass = mask.get_center_of_mass()
2924    npts = pow(2, idim - 1)
2925    pointOffsets = np.zeros((npts, idim - 1))
2926    myrad = np.ones(idim - 1).astype(int).tolist()
2927    mask1vals = np.zeros(int(mask.sum()))
2928    mask1vals[round(len(mask1vals) / 2)] = 1
2929    mask1 = ants.make_image(mask, mask1vals)
2930    myoffsets = ants.get_neighborhood_in_mask(
2931        mask1, mask1, radius=myrad, spatial_info=True
2932    )["offsets"]
2933    mycols = list("xy")
2934    if idim - 1 == 3:
2935        mycols = list("xyz")
2936    useinds = list()
2937    for k in range(myoffsets.shape[0]):
2938        if abs(myoffsets[k, :]).sum() == (idim - 2):
2939            useinds.append(k)
2940        myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
2941    fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
2942    if verbose:
2943        print("Progress:")
2944    counter = 0
2945    for k in range(nTimePoints):
2946        mycount = round(k / nTimePoints * 100)
2947        if verbose and mycount == counter:
2948            counter = counter + 10
2949            print(mycount, end="%.", flush=True)
2950        temp = ants.slice_image(image, axis=idim - 1, idx=k)
2951        temp = ants.iMath(temp, "Normalize")
2952        if temp.numpy().var() > 0:
2953            myrig = ants.registration(
2954                    fixed, temp,
2955                    type_of_transform='antsRegistrationSyNRepro[r]',
2956                    outprefix=ofnL+str(k).zfill(4)+"_",
2957                    **kwargs
2958                )
2959            if type_of_transform == 'SyN':
2960                myreg = ants.registration(
2961                    fixed, temp,
2962                    type_of_transform='SyNOnly',
2963                    total_sigma=total_sigma,
2964                    initial_transform=myrig['fwdtransforms'][0],
2965                    outprefix=ofnL+str(k).zfill(4)+"_",
2966                    **kwargs
2967                )
2968            else:
2969                myreg = myrig
2970            fdptsTxI = ants.apply_transforms_to_points(
2971                idim - 1, fdpts, myreg["fwdtransforms"]
2972            )
2973            if k > 0 and motion_parameters[k - 1] != "NA":
2974                fdptsTxIminus1 = ants.apply_transforms_to_points(
2975                    idim - 1, fdpts, motion_parameters[k - 1]
2976                )
2977            else:
2978                fdptsTxIminus1 = fdptsTxI
2979            # take the absolute value, then the mean across columns, then the sum
2980            FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
2981            motion_parameters.append(myreg["fwdtransforms"])
2982            img1w = ants.apply_transforms( fixed,
2983                ants.slice_image(image, axis=idim - 1, idx=k),
2984                myreg["fwdtransforms"] )
2985            motion_corrected.append(img1w)
2986        else:
2987            motion_parameters.append("NA")
2988            motion_corrected.append(temp)
2989
2990    if remove_it:
2991        import shutil
2992        shutil.rmtree(output_directory, ignore_errors=True )
2993
2994    if verbose:
2995        print("Done")
2996    return {
2997        "motion_corrected": ants.list_to_ndimage(image, motion_corrected),
2998        "motion_parameters": motion_parameters,
2999        "FD": FD,
3000    }

Correct time-series data for motion - with deformation.

Arguments

image: antsImage, usually ND where D=4.

fixed: Fixed image to register all timepoints to.  If not provided,
    mean image is used.

type_of_transform : string
    A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
    See ants registration for details.

fdOffset: offset value to use in framewise displacement calculation

output_directory : string
    output will be named with this prefix plus a numeric extension.

verbose: boolean

kwargs: keyword args
    extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.

Returns

dict containing follow key/value pairs: motion_corrected: Moving image warped to space of fixed image. motion_parameters: transforms for each image in the time series. FD: Framewise displacement generalized for arbitrary transformations.

Notes

Control extra arguments via kwargs. see ants.registration for details.

Example

>>> import ants
>>> fi = ants.image_read(ants.get_ants_data('ch2'))
>>> mytx = ants.motion_correction( fi )
def get_data(name=None, force_download=False, version=26, target_extension='.csv'):
3123def get_data( name=None, force_download=False, version=26, target_extension='.csv' ):
3124    """
3125    Get ANTsPyMM data filename
3126
3127    The first time this is called, it will download data to ~/.antspymm.
3128    After, it will just read data from disk.  The ~/.antspymm may need to
3129    be periodically deleted in order to ensure data is current.
3130
3131    Arguments
3132    ---------
3133    name : string
3134        name of data tag to retrieve
3135        Options:
3136            - 'all'
3137
3138    force_download: boolean
3139
3140    version: version of data to download (integer)
3141
3142    Returns
3143    -------
3144    string
3145        filepath of selected data
3146
3147    Example
3148    -------
3149    >>> import antspymm
3150    >>> antspymm.get_data()
3151    """
3152    os.makedirs(DATA_PATH, exist_ok=True)
3153
3154    def mv_subfolder_files(folder, verbose=False):
3155        """
3156        Move files from subfolders to the parent folder.
3157
3158        Parameters
3159        ----------
3160        folder : str
3161            Path to the folder.
3162        verbose : bool, optional
3163            Print information about the files and folders being processed (default is False).
3164
3165        Returns
3166        -------
3167        None
3168        """
3169        import os
3170        import shutil
3171        for root, dirs, files in os.walk(folder):
3172            if verbose:
3173                print(f"Processing directory: {root}")
3174                print(f"Subdirectories: {dirs}")
3175                print(f"Files: {files}")
3176            
3177            for file in files:
3178                if root != folder:
3179                    if verbose:
3180                        print(f"Moving file: {file} from {root} to {folder}")
3181                    shutil.move(os.path.join(root, file), folder)
3182            
3183            for dir in dirs:
3184                if root != folder:
3185                    if verbose:
3186                        print(f"Removing directory: {dir} from {root}")
3187                    shutil.rmtree(os.path.join(root, dir))
3188
3189    def download_data( version ):
3190        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
3191        target_file_name = "16912366.zip"
3192        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
3193            cache_subdir=DATA_PATH, extract = True )
3194        mv_subfolder_files( os.path.expanduser("~/.antspymm"), False )
3195        os.remove( DATA_PATH + target_file_name )
3196
3197    if force_download:
3198        download_data( version = version )
3199
3200
3201    files = []
3202    for fname in os.listdir(DATA_PATH):
3203        if ( fname.endswith(target_extension) ) :
3204            fname = os.path.join(DATA_PATH, fname)
3205            files.append(fname)
3206
3207    if len( files ) == 0 :
3208        download_data( version = version )
3209        for fname in os.listdir(DATA_PATH):
3210            if ( fname.endswith(target_extension) ) :
3211                fname = os.path.join(DATA_PATH, fname)
3212                files.append(fname)
3213
3214
3215    if name == 'all':
3216        return files
3217
3218    datapath = None
3219
3220    for fname in os.listdir(DATA_PATH):
3221        mystem = (Path(fname).resolve().stem)
3222        mystem = (Path(mystem).resolve().stem)
3223        mystem = (Path(mystem).resolve().stem)
3224        if ( name == mystem and fname.endswith(target_extension) ) :
3225            datapath = os.path.join(DATA_PATH, fname)
3226
3227    return datapath

Get ANTsPyMM data filename

The first time this is called, it will download data to ~/.antspymm. After, it will just read data from disk. The ~/.antspymm may need to be periodically deleted in order to ensure data is current.

Arguments

name : string name of data tag to retrieve Options: - 'all'

force_download: boolean

version: version of data to download (integer)

Returns

string filepath of selected data

Example

>>> import antspymm
>>> antspymm.get_data()
def get_models(version=3, force_download=True):
3230def get_models( version=3, force_download=True ):
3231    """
3232    Get ANTsPyMM data models
3233
3234    force_download: boolean
3235
3236    Returns
3237    -------
3238    None
3239
3240    """
3241    os.makedirs(DATA_PATH, exist_ok=True)
3242
3243    def download_data( version ):
3244        url = "https://figshare.com/ndownloader/articles/21718412/versions/"+str(version)
3245        target_file_name = "21718412.zip"
3246        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
3247            cache_subdir=DATA_PATH, extract = True )
3248        os.remove( DATA_PATH + target_file_name )
3249
3250    if force_download:
3251        download_data( version = version )
3252    return

Get ANTsPyMM data models

force_download: boolean

Returns

None

def get_valid_modalities(long=False, asString=False, qc=False):
625def get_valid_modalities( long=False, asString=False, qc=False ):
626    """
627    return a list of valid modality identifiers used in NRG modality designation
628    and that can be processed by this package.
629
630    long - return the long version
631
632    asString - concat list to string
633    """
634    if long:
635        mymod = ["T1w", "NM2DMT", "rsfMRI", "rsfMRI_LR", "rsfMRI_RL", "rsfMRILR", "rsfMRIRL", "DTI", "DTI_LR","DTI_RL",  "DTILR","DTIRL","T2Flair", "dwi", "dwi_ap", "dwi_pa", "func", "func_ap", "func_pa", "perf", 'pet3d']
636    elif qc:
637        mymod = [ 'T1w', 'T2Flair', 'NM2DMT', 'DTI', 'DTIdwi','DTIb0', 'rsfMRI', "perf", 'pet3d' ]
638    else:
639        mymod = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI", "perf", 'pet3d' ]
640    if not asString:
641        return mymod
642    else:
643        mymodchar=""
644        for x in mymod:
645            mymodchar = mymodchar + " " + str(x)
646        return mymodchar

return a list of valid modality identifiers used in NRG modality designation and that can be processed by this package.

long - return the long version

asString - concat list to string

def dewarp_imageset( image_list, initial_template=None, iterations=None, padding=0, target_idx=[0], **kwargs):
3256def dewarp_imageset( image_list, initial_template=None,
3257    iterations=None, padding=0, target_idx=[0], **kwargs ):
3258    """
3259    Dewarp a set of images
3260
3261    Makes simplifying heuristic decisions about how to transform an image set
3262    into an unbiased reference space.  Will handle plenty of decisions
3263    automatically so beware.  Computes an average shape space for the images
3264    and transforms them to that space.
3265
3266    Arguments
3267    ---------
3268    image_list : list containing antsImages 2D, 3D or 4D
3269
3270    initial_template : optional
3271
3272    iterations : number of template building iterations
3273
3274    padding:  will pad the images by an integer amount to limit edge effects
3275
3276    target_idx : the target indices for the time series over which we should average;
3277        a list of integer indices into the last axis of the input images.
3278
3279    kwargs : keyword args
3280        arguments passed to ants registration - these must be set explicitly
3281
3282    Returns
3283    -------
3284    a dictionary with the mean image and the list of the transformed images as
3285    well as motion correction parameters for each image in the input list
3286
3287    Example
3288    -------
3289    >>> import antspymm
3290    """
3291    outlist = []
3292    avglist = []
3293    if len(image_list[0].shape) > 3:
3294        imagetype = 3
3295        for k in range(len(image_list)):
3296            for j in range(len(target_idx)):
3297                avglist.append( ants.slice_image( image_list[k], axis=3, idx=target_idx[j] ) )
3298    else:
3299        imagetype = 0
3300        avglist=image_list
3301
3302    pw=[]
3303    for k in range(len(avglist[0].shape)):
3304        pw.append( padding )
3305    for k in range(len(avglist)):
3306        avglist[k] = ants.pad_image( avglist[k], pad_width=pw  )
3307
3308    if initial_template is None:
3309        initial_template = avglist[0] * 0
3310        for k in range(len(avglist)):
3311            initial_template = initial_template + avglist[k]/len(avglist)
3312
3313    if iterations is None:
3314        iterations = 2
3315
3316    btp = ants.build_template(
3317        initial_template=initial_template,
3318        image_list=avglist,
3319        gradient_step=0.5, blending_weight=0.8,
3320        iterations=iterations, verbose=False, **kwargs )
3321
3322    # last - warp all images to this frame
3323    mocoplist = []
3324    mocofdlist = []
3325    reglist = []
3326    for k in range(len(image_list)):
3327        if imagetype == 3:
3328            moco0 = ants.motion_correction( image=image_list[k], fixed=btp, type_of_transform='antsRegistrationSyNRepro[r]' )
3329            mocoplist.append( moco0['motion_parameters'] )
3330            mocofdlist.append( moco0['FD'] )
3331            locavg = ants.slice_image( moco0['motion_corrected'], axis=3, idx=0 ) * 0.0
3332            for j in range(len(target_idx)):
3333                locavg = locavg + ants.slice_image( moco0['motion_corrected'], axis=3, idx=target_idx[j] )
3334            locavg = locavg * 1.0 / len(target_idx)
3335        else:
3336            locavg = image_list[k]
3337        reg = ants.registration( btp, locavg, **kwargs )
3338        reglist.append( reg )
3339        if imagetype == 3:
3340            myishape = image_list[k].shape
3341            mytslength = myishape[ len(myishape) - 1 ]
3342            mywarpedlist = []
3343            for j in range(mytslength):
3344                locimg = ants.slice_image( image_list[k], axis=3, idx = j )
3345                mywarped = ants.apply_transforms( btp, locimg,
3346                    reg['fwdtransforms'] + moco0['motion_parameters'][j], imagetype=0 )
3347                mywarpedlist.append( mywarped )
3348            mywarped = ants.list_to_ndimage( image_list[k], mywarpedlist )
3349        else:
3350            mywarped = ants.apply_transforms( btp, image_list[k], reg['fwdtransforms'], imagetype=imagetype )
3351        outlist.append( mywarped )
3352
3353    return {
3354        'dewarpedmean':btp,
3355        'dewarped':outlist,
3356        'deformable_registrations': reglist,
3357        'FD': mocofdlist,
3358        'motionparameters': mocoplist }

Dewarp a set of images

Makes simplifying heuristic decisions about how to transform an image set into an unbiased reference space. Will handle plenty of decisions automatically so beware. Computes an average shape space for the images and transforms them to that space.

Arguments

image_list : list containing antsImages 2D, 3D or 4D

initial_template : optional

iterations : number of template building iterations

padding: will pad the images by an integer amount to limit edge effects

target_idx : the target indices for the time series over which we should average; a list of integer indices into the last axis of the input images.

kwargs : keyword args arguments passed to ants registration - these must be set explicitly

Returns

a dictionary with the mean image and the list of the transformed images as well as motion correction parameters for each image in the input list

Example

>>> import antspymm
def super_res_mcimage( image, srmodel, truncation=[0.0001, 0.995], poly_order='hist', target_range=[0, 1], isotropic=False, verbose=False):
3361def super_res_mcimage( image,
3362    srmodel,
3363    truncation=[0.0001,0.995],
3364    poly_order='hist',
3365    target_range=[0,1],
3366    isotropic = False,
3367    verbose=False ):
3368    """
3369    Super resolution on a timeseries or multi-channel image
3370
3371    Arguments
3372    ---------
3373    image : an antsImage
3374
3375    srmodel : a tensorflow fully convolutional model
3376
3377    truncation :  quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]
3378
3379    poly_order : if not None, will fit a global regression model to map
3380        intensity back to original histogram space; if 'hist' will match
3381        by histogram matching - ants.histogram_match_image
3382
3383    target_range : 2-element tuple
3384        a tuple or array defining the (min, max) of the input image
3385        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
3386        intensity. This range should match the mapping used in the training
3387        of the network.
3388
3389    isotropic : boolean
3390
3391    verbose : boolean
3392
3393    Returns
3394    -------
3395    super resolution version of the image
3396
3397    Example
3398    -------
3399    >>> import antspymm
3400    """
3401    idim = image.dimension
3402    ishape = image.shape
3403    nTimePoints = ishape[idim - 1]
3404    mcsr = list()
3405    for k in range(nTimePoints):
3406        if verbose and (( k % 5 ) == 0 ):
3407            mycount = round(k / nTimePoints * 100)
3408            print(mycount, end="%.", flush=True)
3409        temp = ants.slice_image( image, axis=idim - 1, idx=k )
3410        temp = ants.iMath( temp, "TruncateIntensity", truncation[0], truncation[1] )
3411        mysr = antspynet.apply_super_resolution_model_to_image( temp, srmodel,
3412            target_range = target_range )
3413        if poly_order is not None:
3414            bilin = ants.resample_image_to_target( temp, mysr )
3415            if poly_order == 'hist':
3416                mysr = ants.histogram_match_image( mysr, bilin )
3417            else:
3418                mysr = antspynet.regression_match_image( mysr, bilin, poly_order = poly_order )
3419        if isotropic:
3420            mysr = down2iso( mysr )
3421        if k == 0:
3422            upshape = list()
3423            for j in range(len(ishape)-1):
3424                upshape.append( mysr.shape[j] )
3425            upshape.append( ishape[ idim-1 ] )
3426            if verbose:
3427                print("SR will be of voxel size:" + str(upshape) )
3428        mcsr.append( mysr )
3429
3430    upshape = list()
3431    for j in range(len(ishape)-1):
3432        upshape.append( mysr.shape[j] )
3433    upshape.append( ishape[ idim-1 ] )
3434    if verbose:
3435        print("SR will be of voxel size:" + str(upshape) )
3436
3437    imageup = ants.resample_image( image, upshape, use_voxels = True )
3438    if verbose:
3439        print("Done")
3440
3441    return ants.list_to_ndimage( imageup, mcsr )

Super resolution on a timeseries or multi-channel image

Arguments

image : an antsImage

srmodel : a tensorflow fully convolutional model

truncation : quantiles at which we truncate intensities to limit impact of outliers e.g. [0.005,0.995]

poly_order : if not None, will fit a global regression model to map intensity back to original histogram space; if 'hist' will match by histogram matching - ants.histogram_match_image

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

isotropic : boolean

verbose : boolean

Returns

super resolution version of the image

Example

>>> import antspymm
def segment_timeseries_by_meanvalue(image, quantile=0.995):
3485def segment_timeseries_by_meanvalue( image, quantile = 0.995 ):
3486    """
3487    Identify indices of a time series where we assume there is a different mean
3488    intensity over the volumes.  The indices of volumes with higher and lower
3489    intensities is returned.  Can be used to automatically identify B0 volumes
3490    in DWI timeseries.
3491
3492    Arguments
3493    ---------
3494    image : an antsImage holding B0 and DWI
3495
3496    quantile : a quantile for splitting the indices of the volume - should be greater than 0.5
3497
3498    Returns
3499    -------
3500    dictionary holding the two sets of indices
3501
3502    Example
3503    -------
3504    >>> import antspymm
3505    """
3506    ishape = image.shape
3507    lastdim = len(ishape)-1
3508    meanvalues = list()
3509    for x in range(ishape[lastdim]):
3510        meanvalues.append(  ants.slice_image( image, axis=lastdim, idx=x ).mean() )
3511    myhiq = np.quantile( meanvalues, quantile )
3512    myloq = np.quantile( meanvalues, 1.0 - quantile )
3513    lowerindices = list()
3514    higherindices = list()
3515    for x in range(len(meanvalues)):
3516        hiabs = abs( meanvalues[x] - myhiq )
3517        loabs = abs( meanvalues[x] - myloq )
3518        if hiabs < loabs:
3519            higherindices.append(x)
3520        else:
3521            lowerindices.append(x)
3522
3523    return {
3524    'lowermeans':lowerindices,
3525    'highermeans':higherindices }

Identify indices of a time series where we assume there is a different mean intensity over the volumes. The indices of volumes with higher and lower intensities is returned. Can be used to automatically identify B0 volumes in DWI timeseries.

Arguments

image : an antsImage holding B0 and DWI

quantile : a quantile for splitting the indices of the volume - should be greater than 0.5

Returns

dictionary holding the two sets of indices

Example

>>> import antspymm
def get_average_rsf(x, min_t=10, max_t=35):
3528def get_average_rsf( x, min_t=10, max_t=35 ):
3529    """
3530    automatically generates the average bold image with quick registration
3531
3532    returns:
3533        avg_bold
3534    """
3535    output_directory = tempfile.mkdtemp()
3536    ofn = output_directory + "/w"
3537    bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3538    oavg = ants.slice_image( x, axis=3, idx=0 )
3539    if x.shape[3] <= min_t:
3540        min_t=0
3541    if x.shape[3] <= max_t:
3542        max_t=x.shape[3]
3543    for myidx in range(min_t,max_t):
3544        b0 = ants.slice_image( x, axis=3, idx=myidx)
3545        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3546    bavg = ants.iMath( bavg, 'Normalize' )
3547    oavg = ants.image_clone( bavg )
3548    bavg = oavg * 0.0
3549    for myidx in range(min_t,max_t):
3550        b0 = ants.slice_image( x, axis=3, idx=myidx)
3551        bavg = bavg + ants.registration(oavg,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3552    import shutil
3553    shutil.rmtree(output_directory, ignore_errors=True )
3554    bavg = ants.iMath( bavg, 'Normalize' )
3555    return bavg
3556    # return ants.n4_bias_field_correction(bavg, mask=ants.get_mask( bavg ) )

automatically generates the average bold image with quick registration

returns: avg_bold

def get_average_dwi_b0(x, fixed_b0=None, fixed_dwi=None, fast=False):
3559def get_average_dwi_b0( x, fixed_b0=None, fixed_dwi=None, fast=False ):
3560    """
3561    automatically generates the average b0 and dwi and outputs both;
3562    maps dwi to b0 space at end.
3563
3564    x : input image
3565
3566    fixed_b0 : alernative reference space
3567
3568    fixed_dwi : alernative reference space
3569
3570    fast : boolean
3571
3572    returns:
3573        avg_b0, avg_dwi
3574    """
3575    output_directory = tempfile.mkdtemp()
3576    ofn = output_directory + "/w"
3577    temp = segment_timeseries_by_meanvalue( x )
3578    b0_idx = temp['highermeans']
3579    non_b0_idx = temp['lowermeans']
3580    if ( fixed_b0 is None and fixed_dwi is None ) or fast:
3581        xavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3582        bavg = ants.slice_image( x, axis=3, idx=0 ) * 0.0
3583        fixed_b0_use = ants.slice_image( x, axis=3, idx=b0_idx[0] )
3584        fixed_dwi_use = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
3585    else:
3586        temp_b0 = ants.slice_image( x, axis=3, idx=b0_idx[0] )
3587        temp_dwi = ants.slice_image( x, axis=3, idx=non_b0_idx[0] )
3588        xavg = fixed_b0 * 0.0
3589        bavg = fixed_b0 * 0.0
3590        tempreg = ants.registration( fixed_b0, temp_b0,'antsRegistrationSyNRepro[r]')
3591        fixed_b0_use = tempreg['warpedmovout']
3592        fixed_dwi_use = ants.apply_transforms( fixed_b0, temp_dwi, tempreg['fwdtransforms'] )
3593    for myidx in range(x.shape[3]):
3594        b0 = ants.slice_image( x, axis=3, idx=myidx)
3595        if not fast:
3596            if not myidx in b0_idx:
3597                xavg = xavg + ants.registration(fixed_dwi_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3598            else:
3599                bavg = bavg + ants.registration(fixed_b0_use,b0,'antsRegistrationSyNRepro[r]',outprefix=ofn)['warpedmovout']
3600        else:
3601            if not myidx in b0_idx:
3602                xavg = xavg + b0
3603            else:
3604                bavg = bavg + b0
3605    bavg = ants.iMath( bavg, 'Normalize' )
3606    xavg = ants.iMath( xavg, 'Normalize' )
3607    import shutil
3608    shutil.rmtree(output_directory, ignore_errors=True )
3609    avgb0=ants.n4_bias_field_correction(bavg)
3610    avgdwi=ants.n4_bias_field_correction(xavg)
3611    avgdwi=ants.registration( avgb0, avgdwi, 'antsRegistrationSyNRepro[r]' )['warpedmovout']
3612    return avgb0, avgdwi

automatically generates the average b0 and dwi and outputs both; maps dwi to b0 space at end.

x : input image

fixed_b0 : alernative reference space

fixed_dwi : alernative reference space

fast : boolean

returns: avg_b0, avg_dwi

def dti_template( b_image_list=None, w_image_list=None, iterations=5, gradient_step=0.5, mask_csf=False, average_both=True, verbose=False):
3614def dti_template(
3615    b_image_list=None,
3616    w_image_list=None,
3617    iterations=5,
3618    gradient_step=0.5,
3619    mask_csf=False,
3620    average_both=True,
3621    verbose=False
3622):
3623    """
3624    two channel version of build_template
3625
3626    returns:
3627        avg_b0, avg_dwi
3628    """
3629    output_directory = tempfile.mkdtemp()
3630    mydeftx = tempfile.NamedTemporaryFile(delete=False,dir=output_directory).name
3631    tmp = tempfile.NamedTemporaryFile(delete=False,dir=output_directory,suffix=".nii.gz")
3632    wavgfn = tmp.name
3633    tmp2 = tempfile.NamedTemporaryFile(delete=False,dir=output_directory)
3634    comptx = tmp2.name
3635    weights = np.repeat(1.0 / len(b_image_list), len(b_image_list))
3636    weights = [x / sum(weights) for x in weights]
3637    w_initial_template = w_image_list[0]
3638    b_initial_template = b_image_list[0]
3639    b_initial_template = ants.iMath(b_initial_template,"Normalize")
3640    w_initial_template = ants.iMath(w_initial_template,"Normalize")
3641    if mask_csf:
3642        bcsf0 = ants.threshold_image( b_image_list[0],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
3643        bcsf1 = ants.threshold_image( b_image_list[1],"Otsu",2).threshold_image(1,1).morphology("open",1).iMath("GetLargestComponent")
3644    else:
3645        bcsf0 = b_image_list[0] * 0 + 1
3646        bcsf1 = b_image_list[1] * 0 + 1
3647    bavg = b_initial_template.clone() * bcsf0
3648    wavg = w_initial_template.clone() * bcsf0
3649    bcsf = [ bcsf0, bcsf1 ]
3650    for i in range(iterations):
3651        for k in range(len(w_image_list)):
3652            fimg=wavg
3653            mimg=w_image_list[k] * bcsf[k]
3654            fimg2=bavg
3655            mimg2=b_image_list[k] * bcsf[k]
3656            w1 = ants.registration(
3657                fimg, mimg, type_of_transform='antsRegistrationSyNQuickRepro[s]',
3658                    multivariate_extras= [ [ "CC", fimg2, mimg2, 1, 2 ]],
3659                    outprefix=mydeftx,
3660                    verbose=0 )
3661            txname = ants.apply_transforms(wavg, wavg,
3662                w1["fwdtransforms"], compose=comptx )
3663            if k == 0:
3664                txavg = ants.image_read(txname) * weights[k]
3665                wavgnew = ants.apply_transforms( wavg,
3666                    w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3667                bavgnew = ants.apply_transforms( wavg,
3668                    b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3669            else:
3670                txavg = txavg + ants.image_read(txname) * weights[k]
3671                if i >= (iterations-2) and average_both:
3672                    wavgnew = wavgnew+ants.apply_transforms( wavg,
3673                        w_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3674                    bavgnew = bavgnew+ants.apply_transforms( wavg,
3675                        b_image_list[k] * bcsf[k], txname ).iMath("Normalize")
3676        if verbose:
3677            print("iteration:",str(i),str(txavg.abs().mean()))
3678        wscl = (-1.0) * gradient_step
3679        txavg = txavg * wscl
3680        ants.image_write( txavg, wavgfn )
3681        wavg = ants.apply_transforms(wavg, wavgnew, wavgfn).iMath("Normalize")
3682        bavg = ants.apply_transforms(bavg, bavgnew, wavgfn).iMath("Normalize")
3683    import shutil
3684    shutil.rmtree( output_directory, ignore_errors=True )
3685    if verbose:
3686        print("done")
3687    return bavg, wavg

two channel version of build_template

returns: avg_b0, avg_dwi

def t1_based_dwi_brain_extraction( t1w_head, t1w, dwi, b0_idx=None, transform='antsRegistrationSyNRepro[r]', deform=None, verbose=False):
3709def t1_based_dwi_brain_extraction(
3710    t1w_head,
3711    t1w,
3712    dwi,
3713    b0_idx = None,
3714    transform='antsRegistrationSyNRepro[r]',
3715    deform=None,
3716    verbose=False
3717):
3718    """
3719    Map a t1-based brain extraction to b0 and return a mask and average b0
3720
3721    Arguments
3722    ---------
3723    t1w_head : an antsImage of the hole head
3724
3725    t1w : an antsImage probably but not necessarily T1-weighted
3726
3727    dwi : an antsImage holding B0 and DWI
3728
3729    b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess
3730
3731    transform : string Rigid or other ants.registration tx type
3732
3733    deform : follow up transform with deformation
3734
3735    Returns
3736    -------
3737    dictionary holding the avg_b0 and its mask
3738
3739    Example
3740    -------
3741    >>> import antspymm
3742    """
3743    t1w_use = ants.iMath( t1w, "Normalize" )
3744    t1bxt = ants.threshold_image( t1w_use, 0.05, 1 ).iMath("FillHoles")
3745    if b0_idx is None:
3746        b0_idx = segment_timeseries_by_meanvalue( dwi )['highermeans']
3747    # first get the average b0
3748    if len( b0_idx ) > 1:
3749        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] ).iMath("Normalize")
3750        for n in range(1,len(b0_idx)):
3751            temp = ants.slice_image( dwi, axis=3, idx=b0_idx[n] )
3752            reg = ants.registration( b0_avg, temp, 'antsRegistrationSyNRepro[r]' )
3753            b0_avg = b0_avg + ants.iMath( reg['warpedmovout'], "Normalize")
3754    else:
3755        b0_avg = ants.slice_image( dwi, axis=3, idx=b0_idx[0] )
3756    b0_avg = ants.iMath(b0_avg,"Normalize")
3757    reg = tra_initializer( b0_avg, t1w, n_simulations=12,   verbose=verbose )
3758    if deform is not None:
3759        reg = ants.registration( b0_avg, t1w,
3760            'SyNOnly',
3761            total_sigma=0.5,
3762            initial_transform=reg['fwdtransforms'][0],
3763            verbose=False )
3764    outmsk = ants.apply_transforms( b0_avg, t1bxt, reg['fwdtransforms'], interpolator='linear').threshold_image( 0.5, 1.0 )
3765    return  {
3766    'b0_avg':b0_avg,
3767    'b0_mask':outmsk }

Map a t1-based brain extraction to b0 and return a mask and average b0

Arguments

t1w_head : an antsImage of the hole head

t1w : an antsImage probably but not necessarily T1-weighted

dwi : an antsImage holding B0 and DWI

b0_idx : the indices of the B0; if None, use segment_timeseries_by_meanvalue to guess

transform : string Rigid or other ants.registration tx type

deform : follow up transform with deformation

Returns

dictionary holding the avg_b0 and its mask

Example

>>> import antspymm
def mc_denoise(x, ratio=0.5):
3769def mc_denoise( x, ratio = 0.5 ):
3770    """
3771    ants denoising for timeseries (4D)
3772
3773    Arguments
3774    ---------
3775    x : an antsImage 4D
3776
3777    ratio : weight between 1 and 0 - lower weights bring result closer to initial image
3778
3779    Returns
3780    -------
3781    denoised time series
3782
3783    """
3784    dwpimage = []
3785    for myidx in range(x.shape[3]):
3786        b0 = ants.slice_image( x, axis=3, idx=myidx)
3787        dnzb0 = ants.denoise_image( b0, p=1,r=1,noise_model='Gaussian' )
3788        dwpimage.append( dnzb0 * ratio + b0 * (1.0-ratio) )
3789    return ants.list_to_ndimage( x, dwpimage )

ants denoising for timeseries (4D)

Arguments

x : an antsImage 4D

ratio : weight between 1 and 0 - lower weights bring result closer to initial image

Returns

denoised time series

def tsnr(x, mask, indices=None):
3791def tsnr( x, mask, indices=None ):
3792    """
3793    3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1
3794
3795    x: image
3796
3797    mask : mask
3798
3799    indices: indices to use
3800
3801    returns a 3D image
3802    """
3803    M = ants.timeseries_to_matrix( x, mask )
3804    M = M - M.min()
3805    M = M / M.max()
3806    if indices is not None:
3807        M=M[indices,:]
3808    stdM = np.std(M, axis=0 )
3809    stdM[np.isnan(stdM)] = 0
3810    tt = round( 0.975*100 )
3811    threshold_std = np.percentile( stdM, tt )
3812    tsnrimage = ants.make_image( mask, stdM )
3813    return tsnrimage

3D temporal snr image from a 4D time series image ... the matrix is normalized to range of 0,1

x: image

mask : mask

indices: indices to use

returns a 3D image

def dvars(x, mask, indices=None):
3815def dvars( x,  mask, indices=None ):
3816    """
3817    dvars on a time series image ... the matrix is normalized to range of 0,1
3818
3819    x: image
3820
3821    mask : mask
3822
3823    indices: indices to use
3824
3825    returns an array
3826    """
3827    M = ants.timeseries_to_matrix( x, mask )
3828    M = M - M.min()
3829    M = M / M.max()
3830    if indices is not None:
3831        M=M[indices,:]
3832    DVARS = np.zeros( M.shape[0] )
3833    for i in range(1, M.shape[0] ):
3834        vecdiff = M[i-1,:] - M[i,:]
3835        DVARS[i] = np.sqrt( ( vecdiff * vecdiff ).mean() )
3836    DVARS[0] = DVARS.mean()
3837    return DVARS

dvars on a time series image ... the matrix is normalized to range of 0,1

x: image

mask : mask

indices: indices to use

returns an array

def slice_snr(x, background_mask, foreground_mask, indices=None):
3840def slice_snr( x,  background_mask, foreground_mask, indices=None ):
3841    """
3842    slice-wise SNR on a time series image
3843
3844    x: image
3845
3846    background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask
3847
3848    foreground_mask : mask - maybe cortex or WM or brain mask
3849
3850    indices: indices to use
3851
3852    returns an array
3853    """
3854    xuse=ants.iMath(x,"Normalize")
3855    MB = ants.timeseries_to_matrix( xuse, background_mask )
3856    MF = ants.timeseries_to_matrix( xuse, foreground_mask )
3857    if indices is not None:
3858        MB=MB[indices,:]
3859        MF=MF[indices,:]
3860    ssnr = np.zeros( MB.shape[0] )
3861    for i in range( MB.shape[0] ):
3862        ssnr[i]=MF[i,:].mean()/MB[i,:].std()
3863    ssnr[np.isnan(ssnr)] = 0
3864    return ssnr

slice-wise SNR on a time series image

x: image

background_mask : mask - maybe CSF or background or dilated brain mask minus original brain mask

foreground_mask : mask - maybe cortex or WM or brain mask

indices: indices to use

returns an array

def impute_fa(fa, md):
3867def impute_fa( fa, md ):
3868    """
3869    impute bad values in dti, fa, md
3870    """
3871    def imputeit( x, fa ):
3872        badfa=ants.threshold_image(fa,1,1)
3873        if badfa.max() == 1:
3874            temp=ants.image_clone(x)
3875            temp[badfa==1]=0
3876            temp=ants.iMath(temp,'GD',2)
3877            x[ badfa==1 ]=temp[badfa==1]
3878        return x
3879    md=imputeit( md, fa )
3880    fa=imputeit( ants.image_clone(fa), fa )
3881    return fa, md

impute bad values in dti, fa, md

def trim_dti_mask(fa, mask, param=4.0):
3883def trim_dti_mask( fa, mask, param=4.0 ):
3884    """
3885    trim the dti mask to get rid of bright fa rim
3886
3887    this function erodes the famask by param amount then segments the rim into
3888    bright and less bright parts.  the bright parts are trimmed from the mask
3889    and the remaining edges are cleaned up a bit with closing.
3890
3891    param: closing radius unit is in physical space
3892    """
3893    spacing = ants.get_spacing(mask)
3894    spacing_product = np.prod( spacing )
3895    spcmin = min( spacing )
3896    paramVox = int(np.round( param / spcmin ))
3897    trim_mask = ants.image_clone( mask )
3898    trim_mask = ants.iMath( trim_mask, "FillHoles" )
3899    edgemask = trim_mask - ants.iMath( trim_mask, "ME", paramVox )
3900    maxk=4
3901    edgemask = ants.threshold_image( fa * edgemask, "Otsu", maxk )
3902    edgemask = ants.threshold_image( edgemask, maxk-1, maxk )
3903    trim_mask[edgemask >= 1 ]=0
3904    trim_mask = ants.iMath(trim_mask,"ME",paramVox-1)
3905    trim_mask = ants.iMath(trim_mask,'GetLargestComponent')
3906    trim_mask = ants.iMath(trim_mask,"MD",paramVox-1)
3907    return trim_mask

trim the dti mask to get rid of bright fa rim

this function erodes the famask by param amount then segments the rim into bright and less bright parts. the bright parts are trimmed from the mask and the remaining edges are cleaned up a bit with closing.

param: closing radius unit is in physical space

def dipy_dti_recon( image, bvalsfn, bvecsfn, mask=None, b0_idx=None, mask_dilation=2, mask_closing=5, fit_method='WLS', trim_the_mask=2.0, diffusion_model='DTI', verbose=False):
4271def dipy_dti_recon(
4272    image,
4273    bvalsfn,
4274    bvecsfn,
4275    mask = None,
4276    b0_idx = None,
4277    mask_dilation = 2,
4278    mask_closing = 5,
4279    fit_method='WLS',
4280    trim_the_mask=2.0,
4281    diffusion_model='DTI',
4282    verbose=False ):
4283    """
4284    DiPy DTI reconstruction - building on the DiPy basic DTI example
4285
4286    Arguments
4287    ---------
4288    image : an antsImage holding B0 and DWI
4289
4290    bvalsfn : bvalues  obtained by dipy read_bvals_bvecs or the values themselves
4291
4292    bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves
4293
4294    mask : brain mask for the DWI/DTI reconstruction; if it is not in the same
4295        space as the image, we will resample directly to the image space.  This
4296        could lead to problems if the inputs are really incorrect.
4297
4298    b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue
4299
4300    mask_dilation : integer zero or more dilates the brain mask
4301
4302    mask_closing : integer zero or more closes the brain mask
4303
4304    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.
4305
4306    trim_the_mask : float >=0 post-hoc method for trimming the mask
4307
4308    diffusion_model : string
4309        DTI, FreeWater, DKI
4310
4311    verbose : boolean
4312
4313    Returns
4314    -------
4315    dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)
4316
4317    NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)
4318
4319    NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect
4320        the tailing image volumes.
4321
4322    Example
4323    -------
4324    >>> import antspymm
4325    """
4326
4327    import dipy.reconst.fwdti as fwdti
4328
4329    if isinstance(bvecsfn, str):
4330        bvals, bvecs = read_bvals_bvecs( bvalsfn , bvecsfn   )
4331    else: # assume we already read them
4332        bvals = bvalsfn.copy()
4333        bvecs = bvecsfn.copy()
4334
4335    if bvals.max() < 1.0:
4336        raise ValueError("DTI recon error: maximum bvalues are too small.")
4337
4338    b0_idx = segment_timeseries_by_bvalue( bvals )['lowbvals']
4339
4340    b0 = ants.slice_image( image, axis=3, idx=b0_idx[0] )
4341    bxtmod='bold'
4342    bxtmod='t2'
4343    constant_mask=False
4344    if verbose:
4345        print( np.unique( bvals ), flush=True )
4346    if mask is not None:
4347        if verbose:
4348            print("use set bxt in dipy_dti_recon", flush=True)
4349        constant_mask=True
4350        mask = ants.resample_image_to_target( mask, b0, interp_type='nearestNeighbor')
4351    else:
4352        if verbose:
4353            print("use deep learning bxt in dipy_dti_recon")
4354        mask = antspynet.brain_extraction( b0, bxtmod ).threshold_image(0.5,1).iMath("GetLargestComponent").morphology("close",2).iMath("FillHoles")
4355    if mask_closing > 0 and not constant_mask :
4356        mask = ants.morphology( mask, "close", mask_closing ) # good
4357    maskdil = ants.iMath( mask, "MD", mask_dilation )
4358
4359    if verbose:
4360        print("recon dti.TensorModel",flush=True)
4361
4362    bvecs = repair_bvecs( bvecs )
4363    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
4364    mynt=1
4365    threads_env = os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")
4366    if threads_env is not None:
4367        mynt = int(threads_env)
4368    tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
4369                                             num_threads=mynt )
4370    if verbose:
4371        print("recon dti.TensorModel done",flush=True)
4372
4373    # change the brain mask based on high FA values
4374    if trim_the_mask > 0 and fit_method is not None:
4375        mask = trim_dti_mask( FA, mask, trim_the_mask )
4376        tenfit, FA, MD1, RGB = efficient_dwi_fit( gtab, diffusion_model, image, maskdil,
4377                                             num_threads=mynt )
4378
4379    return {
4380        'tensormodel' : tenfit,
4381        'MD' : MD1 ,
4382        'FA' : FA ,
4383        'RGB' : RGB,
4384        'dwi_mask':mask,
4385        'bvals':bvals,
4386        'bvecs':bvecs
4387        }

DiPy DTI reconstruction - building on the DiPy basic DTI example

Arguments

image : an antsImage holding B0 and DWI

bvalsfn : bvalues obtained by dipy read_bvals_bvecs or the values themselves

bvecsfn : bvectors obtained by dipy read_bvals_bvecs or the values themselves

mask : brain mask for the DWI/DTI reconstruction; if it is not in the same space as the image, we will resample directly to the image space. This could lead to problems if the inputs are really incorrect.

b0_idx : the indices of the B0; if None, use segment_timeseries_by_bvalue

mask_dilation : integer zero or more dilates the brain mask

mask_closing : integer zero or more closes the brain mask

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel) ... if None, will not reconstruct DTI.

trim_the_mask : float >=0 post-hoc method for trimming the mask

diffusion_model : string DTI, FreeWater, DKI

verbose : boolean

Returns

dictionary holding the tensorfit, MD, FA and RGB images and motion parameters (optional)

NOTE -- see dipy reorient_bvecs(gtab, affines, atol=1e-2)

NOTE -- if the bvec.shape[0] is smaller than the image.shape[3], we neglect the tailing image volumes.

Example

>>> import antspymm
def concat_dewarp( refimg, originalDWI, physSpaceDWI, dwpTx, motion_parameters, motion_correct=True, verbose=False):
4390def concat_dewarp(
4391        refimg,
4392        originalDWI,
4393        physSpaceDWI,
4394        dwpTx,
4395        motion_parameters,
4396        motion_correct=True,
4397        verbose=False ):
4398    """
4399    Apply concatentated motion correction and dewarping transforms to timeseries image.
4400
4401    Arguments
4402    ---------
4403
4404    refimg : an antsImage defining the reference domain (3D)
4405
4406    originalDWI : the antsImage in original (not interpolated space) (4D)
4407
4408    physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)
4409
4410    dwpTx : dewarping transform
4411
4412    motion_parameters : previously computed list of motion parameters
4413
4414    motion_correct : boolean
4415
4416    verbose : boolean
4417
4418    """
4419    # apply the dewarping tx to the original dwi and reconstruct again
4420    # NOTE: refimg must be in the same space for this to work correctly
4421    # due to the use of ants.list_to_ndimage( originalDWI, dwpimage )
4422    dwpimage = []
4423    for myidx in range(originalDWI.shape[3]):
4424        b0 = ants.slice_image( originalDWI, axis=3, idx=myidx)
4425        concatx = dwpTx.copy()
4426        if motion_correct:
4427            concatx = concatx + motion_parameters[myidx]
4428        if verbose and myidx == 0:
4429            print("dwp parameters")
4430            print( dwpTx )
4431            print("Motion parameters")
4432            print( motion_parameters[myidx] )
4433            print("concat parameters")
4434            print(concatx)
4435        warpedb0 = ants.apply_transforms( refimg, b0, concatx,
4436            interpolator='nearestNeighbor' )
4437        dwpimage.append( warpedb0 )
4438    return ants.list_to_ndimage( physSpaceDWI, dwpimage )

Apply concatentated motion correction and dewarping transforms to timeseries image.

Arguments

refimg : an antsImage defining the reference domain (3D)

originalDWI : the antsImage in original (not interpolated space) (4D)

physSpaceDWI : ants antsImage defining the physical space of the mapping (4D)

dwpTx : dewarping transform

motion_parameters : previously computed list of motion parameters

motion_correct : boolean

verbose : boolean

def joint_dti_recon( img_LR, bval_LR, bvec_LR, jhu_atlas, jhu_labels, reference_B0, reference_DWI, srmodel=None, img_RL=None, bval_RL=None, bvec_RL=None, t1w=None, brain_mask=None, motion_correct=None, dewarp_modality='FA', denoise=False, fit_method='WLS', impute=False, censor=True, diffusion_model='DTI', verbose=False):
4441def joint_dti_recon(
4442    img_LR,
4443    bval_LR,
4444    bvec_LR,
4445    jhu_atlas,
4446    jhu_labels,
4447    reference_B0,
4448    reference_DWI,
4449    srmodel = None,
4450    img_RL = None,
4451    bval_RL = None,
4452    bvec_RL = None,
4453    t1w = None,
4454    brain_mask = None,
4455    motion_correct = None,
4456    dewarp_modality = 'FA',
4457    denoise=False,
4458    fit_method='WLS',
4459    impute = False,
4460    censor = True,
4461    diffusion_model = 'DTI',
4462    verbose = False ):
4463    """
4464    1. pass in subject data and 1mm JHU atlas/labels
4465    2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
4466    3. dewarp the images using dewarp_modality or T1w
4467    4. apply dewarping to the original data
4468        ===> may want to apply SR at this step
4469    5. reconstruct DTI again
4470    6. label images and do registration
4471    7. return relevant outputs
4472
4473    NOTE: RL images are optional; should pass t1w in this case.
4474
4475    Arguments
4476    ---------
4477
4478    img_LR : an antsImage holding B0 and DWI LR acquisition
4479
4480    bval_LR : bvalue filename LR
4481
4482    bvec_LR : bvector filename LR
4483
4484    jhu_atlas : atlas FA image
4485
4486    jhu_labels : atlas labels
4487
4488    reference_B0 : the "target" B0 image space
4489
4490    reference_DWI : the "target" DW image space
4491
4492    srmodel : optional h5 (tensorflow) model
4493
4494    img_RL : an antsImage holding B0 and DWI RL acquisition
4495
4496    bval_RL : bvalue filename RL
4497
4498    bvec_RL : bvector filename RL
4499
4500    t1w : antsimage t1w neuroimage (brain-extracted)
4501
4502    brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space
4503
4504    motion_correct : None Rigid or SyN
4505
4506    dewarp_modality : string average_dwi, average_b0, MD or FA
4507
4508    denoise: boolean
4509
4510    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
4511
4512    impute : boolean
4513
4514    censor : boolean
4515
4516    diffusion_model : string
4517        DTI, FreeWater, DKI
4518
4519    verbose : boolean
4520
4521    Returns
4522    -------
4523    dictionary holding the mean_fa, its summary statistics via JHU labels,
4524        the JHU registration, the JHU labels, the dewarping dictionary and the
4525        dti reconstruction dictionaries.
4526
4527    Example
4528    -------
4529    >>> import antspymm
4530    """
4531
4532    if verbose:
4533        print("Recon DTI on OR images ...")
4534
4535    def fix_dwi_shape( img, bvalfn, bvecfn ):
4536        if isinstance(bvecfn, str):
4537            bvals, bvecs = read_bvals_bvecs( bvalfn , bvecfn   )
4538        else:
4539            bvals = bvalfn
4540            bvecs = bvecfn
4541        if bvecs.shape[0] < img.shape[3]:
4542            imgout = ants.from_numpy( img[:,:,:,0:bvecs.shape[0]] )
4543            imgout = ants.copy_image_info( img, imgout )
4544            return( imgout )
4545        else:
4546            return( img )
4547
4548    img_LR = fix_dwi_shape( img_LR, bval_LR, bvec_LR )
4549    if denoise :
4550        img_LR = mc_denoise( img_LR )
4551    if img_RL is not None:
4552        img_RL = fix_dwi_shape( img_RL, bval_RL, bvec_RL )
4553        if denoise :
4554            img_RL = mc_denoise( img_RL )
4555
4556    brainmaske = None
4557    if brain_mask is not None:
4558        maskInRightSpace = ants.image_physical_space_consistency( brain_mask, reference_B0 )
4559        if not maskInRightSpace :
4560            raise ValueError('not maskInRightSpace ... provided brain mask should be in reference_B0 space')
4561        brainmaske = ants.iMath( brain_mask, "ME", 2 )
4562
4563    if img_RL is not None :
4564        if verbose:
4565            print("img_RL correction")
4566        reg_RL = dti_reg(
4567            img_RL,
4568            avg_b0=reference_B0,
4569            avg_dwi=reference_DWI,
4570            bvals=bval_RL,
4571            bvecs=bvec_RL,
4572            type_of_transform=motion_correct,
4573            brain_mask_eroded=brainmaske,
4574            verbose=True )
4575    else:
4576        reg_RL=None
4577
4578
4579    if verbose:
4580        print("img_LR correction")
4581    reg_LR = dti_reg(
4582            img_LR,
4583            avg_b0=reference_B0,
4584            avg_dwi=reference_DWI,
4585            bvals=bval_LR,
4586            bvecs=bvec_LR,
4587            type_of_transform=motion_correct,
4588            brain_mask_eroded=brainmaske,
4589            verbose=True )
4590
4591    ts_LR_avg = None
4592    ts_RL_avg = None
4593    reg_its = [100,50,10]
4594    img_LRdwp = ants.image_clone( reg_LR[ 'motion_corrected' ] )
4595    if img_RL is not None:
4596        img_RLdwp = ants.image_clone( reg_RL[ 'motion_corrected' ] )
4597        if srmodel is not None:
4598            if verbose:
4599                print("convert img_RL_dwp to img_RL_dwp_SR")
4600            img_RLdwp = super_res_mcimage( img_RLdwp, srmodel, isotropic=True,
4601                        verbose=verbose )
4602    if srmodel is not None:
4603        reg_its = [100] + reg_its
4604        if verbose:
4605            print("convert img_LR_dwp to img_LR_dwp_SR")
4606        img_LRdwp = super_res_mcimage( img_LRdwp, srmodel, isotropic=True,
4607                verbose=verbose )
4608    if verbose:
4609        print("recon after distortion correction", flush=True)
4610
4611    if impute:
4612        print("impute begin", flush=True)
4613        img_LRdwp=impute_dwi( img_LRdwp, verbose=True )
4614        print("impute done", flush=True)
4615    elif censor:
4616        print("censor begin", flush=True)
4617        img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'] = censor_dwi( img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'], verbose=True )
4618        print("censor done", flush=True)
4619    if impute and img_RL is not None:
4620        img_RLdwp=impute_dwi( img_RLdwp, verbose=True )
4621    elif censor and img_RL is not None:
4622        img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'] = censor_dwi( img_RLdwp, reg_RL['bvals'], reg_RL['bvecs'], verbose=True )
4623
4624    if img_RL is not None:
4625        img_LRdwp, bval_LR, bvec_LR = merge_dwi_data(
4626            img_LRdwp, reg_LR['bvals'], reg_LR['bvecs'],
4627            img_RLdwp, reg_RL['bvals'], reg_RL['bvecs']
4628        )
4629    else:
4630        bval_LR=reg_LR['bvals']
4631        bvec_LR=reg_LR['bvecs']
4632
4633    if verbose:
4634        print("final recon", flush=True)
4635        print(img_LRdwp)
4636
4637    recon_LR_dewarp = dipy_dti_recon(
4638            img_LRdwp, bval_LR, bvec_LR,
4639            mask = brain_mask,
4640            fit_method=fit_method,
4641            mask_dilation=0, diffusion_model=diffusion_model, verbose=True )
4642    if verbose:
4643        print("recon done", flush=True)
4644
4645    if img_RL is not None:
4646        fdjoin = [ reg_LR['FD'],
4647                   reg_RL['FD'] ]
4648        framewise_displacement=np.concatenate( fdjoin )
4649    else:
4650        framewise_displacement=reg_LR['FD']
4651
4652    motion_count = ( framewise_displacement > 1.5  ).sum()
4653    reconFA = recon_LR_dewarp['FA']
4654    reconMD = recon_LR_dewarp['MD']
4655
4656    if verbose:
4657        print("JHU reg",flush=True)
4658
4659    OR_FA2JHUreg = ants.registration( reconFA, jhu_atlas,
4660        type_of_transform = 'antsRegistrationSyNQuickRepro[s]', 
4661        reg_iterations=reg_its, verbose=False )
4662    OR_FA_jhulabels = ants.apply_transforms( reconFA, jhu_labels,
4663        OR_FA2JHUreg['fwdtransforms'], interpolator='genericLabel')
4664
4665    df_FA_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
4666        'FA_JHU_labels_edited',
4667        reconFA,
4668        OR_FA_jhulabels)
4669    df_FA_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
4670            {'df_FA_JHU_ORRL' : df_FA_JHU_ORRL},
4671            col_names = ['Mean'] )
4672
4673    df_MD_JHU_ORRL = antspyt1w.map_intensity_to_dataframe(
4674        'MD_JHU_labels_edited',
4675        reconMD,
4676        OR_FA_jhulabels)
4677    df_MD_JHU_ORRL_bfwide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
4678            {'df_MD_JHU_ORRL' : df_MD_JHU_ORRL},
4679            col_names = ['Mean'] )
4680
4681    temp = segment_timeseries_by_meanvalue( img_LRdwp )
4682    b0_idx = temp['highermeans']
4683    non_b0_idx = temp['lowermeans']
4684
4685    nonbrainmask = ants.iMath( recon_LR_dewarp['dwi_mask'], "MD",2) - recon_LR_dewarp['dwi_mask']
4686    fgmask = ants.threshold_image( reconFA, 0.5 , 1.0).iMath("GetLargestComponent")
4687    bgmask = ants.threshold_image( reconFA, 1e-4 , 0.1)
4688    fa_SNR = 0.0
4689    fa_SNR = mask_snr( reconFA, bgmask, fgmask, bias_correct=False )
4690    fa_evr = antspyt1w.patch_eigenvalue_ratio( reconFA, 512, [16,16,16], evdepth = 0.9, mask=recon_LR_dewarp['dwi_mask'] )
4691
4692    dti_itself = get_dti( reconFA, recon_LR_dewarp['tensormodel'], return_image=True )
4693    return convert_np_in_dict( {
4694        'dti': dti_itself,
4695        'recon_fa':reconFA,
4696        'recon_fa_summary':df_FA_JHU_ORRL_bfwide,
4697        'recon_md':reconMD,
4698        'recon_md_summary':df_MD_JHU_ORRL_bfwide,
4699        'jhu_labels':OR_FA_jhulabels,
4700        'jhu_registration':OR_FA2JHUreg,
4701        'reg_LR':reg_LR,
4702        'reg_RL':reg_RL,
4703        'dtrecon_LR_dewarp':recon_LR_dewarp,
4704        'dwi_LR_dewarped':img_LRdwp,
4705        'bval_unique_count': len(np.unique(bval_LR)),
4706        'bval_LR':bval_LR,
4707        'bvec_LR':bvec_LR,
4708        'bval_RL':bval_RL,
4709        'bvec_RL':bvec_RL,
4710        'b0avg': reference_B0,
4711        'dwiavg': reference_DWI,
4712        'framewise_displacement':framewise_displacement,
4713        'high_motion_count': motion_count,
4714        'tsnr_b0': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
4715        'tsnr_dwi': tsnr( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
4716        'dvars_b0': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], b0_idx),
4717        'dvars_dwi': dvars( img_LRdwp, recon_LR_dewarp['dwi_mask'], non_b0_idx),
4718        'ssnr_b0': slice_snr( img_LRdwp, bgmask , fgmask, b0_idx),
4719        'ssnr_dwi': slice_snr( img_LRdwp, bgmask, fgmask, non_b0_idx),
4720        'fa_evr': fa_evr,
4721        'fa_SNR': fa_SNR
4722    } )
  1. pass in subject data and 1mm JHU atlas/labels
  2. perform initial LR, RL reconstruction (2nd is optional) and motion correction (optional)
  3. dewarp the images using dewarp_modality or T1w
  4. apply dewarping to the original data ===> may want to apply SR at this step
  5. reconstruct DTI again
  6. label images and do registration
  7. return relevant outputs

NOTE: RL images are optional; should pass t1w in this case.

Arguments

img_LR : an antsImage holding B0 and DWI LR acquisition

bval_LR : bvalue filename LR

bvec_LR : bvector filename LR

jhu_atlas : atlas FA image

jhu_labels : atlas labels

reference_B0 : the "target" B0 image space

reference_DWI : the "target" DW image space

srmodel : optional h5 (tensorflow) model

img_RL : an antsImage holding B0 and DWI RL acquisition

bval_RL : bvalue filename RL

bvec_RL : bvector filename RL

t1w : antsimage t1w neuroimage (brain-extracted)

brain_mask : mask for the DWI - just 3D - provided brain mask should be in reference_B0 space

motion_correct : None Rigid or SyN

dewarp_modality : string average_dwi, average_b0, MD or FA

denoise: boolean

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)

impute : boolean

censor : boolean

diffusion_model : string DTI, FreeWater, DKI

verbose : boolean

Returns

dictionary holding the mean_fa, its summary statistics via JHU labels, the JHU registration, the JHU labels, the dewarping dictionary and the dti reconstruction dictionaries.

Example

>>> import antspymm
def middle_slice_snr(x, background_dilation=5):
4725def middle_slice_snr( x, background_dilation=5 ):
4726    """
4727
4728    Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image.
4729    Estimates noise from a background mask which is a
4730    dilation of the foreground mask minus the foreground mask.
4731    Actually estimates the reciprocal of the coefficient of variation.
4732
4733    Arguments
4734    ---------
4735
4736    x : an antsImage
4737
4738    background_dilation : integer - amount to dilate foreground mask
4739
4740    """
4741    xshp = x.shape
4742    xmidslice = ants.slice_image( x, 2, int( xshp[2]/2 )  )
4743    xmidslice = ants.iMath( xmidslice - xmidslice.min(), "Normalize" )
4744    xmidslice = ants.n3_bias_field_correction( xmidslice )
4745    xmidslice = ants.n3_bias_field_correction( xmidslice )
4746    xmidslicemask = ants.threshold_image( xmidslice, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
4747    xbkgmask = ants.iMath( xmidslicemask, "MD", background_dilation ) - xmidslicemask
4748    signal = (xmidslice[ xmidslicemask == 1] ).mean()
4749    noise = (xmidslice[ xbkgmask == 1] ).std()
4750    return signal / noise

Estimate signal to noise ratio (SNR) in 2D mid image from a 3D image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_dilation : integer - amount to dilate foreground mask

def foreground_background_snr(x, background_dilation=10, erode_foreground=False):
4752def foreground_background_snr( x, background_dilation=10,
4753        erode_foreground=False):
4754    """
4755
4756    Estimate signal to noise ratio (SNR) in an image.
4757    Estimates noise from a background mask which is a
4758    dilation of the foreground mask minus the foreground mask.
4759    Actually estimates the reciprocal of the coefficient of variation.
4760
4761    Arguments
4762    ---------
4763
4764    x : an antsImage
4765
4766    background_dilation : integer - amount to dilate foreground mask
4767
4768    erode_foreground : boolean - 2nd option which erodes the initial
4769    foregound mask  to create a new foreground mask.  the background
4770    mask is the initial mask minus the eroded mask.
4771
4772    """
4773    xshp = x.shape
4774    xbc = ants.iMath( x - x.min(), "Normalize" )
4775    xbc = ants.n3_bias_field_correction( xbc )
4776    xmask = ants.threshold_image( xbc, "Otsu", 1 ).morphology("close",2).iMath("FillHoles")
4777    xbkgmask = ants.iMath( xmask, "MD", background_dilation ) - xmask
4778    fgmask = xmask
4779    if erode_foreground:
4780        fgmask = ants.iMath( xmask, "ME", background_dilation )
4781        xbkgmask = xmask - fgmask
4782    signal = (xbc[ fgmask == 1] ).mean()
4783    noise = (xbc[ xbkgmask == 1] ).std()
4784    return signal / noise

Estimate signal to noise ratio (SNR) in an image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_dilation : integer - amount to dilate foreground mask

erode_foreground : boolean - 2nd option which erodes the initial foregound mask to create a new foreground mask. the background mask is the initial mask minus the eroded mask.

def quantile_snr( x, lowest_quantile=0.01, low_quantile=0.1, high_quantile=0.5, highest_quantile=0.95):
4786def quantile_snr( x,
4787    lowest_quantile=0.01,
4788    low_quantile=0.1,
4789    high_quantile=0.5,
4790    highest_quantile=0.95 ):
4791    """
4792
4793    Estimate signal to noise ratio (SNR) in an image.
4794    Estimates noise from a background mask which is a
4795    dilation of the foreground mask minus the foreground mask.
4796    Actually estimates the reciprocal of the coefficient of variation.
4797
4798    Arguments
4799    ---------
4800
4801    x : an antsImage
4802
4803    lowest_quantile : float value < 1 and > 0
4804
4805    low_quantile : float value < 1 and > 0
4806
4807    high_quantile : float value < 1 and > 0
4808
4809    highest_quantile : float value < 1 and > 0
4810
4811    """
4812    import numpy as np
4813    xshp = x.shape
4814    xbc = ants.iMath( x - x.min(), "Normalize" )
4815    xbc = ants.n3_bias_field_correction( xbc )
4816    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
4817    y = xbc.numpy()
4818    ylowest = np.quantile( y[y>0], lowest_quantile )
4819    ylo = np.quantile( y[y>0], low_quantile )
4820    yhi = np.quantile( y[y>0], high_quantile )
4821    yhiest = np.quantile( y[y>0], highest_quantile )
4822    xbkgmask = ants.threshold_image( xbc, ylowest, ylo )
4823    fgmask = ants.threshold_image( xbc, yhi, yhiest )
4824    signal = (xbc[ fgmask == 1] ).mean()
4825    noise = (xbc[ xbkgmask == 1] ).std()
4826    return signal / noise

Estimate signal to noise ratio (SNR) in an image. Estimates noise from a background mask which is a dilation of the foreground mask minus the foreground mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

lowest_quantile : float value < 1 and > 0

low_quantile : float value < 1 and > 0

high_quantile : float value < 1 and > 0

highest_quantile : float value < 1 and > 0

def mask_snr(x, background_mask, foreground_mask, bias_correct=True):
4828def mask_snr( x, background_mask, foreground_mask, bias_correct=True ):
4829    """
4830
4831    Estimate signal to noise ratio (SNR) in an image using
4832    a user-defined foreground and background mask.
4833    Actually estimates the reciprocal of the coefficient of variation.
4834
4835    Arguments
4836    ---------
4837
4838    x : an antsImage
4839
4840    background_mask : binary antsImage
4841
4842    foreground_mask : binary antsImage
4843
4844    bias_correct : boolean
4845
4846    """
4847    import numpy as np
4848    if foreground_mask.sum() <= 1 or background_mask.sum() <= 1:
4849        return 0
4850    xbc = ants.iMath( x - x.min(), "Normalize" )
4851    if bias_correct:
4852        xbc = ants.n3_bias_field_correction( xbc )
4853    xbc = ants.iMath( xbc - xbc.min(), "Normalize" )
4854    signal = (xbc[ foreground_mask == 1] ).mean()
4855    noise = (xbc[ background_mask == 1] ).std()
4856    return signal / noise

Estimate signal to noise ratio (SNR) in an image using a user-defined foreground and background mask. Actually estimates the reciprocal of the coefficient of variation.

Arguments

x : an antsImage

background_mask : binary antsImage

foreground_mask : binary antsImage

bias_correct : boolean

def dwi_deterministic_tracking( dwi, fa, bvals, bvecs, num_processes=1, mask=None, label_image=None, seed_labels=None, fa_thresh=0.05, seed_density=1, step_size=0.15, peak_indices=None, fit_method='WLS', verbose=False):
4859def dwi_deterministic_tracking(
4860    dwi,
4861    fa,
4862    bvals,
4863    bvecs,
4864    num_processes=1,
4865    mask=None,
4866    label_image = None,
4867    seed_labels = None,
4868    fa_thresh = 0.05,
4869    seed_density = 1,
4870    step_size = 0.15,
4871    peak_indices = None,
4872    fit_method='WLS',
4873    verbose = False ):
4874    """
4875
4876    Performs deterministic tractography from the DWI and returns a tractogram
4877    and path length data frame.
4878
4879    Arguments
4880    ---------
4881
4882    dwi : an antsImage holding DWI acquisition
4883
4884    fa : an antsImage holding FA values
4885
4886    bvals : bvalues
4887
4888    bvecs : bvectors
4889
4890    num_processes : number of subprocesses
4891
4892    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
4893        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
4894
4895    label_image : atlas labels
4896
4897    seed_labels : list of label numbers from the atlas labels
4898
4899    fa_thresh : 0.25 defaults
4900
4901    seed_density : 1 default number of seeds per voxel
4902
4903    step_size : for tracking
4904
4905    peak_indices : pass these in, if they are previously estimated.  otherwise, will
4906        compute on the fly (slow)
4907
4908    fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)
4909
4910    verbose : boolean
4911
4912    Returns
4913    -------
4914    dictionary holding tracts and stateful object.
4915
4916    Example
4917    -------
4918    >>> import antspymm
4919    """
4920    import os
4921    import re
4922    import nibabel as nib
4923    import numpy as np
4924    import ants
4925    from dipy.io.gradients import read_bvals_bvecs
4926    from dipy.core.gradients import gradient_table
4927    from dipy.tracking import utils
4928    import dipy.reconst.dti as dti
4929    from dipy.segment.clustering import QuickBundles
4930    from dipy.tracking.utils import path_length
4931    if verbose:
4932        print("begin tracking",flush=True)
4933
4934    affine = ants_to_nibabel_affine(dwi)
4935
4936    if isinstance( bvals, str ) or isinstance( bvecs, str ):
4937        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
4938    bvecs = repair_bvecs( bvecs )
4939    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
4940    if mask is None:
4941        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
4942    dwi_data = dwi.numpy()
4943    dwi_mask = mask.numpy() == 1
4944    dti_model = dti.TensorModel(gtab,fit_method=fit_method)
4945    if verbose:
4946        print("begin tracking fit",flush=True)
4947    dti_fit = dti_model.fit(dwi_data, mask=dwi_mask)  # This step may take a while
4948    evecs_img = dti_fit.evecs
4949    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
4950    stopping_criterion = ThresholdStoppingCriterion(fa.numpy(), fa_thresh)
4951    from dipy.data import get_sphere
4952    sphere = get_sphere(name='symmetric362')
4953    from dipy.direction import peaks_from_model
4954    if peak_indices is None:
4955        # problems with multi-threading ...
4956        # see https://github.com/dipy/dipy/issues/2519
4957        if verbose:
4958            print("begin peaks",flush=True)
4959        mynump=1
4960        # if os.getenv("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"):
4961        #    mynump = os.environ['ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS']
4962        # current_openblas = os.environ.get('OPENBLAS_NUM_THREADS', '')
4963        # current_mkl = os.environ.get('MKL_NUM_THREADS', '')
4964        # os.environ['DIPY_OPENBLAS_NUM_THREADS'] = current_openblas
4965        # os.environ['DIPY_MKL_NUM_THREADS'] = current_mkl
4966        # os.environ['OPENBLAS_NUM_THREADS'] = '1'
4967        # os.environ['MKL_NUM_THREADS'] = '1'
4968        peak_indices = peaks_from_model(
4969            model=dti_model,
4970            data=dwi_data,
4971            sphere=sphere,
4972            relative_peak_threshold=.5,
4973            min_separation_angle=25,
4974            mask=dwi_mask,
4975            npeaks=3, return_odf=False,
4976            return_sh=False,
4977            parallel=int(mynump) > 1,
4978            num_processes=int(mynump)
4979            )
4980        if False:
4981            if 'DIPY_OPENBLAS_NUM_THREADS' in os.environ:
4982                os.environ['OPENBLAS_NUM_THREADS'] = \
4983                    os.environ.pop('DIPY_OPENBLAS_NUM_THREADS', '')
4984                if os.environ['OPENBLAS_NUM_THREADS'] in ['', None]:
4985                    os.environ.pop('OPENBLAS_NUM_THREADS', '')
4986            if 'DIPY_MKL_NUM_THREADS' in os.environ:
4987                os.environ['MKL_NUM_THREADS'] = \
4988                    os.environ.pop('DIPY_MKL_NUM_THREADS', '')
4989                if os.environ['MKL_NUM_THREADS'] in ['', None]:
4990                    os.environ.pop('MKL_NUM_THREADS', '')
4991
4992    if label_image is None or seed_labels is None:
4993        seed_mask = fa.numpy().copy()
4994        seed_mask[seed_mask >= fa_thresh] = 1
4995        seed_mask[seed_mask < fa_thresh] = 0
4996    else:
4997        labels = label_image.numpy()
4998        seed_mask = labels * 0
4999        for u in seed_labels:
5000            seed_mask[ labels == u ] = 1
5001    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
5002    from dipy.tracking.local_tracking import LocalTracking
5003    from dipy.tracking.streamline import Streamlines
5004    if verbose:
5005        print("streamlines begin ...", flush=True)
5006    streamlines_generator = LocalTracking(
5007        peak_indices, stopping_criterion, seeds, affine=affine, step_size=step_size)
5008    streamlines = Streamlines(streamlines_generator)
5009    from dipy.io.stateful_tractogram import Space, StatefulTractogram
5010    from dipy.io.streamline import save_tractogram
5011    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
5012    if verbose:
5013        print("streamlines done", flush=True)
5014    return {
5015          'tractogram': sft,
5016          'streamlines': streamlines,
5017          'peak_indices': peak_indices
5018          }

Performs deterministic tractography from the DWI and returns a tractogram and path length data frame.

Arguments

dwi : an antsImage holding DWI acquisition

fa : an antsImage holding FA values

bvals : bvalues

bvecs : bvectors

num_processes : number of subprocesses

mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")

label_image : atlas labels

seed_labels : list of label numbers from the atlas labels

fa_thresh : 0.25 defaults

seed_density : 1 default number of seeds per voxel

step_size : for tracking

peak_indices : pass these in, if they are previously estimated. otherwise, will compute on the fly (slow)

fit_method : string one of WLS LS NLLS or restore - see import dipy.reconst.dti as dti and help(dti.TensorModel)

verbose : boolean

Returns

dictionary holding tracts and stateful object.

Example

>>> import antspymm
def dwi_closest_peak_tracking( dwi, fa, bvals, bvecs, num_processes=1, mask=None, label_image=None, seed_labels=None, fa_thresh=0.05, seed_density=1, step_size=0.15, peak_indices=None, verbose=False):
5031def dwi_closest_peak_tracking(
5032    dwi,
5033    fa,
5034    bvals,
5035    bvecs,
5036    num_processes=1,
5037    mask=None,
5038    label_image = None,
5039    seed_labels = None,
5040    fa_thresh = 0.05,
5041    seed_density = 1,
5042    step_size = 0.15,
5043    peak_indices = None,
5044    verbose = False ):
5045    """
5046
5047    Performs deterministic tractography from the DWI and returns a tractogram
5048    and path length data frame.
5049
5050    Arguments
5051    ---------
5052
5053    dwi : an antsImage holding DWI acquisition
5054
5055    fa : an antsImage holding FA values
5056
5057    bvals : bvalues
5058
5059    bvecs : bvectors
5060
5061    num_processes : number of subprocesses
5062
5063    mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh
5064        and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
5065
5066    label_image : atlas labels
5067
5068    seed_labels : list of label numbers from the atlas labels
5069
5070    fa_thresh : 0.25 defaults
5071
5072    seed_density : 1 default number of seeds per voxel
5073
5074    step_size : for tracking
5075
5076    peak_indices : pass these in, if they are previously estimated.  otherwise, will
5077        compute on the fly (slow)
5078
5079    verbose : boolean
5080
5081    Returns
5082    -------
5083    dictionary holding tracts and stateful object.
5084
5085    Example
5086    -------
5087    >>> import antspymm
5088    """
5089    import os
5090    import re
5091    import nibabel as nib
5092    import numpy as np
5093    import ants
5094    from dipy.io.gradients import read_bvals_bvecs
5095    from dipy.core.gradients import gradient_table
5096    from dipy.tracking import utils
5097    import dipy.reconst.dti as dti
5098    from dipy.segment.clustering import QuickBundles
5099    from dipy.tracking.utils import path_length
5100    from dipy.core.gradients import gradient_table
5101    from dipy.data import small_sphere
5102    from dipy.direction import BootDirectionGetter, ClosestPeakDirectionGetter
5103    from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,
5104                                    auto_response_ssst)
5105    from dipy.reconst.shm import CsaOdfModel
5106    from dipy.tracking.local_tracking import LocalTracking
5107    from dipy.tracking.streamline import Streamlines
5108    from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
5109
5110    if verbose:
5111        print("begin tracking",flush=True)
5112
5113    affine = ants_to_nibabel_affine(dwi)
5114    if isinstance( bvals, str ) or isinstance( bvecs, str ):
5115        bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
5116    bvecs = repair_bvecs( bvecs )
5117    gtab = gradient_table(bvals, bvecs=bvecs, atol=2.0 )
5118    if mask is None:
5119        mask = ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")
5120    dwi_data = dwi.numpy()
5121    dwi_mask = mask.numpy() == 1
5122
5123
5124    response, ratio = auto_response_ssst(gtab, dwi_data, roi_radii=10, fa_thr=0.7)
5125    csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
5126    csd_fit = csd_model.fit(dwi_data, mask=dwi_mask)
5127    csa_model = CsaOdfModel(gtab, sh_order=6)
5128    gfa = csa_model.fit(dwi_data, mask=dwi_mask).gfa
5129    stopping_criterion = ThresholdStoppingCriterion(gfa, .25)
5130
5131
5132    if label_image is None or seed_labels is None:
5133        seed_mask = fa.numpy().copy()
5134        seed_mask[seed_mask >= fa_thresh] = 1
5135        seed_mask[seed_mask < fa_thresh] = 0
5136    else:
5137        labels = label_image.numpy()
5138        seed_mask = labels * 0
5139        for u in seed_labels:
5140            seed_mask[ labels == u ] = 1
5141    seeds = utils.seeds_from_mask(seed_mask, affine=affine, density=seed_density)
5142    if verbose:
5143        print("streamlines begin ...", flush=True)
5144
5145    pmf = csd_fit.odf(small_sphere).clip(min=0)
5146    if verbose:
5147        print("ClosestPeakDirectionGetter begin ...", flush=True)
5148    peak_dg = ClosestPeakDirectionGetter.from_pmf(pmf, max_angle=30.,
5149                                                sphere=small_sphere)
5150    if verbose:
5151        print("local tracking begin ...", flush=True)
5152    streamlines_generator = LocalTracking(peak_dg, stopping_criterion, seeds,
5153                                            affine, step_size=.5)
5154    streamlines = Streamlines(streamlines_generator)
5155    from dipy.io.stateful_tractogram import Space, StatefulTractogram
5156    from dipy.io.streamline import save_tractogram
5157    sft = None # StatefulTractogram(streamlines, dwi_img, Space.RASMM)
5158    if verbose:
5159        print("streamlines done", flush=True)
5160    return {
5161          'tractogram': sft,
5162          'streamlines': streamlines
5163          }

Performs deterministic tractography from the DWI and returns a tractogram and path length data frame.

Arguments

dwi : an antsImage holding DWI acquisition

fa : an antsImage holding FA values

bvals : bvalues

bvecs : bvectors

num_processes : number of subprocesses

mask : mask within which to do tracking - if None, we will make a mask using the fa_thresh and the code ants.threshold_image( fa, fa_thresh, 2.0 ).iMath("GetLargestComponent")

label_image : atlas labels

seed_labels : list of label numbers from the atlas labels

fa_thresh : 0.25 defaults

seed_density : 1 default number of seeds per voxel

step_size : for tracking

peak_indices : pass these in, if they are previously estimated. otherwise, will compute on the fly (slow)

verbose : boolean

Returns

dictionary holding tracts and stateful object.

Example

>>> import antspymm
def dwi_streamline_pairwise_connectivity(streamlines, label_image, labels_to_connect=[1, None], verbose=False):
5165def dwi_streamline_pairwise_connectivity( streamlines, label_image, labels_to_connect=[1,None], verbose=False ):
5166    """
5167
5168    Return streamlines connecting all of the regions in the label set. Ideal
5169    for just 2 regions.
5170
5171    Arguments
5172    ---------
5173
5174    streamlines : streamline object from dipy
5175
5176    label_image : atlas labels
5177
5178    labels_to_connect : list of 2 labels or [label,None]
5179
5180    verbose : boolean
5181
5182    Returns
5183    -------
5184    the subset of streamlines and a streamline count
5185
5186    Example
5187    -------
5188    >>> import antspymm
5189    """
5190    from dipy.tracking.streamline import Streamlines
5191    keep_streamlines = Streamlines()
5192
5193    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
5194
5195    lin_T, offset = utils._mapping_to_voxel(affine)
5196    label_image_np = label_image.numpy()
5197    def check_it( sl, target_label, label_image, index, full=False ):
5198        if full:
5199            maxind=sl.shape[0]
5200            for index in range(maxind):
5201                pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5202                mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5203                if mylab == target_label[0] or mylab == target_label[1]:
5204                    return { 'ok': True, 'label':mylab }
5205        else:
5206            pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5207            mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5208            if mylab == target_label[0] or mylab == target_label[1]:
5209                return { 'ok': True, 'label':mylab }
5210        return { 'ok': False, 'label':None }
5211    ct=0
5212    for k in range( len( streamlines ) ):
5213        sl = streamlines[k]
5214        mycheck = check_it( sl, labels_to_connect, label_image_np, index=0, full=True )
5215        if mycheck['ok']:
5216            otherind=1
5217            if mycheck['label'] == labels_to_connect[1]:
5218                otherind=0
5219            lsl = len( sl )-1
5220            pt = utils._to_voxel_coordinates(sl[lsl,:], lin_T, offset)
5221            mylab_end = (label_image_np[ pt[0], pt[1], pt[2] ]).astype(int)
5222            accept_point = mylab_end == labels_to_connect[otherind]
5223            if verbose and accept_point:
5224                print( mylab_end )
5225            if labels_to_connect[1] is None:
5226                accept_point = mylab_end != 0
5227            if accept_point:
5228                keep_streamlines.append(sl)
5229                ct=ct+1
5230    return { 'streamlines': keep_streamlines, 'count': ct }

Return streamlines connecting all of the regions in the label set. Ideal for just 2 regions.

Arguments

streamlines : streamline object from dipy

label_image : atlas labels

labels_to_connect : list of 2 labels or [label,None]

verbose : boolean

Returns

the subset of streamlines and a streamline count

Example

>>> import antspymm
def dwi_streamline_connectivity(streamlines, label_image, label_dataframe, verbose=False):
5287def dwi_streamline_connectivity(
5288    streamlines,
5289    label_image,
5290    label_dataframe,
5291    verbose = False ):
5292    """
5293
5294    Summarize network connetivity of the input streamlines between all of the
5295    regions in the label set.
5296
5297    Arguments
5298    ---------
5299
5300    streamlines : streamline object from dipy
5301
5302    label_image : atlas labels
5303
5304    label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)
5305
5306    verbose : boolean
5307
5308    Returns
5309    -------
5310    dictionary holding summary connection statistics in wide format and matrix format.
5311
5312    Example
5313    -------
5314    >>> import antspymm
5315    """
5316    import os
5317    import re
5318    import nibabel as nib
5319    import numpy as np
5320    import ants
5321    from dipy.io.gradients import read_bvals_bvecs
5322    from dipy.core.gradients import gradient_table
5323    from dipy.tracking import utils
5324    import dipy.reconst.dti as dti
5325    from dipy.segment.clustering import QuickBundles
5326    from dipy.tracking.utils import path_length
5327    from dipy.tracking.local_tracking import LocalTracking
5328    from dipy.tracking.streamline import Streamlines
5329    import os
5330    import re
5331    import nibabel as nib
5332    import numpy as np
5333    import ants
5334    from dipy.io.gradients import read_bvals_bvecs
5335    from dipy.core.gradients import gradient_table
5336    from dipy.tracking import utils
5337    import dipy.reconst.dti as dti
5338    from dipy.segment.clustering import QuickBundles
5339    from dipy.tracking.utils import path_length
5340    from dipy.tracking.local_tracking import LocalTracking
5341    from dipy.tracking.streamline import Streamlines
5342    volUnit = np.prod( ants.get_spacing( label_image ) )
5343    labels = label_image.numpy()
5344
5345    affine = ants_to_nibabel_affine(label_image) # to_nibabel(label_image).affine
5346
5347    import numpy as np
5348    from dipy.io.image import load_nifti_data, load_nifti, save_nifti
5349    import pandas as pd
5350    ulabs = label_dataframe['Label']
5351    labels_to_connect = ulabs[ulabs > 0]
5352    Ctdf = None
5353    lin_T, offset = utils._mapping_to_voxel(affine)
5354    label_image_np = label_image.numpy()
5355    def check_it( sl, target_label, label_image, index, not_label = None ):
5356        pt = utils._to_voxel_coordinates(sl[index,:], lin_T, offset)
5357        mylab = (label_image[ pt[0], pt[1], pt[2] ]).astype(int)
5358        if not_label is None:
5359            if ( mylab == target_label ).sum() > 0 :
5360                return { 'ok': True, 'label':mylab }
5361        else:
5362            if ( mylab == target_label ).sum() > 0 and ( mylab == not_label ).sum() == 0:
5363                return { 'ok': True, 'label':mylab }
5364        return { 'ok': False, 'label':None }
5365    ct=0
5366    which = lambda lst:list(np.where(lst)[0])
5367    myCount = np.zeros( [len(ulabs),len(ulabs)])
5368    for k in range( len( streamlines ) ):
5369            sl = streamlines[k]
5370            mycheck = check_it( sl, labels_to_connect, label_image_np, index=0 )
5371            if mycheck['ok']:
5372                exclabel=mycheck['label']
5373                lsl = len( sl )-1
5374                mycheck2 = check_it( sl, labels_to_connect, label_image_np, index=lsl, not_label=exclabel )
5375                if mycheck2['ok']:
5376                    myCount[ulabs == mycheck['label'],ulabs == mycheck2['label']]+=1
5377                    ct=ct+1
5378    Ctdf = label_dataframe.copy()
5379    for k in range(len(ulabs)):
5380            nn3 = "CnxCount"+str(k).zfill(3)
5381            Ctdf.insert(Ctdf.shape[1], nn3, myCount[k,:] )
5382    Ctdfw = antspyt1w.merge_hierarchical_csvs_to_wide_format( { 'networkc': Ctdf },  Ctdf.keys()[2:Ctdf.shape[1]] )
5383    return { 'connectivity_matrix' :  myCount, 'connectivity_wide' : Ctdfw }

Summarize network connetivity of the input streamlines between all of the regions in the label set.

Arguments

streamlines : streamline object from dipy

label_image : atlas labels

label_dataframe : pandas dataframe containing descriptions for the labels in antspy style (Label,Description columns)

verbose : boolean

Returns

dictionary holding summary connection statistics in wide format and matrix format.

Example

>>> import antspymm
def hierarchical_modality_summary( target_image, hier, transformlist, modality_name, return_keys=['Mean', 'Volume'], verbose=False):
5526def hierarchical_modality_summary(
5527    target_image,
5528    hier,
5529    transformlist,
5530    modality_name,
5531    return_keys = ["Mean","Volume"],
5532    verbose = False ):
5533    """
5534
5535    Use output of antspyt1w.hierarchical to summarize a modality
5536
5537    Arguments
5538    ---------
5539
5540    target_image : the image to summarize - should be brain extracted
5541
5542    hier : dictionary holding antspyt1w.hierarchical output
5543
5544    transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)
5545
5546    modality_name : adds the modality name to the data frame columns
5547
5548    return_keys = ["Mean","Volume"] keys to return
5549
5550    verbose : boolean
5551
5552    Returns
5553    -------
5554    data frame holding summary statistics in wide format
5555
5556    Example
5557    -------
5558    >>> import antspymm
5559    """
5560    dfout = pd.DataFrame()
5561    def myhelper( target_image, seg, mytx, mapname, modname, mydf, extra='', verbose=False ):
5562        if verbose:
5563            print( mapname )
5564        target_image_mask = ants.image_clone( target_image ) * 0.0
5565        target_image_mask[ target_image != 0 ] = 1
5566        cortmapped = ants.apply_transforms(
5567            target_image,
5568            seg,
5569            mytx, interpolator='nearestNeighbor' ) * target_image_mask
5570        mapped = antspyt1w.map_intensity_to_dataframe(
5571            mapname,
5572            target_image,
5573            cortmapped )
5574        mapped.iloc[:,1] = modname + '_' + extra + mapped.iloc[:,1]
5575        mappedw = antspyt1w.merge_hierarchical_csvs_to_wide_format(
5576            { 'x' : mapped},
5577            col_names = return_keys )
5578        if verbose:
5579            print( mappedw.keys() )
5580        if mydf.shape[0] > 0:
5581            mydf = pd.concat( [ mydf, mappedw], axis=1, ignore_index=False )
5582        else:
5583            mydf = mappedw
5584        return mydf
5585    if hier['dkt_parc']['dkt_cortex'] is not None:
5586        dfout = myhelper( target_image, hier['dkt_parc']['dkt_cortex'], transformlist,
5587            "dkt", modality_name, dfout, extra='', verbose=verbose )
5588    if hier['deep_cit168lab'] is not None:
5589        dfout = myhelper( target_image, hier['deep_cit168lab'], transformlist,
5590            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='deep_', verbose=verbose )
5591    if hier['cit168lab'] is not None:
5592        dfout = myhelper( target_image, hier['cit168lab'], transformlist,
5593            "CIT168_Reinf_Learn_v1_label_descriptions_pad", modality_name, dfout, extra='', verbose=verbose  )
5594    if hier['bf'] is not None:
5595        dfout = myhelper( target_image, hier['bf'], transformlist,
5596            "nbm3CH13", modality_name, dfout, extra='', verbose=verbose  )
5597    # if hier['mtl'] is not None:
5598    #    dfout = myhelper( target_image, hier['mtl'], reg,
5599    #        "mtl_description", modality_name, dfout, extra='', verbose=verbose  )
5600    return dfout

Use output of antspyt1w.hierarchical to summarize a modality

Arguments

target_image : the image to summarize - should be brain extracted

hier : dictionary holding antspyt1w.hierarchical output

transformlist : spatial transformations mapping from T1 to this modality (e.g. from ants.registration)

modality_name : adds the modality name to the data frame columns

return_keys = ["Mean","Volume"] keys to return

verbose : boolean

Returns

data frame holding summary statistics in wide format

Example

>>> import antspymm
def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30, transform=['rigid'], compreg=None, random_seed=42, verbose=False):
5612def tra_initializer( fixed, moving, n_simulations=32, max_rotation=30,
5613    transform=['rigid'], compreg=None, random_seed=42, verbose=False ):
5614    """
5615    multi-start multi-transform registration solution - based on ants.registration
5616
5617    fixed: fixed image
5618
5619    moving: moving image
5620
5621    n_simulations : number of simulations
5622
5623    max_rotation : maximum rotation angle
5624
5625    transform : list of transforms to loop through
5626
5627    compreg : registration results against which to compare
5628
5629    random_seed : random seed for reproducibility
5630
5631    verbose : boolean
5632
5633    """
5634    import random
5635    if random_seed is not None:
5636        random.seed(random_seed)
5637    if True:
5638        output_directory = tempfile.mkdtemp()
5639        output_directory_w = output_directory + "/tra_reg/"
5640        os.makedirs(output_directory_w,exist_ok=True)
5641        bestmi = math.inf
5642        bestvar = 0.0
5643        myorig = list(ants.get_origin( fixed ))
5644        mymax = 0;
5645        for k in range(len( myorig ) ):
5646            if abs(myorig[k]) > mymax:
5647                mymax = abs(myorig[k])
5648        maxtrans = mymax * 0.05
5649        if compreg is None:
5650            bestreg=ants.registration( fixed,moving,'Translation',
5651                outprefix=output_directory_w+"trans")
5652            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
5653        else :
5654            bestreg=compreg
5655            initx = ants.read_transform( bestreg['fwdtransforms'][0] )
5656        for mytx in transform:
5657            regtx = 'antsRegistrationSyNRepro[r]'
5658            with tempfile.NamedTemporaryFile(suffix='.h5') as tp:
5659                if mytx == 'translation':
5660                    regtx = 'Translation'
5661                    rRotGenerator = ants.contrib.RandomTranslate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
5662                elif mytx == 'affine':
5663                    regtx = 'Affine'
5664                    rRotGenerator = ants.contrib.RandomRotate3D( ( maxtrans*(-1.0), maxtrans ), reference=fixed )
5665                else:
5666                    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rotation*(-1.0), max_rotation ), reference=fixed )
5667                for k in range(n_simulations):
5668                    simtx = ants.compose_ants_transforms( [rRotGenerator.transform(), initx] )
5669                    ants.write_transform( simtx, tp.name )
5670                    if k > 0:
5671                        reg = ants.registration( fixed, moving, regtx,
5672                            initial_transform=tp.name,
5673                            outprefix=output_directory_w+"reg"+str(k),
5674                            verbose=False )
5675                    else:
5676                        reg = ants.registration( fixed, moving,
5677                            regtx,
5678                            outprefix=output_directory_w+"reg"+str(k),
5679                            verbose=False )
5680                    mymi = math.inf
5681                    temp = reg['warpedmovout']
5682                    myvar = temp.numpy().var()
5683                    if verbose:
5684                        print( str(k) + " : " + regtx  + " : " + mytx + " _var_ " + str( myvar ) )
5685                    if myvar > 0 :
5686                        mymi = ants.image_mutual_information( fixed, temp )
5687                        if mymi < bestmi:
5688                            if verbose:
5689                                print( "mi @ " + str(k) + " : " + str(mymi), flush=True)
5690                            bestmi = mymi
5691                            bestreg = reg
5692                            bestvar = myvar
5693        if bestvar == 0.0 and compreg is not None:
5694            return compreg        
5695        return bestreg

multi-start multi-transform registration solution - based on ants.registration

fixed: fixed image

moving: moving image

n_simulations : number of simulations

max_rotation : maximum rotation angle

transform : list of transforms to loop through

compreg : registration results against which to compare

random_seed : random seed for reproducibility

verbose : boolean

def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8, bias_correct=True, denoise=None, srmodel=None, target_range=[0, 1], poly_order='hist', normalize_nm=False, verbose=False):
5697def neuromelanin( list_nm_images, t1, t1_head, t1lab, brain_stem_dilation=8,
5698    bias_correct=True,
5699    denoise=None,
5700    srmodel=None,
5701    target_range=[0,1],
5702    poly_order='hist',
5703    normalize_nm = False,
5704    verbose=False ) :
5705
5706  """
5707  Outputs the averaged and registered neuromelanin image, and neuromelanin labels
5708
5709  Arguments
5710  ---------
5711  list_nm_image : list of ANTsImages
5712    list of neuromenlanin repeat images
5713
5714  t1 : ANTsImage
5715    input 3-D T1 brain image
5716
5717  t1_head : ANTsImage
5718    input 3-D T1 head image
5719
5720  t1lab : ANTsImage
5721    t1 labels that will be propagated to the NM
5722
5723  brain_stem_dilation : integer default 8
5724    dilates the brain stem mask to better match coverage of NM
5725
5726  bias_correct : boolean
5727
5728  denoise : None or integer
5729
5730  srmodel : None -- this is a work in progress feature, probably not optimal
5731
5732  target_range : 2-element tuple
5733        a tuple or array defining the (min, max) of the input image
5734        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
5735        intensity. This range should match the mapping used in the training
5736        of the network.
5737
5738  poly_order : if not None, will fit a global regression model to map
5739      intensity back to original histogram space; if 'hist' will match
5740      by histogram matching - ants.histogram_match_image
5741
5742  normalize_nm : boolean - WIP not validated
5743
5744  verbose : boolean
5745
5746  Returns
5747  ---------
5748  Averaged and registered neuromelanin image and neuromelanin labels and wide csv
5749
5750  """
5751
5752  fnt=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni.nii.gz" )
5753  fntNM=os.path.expanduser("~/.antspymm/CIT168_T1w_700um_pad_adni_NM_norm_avg.nii.gz" )
5754  fntbst=os.path.expanduser("~/.antspyt1w/CIT168_T1w_700um_pad_adni_brainstem.nii.gz")
5755  fnslab=os.path.expanduser("~/.antspyt1w/CIT168_MT_Slab_adni.nii.gz")
5756  fntseg=os.path.expanduser("~/.antspyt1w/det_atlas_25_pad_LR_adni.nii.gz")
5757
5758  template = mm_read( fnt )
5759  templateNM = ants.iMath( mm_read( fntNM ), "Normalize" )
5760  templatebstem = mm_read( fntbst ).threshold_image( 1, 1000 )
5761  # reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
5762  reg = ants.registration( t1, template, 'antsRegistrationSyNQuickRepro[s]' )
5763  # map NM avg to t1 for neuromelanin processing
5764  nmavg2t1 = ants.apply_transforms( t1, templateNM,
5765    reg['fwdtransforms'], interpolator='linear' )
5766  slab2t1 = ants.threshold_image( nmavg2t1, "Otsu", 2 ).threshold_image(1,2).iMath("MD",1).iMath("FillHoles")
5767  # map brain stem and slab to t1 for neuromelanin processing
5768  bstem2t1 = ants.apply_transforms( t1, templatebstem,
5769    reg['fwdtransforms'],
5770    interpolator='nearestNeighbor' ).iMath("MD",1)
5771  slab2t1B = ants.apply_transforms( t1, mm_read( fnslab ),
5772    reg['fwdtransforms'], interpolator = 'nearestNeighbor')
5773  bstem2t1 = ants.crop_image( bstem2t1, slab2t1 )
5774  cropper = ants.decrop_image( bstem2t1, slab2t1 ).iMath("MD",brain_stem_dilation)
5775
5776  # Average images in image_list
5777  nm_avg = list_nm_images[0]*0.0
5778  for k in range(len( list_nm_images )):
5779    if denoise is not None:
5780        list_nm_images[k] = ants.denoise_image( list_nm_images[k],
5781            shrink_factor=1,
5782            p=denoise,
5783            r=denoise+1,
5784            noise_model='Gaussian' )
5785    if bias_correct :
5786        n4mask = ants.threshold_image( ants.iMath(list_nm_images[k], "Normalize" ), 0.05, 1 )
5787        list_nm_images[k] = ants.n4_bias_field_correction( list_nm_images[k], mask=n4mask )
5788    nm_avg = nm_avg + ants.resample_image_to_target( list_nm_images[k], nm_avg ) / len( list_nm_images )
5789
5790  if verbose:
5791      print("Register each nm image in list_nm_images to the averaged nm image (avg)")
5792  nm_avg_new = nm_avg * 0.0
5793  txlist = []
5794  for k in range(len( list_nm_images )):
5795    if verbose:
5796        print(str(k) + " of " + str(len( list_nm_images ) ) )
5797    current_image = ants.registration( list_nm_images[k], nm_avg,
5798        type_of_transform = 'antsRegistrationSyNRepro[r]' )
5799    txlist.append( current_image['fwdtransforms'][0] )
5800    current_image = current_image['warpedfixout']
5801    nm_avg_new = nm_avg_new + current_image / len( list_nm_images )
5802  nm_avg = nm_avg_new
5803
5804  if verbose:
5805      print("do slab registration to map anatomy to NM space")
5806  t1c = ants.crop_image( t1_head, slab2t1 ).iMath("Normalize") # old way
5807  nmavg2t1c = ants.crop_image( nmavg2t1, slab2t1 ).iMath("Normalize")
5808  # slabreg = ants.registration( nm_avg, nmavg2t1c, 'antsRegistrationSyNRepro[r]' )
5809  slabreg = tra_initializer( nm_avg, t1c, verbose=verbose )
5810  if False:
5811      slabregT1 = tra_initializer( nm_avg, t1c, verbose=verbose  )
5812      miNM = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
5813            ants.iMath(slabreg0['warpedmovout'],"Normalize") )
5814      miT1 = ants.image_mutual_information( ants.iMath(nm_avg,"Normalize"),
5815            ants.iMath(slabreg1['warpedmovout'],"Normalize") )
5816      if miT1 < miNM:
5817        slabreg = slabregT1
5818  labels2nm = ants.apply_transforms( nm_avg, t1lab, slabreg['fwdtransforms'],
5819    interpolator = 'genericLabel' )
5820  cropper2nm = ants.apply_transforms( nm_avg, cropper, slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
5821  nm_avg_cropped = ants.crop_image( nm_avg, cropper2nm )
5822
5823  if verbose:
5824      print("now map these labels to each individual nm")
5825  crop_mask_list = []
5826  crop_nm_list = []
5827  for k in range(len( list_nm_images )):
5828      concattx = []
5829      concattx.append( txlist[k] )
5830      concattx.append( slabreg['fwdtransforms'][0] )
5831      cropmask = ants.apply_transforms( list_nm_images[k], cropper,
5832        concattx, interpolator = 'nearestNeighbor' )
5833      crop_mask_list.append( cropmask )
5834      temp = ants.crop_image( list_nm_images[k], cropmask )
5835      crop_nm_list.append( temp )
5836
5837  if srmodel is not None:
5838      if verbose:
5839          print( " start sr " + str(len( crop_nm_list )) )
5840      for k in range(len( crop_nm_list )):
5841          if verbose:
5842              print( " do sr " + str(k) )
5843              print( crop_nm_list[k] )
5844          temp = antspynet.apply_super_resolution_model_to_image(
5845                crop_nm_list[k], srmodel, target_range=target_range,
5846                regression_order=None )
5847          if poly_order is not None:
5848              bilin = ants.resample_image_to_target( crop_nm_list[k], temp )
5849              if poly_order == 'hist':
5850                  temp = ants.histogram_match_image( temp, bilin )
5851              else:
5852                  temp = antspynet.regression_match_image( temp, bilin, poly_order = poly_order )
5853          crop_nm_list[k] = temp
5854
5855  nm_avg_cropped = crop_nm_list[0]*0.0
5856  if verbose:
5857      print( "cropped average" )
5858      print( nm_avg_cropped )
5859  for k in range(len( crop_nm_list )):
5860      nm_avg_cropped = nm_avg_cropped + ants.apply_transforms( nm_avg_cropped,
5861        crop_nm_list[k], txlist[k] ) / len( crop_nm_list )
5862  for loop in range( 3 ):
5863      nm_avg_cropped_new = nm_avg_cropped * 0.0
5864      for k in range(len( crop_nm_list )):
5865            myreg = ants.registration(
5866                ants.iMath(nm_avg_cropped,"Normalize"),
5867                ants.iMath(crop_nm_list[k],"Normalize"),
5868                'antsRegistrationSyNRepro[r]' )
5869            warpednext = ants.apply_transforms(
5870                nm_avg_cropped_new,
5871                crop_nm_list[k],
5872                myreg['fwdtransforms'] )
5873            nm_avg_cropped_new = nm_avg_cropped_new + warpednext
5874      nm_avg_cropped = nm_avg_cropped_new / len( crop_nm_list )
5875
5876  slabregUpdated = tra_initializer( nm_avg_cropped, t1c, compreg=slabreg,verbose=verbose  )
5877  tempOrig = ants.apply_transforms( nm_avg_cropped_new, t1c, slabreg['fwdtransforms'] )
5878  tempUpdate = ants.apply_transforms( nm_avg_cropped_new, t1c, slabregUpdated['fwdtransforms'] )
5879  miUpdate = ants.image_mutual_information(
5880    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempUpdate,"Normalize") )
5881  miOrig = ants.image_mutual_information(
5882    ants.iMath(nm_avg_cropped,"Normalize"), ants.iMath(tempOrig,"Normalize") )
5883  if miUpdate < miOrig :
5884      slabreg = slabregUpdated
5885
5886  if normalize_nm:
5887      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
5888      nm_avg_cropped = ants.iMath( nm_avg_cropped, "TruncateIntensity",0.05,0.95)
5889      nm_avg_cropped = ants.iMath( nm_avg_cropped, "Normalize" )
5890
5891  labels2nm = ants.apply_transforms( nm_avg_cropped, t1lab,
5892        slabreg['fwdtransforms'], interpolator='nearestNeighbor' )
5893
5894  # fix the reference region - keep top two parts
5895  def get_biggest_part( x, labeln ):
5896      temp33 = ants.threshold_image( x, labeln, labeln ).iMath("GetLargestComponent")
5897      x[ x == labeln] = 0
5898      x[ temp33 == 1 ] = labeln
5899
5900  get_biggest_part( labels2nm, 33 )
5901  get_biggest_part( labels2nm, 34 )
5902
5903  if verbose:
5904      print( "map summary measurements to wide format" )
5905  nmdf = antspyt1w.map_intensity_to_dataframe(
5906          'CIT168_Reinf_Learn_v1_label_descriptions_pad',
5907          nm_avg_cropped,
5908          labels2nm)
5909  if verbose:
5910      print( "merge to wide format" )
5911  nmdf_wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
5912              {'NM' : nmdf},
5913              col_names = ['Mean'] )
5914
5915  rr_mask = ants.mask_image( labels2nm, labels2nm, [33,34] , binarize=True )
5916  sn_mask = ants.mask_image( labels2nm, labels2nm, [7,9,23,25] , binarize=True )
5917  nmavgsnr = mask_snr( nm_avg_cropped, rr_mask, sn_mask, bias_correct = False )
5918
5919  snavg = nm_avg_cropped[ sn_mask == 1].mean()
5920  rravg = nm_avg_cropped[ rr_mask == 1].mean()
5921  snstd = nm_avg_cropped[ sn_mask == 1].std()
5922  rrstd = nm_avg_cropped[ rr_mask == 1].std()
5923  vol_element = np.prod( ants.get_spacing(sn_mask) )
5924  snvol = vol_element * sn_mask.sum()
5925
5926  # get the mean voxel position of the SN
5927  if snvol > 0:
5928      sn_z = ants.transform_physical_point_to_index( sn_mask, ants.get_center_of_mass(sn_mask ))[2]
5929      sn_z = sn_z/sn_mask.shape[2] # around 0.5 would be nice
5930  else:
5931      sn_z = math.nan
5932
5933  nm_evr = 0.0
5934  if cropper2nm.sum() > 0:
5935    nm_evr = antspyt1w.patch_eigenvalue_ratio( nm_avg, 512, [6,6,6], 
5936        evdepth = 0.9, mask=cropper2nm )
5937
5938  simg = ants.smooth_image( nm_avg_cropped, np.min(ants.get_spacing(nm_avg_cropped)) )
5939  k = 2.0
5940  rrthresh = (rravg + k * rrstd)
5941  nmabovekthresh_mask = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5942  snvolabovethresh = vol_element * nmabovekthresh_mask.sum()
5943  snintmeanabovethresh = float( ( simg * nmabovekthresh_mask ).mean() )
5944  snintsumabovethresh = float( ( simg * nmabovekthresh_mask ).sum() )
5945
5946  k = 3.0
5947  rrthresh = (rravg + k * rrstd)
5948  nmabovekthresh_mask3 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5949  snvolabovethresh3 = vol_element * nmabovekthresh_mask3.sum()
5950
5951  k = 1.0
5952  rrthresh = (rravg + k * rrstd)
5953  nmabovekthresh_mask1 = sn_mask * ants.threshold_image( simg, rrthresh, math.inf)
5954  snvolabovethresh1 = vol_element * nmabovekthresh_mask1.sum()
5955  
5956  if verbose:
5957    print( "nm vol @2std above rrmean: " + str( snvolabovethresh ) )
5958    print( "nm intmean @2std above rrmean: " + str( snintmeanabovethresh ) )
5959    print( "nm intsum @2std above rrmean: " + str( snintsumabovethresh ) )
5960    print( "nm done" )
5961
5962  return convert_np_in_dict( {
5963      'NM_avg' : nm_avg,
5964      'NM_avg_cropped' : nm_avg_cropped,
5965      'NM_labels': labels2nm,
5966      'NM_cropped': crop_nm_list,
5967      'NM_midbrainROI': cropper2nm,
5968      'NM_dataframe': nmdf,
5969      'NM_dataframe_wide': nmdf_wide,
5970      't1_to_NM': slabreg['warpedmovout'],
5971      't1_to_NM_transform' : slabreg['fwdtransforms'],
5972      'NM_avg_signaltonoise' : nmavgsnr,
5973      'NM_avg_substantianigra' : snavg,
5974      'NM_std_substantianigra' : snstd,
5975      'NM_volume_substantianigra' : snvol,
5976      'NM_volume_substantianigra_1std' : snvolabovethresh1,
5977      'NM_volume_substantianigra_2std' : snvolabovethresh,
5978      'NM_intmean_substantianigra_2std' : snintmeanabovethresh,
5979      'NM_intsum_substantianigra_2std' : snintsumabovethresh,
5980      'NM_volume_substantianigra_3std' : snvolabovethresh3,
5981      'NM_avg_refregion' : rravg,
5982      'NM_std_refregion' : rrstd,
5983      'NM_min' : nm_avg_cropped.min(),
5984      'NM_max' : nm_avg_cropped.max(),
5985      'NM_mean' : nm_avg_cropped.numpy().mean(),
5986      'NM_sd' : np.std( nm_avg_cropped.numpy() ),
5987      'NM_q0pt05' : np.quantile( nm_avg_cropped.numpy(), 0.05 ),
5988      'NM_q0pt10' : np.quantile( nm_avg_cropped.numpy(), 0.10 ),
5989      'NM_q0pt90' : np.quantile( nm_avg_cropped.numpy(), 0.90 ),
5990      'NM_q0pt95' : np.quantile( nm_avg_cropped.numpy(), 0.95 ),
5991      'NM_substantianigra_z_coordinate' : sn_z,
5992      'NM_evr' : nm_evr,
5993      'NM_count': len( list_nm_images )
5994       } )

Outputs the averaged and registered neuromelanin image, and neuromelanin labels

Arguments

list_nm_image : list of ANTsImages list of neuromenlanin repeat images

t1 : ANTsImage input 3-D T1 brain image

t1_head : ANTsImage input 3-D T1 head image

t1lab : ANTsImage t1 labels that will be propagated to the NM

brain_stem_dilation : integer default 8 dilates the brain stem mask to better match coverage of NM

bias_correct : boolean

denoise : None or integer

srmodel : None -- this is a work in progress feature, probably not optimal

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

poly_order : if not None, will fit a global regression model to map intensity back to original histogram space; if 'hist' will match by histogram matching - ants.histogram_match_image

normalize_nm : boolean - WIP not validated

verbose : boolean

Returns

Averaged and registered neuromelanin image and neuromelanin labels and wide csv

def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation, f=[0.03, 0.08], FD_threshold=5.0, spa=None, spt=None, nc=5, outlier_threshold=0.25, ica_components=0, impute=True, censor=True, despike=2.5, motion_as_nuisance=True, powers=False, upsample=3.0, clean_tmp=None, paramset='unset', verbose=False):
6092def resting_state_fmri_networks( fmri, fmri_template, t1, t1segmentation,
6093    f=[0.03, 0.08],
6094    FD_threshold=5.0,
6095    spa = None,
6096    spt = None,
6097    nc = 5,
6098    outlier_threshold=0.250,
6099    ica_components = 0,
6100    impute = True,
6101    censor = True,
6102    despike = 2.5,
6103    motion_as_nuisance = True,
6104    powers = False,
6105    upsample = 3.0,
6106    clean_tmp = None,
6107    paramset='unset',
6108    verbose=False ):
6109  """
6110  Compute resting state network correlation maps based on the J Power labels.
6111  This will output a map for each of the major network systems.  This function 
6112  will by optionally upsample data to 2mm during the registration process if data 
6113  is below that resolution.
6114
6115  registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations
6116
6117  Arguments
6118  ---------
6119  fmri : BOLD fmri antsImage
6120
6121  fmri_template : reference space for BOLD
6122
6123  t1 : ANTsImage
6124    input 3-D T1 brain image (brain extracted)
6125
6126  t1segmentation : ANTsImage
6127    t1 segmentation - a six tissue segmentation image in T1 space
6128
6129  f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015
6130
6131  spa : gaussian smoothing for spatial component (physical coordinates)
6132
6133  spt : gaussian smoothing for temporal component
6134
6135  nc  : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm
6136
6137  ica_components : integer if greater than 0 then include ica components
6138
6139  impute : boolean if True, then use imputation in f/ALFF, PerAF calculation
6140
6141  censor : boolean if True, then use censoring (censoring)
6142
6143  despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction
6144
6145  motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance
6146
6147  powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)
6148
6149  upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution
6150
6151  clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing.  if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.
6152
6153  verbose : boolean
6154
6155  Returns
6156  ---------
6157  a dictionary containing the derived network maps
6158
6159  References
6160  ---------
6161
6162  10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."
6163
6164  10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable"  "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding  high frequency information) reduced both reliability and signal-noise separation"
6165
6166  10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.
6167
6168  10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.
6169
6170  """
6171
6172  import warnings
6173
6174  if clean_tmp is not None:
6175    clean_tmp_directory( age_hours = clean_tmp )
6176
6177  if nc > 1:
6178    nc = int(nc)
6179  else:
6180    nc=float(nc)
6181
6182  type_of_transform="antsRegistrationSyNQuickRepro[r]" # , # should probably not change this
6183  remove_it=True
6184  output_directory = tempfile.mkdtemp()
6185  output_directory_w = output_directory + "/ts_t1_reg/"
6186  os.makedirs(output_directory_w,exist_ok=True)
6187  ofnt1tx = tempfile.NamedTemporaryFile(delete=False,suffix='t1_deformation',dir=output_directory_w).name
6188
6189  import numpy as np
6190# Assuming core and utils are modules or packages with necessary functions
6191
6192  if upsample > 0.0:
6193      spc = ants.get_spacing( fmri )
6194      minspc = upsample
6195      if min(spc[0:3]) < minspc:
6196          minspc = min(spc[0:3])
6197      newspc = [minspc,minspc,minspc]
6198      fmri_template = ants.resample_image( fmri_template, newspc, interp_type=0 )
6199
6200  def temporal_derivative_same_shape(array):
6201    """
6202    Compute the temporal derivative of a 2D numpy array along the 0th axis (time)
6203    and ensure the output has the same shape as the input.
6204
6205    :param array: 2D numpy array with time as the 0th axis.
6206    :return: 2D numpy array of the temporal derivative with the same shape as input.
6207    """
6208    derivative = np.diff(array, axis=0)
6209    
6210    # Append a row to maintain the same shape
6211    # You can choose to append a row of zeros or the last row of the derivative
6212    # Here, a row of zeros is appended
6213    zeros_row = np.zeros((1, array.shape[1]))
6214    return np.vstack((zeros_row, derivative ))
6215
6216  def compute_tSTD(M, quantile, x=0, axis=0):
6217    stdM = np.std(M, axis=axis)
6218    # set bad values to x
6219    stdM[stdM == 0] = x
6220    stdM[np.isnan(stdM)] = x
6221    tt = round(quantile * 100)
6222    threshold_std = np.percentile(stdM, tt)
6223    return {'tSTD': stdM, 'threshold_std': threshold_std}
6224
6225  def get_compcor_matrix(boldImage, mask, quantile):
6226    """
6227    Compute the compcor matrix.
6228
6229    :param boldImage: The bold image.
6230    :param mask: The mask to apply, if None, it will be computed.
6231    :param quantile: Quantile for computing threshold in tSTD.
6232    :return: The compor matrix.
6233    """
6234    if mask is None:
6235        temp = ants.slice_image(boldImage, axis=boldImage.dimension - 1, idx=0)
6236        mask = ants.get_mask(temp)
6237
6238    imagematrix = ants.timeseries_to_matrix(boldImage, mask)
6239    temp = compute_tSTD(imagematrix, quantile, 0)
6240    tsnrmask = ants.make_image(mask, temp['tSTD'])
6241    tsnrmask = ants.threshold_image(tsnrmask, temp['threshold_std'], temp['tSTD'].max())
6242    M = ants.timeseries_to_matrix(boldImage, tsnrmask)
6243    return M
6244
6245
6246  from sklearn.decomposition import FastICA
6247  def find_indices(lst, value):
6248    return [index for index, element in enumerate(lst) if element > value]
6249
6250  def mean_of_list(lst):
6251    if not lst:  # Check if the list is not empty
6252        return 0  # Return 0 or appropriate value for an empty list
6253    return sum(lst) / len(lst)
6254  fmrispc = list( ants.get_spacing( fmri ) )
6255  if spa is None:
6256    spa = mean_of_list( fmrispc[0:3] ) * 1.0
6257  if spt is None:
6258    spt = fmrispc[3] * 0.5
6259      
6260  import numpy as np
6261  import pandas as pd
6262  import re
6263  import math
6264  # point data resources
6265  A = np.zeros((1,1))
6266  dfnname='DefaultMode'
6267  if powers:
6268      powers_areal_mni_itk = pd.read_csv( get_data('powers_mni_itk', target_extension=".csv")) # power coordinates
6269      coords='powers'
6270  else:
6271      powers_areal_mni_itk = pd.read_csv( get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".csv")) # yeo 2023 coordinates
6272      coords='yeo_17_500_2023'
6273  fmri = ants.iMath( fmri, 'Normalize' )
6274  bmask = antspynet.brain_extraction( fmri_template, 'bold' ).threshold_image(0.5,1).iMath("FillHoles")
6275  if verbose:
6276      print("Begin rsfmri motion correction")
6277  debug=False
6278  if debug:
6279      ants.image_write( fmri_template, '/tmp/fmri_template.nii.gz' )
6280      ants.image_write( fmri, '/tmp/fmri.nii.gz' )
6281      print("debug wrote fmri and fmri_template")
6282  # mot-co
6283  corrmo = timeseries_reg(
6284    fmri, fmri_template,
6285    type_of_transform=type_of_transform,
6286    total_sigma=0.5,
6287    fdOffset=2.0,
6288    trim = 8,
6289    output_directory=None,
6290    verbose=verbose,
6291    syn_metric='CC',
6292    syn_sampling=2,
6293    reg_iterations=[40,20,5],
6294    return_numpy_motion_parameters=True )
6295  
6296  if verbose:
6297      print("End rsfmri motion correction")
6298      print("--maximum motion : " + str(corrmo['FD'].max()) )
6299      print("=== next anatomically based mapping ===")
6300
6301  despiking_count = np.zeros( corrmo['motion_corrected'].shape[3] )
6302  if despike > 0.0:
6303      corrmo['motion_corrected'], despiking_count = despike_time_series_afni( corrmo['motion_corrected'], c1=despike )
6304
6305  despiking_count_summary = despiking_count.sum() / np.prod( corrmo['motion_corrected'].shape )
6306  high_motion_count=(corrmo['FD'] > FD_threshold ).sum()
6307  high_motion_pct=high_motion_count / fmri.shape[3]
6308
6309  # filter mask based on TSNR
6310  mytsnr = tsnr( corrmo['motion_corrected'], bmask )
6311  mytsnrThresh = np.quantile( mytsnr.numpy(), 0.995 )
6312  tsnrmask = ants.threshold_image( mytsnr, 0, mytsnrThresh ).morphology("close",2)
6313  bmask = bmask * tsnrmask
6314
6315  # anatomical mapping
6316  und = fmri_template * bmask
6317  t1reg = ants.registration( und, t1,
6318     "antsRegistrationSyNQuickRepro[s]", outprefix=ofnt1tx )
6319  if verbose:
6320    print("t1 2 bold done")
6321  gmseg = ants.threshold_image( t1segmentation, 2, 2 )
6322  gmseg = gmseg + ants.threshold_image( t1segmentation, 4, 4 )
6323  gmseg = ants.threshold_image( gmseg, 1, 4 )
6324  gmseg = ants.iMath( gmseg, 'MD', 1 ) # FIXMERSF
6325  gmseg = ants.apply_transforms( und, gmseg,
6326    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' ) * bmask
6327  csfAndWM = ( ants.threshold_image( t1segmentation, 1, 1 ) +
6328               ants.threshold_image( t1segmentation, 3, 3 ) ).morphology("erode",1)
6329  csfAndWM = ants.apply_transforms( und, csfAndWM,
6330    t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6331  csf = ants.threshold_image( t1segmentation, 1, 1 )
6332  csf = ants.apply_transforms( und, csf, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6333  wm = ants.threshold_image( t1segmentation, 3, 3 ).morphology("erode",1)
6334  wm = ants.apply_transforms( und, wm, t1reg['fwdtransforms'], interpolator = 'nearestNeighbor' )  * bmask
6335  if powers:
6336    ch2 = mm_read( ants.get_ants_data( "ch2" ) )
6337  else:
6338    ch2 = mm_read( get_data( "PPMI_template0_brain", target_extension='.nii.gz' ) )
6339  treg = ants.registration( 
6340    # this is to make the impact of resolution consistent
6341    ants.resample_image(t1, [1.0,1.0,1.0], interp_type=0), 
6342    ch2, "antsRegistrationSyNQuickRepro[s]" )
6343  if powers:
6344    concatx2 = treg['invtransforms'] + t1reg['invtransforms']
6345    pts2bold = ants.apply_transforms_to_points( 3, powers_areal_mni_itk, concatx2,
6346        whichtoinvert = ( True, False, True, False ) )
6347    locations = pts2bold.iloc[:,:3].values
6348    ptImg = ants.make_points_image( locations, bmask, radius = 2 )
6349  else:
6350    concatx2 = t1reg['fwdtransforms'] + treg['fwdtransforms']    
6351    rsfsegfn = get_data('ppmi_template_500Parcels_Yeo2011_17Networks_2023_homotopic', target_extension=".nii.gz")
6352    rsfsegimg = ants.image_read( rsfsegfn )
6353    ptImg = ants.apply_transforms( und, rsfsegimg, concatx2, interpolator='nearestNeighbor' ) * bmask
6354    pts2bold = powers_areal_mni_itk
6355    # ants.plot( und, ptImg, crop=True, axis=2 )
6356
6357  # optional smoothing
6358  tr = ants.get_spacing( corrmo['motion_corrected'] )[3]
6359  smth = ( spa, spa, spa, spt ) # this is for sigmaInPhysicalCoordinates = TRUE
6360  simg = ants.smooth_image( corrmo['motion_corrected'], smth, sigma_in_physical_coordinates = True )
6361
6362  # collect censoring indices
6363  hlinds = find_indices( corrmo['FD'], FD_threshold )
6364  if verbose:
6365    print("high motion indices")
6366    print( hlinds )
6367  if outlier_threshold < 1.0 and outlier_threshold > 0.0:
6368    fmrimotcorr, hlinds2 = loop_timeseries_censoring( corrmo['motion_corrected'], 
6369      threshold=outlier_threshold, verbose=verbose )
6370    hlinds.extend( hlinds2 )
6371    del fmrimotcorr
6372  hlinds = list(set(hlinds)) # make unique
6373
6374  # nuisance
6375  globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], bmask )
6376  globalsignal = np.nanmean( globalmat, axis = 1 )
6377  del globalmat
6378  compcorquantile=0.50
6379  nc_wm=nc_csf=nc
6380  if nc < 1:
6381    globalmat = get_compcor_matrix( corrmo['motion_corrected'], wm, compcorquantile )
6382    nc_wm = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
6383    globalmat = get_compcor_matrix( corrmo['motion_corrected'], csf, compcorquantile )
6384    nc_csf = int(estimate_optimal_pca_components( data=globalmat, variance_threshold=nc))
6385    del globalmat
6386  if verbose:
6387    print("include compcor components as nuisance: csf " + str(nc_csf) + " wm " + str(nc_wm))
6388  mycompcor_csf = ants.compcor( corrmo['motion_corrected'],
6389    ncompcor=nc_csf, quantile=compcorquantile, mask = csf,
6390    filter_type='polynomial', degree=2 )
6391  mycompcor_wm = ants.compcor( corrmo['motion_corrected'],
6392    ncompcor=nc_wm, quantile=compcorquantile, mask = wm,
6393    filter_type='polynomial', degree=2 )
6394  nuisance = np.c_[ mycompcor_csf[ 'components' ], mycompcor_wm[ 'components' ] ]
6395
6396  if motion_as_nuisance:
6397      if verbose:
6398          print("include motion as nuisance")
6399          print( corrmo['motion_parameters'].shape )
6400      deriv = temporal_derivative_same_shape( corrmo['motion_parameters']  )
6401      nuisance = np.c_[ nuisance, corrmo['motion_parameters'], deriv ]
6402
6403  if ica_components > 0:
6404    if verbose:
6405        print("include ica components as nuisance: " + str(ica_components))
6406    ica = FastICA(n_components=ica_components, max_iter=10000, tol=0.001, random_state=42 )
6407    globalmat = ants.timeseries_to_matrix( corrmo['motion_corrected'], csfAndWM )
6408    nuisance_ica = ica.fit_transform(globalmat)  # Reconstruct signals
6409    nuisance = np.c_[ nuisance, nuisance_ica ]
6410    del globalmat
6411
6412  # concat all nuisance data
6413  # nuisance = np.c_[ nuisance, mycompcor['basis'] ]
6414  # nuisance = np.c_[ nuisance, corrmo['FD'] ]
6415  nuisance = np.c_[ nuisance, globalsignal ]
6416
6417  if impute:
6418    simgimp = impute_timeseries( simg, hlinds, method='linear')
6419  else:
6420    simgimp = simg
6421
6422  # falff/alff stuff  def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
6423  myfalff=alff_image( simgimp, bmask, flo=f[0], fhi=f[1], nuisance=nuisance  )
6424
6425  # bandpass any data collected before here -- if bandpass requested
6426  if f[0] > 0 and f[1] < 1.0:
6427    if verbose:
6428        print( "bandpass: " + str(f[0]) + " <=> " + str( f[1] ) )
6429    nuisance = ants.bandpass_filter_matrix( nuisance, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
6430    globalmat = ants.timeseries_to_matrix( simg, bmask )
6431    globalmat = ants.bandpass_filter_matrix( globalmat, tr = tr, lowf=f[0], highf=f[1] ) # some would argue against this
6432    simg = ants.matrix_to_timeseries( simg, globalmat, bmask )
6433
6434  if verbose:
6435    print("now regress nuisance")
6436
6437
6438  if len( hlinds ) > 0 :
6439    if censor:
6440        nuisance = remove_elements_from_numpy_array( nuisance, hlinds  )
6441        simg = remove_volumes_from_timeseries( simg, hlinds )
6442
6443  gmmat = ants.timeseries_to_matrix( simg, bmask )
6444  gmmat = ants.regress_components( gmmat, nuisance )
6445  simg = ants.matrix_to_timeseries(simg, gmmat, bmask)
6446
6447
6448  # structure the output data
6449  outdict = {}
6450  outdict['paramset'] = paramset
6451  outdict['upsampling'] = upsample
6452  outdict['coords'] = coords
6453  outdict['dfnname']=dfnname
6454  outdict['meanBold'] = und
6455
6456  # add correlation matrix that captures each node pair
6457  # some of the spheres overlap so extract separately from each ROI
6458  if powers:
6459    nPoints = int(pts2bold['ROI'].max())
6460    pointrange = list(range(int(nPoints)))
6461  else:
6462    nPoints = int(ptImg.max())
6463    pointrange = list(range(int(nPoints)))
6464  nVolumes = simg.shape[3]
6465  meanROI = np.zeros([nVolumes, nPoints])
6466  roiNames = []
6467  if debug:
6468      ptImgAll = und * 0.
6469  for i in pointrange:
6470    # specify name for matrix entries that's links back to ROI number and network; e.g., ROI1_Uncertain
6471    netLabel = re.sub( " ", "", pts2bold.loc[i,'SystemName'])
6472    netLabel = re.sub( "-", "", netLabel )
6473    netLabel = re.sub( "/", "", netLabel )
6474    roiLabel = "ROI" + str(pts2bold.loc[i,'ROI']) + '_' + netLabel
6475    roiNames.append( roiLabel )
6476    if powers:
6477        ptImage = ants.make_points_image(pts2bold.iloc[[i],:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6478    else:
6479        #print("Doing " + pts2bold.loc[i,'SystemName'] + " at " + str(i) )
6480        #ptImage = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==pts2bold.loc[i,'SystemName']],binarize=True)
6481        ptImage=ants.threshold_image( ptImg, pts2bold.loc[i,'ROI'], pts2bold.loc[i,'ROI'] )
6482    if debug:
6483      ptImgAll = ptImgAll + ptImage
6484    if ptImage.sum() > 0 :
6485        meanROI[:,i] = ants.timeseries_to_matrix( simg, ptImage).mean(axis=1)
6486
6487  if debug:
6488      ants.image_write( simg, '/tmp/simg.nii.gz' )
6489      ants.image_write( ptImgAll, '/tmp/ptImgAll.nii.gz' )
6490      ants.image_write( und, '/tmp/und.nii.gz' )
6491      ants.image_write( und, '/tmp/und.nii.gz' )
6492
6493  # get full correlation matrix
6494  corMat = np.corrcoef(meanROI, rowvar=False)
6495  outputMat = pd.DataFrame(corMat)
6496  outputMat.columns = roiNames
6497  outputMat['ROIs'] = roiNames
6498  # add to dictionary
6499  outdict['fullCorrMat'] = outputMat
6500
6501  networks = powers_areal_mni_itk['SystemName'].unique()
6502  # this is just for human readability - reminds us of which we choose by default
6503  if powers:
6504    netnames = ['Cingulo-opercular Task Control', 'Default Mode',
6505                    'Memory Retrieval', 'Ventral Attention', 'Visual',
6506                    'Fronto-parietal Task Control', 'Salience', 'Subcortical',
6507                    'Dorsal Attention']
6508    numofnets = [3,5,6,7,8,9,10,11,13]
6509  else:
6510    netnames = networks
6511    numofnets = list(range(len(netnames)))
6512 
6513  ct = 0
6514  for mynet in numofnets:
6515    netname = re.sub( " ", "", networks[mynet] )
6516    netname = re.sub( "-", "", netname )
6517    ww = np.where( powers_areal_mni_itk['SystemName'] == networks[mynet] )[0]
6518    if powers:
6519        dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6520    else:
6521        dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[mynet]],binarize=True)
6522    if dfnImg.max() >= 1:
6523        if verbose:
6524            print("DO: " + coords + " " + netname )
6525        dfnmat = ants.timeseries_to_matrix( simg, ants.threshold_image( dfnImg, 1, dfnImg.max() ) )
6526        dfnsignal = np.nanmean( dfnmat, axis = 1 )
6527        nan_count_dfn = np.count_nonzero( np.isnan( dfnsignal) )
6528        if nan_count_dfn > 0 :
6529            warnings.warn( " mynet " + netnames[ mynet ] + " vs " +  " mean-signal has nans " + str( nan_count_dfn ) ) 
6530        gmmatDFNCorr = np.zeros( gmmat.shape[1] )
6531        if nan_count_dfn == 0:
6532            for k in range( gmmat.shape[1] ):
6533                nan_count_gm = np.count_nonzero( np.isnan( gmmat[:,k]) )
6534                if debug and False:
6535                    print( str( k ) +  " nans gm " + str(nan_count_gm)  )
6536                if nan_count_gm == 0:
6537                    gmmatDFNCorr[ k ] = pearsonr( dfnsignal, gmmat[:,k] )[0]
6538        corrImg = ants.make_image( bmask, gmmatDFNCorr  )
6539        outdict[ netname ] = corrImg * gmseg
6540    else:
6541        outdict[ netname ] = None
6542    ct = ct + 1
6543
6544  A = np.zeros( ( len( numofnets ) , len( numofnets ) ) )
6545  A_wide = np.zeros( ( 1, len( numofnets ) * len( numofnets ) ) )
6546  newnames=[]
6547  newnames_wide=[]
6548  ct = 0
6549  for i in range( len( numofnets ) ):
6550      netnamei = re.sub( " ", "", networks[numofnets[i]] )
6551      netnamei = re.sub( "-", "", netnamei )
6552      newnames.append( netnamei  )
6553      ww = np.where( powers_areal_mni_itk['SystemName'] == networks[numofnets[i]] )[0]
6554      if powers:
6555          dfnImg = ants.make_points_image(pts2bold.iloc[ww,:3].values, bmask, radius=1).threshold_image( 1, 1e9 )
6556      else:
6557          dfnImg = ants.mask_image( ptImg, ptImg, level=pts2bold['ROI'][pts2bold['SystemName']==networks[numofnets[i]]],binarize=True)
6558      for j in range( len( numofnets ) ):
6559          netnamej = re.sub( " ", "", networks[numofnets[j]] )
6560          netnamej = re.sub( "-", "", netnamej )
6561          newnames_wide.append( netnamei + "_2_" + netnamej )
6562          A[i,j] = 0
6563          if dfnImg is not None and netnamej is not None:
6564            subbit = dfnImg == 1
6565            if subbit is not None:
6566                if subbit.sum() > 0 and netnamej in outdict:
6567                    A[i,j] = outdict[ netnamej ][ subbit ].mean()
6568          A_wide[0,ct] = A[i,j]
6569          ct=ct+1
6570
6571  A = pd.DataFrame( A )
6572  A.columns = newnames
6573  A['networks']=newnames
6574  A_wide = pd.DataFrame( A_wide )
6575  A_wide.columns = newnames_wide
6576  outdict['corr'] = A
6577  outdict['corr_wide'] = A_wide
6578  outdict['fmri_template'] = fmri_template
6579  outdict['brainmask'] = bmask
6580  outdict['gmmask'] = gmseg
6581  outdict['alff'] = myfalff['alff']
6582  outdict['falff'] = myfalff['falff']
6583  # add global mean and standard deviation for post-hoc z-scoring
6584  outdict['alff_mean'] = (myfalff['alff'][myfalff['alff']!=0]).mean()
6585  outdict['alff_sd'] = (myfalff['alff'][myfalff['alff']!=0]).std()
6586  outdict['falff_mean'] = (myfalff['falff'][myfalff['falff']!=0]).mean()
6587  outdict['falff_sd'] = (myfalff['falff'][myfalff['falff']!=0]).std()
6588
6589  perafimg = PerAF( simgimp, bmask )
6590  for k in pointrange:
6591    anatname=( pts2bold['AAL'][k] )
6592    if isinstance(anatname, str):
6593        anatname = re.sub("_","",anatname)
6594    else:
6595        anatname='Unk'
6596    if powers:
6597        kk = f"{k:0>3}"+"_"
6598    else:
6599        kk = f"{k % int(nPoints/2):0>3}"+"_"
6600    fname='falffPoint'+kk+anatname
6601    aname='alffPoint'+kk+anatname
6602    pname='perafPoint'+kk+anatname
6603    localsel = ptImg == k
6604    if localsel.sum() > 0 : # check if non-empty
6605        outdict[fname]=(outdict['falff'][localsel]).mean()
6606        outdict[aname]=(outdict['alff'][localsel]).mean()
6607        outdict[pname]=(perafimg[localsel]).mean()
6608    else:
6609        outdict[fname]=math.nan
6610        outdict[aname]=math.nan
6611        outdict[pname]=math.nan
6612
6613  rsfNuisance = pd.DataFrame( nuisance )
6614  if remove_it:
6615    import shutil
6616    shutil.rmtree(output_directory, ignore_errors=True )
6617
6618  if not powers:
6619    dfnsum=outdict['DefaultA']+outdict['DefaultB']+outdict['DefaultC']
6620    outdict['DefaultMode']=dfnsum
6621    dfnsum=outdict['VisCent']+outdict['VisPeri']
6622    outdict['Visual']=dfnsum
6623
6624  nonbrainmask = ants.iMath( bmask, "MD",2) - bmask
6625  trimmask = ants.iMath( bmask, "ME",2)
6626  edgemask = ants.iMath( bmask, "ME",1) - trimmask
6627  outdict['motion_corrected'] = corrmo['motion_corrected']
6628  outdict['nuisance'] = rsfNuisance
6629  outdict['PerAF'] = perafimg
6630  outdict['tsnr'] = mytsnr
6631  outdict['ssnr'] = slice_snr( corrmo['motion_corrected'], csfAndWM, gmseg )
6632  outdict['dvars'] = dvars( corrmo['motion_corrected'], gmseg )
6633  outdict['bandpass_freq_0']=f[0]
6634  outdict['bandpass_freq_1']=f[1]
6635  outdict['censor']=int(censor)
6636  outdict['spatial_smoothing']=spa
6637  outdict['outlier_threshold']=outlier_threshold
6638  outdict['FD_threshold']=outlier_threshold
6639  outdict['high_motion_count'] = high_motion_count
6640  outdict['high_motion_pct'] = high_motion_pct
6641  outdict['despiking_count_summary'] = despiking_count_summary
6642  outdict['FD_max'] = corrmo['FD'].max()
6643  outdict['FD_mean'] = corrmo['FD'].mean()
6644  outdict['FD_sd'] = corrmo['FD'].std()
6645  outdict['bold_evr'] =  antspyt1w.patch_eigenvalue_ratio( und, 512, [16,16,16], evdepth = 0.9, mask = bmask )
6646  outdict['n_outliers'] = len(hlinds)
6647  outdict['nc_wm'] = int(nc_wm)
6648  outdict['nc_csf'] = int(nc_csf)
6649  outdict['minutes_original_data'] = ( tr * fmri.shape[3] ) / 60.0 # minutes of useful data
6650  outdict['minutes_censored_data'] = ( tr * simg.shape[3] ) / 60.0 # minutes of useful data
6651  return convert_np_in_dict( outdict )

Compute resting state network correlation maps based on the J Power labels. This will output a map for each of the major network systems. This function will by optionally upsample data to 2mm during the registration process if data is below that resolution.

registration - despike - anatomy - smooth - nuisance - bandpass - regress.nuisance - censor - falff - correlations

Arguments

fmri : BOLD fmri antsImage

fmri_template : reference space for BOLD

t1 : ANTsImage input 3-D T1 brain image (brain extracted)

t1segmentation : ANTsImage t1 segmentation - a six tissue segmentation image in T1 space

f : band pass limits for frequency filtering; we use high-pass here as per Shirer 2015

spa : gaussian smoothing for spatial component (physical coordinates)

spt : gaussian smoothing for temporal component

nc : number of components for compcor filtering; if less than 1 we estimate on the fly based on explained variance; 10 wrt Shirer 2015 5 from csf and 5 from wm

ica_components : integer if greater than 0 then include ica components

impute : boolean if True, then use imputation in f/ALFF, PerAF calculation

censor : boolean if True, then use censoring (censoring)

despike : if this is greater than zero will run voxel-wise despiking in the 3dDespike (afni) sense; after motion-correction

motion_as_nuisance: boolean will add motion and first derivative of motion as nuisance

powers : boolean if True use Powers nodes otherwise 2023 Yeo 500 homotopic nodes (10.1016/j.neuroimage.2023.120010)

upsample : float optionally isotropically upsample data to upsample (the parameter value) in mm during the registration process if data is below that resolution; if the input spacing is less than that provided by the user, the data will simply be resampled to isotropic resolution

clean_tmp : will automatically try to clean the tmp directory - not recommended but can be used in distributed computing systems to help prevent failures due to accumulation of tmp files when doing large-scale processing. if this is set, the float value clean_tmp will be interpreted as the age in hours of files to be cleaned.

verbose : boolean

Returns

a dictionary containing the derived network maps

References

10.1162/netn_a_00071 "Methods that included global signal regression were the most consistently effective de-noising strategies."

10.1016/j.neuroimage.2019.116157 "frontal and default model networks are most reliable whereas subcortical neteworks are least reliable" "the most comprehensive studies of pipeline effects on edge-level reliability have been done by shirer (2015) and Parkes (2018)" "slice timing correction has minimal impact" "use of low-pass or narrow filter (discarding high frequency information) reduced both reliability and signal-noise separation"

10.1016/j.neuroimage.2017.12.073: Our results indicate that (1) simple linear regression of regional fMRI time series against head motion parameters and WM/CSF signals (with or without expansion terms) is not sufficient to remove head motion artefacts; (2) aCompCor pipelines may only be viable in low-motion data; (3) volume censoring performs well at minimising motion-related artefact but a major benefit of this approach derives from the exclusion of high-motion individuals; (4) while not as effective as volume censoring, ICA-AROMA performed well across our benchmarks for relatively low cost in terms of data loss; (5) the addition of global signal regression improved the performance of nearly all pipelines on most benchmarks, but exacerbated the distance-dependence of correlations between motion and functional connec- tivity; and (6) group comparisons in functional connectivity between healthy controls and schizophrenia patients are highly dependent on preprocessing strategy. We offer some recommendations for best practice and outline simple analyses to facilitate transparent reporting of the degree to which a given set of findings may be affected by motion-related artefact.

10.1016/j.dcn.2022.101087 : We found that: 1) the most efficacious pipeline for both noise removal and information recovery included censoring, GSR, bandpass filtering, and head motion parameter (HMP) regression, 2) ICA-AROMA performed similarly to HMP regression and did not obviate the need for censoring, 3) GSR had a minimal impact on connectome fingerprinting but improved ISC, and 4) the strictest censoring approaches reduced motion correlated edges but negatively impacted identifiability.

def write_bvals_bvecs(bvals, bvecs, prefix):
7563def write_bvals_bvecs(bvals, bvecs, prefix ):
7564    ''' Write FSL FDT bvals and bvecs files
7565
7566    adapted from dipy.external code
7567
7568    Parameters
7569    -------------
7570    bvals : (N,) sequence
7571       Vector with diffusion gradient strength (one per diffusion
7572       acquisition, N=no of acquisitions)
7573    bvecs : (N, 3) array-like
7574       diffusion gradient directions
7575    prefix : string
7576       path to write FDT bvals, bvecs text files
7577       None results in current working directory.
7578    '''
7579    _VAL_FMT = '   %e'
7580    bvals = tuple(bvals)
7581    bvecs = np.asarray(bvecs)
7582    bvecs[np.isnan(bvecs)] = 0
7583    N = len(bvals)
7584    fname = prefix + '.bval'
7585    fmt = _VAL_FMT * N + '\n'
7586    myfile = open(fname, 'wt')
7587    myfile.write(fmt % bvals)
7588    myfile.close()
7589    fname = prefix + '.bvec'
7590    bvf = open(fname, 'wt')
7591    for dim_vals in bvecs.T:
7592        bvf.write(fmt % tuple(dim_vals))
7593    bvf.close()

Write FSL FDT bvals and bvecs files

adapted from dipy.external code

Parameters

bvals : (N,) sequence Vector with diffusion gradient strength (one per diffusion acquisition, N=no of acquisitions) bvecs : (N, 3) array-like diffusion gradient directions prefix : string path to write FDT bvals, bvecs text files None results in current working directory.

def crop_mcimage(x, mask, padder=None):
7596def crop_mcimage( x, mask, padder=None ):
7597    """
7598    crop a time series (4D) image by a 3D mask
7599
7600    Parameters
7601    -------------
7602
7603    x : raw image
7604
7605    mask  : mask for cropping
7606
7607    """
7608    cropmask = ants.crop_image( mask, mask )
7609    myorig = list( ants.get_origin(cropmask) )
7610    myorig.append( ants.get_origin( x )[3] )
7611    croplist = []
7612    if len(x.shape) > 3:
7613        for k in range(x.shape[3]):
7614            temp = ants.slice_image( x, axis=3, idx=k )
7615            temp = ants.crop_image( temp, mask )
7616            if padder is not None:
7617                temp = ants.pad_image( temp, pad_width=padder )
7618            croplist.append( temp )
7619        temp = ants.list_to_ndimage( x, croplist )
7620        temp.set_origin( myorig )
7621        return temp
7622    else:
7623        return( ants.crop_image( x, mask ) )

crop a time series (4D) image by a 3D mask

Parameters

x : raw image

mask : mask for cropping

def mm( t1_image, hier, rsf_image=[], flair_image=None, nm_image_list=None, dw_image=[], bvals=[], bvecs=[], perfusion_image=None, srmodel=None, do_tractography=False, do_kk=False, do_normalization=None, group_template=None, group_transform=None, target_range=[0, 1], dti_motion_correct='antsRegistrationSyNQuickRepro[r]', dti_denoise=False, perfusion_trim=10, perfusion_m0_image=None, perfusion_m0=None, rsf_upsampling=3.0, pet_3d_image=None, test_run=False, verbose=False):
7626def mm(
7627    t1_image,
7628    hier,
7629    rsf_image=[],
7630    flair_image=None,
7631    nm_image_list=None,
7632    dw_image=[], bvals=[], bvecs=[],
7633    perfusion_image=None,
7634    srmodel=None,
7635    do_tractography = False,
7636    do_kk = False,
7637    do_normalization = None,
7638    group_template = None,
7639    group_transform = None,
7640    target_range = [0,1],
7641    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
7642    dti_denoise = False,
7643    perfusion_trim=10,
7644    perfusion_m0_image=None,
7645    perfusion_m0=None,
7646    rsf_upsampling=3.0,
7647    pet_3d_image=None,
7648    test_run = False,
7649    verbose = False ):
7650    """
7651    Multiple modality processing and normalization
7652
7653    aggregates modality-specific processing under one roof.  see individual
7654    modality specific functions for details.
7655
7656    Parameters
7657    -------------
7658
7659    t1_image : raw t1 image
7660
7661    hier  : output of antspyt1w.hierarchical ( see read hierarchical )
7662
7663    rsf_image : list of resting state fmri
7664
7665    flair_image : flair
7666
7667    nm_image_list : list of neuromelanin images
7668
7669    dw_image : list of diffusion weighted images
7670
7671    bvals : list of bvals file names
7672
7673    bvecs : list of bvecs file names
7674
7675    perfusion_image : single perfusion image
7676
7677    srmodel : optional srmodel
7678
7679    do_tractography : boolean
7680
7681    do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)
7682
7683    do_normalization : template transformation if available
7684
7685    group_template : optional reference template corresponding to the group_transform
7686
7687    group_transform : optional transforms corresponding to the group_template
7688
7689    target_range : 2-element tuple
7690        a tuple or array defining the (min, max) of the input image
7691        (e.g., [-127.5, 127.5] or [0,1]).  Output images will be scaled back to original
7692        intensity. This range should match the mapping used in the training
7693        of the network.
7694    
7695    dti_motion_correct : None Rigid or SyN
7696
7697    dti_denoise : boolean
7698
7699    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
7700
7701    perfusion_m0_image : optional antsImage m0 associated with the perfusion time series
7702
7703    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
7704
7705    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
7706
7707    pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of 
7708        this image.  we just process it and provide summary information.
7709
7710    test_run : boolean 
7711
7712    verbose : boolean
7713
7714    """
7715    from os.path import exists
7716    ex_path = os.path.expanduser( "~/.antspyt1w/" )
7717    ex_path_mm = os.path.expanduser( "~/.antspymm/" )
7718    mycsvfn = ex_path + "FA_JHU_labels_edited.csv"
7719    citcsvfn = ex_path + "CIT168_Reinf_Learn_v1_label_descriptions_pad.csv"
7720    dktcsvfn = ex_path + "dkt.csv"
7721    cnxcsvfn = ex_path + "dkt_cortex_cit_deep_brain.csv"
7722    JHU_atlasfn = ex_path + 'JHU-ICBM-FA-1mm.nii.gz' # Read in JHU atlas
7723    JHU_labelsfn = ex_path + 'JHU-ICBM-labels-1mm.nii.gz' # Read in JHU labels
7724    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
7725    if not exists( mycsvfn ) or not exists( citcsvfn ) or not exists( cnxcsvfn ) or not exists( dktcsvfn ) or not exists( JHU_atlasfn ) or not exists( JHU_labelsfn ) or not exists( templatefn ):
7726        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
7727        raise ValueError('**missing files** => call get_data from latest antspyt1w and antspymm.')
7728    mycsv = pd.read_csv(  mycsvfn )
7729    citcsv = pd.read_csv(  os.path.expanduser( citcsvfn ) )
7730    dktcsv = pd.read_csv(  os.path.expanduser( dktcsvfn ) )
7731    cnxcsv = pd.read_csv(  os.path.expanduser( cnxcsvfn ) )
7732    JHU_atlas = mm_read( JHU_atlasfn ) # Read in JHU atlas
7733    JHU_labels = mm_read( JHU_labelsfn ) # Read in JHU labels
7734    template = mm_read( templatefn ) # Read in template
7735    if group_template is None:
7736        group_template = template
7737        group_transform = do_normalization['fwdtransforms']
7738    if verbose:
7739        print("Using group template:")
7740        print( group_template )
7741    #####################
7742    #  T1 hierarchical  #
7743    #####################
7744    t1imgbrn = hier['brain_n4_dnz']
7745    t1atropos = hier['dkt_parc']['tissue_segmentation']
7746    output_dict = {
7747        'kk': None,
7748        'rsf': None,
7749        'flair' : None,
7750        'NM' : None,
7751        'DTI' : None,
7752        'FA_summ' : None,
7753        'MD_summ' : None,
7754        'tractography' : None,
7755        'tractography_connectivity' : None,
7756        'perf' : None,
7757        'pet3d' : None,
7758    }
7759    normalization_dict = {
7760        'kk_norm': None,
7761        'NM_norm' : None,
7762        'DTI_norm': None,
7763        'FA_norm' : None,
7764        'MD_norm' : None,
7765        'perf_norm' : None,
7766        'alff_norm' : None,
7767        'falff_norm' : None,
7768        'CinguloopercularTaskControl_norm' : None,
7769        'DefaultMode_norm' : None,
7770        'MemoryRetrieval_norm' : None,
7771        'VentralAttention_norm' : None,
7772        'Visual_norm' : None,
7773        'FrontoparietalTaskControl_norm' : None,
7774        'Salience_norm' : None,
7775        'Subcortical_norm' : None,
7776        'DorsalAttention_norm' : None,
7777        'pet3d_norm' : None
7778    }
7779    if test_run:
7780        return output_dict, normalization_dict
7781
7782    if do_kk:
7783        if verbose:
7784            print('kk in mm')
7785        output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( t1_image,
7786            labels=hier['dkt_parc']['dkt_cortex'], iterations=45 )
7787
7788    if perfusion_image is not None:
7789        if perfusion_image.shape[3] > 1: # FIXME - better heuristic?
7790            output_dict['perf'] = bold_perfusion(
7791                perfusion_image,
7792                t1_image,
7793                hier['brain_n4_dnz'],
7794                t1atropos,
7795                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
7796                n_to_trim = perfusion_trim,
7797                m0_image = perfusion_m0_image,
7798                m0_indices = perfusion_m0,
7799                verbose=verbose )
7800
7801    if pet_3d_image is not None:
7802        if pet_3d_image.dimension == 3: # FIXME - better heuristic?
7803            output_dict['pet3d'] = pet3d_summary(
7804                pet_3d_image,
7805                t1_image,
7806                hier['brain_n4_dnz'],
7807                t1atropos,
7808                hier['dkt_parc']['dkt_cortex'] + hier['cit168lab'],
7809                verbose=verbose )
7810    ################################## do the rsf .....
7811    if len(rsf_image) > 0:
7812        my_motion_tx = 'antsRegistrationSyNRepro[r]'
7813        rsf_image = [i for i in rsf_image if i is not None]
7814        if verbose:
7815            print('rsf length ' + str( len( rsf_image ) ) )
7816        if len( rsf_image ) >= 2: # assume 2 is the largest possible value
7817            rsf_image1 = rsf_image[0]
7818            rsf_image2 = rsf_image[1]
7819            # build a template then join the images
7820            if verbose:
7821                print("initial average for rsf")
7822            rsfavg1, hlinds = loop_timeseries_censoring( rsf_image1, 0.1 )
7823            rsfavg1=get_average_rsf(rsfavg1)
7824            rsfavg2, hlinds = loop_timeseries_censoring( rsf_image2, 0.1 )
7825            rsfavg2=get_average_rsf(rsfavg2)
7826            if verbose:
7827                print("template average for rsf")
7828            init_temp = ants.image_clone( rsfavg1 )
7829            if rsf_image1.shape[3] < rsf_image2.shape[3]:
7830                init_temp = ants.image_clone( rsfavg2 )
7831            boldTemplate = ants.build_template(
7832                initial_template = init_temp,
7833                image_list=[rsfavg1,rsfavg2],
7834                type_of_transform="antsRegistrationSyNQuickRepro[s]",
7835                iterations=5, verbose=False )
7836            if verbose:
7837                print("join the 2 rsf")
7838            if rsf_image1.shape[3] > 10 and rsf_image2.shape[3] > 10:
7839                leadvols = list(range(8))
7840                rsf_image2 = remove_volumes_from_timeseries( rsf_image2, leadvols )
7841                rsf_image = merge_timeseries_data( rsf_image1, rsf_image2 )
7842            elif rsf_image1.shape[3] > rsf_image2.shape[3]:
7843                rsf_image = rsf_image1
7844            else:
7845                rsf_image = rsf_image2
7846        elif len( rsf_image ) == 1:
7847            rsf_image = rsf_image[0]
7848            boldTemplate, hlinds = loop_timeseries_censoring( rsf_image, 0.1 )
7849            boldTemplate = get_average_rsf(boldTemplate)
7850        if rsf_image.shape[3] > 10: # FIXME - better heuristic?
7851            rsfprolist = [] # FIXMERSF
7852            # Create the parameter DataFrame
7853            df = pd.DataFrame({
7854                "num": [134, 122, 129],
7855                "loop": [0.50, 0.25, 0.50],
7856                "cens": [True, True, True],
7857                "HM": [1.0, 5.0, 0.5],
7858                "ff": ["tight", "tight", "tight"],
7859                "CC": [5, 5, 0.8],
7860                "imp": [True, True, True],
7861                "up": [rsf_upsampling, rsf_upsampling, rsf_upsampling],
7862                "coords": [False,False,False]
7863            }, index=[0, 1, 2])
7864            for p in range(df.shape[0]):
7865                if verbose:
7866                    print("rsf parameters")
7867                    print( df.iloc[p] )
7868                if df['ff'].iloc[p] == 'broad':
7869                    f=[ 0.008, 0.15 ]
7870                elif df['ff'].iloc[p] == 'tight':
7871                    f=[ 0.03, 0.08 ]
7872                elif df['ff'].iloc[p] == 'mid':
7873                    f=[ 0.01, 0.1 ]
7874                elif df['ff'].iloc[p] == 'mid2':
7875                    f=[ 0.01, 0.08 ]
7876                else:
7877                    raise ValueError("we do not recognize this parameter choice for frequency filtering: " + df['ff'].iloc[p] )
7878                HM = df['HM'].iloc[p]
7879                CC = df['CC'].iloc[p]
7880                loop= df['loop'].iloc[p]
7881                cens =df['cens'].iloc[p]
7882                imp = df['imp'].iloc[p]
7883                rsf0 = resting_state_fmri_networks(
7884                                            rsf_image,
7885                                            boldTemplate,
7886                                            hier['brain_n4_dnz'],
7887                                            t1atropos,
7888                                            f=f,
7889                                            FD_threshold=HM, 
7890                                            spa = None, 
7891                                            spt = None, 
7892                                            nc = CC,
7893                                            outlier_threshold=loop,
7894                                            ica_components = 0,
7895                                            impute = imp,
7896                                            censor = cens,
7897                                            despike = 2.5,
7898                                            motion_as_nuisance = True,
7899                                            upsample=df['up'].iloc[p],
7900                                            clean_tmp=0.66,
7901                                            paramset=df['num'].iloc[p],
7902                                            powers=df['coords'].iloc[p],
7903                                            verbose=verbose ) # default
7904                rsfprolist.append( rsf0 )
7905            output_dict['rsf'] = rsfprolist
7906
7907    if nm_image_list is not None:
7908        if verbose:
7909            print('nm')
7910        if srmodel is None:
7911            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], verbose=verbose )
7912        else:
7913            output_dict['NM'] = neuromelanin( nm_image_list, t1imgbrn, t1_image, hier['deep_cit168lab'], srmodel=srmodel, target_range=target_range, verbose=verbose  )
7914################################## do the dti .....
7915    if len(dw_image) > 0 :
7916        if verbose:
7917            print('dti-x')
7918        if len( dw_image ) == 1: # use T1 for distortion correction and brain extraction
7919            if verbose:
7920                print("We have only one DTI: " + str(len(dw_image)))
7921            dw_image = dw_image[0]
7922            btpB0, btpDW = get_average_dwi_b0(dw_image)
7923            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
7924            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
7925                syn_metric='CC', syn_sampling=2,
7926                reg_iterations=[50,50,20],
7927                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
7928                initial_transform=initrig
7929                )
7930            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
7931            btpDW = ants.apply_transforms( btpDW, btpDW,
7932                tempreg['invtransforms'][1], interpolator='linear')
7933            btpB0 = ants.apply_transforms( btpB0, btpB0,
7934                tempreg['invtransforms'][1], interpolator='linear')
7935            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'][1], interpolator='nearestNeighbor')
7936            # dwimask = ants.iMath(dwimask,'MD',1)
7937            t12dwi = ants.apply_transforms( btpDW, hier['brain_n4_dnz'], tempreg['fwdtransforms'][1], interpolator='linear')
7938            output_dict['DTI'] = joint_dti_recon(
7939                dw_image,
7940                bvals[0],
7941                bvecs[0],
7942                jhu_atlas=JHU_atlas,
7943                jhu_labels=JHU_labels,
7944                brain_mask = dwimask,
7945                reference_B0 = btpB0,
7946                reference_DWI = btpDW,
7947                srmodel=srmodel,
7948                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
7949                denoise=dti_denoise,
7950                verbose = verbose)
7951        else :  # use phase encoding acquisitions for distortion correction and T1 for brain extraction
7952            if verbose:
7953                print("We have both DTI_LR and DTI_RL: " + str(len(dw_image)))
7954            a1b,a1w=get_average_dwi_b0(dw_image[0])
7955            a2b,a2w=get_average_dwi_b0(dw_image[1],fixed_b0=a1b,fixed_dwi=a1w)
7956            btpB0, btpDW = dti_template(
7957                b_image_list=[a1b,a2b],
7958                w_image_list=[a1w,a2w],
7959                iterations=7, verbose=verbose )
7960            initrig = ants.registration( btpDW, hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[r]' )['fwdtransforms'][0]
7961            tempreg = ants.registration( btpDW, hier['brain_n4_dnz'], 'SyNOnly',
7962                syn_metric='CC', syn_sampling=2,
7963                reg_iterations=[50,50,20],
7964                multivariate_extras=[ [ "CC", btpB0, hier['brain_n4_dnz'], 1, 2 ]],
7965                initial_transform=initrig
7966                )
7967            mybxt = ants.threshold_image( ants.iMath(hier['brain_n4_dnz'], "Normalize" ), 0.001, 1 )
7968            dwimask = ants.apply_transforms( btpDW, mybxt, tempreg['fwdtransforms'], interpolator='nearestNeighbor')
7969            output_dict['DTI'] = joint_dti_recon(
7970                dw_image[0],
7971                bvals[0],
7972                bvecs[0],
7973                jhu_atlas=JHU_atlas,
7974                jhu_labels=JHU_labels,
7975                brain_mask = dwimask,
7976                reference_B0 = btpB0,
7977                reference_DWI = btpDW,
7978                srmodel=srmodel,
7979                img_RL=dw_image[1],
7980                bval_RL=bvals[1],
7981                bvec_RL=bvecs[1],
7982                motion_correct=dti_motion_correct, # set to False if using input from qsiprep
7983                denoise=dti_denoise,
7984                verbose = verbose)
7985        mydti = output_dict['DTI']
7986        # summarize dwi with T1 outputs
7987        # first - register ....
7988        reg = ants.registration( mydti['recon_fa'], hier['brain_n4_dnz'], 'antsRegistrationSyNRepro[s]', total_sigma=1.0 )
7989        ##################################################
7990        output_dict['FA_summ'] = hierarchical_modality_summary(
7991            mydti['recon_fa'],
7992            hier=hier,
7993            modality_name='fa',
7994            transformlist=reg['fwdtransforms'],
7995            verbose = False )
7996        ##################################################
7997        output_dict['MD_summ'] = hierarchical_modality_summary(
7998            mydti['recon_md'],
7999            hier=hier,
8000            modality_name='md',
8001            transformlist=reg['fwdtransforms'],
8002            verbose = False )
8003        # these inputs should come from nicely processed data
8004        dktmapped = ants.apply_transforms(
8005            mydti['recon_fa'],
8006            hier['dkt_parc']['dkt_cortex'],
8007            reg['fwdtransforms'], interpolator='nearestNeighbor' )
8008        citmapped = ants.apply_transforms(
8009            mydti['recon_fa'],
8010            hier['cit168lab'],
8011            reg['fwdtransforms'], interpolator='nearestNeighbor' )
8012        dktmapped[ citmapped > 0]=0
8013        mask = ants.threshold_image( mydti['recon_fa'], 0.01, 2.0 ).iMath("GetLargestComponent")
8014        if do_tractography: # dwi_deterministic_tracking dwi_closest_peak_tracking
8015            output_dict['tractography'] = dwi_deterministic_tracking(
8016                mydti['dwi_LR_dewarped'],
8017                mydti['recon_fa'],
8018                mydti['bval_LR'],
8019                mydti['bvec_LR'],
8020                seed_density = 1,
8021                mask=mask,
8022                verbose = verbose )
8023            mystr = output_dict['tractography']
8024            output_dict['tractography_connectivity'] = dwi_streamline_connectivity( mystr['streamlines'], dktmapped+citmapped, cnxcsv, verbose=verbose )
8025    ################################## do the flair .....
8026    if flair_image is not None:
8027        if verbose:
8028            print('flair')
8029        wmhprior = None
8030        priorfn = ex_path_mm + 'CIT168_wmhprior_700um_pad_adni.nii.gz'
8031        if ( exists( priorfn ) ):
8032            wmhprior = ants.image_read( priorfn )
8033            wmhprior = ants.apply_transforms( t1_image, wmhprior, do_normalization['invtransforms'] )
8034        output_dict['flair'] = boot_wmh( flair_image, t1_image, t1atropos,
8035            prior_probability=wmhprior, verbose=verbose )
8036    #################################################################
8037    ### NOTES: deforming to a common space and writing out images ###
8038    ### images we want come from: DTI, NM, rsf, thickness ###########
8039    #################################################################
8040    if do_normalization is not None:
8041        if verbose:
8042            print('normalization')
8043        # might reconsider this template space - cropped and/or higher res?
8044        # template = ants.resample_image( template, [1,1,1], use_voxels=False )
8045        # t1reg = ants.registration( template, hier['brain_n4_dnz'], "antsRegistrationSyNQuickRepro[s]")
8046        t1reg = do_normalization
8047        if do_kk:
8048            normalization_dict['kk_norm'] = ants.apply_transforms( group_template, output_dict['kk']['thickness_image'], group_transform )
8049        if output_dict['DTI'] is not None:
8050            mydti = output_dict['DTI']
8051            dtirig = ants.registration( hier['brain_n4_dnz'], mydti['recon_fa'], 'antsRegistrationSyNRepro[r]' )
8052            normalization_dict['MD_norm'] = ants.apply_transforms( group_template, mydti['recon_md'],group_transform+dtirig['fwdtransforms'] )
8053            normalization_dict['FA_norm'] = ants.apply_transforms( group_template, mydti['recon_fa'],group_transform+dtirig['fwdtransforms'] )
8054            output_directory = tempfile.mkdtemp()
8055            do_dti_norm=False
8056            if do_dti_norm:
8057                comptx = ants.apply_transforms( group_template, group_template, group_transform+dtirig['fwdtransforms'], compose = output_directory + '/xxx' )
8058                tspc=[2.,2.,2.]
8059                if srmodel is not None:
8060                    tspc=[1.,1.,1.]
8061                group_template2mm = ants.resample_image( group_template, tspc  )
8062                normalization_dict['DTI_norm'] = transform_and_reorient_dti( group_template2mm, mydti['dti'], comptx, verbose=False )
8063            import shutil
8064            shutil.rmtree(output_directory, ignore_errors=True )
8065        if output_dict['rsf'] is not None:
8066            if False:
8067                rsfpro = output_dict['rsf'] # FIXME
8068                rsfrig = ants.registration( hier['brain_n4_dnz'], rsfpro['meanBold'], 'antsRegistrationSyNRepro[r]' )
8069                for netid in get_antsimage_keys( rsfpro ):
8070                    rsfkey = netid + "_norm"
8071                    normalization_dict[rsfkey] = ants.apply_transforms(
8072                        group_template, rsfpro[netid],
8073                        group_transform+rsfrig['fwdtransforms'] )
8074        if output_dict['perf'] is not None: # zizzer
8075            comptx = group_transform + output_dict['perf']['t1reg']['invtransforms']
8076            normalization_dict['perf_norm'] = ants.apply_transforms( group_template,
8077                output_dict['perf']['perfusion'], comptx,
8078                whichtoinvert=[False,False,True,False] )
8079            normalization_dict['cbf_norm'] = ants.apply_transforms( group_template,
8080                output_dict['perf']['cbf'], comptx,
8081                whichtoinvert=[False,False,True,False] )
8082        if output_dict['pet3d'] is not None: # zizzer
8083            secondTx=output_dict['pet3d']['t1reg']['invtransforms']
8084            comptx = group_transform + secondTx
8085            if len( secondTx ) == 2:
8086                wti=[False,False,True,False]
8087            else:
8088                wti=[False,False,True]
8089            normalization_dict['pet3d_norm'] = ants.apply_transforms( group_template,
8090                output_dict['pet3d']['pet3d'], comptx,
8091                whichtoinvert=wti )
8092        if nm_image_list is not None:
8093            nmpro = output_dict['NM']
8094            nmrig = nmpro['t1_to_NM_transform'] # this is an inverse tx
8095            normalization_dict['NM_norm'] = ants.apply_transforms( group_template, nmpro['NM_avg'], group_transform+nmrig,
8096                whichtoinvert=[False,False,True])
8097
8098    if verbose:
8099        print('mm done')
8100    return output_dict, normalization_dict

Multiple modality processing and normalization

aggregates modality-specific processing under one roof. see individual modality specific functions for details.

Parameters

t1_image : raw t1 image

hier : output of antspyt1w.hierarchical ( see read hierarchical )

rsf_image : list of resting state fmri

flair_image : flair

nm_image_list : list of neuromelanin images

dw_image : list of diffusion weighted images

bvals : list of bvals file names

bvecs : list of bvecs file names

perfusion_image : single perfusion image

srmodel : optional srmodel

do_tractography : boolean

do_kk : boolean to control whether we compute kelly kapowski thickness image (slow)

do_normalization : template transformation if available

group_template : optional reference template corresponding to the group_transform

group_transform : optional transforms corresponding to the group_template

target_range : 2-element tuple a tuple or array defining the (min, max) of the input image (e.g., [-127.5, 127.5] or [0,1]). Output images will be scaled back to original intensity. This range should match the mapping used in the training of the network.

dti_motion_correct : None Rigid or SyN

dti_denoise : boolean

perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series

perfusion_m0_image : optional antsImage m0 associated with the perfusion time series

perfusion_m0 : optional list containing indices of the m0 in the perfusion time series

rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done

pet_3d_image : optional antsImage for a 3D pet; we make no assumptions about the contents of this image. we just process it and provide summary information.

test_run : boolean

verbose : boolean

def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False):
8103def write_mm( output_prefix, mm, mm_norm=None, t1wide=None, separator='_', verbose=False ):
8104    """
8105    write the tabular and normalization output of the mm function
8106
8107    Parameters
8108    -------------
8109
8110    output_prefix : prefix for file outputs - modality specific postfix will be added
8111
8112    mm  : output of mm function for modality-space processing should be a dictionary with 
8113        dictionary entries for each modality.
8114
8115    mm_norm : output of mm function for normalized processing
8116
8117    t1wide : wide output data frame from t1 hierarchical
8118
8119    separator : string or character separator for filenames
8120
8121    verbose : boolean
8122
8123    Returns
8124    ---------
8125
8126    both csv and image files written to disk.  the primary outputs will be
8127    output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images
8128
8129    """
8130    from dipy.io.streamline import save_tractogram
8131    if mm_norm is not None:
8132        for mykey in mm_norm:
8133            tempfn = output_prefix + separator + mykey + '.nii.gz'
8134            if mm_norm[mykey] is not None:
8135                image_write_with_thumbnail( mm_norm[mykey], tempfn )
8136    thkderk = None
8137    if t1wide is not None:
8138        thkderk = t1wide.iloc[: , 1:]
8139    kkderk = None
8140    if 'kk' in mm:
8141        if mm['kk'] is not None:
8142            kkderk = mm['kk']['thickness_dataframe'].iloc[: , 1:]
8143            mykey='thickness_image'
8144            tempfn = output_prefix + separator + mykey + '.nii.gz'
8145            image_write_with_thumbnail( mm['kk'][mykey], tempfn )
8146    nmderk = None
8147    if 'NM' in mm:
8148        if mm['NM'] is not None:
8149            nmderk = mm['NM']['NM_dataframe_wide'].iloc[: , 1:]
8150            for mykey in get_antsimage_keys( mm['NM'] ):
8151                tempfn = output_prefix + separator + mykey + '.nii.gz'
8152                image_write_with_thumbnail( mm['NM'][mykey], tempfn, thumb=False )
8153
8154    faderk = mdderk = fat1derk = mdt1derk = None
8155
8156    if 'DTI' in mm:
8157        if mm['DTI'] is not None:
8158            mydti = mm['DTI']
8159            myop = output_prefix + separator
8160            ants.image_write( mydti['dti'],  myop + 'dti.nii.gz' )
8161            write_bvals_bvecs( mydti['bval_LR'], mydti['bvec_LR'], myop + 'reoriented' )
8162            image_write_with_thumbnail( mydti['dwi_LR_dewarped'],  myop + 'dwi.nii.gz' )
8163            image_write_with_thumbnail( mydti['dtrecon_LR_dewarp']['RGB'] ,  myop + 'DTIRGB.nii.gz' )
8164            image_write_with_thumbnail( mydti['jhu_labels'],  myop+'dtijhulabels.nii.gz', mydti['recon_fa'] )
8165            image_write_with_thumbnail( mydti['recon_fa'],  myop+'dtifa.nii.gz' )
8166            image_write_with_thumbnail( mydti['recon_md'],  myop+'dtimd.nii.gz' )
8167            image_write_with_thumbnail( mydti['b0avg'],  myop+'b0avg.nii.gz' )
8168            image_write_with_thumbnail( mydti['dwiavg'],  myop+'dwiavg.nii.gz' )
8169            faderk = mm['DTI']['recon_fa_summary'].iloc[: , 1:]
8170            mdderk = mm['DTI']['recon_md_summary'].iloc[: , 1:]
8171            fat1derk = mm['FA_summ'].iloc[: , 1:]
8172            mdt1derk = mm['MD_summ'].iloc[: , 1:]
8173    if 'tractography' in mm:
8174        if mm['tractography'] is not None:
8175            ofn = output_prefix + separator + 'tractogram.trk'
8176            if mm['tractography']['tractogram'] is not None:
8177                save_tractogram( mm['tractography']['tractogram'], ofn )
8178    cnxderk = None
8179    if 'tractography_connectivity' in mm:
8180        if mm['tractography_connectivity'] is not None:
8181            cnxderk = mm['tractography_connectivity']['connectivity_wide'].iloc[: , 1:] # NOTE: connectivity_wide is not much tested
8182            ofn = output_prefix + separator + 'dtistreamlineconn.csv'
8183            pd.DataFrame(mm['tractography_connectivity']['connectivity_matrix']).to_csv( ofn )
8184
8185    dlist = [
8186        thkderk,
8187        kkderk,
8188        nmderk,
8189        faderk,
8190        mdderk,
8191        fat1derk,
8192        mdt1derk,
8193        cnxderk
8194        ]
8195    is_all_none = all(element is None for element in dlist)
8196    if is_all_none:
8197        mm_wide = pd.DataFrame({'u_hier_id': [output_prefix] })
8198    else:
8199        mm_wide = pd.concat( dlist, axis=1, ignore_index=False )
8200
8201    mm_wide = mm_wide.copy()
8202    if 'NM' in mm:
8203        if mm['NM'] is not None:
8204            nmwide = dict_to_dataframe( mm['NM'] )
8205            if mm_wide.shape[0] > 0 and nmwide.shape[0] > 0:
8206                nmwide.set_index( mm_wide.index, inplace=True )
8207            mm_wide = pd.concat( [mm_wide, nmwide ], axis=1, ignore_index=False )
8208    if 'flair' in mm:
8209        if mm['flair'] is not None:
8210            myop = output_prefix + separator + 'wmh.nii.gz'
8211            pngfnb = output_prefix + separator + 'wmh_seg.png'
8212            ants.plot( mm['flair']['flair'], mm['flair']['WMH_posterior_probability_map'], axis=2, nslices=21, ncol=7, filename=pngfnb, crop=True )
8213            if mm['flair']['WMH_probability_map'] is not None:
8214                image_write_with_thumbnail( mm['flair']['WMH_probability_map'], myop, thumb=False )
8215            flwide = dict_to_dataframe( mm['flair'] )
8216            if mm_wide.shape[0] > 0 and flwide.shape[0] > 0:
8217                flwide.set_index( mm_wide.index, inplace=True )
8218            mm_wide = pd.concat( [mm_wide, flwide ], axis=1, ignore_index=False )
8219    if 'rsf' in mm:
8220        if mm['rsf'] is not None:
8221            fcnxpro=99
8222            rsfdata = mm['rsf']
8223            if not isinstance( rsfdata, list ):
8224                rsfdata = [ rsfdata ]
8225            for rsfpro in rsfdata:
8226                fcnxpro=str( rsfpro['paramset']  )
8227                pronum = 'fcnxpro'+str(fcnxpro)+"_"
8228                if verbose:
8229                    print("Collect rsf data " + pronum)
8230                new_rsf_wide = dict_to_dataframe( rsfpro )
8231                new_rsf_wide = pd.concat( [new_rsf_wide, rsfpro['corr_wide'] ], axis=1, ignore_index=False )
8232                new_rsf_wide = new_rsf_wide.add_prefix( pronum )
8233                new_rsf_wide.set_index( mm_wide.index, inplace=True )
8234                ofn = output_prefix + separator + pronum + '.csv'
8235                new_rsf_wide.to_csv( ofn )
8236                mm_wide = pd.concat( [mm_wide, new_rsf_wide ], axis=1, ignore_index=False )
8237                for mykey in get_antsimage_keys( rsfpro ):
8238                    myop = output_prefix + separator + pronum + mykey + '.nii.gz'
8239                    image_write_with_thumbnail( rsfpro[mykey], myop, thumb=True )
8240                ofn = output_prefix + separator + pronum + 'rsfcorr.csv'
8241                rsfpro['corr'].to_csv( ofn )
8242                # apply same principle to new correlation matrix, doesn't need to be incorporated with mm_wide
8243                ofn2 = output_prefix + separator + pronum + 'nodescorr.csv'
8244                rsfpro['fullCorrMat'].to_csv( ofn2 )
8245    if 'DTI' in mm:
8246        if mm['DTI'] is not None:
8247            mydti = mm['DTI']
8248            mm_wide['dti_tsnr_b0_mean'] =  mydti['tsnr_b0'].mean()
8249            mm_wide['dti_tsnr_dwi_mean'] =  mydti['tsnr_dwi'].mean()
8250            mm_wide['dti_dvars_b0_mean'] =  mydti['dvars_b0'].mean()
8251            mm_wide['dti_dvars_dwi_mean'] =  mydti['dvars_dwi'].mean()
8252            mm_wide['dti_ssnr_b0_mean'] =  mydti['ssnr_b0'].mean()
8253            mm_wide['dti_ssnr_dwi_mean'] =  mydti['ssnr_dwi'].mean()
8254            mm_wide['dti_fa_evr'] =  mydti['fa_evr']
8255            mm_wide['dti_fa_SNR'] =  mydti['fa_SNR']
8256            if mydti['framewise_displacement'] is not None:
8257                mm_wide['dti_high_motion_count'] =  mydti['high_motion_count']
8258                mm_wide['dti_FD_mean'] = mydti['framewise_displacement'].mean()
8259                mm_wide['dti_FD_max'] = mydti['framewise_displacement'].max()
8260                mm_wide['dti_FD_sd'] = mydti['framewise_displacement'].std()
8261                fdfn = output_prefix + separator + '_fd.csv'
8262            else:
8263                mm_wide['dti_FD_mean'] = mm_wide['dti_FD_max'] = mm_wide['dti_FD_sd'] = 'NA'
8264
8265    if 'perf' in mm:
8266        if mm['perf'] is not None:
8267            perfpro = mm['perf']
8268            prwide = dict_to_dataframe( perfpro )
8269            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
8270                prwide.set_index( mm_wide.index, inplace=True )
8271            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
8272            if 'perf_dataframe' in perfpro.keys():
8273                pderk = perfpro['perf_dataframe'].iloc[: , 1:]
8274                pderk.set_index( mm_wide.index, inplace=True )
8275                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
8276            else:
8277                print("FIXME - perfusion dataframe")
8278            for mykey in get_antsimage_keys( mm['perf'] ):
8279                tempfn = output_prefix + separator + mykey + '.nii.gz'
8280                image_write_with_thumbnail( mm['perf'][mykey], tempfn, thumb=False )
8281
8282    if 'pet3d' in mm:
8283        if mm['pet3d'] is not None:
8284            pet3dpro = mm['pet3d']
8285            prwide = dict_to_dataframe( pet3dpro )
8286            if mm_wide.shape[0] > 0 and prwide.shape[0] > 0:
8287                prwide.set_index( mm_wide.index, inplace=True )
8288            mm_wide = pd.concat( [mm_wide, prwide ], axis=1, ignore_index=False )
8289            if 'pet3d_dataframe' in pet3dpro.keys():
8290                pderk = pet3dpro['pet3d_dataframe'].iloc[: , 1:]
8291                pderk.set_index( mm_wide.index, inplace=True )
8292                mm_wide = pd.concat( [ mm_wide, pderk ], axis=1, ignore_index=False )
8293            else:
8294                print("FIXME - pet3dusion dataframe")
8295            for mykey in get_antsimage_keys( mm['pet3d'] ):
8296                tempfn = output_prefix + separator + mykey + '.nii.gz'
8297                image_write_with_thumbnail( mm['pet3d'][mykey], tempfn, thumb=False )
8298
8299    mmwidefn = output_prefix + separator + 'mmwide.csv'
8300    mm_wide.to_csv( mmwidefn )
8301    if verbose:
8302        print( output_prefix + " write_mm done." )
8303    return

write the tabular and normalization output of the mm function

Parameters

output_prefix : prefix for file outputs - modality specific postfix will be added

mm : output of mm function for modality-space processing should be a dictionary with dictionary entries for each modality.

mm_norm : output of mm function for normalized processing

t1wide : wide output data frame from t1 hierarchical

separator : string or character separator for filenames

verbose : boolean

Returns

both csv and image files written to disk. the primary outputs will be output_prefix + separator + 'mmwide.csv' and *norm.nii.gz images

def mm_nrg( studyid, sourcedir='/Users/stnava/data/PPMI/MV/example_s3_b/images/PPMI/', sourcedatafoldername='images', processDir='processed', mysep='-', srmodel_T1=False, srmodel_NM=False, srmodel_DTI=False, visualize=True, nrg_modality_list=['T1w', 'NM2DMT', 'DTI', 'T2Flair', 'rsfMRI'], verbose=True):
8306def mm_nrg(
8307    studyid,   # pandas data frame
8308    sourcedir = os.path.expanduser( "~/data/PPMI/MV/example_s3_b/images/PPMI/" ),
8309    sourcedatafoldername = 'images', # root for source data
8310    processDir = "processed", # where output will go - parallel to sourcedatafoldername
8311    mysep = '-', # define a separator for filename components
8312    srmodel_T1 = False, # optional - will add a great deal of time
8313    srmodel_NM = False, # optional - will add a great deal of time
8314    srmodel_DTI = False, # optional - will add a great deal of time
8315    visualize = True,
8316    nrg_modality_list = ["T1w", "NM2DMT", "DTI","T2Flair", "rsfMRI" ],
8317    verbose = True
8318):
8319    """
8320    too dangerous to document ... use with care.
8321
8322    processes multiple modality MRI specifically:
8323
8324    * T1w
8325    * T2Flair
8326    * DTI, DTI_LR, DTI_RL
8327    * rsfMRI, rsfMRI_LR, rsfMRI_RL
8328    * NM2DMT (neuromelanin)
8329
8330    other modalities may be added later ...
8331
8332    "trust me, i know what i'm doing" - sledgehammer
8333
8334    convert to pynb via:
8335        p2j mm.py -o
8336
8337    convert the ipynb to html via:
8338        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
8339
8340    this function assumes NRG format for the input data ....
8341    we also assume that t1w hierarchical (if already done) was written
8342    via its standardized write function.
8343    NRG = https://github.com/stnava/biomedicalDataOrganization
8344
8345    this function is verbose
8346
8347    Parameters
8348    -------------
8349
8350    studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID
8351        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
8352        these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor,
8353        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
8354        t1 is required.  rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.  see antspymm.generate_mm_dataframe
8355
8356    sourcedir : a study specific folder containing individual subject folders
8357
8358    sourcedatafoldername : root for source data e.g. "images"
8359
8360    processDir : where output will go - parallel to sourcedatafoldername e.g.
8361        "processed"
8362
8363    mysep : define a character separator for filename components
8364
8365    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
8366
8367    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
8368
8369    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
8370
8371    visualize : True - will plot some results to png
8372
8373    nrg_modality_list : list of permissible modalities - always include [T1w] as base
8374
8375    verbose : boolean
8376
8377    Returns
8378    ---------
8379
8380    writes output to disk and potentially produces figures that may be
8381    captured in a ipynb / html file.
8382
8383    """
8384    studyid = studyid.dropna(axis=1)
8385    if studyid.shape[0] < 1:
8386        raise ValueError('studyid has no rows')
8387    musthavecols = ['subjectID','date','imageID']
8388    for k in range(len(musthavecols)):
8389        if not musthavecols[k] in studyid.keys():
8390            raise ValueError('studyid is missing column ' +musthavecols[k] )
8391    def makewideout( x, separator = '-' ):
8392        return x + separator + 'mmwide.csv'
8393    if nrg_modality_list[0] != 'T1w':
8394        nrg_modality_list.insert(0, "T1w" )
8395    testloop = False
8396    counter=0
8397    import glob as glob
8398    from os.path import exists
8399    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8400    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
8401    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8402    if not exists( templatefn ):
8403        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
8404        antspyt1w.get_data( force_download=True )
8405        get_data( force_download=True )
8406    temp = sourcedir.split( "/" )
8407    splitCount = len( temp )
8408    template = mm_read( templatefn ) # Read in template
8409    test_run = False
8410    if test_run:
8411        visualize=False
8412    # get sid and dtid from studyid
8413    sid = str(studyid['subjectID'].iloc[0])
8414    dtid = str(studyid['date'].iloc[0])
8415    iid = str(studyid['imageID'].iloc[0])
8416    subjectrootpath = os.path.join(sourcedir,sid, dtid)
8417    if verbose:
8418        print("subjectrootpath: "+ subjectrootpath )
8419    myimgsInput = glob.glob( subjectrootpath+"/*" )
8420    myimgsInput.sort( )
8421    if verbose:
8422        print( myimgsInput )
8423    # hierarchical
8424    # NOTE: if there are multiple T1s for this time point, should take
8425    # the one with the highest resnetGrade
8426    t1_search_path = os.path.join(subjectrootpath, "T1w", iid, "*nii.gz")
8427    if verbose:
8428        print(f"t1 search path: {t1_search_path}")
8429    t1fn = glob.glob(t1_search_path)
8430    t1fn.sort()
8431    if len( t1fn ) < 1:
8432        raise ValueError('mm_nrg cannot find the T1w with uid ' + iid + ' @ ' + subjectrootpath )
8433    t1fn = t1fn[0]
8434    t1 = mm_read( t1fn )
8435    hierfn0 = re.sub( sourcedatafoldername, processDir, t1fn)
8436    hierfn0 = re.sub( ".nii.gz", "", hierfn0)
8437    hierfn = re.sub( "T1w", "T1wHierarchical", hierfn0)
8438    hierfn = hierfn + mysep
8439    hierfntest = hierfn + 'snseg.csv'
8440    regout = hierfn0 + mysep + "syn"
8441    templateTx = {
8442        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
8443        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
8444    if verbose:
8445        print( "-<REGISTRATION EXISTENCE>-: \n" + 
8446              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
8447            str(exists( templateTx['fwdtransforms'][0])) + " " +
8448            str(exists( templateTx['fwdtransforms'][1])) + " " +
8449            str(exists( templateTx['invtransforms'][0])) + " " +
8450            str(exists( templateTx['invtransforms'][1])) )
8451    if verbose:
8452        print( hierfntest )
8453    hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
8454    hier = None
8455    if not hierexists and not testloop:
8456        subjectpropath = os.path.dirname( hierfn )
8457        if verbose:
8458            print( subjectpropath )
8459        os.makedirs( subjectpropath, exist_ok=True  )
8460        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
8461        antspyt1w.write_hierarchical( hier, hierfn )
8462        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8463                hier['dataframes'], identifier=None )
8464        t1wide.to_csv( hierfn + 'mmwide.csv' )
8465    ################# read the hierarchical data ###############################
8466    hier = antspyt1w.read_hierarchical( hierfn )
8467    if exists( hierfn + 'mmwide.csv' ) :
8468        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
8469    elif not testloop:
8470        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8471                hier['dataframes'], identifier=None )
8472    if srmodel_T1 is not False :
8473        hierfnSR = re.sub( sourcedatafoldername, processDir, t1fn)
8474        hierfnSR = re.sub( "T1w", "T1wHierarchicalSR", hierfnSR)
8475        hierfnSR = re.sub( ".nii.gz", "", hierfnSR)
8476        hierfnSR = hierfnSR + mysep
8477        hierfntest = hierfnSR + 'mtl.csv'
8478        if verbose:
8479            print( hierfntest )
8480        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
8481        if not hierexists:
8482            subjectpropath = os.path.dirname( hierfnSR )
8483            if verbose:
8484                print( subjectpropath )
8485            os.makedirs( subjectpropath, exist_ok=True  )
8486            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
8487            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
8488            mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_2chan_featvggL6_postseg_best_mdl.h5"
8489            if isinstance( srmodel_T1, str ):
8490                mdlfn = os.path.join( ex_pathmm, srmodel_T1 )
8491            if verbose:
8492                print( mdlfn )
8493            if exists( mdlfn ):
8494                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
8495            else:
8496                print( mdlfn + " does not exist - will not run.")
8497            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
8498            antspyt1w.write_hierarchical( hierSR, hierfnSR )
8499            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8500                    hierSR['dataframes'], identifier=None )
8501            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
8502    hier = antspyt1w.read_hierarchical( hierfn )
8503    if exists( hierfn + 'mmwide.csv' ) :
8504        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
8505    elif not testloop:
8506        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
8507                hier['dataframes'], identifier=None )
8508    if not testloop:
8509        t1imgbrn = hier['brain_n4_dnz']
8510        t1atropos = hier['dkt_parc']['tissue_segmentation']
8511    # loop over modalities and then unique image IDs
8512    # we treat NM in a "special" way -- aggregating repeats
8513    # other modalities (beyond T1) are treated individually
8514    nimages = len(myimgsInput)
8515    if verbose:
8516        print(  " we have : " + str(nimages) + " modalities.")
8517    for overmodX in nrg_modality_list:
8518        counter=counter+1
8519        if counter > (len(nrg_modality_list)+1):
8520            print("This is weird. " + str(counter))
8521            return
8522        if overmodX == 'T1w':
8523            iidOtherMod = iid
8524            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8525            myimgsr = glob.glob(mod_search_path)
8526        elif overmodX == 'NM2DMT' and ('nmid1' in studyid.keys() ):
8527            iidOtherMod = str( int(studyid['nmid1'].iloc[0]) )
8528            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8529            myimgsr = glob.glob(mod_search_path)
8530            for nmnum in range(2,11):
8531                locnmnum = 'nmid'+str(nmnum)
8532                if locnmnum in studyid.keys() :
8533                    iidOtherMod = str( int(studyid[locnmnum].iloc[0]) )
8534                    mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8535                    myimgsr.append( glob.glob(mod_search_path)[0] )
8536        elif 'rsfMRI' in overmodX and ( ( 'rsfid1' in studyid.keys() ) or ('rsfid2' in studyid.keys() ) ):
8537            myimgsr = []
8538            if  'rsfid1' in studyid.keys():
8539                iidOtherMod = str( int(studyid['rsfid1'].iloc[0]) )
8540                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8541                myimgsr.append( glob.glob(mod_search_path)[0] )
8542            if  'rsfid2' in studyid.keys():
8543                iidOtherMod = str( int(studyid['rsfid2'].iloc[0]) )
8544                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8545                myimgsr.append( glob.glob(mod_search_path)[0] )
8546        elif 'DTI' in overmodX and (  'dtid1' in studyid.keys() or  'dtid2' in studyid.keys() ):
8547            myimgsr = []
8548            if  'dtid1' in studyid.keys():
8549                iidOtherMod = str( int(studyid['dtid1'].iloc[0]) )
8550                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8551                myimgsr.append( glob.glob(mod_search_path)[0] )
8552            if  'dtid2' in studyid.keys():
8553                iidOtherMod = str( int(studyid['dtid2'].iloc[0]) )
8554                mod_search_path = os.path.join(subjectrootpath, overmodX+"*", iidOtherMod, "*nii.gz")
8555                myimgsr.append( glob.glob(mod_search_path)[0] )
8556        elif 'T2Flair' in overmodX and ('flairid' in studyid.keys() ):
8557            iidOtherMod = str( int(studyid['flairid'].iloc[0]) )
8558            mod_search_path = os.path.join(subjectrootpath, overmodX, iidOtherMod, "*nii.gz")
8559            myimgsr = glob.glob(mod_search_path)
8560        if verbose:
8561            print( "overmod " + overmodX + " " + iidOtherMod )
8562            print(f"modality search path: {mod_search_path}")
8563        myimgsr.sort()
8564        if len(myimgsr) > 0:
8565            overmodXx = str(overmodX)
8566            dowrite=False
8567            if verbose:
8568                print( 'overmodX is : ' + overmodXx )
8569                print( 'example image name is : '  )
8570                print( myimgsr )
8571            if overmodXx == 'NM2DMT':
8572                myimgsr2 = myimgsr
8573                myimgsr2.sort()
8574                is4d = False
8575                temp = ants.image_read( myimgsr2[0] )
8576                if temp.dimension == 4:
8577                    is4d = True
8578                if len( myimgsr2 ) == 1 and not is4d: # check dimension
8579                    myimgsr2 = myimgsr2 + myimgsr2
8580                subjectpropath = os.path.dirname( myimgsr2[0] )
8581                subjectpropath = re.sub( sourcedatafoldername, processDir,subjectpropath )
8582                if verbose:
8583                    print( "subjectpropath " + subjectpropath )
8584                mysplit = subjectpropath.split( "/" )
8585                os.makedirs( subjectpropath, exist_ok=True  )
8586                mysplitCount = len( mysplit )
8587                project = mysplit[mysplitCount-5]
8588                subject = mysplit[mysplitCount-4]
8589                date = mysplit[mysplitCount-3]
8590                modality = mysplit[mysplitCount-2]
8591                uider = mysplit[mysplitCount-1]
8592                identifier = mysep.join([project, subject, date, modality ])
8593                identifier = identifier + "_" + iid
8594                mymm = subjectpropath + "/" + identifier
8595                mymmout = makewideout( mymm )
8596                if verbose and not exists( mymmout ):
8597                    print( "NM " + mymm  + ' execution ')
8598                elif verbose and exists( mymmout ) :
8599                    print( "NM " + mymm + ' complete ' )
8600                if exists( mymmout ):
8601                    continue
8602                if is4d:
8603                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
8604                else:
8605                    nmlist = []
8606                    for zz in myimgsr2:
8607                        nmlist.append( mm_read( zz ) )
8608                srmodel_NM_mdl = None
8609                if srmodel_NM is not False:
8610                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
8611                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
8612                    if isinstance( srmodel_NM, str ):
8613                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
8614                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
8615                    if exists( mdlfn ):
8616                        if verbose:
8617                            print(mdlfn)
8618                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
8619                    else:
8620                        print( mdlfn + " does not exist - wont use SR")
8621                if not testloop:
8622                    tabPro, normPro = mm( t1, hier,
8623                            nm_image_list = nmlist,
8624                            srmodel=srmodel_NM_mdl,
8625                            do_tractography=False,
8626                            do_kk=False,
8627                            do_normalization=templateTx,
8628                            test_run=test_run,
8629                            verbose=True )
8630                    if not test_run:
8631                        write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=None, separator=mysep )
8632                        nmpro = tabPro['NM']
8633                        mysl = range( nmpro['NM_avg'].shape[2] )
8634                    if visualize:
8635                        mysl = range( nmpro['NM_avg'].shape[2] )
8636                        ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
8637                        mysl = range( nmpro['NM_avg_cropped'].shape[2] )
8638                        ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
8639                        ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
8640                        ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
8641            else :
8642                if len( myimgsr ) > 0:
8643                    dowrite=False
8644                    myimgcount = 0
8645                    if len( myimgsr ) > 0 :
8646                        myimg = myimgsr[myimgcount]
8647                        subjectpropath = os.path.dirname( myimg )
8648                        subjectpropath = re.sub( sourcedatafoldername, processDir, subjectpropath )
8649                        mysplit = subjectpropath.split("/")
8650                        mysplitCount = len( mysplit )
8651                        project = mysplit[mysplitCount-5]
8652                        date = mysplit[mysplitCount-4]
8653                        subject = mysplit[mysplitCount-3]
8654                        mymod = mysplit[mysplitCount-2] # FIXME system dependent
8655                        uid = mysplit[mysplitCount-1] # unique image id
8656                        os.makedirs( subjectpropath, exist_ok=True  )
8657                        if mymod == 'T1w':
8658                            identifier = mysep.join([project, date, subject, mymod, uid])
8659                        else:  # add the T1 unique id since that drives a lot of the analysis
8660                            identifier = mysep.join([project, date, subject, mymod, uid ])
8661                            identifier = identifier + "_" + iid
8662                        mymm = subjectpropath + "/" + identifier
8663                        mymmout = makewideout( mymm )
8664                        if verbose and not exists( mymmout ):
8665                            print("Modality specific processing: " + mymod + " execution " )
8666                            print( mymm )
8667                        elif verbose and exists( mymmout ) :
8668                            print("Modality specific processing: " + mymod + " complete " )
8669                        if exists( mymmout ) :
8670                            continue
8671                        if verbose:
8672                            print(subjectpropath)
8673                            print(identifier)
8674                            print( myimg )
8675                        if not testloop:
8676                            img = mm_read( myimg )
8677                            ishapelen = len( img.shape )
8678                            if mymod == 'T1w' and ishapelen == 3: # for a real run, set to True
8679                                if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
8680                                    if verbose:
8681                                        print('start t1 registration')
8682                                    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8683                                    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8684                                    template = mm_read( templatefn )
8685                                    template = ants.resample_image( template, [1,1,1], use_voxels=False )
8686                                    t1reg = ants.registration( template, hier['brain_n4_dnz'],
8687                                        "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
8688                                    myjac = ants.create_jacobian_determinant_image( template,
8689                                        t1reg['fwdtransforms'][0], do_log=True, geom=True )
8690                                    image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
8691                                    if visualize:
8692                                        ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
8693                                        ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
8694                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
8695                                    dowrite=True
8696                                    if verbose:
8697                                        print('start kk')
8698                                    tabPro, normPro = mm( t1, hier,
8699                                        srmodel=None,
8700                                        do_tractography=False,
8701                                        do_kk=True,
8702                                        do_normalization=templateTx,
8703                                        test_run=test_run,
8704                                        verbose=True )
8705                                    if visualize:
8706                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
8707                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
8708                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
8709                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
8710                            if mymod == 'T2Flair' and ishapelen == 3:
8711                                dowrite=True
8712                                tabPro, normPro = mm( t1, hier,
8713                                    flair_image = img,
8714                                    srmodel=None,
8715                                    do_tractography=False,
8716                                    do_kk=False,
8717                                    do_normalization=templateTx,
8718                                    test_run=test_run,
8719                                    verbose=True )
8720                                if visualize:
8721                                    maxslice = np.min( [21, img.shape[2] ] )
8722                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
8723                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
8724                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
8725                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
8726                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
8727                                img2 = None
8728                                if len( myimgsr ) > 1:
8729                                    img2 = mm_read( myimgsr[myimgcount+1] )
8730                                    ishapelen2 = len( img2.shape )
8731                                    if ishapelen2 != 4 :
8732                                        img2 = None
8733                                dowrite=True
8734                                tabPro, normPro = mm( t1, hier,
8735                                    rsf_image=[img,img2],
8736                                    srmodel=None,
8737                                    do_tractography=False,
8738                                    do_kk=False,
8739                                    do_normalization=templateTx,
8740                                    test_run=test_run,
8741                                    verbose=True )
8742                                if tabPro['rsf'] is not None and visualize:
8743                                    dfn=tabPro['rsf']['dfnname']
8744                                    maxslice = np.min( [21, tabPro['rsf']['meanBold'].shape[2] ] )
8745                                    ants.plot( tabPro['rsf']['meanBold'],
8746                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=mymm+mysep+"meanBOLD.png" )
8747                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['alff'],"Normalize"),
8748                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=mymm+mysep+"boldALFF.png" )
8749                                    ants.plot( tabPro['rsf']['meanBold'], ants.iMath(tabPro['rsf']['falff'],"Normalize"),
8750                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=mymm+mysep+"boldfALFF.png" )
8751                                    ants.plot( tabPro['rsf']['meanBold'], tabPro['rsf'][dfn],
8752                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='DefaultMode', filename=mymm+mysep+"boldDefaultMode.png" )
8753                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
8754                                dowrite=True
8755                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
8756                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
8757                                imgList = [ img ]
8758                                bvalfnList = [ bvalfn ]
8759                                bvecfnList = [ bvecfn ]
8760                                if len( myimgsr ) > 1:  # find DTI_RL
8761                                    dtilrfn = myimgsr[myimgcount+1]
8762                                    if len( dtilrfn ) == 1:
8763                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
8764                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
8765                                        imgRL = ants.image_read( dtilrfn )
8766                                        imgList.append( imgRL )
8767                                        bvalfnList.append( bvalfnRL )
8768                                        bvecfnList.append( bvecfnRL )
8769                                srmodel_DTI_mdl=None
8770                                if srmodel_DTI is not False:
8771                                    temp = ants.get_spacing(img)
8772                                    dtspc=[temp[0],temp[1],temp[2]]
8773                                    bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
8774                                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
8775                                    if isinstance( srmodel_DTI, str ):
8776                                        srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
8777                                        mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
8778                                    if exists( mdlfn ):
8779                                        if verbose:
8780                                            print(mdlfn)
8781                                        srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
8782                                    else:
8783                                        print(mdlfn + " does not exist - wont use SR")
8784                                tabPro, normPro = mm( t1, hier,
8785                                    dw_image=imgList,
8786                                    bvals = bvalfnList,
8787                                    bvecs = bvecfnList,
8788                                    srmodel=srmodel_DTI_mdl,
8789                                    do_tractography=not test_run,
8790                                    do_kk=False,
8791                                    do_normalization=templateTx,
8792                                    test_run=test_run,
8793                                    verbose=True )
8794                                mydti = tabPro['DTI']
8795                                if visualize:
8796                                    maxslice = np.min( [21, mydti['recon_fa'] ] )
8797                                    ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
8798                                    ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
8799                                    ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
8800                            if dowrite:
8801                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep, verbose=True )
8802                                for mykey in normPro.keys():
8803                                    if normPro[mykey] is not None:
8804                                        if visualize and normPro[mykey].components == 1 and False:
8805                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
8806        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
8807            return
8808        if verbose:
8809            print("done with " + overmodX )
8810    if verbose:
8811        print("mm_nrg complete.")
8812    return

too dangerous to document ... use with care.

processes multiple modality MRI specifically:

  • T1w
  • T2Flair
  • DTI, DTI_LR, DTI_RL
  • rsfMRI, rsfMRI_LR, rsfMRI_RL
  • NM2DMT (neuromelanin)

other modalities may be added later ...

"trust me, i know what i'm doing" - sledgehammer

convert to pynb via: p2j mm.py -o

convert the ipynb to html via: jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html

this function assumes NRG format for the input data .... we also assume that t1w hierarchical (if already done) was written via its standardized write function. NRG = https://github.com/stnava/biomedicalDataOrganization

this function is verbose

Parameters

studyid : must have columns 1. subjectID 2. date (in form 20220228) and 3. imageID other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid; these provide unique image IDs for these modalities: nm=neuromelanin, dti=diffusion tensor, rsf=resting state fmri, flair=T2Flair. none of these are required. only t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*. see antspymm.generate_mm_dataframe

sourcedir : a study specific folder containing individual subject folders

sourcedatafoldername : root for source data e.g. "images"

processDir : where output will go - parallel to sourcedatafoldername e.g. "processed"

mysep : define a character separator for filename components

srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan

srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan

srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan

visualize : True - will plot some results to png

nrg_modality_list : list of permissible modalities - always include [T1w] as base

verbose : boolean

Returns

writes output to disk and potentially produces figures that may be captured in a ipynb / html file.

def mm_csv( studycsv, mysep='-', srmodel_T1=False, srmodel_NM=False, srmodel_DTI=False, dti_motion_correct='antsRegistrationSyNQuickRepro[r]', dti_denoise=True, nrg_modality_list=None, normalization_template=None, normalization_template_output=None, normalization_template_transform_type='antsRegistrationSyNRepro[s]', normalization_template_spacing=None, enantiomorphic=False, perfusion_trim=10, perfusion_m0_image=None, perfusion_m0=None, rsf_upsampling=3.0, pet3d=None):
8816def mm_csv(
8817    studycsv,   # pandas data frame
8818    mysep = '-', # or "_" for BIDS
8819    srmodel_T1 = False, # optional - will add a great deal of time
8820    srmodel_NM = False, # optional - will add a great deal of time
8821    srmodel_DTI = False, # optional - will add a great deal of time
8822    dti_motion_correct = 'antsRegistrationSyNQuickRepro[r]',
8823    dti_denoise = True,
8824    nrg_modality_list = None,
8825    normalization_template = None,
8826    normalization_template_output = None,
8827    normalization_template_transform_type = "antsRegistrationSyNRepro[s]",
8828    normalization_template_spacing=None,
8829    enantiomorphic=False,
8830    perfusion_trim = 10,
8831    perfusion_m0_image = None,
8832    perfusion_m0 = None,
8833    rsf_upsampling = 3.0,
8834    pet3d = None,
8835):
8836    """
8837    too dangerous to document ... use with care.
8838
8839    processes multiple modality MRI specifically:
8840
8841    * T1w
8842    * T2Flair
8843    * DTI, DTI_LR, DTI_RL
8844    * rsfMRI, rsfMRI_LR, rsfMRI_RL
8845    * NM2DMT (neuromelanin)
8846
8847    other modalities may be added later ...
8848
8849    "trust me, i know what i'm doing" - sledgehammer
8850
8851    convert to pynb via:
8852        p2j mm.py -o
8853
8854    convert the ipynb to html via:
8855        jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html
8856
8857    this function does not assume NRG format for the input data ....
8858
8859    Parameters
8860    -------------
8861
8862    studycsv : must have columns:
8863        - subjectID
8864        - date or session
8865        - imageID
8866        - modality
8867        - sourcedir
8868        - outputdir
8869        - filename (path to the t1 image)
8870        other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid;
8871        these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor,
8872        rsf=resting state fmri, flair=T2Flair.  none of these are required. only
8873        t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*.
8874        see antspymm.generate_mm_dataframe
8875
8876    sourcedir : a study specific folder containing individual subject folders
8877
8878    outputdir : a study specific folder where individual output subject folders will go
8879
8880    filename : the raw image filename (full path)
8881
8882    srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan
8883
8884    srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan
8885
8886    srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan
8887
8888    dti_motion_correct : None, Rigid or SyN
8889
8890    dti_denoise : boolean
8891
8892    nrg_modality_list : optional; defaults to None; use to focus on a given modality
8893
8894    normalization_template : optional; defaults to None; if present, all images will
8895        be deformed into this space and the deformation will be stored with an extension
8896        related to this variable.  this should be a brain extracted T1w image.
8897
8898    normalization_template_output : optional string; defaults to None; naming for the 
8899        normalization_template outputs which will be in the T1w directory.
8900
8901    normalization_template_transform_type : optional string transform type passed to ants.registration
8902
8903    normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed 
8904    
8905    enantiomorphic: boolean (WIP)
8906
8907    perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series
8908
8909    perfusion_m0_image : optional m0 antsImage associated with the perfusion time series
8910
8911    perfusion_m0 : optional list containing indices of the m0 in the perfusion time series
8912
8913    rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done
8914
8915    pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize
8916
8917    Returns
8918    ---------
8919
8920    writes output to disk and produces figures
8921
8922    """
8923    import traceback
8924    visualize = True
8925    verbose = True
8926    if verbose:
8927        print( version() )
8928    if nrg_modality_list is None:
8929        nrg_modality_list = get_valid_modalities()
8930    if studycsv.shape[0] < 1:
8931        raise ValueError('studycsv has no rows')
8932    musthavecols = ['projectID', 'subjectID','date','imageID','modality','sourcedir','outputdir','filename']
8933    for k in range(len(musthavecols)):
8934        if not musthavecols[k] in studycsv.keys():
8935            raise ValueError('studycsv is missing column ' +musthavecols[k] )
8936    def makewideout( x, separator = mysep ):
8937        return x + separator + 'mmwide.csv'
8938    testloop = False
8939    counter=0
8940    import glob as glob
8941    from os.path import exists
8942    ex_path = os.path.expanduser( "~/.antspyt1w/" )
8943    ex_pathmm = os.path.expanduser( "~/.antspymm/" )
8944    templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
8945    if not exists( templatefn ):
8946        print( "**missing files** => call get_data from latest antspyt1w and antspymm." )
8947        antspyt1w.get_data( force_download=True )
8948        get_data( force_download=True )
8949    template = mm_read( templatefn ) # Read in template
8950    test_run = False
8951    if test_run:
8952        visualize=False
8953    # get sid and dtid from studycsv
8954    # musthavecols = ['projectID','subjectID','date','imageID','modality','sourcedir','outputdir','filename']
8955    projid = str(studycsv['projectID'].iloc[0])
8956    sid = str(studycsv['subjectID'].iloc[0])
8957    dtid = str(studycsv['date'].iloc[0])
8958    iid = str(studycsv['imageID'].iloc[0])
8959    t1iidUse=iid
8960    modality = str(studycsv['modality'].iloc[0])
8961    sourcedir = str(studycsv['sourcedir'].iloc[0])
8962    outputdir = str(studycsv['outputdir'].iloc[0])
8963    filename = str(studycsv['filename'].iloc[0])
8964    if not exists(filename):
8965            raise ValueError('mm_nrg cannot find filename ' + filename + ' in mm_csv' )
8966
8967    # hierarchical
8968    # NOTE: if there are multiple T1s for this time point, should take
8969    # the one with the highest resnetGrade
8970    t1fn = filename
8971    if not exists( t1fn ):
8972        raise ValueError('mm_nrg cannot find the T1w with uid ' + t1fn )
8973    t1 = mm_read( t1fn, modality='T1w' )
8974    minspc = np.min(ants.get_spacing(t1))
8975    minshape = np.min(t1.shape)
8976    if minspc < 1e-16:
8977        warnings.warn('minimum spacing in T1w is too small - cannot process. ' + str(minspc) )
8978        return
8979    if minshape < 32:
8980        warnings.warn('minimum shape in T1w is too small - cannot process. ' + str(minshape) )
8981        return
8982
8983    if enantiomorphic:
8984        t1 = enantiomorphic_filling_without_mask( t1, axis=0 )[0]
8985    hierfn = outputdir + "/"  + projid + "/" + sid + "/" + dtid + "/" + "T1wHierarchical" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchical" + mysep + iid + mysep
8986    hierfnSR = outputdir + "/" + projid + "/"  + sid + "/" + dtid + "/" + "T1wHierarchicalSR" + '/' + iid + "/" + projid + mysep + sid + mysep + dtid + mysep + "T1wHierarchicalSR" + mysep + iid + mysep
8987    hierfntest = hierfn + 'cerebellum.csv'
8988    if verbose:
8989        print( hierfntest )
8990    regout = re.sub("T1wHierarchical","T1w",hierfn) + "syn"
8991    templateTx = {
8992        'fwdtransforms': [ regout+'1Warp.nii.gz', regout+'0GenericAffine.mat'],
8993        'invtransforms': [ regout+'0GenericAffine.mat', regout+'1InverseWarp.nii.gz']  }
8994    groupTx = None
8995    # make the T1w directory
8996    os.makedirs( os.path.dirname(re.sub("T1wHierarchical","T1w",hierfn)), exist_ok=True  )
8997    if normalization_template_output is not None:
8998        normout = re.sub("T1wHierarchical","T1w",hierfn) +  normalization_template_output
8999        templateNormTx = {
9000            'fwdtransforms': [ normout+'1Warp.nii.gz', normout+'0GenericAffine.mat'],
9001            'invtransforms': [ normout+'0GenericAffine.mat', normout+'1InverseWarp.nii.gz']  }
9002        groupTx = templateNormTx['fwdtransforms']
9003    if verbose:
9004        print( "-<REGISTRATION EXISTENCE>-: \n" + 
9005              "NAMING: " + regout+'0GenericAffine.mat' + " \n " +
9006            str(exists( templateTx['fwdtransforms'][0])) + " " +
9007            str(exists( templateTx['fwdtransforms'][1])) + " " +
9008            str(exists( templateTx['invtransforms'][0])) + " " +
9009            str(exists( templateTx['invtransforms'][1])) )
9010    if verbose:
9011        print( hierfntest )
9012    hierexists = exists( hierfntest ) and exists( templateTx['fwdtransforms'][0]) and exists( templateTx['fwdtransforms'][1]) and exists( templateTx['invtransforms'][0]) and exists( templateTx['invtransforms'][1])
9013    hier = None
9014    if not hierexists and not testloop:
9015        subjectpropath = os.path.dirname( hierfn )
9016        if verbose:
9017            print( subjectpropath )
9018        os.makedirs( subjectpropath, exist_ok=True  )
9019        hier = antspyt1w.hierarchical( t1, hierfn, labels_to_register=None )
9020        antspyt1w.write_hierarchical( hier, hierfn )
9021        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9022                hier['dataframes'], identifier=None )
9023        t1wide.to_csv( hierfn + 'mmwide.csv' )
9024    ################# read the hierarchical data ###############################
9025    # over-write the rbp data with a consistent and recent approach ############
9026    redograding = True
9027    if redograding:
9028        myx = antspyt1w.inspect_raw_t1( t1, hierfn + 'rbp' , option='both' )
9029        myx['brain'].to_csv( hierfn + 'rbp.csv', index=False )
9030        myx['brain'].to_csv( hierfn + 'rbpbrain.csv', index=False )
9031        del myx
9032
9033    hier = antspyt1w.read_hierarchical( hierfn )
9034    t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9035        hier['dataframes'], identifier=None )
9036    rgrade = str( t1wide['resnetGrade'].iloc[0] )
9037    if t1wide['resnetGrade'].iloc[0] < 0.20:
9038        warnings.warn('T1w quality check indicates failure: ' + rgrade + " will not process." )
9039        return
9040    else:
9041        print('T1w quality check indicates success: ' + rgrade + " will process." )
9042
9043    if srmodel_T1 is not False :
9044        hierfntest = hierfnSR + 'mtl.csv'
9045        if verbose:
9046            print( hierfntest )
9047        hierexists = exists( hierfntest ) # FIXME should test this explicitly but we assume it here
9048        if not hierexists:
9049            subjectpropath = os.path.dirname( hierfnSR )
9050            if verbose:
9051                print( subjectpropath )
9052            os.makedirs( subjectpropath, exist_ok=True  )
9053            # hierarchical_to_sr(t1hier, sr_model, tissue_sr=False, blending=0.5, verbose=False)
9054            bestup = siq.optimize_upsampling_shape( ants.get_spacing(t1), modality='T1' )
9055            mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_2chan_featvggL6_postseg_best_mdl.h5"
9056            if isinstance( srmodel_T1, str ):
9057                mdlfn = os.path.join( ex_pathmm, srmodel_T1 )
9058            if verbose:
9059                print( mdlfn )
9060            if exists( mdlfn ):
9061                srmodel_T1_mdl = tf.keras.models.load_model( mdlfn, compile=False )
9062            else:
9063                print( mdlfn + " does not exist - will not run.")
9064            hierSR = antspyt1w.hierarchical_to_sr( hier, srmodel_T1_mdl, blending=None, tissue_sr=False )
9065            antspyt1w.write_hierarchical( hierSR, hierfnSR )
9066            t1wideSR = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9067                    hierSR['dataframes'], identifier=None )
9068            t1wideSR.to_csv( hierfnSR + 'mmwide.csv' )
9069    hier = antspyt1w.read_hierarchical( hierfn )
9070    if exists( hierfn + 'mmwide.csv' ) :
9071        t1wide = pd.read_csv( hierfn + 'mmwide.csv' )
9072    elif not testloop:
9073        t1wide = antspyt1w.merge_hierarchical_csvs_to_wide_format(
9074                hier['dataframes'], identifier=None )
9075    if not testloop:
9076        t1imgbrn = hier['brain_n4_dnz']
9077        t1atropos = hier['dkt_parc']['tissue_segmentation']
9078
9079    if not exists( regout + "logjacobian.nii.gz" ) or not exists( regout+'1Warp.nii.gz' ):
9080        if verbose:
9081            print('start t1 registration')
9082        ex_path = os.path.expanduser( "~/.antspyt1w/" )
9083        templatefn = ex_path + 'CIT168_T1w_700um_pad_adni.nii.gz'
9084        template = mm_read( templatefn )
9085        template = ants.resample_image( template, [1,1,1], use_voxels=False )
9086        t1reg = ants.registration( template, 
9087            hier['brain_n4_dnz'],
9088            "antsRegistrationSyNQuickRepro[s]", outprefix = regout, verbose=False )
9089        myjac = ants.create_jacobian_determinant_image( template,
9090            t1reg['fwdtransforms'][0], do_log=True, geom=True )
9091        image_write_with_thumbnail( myjac, regout + "logjacobian.nii.gz", thumb=False )
9092        if visualize:
9093            ants.plot( ants.iMath(t1reg['warpedmovout'],"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='warped to template', filename=regout+"totemplate.png" )
9094            ants.plot( ants.iMath(myjac,"Normalize"),  axis=2, nslices=21, ncol=7, crop=True, title='jacobian', filename=regout+"jacobian.png" )
9095
9096    if normalization_template_output is not None and normalization_template is not None:
9097        if verbose:
9098            print("begin group template registration")
9099        if not exists( normout+'0GenericAffine.mat' ):
9100            if normalization_template_spacing is not None:
9101                normalization_template_rr=ants.resample_image(normalization_template,normalization_template_spacing)
9102            else:
9103                normalization_template_rr=normalization_template
9104            greg = ants.registration( 
9105                normalization_template_rr, 
9106                hier['brain_n4_dnz'],
9107                normalization_template_transform_type,
9108                outprefix = normout, verbose=False )
9109            myjac = ants.create_jacobian_determinant_image( template,
9110                    greg['fwdtransforms'][0], do_log=True, geom=True )
9111            image_write_with_thumbnail( myjac, normout + "logjacobian.nii.gz", thumb=False )
9112            if verbose:
9113                print("end group template registration")
9114        else:
9115            if verbose:
9116                print("group template registration already done")
9117
9118    # loop over modalities and then unique image IDs
9119    # we treat NM in a "special" way -- aggregating repeats
9120    # other modalities (beyond T1) are treated individually
9121    for overmodX in nrg_modality_list:
9122        # define 1. input images 2. output prefix
9123        mydoc = docsamson( overmodX, studycsv=studycsv, outputdir=outputdir, projid=projid, sid=sid, dtid=dtid, mysep=mysep,t1iid=t1iidUse )
9124        myimgsr = mydoc['images']
9125        mymm = mydoc['outprefix']
9126        mymod = mydoc['modality']
9127        if verbose:
9128            print( mydoc )
9129        if len(myimgsr) > 0:
9130            dowrite=False
9131            if verbose:
9132                print( 'overmodX is : ' + overmodX )
9133                print( 'example image name is : '  )
9134                print( myimgsr )
9135            if overmodX == 'NM2DMT':
9136                dowrite = True
9137                visualize = True
9138                subjectpropath = os.path.dirname( mydoc['outprefix'] )
9139                if verbose:
9140                    print("subjectpropath is")
9141                    print(subjectpropath)
9142                    os.makedirs( subjectpropath, exist_ok=True  )
9143                myimgsr2 = myimgsr
9144                myimgsr2.sort()
9145                is4d = False
9146                temp = ants.image_read( myimgsr2[0] )
9147                if temp.dimension == 4:
9148                    is4d = True
9149                if len( myimgsr2 ) == 1 and not is4d: # check dimension
9150                    myimgsr2 = myimgsr2 + myimgsr2
9151                mymmout = makewideout( mymm )
9152                if verbose and not exists( mymmout ):
9153                    print( "NM " + mymm  + ' execution ')
9154                elif verbose and exists( mymmout ) :
9155                    print( "NM " + mymm + ' complete ' )
9156                if exists( mymmout ):
9157                    continue
9158                if is4d:
9159                    nmlist = ants.ndimage_to_list( mm_read( myimgsr2[0] ) )
9160                else:
9161                    nmlist = []
9162                    for zz in myimgsr2:
9163                        nmlist.append( mm_read( zz ) )
9164                srmodel_NM_mdl = None
9165                if srmodel_NM is not False:
9166                    bestup = siq.optimize_upsampling_shape( ants.get_spacing(nmlist[0]), modality='NM', roundit=True )
9167                    mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
9168                    if isinstance( srmodel_NM, str ):
9169                        srmodel_NM = re.sub( "bestup", bestup, srmodel_NM )
9170                        mdlfn = os.path.join( ex_pathmm, srmodel_NM )
9171                    if exists( mdlfn ):
9172                        if verbose:
9173                            print(mdlfn)
9174                        srmodel_NM_mdl = tf.keras.models.load_model( mdlfn, compile=False  )
9175                    else:
9176                        print( mdlfn + " does not exist - wont use SR")
9177                if not testloop:
9178                    try:
9179                        tabPro, normPro = mm( t1, hier,
9180                            nm_image_list = nmlist,
9181                            srmodel=srmodel_NM_mdl,
9182                            do_tractography=False,
9183                            do_kk=False,
9184                            do_normalization=templateTx,
9185                            group_template = normalization_template,
9186                            group_transform = groupTx,
9187                            test_run=test_run,
9188                            verbose=True )
9189                    except Exception as e:
9190                        error_info = traceback.format_exc()
9191                        print(error_info)
9192                        visualize=False
9193                        dowrite=False
9194                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9195                        pass
9196                    if not test_run:
9197                        if dowrite:
9198                            write_mm( output_prefix=mymm, mm=tabPro,
9199                                mm_norm=normPro, t1wide=None, separator=mysep )
9200                        if visualize :
9201                            nmpro = tabPro['NM']
9202                            mysl = range( nmpro['NM_avg'].shape[2] )
9203                            ants.plot( nmpro['NM_avg'],  nmpro['t1_to_NM'], slices=mysl, axis=2, title='nm + t1', filename=mymm+mysep+"NMavg.png" )
9204                            mysl = range( nmpro['NM_avg_cropped'].shape[2] )
9205                            ants.plot( nmpro['NM_avg_cropped'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop', filename=mymm+mysep+"NMavgcrop.png" )
9206                            ants.plot( nmpro['NM_avg_cropped'], nmpro['t1_to_NM'], axis=2, slices=mysl, overlay_alpha=0.3, title='nm crop + t1', filename=mymm+mysep+"NMavgcropt1.png" )
9207                            ants.plot( nmpro['NM_avg_cropped'], nmpro['NM_labels'], axis=2, slices=mysl, title='nm crop + labels', filename=mymm+mysep+"NMavgcroplabels.png" )
9208            else :
9209                if len( myimgsr ) > 0 :
9210                    dowrite=False
9211                    myimgcount=0
9212                    if len( myimgsr ) > 0 :
9213                        myimg = myimgsr[ myimgcount ]
9214                        subjectpropath = os.path.dirname( mydoc['outprefix'] )
9215                        if verbose:
9216                            print("subjectpropath is")
9217                            print(subjectpropath)
9218                        os.makedirs( subjectpropath, exist_ok=True  )
9219                        mymmout = makewideout( mymm )
9220                        if verbose and not exists( mymmout ):
9221                            print( "Modality specific processing: " + mymod + " execution " )
9222                            print( mymm )
9223                        elif verbose and exists( mymmout ) :
9224                            print("Modality specific processing: " + mymod + " complete " )
9225                        if exists( mymmout ) :
9226                            continue
9227                        if verbose:
9228                            print( subjectpropath )
9229                            print( myimg )
9230                        if not testloop:
9231                            img = mm_read( myimg )
9232                            ishapelen = len( img.shape )
9233                            if mymod == 'T1w' and ishapelen == 3:
9234                                if not exists( mymm + mysep + "kk_norm.nii.gz" ):
9235                                    dowrite=True
9236                                    if verbose:
9237                                        print('start kk')
9238                                    try:
9239                                        tabPro, normPro = mm( t1, hier,
9240                                            srmodel=None,
9241                                            do_tractography=False,
9242                                            do_kk=True,
9243                                            do_normalization=templateTx,
9244                                            group_template = normalization_template,
9245                                            group_transform = groupTx,
9246                                            test_run=test_run,
9247                                            verbose=True )
9248                                    except Exception as e:
9249                                        error_info = traceback.format_exc()
9250                                        print(error_info)
9251                                        visualize=False
9252                                        dowrite=False
9253                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9254                                        pass
9255                                    if visualize:
9256                                        maxslice = np.min( [21, hier['brain_n4_dnz'].shape[2] ] )
9257                                        ants.plot( hier['brain_n4_dnz'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='brain extraction', filename=mymm+mysep+"brainextraction.png" )
9258                                        ants.plot( tabPro['kk']['thickness_image'], axis=2, nslices=maxslice, ncol=7, crop=True, title='kk',
9259                                        cmap='plasma', filename=mymm+mysep+"kkthickness.png" )
9260                            if mymod == 'T2Flair' and ishapelen == 3 and np.min(img.shape) > 15:
9261                                dowrite=True
9262                                try:
9263                                    tabPro, normPro = mm( t1, hier,
9264                                        flair_image = img,
9265                                        srmodel=None,
9266                                        do_tractography=False,
9267                                        do_kk=False,
9268                                        do_normalization=templateTx,
9269                                        group_template = normalization_template,
9270                                        group_transform = groupTx,
9271                                        test_run=test_run,
9272                                        verbose=True )
9273                                except Exception as e:
9274                                        error_info = traceback.format_exc()
9275                                        print(error_info)
9276                                        visualize=False
9277                                        dowrite=False
9278                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9279                                        pass
9280                                if visualize:
9281                                    maxslice = np.min( [21, img.shape[2] ] )
9282                                    ants.plot_ortho( img, crop=True, title='Flair', filename=mymm+mysep+"flair.png", flat=True )
9283                                    ants.plot_ortho( img, tabPro['flair']['WMH_probability_map'], crop=True, title='Flair + WMH', filename=mymm+mysep+"flairWMH.png", flat=True )
9284                                    if tabPro['flair']['WMH_posterior_probability_map'] is not None:
9285                                        ants.plot_ortho( img, tabPro['flair']['WMH_posterior_probability_map'],  crop=True, title='Flair + prior WMH', filename=mymm+mysep+"flairpriorWMH.png", flat=True )
9286                            if ( mymod == 'rsfMRI_LR' or mymod == 'rsfMRI_RL' or mymod == 'rsfMRI' )  and ishapelen == 4:
9287                                img2 = None
9288                                if len( myimgsr ) > 1:
9289                                    img2 = mm_read( myimgsr[myimgcount+1] )
9290                                    ishapelen2 = len( img2.shape )
9291                                    if ishapelen2 != 4 or 1 in img2.shape:
9292                                        img2 = None
9293                                if 1 in img.shape:
9294                                    warnings.warn( 'rsfMRI image shape suggests it is an incorrectly converted mosaic image - will not process.')
9295                                    dowrite=False
9296                                    tabPro={'rsf':None}
9297                                    normPro={'rsf':None}
9298                                else:
9299                                    dowrite=True
9300                                    try:
9301                                        tabPro, normPro = mm( t1, hier,
9302                                            rsf_image=[img,img2],
9303                                            srmodel=None,
9304                                            do_tractography=False,
9305                                            do_kk=False,
9306                                            do_normalization=templateTx,
9307                                            group_template = normalization_template,
9308                                            group_transform = groupTx,
9309                                            rsf_upsampling = rsf_upsampling,
9310                                            test_run=test_run,
9311                                            verbose=True )
9312                                    except Exception as e:
9313                                        error_info = traceback.format_exc()
9314                                        print(error_info)
9315                                        visualize=False
9316                                        dowrite=False
9317                                        tabPro={'rsf':None}
9318                                        normPro={'rsf':None}
9319                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9320                                        pass
9321                                if tabPro['rsf'] is not None and visualize:
9322                                    for tpro in tabPro['rsf']: # FIXMERSF
9323                                        maxslice = np.min( [21, tpro['meanBold'].shape[2] ] )
9324                                        tproprefix = mymm+mysep+str(tpro['paramset'])+mysep
9325                                        ants.plot( tpro['meanBold'],
9326                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='meanBOLD', filename=tproprefix+"meanBOLD.png" )
9327                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['alff'],"Normalize"),
9328                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='ALFF', filename=tproprefix+"boldALFF.png" )
9329                                        ants.plot( tpro['meanBold'], ants.iMath(tpro['falff'],"Normalize"),
9330                                            axis=2, nslices=maxslice, ncol=7, crop=True, title='fALFF', filename=tproprefix+"boldfALFF.png" )
9331                                        dfn=tpro['dfnname']
9332                                        ants.plot( tpro['meanBold'], tpro[dfn],
9333                                            axis=2, nslices=maxslice, ncol=7, crop=True, title=dfn, filename=tproprefix+"boldDefaultMode.png" )
9334                            if ( mymod == 'perf' ) and ishapelen == 4:
9335                                dowrite=True
9336                                try:
9337                                    tabPro, normPro = mm( t1, hier,
9338                                        perfusion_image=img,
9339                                        srmodel=None,
9340                                        do_tractography=False,
9341                                        do_kk=False,
9342                                        do_normalization=templateTx,
9343                                        group_template = normalization_template,
9344                                        group_transform = groupTx,
9345                                        test_run=test_run,
9346                                        perfusion_trim=perfusion_trim,
9347                                        perfusion_m0_image=perfusion_m0_image,
9348                                        perfusion_m0=perfusion_m0,
9349                                        verbose=True )
9350                                except Exception as e:
9351                                        error_info = traceback.format_exc()
9352                                        print(error_info)
9353                                        visualize=False
9354                                        dowrite=False
9355                                        tabPro={'perf':None}
9356                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9357                                        pass
9358                                if tabPro['perf'] is not None and visualize:
9359                                    maxslice = np.min( [21, tabPro['perf']['meanBold'].shape[2] ] )
9360                                    ants.plot( tabPro['perf']['perfusion'],
9361                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='perfusion image', filename=mymm+mysep+"perfusion.png" )
9362                                    ants.plot( tabPro['perf']['cbf'],
9363                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='CBF image', filename=mymm+mysep+"cbf.png" )
9364                                    ants.plot( tabPro['perf']['m0'],
9365                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='M0 image', filename=mymm+mysep+"m0.png" )
9366
9367                            if ( mymod == 'pet3d' ) and ishapelen == 3:
9368                                dowrite=True
9369                                try:
9370                                    tabPro, normPro = mm( t1, hier,
9371                                        srmodel=None,
9372                                        do_tractography=False,
9373                                        do_kk=False,
9374                                        do_normalization=templateTx,
9375                                        group_template = normalization_template,
9376                                        group_transform = groupTx,
9377                                        test_run=test_run,
9378                                        pet_3d_image=img,
9379                                        verbose=True )
9380                                except Exception as e:
9381                                        error_info = traceback.format_exc()
9382                                        print(error_info)
9383                                        visualize=False
9384                                        dowrite=False
9385                                        tabPro={'pet3d':None}
9386                                        print(f"antspymmerror occurred while processing {overmodX}: {e}")
9387                                        pass
9388                                if tabPro['pet3d'] is not None and visualize:
9389                                    maxslice = np.min( [21, tabPro['pet3d']['pet3d'].shape[2] ] )
9390                                    ants.plot( tabPro['pet3d']['pet3d'],
9391                                        axis=2, nslices=maxslice, ncol=7, crop=True, title='PET image', filename=mymm+mysep+"pet3d.png" )
9392                            if ( mymod == 'DTI_LR' or mymod == 'DTI_RL' or mymod == 'DTI' ) and ishapelen == 4:
9393                                bvalfn = re.sub( '.nii.gz', '.bval' , myimg )
9394                                bvecfn = re.sub( '.nii.gz', '.bvec' , myimg )
9395                                imgList = [ img ]
9396                                bvalfnList = [ bvalfn ]
9397                                bvecfnList = [ bvecfn ]
9398                                missing_dti_data=False # bval, bvec or images
9399                                if len( myimgsr ) == 2:  # find DTI_RL
9400                                    dtilrfn = myimgsr[myimgcount+1]
9401                                    if exists( dtilrfn ):
9402                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
9403                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
9404                                        imgRL = ants.image_read( dtilrfn )
9405                                        imgList.append( imgRL )
9406                                        bvalfnList.append( bvalfnRL )
9407                                        bvecfnList.append( bvecfnRL )
9408                                elif len( myimgsr ) == 3:  # find DTI_RL
9409                                    print("DTI trinity")
9410                                    dtilrfn = myimgsr[myimgcount+1]
9411                                    dtilrfn2 = myimgsr[myimgcount+2]
9412                                    if exists( dtilrfn ) and exists( dtilrfn2 ):
9413                                        bvalfnRL = re.sub( '.nii.gz', '.bval' , dtilrfn )
9414                                        bvecfnRL = re.sub( '.nii.gz', '.bvec' , dtilrfn )
9415                                        bvalfnRL2 = re.sub( '.nii.gz', '.bval' , dtilrfn2 )
9416                                        bvecfnRL2 = re.sub( '.nii.gz', '.bvec' , dtilrfn2 )
9417                                        imgRL = ants.image_read( dtilrfn )
9418                                        imgRL2 = ants.image_read( dtilrfn2 )
9419                                        bvals, bvecs = read_bvals_bvecs( bvalfnRL , bvecfnRL  )
9420                                        print( bvals.max() )
9421                                        bvals2, bvecs2 = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
9422                                        print( bvals2.max() )
9423                                        temp = merge_dwi_data( imgRL, bvals, bvecs, imgRL2, bvals2, bvecs2  )
9424                                        imgList.append( temp[0] )
9425                                        bvalfnList.append( mymm+mysep+'joined.bval' )
9426                                        bvecfnList.append( mymm+mysep+'joined.bvec' )
9427                                        write_bvals_bvecs( temp[1], temp[2], mymm+mysep+'joined' )
9428                                        bvalsX, bvecsX = read_bvals_bvecs( bvalfnRL2 , bvecfnRL2  )
9429                                        print( bvalsX.max() )
9430                                # check existence of all files expected ...
9431                                for dtiex in bvalfnList+bvecfnList+myimgsr:
9432                                    if not exists(dtiex):
9433                                        print('mm_csv: missing dti data ' + dtiex )
9434                                        missing_dti_data=True
9435                                        dowrite=False
9436                                if not missing_dti_data:
9437                                    dowrite=True
9438                                    srmodel_DTI_mdl=None
9439                                    if srmodel_DTI is not False:
9440                                        temp = ants.get_spacing(img)
9441                                        dtspc=[temp[0],temp[1],temp[2]]
9442                                        bestup = siq.optimize_upsampling_shape( dtspc, modality='DTI' )
9443                                        mdlfn = ex_pathmm + "siq_default_sisr_" + bestup + "_1chan_featvggL6_best_mdl.h5"
9444                                        if isinstance( srmodel_DTI, str ):
9445                                            srmodel_DTI = re.sub( "bestup", bestup, srmodel_DTI )
9446                                            mdlfn = os.path.join( ex_pathmm, srmodel_DTI )
9447                                        if exists( mdlfn ):
9448                                            if verbose:
9449                                                print(mdlfn)
9450                                            srmodel_DTI_mdl = tf.keras.models.load_model( mdlfn, compile=False )
9451                                        else:
9452                                            print(mdlfn + " does not exist - wont use SR")
9453                                    try:
9454                                        tabPro, normPro = mm( t1, hier,
9455                                            dw_image=imgList,
9456                                            bvals = bvalfnList,
9457                                            bvecs = bvecfnList,
9458                                            srmodel=srmodel_DTI_mdl,
9459                                            do_tractography=not test_run,
9460                                            do_kk=False,
9461                                            do_normalization=templateTx,
9462                                            group_template = normalization_template,
9463                                            group_transform = groupTx,
9464                                            dti_motion_correct = dti_motion_correct,
9465                                            dti_denoise = dti_denoise,
9466                                            test_run=test_run,
9467                                            verbose=True )
9468                                    except Exception as e:
9469                                            error_info = traceback.format_exc()
9470                                            print(error_info)
9471                                            visualize=False
9472                                            dowrite=False
9473                                            tabPro={'DTI':None}
9474                                            print(f"antspymmerror occurred while processing {overmodX}: {e}")
9475                                            pass
9476                                    mydti = tabPro['DTI']
9477                                    if visualize and tabPro['DTI'] is not None:
9478                                        maxslice = np.min( [21, mydti['recon_fa'] ] )
9479                                        ants.plot( mydti['recon_fa'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='FA', filename=mymm+mysep+"FAbetter.png"  )
9480                                        ants.plot( mydti['recon_fa'], mydti['jhu_labels'], axis=2, nslices=maxslice, ncol=7, crop=True, title='FA + JHU', filename=mymm+mysep+"FAJHU.png"  )
9481                                        ants.plot( mydti['recon_md'],  axis=2, nslices=maxslice, ncol=7, crop=True, title='MD', filename=mymm+mysep+"MD.png"  )
9482                            if dowrite:
9483                                write_mm( output_prefix=mymm, mm=tabPro, mm_norm=normPro, t1wide=t1wide, separator=mysep )
9484                                for mykey in normPro.keys():
9485                                    if normPro[mykey] is not None and normPro[mykey].components == 1:
9486                                        if visualize and False:
9487                                            ants.plot( template, normPro[mykey], axis=2, nslices=21, ncol=7, crop=True, title=mykey, filename=mymm+mysep+mykey+".png"   )
9488        if overmodX == nrg_modality_list[ len( nrg_modality_list ) - 1 ]:
9489            return
9490        if verbose:
9491            print("done with " + overmodX )
9492    if verbose:
9493        print("mm_nrg complete.")
9494    return

too dangerous to document ... use with care.

processes multiple modality MRI specifically:

  • T1w
  • T2Flair
  • DTI, DTI_LR, DTI_RL
  • rsfMRI, rsfMRI_LR, rsfMRI_RL
  • NM2DMT (neuromelanin)

other modalities may be added later ...

"trust me, i know what i'm doing" - sledgehammer

convert to pynb via: p2j mm.py -o

convert the ipynb to html via: jupyter nbconvert ANTsPyMM/tests/mm.ipynb --execute --to html

this function does not assume NRG format for the input data ....

Parameters

studycsv : must have columns: - subjectID - date or session - imageID - modality - sourcedir - outputdir - filename (path to the t1 image) other relevant columns include nmid1-10, rsfid1, rsfid2, dtid1, dtid2, flairid; these provide filenames for these modalities: nm=neuromelanin, dti=diffusion tensor, rsf=resting state fmri, flair=T2Flair. none of these are required. only t1 is required. rsfid1/rsfid2 will be processed jointly. same for dtid1/dtid2 and nmid*. see antspymm.generate_mm_dataframe

sourcedir : a study specific folder containing individual subject folders

outputdir : a study specific folder where individual output subject folders will go

filename : the raw image filename (full path)

srmodel_T1 : False (default) - will add a great deal of time - or h5 filename, 2 chan

srmodel_NM : False (default) - will add a great deal of time - or h5 filename, 1 chan

srmodel_DTI : False (default) - will add a great deal of time - or h5 filename, 1 chan

dti_motion_correct : None, Rigid or SyN

dti_denoise : boolean

nrg_modality_list : optional; defaults to None; use to focus on a given modality

normalization_template : optional; defaults to None; if present, all images will be deformed into this space and the deformation will be stored with an extension related to this variable. this should be a brain extracted T1w image.

normalization_template_output : optional string; defaults to None; naming for the normalization_template outputs which will be in the T1w directory.

normalization_template_transform_type : optional string transform type passed to ants.registration

normalization_template_spacing : 3-tuple controlling the resolution at which registration is computed

enantiomorphic: boolean (WIP)

perfusion_trim : optional integer number of time volumes to exclude from the front of the perfusion time series

perfusion_m0_image : optional m0 antsImage associated with the perfusion time series

perfusion_m0 : optional list containing indices of the m0 in the perfusion time series

rsf_upsampling : optional upsampling parameter value in mm; if set to zero, no upsampling is done

pet3d : optional antsImage for PET (or other 3d scalar) data which we want to summarize

Returns

writes output to disk and produces figures

def collect_blind_qc_by_modality(modality_path, set_index_to_fn=True):
1094def collect_blind_qc_by_modality( modality_path, set_index_to_fn=True ):
1095    """
1096    Collects blind QC data from multiple CSV files with the same modality.
1097
1098    Args:
1099
1100    modality_path (str): The path to the folder containing the CSV files.
1101
1102    set_index_to_fn: boolean
1103
1104    Returns:
1105    Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.
1106    """
1107    import glob as glob
1108    fns = glob.glob( modality_path )
1109    fns.sort()
1110    jdf = pd.DataFrame()
1111    for k in range(len(fns)):
1112        temp=pd.read_csv(fns[k])
1113        if not 'filename' in temp.keys():
1114            temp['filename']=fns[k]
1115        jdf=pd.concat( [jdf,temp], axis=0, ignore_index=False )
1116    if set_index_to_fn:
1117        jdf.reset_index(drop=True)
1118        if "Unnamed: 0" in jdf.columns:
1119            holder=jdf.pop( "Unnamed: 0" )
1120        jdf.set_index('filename')
1121    return jdf

Collects blind QC data from multiple CSV files with the same modality.

Args:

modality_path (str): The path to the folder containing the CSV files.

set_index_to_fn: boolean

Returns: Pandas DataFrame: A DataFrame containing all the blind QC data from the CSV files.

def alffmap(x, flo=0.01, fhi=0.1, tr=1, detrend=True):
9755def alffmap( x, flo=0.01, fhi=0.1, tr=1, detrend = True ):
9756    """
9757    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
9758    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
9759    are related measures that quantify the amplitude of low frequency
9760    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
9761    same function in ANTsR.
9762
9763    x input vector for the time series of interest
9764    flo low frequency, typically 0.01
9765    fhi high frequency, typically 0.1
9766    tr the period associated with the vector x (inverse of frequency)
9767    detrend detrend the input time series
9768
9769    return vector is output showing ALFF and fALFF values
9770    """
9771    temp = spec_pgram( x, xfreq=1.0/tr, demean=False, detrend=detrend, taper=0, fast=True, plot=False )
9772    fselect = np.logical_and( temp['freq'] >= flo, temp['freq'] <= fhi )
9773    denom = (temp['spec']).sum()
9774    numer = (temp['spec'][fselect]).sum()
9775    return {  'alff':numer, 'falff': numer/denom }

Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008) are related measures that quantify the amplitude of low frequency oscillations (LFOs). This function outputs ALFF and fALFF for the input. same function in ANTsR.

x input vector for the time series of interest flo low frequency, typically 0.01 fhi high frequency, typically 0.1 tr the period associated with the vector x (inverse of frequency) detrend detrend the input time series

return vector is output showing ALFF and fALFF values

def alff_image(x, mask, flo=0.01, fhi=0.1, nuisance=None):
9778def alff_image( x, mask, flo=0.01, fhi=0.1, nuisance=None ):
9779    """
9780    Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and
9781    fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008)
9782    are related measures that quantify the amplitude of low frequency
9783    oscillations (LFOs).  This function outputs ALFF and fALFF for the input.
9784
9785    x - input clean resting state fmri
9786    mask - mask over which to compute f/alff
9787    flo - low frequency, typically 0.01
9788    fhi - high frequency, typically 0.1
9789    nuisance - optional nuisance matrix
9790
9791    return dictionary with ALFF and fALFF images
9792    """
9793    xmat = ants.timeseries_to_matrix( x, mask )
9794    if nuisance is not None:
9795        xmat = ants.regress_components( xmat, nuisance )
9796    alffvec = xmat[0,:]*0
9797    falffvec = xmat[0,:]*0
9798    mytr = ants.get_spacing( x )[3]
9799    for n in range( xmat.shape[1] ):
9800        temp = alffmap( xmat[:,n], flo=flo, fhi=fhi, tr=mytr )
9801        alffvec[n]=temp['alff']
9802        falffvec[n]=temp['falff']
9803    alffi=ants.make_image( mask, alffvec )
9804    falffi=ants.make_image( mask, falffvec )
9805    alfftrimmedmean = calculate_trimmed_mean( alffvec, 0.01 )
9806    falfftrimmedmean = calculate_trimmed_mean( falffvec, 0.01 )
9807    alffi=alffi / alfftrimmedmean
9808    falffi=falffi / falfftrimmedmean
9809    return {  'alff': alffi, 'falff': falffi }

Amplitude of Low Frequency Fluctuations (ALFF; Zang et al., 2007) and fractional Amplitude of Low Frequency Fluctuations (f/ALFF; Zou et al., 2008) are related measures that quantify the amplitude of low frequency oscillations (LFOs). This function outputs ALFF and fALFF for the input.

x - input clean resting state fmri mask - mask over which to compute f/alff flo - low frequency, typically 0.01 fhi - high frequency, typically 0.1 nuisance - optional nuisance matrix

return dictionary with ALFF and fALFF images

def down2iso(x, interpolation='linear', takemin=False):
9812def down2iso( x, interpolation='linear', takemin=False ):
9813    """
9814    will downsample an anisotropic image to an isotropic resolution
9815
9816    x: input image
9817
9818    interpolation: linear or nearestneighbor
9819
9820    takemin : boolean map to min space; otherwise max
9821
9822    return image downsampled to isotropic resolution
9823    """
9824    spc = ants.get_spacing( x )
9825    if takemin:
9826        newspc = np.asarray(spc).min()
9827    else:
9828        newspc = np.asarray(spc).max()
9829    newspc = np.repeat( newspc, x.dimension )
9830    if interpolation == 'linear':
9831        xs = ants.resample_image( x, newspc, interp_type=0)
9832    else:
9833        xs = ants.resample_image( x, newspc, interp_type=1)
9834    return xs

will downsample an anisotropic image to an isotropic resolution

x: input image

interpolation: linear or nearestneighbor

takemin : boolean map to min space; otherwise max

return image downsampled to isotropic resolution

def read_mm_csv(x, is_t1=False, colprefix=None, separator='-', verbose=False):
9837def read_mm_csv( x, is_t1=False, colprefix=None, separator='-', verbose=False ):
9838    splitter=os.path.basename(x).split( separator )
9839    lensplit = len( splitter )-1
9840    temp = os.path.basename(x)
9841    temp = os.path.splitext(temp)[0]
9842    temp = re.sub(separator+'mmwide','',temp)
9843    idcols = ['u_hier_id','sid','visitdate','modality','mmimageuid','t1imageuid']
9844    df = pd.DataFrame( columns = idcols, index=range(1) )
9845    valstoadd = [temp] + splitter[1:(lensplit-1)]
9846    if is_t1:
9847        valstoadd = valstoadd + [splitter[(lensplit-1)],splitter[(lensplit-1)]]
9848    else:
9849        split2=splitter[(lensplit-1)].split( "_" )
9850        if len(split2) == 1:
9851            split2.append( split2[0] )
9852        if len(valstoadd) == 3:
9853            valstoadd = valstoadd + [split2[0]] + [math.nan] + [split2[1]]
9854        else:
9855            valstoadd = valstoadd + [split2[0],split2[1]]
9856    if verbose:
9857        print( valstoadd )
9858    df.iloc[0] = valstoadd
9859    if verbose:
9860        print( "read xdf: " + x )
9861    xdf = pd.read_csv( x )
9862    df.reset_index()
9863    xdf.reset_index(drop=True)
9864    if "Unnamed: 0" in xdf.columns:
9865        holder=xdf.pop( "Unnamed: 0" )
9866    if "Unnamed: 1" in xdf.columns:
9867        holder=xdf.pop( "Unnamed: 1" )
9868    if "u_hier_id.1" in xdf.columns:
9869        holder=xdf.pop( "u_hier_id.1" )
9870    if "u_hier_id" in xdf.columns:
9871        holder=xdf.pop( "u_hier_id" )
9872    if not is_t1:
9873        if 'resnetGrade' in xdf.columns:
9874            index_no = xdf.columns.get_loc('resnetGrade')
9875            xdf = xdf.drop( xdf.columns[range(index_no+1)] , axis=1)
9876
9877    if xdf.shape[0] == 2:
9878        xdfcols = xdf.columns
9879        xdf = xdf.iloc[1]
9880        ddnum = xdf.to_numpy()
9881        ddnum = ddnum.reshape([1,ddnum.shape[0]])
9882        newcolnames = xdf.index.to_list()
9883        if len(newcolnames) != ddnum.shape[1]:
9884            print("Cannot Merge : Shape MisMatch " + str( len(newcolnames) ) + " " + str(ddnum.shape[1]))
9885        else:
9886            xdf = pd.DataFrame(ddnum, columns=xdfcols )
9887    if xdf.shape[1] == 0:
9888        return None
9889    if colprefix is not None:
9890        xdf.columns=colprefix + xdf.columns
9891    return pd.concat( [df,xdf], axis=1, ignore_index=False )
def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False):
 9997def assemble_modality_specific_dataframes( mm_wide_csvs, hierdfin, nrg_modality, separator='-', progress=None, verbose=False ):
 9998    moddersub = re.sub( "[*]","",nrg_modality)
 9999    nmdf=pd.DataFrame()
10000    for k in range( hierdfin.shape[0] ):
10001        if progress is not None:
10002            if k % progress == 0:
10003                progger = str( np.round( k / hierdfin.shape[0] * 100 ) )
10004                print( progger, end ="...", flush=True)
10005        temp = mm_wide_csvs[k]
10006        mypartsf = temp.split("T1wHierarchical")
10007        myparts = mypartsf[0]
10008        t1iid = str(mypartsf[1].split("/")[1])
10009        fnsnm = glob.glob(myparts+"/" + nrg_modality + "/*/*" + t1iid + "*wide.csv")
10010        if len( fnsnm ) > 0 :
10011            for y in fnsnm:
10012                temp=read_mm_csv( y, colprefix=moddersub+'_', is_t1=False, separator=separator, verbose=verbose )
10013                if temp is not None:
10014                    nmdf=pd.concat( [nmdf, temp], axis=0, ignore_index=False )
10015    return nmdf
def bind_wide_mm_csvs(mm_wide_csvs, merge=True, separator='-', verbose=0):
10017def bind_wide_mm_csvs( mm_wide_csvs, merge=True, separator='-', verbose = 0 ) :
10018    """
10019    will convert a list of t1w hierarchical csv filenames to a merged dataframe
10020
10021    returns a pair of data frames, the left side having all entries and the
10022        right side having row averaged entries i.e. unique values for each visit
10023
10024    set merge to False to return individual dataframes ( for debugging )
10025
10026    return alldata, row_averaged_data
10027    """
10028    mm_wide_csvs.sort()
10029    if not mm_wide_csvs:
10030        print("No files found with specified pattern")
10031        return
10032    # 1. row-bind the t1whier data
10033    # 2. same for each other modality
10034    # 3. merge the modalities by the keys
10035    hierdf = pd.DataFrame()
10036    for y in mm_wide_csvs:
10037        temp=read_mm_csv( y, colprefix='T1Hier_', separator=separator, is_t1=True )
10038        if temp is not None:
10039            hierdf=pd.concat( [hierdf, temp], axis=0, ignore_index=False )
10040    if verbose > 0:
10041        mypro=50
10042    else:
10043        mypro=None
10044    if verbose > 0:
10045        print("thickness")
10046    thkdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T1w', progress=mypro, verbose=verbose==2)
10047    if verbose > 0:
10048        print("flair")
10049    flairdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'T2Flair', progress=mypro, verbose=verbose==2)
10050    if verbose > 0:
10051        print("NM")
10052    nmdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'NM2DMT', progress=mypro, verbose=verbose==2)
10053    if verbose > 0:
10054        print("rsf")
10055    rsfdf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'rsfMRI*', progress=mypro, verbose=verbose==2)
10056    if verbose > 0:
10057        print("dti")
10058    dtidf = assemble_modality_specific_dataframes( mm_wide_csvs, hierdf, 'DTI*', progress=mypro, verbose=verbose==2 )
10059    if not merge:
10060        return hierdf, thkdf, flairdf, nmdf, rsfdf, dtidf
10061    hierdfmix = hierdf.copy()
10062    modality_df_suffixes = [
10063        (thkdf, "_thk"),
10064        (flairdf, "_flair"),
10065        (nmdf, "_nm"),
10066        (rsfdf, "_rsf"),
10067        (dtidf, "_dti"),
10068    ]
10069    for pair in modality_df_suffixes:
10070        hierdfmix = merge_mm_dataframe(hierdfmix, pair[0], pair[1])
10071    hierdfmix = hierdfmix.replace(r'^\s*$', np.nan, regex=True)
10072    return hierdfmix, hierdfmix.groupby("u_hier_id", as_index=False).mean(numeric_only=True)

will convert a list of t1w hierarchical csv filenames to a merged dataframe

returns a pair of data frames, the left side having all entries and the right side having row averaged entries i.e. unique values for each visit

set merge to False to return individual dataframes ( for debugging )

return alldata, row_averaged_data

def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10074def merge_mm_dataframe(hierdf, mmdf, mm_suffix):
10075    try:
10076        hierdf = hierdf.merge(mmdf, on=['sid', 'visitdate', 't1imageuid'], suffixes=("",mm_suffix),how='left')
10077        return hierdf
10078    except KeyError:
10079        return hierdf
def augment_image(x, max_rot=10, nzsd=1):
10081def augment_image( x,  max_rot=10, nzsd=1 ):
10082    rRotGenerator = ants.contrib.RandomRotate3D( ( max_rot*(-1.0), max_rot ), reference=x )
10083    tx = rRotGenerator.transform()
10084    itx = ants.invert_ants_transform(tx)
10085    y = ants.apply_ants_transform_to_image( tx, x, x, interpolation='linear')
10086    y = ants.add_noise_to_image( y,'additivegaussian', [0,nzsd] )
10087    return y, tx, itx
def boot_wmh( flair, t1, t1seg, mmfromconvexhull=0.0, strict=True, probability_mask=None, prior_probability=None, n_simulations=16, random_seed=42, verbose=False):
10089def boot_wmh( flair, t1, t1seg, mmfromconvexhull = 0.0, strict=True,
10090        probability_mask=None, prior_probability=None, n_simulations=16,
10091        random_seed = 42,
10092        verbose=False ) :
10093    import random
10094    random.seed( random_seed )
10095    if verbose and prior_probability is None:
10096        print("augmented flair")
10097    if verbose and prior_probability is not None:
10098        print("augmented flair with prior")
10099    wmh_sum_aug = 0
10100    wmh_sum_prior_aug = 0
10101    augprob = flair * 0.0
10102    augprob_prior = None
10103    if prior_probability is not None:
10104        augprob_prior = flair * 0.0
10105    for n in range(n_simulations):
10106        augflair, tx, itx = augment_image( ants.iMath(flair,"Normalize"), 5, 0.01 )
10107        locwmh = wmh( augflair, t1, t1seg, mmfromconvexhull = mmfromconvexhull,
10108            strict=strict, probability_mask=None, prior_probability=prior_probability )
10109        if verbose:
10110            print( "flair sim: " + str(n) + " vol: " + str( locwmh['wmh_mass'] )+ " vol-prior: " + str( locwmh['wmh_mass_prior'] )+ " snr: " + str( locwmh['wmh_SNR'] ) )
10111        wmh_sum_aug = wmh_sum_aug + locwmh['wmh_mass']
10112        wmh_sum_prior_aug = wmh_sum_prior_aug + locwmh['wmh_mass_prior']
10113        temp = locwmh['WMH_probability_map']
10114        augprob = augprob + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10115        if prior_probability is not None:
10116            temp = locwmh['WMH_posterior_probability_map']
10117            augprob_prior = augprob_prior + ants.apply_ants_transform_to_image( itx, temp, flair, interpolation='linear')
10118    augprob = augprob * (1.0/float( n_simulations ))
10119    if prior_probability is not None:
10120        augprob_prior = augprob_prior * (1.0/float( n_simulations ))
10121    wmh_sum_aug = wmh_sum_aug / float( n_simulations )
10122    wmh_sum_prior_aug = wmh_sum_prior_aug / float( n_simulations )
10123    return{
10124      'flair' : ants.iMath(flair,"Normalize"),
10125      'WMH_probability_map' : augprob,
10126      'WMH_posterior_probability_map' : augprob_prior,
10127      'wmh_mass': wmh_sum_aug,
10128      'wmh_mass_prior': wmh_sum_prior_aug,
10129      'wmh_evr': locwmh['wmh_evr'],
10130      'wmh_SNR': locwmh['wmh_SNR']  }
def threaded_bind_wide_mm_csvs(mm_wide_csvs, n_workers):
10133def threaded_bind_wide_mm_csvs( mm_wide_csvs, n_workers ):
10134    from concurrent.futures import as_completed
10135    from concurrent import futures
10136    import concurrent.futures
10137    def chunks(l, n):
10138        """Yield n number of sequential chunks from l."""
10139        d, r = divmod(len(l), n)
10140        for i in range(n):
10141            si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r)
10142            yield l[si:si+(d+1 if i < r else d)]
10143    import numpy as np
10144    newx = list( chunks( mm_wide_csvs, n_workers ) )
10145    import pandas as pd
10146    alldf = pd.DataFrame()
10147    alldfavg = pd.DataFrame()
10148    with futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
10149        to_do = []
10150        for group in range(len(newx)) :
10151            future = executor.submit(bind_wide_mm_csvs, newx[group] )
10152            to_do.append(future)
10153        results = []
10154        for future in futures.as_completed(to_do):
10155            res0, res1 = future.result()
10156            alldf=pd.concat(  [alldf, res0 ], axis=0, ignore_index=False )
10157            alldfavg=pd.concat(  [alldfavg, res1 ], axis=0, ignore_index=False )
10158    return alldf, alldfavg
def get_names_from_data_frame(x, demogIn, exclusions=None):
10161def get_names_from_data_frame(x, demogIn, exclusions=None):
10162    """
10163    data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]}
10164    antspymm.get_names_from_data_frame( ['e'], df )
10165    antspymm.get_names_from_data_frame( ['a','e'], df )
10166    antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )
10167    """
10168    # Check if x is a string and convert it to a list
10169    if isinstance(x, str):
10170        x = [x]
10171    def get_unique( qq ):
10172        unique = []
10173        for number in qq:
10174            if number in unique:
10175                continue
10176            else:
10177                unique.append(number)
10178        return unique
10179    outnames = list(demogIn.columns[demogIn.columns.str.contains(x[0])])
10180    if len(x) > 1:
10181        for y in x[1:]:
10182            outnames = [i for i in outnames if y in i]
10183    outnames = get_unique( outnames )
10184    if exclusions is not None:
10185        toexclude = [name for name in outnames if exclusions[0] in name ]
10186        if len(exclusions) > 1:
10187            for zz in exclusions[1:]:
10188                toexclude.extend([name for name in outnames if zz in name ])
10189        if len(toexclude) > 0:
10190            outnames = [name for name in outnames if name not in toexclude]
10191    return outnames

data = {'Name':['Tom', 'nick', 'krish', 'jack'], 'Age':[20, 21, 19, 18]} antspymm.get_names_from_data_frame( ['e'], df ) antspymm.get_names_from_data_frame( ['a','e'], df ) antspymm.get_names_from_data_frame( ['e'], df, exclusions='N' )

def average_mm_df(jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False):
10194def average_mm_df( jmm_in, diagnostic_n=25, corr_thresh=0.9, verbose=False ):
10195    """
10196    jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )
10197    """
10198
10199    jmm = jmm_in.copy()
10200    dxcols=['subjectid1','subjectid2','modalityid','joinid','correlation','distance']
10201    joinDiagnostics = pd.DataFrame( columns = dxcols )
10202    nanList=[math.nan]
10203    def rob(x, y=0.99):
10204        x[x > np.quantile(x, y, nan_policy="omit")] = np.nan
10205        return x
10206
10207    jmm = jmm.replace(r'^\s*$', np.nan, regex=True)
10208
10209    if verbose:
10210        print("do rsfMRI")
10211    # here - we first have to average within each row
10212    dt0 = get_names_from_data_frame(["rsfMRI"], jmm, exclusions=["Unnamed", "rsfMRI_LR", "rsfMRI_RL"])
10213    dt1 = get_names_from_data_frame(["rsfMRI_RL"], jmm, exclusions=["Unnamed"])
10214    if len( dt0 ) > 0 and len( dt1 ) > 0:
10215        flid = dt0[0]
10216        wrows = []
10217        for i in range(jmm.shape[0]):
10218            if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) :
10219                wrows.append(i)
10220        for k in wrows:
10221            v1 = jmm.iloc[k][dt0[1:]].astype(float)
10222            v2 = jmm.iloc[k][dt1[1:]].astype(float)
10223            vvec = [v1[0], v2[0]]
10224            if any(~np.isnan(vvec)):
10225                mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10226                jmm.iloc[k][dt0[0]] = 'rsfMRI'
10227                if len(mynna) == 1:
10228                    if mynna[0] == 0:
10229                        jmm.iloc[k][dt0[1:]] = v1
10230                    if mynna[0] == 1:
10231                        jmm.iloc[k][dt0[1:]] = v2
10232                elif len(mynna) > 1:
10233                    if len(v2) > diagnostic_n:
10234                        v1dx=v1[0:diagnostic_n]
10235                        v2dx=v2[0:diagnostic_n]
10236                    else :
10237                        v1dx=v1
10238                        v2dx=v2
10239                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10240                    mycorr = np.corrcoef( v1dx.values, v2dx.values )[0,1]
10241                    myerr=np.sqrt(np.mean((v1dx.values - v2dx.values)**2))
10242                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'rsfMRI','colavg',mycorr,myerr]
10243                    if mycorr > corr_thresh:
10244                        jmm.loc[k, dt0[1:]] = v1.values*0.5 + v2.values*0.5
10245                    else:
10246                        jmm.loc[k, dt0[1:]] = nanList * len(v1)
10247                    if verbose:
10248                        print( joinDiagnosticsLoc )
10249                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10250
10251    if verbose:
10252        print("do DTI")
10253    # here - we first have to average within each row
10254    dt0 = get_names_from_data_frame(["DTI"], jmm, exclusions=["Unnamed", "DTI_LR", "DTI_RL"])
10255    dt1 = get_names_from_data_frame(["DTI_LR"], jmm, exclusions=["Unnamed"])
10256    dt2 = get_names_from_data_frame( ["DTI_RL"], jmm, exclusions=["Unnamed"])
10257    flid = dt0[0]
10258    wrows = []
10259    for i in range(jmm.shape[0]):
10260        if not pd.isna(jmm[dt0[1]][i]) or not pd.isna(jmm[dt1[1]][i]) or not pd.isna(jmm[dt2[1]][i]):
10261            wrows.append(i)
10262    for k in wrows:
10263        v1 = jmm.loc[k, dt0[1:]].astype(float)
10264        v2 = jmm.loc[k, dt1[1:]].astype(float)
10265        v3 = jmm.loc[k, dt2[1:]].astype(float)
10266        checkcol = dt0[5]
10267        if not np.isnan(v1[checkcol]):
10268            if v1[checkcol] < 0.25:
10269                v1.replace(np.nan, inplace=True)
10270        checkcol = dt1[5]
10271        if not np.isnan(v2[checkcol]):
10272            if v2[checkcol] < 0.25:
10273                v2.replace(np.nan, inplace=True)
10274        checkcol = dt2[5]
10275        if not np.isnan(v3[checkcol]):
10276            if v3[checkcol] < 0.25:
10277                v3.replace(np.nan, inplace=True)
10278        vvec = [v1[0], v2[0], v3[0]]
10279        if any(~np.isnan(vvec)):
10280            mynna = [i for i, x in enumerate(vvec) if ~np.isnan(x)]
10281            jmm.loc[k, dt0[0]] = 'DTI'
10282            if len(mynna) == 1:
10283                if mynna[0] == 0:
10284                    jmm.loc[k, dt0[1:]] = v1
10285                if mynna[0] == 1:
10286                    jmm.loc[k, dt0[1:]] = v2
10287                if mynna[0] == 2:
10288                    jmm.loc[k, dt0[1:]] = v3
10289            elif len(mynna) > 1:
10290                if mynna[0] == 0:
10291                    jmm.loc[k, dt0[1:]] = v1
10292                else:
10293                    joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10294                    mycorr = np.corrcoef( v2[0:diagnostic_n].values, v3[0:diagnostic_n].values )[0,1]
10295                    myerr=np.sqrt(np.mean((v2[0:diagnostic_n].values - v3[0:diagnostic_n].values)**2))
10296                    joinDiagnosticsLoc.iloc[0] = [jmm.loc[k,'u_hier_id'],math.nan,'DTI','colavg',mycorr,myerr]
10297                    if mycorr > corr_thresh:
10298                        jmm.loc[k, dt0[1:]] = v2.values*0.5 + v3.values*0.5
10299                    else: #
10300                        jmm.loc[k, dt0[1:]] = nanList * len( dt0[1:] )
10301                    if verbose:
10302                        print( joinDiagnosticsLoc )
10303                    joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], axis=0, ignore_index=False )
10304
10305
10306    # first task - sort by u_hier_id
10307    jmm = jmm.sort_values( "u_hier_id" )
10308    # get rid of junk columns
10309    badnames = get_names_from_data_frame( ['Unnamed'], jmm )
10310    jmm=jmm.drop(badnames, axis=1)
10311    jmm=jmm.set_index("u_hier_id",drop=False)
10312    # 2nd - get rid of duplicated u_hier_id
10313    jmmUniq = jmm.drop_duplicates( subset="u_hier_id" ) # fast and easy
10314    # for each modality, count which ids have more than one
10315    mod_names = get_valid_modalities()
10316    for mod_name in mod_names:
10317        fl_names = get_names_from_data_frame([mod_name], jmm,
10318            exclusions=['Unnamed',"DTI_LR","DTI_RL","rsfMRI_RL","rsfMRI_LR"])
10319        if len( fl_names ) > 1:
10320            if verbose:
10321                print(mod_name)
10322                print(fl_names)
10323            fl_id = fl_names[0]
10324            n_names = len(fl_names)
10325            locvec = jmm[fl_names[n_names-1]].astype(float)
10326            boolvec=~pd.isna(locvec)
10327            jmmsub = jmm[boolvec][ ['u_hier_id']+fl_names]
10328            my_tbl = Counter(jmmsub['u_hier_id'])
10329            gtoavg = [name for name in my_tbl.keys() if my_tbl[name] == 1]
10330            gtoavgG1 = [name for name in my_tbl.keys() if my_tbl[name] > 1]
10331            if verbose:
10332                print("Join 1")
10333            jmmsub1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavg)][['u_hier_id']+fl_names]
10334            for u in gtoavg:
10335                jmmUniq.loc[u][fl_names[1:]] = jmmsub1.loc[u][fl_names[1:]]
10336            if verbose and len(gtoavgG1) > 1:
10337                print("Join >1")
10338            jmmsubG1 = jmmsub.loc[jmmsub['u_hier_id'].isin(gtoavgG1)][['u_hier_id']+fl_names]
10339            for u in gtoavgG1:
10340                temp = jmmsubG1.loc[u][ ['u_hier_id']+fl_names ]
10341                dropnames = get_names_from_data_frame( ['MM.ID'], temp )
10342                tempVec = temp.drop(columns=dropnames)
10343                joinDiagnosticsLoc = pd.DataFrame( columns = dxcols, index=range(1) )
10344                id1=temp[fl_id].iloc[0]
10345                id2=temp[fl_id].iloc[1]
10346                v1=tempVec.iloc[0][1:].astype(float).to_numpy()
10347                v2=tempVec.iloc[1][1:].astype(float).to_numpy()
10348                if len(v2) > diagnostic_n:
10349                    v1=v1[0:diagnostic_n]
10350                    v2=v2[0:diagnostic_n]
10351                mycorr = np.corrcoef( v1, v2 )[0,1]
10352                # mycorr=temparr[np.triu_indices_from(temparr, k=1)].mean()
10353                myerr=np.sqrt(np.mean((v1 - v2)**2))
10354                joinDiagnosticsLoc.iloc[0] = [id1,id2,mod_name,'rowavg',mycorr,myerr]
10355                if verbose:
10356                    print( joinDiagnosticsLoc )
10357                temp = jmmsubG1.loc[u][fl_names[1:]].astype(float)
10358                if mycorr > corr_thresh or len( v1 ) < 10:
10359                    jmmUniq.loc[u][fl_names[1:]] = temp.mean(axis=0)
10360                else:
10361                    jmmUniq.loc[u][fl_names[1:]] = nanList * temp.shape[1]
10362                joinDiagnostics = pd.concat( [joinDiagnostics, joinDiagnosticsLoc], 
10363                                            axis=0, ignore_index=False )
10364
10365    return jmmUniq, jmm, joinDiagnostics

jmrowavg, jmmcolavg, diagnostics = antspymm.average_mm_df( jmm_in, verbose=True )

def quick_viz_mm_nrg( sourcedir, projectid, sid, dtid, extract_brain=True, slice_factor=0.55, post=False, original_sourcedir=None, filename=None, verbose=True):
10369def quick_viz_mm_nrg(
10370    sourcedir, # root folder
10371    projectid, # project name
10372    sid , # subject unique id
10373    dtid, # date
10374    extract_brain=True,
10375    slice_factor = 0.55,
10376    post = False,
10377    original_sourcedir = None,
10378    filename = None, # output path
10379    verbose = True
10380):
10381    """
10382    This function creates visualizations of brain images for a specific subject in a project using ANTsPy.
10383
10384    Args:
10385
10386    sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)
10387    
10388    projectid (str): Project name.
10389    
10390    sid (str): Subject unique id.
10391    
10392    dtid (str): Date.
10393    
10394    extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.
10395    
10396    slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.
10397
10398    post ( bool ) : if True, will visualize example post-processing results.
10399    
10400    original_sourcedir (str): Root folder for original data (used if post=True)
10401    
10402    filename (str): Output path with extension (.png)
10403    
10404    verbose (bool): If True, information will be printed while running the function. Default is True.
10405
10406    Returns:
10407    None
10408
10409    """
10410    iid='*'
10411    import glob as glob
10412    from os.path import exists
10413    import ants
10414    temp = sourcedir.split( "/" )
10415    subjectrootpath = os.path.join(sourcedir, projectid, sid, dtid)
10416    if verbose:
10417        print( 'subjectrootpath' )
10418        print( subjectrootpath )
10419    t1_search_path = os.path.join(subjectrootpath, "T1w", "*", "*nii.gz")
10420    if verbose:
10421        print(f"t1 search path: {t1_search_path}")
10422    t1fn = glob.glob(t1_search_path)
10423    if len( t1fn ) < 1:
10424        raise ValueError('quick_viz_mm_nrg cannot find the T1w @ ' + subjectrootpath )
10425    vizlist=[]
10426    undlist=[]
10427    nrg_modality_list = [ 'T1w', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10428    if post:
10429        nrg_modality_list = [ 'T1wHierarchical', 'DTI', 'rsfMRI', 'perf', 'T2Flair', 'NM2DMT' ]
10430    for nrgNum in [0,1,2,3,4,5]:
10431        underlay = None
10432        overmodX = nrg_modality_list[nrgNum]
10433        if  'T1w' in overmodX :
10434            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10435            if post:
10436                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*brain_n4_dnz.nii.gz")
10437                mod_search_path_ol = os.path.join(subjectrootpath, overmodX, iid, "*thickness_image.nii.gz" )
10438                mod_search_path_ol = re.sub( "T1wHierarchical","T1w",mod_search_path_ol)
10439                myol = glob.glob(mod_search_path_ol)
10440                if len( myol ) > 0:
10441                    temper = find_most_recent_file( myol )[0]
10442                    underlay = ants.image_read(  temper )
10443                    if verbose:
10444                        print("T1w overlay " + temper )
10445                    underlay = underlay * ants.threshold_image( underlay, 0.2, math.inf )
10446            myimgsr = glob.glob(mod_search_path)
10447            if len( myimgsr ) == 0:
10448                if verbose:
10449                    print("No t1 images: " + sid + dtid )
10450                return None
10451            myimgsr=find_most_recent_file( myimgsr )[0]
10452            vimg=ants.image_read( myimgsr )
10453        elif  'T2Flair' in overmodX :
10454            if verbose:
10455                print("search flair")
10456            mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*nii.gz")
10457            if post and original_sourcedir is not None:
10458                if verbose:
10459                    print("post in flair")
10460                mysubdir = os.path.join(original_sourcedir, projectid, sid, dtid)
10461                mod_search_path_under = os.path.join(mysubdir, overmodX, iid, "*T2Flair*.nii.gz")
10462                if verbose:
10463                    print("post in flair mod_search_path_under " + mod_search_path_under)
10464                mod_search_path = os.path.join(subjectrootpath, overmodX, iid, "*wmh.nii.gz")
10465                if verbose:
10466                    print("post in flair mod_search_path " + mod_search_path )
10467                myimgul = glob.glob(mod_search_path_under)
10468                if len( myimgul ) > 0:
10469                    myimgul = find_most_recent_file( myimgul )[0]
10470                    if verbose:
10471                        print("Flair  " + myimgul )
10472                    vimg = ants.image_read( myimgul )
10473                    myol = glob.glob(mod_search_path)
10474                    if len( myol ) == 0:
10475                        underlay = myimgsr * 0.0
10476                    else:
10477                        myol = find_most_recent_file( myol )[0]
10478                        if verbose:
10479                            print("Flair overlay " + myol )
10480                        underlay=ants.image_read( myol )
10481                        underlay=underlay*ants.threshold_image(underlay,0.05,math.inf)
10482                else:
10483                    vimg = noizimg.clone()
10484                    underlay = vimg * 0.0
10485            if original_sourcedir is None:
10486                myimgsr = glob.glob(mod_search_path)
10487                if len( myimgsr ) == 0:
10488                    vimg = noizimg.clone()
10489                else:
10490                    myimgsr=find_most_recent_file( myimgsr )[0]
10491                    vimg=ants.image_read( myimgsr )
10492        elif overmodX == 'DTI':
10493            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10494            if post:
10495                mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*fa.nii.gz")
10496            myimgsr = glob.glob(mod_search_path)
10497            if len( myimgsr ) > 0:
10498                myimgsr=find_most_recent_file( myimgsr )[0]
10499                vimg=ants.image_read( myimgsr )
10500            else:
10501                if verbose:
10502                    print("No " + overmodX)
10503                vimg = noizimg.clone()
10504        elif overmodX == 'DTI2':
10505            mod_search_path = os.path.join(subjectrootpath, 'DTI*', "*", "*nii.gz")
10506            myimgsr = glob.glob(mod_search_path)
10507            if len( myimgsr ) > 0:
10508                myimgsr.sort()
10509                myimgsr=myimgsr[len(myimgsr)-1]
10510                vimg=ants.image_read( myimgsr )
10511            else:
10512                if verbose:
10513                    print("No " + overmodX)
10514                vimg = noizimg.clone()
10515        elif overmodX == 'NM2DMT':
10516            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10517            if post:
10518                mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*NM_avg.nii.gz" )
10519            myimgsr = glob.glob(mod_search_path)
10520            if len( myimgsr ) > 0:
10521                myimgsr0=myimgsr[0]
10522                vimg=ants.image_read( myimgsr0 )
10523                for k in range(1,len(myimgsr)):
10524                    temp = ants.image_read( myimgsr[k])
10525                    vimg=vimg+ants.resample_image_to_target(temp,vimg)
10526            else:
10527                if verbose:
10528                    print("No " + overmodX)
10529                vimg = noizimg.clone()
10530        elif overmodX == 'rsfMRI':
10531            mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*nii.gz")
10532            if post:
10533                mod_search_path = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_meanBold.nii.gz" )
10534                mod_search_path_ol = os.path.join(subjectrootpath, 'rsfMRI*', "*", "*fcnxpro122_DefaultMode.nii.gz" )
10535                myol = glob.glob(mod_search_path_ol)
10536                if len( myol ) > 0:
10537                    myol = find_most_recent_file( myol )[0]
10538                    underlay = ants.image_read( myol )
10539                    if verbose:
10540                        print("BOLD overlay " + myol )
10541                    underlay = underlay * ants.threshold_image( underlay, 0.1, math.inf )
10542            myimgsr = glob.glob(mod_search_path)
10543            if len( myimgsr ) > 0:
10544                myimgsr=find_most_recent_file( myimgsr )[0]
10545                vimg=mm_read_to_3d( myimgsr )
10546            else:
10547                if verbose:
10548                    print("No " + overmodX)
10549                vimg = noizimg.clone()
10550        elif overmodX == 'perf':
10551            mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*nii.gz")
10552            if post:
10553                mod_search_path = os.path.join(subjectrootpath, 'perf*', "*", "*cbf.nii.gz")
10554            myimgsr = glob.glob(mod_search_path)
10555            if len( myimgsr ) > 0:
10556                myimgsr=find_most_recent_file( myimgsr )[0]
10557                vimg=mm_read_to_3d( myimgsr )
10558            else:
10559                if verbose:
10560                    print("No " + overmodX)
10561                vimg = noizimg.clone()
10562        else :
10563            if verbose:
10564                print("Something else here")
10565            mod_search_path = os.path.join(subjectrootpath, overmodX, "*", "*nii.gz")
10566            myimgsr = glob.glob(mod_search_path)
10567            if post:
10568                myimgsr=[]
10569            if len( myimgsr ) > 0:
10570                myimgsr=find_most_recent_file( myimgsr )[0]
10571                vimg=ants.image_read( myimgsr )
10572            else:
10573                if verbose:
10574                    print("No " + overmodX)
10575                vimg = noizimg
10576        if True:
10577            if extract_brain and overmodX == 'T1w' and post == False:
10578                vimg = vimg * antspyt1w.brain_extraction(vimg)
10579            if verbose:
10580                print(f"modality search path: {myimgsr}" + " num: " + str(nrgNum))
10581            if vimg.dimension == 4 and ( overmodX == "DTI2"  ):
10582                ttb0, ttdw=get_average_dwi_b0(vimg)
10583                vimg = ttdw
10584            elif vimg.dimension == 4 and overmodX == "DTI":
10585                ttb0, ttdw=get_average_dwi_b0(vimg)
10586                vimg = ttb0
10587            elif vimg.dimension == 4 :
10588                vimg=ants.get_average_of_timeseries(vimg)
10589            msk=ants.get_mask(vimg)
10590            if overmodX == 'T2Flair':
10591                msk=vimg*0+1
10592            if underlay is not None:
10593                print( overmodX + " has underlay" )
10594            else:
10595                underlay = vimg * 0.0
10596            if nrgNum == 0:
10597                refimg=ants.image_clone( vimg )
10598                noizimg = ants.add_noise_to_image( refimg*0, 'additivegaussian', [100,1] )
10599                vizlist.append( vimg )
10600                undlist.append( underlay )
10601            else:
10602                vimg = ants.iMath( vimg, 'TruncateIntensity',0.01,0.98)
10603                vizlist.append( ants.iMath( vimg, 'Normalize' ) * 255 )
10604                undlist.append( underlay )
10605
10606    # mask & crop systematically ...
10607    msk = ants.get_mask( refimg )
10608    refimg = ants.crop_image( refimg, msk )
10609
10610    for jj in range(len(vizlist)):
10611        vizlist[jj]=ants.resample_image_to_target( vizlist[jj], refimg )
10612        undlist[jj]=ants.resample_image_to_target( undlist[jj], refimg )
10613        print( 'viz: ' + str( jj ) )
10614        print( vizlist[jj] )
10615        print( 'und: ' + str( jj ) )
10616        print( undlist[jj] )
10617
10618
10619    xyz = [None]*3
10620    for i in range(3):
10621        if xyz[i] is None:
10622            xyz[i] = int(refimg.shape[i] * slice_factor )
10623
10624    if verbose:
10625        print('slice positions')
10626        print( xyz )
10627
10628    ants.plot_ortho_stack( vizlist, overlays=undlist, crop=False, reorient=False, filename=filename, xyz=xyz, orient_labels=False )
10629    return
10630    # listlen = len( vizlist )
10631    # vizlist = np.asarray( vizlist )
10632    if show_it is not None:
10633        filenameout=None
10634        if verbose:
10635            print( show_it )
10636        for a in [0,1,2]:
10637            n=int(np.round( refimg.shape[a] * slice_factor ))
10638            slices=np.repeat( int(n), listlen  )
10639            if isinstance(show_it,str):
10640                filenameout=show_it+'_ax'+str(int(a))+'_sl'+str(n)+'.png'
10641                if verbose:
10642                    print( filenameout )
10643#            ants.plot_grid(vizlist.reshape(2,3), slices.reshape(2,3), title='MM Subject ' + sid + ' ' + dtid, rfacecolor='white', axes=a, filename=filenameout )
10644    if verbose:
10645        print("viz complete.")
10646    return vizlist

This function creates visualizations of brain images for a specific subject in a project using ANTsPy.

Args:

sourcedir (str): Root folder for original data (if post=False) or processed data (post=True)

projectid (str): Project name.

sid (str): Subject unique id.

dtid (str): Date.

extract_brain (bool): If True, the function extracts the brain from the T1w image. Default is True.

slice_factor (float): The slice to be visualized is determined by multiplying the image size by this factor. Default is 0.55.

post ( bool ) : if True, will visualize example post-processing results.

original_sourcedir (str): Root folder for original data (used if post=True)

filename (str): Output path with extension (.png)

verbose (bool): If True, information will be printed while running the function. Default is True.

Returns: None

def blind_image_assessment( image, viz_filename=None, title=False, pull_rank=False, resample=None, n_to_skip=10, verbose=False):
10649def blind_image_assessment(
10650    image,
10651    viz_filename=None,
10652    title=False,
10653    pull_rank=False,
10654    resample=None,
10655    n_to_skip = 10,
10656    verbose=False
10657):
10658    """
10659    quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D.  produces a png and csv where csv contains:
10660
10661    * reflection error ( estimates asymmetry )
10662
10663    * brisq ( blind quality assessment )
10664
10665    * patch eigenvalue ratio ( blind quality assessment )
10666
10667    * PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)
10668
10669    * mask volume ( estimates foreground object size )
10670
10671    * spacing
10672
10673    * dimension after cropping by mask
10674
10675    image : character or image object usually a nifti image
10676
10677    viz_filename : character for a png output image
10678
10679    title : display a summary title on the png
10680
10681    pull_rank : boolean
10682
10683    resample : None, numeric max or min, resamples image to isotropy
10684
10685    n_to_skip : 10 by default; samples time series every n_to_skip volume
10686
10687    verbose : boolean
10688
10689    """
10690    import glob as glob
10691    from os.path import exists
10692    import ants
10693    import matplotlib.pyplot as plt
10694    from PIL import Image
10695    from pathlib import Path
10696    import json
10697    import re
10698    from dipy.io.gradients import read_bvals_bvecs
10699    mystem=''
10700    if isinstance(image,list):
10701        isfilename=isinstance( image[0], str)
10702        image = image[0]
10703    else:
10704        isfilename=isinstance( image, str)
10705    outdf = pd.DataFrame()
10706    mymeta = None
10707    MagneticFieldStrength = None
10708    image_filename=''
10709    if isfilename:
10710        image_filename = image
10711        if isinstance(image,list):
10712            image_filename=image[0]
10713        json_name = re.sub(".nii.gz",".json",image_filename)
10714        if exists( json_name ):
10715            try:
10716                with open(json_name, 'r') as fcc_file:
10717                    mymeta = json.load(fcc_file)
10718                    if verbose:
10719                        print(json.dumps(mymeta, indent=4))
10720                    fcc_file.close()
10721            except:
10722                pass
10723        mystem=Path( image ).stem
10724        mystem=Path( mystem ).stem
10725        image_reference = ants.image_read( image )
10726        image = ants.image_read( image )
10727    else:
10728        image_reference = ants.image_clone( image )
10729    ntimepoints = 1
10730    bvalueMax=None
10731    bvecnorm=None
10732    if image_reference.dimension == 4:
10733        ntimepoints = image_reference.shape[3]
10734        if "DTI" in image_filename:
10735            myTSseg = segment_timeseries_by_meanvalue( image_reference )
10736            image_b0, image_dwi = get_average_dwi_b0( image_reference, fast=True )
10737            image_b0 = ants.iMath( image_b0, 'Normalize' )
10738            image_dwi = ants.iMath( image_dwi, 'Normalize' )
10739            bval_name = re.sub(".nii.gz",".bval",image_filename)
10740            bvec_name = re.sub(".nii.gz",".bvec",image_filename)
10741            if exists( bval_name ) and exists( bvec_name ):
10742                bvals, bvecs = read_bvals_bvecs( bval_name , bvec_name  )
10743                bvalueMax = bvals.max()
10744                bvecnorm = np.linalg.norm(bvecs,axis=1).reshape( bvecs.shape[0],1 )
10745                bvecnorm = bvecnorm.max()
10746        else:
10747            image_b0 = ants.get_average_of_timeseries( image_reference ).iMath("Normalize")
10748    else:
10749        image_compare = ants.smooth_image( image_reference, 3, sigma_in_physical_coordinates=False )
10750    for jjj in range(0,ntimepoints,n_to_skip):
10751        modality='unknown'
10752        if "rsfMRI" in image_filename:
10753            modality='rsfMRI'
10754        elif "perf" in image_filename:
10755            modality='perf'
10756        elif "DTI" in image_filename:
10757            modality='DTI'
10758        elif "T1w" in image_filename:
10759            modality='T1w'
10760        elif "T2Flair" in image_filename:
10761            modality='T2Flair'
10762        elif "NM2DMT" in image_filename:
10763            modality='NM2DMT'
10764        if image_reference.dimension == 4:
10765            image = ants.slice_image( image_reference, idx=int(jjj), axis=3 )
10766            if "DTI" in image_filename:
10767                if jjj in myTSseg['highermeans']:
10768                    image_compare = ants.image_clone( image_b0 )
10769                    modality='DTIb0'
10770                else:
10771                    image_compare = ants.image_clone( image_dwi )
10772                    modality='DTIdwi'
10773            else:
10774                image_compare = ants.image_clone( image_b0 )
10775        # image = ants.iMath( image, 'TruncateIntensity',0.01,0.995)
10776        minspc = np.min(ants.get_spacing(image))
10777        maxspc = np.max(ants.get_spacing(image))
10778        if resample is not None:
10779            if resample == 'min':
10780                if minspc < 1e-12:
10781                    minspc = np.max(ants.get_spacing(image))
10782                newspc = np.repeat( minspc, 3 )
10783            elif resample == 'max':
10784                newspc = np.repeat( maxspc, 3 )
10785            else:
10786                newspc = np.repeat( resample, 3 )
10787            image = ants.resample_image( image, newspc )
10788            image_compare = ants.resample_image( image_compare, newspc )
10789        else:
10790            # check for spc close to zero
10791            spc = list(ants.get_spacing(image))
10792            for spck in range(len(spc)):
10793                if spc[spck] < 1e-12:
10794                    spc[spck]=1
10795            ants.set_spacing( image, spc )
10796            ants.set_spacing( image_compare, spc )
10797        # if "NM2DMT" in image_filename or "FIXME" in image_filename or "SPECT" in image_filename or "UNKNOWN" in image_filename:
10798        minspc = np.min(ants.get_spacing(image))
10799        maxspc = np.max(ants.get_spacing(image))
10800        msk = ants.threshold_image( ants.iMath(image,'Normalize'), 0.15, 1.0 )
10801        # else:
10802        #    msk = ants.get_mask( image )
10803        msk = ants.morphology(msk, "close", 3 )
10804        bgmsk = msk*0+1-msk
10805        mskdil = ants.iMath(msk, "MD", 4 )
10806        # ants.plot_ortho( image, msk, crop=False )
10807        nvox = int( msk.sum() )
10808        spc = ants.get_spacing( image )
10809        org = ants.get_origin( image )
10810        if ( nvox > 0 ):
10811            image = ants.crop_image( image, mskdil ).iMath("Normalize")
10812            msk = ants.crop_image( msk, mskdil ).iMath("Normalize")
10813            bgmsk = ants.crop_image( bgmsk, mskdil ).iMath("Normalize")
10814            image_compare = ants.crop_image( image_compare, mskdil ).iMath("Normalize")           
10815            npatch = int( np.round(  0.1 * nvox ) )
10816            npatch = np.min(  [512,npatch ] )
10817            patch_shape = []
10818            for k in range( 3 ):
10819                p = int( 32.0 / ants.get_spacing( image  )[k] )
10820                if p > int( np.round( image.shape[k] * 0.5 ) ):
10821                    p = int( np.round( image.shape[k] * 0.5 ) )
10822                patch_shape.append( p )
10823            if verbose:
10824                print(image)
10825                print( patch_shape )
10826                print( npatch )
10827            myevr = math.nan # dont want to fail if something odd happens in patch extraction
10828            try:
10829                myevr = antspyt1w.patch_eigenvalue_ratio( image, npatch, patch_shape,
10830                    evdepth = 0.9, mask=msk )
10831            except:
10832                pass
10833            if pull_rank:
10834                image = ants.rank_intensity(image)
10835            imagereflect = ants.reflect_image(image, axis=0)
10836            asym_err = ( image - imagereflect ).abs().mean()
10837            # estimate noise by center cropping, denoizing and taking magnitude of difference
10838            nocrop=False
10839            if image.dimension == 3:
10840                if image.shape[2] == 1:
10841                    nocrop=True        
10842            if maxspc/minspc > 10:
10843                nocrop=True
10844            if nocrop:
10845                mycc = ants.image_clone( image )
10846            else:
10847                mycc = antspyt1w.special_crop( image,
10848                    ants.get_center_of_mass( msk *0 + 1 ), patch_shape )
10849            myccd = ants.denoise_image( mycc, p=2,r=2,noise_model='Gaussian' )
10850            noizlevel = ( mycc - myccd ).abs().mean()
10851    #        ants.plot_ortho( image, crop=False, filename=viz_filename, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0 )
10852    #        from brisque import BRISQUE
10853    #        obj = BRISQUE(url=False)
10854    #        mybrisq = obj.score( np.array( Image.open( viz_filename )) )
10855            msk_vol = msk.sum() * np.prod( spc )
10856            bgstd = image[ bgmsk == 1 ].std()
10857            fgmean = image[ msk == 1 ].mean()
10858            bgmean = image[ bgmsk == 1 ].mean()
10859            snrref = fgmean / bgstd
10860            cnrref = ( fgmean - bgmean ) / bgstd
10861            psnrref = antspynet.psnr(  image_compare, image  )
10862            ssimref = antspynet.ssim(  image_compare, image  )
10863            if nocrop:
10864                mymi = math.inf
10865            else:
10866                mymi = ants.image_mutual_information( image_compare, image )
10867        else:
10868            msk_vol = 0
10869            myevr = mymi = ssimref = psnrref = cnrref = asym_err = noizlevel = math.nan
10870            
10871        mriseries=None
10872        mrimfg=None
10873        mrimodel=None
10874        mriSAR=None
10875        BandwidthPerPixelPhaseEncode=None
10876        PixelBandwidth=None
10877        if mymeta is not None:
10878            # mriseries=mymeta['']
10879            try:
10880                mrimfg=mymeta['Manufacturer']
10881            except:
10882                pass
10883            try:
10884                mrimodel=mymeta['ManufacturersModelName']
10885            except:
10886                pass
10887            try:
10888                MagneticFieldStrength=mymeta['MagneticFieldStrength']
10889            except:
10890                pass
10891            try:
10892                PixelBandwidth=mymeta['PixelBandwidth']
10893            except:
10894                pass
10895            try:
10896                BandwidthPerPixelPhaseEncode=mymeta['BandwidthPerPixelPhaseEncode']
10897            except:
10898                pass
10899            try:
10900                mriSAR=mymeta['SAR']
10901            except:
10902                pass
10903        ttl=mystem + ' '
10904        ttl=''
10905        ttl=ttl + "NZ: " + "{:0.4f}".format(noizlevel) + " SNR: " + "{:0.4f}".format(snrref) + " CNR: " + "{:0.4f}".format(cnrref) + " PS: " + "{:0.4f}".format(psnrref)+ " SS: " + "{:0.4f}".format(ssimref) + " EVR: " + "{:0.4f}".format(myevr)+ " MI: " + "{:0.4f}".format(mymi)
10906        if viz_filename is not None and ( jjj == 0 or (jjj % 30 == 0) ) and image.shape[2] < 685:
10907            viz_filename_use = re.sub( ".png", "_slice"+str(jjj).zfill(4)+".png", viz_filename )
10908            ants.plot_ortho( image, crop=False, filename=viz_filename_use, flat=True, xyz_lines=False, orient_labels=False, xyz_pad=0,  title=ttl, titlefontsize=12, title_dy=-0.02,textfontcolor='red' )
10909        df = pd.DataFrame([[ 
10910            mystem, 
10911            image_reference.dimension, 
10912            noizlevel, snrref, cnrref, psnrref, ssimref, mymi, asym_err, myevr, msk_vol, 
10913            spc[0], spc[1], spc[2],org[0], org[1], org[2], 
10914            image.shape[0], image.shape[1], image.shape[2], ntimepoints, 
10915            jjj, modality, mriseries, mrimfg, mrimodel, MagneticFieldStrength, mriSAR, PixelBandwidth, BandwidthPerPixelPhaseEncode, bvalueMax, bvecnorm ]], 
10916            columns=[
10917                'filename', 
10918                'dimensionality',
10919                'noise', 'snr', 'cnr', 'psnr', 'ssim', 'mi', 'reflection_err', 'EVR', 'msk_vol', 'spc0','spc1','spc2','org0','org1','org2','dimx','dimy','dimz','dimt','slice','modality', 'mriseries', 'mrimfg', 'mrimodel', 'mriMagneticFieldStrength', 'mriSAR', 'mriPixelBandwidth', 'mriPixelBandwidthPE', 'dti_bvalueMax', 'dti_bvecnorm' ])
10920        outdf = pd.concat( [outdf, df ], axis=0, ignore_index=False )
10921        if verbose:
10922            print( outdf )
10923    if viz_filename is not None:
10924        csvfn = re.sub( "png", "csv", viz_filename )
10925        outdf.to_csv( csvfn )
10926    return outdf

quick blind image assessment and triplanar visualization of an image ... 4D input will be visualized and assessed in 3D. produces a png and csv where csv contains:

  • reflection error ( estimates asymmetry )

  • brisq ( blind quality assessment )

  • patch eigenvalue ratio ( blind quality assessment )

  • PSNR and SSIM vs a smoothed reference (4D or 3D appropriate)

  • mask volume ( estimates foreground object size )

  • spacing

  • dimension after cropping by mask

image : character or image object usually a nifti image

viz_filename : character for a png output image

title : display a summary title on the png

pull_rank : boolean

resample : None, numeric max or min, resamples image to isotropy

n_to_skip : 10 by default; samples time series every n_to_skip volume

verbose : boolean

def average_blind_qc_by_modality(qc_full, verbose=False):
10952def average_blind_qc_by_modality(qc_full,verbose=False):
10953    """
10954    Averages time series qc results to yield one entry per image. this also filters to "known" columns.
10955
10956    Args:
10957    qc_full: pandas dataframe containing the full qc data.
10958
10959    Returns:
10960    pandas dataframe containing the processed qc data.
10961    """
10962    qc_full = remove_unwanted_columns( qc_full )
10963    # Get unique modalities
10964    modalities = qc_full['modality'].unique()
10965    modalities = modalities[modalities != 'unknown']
10966    # Get unique ids
10967    uid = qc_full['filename']
10968    to_average = uid.unique()
10969    meta = pd.DataFrame(columns=qc_full.columns )
10970    # Process each unique id
10971    n = len(to_average)
10972    for k in range(n):
10973        if verbose:
10974            if k % 100 == 0:
10975                progger = str( np.round( k / n * 100 ) )
10976                print( progger, end ="...", flush=True)
10977        m1sel = uid == to_average[k]
10978        if sum(m1sel) > 1:
10979            # If more than one entry for id, take the average of continuous columns,
10980            # maximum of the slice column, and the first entry of the other columns
10981            mfsub = process_dataframe_generalized(qc_full[m1sel],'filename')
10982        else:
10983            mfsub = qc_full[m1sel]
10984        meta.loc[k] = mfsub.iloc[0]
10985    meta['modality'] = meta['modality'].replace(['DTIdwi', 'DTIb0'], 'DTI', regex=True)
10986    return meta

Averages time series qc results to yield one entry per image. this also filters to "known" columns.

Args: qc_full: pandas dataframe containing the full qc data.

Returns: pandas dataframe containing the processed qc data.

def best_mmm(mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
1599def best_mmm( mmdf, wmod, mysep='-', outlier_column='ol_loop', verbose=False):
1600    """
1601    Selects the best repeats per modality.
1602
1603    Args:
1604    wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')
1605
1606    mysep (str, optional): the separator used in the image file names. Defaults to '-'.
1607
1608    outlier_name : column name for outlier score
1609
1610    verbose (bool, optional): default True
1611
1612    Returns:
1613
1614    list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.
1615
1616    """
1617#    mmdf = mmdf.astype(str)
1618    mmdf[outlier_column]=mmdf[outlier_column].astype(float)
1619    msel = mmdf['modality'] == wmod
1620    if wmod == 'rsfMRI':
1621        msel1 = mmdf['modality'] == 'rsfMRI'
1622        msel2 = mmdf['modality'] == 'rsfMRI_LR'
1623        msel3 = mmdf['modality'] == 'rsfMRI_RL'
1624        msel = msel1 | msel2
1625        msel = msel | msel3
1626    if wmod == 'DTI':
1627        msel1 = mmdf['modality'] == 'DTI'
1628        msel2 = mmdf['modality'] == 'DTI_LR'
1629        msel3 = mmdf['modality'] == 'DTI_RL'
1630        msel4 = mmdf['modality'] == 'DTIdwi'
1631        msel5 = mmdf['modality'] == 'DTIb0'
1632        msel = msel1 | msel2 | msel3 | msel4 | msel5
1633    if sum(msel) == 0:
1634        return {'raw': None, 'filt': None}
1635    metasub = mmdf[msel].copy()
1636
1637    if verbose:
1638        print(f"{wmod} {(metasub.shape[0])} pre")
1639
1640    metasub['subjectID']=None
1641    metasub['date']=None
1642    metasub['subjectIDdate']=None
1643    metasub['imageID']=None
1644    metasub['negol']=math.nan
1645    for k in metasub.index:
1646        temp = metasub.loc[k, 'filename'].split( mysep )
1647        metasub.loc[k,'subjectID'] = str( temp[1] )
1648        metasub.loc[k,'date'] = str( temp[2] )
1649        metasub.loc[k,'subjectIDdate'] = str( temp[1] + mysep + temp[2] )
1650        metasub.loc[k,'imageID'] = str( temp[4])
1651
1652
1653    if 'ol_' in outlier_column:
1654        metasub['negol'] = metasub[outlier_column].max() - metasub[outlier_column]
1655    else:
1656        metasub['negol'] = metasub[outlier_column]
1657    if 'date' not in metasub.keys():
1658        metasub['date']=None
1659    metasubq = add_repeat_column( metasub, 'subjectIDdate' )
1660    metasubq = highest_quality_repeat(metasubq, 'filename', 'date', 'negol')
1661
1662    if verbose:
1663        print(f"{wmod} {metasubq.shape[0]} post")
1664
1665#    metasub = metasub.astype(str)
1666#    metasubq = metasubq.astype(str)
1667    metasub[outlier_column]=metasub[outlier_column].astype(float)
1668    metasubq[outlier_column]=metasubq[outlier_column].astype(float)
1669    return {'raw': metasub, 'filt': metasubq}

Selects the best repeats per modality.

Args: wmod (str): the modality of the image ( 'T1w', 'T2Flair', 'NM2DMT' 'rsfMRI', 'DTI')

mysep (str, optional): the separator used in the image file names. Defaults to '-'.

outlier_name : column name for outlier score

verbose (bool, optional): default True

Returns:

list: a list containing two metadata dataframes - raw and filt. raw contains all the metadata for the selected modality and filt contains the metadata filtered for highest quality repeats.

def nrg_2_bids(nrg_filename):
 990def nrg_2_bids( nrg_filename ):
 991    """
 992    Convert an NRG filename to BIDS path/filename.
 993
 994    Parameters:
 995    nrg_filename (str): The NRG filename to convert.
 996
 997    Returns:
 998    str: The BIDS path/filename.
 999    """
1000
1001    # Split the NRG filename into its components
1002    nrg_dirname, nrg_basename = os.path.split(nrg_filename)
1003    nrg_suffix = '.' + nrg_basename.split('.',1)[-1]
1004    nrg_basename = nrg_basename.replace(nrg_suffix, '') # remove ext
1005    nrg_parts = nrg_basename.split('-')
1006    nrg_subject_id = nrg_parts[1]
1007    nrg_modality = nrg_parts[3]
1008    nrg_repeat= nrg_parts[4]
1009
1010    # Build the BIDS path/filename
1011    bids_dirname = os.path.join(nrg_dirname, 'bids')
1012    bids_subject = f'sub-{nrg_subject_id}'
1013    bids_session = f'ses-{nrg_repeat}'
1014
1015    valid_modalities = get_valid_modalities()
1016    if nrg_modality is not None:
1017        if not nrg_modality in valid_modalities:
1018            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
1019
1020    if nrg_modality == 'T1w' :
1021        bids_modality_folder = 'anat'
1022        bids_modality_filename = 'T1w'
1023
1024    if nrg_modality == 'T2Flair' :
1025        bids_modality_folder = 'anat'
1026        bids_modality_filename = 'flair'
1027
1028    if nrg_modality == 'NM2DMT' :
1029        bids_modality_folder = 'anat'
1030        bids_modality_filename = 'nm2dmt'
1031
1032    if nrg_modality == 'DTI' or nrg_modality == 'DTI_RL' or nrg_modality == 'DTI_LR' :
1033        bids_modality_folder = 'dwi'
1034        bids_modality_filename = 'dwi'
1035
1036    if nrg_modality == 'rsfMRI' or nrg_modality == 'rsfMRI_RL' or nrg_modality == 'rsfMRI_LR' :
1037        bids_modality_folder = 'func'
1038        bids_modality_filename = 'func'
1039
1040    if nrg_modality == 'perf'  :
1041        bids_modality_folder = 'perf'
1042        bids_modality_filename = 'perf'
1043
1044    bids_suffix = nrg_suffix[1:]
1045    bids_filename = f'{bids_subject}_{bids_session}_{bids_modality_filename}.{bids_suffix}'
1046
1047    # Return bids filepath/filename
1048    return os.path.join(bids_dirname, bids_subject, bids_session, bids_modality_folder, bids_filename)

Convert an NRG filename to BIDS path/filename.

Parameters: nrg_filename (str): The NRG filename to convert.

Returns: str: The BIDS path/filename.

def bids_2_nrg(bids_filename, project_name, date, nrg_modality=None):
1051def bids_2_nrg( bids_filename, project_name, date, nrg_modality=None ):
1052    """
1053    Convert a BIDS filename to NRG path/filename.
1054
1055    Parameters:
1056    bids_filename (str): The BIDS filename to convert
1057    project_name (str) : Name of project (i.e. PPMI)
1058    date (str) : Date of image acquisition
1059
1060
1061    Returns:
1062    str: The NRG path/filename.
1063    """
1064
1065    bids_dirname, bids_basename = os.path.split(bids_filename)
1066    bids_suffix = '.'+ bids_basename.split('.',1)[-1]
1067    bids_basename = bids_basename.replace(bids_suffix, '') # remove ext
1068    bids_parts = bids_basename.split('_')
1069    nrg_subject_id = bids_parts[0].replace('sub-','')
1070    nrg_image_id = bids_parts[1].replace('ses-', '')
1071    bids_modality = bids_parts[2]
1072    valid_modalities = get_valid_modalities()
1073    if nrg_modality is not None:
1074        if not nrg_modality in valid_modalities:
1075            raise ValueError('nrg_modality ' + str(nrg_modality) + " not a valid mm modality: " + get_valid_modalities(asString=True))
1076
1077    if bids_modality == 'anat' and nrg_modality is None :
1078        nrg_modality = 'T1w'
1079
1080    if bids_modality == 'dwi' and nrg_modality is None  :
1081        nrg_modality = 'DTI'
1082
1083    if bids_modality == 'func' and nrg_modality is None  :
1084        nrg_modality = 'rsfMRI'
1085
1086    if bids_modality == 'perf' and nrg_modality is None  :
1087        nrg_modality = 'perf'
1088
1089    nrg_suffix = bids_suffix[1:]
1090    nrg_filename = f'{project_name}-{nrg_subject_id}-{date}-{nrg_modality}-{nrg_image_id}.{nrg_suffix}'
1091
1092    return os.path.join(project_name, nrg_subject_id, date, nrg_modality, nrg_image_id,nrg_filename)

Convert a BIDS filename to NRG path/filename.

Parameters: bids_filename (str): The BIDS filename to convert project_name (str) : Name of project (i.e. PPMI) date (str) : Date of image acquisition

Returns: str: The NRG path/filename.

def parse_nrg_filename(x, separator='-'):
973def parse_nrg_filename( x, separator='-' ):
974    """
975    split a NRG filename into its named parts
976    """
977    temp = x.split( separator )
978    if len(temp) != 5:
979        raise ValueError(x + " not a valid NRG filename")
980    return {
981        'project':temp[0],
982        'subjectID':temp[1],
983        'date':temp[2],
984        'modality':temp[3],
985        'imageID':temp[4]
986    }

split a NRG filename into its named parts

def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11602def novelty_detection_svm(df_train, df_test, nu=0.05, kernel='rbf'):
11603    """
11604    This function performs novelty detection using One-Class SVM.
11605
11606    Parameters:
11607
11608    - df_train (pandas dataframe): training data used to fit the model
11609
11610    - df_test (pandas dataframe): test data used to predict novelties
11611
11612    - nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)
11613
11614    - kernel (str): kernel type used in the SVM algorithm (default: 'rbf')
11615
11616    Returns:
11617
11618    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11619    """
11620    from sklearn.svm import OneClassSVM
11621    # Fit the model on the training data
11622    df_train[ df_train == math.inf ] = 0
11623    df_test[ df_test == math.inf ] = 0
11624    clf = OneClassSVM(nu=nu, kernel=kernel)
11625    from sklearn.preprocessing import StandardScaler
11626    scaler = StandardScaler()
11627    scaler.fit(df_train)
11628    clf.fit(scaler.transform(df_train))
11629    predictions = clf.predict(scaler.transform(df_test))
11630    predictions[predictions==1]=0
11631    predictions[predictions==-1]=1
11632    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11633        return pd.Series(predictions, index=df_test.index)
11634    else:
11635        return pd.Series(predictions)

This function performs novelty detection using One-Class SVM.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • nu (float): parameter controlling the fraction of training errors and the fraction of support vectors (default: 0.05)

  • kernel (str): kernel type used in the SVM algorithm (default: 'rbf')

Returns:

predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)

def novelty_detection_ee(df_train, df_test, contamination=0.05):
11566def novelty_detection_ee(df_train, df_test, contamination=0.05):
11567    """
11568    This function performs novelty detection using Elliptic Envelope.
11569
11570    Parameters:
11571
11572    - df_train (pandas dataframe): training data used to fit the model
11573
11574    - df_test (pandas dataframe): test data used to predict novelties
11575
11576    - contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)
11577
11578    Returns:
11579
11580    predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11581    """
11582    import pandas as pd
11583    from sklearn.covariance import EllipticEnvelope
11584    # Fit the model on the training data
11585    clf = EllipticEnvelope(contamination=contamination,support_fraction=1)
11586    df_train[ df_train == math.inf ] = 0
11587    df_test[ df_test == math.inf ] = 0
11588    from sklearn.preprocessing import StandardScaler
11589    scaler = StandardScaler()
11590    scaler.fit(df_train)
11591    clf.fit(scaler.transform(df_train))
11592    predictions = clf.predict(scaler.transform(df_test))
11593    predictions[predictions==1]=0
11594    predictions[predictions==-1]=1
11595    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11596        return pd.Series(predictions, index=df_test.index)
11597    else:
11598        return pd.Series(predictions)

This function performs novelty detection using Elliptic Envelope.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • contamination (float): parameter controlling the proportion of outliers in the data (default: 0.05)

Returns:

predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)

def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11639def novelty_detection_lof(df_train, df_test, n_neighbors=20):
11640    """
11641    This function performs novelty detection using Local Outlier Factor (LOF).
11642
11643    Parameters:
11644
11645    - df_train (pandas dataframe): training data used to fit the model
11646
11647    - df_test (pandas dataframe): test data used to predict novelties
11648
11649    - n_neighbors (int): number of neighbors used to compute the LOF (default: 20)
11650
11651    Returns:
11652
11653    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11654
11655    """
11656    from sklearn.neighbors import LocalOutlierFactor
11657    # Fit the model on the training data
11658    df_train[ df_train == math.inf ] = 0
11659    df_test[ df_test == math.inf ] = 0
11660    clf = LocalOutlierFactor(n_neighbors=n_neighbors, algorithm='auto',contamination='auto', novelty=True)
11661    from sklearn.preprocessing import StandardScaler
11662    scaler = StandardScaler()
11663    scaler.fit(df_train)
11664    clf.fit(scaler.transform(df_train))
11665    predictions = clf.predict(scaler.transform(df_test))
11666    predictions[predictions==1]=0
11667    predictions[predictions==-1]=1
11668    if str(type(df_train))=="<class 'pandas.core.frame.DataFrame'>":
11669        return pd.Series(predictions, index=df_test.index)
11670    else:
11671        return pd.Series(predictions)

This function performs novelty detection using Local Outlier Factor (LOF).

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • n_neighbors (int): number of neighbors used to compute the LOF (default: 20)

Returns:

  • predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11674def novelty_detection_loop(df_train, df_test, n_neighbors=20, distance_metric='minkowski'):
11675    """
11676    This function performs novelty detection using Local Outlier Factor (LOF).
11677
11678    Parameters:
11679
11680    - df_train (pandas dataframe): training data used to fit the model
11681
11682    - df_test (pandas dataframe): test data used to predict novelties
11683
11684    - n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)
11685
11686    - distance_metric : default minkowski
11687
11688    Returns:
11689
11690    - predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
11691
11692    """
11693    from PyNomaly import loop
11694    from sklearn.neighbors import NearestNeighbors
11695    from sklearn.preprocessing import StandardScaler
11696    scaler = StandardScaler()
11697    scaler.fit(df_train)
11698    data = np.vstack( [scaler.transform(df_test),scaler.transform(df_train)])
11699    neigh = NearestNeighbors(n_neighbors=n_neighbors, metric=distance_metric)
11700    neigh.fit(data)
11701    d, idx = neigh.kneighbors(data, return_distance=True)
11702    m = loop.LocalOutlierProbability(distance_matrix=d, neighbor_matrix=idx, n_neighbors=n_neighbors).fit()
11703    return m.local_outlier_probabilities[range(df_test.shape[0])]

This function performs novelty detection using Local Outlier Factor (LOF).

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

  • n_neighbors (int): number of neighbors used to compute the LOOP (default: 20)

  • distance_metric : default minkowski

Returns:

  • predictions (pandas series): predicted labels for the test data (1 for novelties, 0 for inliers)
def novelty_detection_quantile(df_train, df_test):
11707def novelty_detection_quantile(df_train, df_test):
11708    """
11709    This function performs novelty detection using quantiles for each column.
11710
11711    Parameters:
11712
11713    - df_train (pandas dataframe): training data used to fit the model
11714
11715    - df_test (pandas dataframe): test data used to predict novelties
11716
11717    Returns:
11718
11719    - quantiles for the test sample at each column where values range in [0,1]
11720        and higher values mean the column is closer to the edge of the distribution
11721
11722    """
11723    myqs = df_test.copy()
11724    n = df_train.shape[0]
11725    df_trainkeys = df_train.keys()
11726    for k in range( df_train.shape[1] ):
11727        mykey = df_trainkeys[k]
11728        temp = (myqs[mykey][0] >  df_train[mykey]).sum() / n
11729        myqs[mykey] = abs( temp - 0.5 ) / 0.5
11730    return myqs

This function performs novelty detection using quantiles for each column.

Parameters:

  • df_train (pandas dataframe): training data used to fit the model

  • df_test (pandas dataframe): test data used to predict novelties

Returns:

  • quantiles for the test sample at each column where values range in [0,1] and higher values mean the column is closer to the edge of the distribution
def generate_mm_dataframe( projectID, subjectID, date, imageUniqueID, modality, source_image_directory, output_image_directory, t1_filename, flair_filename=[], rsf_filenames=[], dti_filenames=[], nm_filenames=[], perf_filename=[], pet3d_filename=[]):
649def generate_mm_dataframe(
650        projectID,
651        subjectID,
652        date,
653        imageUniqueID,
654        modality,
655        source_image_directory,
656        output_image_directory,
657        t1_filename,
658        flair_filename=[],
659        rsf_filenames=[],
660        dti_filenames=[],
661        nm_filenames=[],
662        perf_filename=[],
663        pet3d_filename=[],
664):
665    """
666    Generate a DataFrame for medical imaging data with extensive validation of input parameters.
667
668    This function creates a DataFrame containing information about medical imaging files,
669    ensuring that filenames match expected patterns for their modalities and that all
670    required images exist. It also validates the number of filenames provided for specific
671    modalities like rsfMRI, DTI, and NM.
672
673    Parameters:
674    - projectID (str): Project identifier.
675    - subjectID (str): Subject identifier.
676    - date (str): Date of the imaging study.
677    - imageUniqueID (str): Unique image identifier.
678    - modality (str): Modality of the imaging study.
679    - source_image_directory (str): Directory of the source images.
680    - output_image_directory (str): Directory for output images.
681    - t1_filename (str): Filename of the T1-weighted image.
682    - flair_filename (list): List of filenames for FLAIR images.
683    - rsf_filenames (list): List of filenames for rsfMRI images.
684    - dti_filenames (list): List of filenames for DTI images.
685    - nm_filenames (list): List of filenames for NM images.
686    - perf_filename (list): List of filenames for perfusion images.
687    - pet3d_filename (list): List of filenames for pet3d images.
688
689    Returns:
690    - pandas.DataFrame: A DataFrame containing the validated imaging study information.
691
692    Raises:
693    - ValueError: If any validation checks fail or if the number of columns does not match the data.
694    """
695    def check_pd_construction(data, columns):
696        # Check if the length of columns matches the length of data in each row
697        if all(len(row) == len(columns) for row in data):
698            return True
699        else:
700            return False
701    from os.path import exists
702    valid_modalities = get_valid_modalities()
703    if not isinstance(t1_filename, str):
704        raise ValueError("t1_filename is not a string")
705    if not exists(t1_filename):
706        raise ValueError("t1_filename does not exist")
707    if modality not in valid_modalities:
708        raise ValueError('modality ' + str(modality) + " not a valid mm modality:  " + get_valid_modalities(asString=True))
709    # if not exists( output_image_directory ):
710    #    raise ValueError("output_image_directory does not exist")
711    if not exists( source_image_directory ):
712        raise ValueError("source_image_directory does not exist")
713    if len( rsf_filenames ) > 2:
714        raise ValueError("len( rsf_filenames ) > 2")
715    if len( dti_filenames ) > 3:
716        raise ValueError("len( dti_filenames ) > 3")
717    if len( nm_filenames ) > 11:
718        raise ValueError("len( nm_filenames ) > 11")
719    if len( rsf_filenames ) < 2:
720        for k in range(len(rsf_filenames),2):
721            rsf_filenames.append(None)
722    if len( dti_filenames ) < 3:
723        for k in range(len(dti_filenames),3):
724            dti_filenames.append(None)
725    if len( nm_filenames ) < 10:
726        for k in range(len(nm_filenames),10):
727            nm_filenames.append(None)
728    # check modality names
729    if not "T1w" in t1_filename:
730        raise ValueError("T1w is not in t1 filename " + t1_filename)
731    if flair_filename is not None:
732        if isinstance(flair_filename,list):
733            if (len(flair_filename) == 0):
734                flair_filename=None
735            else:
736                print("Take first entry from flair_filename list")
737                flair_filename=flair_filename[0]
738    if flair_filename is not None and not "lair" in flair_filename:
739            raise ValueError("flair is not flair filename " + flair_filename)
740    ## perfusion
741    if perf_filename is not None:
742        if isinstance(perf_filename,list):
743            if (len(perf_filename) == 0):
744                perf_filename=None
745            else:
746                print("Take first entry from perf_filename list")
747                perf_filename=perf_filename[0]
748    if perf_filename is not None and not "perf" in perf_filename:
749            raise ValueError("perf_filename is not perf filename " + perf_filename)
750
751    if pet3d_filename is not None:
752        if isinstance(pet3d_filename,list):
753            if (len(pet3d_filename) == 0):
754                pet3d_filename=None
755            else:
756                print("Take first entry from pet3d_filename list")
757                pet3d_filename=pet3d_filename[0]
758    if pet3d_filename is not None and not "pet" in pet3d_filename:
759            raise ValueError("pet3d_filename is not pet filename " + pet3d_filename)
760    
761    for k in nm_filenames:
762        if k is not None:
763            if not "NM" in k:
764                raise ValueError("NM is not flair filename " + k)
765    for k in dti_filenames:
766        if k is not None:
767            if not "DTI" in k and not "dwi" in k:
768                raise ValueError("DTI/DWI is not dti filename " + k)
769    for k in rsf_filenames:
770        if k is not None:
771            if not "fMRI" in k and not "func" in k:
772                raise ValueError("rsfMRI/func is not rsfmri filename " + k)
773    if perf_filename is not None:
774        if not "perf" in perf_filename:
775                raise ValueError("perf_filename is not a valid perfusion (perf) filename " + k)
776    allfns = [t1_filename] + [flair_filename] + nm_filenames + dti_filenames + rsf_filenames + [perf_filename] + [pet3d_filename]
777    for k in allfns:
778        if k is not None:
779            if not isinstance(k, str):
780                raise ValueError(str(k) + " is not a string")
781            if not exists( k ):
782                raise ValueError( "image " + k + " does not exist")
783    coredata = [
784        projectID,
785        subjectID,
786        date,
787        imageUniqueID,
788        modality,
789        source_image_directory,
790        output_image_directory,
791        t1_filename,
792        flair_filename, 
793        perf_filename,
794        pet3d_filename]
795    mydata0 = coredata +  rsf_filenames + dti_filenames
796    mydata = mydata0 + nm_filenames
797    corecols = [
798        'projectID',
799        'subjectID',
800        'date',
801        'imageID',
802        'modality',
803        'sourcedir',
804        'outputdir',
805        'filename',
806        'flairid',
807        'perfid',
808        'pet3did']
809    mycols0 = corecols + [
810        'rsfid1', 'rsfid2',
811        'dtid1', 'dtid2','dtid3']
812    nmext = [
813        'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5',
814        'nmid6', 'nmid7','nmid8', 'nmid9', 'nmid10' #, 'nmid11'
815    ]
816    mycols = mycols0 + nmext
817    if not check_pd_construction( [mydata], mycols ) :
818#        print( mydata )
819#        print( len(mydata ))
820#        print( mycols )
821#        print( len(mycols ))
822        raise ValueError( "Error in generate_mm_dataframe: len( mycols ) != len( mydata ) which indicates a bad input parameter to this function." )
823    studycsv = pd.DataFrame([ mydata ], columns=mycols)
824    return studycsv

Generate a DataFrame for medical imaging data with extensive validation of input parameters.

This function creates a DataFrame containing information about medical imaging files, ensuring that filenames match expected patterns for their modalities and that all required images exist. It also validates the number of filenames provided for specific modalities like rsfMRI, DTI, and NM.

Parameters:

  • projectID (str): Project identifier.
  • subjectID (str): Subject identifier.
  • date (str): Date of the imaging study.
  • imageUniqueID (str): Unique image identifier.
  • modality (str): Modality of the imaging study.
  • source_image_directory (str): Directory of the source images.
  • output_image_directory (str): Directory for output images.
  • t1_filename (str): Filename of the T1-weighted image.
  • flair_filename (list): List of filenames for FLAIR images.
  • rsf_filenames (list): List of filenames for rsfMRI images.
  • dti_filenames (list): List of filenames for DTI images.
  • nm_filenames (list): List of filenames for NM images.
  • perf_filename (list): List of filenames for perfusion images.
  • pet3d_filename (list): List of filenames for pet3d images.

Returns:

  • pandas.DataFrame: A DataFrame containing the validated imaging study information.

Raises:

  • ValueError: If any validation checks fail or if the number of columns does not match the data.
def aggregate_antspymm_results( input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path='./Processed/ANTsExpArt/', hiervariable='T1wHierarchical', valid_modalities=None, verbose=False):
12263def aggregate_antspymm_results(input_csv, subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Processed/ANTsExpArt/", hiervariable='T1wHierarchical', valid_modalities=None, verbose=False ):
12264    """
12265    Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.
12266
12267    Parameters:
12268    - input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
12269    - subject_col (str): Name of the column to store subject IDs.
12270    - date_col (str): Name of the column to store date information.
12271    - image_col (str): Name of the column to store image IDs.
12272    - date_column (str): Name of the column representing the date information.
12273    - base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
12274    - hiervariable (str) : the string variable denoting the Hierarchical output
12275    - valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
12276    - verbose : boolean
12277
12278    Note:
12279    This function is tested under limited circumstances. Use with caution.
12280
12281    Example usage:
12282    agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")
12283
12284    Author:
12285    Avants and ChatGPT
12286    """
12287    import pandas as pd
12288    import numpy as np
12289    from glob import glob
12290
12291    def myread_csv(x, cnms):
12292        """
12293        Reads a CSV file and returns a DataFrame excluding specified columns.
12294
12295        Parameters:
12296        - x (str): File path of the input CSV file describing the blind QC output
12297        - cnms (list): List of column names to exclude from the DataFrame.
12298
12299        Returns:
12300        pd.DataFrame: DataFrame with specified columns excluded.
12301        """
12302        df = pd.read_csv(x)
12303        return df.loc[:, ~df.columns.isin(cnms)]
12304
12305    import warnings
12306    # Warning message for untested function
12307    warnings.warn("Warning: This function is not well tested. Use with caution.")
12308
12309    if valid_modalities is None:
12310        valid_modalities = get_valid_modalities('long')
12311
12312    # Read the input CSV file
12313    df = pd.read_csv(input_csv)
12314
12315    # Filter rows where modality is 'T1w'
12316    df = df[df['modality'] == 'T1w']
12317    badnames = get_names_from_data_frame( ['Unnamed'], df )
12318    df=df.drop(badnames, axis=1)
12319
12320    # Add new columns for subject ID, date, and image ID
12321    df[subject_col] = np.nan
12322    df[date_col] = date_column
12323    df[image_col] = np.nan
12324    df = df.astype({subject_col: str, date_col: str, image_col: str })
12325
12326#    if verbose:
12327#        print( df.shape )
12328#        print( df.dtypes )
12329
12330    # prefilter df for data that exists
12331    keep = np.tile( False, df.shape[0] )
12332    for x in range(df.shape[0]):
12333        temp = df['filename'].iloc[x].split("_")
12334        # Generalized search paths
12335        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12336        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12337        if len( hierfn ) > 0:
12338            keep[x]=True
12339
12340    
12341    df=df[keep]
12342    
12343    if verbose:
12344        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12345        print( df.shape )
12346
12347    myct = 0
12348    for x in range( df.shape[0]):
12349        if verbose:
12350            print(f"{x}...")
12351        locind = df.index[x]
12352        temp = df['filename'].iloc[x].split("_")
12353        if verbose:
12354            print( temp )
12355        df[subject_col].iloc[x]=temp[0]
12356        df[date_col].iloc[x]=date_column
12357        df[image_col].iloc[x]=temp[1]
12358
12359        # Generalized search paths
12360        path_template = f"{base_path}{temp[0]}/{date_column}/*/*/*"
12361        if verbose:
12362            print(path_template)
12363        hierfn = sorted(glob( path_template + "-" + hiervariable + "-*wide.csv" ) )
12364        if len( hierfn ) > 0:
12365            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12366            if verbose:
12367                print(hierfn)
12368            hdf = pd.read_csv(hierfn[0])
12369            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12370            hdf=hdf.drop(badnames, axis=1)
12371            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12372            corenames = list(np.array(hdf.columns)[nums])
12373            hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12374            myct = myct + 1
12375            dflist = [hdf]
12376
12377            for mymod in valid_modalities:
12378                t1wfn = sorted(glob( path_template+ "-" + mymod + "-*wide.csv" ) )
12379                if len( t1wfn ) > 0 :
12380                    if verbose:
12381                        print(t1wfn)
12382                    t1df = myread_csv(t1wfn[0], corenames)
12383                    t1df = filter_df( t1df, mymod+'_')
12384                    dflist = dflist + [t1df]
12385                
12386            hdf = pd.concat( dflist, axis=1, ignore_index=False )
12387            if verbose:
12388                print( df.loc[locind,'filename'] )
12389            if myct == 1:
12390                subdf = df.iloc[[x]]
12391                hdf.index = subdf.index.copy()
12392                df = pd.concat( [df,hdf], axis=1, ignore_index=False )
12393            else:
12394                commcols = list(set(hdf.columns).intersection(df.columns))
12395                df.loc[locind, commcols] = hdf.loc[0, commcols]
12396    badnames = get_names_from_data_frame( ['Unnamed'], df )
12397    df=df.drop(badnames, axis=1)
12398    return( df )

Aggregate ANTsPyMM results from the specified CSV file and save the aggregated results to a new CSV file.

Parameters:

  • input_csv (str): File path of the input CSV file containing ANTsPyMM QC results averaged and with outlier measurements.
  • subject_col (str): Name of the column to store subject IDs.
  • date_col (str): Name of the column to store date information.
  • image_col (str): Name of the column to store image IDs.
  • date_column (str): Name of the column representing the date information.
  • base_path (str): Base path for search paths. Defaults to "./Processed/ANTsExpArt/".
  • hiervariable (str) : the string variable denoting the Hierarchical output
  • valid_modalities (str array) : identifies for each modality; if None will be replaced by get_valid_modalities(long=True)
  • verbose : boolean

Note: This function is tested under limited circumstances. Use with caution.

Example usage: agg_df = aggregate_antspymm_results("qcdfaol.csv", subject_col='subjectID', date_col='date', image_col='imageID', date_column='ses-1', base_path="./Your/Custom/Path/")

Author: Avants and ChatGPT

def aggregate_antspymm_results_sdf( study_df, project_col='projectID', subject_col='subjectID', date_col='date', image_col='imageID', base_path='./', hiervariable='T1wHierarchical', splitsep='-', idsep='-', wild_card_modality_id=False, second_split=False, verbose=False):
12421def aggregate_antspymm_results_sdf(
12422    study_df, 
12423    project_col='projectID',
12424    subject_col='subjectID', 
12425    date_col='date', 
12426    image_col='imageID', 
12427    base_path="./", 
12428    hiervariable='T1wHierarchical', 
12429    splitsep='-',
12430    idsep='-',
12431    wild_card_modality_id=False,
12432    second_split=False,
12433    verbose=False ):
12434    """
12435    Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame.  This assumes data is organized on disk 
12436    as follows:  rootdir/projectID/subjectID/date/outputid/imageid/ where 
12437    outputid is modality-specific and created by ANTsPyMM processing.
12438
12439    Parameters:
12440    - study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
12441    - project_col (str): Name of the column that stores the project ID
12442    - subject_col (str): Name of the column to store subject IDs.
12443    - date_col (str): Name of the column to store date information.
12444    - image_col (str): Name of the column to store image IDs.
12445    - base_path (str): Base path for searching for processing outputs of ANTsPyMM.
12446    - hiervariable (str) : the string variable denoting the Hierarchical output
12447    - splitsep (str):  the separator used to split the filename
12448    - idsep (str): the separator used to partition subjectid date and imageid 
12449        for example, if idsep is - then we have subjectid-date-imageid
12450    - wild_card_modality_id (bool): keep if False for safer execution
12451    - second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
12452    - verbose : boolean
12453
12454    Note:
12455    This function is tested under limited circumstances. Use with caution.
12456    One particular gotcha is if the imageID is stored as a numeric value in the dataframe 
12457    but is meant to be a string.  E.g. '000' (string) would be interpreted as 0 in the 
12458    file name glob.  This would miss the extant (on disk) csv.
12459
12460    Example usage:
12461    agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")
12462
12463    Author:
12464    Avants and ChatGPT
12465    """
12466    import pandas as pd
12467    import numpy as np
12468    from glob import glob
12469
12470    def progress_reporter(current_step, total_steps, width=50):
12471        # Calculate the proportion of progress
12472        progress = current_step / total_steps
12473        # Calculate the number of 'filled' characters in the progress bar
12474        filled_length = int(width * progress)
12475        # Create the progress bar string
12476        bar = 'â–ˆ' * filled_length + '-' * (width - filled_length)
12477        # Print the progress bar with percentage
12478        print(f'\rProgress: |{bar}| {int(100 * progress)}%', end='\r')
12479        # Print a new line when the progress is complete
12480        if current_step == total_steps:
12481            print()
12482
12483    def myread_csv(x, cnms):
12484        """
12485        Reads a CSV file and returns a DataFrame excluding specified columns.
12486
12487        Parameters:
12488        - x (str): File path of the input CSV file describing the blind QC output
12489        - cnms (list): List of column names to exclude from the DataFrame.
12490
12491        Returns:
12492        pd.DataFrame: DataFrame with specified columns excluded.
12493        """
12494        df = pd.read_csv(x)
12495        return df.loc[:, ~df.columns.isin(cnms)]
12496
12497    import warnings
12498    # Warning message for untested function
12499    warnings.warn("Warning: This function is not well tested. Use with caution.")
12500
12501    vmoddict = {}
12502    # Add key-value pairs
12503    vmoddict['imageID'] = 'T1w'
12504    vmoddict['flairid'] = 'T2Flair'
12505    vmoddict['perfid'] = 'perf'
12506    vmoddict['pet3did'] = 'pet3d'
12507    vmoddict['rsfid1'] = 'rsfMRI'
12508#    vmoddict['rsfid2'] = 'rsfMRI'
12509    vmoddict['dtid1'] = 'DTI'
12510#    vmoddict['dtid2'] = 'DTI'
12511    vmoddict['nmid1'] = 'NM2DMT'
12512#    vmoddict['nmid2'] = 'NM2DMT'
12513
12514    # Filter rows where modality is 'T1w'
12515    df = study_df[ study_df['modality'] == 'T1w']
12516    badnames = get_names_from_data_frame( ['Unnamed'], df )
12517    df=df.drop(badnames, axis=1)
12518    # prefilter df for data that exists
12519    keep = np.tile( False, df.shape[0] )
12520    for x in range(df.shape[0]):
12521        myfn = os.path.basename( df['filename'].iloc[x] )
12522        temp = myfn.split( splitsep )
12523        # Generalized search paths
12524        sid0 = str( temp[1] )
12525        sid = str( df[subject_col].iloc[x] )
12526        if sid0 != sid:
12527            warnings.warn("OUTER: the id derived from the filename " + sid0 + " does not match the id stored in the data frame " + sid )
12528            warnings.warn( "filename is : " +  myfn )
12529            warnings.warn( "sid is : " + sid )
12530            warnings.warn( "x is : " + str(x) )
12531        myproj = str(df[project_col].iloc[x])
12532        mydate = str(df[date_col].iloc[x])
12533        myid = str(df[image_col].iloc[x])
12534        if second_split:
12535            myid = myid.split(".")[0]
12536        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12537        hierfn = sorted(glob( path_template + "*" + hiervariable + "*wide.csv" ) )
12538        if len( hierfn ) == 0:
12539            print( hierfn )
12540            print( path_template )
12541            print( myproj )
12542            print( sid )
12543            print( mydate ) 
12544            print( myid )
12545        if len( hierfn ) > 0:
12546            keep[x]=True
12547
12548    # df=df[keep]
12549    if df.shape[0] == 0:
12550        warnings.warn("input data frame shape is filtered down to zero")
12551        return df
12552
12553    if not df.index.is_unique:
12554        warnings.warn("data frame does not have unique indices.  we therefore reset the index to allow the function to continue on." )
12555        df = df.reset_index()
12556
12557    
12558    if verbose:
12559        print( "original input had shape " + str( df.shape[0] ) + " (T1 only) and we find " + str( (keep).sum() ) + " with hierarchical output defined by variable: " + hiervariable )
12560        print( df.shape )
12561
12562    dfout = pd.DataFrame()
12563    myct = 0
12564    for x in range( df.shape[0]):
12565        if verbose:
12566            print("\n\n-------------------------------------------------")
12567            print(f"{x}...")
12568        else:
12569            progress_reporter(x, df.shape[0], width=500)
12570        locind = df.index[x]
12571        myfn = os.path.basename( df['filename'].iloc[x] )
12572        sid = str( df[subject_col].iloc[x] )
12573        tempB = myfn.split( splitsep )
12574        sid0 = str(tempB[1])
12575        if sid0 != sid and verbose:
12576            warnings.warn("INNER: the id derived from the filename " + str(sid) + " does not match the id stored in the data frame " + str(sid0) )
12577            warnings.warn( "filename is : " +  str(myfn) )
12578            warnings.warn( "sid is : " + str(sid) )
12579            warnings.warn( "x is : " + str(x) )
12580            warnings.warn( "index is : " + str(locind) )
12581        myproj = str(df[project_col].iloc[x])
12582        mydate = str(df[date_col].iloc[x])
12583        myid = str(df[image_col].iloc[x])
12584        if second_split:
12585            myid = myid.split(".")[0]
12586        if verbose:
12587            print( myfn )
12588            print( temp )
12589            print( "id " + sid  )
12590        path_template = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + hiervariable + '/' + str(myid) + "/"
12591        searchhier = path_template + "*" + hiervariable + "*wide.csv"
12592        if verbose:
12593            print( searchhier )
12594        hierfn = sorted( glob( searchhier ) )
12595        if len( hierfn ) > 1:
12596            raise ValueError("there are " + str( len( hierfn ) ) + " number of hier fns with search path " + searchhier )
12597        if len( hierfn ) == 1:
12598            hdf=t1df=dtdf=rsdf=perfdf=nmdf=flairdf=None
12599            if verbose:
12600                print(hierfn)
12601            hdf = pd.read_csv(hierfn[0])
12602            if verbose:
12603                print( hdf['vol_hemisphere_lefthemispheres'] )
12604            badnames = get_names_from_data_frame( ['Unnamed'], hdf )
12605            hdf=hdf.drop(badnames, axis=1)
12606            nums = [isinstance(hdf[col].iloc[0], (int, float)) for col in hdf.columns]
12607            corenames = list(np.array(hdf.columns)[nums])
12608            # hdf.loc[:, nums] = hdf.loc[:, nums].add_prefix("T1Hier_")
12609            hdf = hdf.add_prefix("T1Hier_")
12610            myct = myct + 1
12611            dflist = [hdf]
12612
12613            for mymod in vmoddict.keys():
12614                if verbose:
12615                    print("\n\n************************* " + mymod + " *************************")
12616                modalityclass = vmoddict[ mymod ]
12617                if wild_card_modality_id:
12618                    mymodid = '*'
12619                else:
12620                    mymodid = str( df[mymod].iloc[x] )
12621                    if mymodid.lower() != "nan" and mymodid.lower() != "na":
12622                        mymodid = os.path.basename( mymodid )
12623                        mymodid = os.path.splitext( mymodid )[0]
12624                        mymodid = os.path.splitext( mymodid )[0]
12625                        temp = mymodid.split( idsep )
12626                        mymodid = temp[ len( temp )-1 ]
12627                    else:
12628                        if verbose:
12629                            print("missing")
12630                        continue
12631                if verbose:
12632                    print( "modality id is " + mymodid + " for modality " + modalityclass + ' modality specific subj ' + sid + ' modality specific id is ' + myid + " its date " +  mydate )
12633                modalityclasssearch = modalityclass
12634                if modalityclass in ['rsfMRI','DTI']:
12635                    modalityclasssearch=modalityclass+"*"
12636                path_template_m = base_path + "/" + myproj +  "/" + sid + "/" + mydate + '/' + modalityclasssearch + '/' + mymodid + "/"
12637                modsearch = path_template_m + "*" + modalityclasssearch + "*wide.csv"
12638                if verbose:
12639                    print( modsearch )
12640                t1wfn = sorted( glob( modsearch ) )
12641                if len( t1wfn ) > 1:
12642                    nlarge = len(t1wfn)
12643                    t1wfn = find_most_recent_file( t1wfn )
12644                    warnings.warn("there are " + str( nlarge ) + " number of wide fns with search path " + modsearch + " we take the most recent of these " + t1wfn[0] )
12645                if len( t1wfn ) == 1:
12646                    if verbose:
12647                        print(t1wfn)
12648                    t1df = myread_csv(t1wfn[0], corenames)
12649                    t1df = filter_df( t1df, modalityclass+'_')
12650                    dflist = dflist + [t1df]
12651                else:
12652                    if verbose:
12653                        print( " cannot find " + modsearch )
12654                
12655            hdf = pd.concat( dflist, axis=1, ignore_index=False)
12656            if verbose:
12657                print( "count: " + str( myct ) )
12658            subdf = df.iloc[[x]]
12659            hdf.index = subdf.index.copy()
12660            subdf = pd.concat( [subdf,hdf], axis=1, ignore_index=False)
12661            dfout = pd.concat( [dfout,subdf], axis=0, ignore_index=False )
12662
12663    if dfout.shape[0] > 0:
12664        badnames = get_names_from_data_frame( ['Unnamed'], dfout )
12665        dfout=dfout.drop(badnames, axis=1)
12666    return dfout

Aggregate ANTsPyMM results from the specified study data frame and store the aggregated results in a new data frame. This assumes data is organized on disk as follows: rootdir/projectID/subjectID/date/outputid/imageid/ where outputid is modality-specific and created by ANTsPyMM processing.

Parameters:

  • study_df (pandas df): pandas data frame, output of generate_mm_dataframe.
  • project_col (str): Name of the column that stores the project ID
  • subject_col (str): Name of the column to store subject IDs.
  • date_col (str): Name of the column to store date information.
  • image_col (str): Name of the column to store image IDs.
  • base_path (str): Base path for searching for processing outputs of ANTsPyMM.
  • hiervariable (str) : the string variable denoting the Hierarchical output
  • splitsep (str): the separator used to split the filename
  • idsep (str): the separator used to partition subjectid date and imageid for example, if idsep is - then we have subjectid-date-imageid
  • wild_card_modality_id (bool): keep if False for safer execution
  • second_split (bool): this is a hack that will split the imageID by . and keep the first part of the split; may be needed when the input filenames contain .
  • verbose : boolean

Note: This function is tested under limited circumstances. Use with caution. One particular gotcha is if the imageID is stored as a numeric value in the dataframe but is meant to be a string. E.g. '000' (string) would be interpreted as 0 in the file name glob. This would miss the extant (on disk) csv.

Example usage: agg_df = aggregate_antspymm_results_sdf( studydf, subject_col='subjectID', date_col='date', image_col='imageID', base_path="./Your/Custom/Path/")

Author: Avants and ChatGPT

def study_dataframe_from_matched_dataframe(matched_dataframe, rootdir, outputdir, verbose=False):
1231def study_dataframe_from_matched_dataframe( matched_dataframe, rootdir, outputdir, verbose=False ):
1232    """
1233    converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv
1234
1235    matched_dataframe : output of antspymm.match_modalities
1236
1237    rootdir : location for the input data root folder (in e.g. NRG format)
1238
1239    outputdir : location for the output data
1240
1241    verbose : boolean
1242    """
1243    iext='.nii.gz'
1244    from os.path import exists
1245    musthavecols = ['projectID', 'subjectID','date','imageID','filename']
1246    for k in range(len(musthavecols)):
1247        if not musthavecols[k] in matched_dataframe.keys():
1248            raise ValueError('matched_dataframe is missing column ' + musthavecols[k] + ' in study_dataframe_from_qc_dataframe' )
1249    csvrow=matched_dataframe.dropna(axis=1)
1250    pid=get_first_item_as_string( csvrow, 'projectID'  )
1251    sid=get_first_item_as_string( csvrow, 'subjectID'  ) # str(csvrow['subjectID'].iloc[0] )
1252    dt=get_first_item_as_string( csvrow, 'date'  )  # str(csvrow['date'].iloc[0])
1253    iid=get_first_item_as_string( csvrow, 'imageID'  ) # str(csvrow['imageID'].iloc[0])
1254    nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
1255    if not exists( nrgt1fn ) and iid == '0':
1256        iid='000'
1257        nrgt1fn=os.path.join( rootdir, pid, sid, dt, 'T1w', iid, str(csvrow['filename'].iloc[0]+iext) )
1258    if not exists( nrgt1fn ):
1259        raise ValueError("T1 " + nrgt1fn + " does not exist in study_dataframe_from_qc_dataframe")
1260    flList=[]
1261    dtList=[]
1262    rsfList=[]
1263    nmList=[]
1264    perfList=[]
1265    if 'flairfn' in csvrow.keys():
1266        flid=get_first_item_as_string( csvrow, 'flairid' )
1267        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
1268        if not exists( nrgt2fn ) and flid == '0':
1269            flid='000'
1270            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'T2Flair', flid, str(csvrow['flairfn'].iloc[0]+iext) )
1271        if verbose:
1272            print("Trying " + nrgt2fn )
1273        if exists( nrgt2fn ):
1274            if verbose:
1275                print("success" )
1276            flList.append( nrgt2fn )
1277    if 'perffn' in csvrow.keys():
1278        flid=get_first_item_as_string( csvrow, 'perfid' )
1279        nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
1280        if not exists( nrgt2fn ) and flid == '0':
1281            flid='000'
1282            nrgt2fn=os.path.join( rootdir, pid, sid, dt, 'perf', flid, str(csvrow['perffn'].iloc[0]+iext) )
1283        if verbose:
1284            print("Trying " + nrgt2fn )
1285        if exists( nrgt2fn ):
1286            if verbose:
1287                print("success" )
1288            perfList.append( nrgt2fn )
1289    if 'dtfn1' in csvrow.keys():
1290        dtid=get_first_item_as_string( csvrow, 'dtid1' )
1291        dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
1292        if len( dtfn1) == 0 :
1293            dtid = '000'
1294            dtfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn1'].iloc[0]+iext) ))
1295        dtfn1=dtfn1[0]
1296        if exists( dtfn1 ):
1297            dtList.append( dtfn1 )
1298    if 'dtfn2' in csvrow.keys():
1299        dtid=get_first_item_as_string( csvrow, 'dtid2' )
1300        dtfn2=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
1301        if len( dtfn2) == 0 :
1302            dtid = '000'
1303            dtfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn2'].iloc[0]+iext) ))
1304        dtfn2=dtfn2[0]
1305        if exists( dtfn2 ):
1306            dtList.append( dtfn2 )
1307    if 'dtfn3' in csvrow.keys():
1308        dtid=get_first_item_as_string( csvrow, 'dtid3' )
1309        dtfn3=glob.glob(os.path.join(rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
1310        if len( dtfn3) == 0 :
1311            dtid = '000'
1312            dtfn3=glob.glob(os.path.join( rootdir, pid, sid, dt, 'DTI*', dtid, str(csvrow['dtfn3'].iloc[0]+iext) ))
1313        dtfn3=dtfn3[0]
1314        if exists( dtfn3 ):
1315            dtList.append( dtfn3 )
1316    if 'rsffn1' in csvrow.keys():
1317        rsid=get_first_item_as_string( csvrow, 'rsfid1' )
1318        rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
1319        if len( rsfn1 ) == 0 :
1320            rsid = '000'
1321            rsfn1=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn1'].iloc[0]+iext) ))
1322        rsfn1=rsfn1[0]
1323        if exists( rsfn1 ):
1324            rsfList.append( rsfn1 )
1325    if 'rsffn2' in csvrow.keys():
1326        rsid=get_first_item_as_string( csvrow, 'rsfid2' )
1327        rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))[0]
1328        if len( rsfn2 ) == 0 :
1329            rsid = '000'
1330            rsfn2=glob.glob(os.path.join( rootdir, pid, sid, dt, 'rsfMRI*', rsid, str(csvrow['rsffn2'].iloc[0]+iext) ))
1331        rsfn2=rsfn2[0]
1332        if exists( rsfn2 ):
1333            rsfList.append( rsfn2 )
1334    for j in range(11):
1335        keyname="nmfn"+str(j)
1336        keynameid="nmid"+str(j)
1337        if keyname in csvrow.keys() and keynameid in csvrow.keys():
1338            nmid=get_first_item_as_string( csvrow, keynameid )
1339            nmsearchpath=os.path.join( rootdir, pid, sid, dt, 'NM2DMT', nmid, "*"+nmid+iext)
1340            nmfn=glob.glob( nmsearchpath )
1341            nmfn=nmfn[0]
1342            if exists( nmfn ):
1343                nmList.append( nmfn )
1344    if verbose:
1345        print("assembled the image lists mapping to ....")
1346        print(nrgt1fn)
1347        print("NM")
1348        print(nmList)
1349        print("FLAIR")
1350        print(flList)
1351        print("DTI")
1352        print(dtList)
1353        print("rsfMRI")
1354        print(rsfList)
1355        print("perf")
1356        print(perfList)
1357    studycsv = generate_mm_dataframe(
1358        pid,
1359        sid,
1360        dt,
1361        iid, # the T1 id
1362        'T1w',
1363        rootdir,
1364        outputdir,
1365        t1_filename=nrgt1fn,
1366        flair_filename=flList,
1367        dti_filenames=dtList,
1368        rsf_filenames=rsfList,
1369        nm_filenames=nmList,
1370        perf_filename=perfList)
1371    return studycsv.dropna(axis=1)

converts the output of antspymm.match_modalities dataframe (one row) to that needed for a study-driving dataframe for input to mm_csv

matched_dataframe : output of antspymm.match_modalities

rootdir : location for the input data root folder (in e.g. NRG format)

outputdir : location for the output data

verbose : boolean

def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False, progress=False, verbose=False):
9893def merge_wides_to_study_dataframe( sdf, processing_dir, separator='-', sid_is_int=True, id_is_int=True, date_is_int=True, report_missing=False,
9894progress=False, verbose=False ):
9895    """
9896    extend a study data frame with wide outputs
9897
9898    sdf : the input study dataframe from antspymm QC output
9899
9900    processing_dir:  the directory location of the processed data 
9901
9902    separator : string usually '-' or '_'
9903
9904    sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas
9905
9906    date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas
9907
9908    id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas
9909
9910    report_missing : boolean combined with verbose will report missing modalities
9911
9912    progress : integer reports percent progress modulo progress value 
9913
9914    verbose : boolean
9915    """
9916    from os.path import exists
9917    musthavecols = ['projectID', 'subjectID','date','imageID']
9918    for k in range(len(musthavecols)):
9919        if not musthavecols[k] in sdf.keys():
9920            raise ValueError('sdf is missing column ' +musthavecols[k] + ' in merge_wides_to_study_dataframe' )
9921    possible_iids = [ 'imageID', 'imageID', 'imageID', 'flairid', 'dtid1', 'dtid2', 'rsfid1', 'rsfid2', 'nmid1', 'nmid2', 'nmid3', 'nmid4', 'nmid5', 'nmid6', 'nmid7', 'nmid8', 'nmid9', 'nmid10', 'perfid' ]
9922    modality_ids = [ 'T1wHierarchical', 'T1wHierarchicalSR', 'T1w', 'T2Flair', 'DTI', 'DTI', 'rsfMRI', 'rsfMRI', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'NM2DMT', 'perf']
9923    alldf=pd.DataFrame()
9924    for myk in sdf.index:
9925        if progress > 0 and int(myk) % int(progress) == 0:
9926            print( str( round( myk/sdf.shape[0]*100.0)) + "%...", end='', flush=True)
9927        if verbose:
9928            print( "DOROW " + str(myk) + ' of ' + str( sdf.shape[0] ) )
9929        csvrow = sdf.loc[sdf.index == myk].dropna(axis=1)
9930        ct=-1
9931        for iidkey in possible_iids:
9932            ct=ct+1
9933            mod_name = modality_ids[ct]
9934            if iidkey in csvrow.keys():
9935                if id_is_int:
9936                    iid = str( int( csvrow[iidkey].iloc[0] ) )
9937                else:
9938                    iid = str( csvrow[iidkey].iloc[0] )
9939                if verbose:
9940                    print( "iidkey " + iidkey + " modality " + mod_name + ' iid '+ iid )
9941                pid=str(csvrow['projectID'].iloc[0] )
9942                if sid_is_int:
9943                    sid=str(int(csvrow['subjectID'].iloc[0] ))
9944                else:
9945                    sid=str(csvrow['subjectID'].iloc[0] )
9946                if date_is_int:
9947                    dt=str(int(csvrow['date'].iloc[0]))
9948                else:
9949                    dt=str(csvrow['date'].iloc[0])
9950                if id_is_int:
9951                    t1iid=str(int(csvrow['imageID'].iloc[0]))
9952                else:
9953                    t1iid=str(csvrow['imageID'].iloc[0])
9954                if t1iid != iid:
9955                    iidj=iid+"_"+t1iid
9956                else:
9957                    iidj=iid
9958                rootid = pid +separator+ sid +separator+dt+separator+mod_name+separator+iidj
9959                myext = rootid +separator+'mmwide.csv'
9960                nrgwidefn=os.path.join( processing_dir, pid, sid, dt, mod_name, iid, myext )
9961                moddersub = mod_name
9962                is_t1=False
9963                if mod_name == 'T1wHierarchical':
9964                    is_t1=True
9965                    moddersub='T1Hier'
9966                elif mod_name == 'T1wHierarchicalSR':
9967                    is_t1=True
9968                    moddersub='T1HSR'
9969                if exists( nrgwidefn ):
9970                    if verbose:
9971                        print( nrgwidefn + " exists")
9972                    mm=read_mm_csv( nrgwidefn, colprefix=moddersub+'_', is_t1=is_t1, separator=separator, verbose=verbose )
9973                    if mm is not None:
9974                        if mod_name == 'T1wHierarchical':
9975                            a=list( csvrow.keys() )
9976                            b=list( mm.keys() )
9977                            abintersect=list(set(b).intersection( set(a) ) )
9978                            if len( abintersect  ) > 0 :
9979                                for qq in abintersect:
9980                                    mm.pop( qq )
9981                        # mm.index=csvrow.index
9982                        uidname = mod_name + '_mmwide_filename'
9983                        mm[ uidname ] = rootid
9984                        csvrow=pd.concat( [csvrow,mm], axis=1, ignore_index=False )
9985                else:
9986                    if verbose and report_missing:
9987                        print( nrgwidefn + " absent")
9988        if alldf.shape[0] == 0:
9989            alldf = csvrow.copy()
9990            alldf = alldf.loc[:,~alldf.columns.duplicated()]
9991        else:
9992            csvrow=csvrow.loc[:,~csvrow.columns.duplicated()]
9993            alldf = alldf.loc[:,~alldf.columns.duplicated()]
9994            alldf = pd.concat( [alldf, csvrow], axis=0, ignore_index=True )
9995    return alldf

extend a study data frame with wide outputs

sdf : the input study dataframe from antspymm QC output

processing_dir: the directory location of the processed data

separator : string usually '-' or '_'

sid_is_int : boolean set to True to cast unique subject ids to int; can be useful if they are inadvertently stored as float by pandas

date_is_int : boolean set to True to cast date to int; can be useful if they are inadvertently stored as float by pandas

id_is_int : boolean set to True to cast unique image ids to int; can be useful if they are inadvertently stored as float by pandas

report_missing : boolean combined with verbose will report missing modalities

progress : integer reports percent progress modulo progress value

verbose : boolean

def filter_image_files(image_paths, criteria='largest'):
12705def filter_image_files(image_paths, criteria='largest'):
12706    """
12707    Filters a list of image file paths based on specified criteria and returns 
12708    the path of the image that best matches that criteria (smallest, largest, or brightest).
12709
12710    Args:
12711    image_paths (list): A list of file paths to the images.
12712    criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').
12713
12714    Returns:
12715    str: The file path of the selected image, or None if no valid images are found.
12716    """
12717    import numpy as np
12718    if not image_paths:
12719        return None
12720
12721    selected_image_path = None
12722    if criteria == 'smallest' or criteria == 'largest':
12723        extreme_volume = None
12724
12725        for path in image_paths:
12726            try:
12727                image = ants.image_read(path)
12728                volume = np.prod(image.shape)
12729
12730                if criteria == 'largest':
12731                    if extreme_volume is None or volume > extreme_volume:
12732                        extreme_volume = volume
12733                        selected_image_path = path
12734                elif criteria == 'smallest':
12735                    if extreme_volume is None or volume < extreme_volume:
12736                        extreme_volume = volume
12737                        selected_image_path = path
12738
12739            except Exception as e:
12740                print(f"Error processing image {path}: {e}")
12741
12742    elif criteria == 'brightest':
12743        max_brightness = None
12744
12745        for path in image_paths:
12746            try:
12747                image = ants.image_read(path)
12748                brightness = np.mean(image.numpy())
12749
12750                if max_brightness is None or brightness > max_brightness:
12751                    max_brightness = brightness
12752                    selected_image_path = path
12753
12754            except Exception as e:
12755                print(f"Error processing image {path}: {e}")
12756
12757    else:
12758        raise ValueError("Criteria must be 'smallest', 'largest', or 'brightest'.")
12759
12760    return selected_image_path

Filters a list of image file paths based on specified criteria and returns the path of the image that best matches that criteria (smallest, largest, or brightest).

Args: image_paths (list): A list of file paths to the images. criteria (str): Criteria for selecting the image ('smallest', 'largest', 'brightest').

Returns: str: The file path of the selected image, or None if no valid images are found.

def docsamson( locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
522def docsamson(locmod, studycsv, outputdir, projid, sid, dtid, mysep, t1iid=None, verbose=True):
523    """
524    Processes image file names based on the specified imaging modality and other parameters.
525
526    The function selects file names from the provided dictionary `studycsv` based on the imaging modality.
527    It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly.
528    The function can optionally print verbose output during processing.
529
530    Parameters:
531    locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'.
532    studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names.
533    outputdir (str): Base directory for output files.
534    projid (str): Project identifier.
535    sid (str): Subject identifier.
536    dtid (str): Data acquisition time identifier.
537    mysep (str): Separator used in file naming.
538    t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'.
539    verbose (bool, optional): If True, prints detailed information during execution.
540
541    Returns:
542    dict: A dictionary with keys 'modality', 'outprefix', and 'images'.
543        - 'modality' (str): The imaging modality used.
544        - 'outprefix' (str): The prefix for output file paths.
545        - 'images' (list): A list of processed image file names.
546
547    Notes:
548    - The function is designed to work within a specific workflow and might require adaptation for general use.
549
550    Examples:
551    >>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
552    >>> print(result['modality'])
553    'T1w'
554    >>> print(result['outprefix'])
555    '/path/to/output/directory/T1w/some_identifier'
556    >>> print(result['images'])
557    ['image1.nii', 'image2.nii']
558    """
559
560    import os
561    import re
562
563    myimgsInput = []
564    myoutputPrefix = None
565    imfns = ['filename', 'rsfid1', 'rsfid2', 'dtid1', 'dtid2', 'flairid']
566    
567    # Define image file names based on the modality
568    if locmod == 'T1w':
569        imfns=['filename']
570    elif locmod == 'T2Flair':
571        imfns=['flairid']
572    elif locmod == 'perf':
573        imfns=['perfid']
574    elif locmod == 'pet3d':
575        imfns=['pet3did']
576    elif locmod == 'NM2DMT':
577        imfns=[]
578        for i in range(11):
579            imfns.append('nmid' + str(i))
580    elif locmod == 'rsfMRI':
581        imfns=[]
582        for i in range(4):
583            imfns.append('rsfid' + str(i))
584    elif locmod == 'DTI':
585        imfns=[]
586        for i in range(4):
587            imfns.append('dtid' + str(i))
588    else:
589        raise ValueError("docsamson: no match of modality to filename id " + locmod )
590
591    # Process each file name
592    for i in imfns:
593        if verbose:
594            print(i + " " + locmod)
595        if i in studycsv.keys():
596            fni = str(studycsv[i].iloc[0])
597            if verbose:
598                print(i + " " + fni + ' exists ' + str(os.path.exists(fni)))
599            if os.path.exists(fni):
600                myimgsInput.append(fni)
601                temp = os.path.basename(fni)
602                mysplit = temp.split(mysep)
603                iid = re.sub(".nii.gz", "", mysplit[-1])
604                iid = re.sub(".mha", "", iid)
605                iid = re.sub(".nii", "", iid)
606                iid2 = iid
607                if locmod != 'T1w' and t1iid is not None:
608                    iid2 = iid + "_" + t1iid
609                else:
610                    iid2 = t1iid
611                myoutputPrefix = os.path.join(outputdir, projid, sid, dtid, locmod, iid, projid + mysep + sid + mysep + dtid + mysep + locmod + mysep + iid2)
612    
613    if verbose:
614        print(locmod)
615        print(myimgsInput)
616        print(myoutputPrefix)
617    
618    return {
619        'modality': locmod,
620        'outprefix': myoutputPrefix,
621        'images': myimgsInput
622    }

Processes image file names based on the specified imaging modality and other parameters.

The function selects file names from the provided dictionary studycsv based on the imaging modality. It supports various modalities like T1w, T2Flair, perf, NM2DMT, rsfMRI, DTI, and configures the filenames accordingly. The function can optionally print verbose output during processing.

Parameters: locmod (str): The imaging modality. Options include 'T1w', 'T2Flair', 'perf', 'NM2DMT', 'rsfMRI', 'DTI'. studycsv (dict): A dictionary with keys corresponding to imaging modalities and values as file names. outputdir (str): Base directory for output files. projid (str): Project identifier. sid (str): Subject identifier. dtid (str): Data acquisition time identifier. mysep (str): Separator used in file naming. t1iid (str, optional): Identifier related to T1-weighted images, used in naming output files when locmod is not 'T1w'. verbose (bool, optional): If True, prints detailed information during execution.

Returns: dict: A dictionary with keys 'modality', 'outprefix', and 'images'. - 'modality' (str): The imaging modality used. - 'outprefix' (str): The prefix for output file paths. - 'images' (list): A list of processed image file names.

Notes:

  • The function is designed to work within a specific workflow and might require adaptation for general use.

Examples:

>>> result = docsamson('T1w', studycsv, outputdir, projid, sid, dtid, mysep)
>>> print(result['modality'])
'T1w'
>>> print(result['outprefix'])
'/path/to/output/directory/T1w/some_identifier'
>>> print(result['images'])
['image1.nii', 'image2.nii']
def enantiomorphic_filling_without_mask(image, axis=0, intensity='low'):
12668def enantiomorphic_filling_without_mask( image, axis=0, intensity='low' ):
12669    """
12670    Perform an enantiomorphic lesion filling on an image without a lesion mask.
12671
12672    Args:
12673    image (antsImage): The ants image to flip and fill
12674    axis ( int ): the axis along which to reflect the image
12675    intensity ( str ) : low or high
12676
12677    Returns:
12678    ants.ANTsImage: The image after enantiomorphic filling.
12679    """
12680    imagen = ants.iMath( image, 'Normalize' )
12681    imagen = ants.iMath( imagen, "TruncateIntensity", 1e-6, 0.98 )
12682    imagen = ants.iMath( imagen, 'Normalize' )
12683    # Create a mirror image (flipping left and right)
12684    mirror_image = ants.reflect_image(imagen, axis=0, tx='antsRegistrationSyNQuickRepro[s]' )['warpedmovout']
12685
12686    # Create a symmetric version of the image by averaging the original and the mirror image
12687    symmetric_image = imagen * 0.5 + mirror_image * 0.5
12688
12689    # Identify potential lesion areas by finding differences between the original and symmetric image
12690    difference_image = image - symmetric_image
12691    diffseg = ants.threshold_image(difference_image, "Otsu", 3 )
12692    if intensity == 'low':
12693        likely_lesion = ants.threshold_image( diffseg, 1,  1)
12694    else:
12695        likely_lesion = ants.threshold_image( diffseg, 3,  3)
12696    likely_lesion = ants.smooth_image( likely_lesion, 3.0 ).iMath("Normalize")
12697    lesionneg = ( imagen*0+1.0 ) - likely_lesion
12698    filled_image = ants.image_clone(imagen)    
12699    filled_image = imagen * lesionneg + mirror_image * likely_lesion
12700
12701    return filled_image, diffseg

Perform an enantiomorphic lesion filling on an image without a lesion mask.

Args: image (antsImage): The ants image to flip and fill axis ( int ): the axis along which to reflect the image intensity ( str ) : low or high

Returns: ants.ANTsImage: The image after enantiomorphic filling.

def wmh( flair, t1, t1seg, mmfromconvexhull=3.0, strict=True, probability_mask=None, prior_probability=None, model='sysu', verbose=False):
10988def wmh( flair, t1, t1seg,
10989    mmfromconvexhull = 3.0,
10990    strict=True,
10991    probability_mask=None,
10992    prior_probability=None,
10993    model='sysu',
10994    verbose=False ) :
10995    """
10996    Outputs the WMH probability mask and a summary single measurement
10997
10998    Arguments
10999    ---------
11000    flair : ANTsImage
11001        input 3-D FLAIR brain image (not skull-stripped).
11002
11003    t1 : ANTsImage
11004        input 3-D T1 brain image (not skull-stripped).
11005
11006    t1seg : ANTsImage
11007        T1 segmentation image
11008
11009    mmfromconvexhull : float
11010        restrict WMH to regions that are WM or mmfromconvexhull mm away from the
11011        convex hull of the cerebrum.   we choose a default value based on
11012        Figure 4 from:
11013        https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf
11014
11015    strict: boolean - if True, only use convex hull distance
11016
11017    probability_mask : None - use to compute wmh just once - then this function
11018        just does refinement and summary
11019
11020    prior_probability : optional prior probability image in space of the input t1
11021
11022    model : either sysu or hyper
11023
11024    verbose : boolean
11025
11026    Returns
11027    ---------
11028    WMH probability map and a summary single measurement which is the sum of the WMH map
11029
11030    """
11031    import numpy as np
11032    import math
11033    t1_2_flair_reg = ants.registration(flair, t1, type_of_transform = 'antsRegistrationSyNRepro[r]') # Register T1 to Flair
11034    if probability_mask is None and model == 'sysu':
11035        if verbose:
11036            print('sysu')
11037        probability_mask = antspynet.sysu_media_wmh_segmentation( flair )
11038    elif probability_mask is None and model == 'hyper':
11039        if verbose:
11040            print('hyper')
11041        probability_mask = antspynet.hypermapp3r_segmentation( t1_2_flair_reg['warpedmovout'], flair )
11042    # t1_2_flair_reg = tra_initializer( flair, t1, n_simulations=4, max_rotation=5, transform=['rigid'], verbose=False )
11043    prior_probability_flair = None
11044    if prior_probability is not None:
11045        prior_probability_flair = ants.apply_transforms( flair, prior_probability,
11046            t1_2_flair_reg['fwdtransforms'] )
11047    wmseg_mask = ants.threshold_image( t1seg,
11048        low_thresh = 3, high_thresh = 3).iMath("FillHoles")
11049    wmseg_mask_use = ants.image_clone( wmseg_mask )
11050    distmask = None
11051    if mmfromconvexhull > 0:
11052            convexhull = ants.threshold_image( t1seg, 1, 4 )
11053            spc2vox = np.prod( ants.get_spacing( t1seg ) )
11054            voxdist = 0.0
11055            myspc = ants.get_spacing( t1seg )
11056            for k in range( t1seg.dimension ):
11057                voxdist = voxdist + myspc[k] * myspc[k]
11058            voxdist = math.sqrt( voxdist )
11059            nmorph = round( 2.0 / voxdist )
11060            convexhull = ants.morphology( convexhull, "close", nmorph ).iMath("FillHoles")
11061            dist = ants.iMath( convexhull, "MaurerDistance" ) * -1.0
11062            distmask = ants.threshold_image( dist, mmfromconvexhull, 1.e80 )
11063            wmseg_mask = wmseg_mask + distmask
11064            if strict:
11065                wmseg_mask_use = ants.threshold_image( wmseg_mask, 2, 2 )
11066            else:
11067                wmseg_mask_use = ants.threshold_image( wmseg_mask, 1, 2 )
11068    ##############################################################################
11069    wmseg_2_flair = ants.apply_transforms(flair, wmseg_mask_use,
11070        transformlist = t1_2_flair_reg['fwdtransforms'],
11071        interpolator = 'nearestNeighbor' )
11072    seg_2_flair = ants.apply_transforms(flair, t1seg,
11073        transformlist = t1_2_flair_reg['fwdtransforms'],
11074        interpolator = 'nearestNeighbor' )
11075    csfmask = ants.threshold_image(seg_2_flair,1,1)
11076    flairsnr = mask_snr( flair, csfmask, wmseg_2_flair, bias_correct = False )
11077    probability_mask_WM = wmseg_2_flair * probability_mask # Remove WMH signal outside of WM
11078    wmh_sum = np.prod( ants.get_spacing( flair ) ) * probability_mask_WM.sum()
11079    wmh_sum_prior = math.nan
11080    probability_mask_posterior = None
11081    if prior_probability_flair is not None:
11082        probability_mask_posterior = prior_probability_flair * probability_mask # use prior
11083        wmh_sum_prior = np.prod( ants.get_spacing(flair) ) * probability_mask_posterior.sum()
11084    if math.isnan( wmh_sum ):
11085        wmh_sum=0
11086    if math.isnan( wmh_sum_prior ):
11087        wmh_sum_prior=0
11088    flair_evr = antspyt1w.patch_eigenvalue_ratio( flair, 512, [16,16,16], evdepth = 0.9, mask=wmseg_2_flair )
11089    return{
11090        'WMH_probability_map_raw': probability_mask,
11091        'WMH_probability_map' : probability_mask_WM,
11092        'WMH_posterior_probability_map' : probability_mask_posterior,
11093        'wmh_mass': wmh_sum,
11094        'wmh_mass_prior': wmh_sum_prior,
11095        'wmh_evr' : flair_evr,
11096        'wmh_SNR' : flairsnr,
11097        'convexhull_mask': distmask }

Outputs the WMH probability mask and a summary single measurement

Arguments

flair : ANTsImage input 3-D FLAIR brain image (not skull-stripped).

t1 : ANTsImage input 3-D T1 brain image (not skull-stripped).

t1seg : ANTsImage T1 segmentation image

mmfromconvexhull : float restrict WMH to regions that are WM or mmfromconvexhull mm away from the convex hull of the cerebrum. we choose a default value based on Figure 4 from: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6240579/pdf/fnagi-10-00339.pdf

strict: boolean - if True, only use convex hull distance

probability_mask : None - use to compute wmh just once - then this function just does refinement and summary

prior_probability : optional prior probability image in space of the input t1

model : either sysu or hyper

verbose : boolean

Returns

WMH probability map and a summary single measurement which is the sum of the WMH map

def remove_elements_from_numpy_array(original_array, indices_to_remove):
11140def remove_elements_from_numpy_array(original_array, indices_to_remove):
11141    """
11142    Remove specified elements or rows from a numpy array.
11143
11144    Parameters:
11145    original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed.
11146    indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.
11147
11148    Returns:
11149    numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None,
11150                   the function returns None.
11151    """
11152
11153    if original_array is None:
11154        return None
11155
11156    if original_array.ndim == 1:
11157        # Remove elements from a 1D array
11158        return np.delete(original_array, indices_to_remove)
11159    elif original_array.ndim == 2:
11160        # Remove rows from a 2D array
11161        return np.delete(original_array, indices_to_remove, axis=0)
11162    else:
11163        raise ValueError("original_array must be either 1D or 2D.")

Remove specified elements or rows from a numpy array.

Parameters: original_array (numpy.ndarray): A numpy array from which elements or rows are to be removed. indices_to_remove (list or numpy.ndarray): Indices of elements or rows to be removed.

Returns: numpy.ndarray: A new numpy array with the specified elements or rows removed. If the input array is None, the function returns None.

def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg):
11466def score_fmri_censoring(cbfts, csf_seg, gm_seg, wm_seg ):
11467    """
11468    Process CBF time series to remove high-leverage points.
11469    Derived from the SCORE algorithm by Sudipto Dolui et. al.
11470
11471    Parameters:
11472    cbfts (ANTsImage): 4D ANTsImage of CBF time series.
11473    csf_seg (ANTsImage): CSF binary map.
11474    gm_seg (ANTsImage): Gray matter binary map.
11475    wm_seg (ANTsImage): WM binary map.
11476
11477    Returns:
11478    ANTsImage: Processed CBF time series.
11479    ndarray: Index of removed volumes.
11480    """
11481    
11482    n_gm_voxels = np.sum(gm_seg.numpy()) - 1
11483    n_wm_voxels = np.sum(wm_seg.numpy()) - 1
11484    n_csf_voxels = np.sum(csf_seg.numpy()) - 1
11485    mask1img = gm_seg + wm_seg + csf_seg
11486    mask1 = (mask1img==1).numpy()
11487    
11488    cbfts_np = cbfts.numpy()
11489    gmbool = (gm_seg==1).numpy()
11490    csfbool = (csf_seg==1).numpy()
11491    wmbool = (wm_seg==1).numpy()
11492    gm_cbf_ts = ants.timeseries_to_matrix( cbfts, gm_seg )
11493    gm_cbf_ts = np.squeeze(np.mean(gm_cbf_ts, axis=1))
11494    
11495    median_gm_cbf = np.median(gm_cbf_ts)
11496    mad_gm_cbf = np.median(np.abs(gm_cbf_ts - median_gm_cbf)) / 0.675
11497    indx = np.abs(gm_cbf_ts - median_gm_cbf) > (2.5 * mad_gm_cbf)
11498    
11499    # the spatial mean
11500    spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11501    spatmean = ants.from_numpy( spatmeannp )
11502    V = (
11503        n_gm_voxels * np.var(spatmeannp[gmbool])
11504        + n_wm_voxels * np.var(spatmeannp[wmbool])
11505        + n_csf_voxels * np.var(spatmeannp[csfbool])
11506    )
11507    V1 = math.inf
11508    ct=0
11509    while V < V1:
11510        ct=ct+1
11511        V1 = V
11512        CC = np.zeros(cbfts_np.shape[3])
11513        for s in range(cbfts_np.shape[3]):
11514            if indx[s]:
11515                continue
11516            tmp1 = ants.from_numpy( cbfts_np[:, :, :, s] )
11517            CC[s] = ants.image_similarity( spatmean, tmp1, metric_type='Correlation', fixed_mask=mask1img )
11518        inx = np.argmin(CC)
11519        indx[inx] = True
11520        spatmeannp = np.mean(cbfts_np[:, :, :, ~indx], axis=3)
11521        spatmean = ants.from_numpy( spatmeannp )
11522        V = (
11523          n_gm_voxels * np.var(spatmeannp[gmbool]) + 
11524          n_wm_voxels * np.var(spatmeannp[wmbool]) + 
11525          n_csf_voxels * np.var(spatmeannp[csfbool])
11526        )
11527    cbfts_recon = cbfts_np[:, :, :, ~indx]
11528    cbfts_recon = np.nan_to_num(cbfts_recon)
11529    cbfts_recon_ants = ants.from_numpy(cbfts_recon)
11530    cbfts_recon_ants = ants.copy_image_info(cbfts, cbfts_recon_ants)
11531    return cbfts_recon_ants, indx

Process CBF time series to remove high-leverage points. Derived from the SCORE algorithm by Sudipto Dolui et. al.

Parameters: cbfts (ANTsImage): 4D ANTsImage of CBF time series. csf_seg (ANTsImage): CSF binary map. gm_seg (ANTsImage): Gray matter binary map. wm_seg (ANTsImage): WM binary map.

Returns: ANTsImage: Processed CBF time series. ndarray: Index of removed volumes.

def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11165def remove_volumes_from_timeseries(time_series, volumes_to_remove):
11166    """
11167    Remove specified volumes from a time series.
11168
11169    :param time_series: ANTsImage representing the time series (4D image).
11170    :param volumes_to_remove: List of volume indices to remove.
11171    :return: ANTsImage with specified volumes removed.
11172    """
11173    if not isinstance(time_series, ants.core.ants_image.ANTsImage):
11174        raise ValueError("time_series must be an ANTsImage.")
11175
11176    if time_series.dimension != 4:
11177        raise ValueError("time_series must be a 4D image.")
11178
11179    # Create a boolean index for volumes to keep
11180    volumes_to_keep = [i for i in range(time_series.shape[3]) if i not in volumes_to_remove]
11181
11182    # Select the volumes to keep
11183    filtered_time_series = ants.from_numpy( time_series.numpy()[..., volumes_to_keep] )
11184
11185    return ants.copy_image_info( time_series, filtered_time_series )

Remove specified volumes from a time series.

Parameters
  • time_series: ANTsImage representing the time series (4D image).
  • volumes_to_remove: List of volume indices to remove.
Returns

ANTsImage with specified volumes removed.

def loop_timeseries_censoring( x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11533def loop_timeseries_censoring(x, threshold=0.5, mask=None, n_features_sample=0.02, seed=42, verbose=True):
11534    """
11535    Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).
11536
11537    Parameters:
11538    x (ANTsImage): A 4D time series image.
11539    threshold (float): Threshold for determining high leverage volumes based on LoOP scores.
11540    mask (antsImage): restricts to a ROI
11541    n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used
11542    seed (int): random seed
11543    verbose (bool)
11544
11545    Returns:
11546    tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.
11547    """
11548    import warnings
11549    if x.shape[3] < 20: # just a guess at what we need here ...
11550        warnings.warn("Warning: the time dimension is < 20 - too few samples for loop. just return the original data.")
11551        return x, []
11552    if mask is None:
11553        flattened_series = flatten_time_series(x.numpy())
11554    else:
11555        flattened_series = ants.timeseries_to_matrix( x, mask )
11556    if verbose:
11557        print("loop_timeseries_censoring: flattened")
11558    loop_scores = calculate_loop_scores(flattened_series, n_features_sample=n_features_sample, seed=seed, verbose=verbose )
11559    high_leverage_volumes = np.where(loop_scores > threshold)[0]
11560    if verbose:
11561        print("loop_timeseries_censoring: High Leverage Volumes:", high_leverage_volumes)
11562    new_asl = remove_volumes_from_timeseries(x, high_leverage_volumes)
11563    return new_asl, high_leverage_volumes

Censor high leverage volumes from a time series using Local Outlier Probabilities (LoOP).

Parameters: x (ANTsImage): A 4D time series image. threshold (float): Threshold for determining high leverage volumes based on LoOP scores. mask (antsImage): restricts to a ROI n_features_sample (int/float): feature sample size default 0.01; if less than one then this is interpreted as a percentage of the total features otherwise it sets the number of features to be used seed (int): random seed verbose (bool)

Returns: tuple: A tuple containing the censored time series (ANTsImage) and the indices of the high leverage volumes.

def clean_tmp_directory( age_hours=1.0, use_sudo=False, extensions=['.nii', '.nii.gz'], log_file_path=None):
469def clean_tmp_directory(age_hours=1., use_sudo=False, extensions=[ '.nii', '.nii.gz' ], log_file_path=None):
470    """
471    Clean the /tmp directory by removing files and directories older than a certain number of hours.
472    Optionally uses sudo and can filter files by extensions.
473
474    :param age_hours: Age in hours to consider files and directories for deletion.
475    :param use_sudo: Whether to use sudo for removal commands.
476    :param extensions: List of file extensions to delete. If None, all files are considered.
477    :param log_file_path: Path to the log file. If None, a default path will be used based on the OS.
478
479    # Usage
480    # Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])
481    """
482    import os
483    import platform
484    import subprocess
485    from datetime import datetime, timedelta
486
487    if not isinstance(age_hours, float):
488        return
489
490    # Determine the tmp directory based on the operating system
491    tmp_dir = '/tmp'
492
493    # Set the log file path
494    if log_file_path is not None:
495        log_file = log_file_path
496
497    current_time = datetime.now()
498    for item in os.listdir(tmp_dir):
499        try:
500            item_path = os.path.join(tmp_dir, item)
501            item_stat = os.stat(item_path)
502
503            # Calculate the age of the file/directory
504            item_age = current_time - datetime.fromtimestamp(item_stat.st_mtime)
505            if item_age > timedelta(hours=age_hours):
506                # Check for file extensions if provided
507                if extensions is None or any(item.endswith(ext) for ext in extensions):
508                    # Construct the removal command
509                    rm_command = ['sudo', 'rm', '-rf', item_path] if use_sudo else ['rm', '-rf', item_path]
510                    subprocess.run(rm_command)
511
512                if log_file_path is not None:
513                    with open(log_file, 'a') as log:
514                        log.write(f"{datetime.now()}: Deleted {item_path}\n")
515        except Exception as e:
516            if log_file_path is not None:
517                with open(log_file, 'a') as log:
518                    log.write(f"{datetime.now()}: Error deleting {item_path}: {e}\n")

Clean the /tmp directory by removing files and directories older than a certain number of hours. Optionally uses sudo and can filter files by extensions.

Parameters
  • age_hours: Age in hours to consider files and directories for deletion.
  • use_sudo: Whether to use sudo for removal commands.
  • extensions: List of file extensions to delete. If None, all files are considered.
  • log_file_path: Path to the log file. If None, a default path will be used based on the OS.

Usage

Example: clean_tmp_directory(age_hours=1, use_sudo=True, extensions=['.log', '.tmp'])

def validate_nrg_file_format(path, separator):
189def validate_nrg_file_format(path, separator):
190    """
191    is your path nrg-etic?
192    Validates if a given path conforms to the NRG file format, taking into account known extensions
193    and the expected directory structure.
194
195    :param path: The file path to validate.
196    :param separator: The separator used in the filename and directory structure.
197    :return: A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.
198
199    : example
200
201    ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz'
202    ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'
203
204    validate_nrg_detailed(ntfngood, '-')
205    print( validate_nrg_detailed(ntfn, '-') )
206    print( validate_nrg_detailed(ntfn, '_') )
207
208    """
209    import re    
210
211    def normalize_path(path):
212        """
213        Replace multiple repeated '/' with just a single '/'
214        
215        :param path: The file path to normalize.
216        :return: The normalized file path with single '/'.
217        """
218        normalized_path = re.sub(r'/+', '/', path)
219        return normalized_path
220
221    def strip_known_extension(filename, known_extensions):
222        """
223        Strips a known extension from the filename.
224
225        :param filename: The filename from which to strip the extension.
226        :param known_extensions: A list of known extensions to strip from the filename.
227        :return: The filename with the known extension stripped off, if found.
228        """
229        for ext in known_extensions:
230            if filename.endswith(ext):
231                # Strip the extension and return the modified filename
232                return filename[:-len(ext)]
233        # If no known extension is found, return the original filename
234        return filename
235
236    import warnings
237    if normalize_path( path ) != path:
238        path = normalize_path( path )
239        warnings.warn("Probably had multiple repeated slashes eg /// in the file path.  this might cause issues. clean up with re.sub(r'/+', '/', path)")
240
241    known_extensions = [".nii.gz", ".nii", ".mhd", ".nrrd", ".mha", ".json", ".bval", ".bvec"]
242    known_extensions2 = [ext.lstrip('.') for ext in known_extensions]
243    def get_extension(filename, known_extensions ):
244        # List of known extensions in priority order
245        for ext in known_extensions:
246            if filename.endswith(ext):
247                return ext.strip('.')
248        return "Invalid extension"
249    
250    parts = path.split('/')
251    if len(parts) < 7:  # Checking for minimum path structure
252        return False, "Path structure is incomplete. Expected at least 7 components, found {}.".format(len(parts))
253    
254    # Extract directory components and filename
255    directory_components = parts[1:-1]  # Exclude the root '/' and filename
256    filename = parts[-1]
257    filename_without_extension = strip_known_extension( filename, known_extensions )
258    file_extension = get_extension( filename, known_extensions )
259    
260    # Validating file extension
261    if file_extension not in known_extensions2:
262        print( file_extension )
263        return False, "Invalid file extension: {}. Expected 'nii.gz' or 'json'.".format(file_extension)
264    
265    # Splitting the filename to validate individual parts
266    filename_parts = filename_without_extension.split(separator)
267    if len(filename_parts) != 5:  # Expecting 5 parts based on the NRG format
268        print( filename_parts )
269        return False, "Filename does not have exactly 5 parts separated by '{}'. Found {} parts.".format(separator, len(filename_parts))
270    
271    # Reconstruct expected filename from directory components
272    expected_filename_parts = directory_components[-5:]
273    expected_filename = separator.join(expected_filename_parts)
274    if filename_without_extension != expected_filename:
275        print( filename_without_extension )
276        print("--- vs expected ---")
277        print( expected_filename )
278        return False, "Filename structure does not match directory structure. Expected filename: {}.".format(expected_filename)
279    
280    # Validate directory structure against NRG format
281    study_name, subject_id, session, modality = directory_components[-4:-1] + [directory_components[-1].split('/')[0]]
282    if not all([study_name, subject_id, session, modality]):
283        return False, "Directory structure does not follow NRG format. Ensure StudyName, SubjectID, Session (ses_x), and Modality are correctly specified."
284    
285    # If all checks pass
286    return True, "The path conforms to the NRG format."

is your path nrg-etic? Validates if a given path conforms to the NRG file format, taking into account known extensions and the expected directory structure.

Parameters
  • path: The file path to validate.
  • separator: The separator used in the filename and directory structure.
Returns

A tuple (bool, str) indicating whether the path is valid and a message explaining the validation result.

: example

ntfn='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses-1/rsfMRI_RL/000/ANTsLIMBIC_sub08C105120Yr_ses-1_rsfMRI_RL_000.nii.gz' ntfngood='/Users/ntustison/Data/Stone/LIMBIC/NRG/ANTsLIMBIC/sub08C105120Yr/ses_1/rsfMRI_RL/000/ANTsLIMBIC-sub08C105120Yr-ses_1-rsfMRI_RL-000.nii.gz'

validate_nrg_detailed(ntfngood, '-') print( validate_nrg_detailed(ntfn, '-') ) print( validate_nrg_detailed(ntfn, '_') )

def ants_to_nibabel_affine(ants_img):
390def ants_to_nibabel_affine(ants_img):
391    """
392    Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space).
393    Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).
394    
395    Returns:
396        4x4 np.ndarray affine matrix in RAS space.
397    """
398    spatial_dim = ants_img.dimension
399    spacing = np.array(ants_img.spacing)
400    origin = np.array(ants_img.origin)
401    direction = np.array(ants_img.direction).reshape((spatial_dim, spatial_dim))
402    # Compute rotation-scale matrix
403    affine_linear = direction @ np.diag(spacing)
404    # Build full 4x4 affine with identity in homogeneous bottom row
405    affine = np.eye(4)
406    affine[:spatial_dim, :spatial_dim] = affine_linear
407    affine[:spatial_dim, 3] = origin
408    affine[3, 3]=1
409    # Convert LPS -> RAS by flipping x and y
410    lps_to_ras = np.diag([-1, -1, 1, 1])
411    affine = lps_to_ras @ affine
412    return affine

Convert an ANTsPy image (in LPS space) to a Nibabel-compatible affine (in RAS space). Handles 2D, 3D, 4D input (only spatial dimensions are encoded in the affine).

Returns: 4x4 np.ndarray affine matrix in RAS space.

def dict_to_dataframe( data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
415def dict_to_dataframe(data_dict, convert_lists=True, convert_arrays=True, convert_images=True, verbose=False):
416    """
417    Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.
418
419    :param data_dict: Dictionary to be converted.
420    :param convert_lists: boolean
421    :param convert_arrays: boolean
422    :param convert_images: boolean
423    :param verbose: boolean
424    :return: DataFrame representation of the dictionary.
425    """
426    processed_data = {}
427    list_length = None
428    def mean_of_list(lst):
429        if not lst:  # Check if the list is not empty
430            return 0  # Return 0 or appropriate value for an empty list
431        all_numeric = all(isinstance(item, (int, float)) for item in lst)
432        if all_numeric:
433            return sum(lst) / len(lst)
434        return None
435    
436    for key, value in data_dict.items():
437        # Check if value is a scalar
438        if isinstance(value, (int, float, str, bool)):
439            processed_data[key] = [value]
440        # Check if value is a list of scalars
441        elif isinstance(value, list) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_lists:
442            meanvalue = mean_of_list( value )
443            newkey = key+"_mean"
444            if verbose:
445                print( " Key " + key + " is list with mean " + str(meanvalue) + " to " + newkey )
446            if newkey not in data_dict.keys() and convert_lists:
447                processed_data[newkey] = meanvalue
448        elif isinstance(value, np.ndarray) and all(isinstance(item, (int, float, str, bool)) for item in value) and convert_arrays:
449            meanvalue = value.mean()
450            newkey = key+"_mean"
451            if verbose:
452                print( " Key " + key + " is nparray with mean " + str(meanvalue) + " to " + newkey )
453            if newkey not in data_dict.keys():
454                processed_data[newkey] = meanvalue
455        elif isinstance(value, ants.core.ants_image.ANTsImage ) and convert_images:
456            meanvalue = value.mean()
457            newkey = key+"_mean"
458            if newkey not in data_dict.keys():
459                if verbose:
460                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " to " + newkey )
461                processed_data[newkey] = meanvalue
462            else:
463                if verbose:
464                    print( " Key " + key + " is antsimage with mean " + str(meanvalue) + " but " + newkey + " already exists" )
465
466    return pd.DataFrame.from_dict(processed_data)

Convert a dictionary to a pandas DataFrame, excluding items that cannot be processed by pandas.

Parameters
  • data_dict: Dictionary to be converted.
  • convert_lists: boolean
  • convert_arrays: boolean
  • convert_images: boolean
  • verbose: boolean
Returns

DataFrame representation of the dictionary.