From c8c8940b3af15dbc2ae61590a15dad204157ad35 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 27 Jan 2026 15:12:58 +0900 Subject: [PATCH 01/51] start allowing mixed bundle defs --- AFQ/api/bundle_dict.py | 48 ++++++++++++++++++++++++++---- AFQ/api/group.py | 3 +- AFQ/data/fetch.py | 59 +++++++++++++++++++++++++++++++++++++ AFQ/definitions/image.py | 2 +- AFQ/nn/synthseg.py | 4 +-- AFQ/recognition/criteria.py | 4 +-- AFQ/tasks/decorators.py | 1 - AFQ/tasks/mapping.py | 2 +- docs/source/references.bib | 21 +++++++++++++ 9 files changed, 128 insertions(+), 16 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 22c8e96e..9168da7b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -10,7 +10,7 @@ from AFQ.definitions.utils import find_file from AFQ.tasks.utils import get_fname, str_to_desc -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") __all__ = [ @@ -1111,7 +1111,7 @@ def update_max_includes(self, new_max): self.max_includes = new_max def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): - if isinstance(roi_or_sl, dict): + if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") roi_or_sl = find_file( bids_layout, bids_path, roi_or_sl, suffix, session, subject @@ -1124,6 +1124,33 @@ def _cond_load(self, roi_or_sl, resample_to): """ Load ROI or streamline if not already loaded """ + if isinstance(roi_or_sl, dict): + space = roi_or_sl.get("space", None) + roi_or_sl = roi_or_sl.get("roi", None) + if roi_or_sl is None or space is None: + raise ValueError( + ( + f"Unclear ROI definition for {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + if space == "template": + resample_to = self.resample_to + elif space == "subject": + resample_to = self.resample_subject_to + else: + raise ValueError( + ( + f"Unknown space {space} for ROI definition {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + + logger.debug(f"Loading ROI or streamlines: {roi_or_sl}") + logger.debug(f"Loading ROI or streamlines from space: {resample_to}") + if isinstance(roi_or_sl, str): if ".nii" in roi_or_sl: return afd.read_resample_roi(roi_or_sl, resample_to=resample_to) @@ -1261,11 +1288,20 @@ def is_bundle_in_template(self, bundle_name): return ( "space" not in self._dict[bundle_name] or self._dict[bundle_name]["space"] == "template" + or self._dict[bundle_name]["space"] == "mixed" ) - def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): + def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): roi_or_sl = self._cond_load(roi_or_sl, self.resample_to) if isinstance(roi_or_sl, nib.Nifti1Image): + if ( + np.allclose(roi_or_sl.affine, new_img.affine) + and roi_or_sl.shape == new_img.shape + ): + # This is the case of a mixed bundle definition, where + # some ROIs need transformed and others do not + return roi_or_sl + fdata = roi_or_sl.get_fdata() if len(np.unique(fdata)) <= 2: boolean_ = True @@ -1278,7 +1314,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): if boolean_: warped_img = warped_img.astype(np.uint8) - warped_img = nib.Nifti1Image(warped_img, new_affine) + warped_img = nib.Nifti1Image(warped_img, new_img.affine) return warped_img else: return roi_or_sl @@ -1287,7 +1323,7 @@ def transform_rois( self, bundle_name, mapping, - new_affine, + new_img, base_fname=None, to_space="subject", apply_to_recobundles=False, @@ -1333,7 +1369,7 @@ def transform_rois( bundle_name, self._roi_transform_helper, mapping, - new_affine, + new_img, bundle_name, dry_run=True, apply_to_recobundles=apply_to_recobundles, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index a94c6949..94b1e5b4 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -50,7 +50,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) warnings.simplefilter(action="ignore", category=FutureWarning) @@ -142,7 +141,7 @@ def __init__( api.GroupAFQ(my_path, csd_sh_order_max=4) api.GroupAFQ( my_path, - reg_template_spec="mni_t2", reg_subject_spec="b0") + _spec="mni_t2", reg_subject_spec="b0") """ if bids_layout_kwargs is None: bids_layout_kwargs = {} diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9400509e..e9da8a99 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1089,6 +1089,65 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +massp_fnames = [ + "left_VTA.nii.gz", + "right_VTA.nii.gz", +] + +massp_remote_fnames = [ + "34892325", + "34892319", +] + +massp_md5_hashes = [ + "03d65d85abb161ea25501c343c136e40", + "440874b899d2c1057e5fd77b8b350bc4", +] + +fetch_massp_templates = _make_reusable_fetcher( + "fetch_massp_templates", + op.join(afq_home, "massp_templates"), + baseurl, + massp_remote_fnames, + massp_fnames, + md5_list=massp_md5_hashes, + doc="Download AFQ MassP templates", +) + + +def read_massp_templates(as_img=True, resample_to=False): + """Load AFQ MASSSP templates from file + + Parameters + ---------- + as_img : bool, optional + If True, values are `Nifti1Image`. Otherwise, values are + paths to Nifti files. Default: True + resample_to : str or nibabel image class instance, optional + A template image to resample to. Typically, this should be the + template to which individual-level data are registered. Defaults to + the MNI template. Default: False + + Returns + ------- + dict with: keys: names of template ROIs and values: nibabel Nifti1Image + objects from each of the ROI nifti files. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading oton templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template( + fetch_massp_templates, as_img=as_img, resample_to=resample_to + ) + + toc = time.perf_counter() + logger.debug(f"MASSSP templates loaded in {toc - tic:0.4f} seconds") + + return template_dict + + cp_fnames = [ "ICP_L_inferior_prob.nii.gz", "ICP_L_superior_prob.nii.gz", diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 4208056c..985d270b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -384,7 +384,7 @@ def _image_getter_helper( for bundle_name in bundle_dict: bundle_entry = bundle_dict.transform_rois( - bundle_name, mapping_imap["mapping"], data_imap["dwi_affine"] + bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) rois = [] if self.use_endpoints: diff --git a/AFQ/nn/synthseg.py b/AFQ/nn/synthseg.py index 45f80f52..6ea9d0c2 100644 --- a/AFQ/nn/synthseg.py +++ b/AFQ/nn/synthseg.py @@ -86,8 +86,8 @@ def pve_from_synthseg(synthseg_data): PVE data with CSF, GM, and WM segmentations. """ - CSF_labels = [0, 3, 4, 11, 12, 21, 22, 17] - GM_labels = [2, 7, 8, 9, 10, 14, 15, 16, 20, 25, 26, 27, 28, 29, 30, 31] + CSF_labels = [0, 3, 4, 11, 12, 21, 22, 16] + GM_labels = [2, 7, 8, 9, 10, 14, 15, 17, 20, 25, 26, 27, 28, 29, 30, 31] WM_labels = [1, 5, 19, 23] mixed_labels = [13, 18, 32] diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index e68e6251..1e0bdccb 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -392,9 +392,7 @@ def run_bundle_rec_plan( start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( - bundle_dict.transform_rois( - bundle_name, mapping, img.affine, apply_to_recobundles=True - ) + bundle_dict.transform_rois(bundle_name, mapping, img, apply_to_recobundles=True) ) def check_space(roi): diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index bb1d750c..3c041287 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -26,7 +26,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) def get_new_signature(og_func, needed_args): diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 33006030..c593371a 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -85,7 +85,7 @@ def export_rois(base_fname, output_dir, dwi_data_file, data_imap, mapping): *bundle_dict.transform_rois( bundle_name, mapping, - data_imap["dwi_affine"], + data_imap["dwi"], base_fname=base_roi_fname, to_space=to_space, ) diff --git a/docs/source/references.bib b/docs/source/references.bib index 3dba804f..ec7b1eb7 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -9,6 +9,27 @@ @article{Grotheer2022 publisher={Nature Publishing Group} } +@article{leong2016white, + title={White-matter tract connecting anterior insula to nucleus accumbens correlates with reduced preference for positively skewed gambles}, + author={Leong, Josiah K and Pestilli, Franco and Wu, Charlene C and Samanez-Larkin, Gregory R and Knutson, Brian}, + journal={Neuron}, + volume={89}, + number={1}, + pages={63--69}, + year={2016}, + publisher={Elsevier} +} + +@article{alkemade2020amsterdam, + title={The Amsterdam Ultra-high field adult lifespan database (AHEAD): A freely available multimodal 7 Tesla submillimeter magnetic resonance imaging database}, + author={Alkemade, Anneke and Mulder, Martijn J and Groot, Josephine M and Isaacs, Bethany R and van Berendonk, Nikita and Lute, Nicky and Isherwood, Scott JS and Bazin, Pierre-Louis and Forstmann, Birte U}, + journal={NeuroImage}, + volume={221}, + pages={117200}, + year={2020}, + publisher={Elsevier} +} + @article{grotheer2023human, title={Human white matter myelinates faster in utero than ex utero}, author={Grotheer, Mareike and Bloom, David and Kruper, John and Richie-Halford, Adam and Zika, Stephanie and Aguilera González, Vicente A and Yeatman, Jason D and Grill-Spector, Kalanit and Rokem, Ariel}, From 2dac434424f7a5e4ead69c7abab3a80aef30215a Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 10:34:52 +0900 Subject: [PATCH 02/51] finish up wmgmi seeding changes --- AFQ/definitions/image.py | 52 ++++++++++++++++++++++++++++++----- AFQ/recognition/preprocess.py | 16 +---------- AFQ/recognition/utils.py | 17 ++++++++++++ AFQ/tasks/utils.py | 3 +- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 985d270b..85454908 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -3,8 +3,10 @@ import nibabel as nib import numpy as np from dipy.align import resample +from scipy.ndimage import distance_transform_edt from AFQ.definitions.utils import Definition, find_file, name_from_path +from AFQ.recognition.utils import tolerance_mm_to_vox from AFQ.tasks.utils import get_tp __all__ = [ @@ -324,6 +326,13 @@ class RoiImage(ImageDefinition): use_endpoints : bool Whether to use the endpoints ("start" and "end") to generate the image. + only_wmgmi : bool + Whether to only include portion of ROIs in the WM-GM interface. + only_wm : bool + Whether to only include portion of ROIs in the white matter. + dilate : bool + Whether to dilate the ROIs before combining them, according to the + tolerance that will be used during bundle recognition. tissue_property : str or None Tissue property from `scalars` to multiply the ROI image with. Can be useful to limit seed mask to the core white matter. @@ -350,14 +359,20 @@ def __init__( use_presegment=False, use_endpoints=False, only_wmgmi=False, + only_wm=False, + dilate=True, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None, ): + if only_wmgmi and only_wm: + raise ValueError("only_wmgmi and only_wm cannot both be True") self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints self.only_wmgmi = only_wmgmi + self.only_wm = only_wm + self.dilate = dilate self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold @@ -386,21 +401,40 @@ def _image_getter_helper( bundle_entry = bundle_dict.transform_rois( bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) - rois = [] + rois = {} if self.use_endpoints: - rois.extend( - [ - bundle_entry[end_type] + rois.update( + { + bundle_entry[end_type]: end_type for end_type in ["start", "end"] if end_type in bundle_entry - ] + } ) if self.use_waypoints: - rois.extend(bundle_entry.get("include", [])) - for roi in rois: + rois.update( + dict.fromkeys(bundle_entry.get("include", []), "waypoint") + ) + for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) + if self.dilate: + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) + edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) + if roi_type == "waypoint": + warped_roi = edt <= dist_to_waypoint + else: + warped_roi = edt <= dist_to_atlas + warped_roi = ( + edt <= dist_to_waypoint + if roi_type == "waypoint" + else edt <= dist_to_atlas + ) + image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: tp = nib.load( @@ -456,6 +490,10 @@ def _image_getter_helper( ) ) + if self.only_wm: + wm = nib.load(tissue_imap["pve_internal"]).get_fdata()[..., 2] >= 0.5 + image_data = np.logical_and(image_data, wm) + return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"] ), dict(source="ROIs") diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8bb41e67..8b3ce657 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -1,7 +1,6 @@ import logging from time import time -import dipy.tracking.streamline as dts import immlib import nibabel as nib import numpy as np @@ -13,20 +12,7 @@ @immlib.calc("tol", "dist_to_atlas", "vox_dim") def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): - # We need to calculate the size of a voxel, so we can transform - # from mm to voxel units: - R = img.affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - - # Tolerance is set to the square of the distance to the corner - # because we are using the squared Euclidean distance in calls to - # `cdist` to make those calls faster. - if dist_to_waypoint is None: - tol = dts.dist_to_corner(img.affine) - else: - tol = dist_to_waypoint / vox_dim - dist_to_atlas = int(input_dist_to_atlas / vox_dim) - return tol, dist_to_atlas, vox_dim + return abu.tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas) @immlib.calc("fgarray") diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index fd34aa18..505f8300 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -14,6 +14,23 @@ logger = logging.getLogger("AFQ") +def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): + # We need to calculate the size of a voxel, so we can transform + # from mm to voxel units: + R = img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + # Tolerance is set to the square of the distance to the corner + # because we are using the squared Euclidean distance in calls to + # `cdist` to make those calls faster. + if dist_to_waypoint is None: + tol = dts.dist_to_corner(img.affine) + else: + tol = dist_to_waypoint / vox_dim + dist_to_atlas = int(input_dist_to_atlas / vox_dim) + return tol, dist_to_atlas, vox_dim + + def flip_sls(select_sl, idx_to_flip, in_place=False): """ Helper function to flip streamlines diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 128fd157..e3cfb1fe 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -29,7 +29,8 @@ def get_base_fname(output_dir, dwi_data_file): key = key_val_pair.split("-")[0] if key not in used_key_list: fname = fname + key_val_pair + "_" - fname = fname[:-1] + if fname[-1] == "_": + fname = fname[:-1] return fname From 2538e39f77569fec2f951f3d0426e965cd4e9fd9 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 13:58:40 +0900 Subject: [PATCH 03/51] more mixed ROI fixes --- AFQ/api/bundle_dict.py | 15 +++++++++++++-- AFQ/api/group.py | 2 ++ AFQ/api/participant.py | 4 ++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 9168da7b..b9735357 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -11,6 +11,7 @@ from AFQ.tasks.utils import get_fname, str_to_desc logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) __all__ = [ @@ -1139,6 +1140,14 @@ def _cond_load(self, roi_or_sl, resample_to): resample_to = self.resample_to elif space == "subject": resample_to = self.resample_subject_to + if resample_to is False: + raise ValueError( + ( + "When using mixed ROI bundle definitions, " + "and subject space ROIs, " + "resample_subject_to cannot be False." + ) + ) else: raise ValueError( ( @@ -1296,7 +1305,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): if isinstance(roi_or_sl, nib.Nifti1Image): if ( np.allclose(roi_or_sl.affine, new_img.affine) - and roi_or_sl.shape == new_img.shape + and roi_or_sl.shape == new_img.shape[:3] ): # This is the case of a mixed bundle definition, where # some ROIs need transformed and others do not @@ -1418,7 +1427,9 @@ def transform_rois( def __add__(self, other): for resample in ["resample_to", "resample_subject_to"]: - if ( + if getattr(self, resample) == getattr(other, resample): + pass + elif ( not getattr(self, resample) or not getattr(other, resample) or getattr(self, resample) is None diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 94b1e5b4..05120f92 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -50,6 +50,8 @@ logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + warnings.simplefilter(action="ignore", category=FutureWarning) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index fd0e08e6..927f64ae 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -34,6 +34,10 @@ __all__ = ["ParticipantAFQ"] +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + + class ParticipantAFQ(object): f"""{AFQclass_doc}""" From 7ddab0fa25711dc81ee4de6290fae36455ce734e Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:07:08 +0900 Subject: [PATCH 04/51] BFs --- AFQ/api/bundle_dict.py | 2 +- AFQ/tasks/utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index b9735357..ed37dd16 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -302,7 +302,7 @@ def default_bd(): "primary_axis_percentage": 40, }, }, - citations={"Yeatman2012", "takemura2017occipital"}, + citations={"Yeatman2012", "takemura2017occipital", "Tzourio-Mazoyer2002"}, ) diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index e3cfb1fe..e57eb393 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -31,6 +31,10 @@ def get_base_fname(output_dir, dwi_data_file): fname = fname + key_val_pair + "_" if fname[-1] == "_": fname = fname[:-1] + else: + # if no key value pairs found, + # have some default base file name + fname = fname + "subject" return fname From 475d9ae50e398e571d3c5fa6ddc5b7f95b521f02 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:29:26 +0900 Subject: [PATCH 05/51] return to setting logging to info --- AFQ/api/group.py | 1 + AFQ/api/participant.py | 1 + 2 files changed, 2 insertions(+) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 05120f92..b831e62c 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -49,6 +49,7 @@ __all__ = ["GroupAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 927f64ae..72beb926 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -34,6 +34,7 @@ __all__ = ["ParticipantAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) From e7f2d67d3433210773af2d7212974b1d677d0ff0 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 28 Jan 2026 15:53:34 +0900 Subject: [PATCH 06/51] minor docs fixes from copilot --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/api/group.py | 2 +- AFQ/data/fetch.py | 6 +++--- AFQ/definitions/image.py | 16 ++++++---------- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index ed37dd16..273bd72e 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -1351,8 +1351,8 @@ def transform_rois( Name of the bundle to be transformed. mapping : DiffeomorphicMap object A mapping between DWI space and a template. - new_affine : array - Affine of space transformed into. + new_img : Nifti1Image + Image of space transformed into. base_fname : str, optional Base file path to construct file path from. Additional BIDS descriptors will be added to this file path. If None, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index b831e62c..e1c9ebeb 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -144,7 +144,7 @@ def __init__( api.GroupAFQ(my_path, csd_sh_order_max=4) api.GroupAFQ( my_path, - _spec="mni_t2", reg_subject_spec="b0") + reg_template_spec="mni_t2", reg_subject_spec="b0") """ if bids_layout_kwargs is None: bids_layout_kwargs = {} diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index e9da8a99..c88fc74e 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1116,7 +1116,7 @@ def read_oton_templates(as_img=True, resample_to=False): def read_massp_templates(as_img=True, resample_to=False): - """Load AFQ MASSSP templates from file + """Load AFQ MASSP templates from file Parameters ---------- @@ -1135,7 +1135,7 @@ def read_massp_templates(as_img=True, resample_to=False): """ logger = logging.getLogger("AFQ") - logger.debug("loading oton templates") + logger.debug("loading MASSP templates") tic = time.perf_counter() template_dict = _fetcher_to_template( @@ -1143,7 +1143,7 @@ def read_massp_templates(as_img=True, resample_to=False): ) toc = time.perf_counter() - logger.debug(f"MASSSP templates loaded in {toc - tic:0.4f} seconds") + logger.debug(f"MASSP templates loaded in {toc - tic:0.4f} seconds") return template_dict diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 85454908..df5df2b4 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -414,26 +414,22 @@ def _image_getter_helper( rois.update( dict.fromkeys(bundle_entry.get("include", []), "waypoint") ) + + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) if self.dilate: - dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( - data_imap["dwi"], - segmentation_params["dist_to_waypoint"], - segmentation_params["dist_to_atlas"], - ) edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) if roi_type == "waypoint": warped_roi = edt <= dist_to_waypoint else: warped_roi = edt <= dist_to_atlas - warped_roi = ( - edt <= dist_to_waypoint - if roi_type == "waypoint" - else edt <= dist_to_atlas - ) image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: From 8afa112376ba5a06f74644b799cd7a93105503a3 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 29 Jan 2026 13:38:06 +0900 Subject: [PATCH 07/51] toy with vof --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 273bd72e..8408bb2b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -279,7 +279,7 @@ def default_bd(): "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25}, "isolation_forest": {}, "primary_axis": "I/S", @@ -295,7 +295,7 @@ def default_bd(): "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25}, "isolation_forest": {}, "primary_axis": "I/S", From 422b0652647fac18a3bffb8a64056fe9b997f558 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 29 Jan 2026 15:20:07 +0900 Subject: [PATCH 08/51] much needed speedups in bundle recognition --- AFQ/api/bundle_dict.py | 12 +++++++++-- AFQ/recognition/criteria.py | 33 ++++++++++++++++------------- AFQ/recognition/roi.py | 42 ++++++++++++++++--------------------- 3 files changed, 46 insertions(+), 41 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 8408bb2b..be14a633 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -185,6 +185,7 @@ def default_bd(): "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, @@ -194,6 +195,7 @@ def default_bd(): "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -221,6 +223,7 @@ def default_bd(): "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Right Arcuate": { "cross_midline": False, @@ -230,6 +233,7 @@ def default_bd(): "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Left Uncinate": { "cross_midline": False, @@ -254,8 +258,10 @@ def default_bd(): "include": [templates["SLFt_roi2_L"]], "exclude": [templates["SLF_roi1_L"]], "space": "template", + "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], "Left Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, }, @@ -264,8 +270,10 @@ def default_bd(): "include": [templates["SLFt_roi2_R"]], "exclude": [templates["SLF_roi1_R"]], "space": "template", + "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], "Right Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, }, @@ -280,7 +288,7 @@ def default_bd(): }, "Left Inferior Fronto-occipital": {"core": "Right"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, - "length": {"min_len": 25}, + "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -296,7 +304,7 @@ def default_bd(): }, "Right Inferior Fronto-occipital": {"core": "Left"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, - "length": {"min_len": 25}, + "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", "primary_axis_percentage": 40, diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 1e0bdccb..5e078006 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -24,9 +24,9 @@ criteria_order_pre_other_bundles = [ "prob_map", "cross_midline", + "length", "start", "end", - "length", "primary_axis", "include", "exclude", @@ -69,18 +69,17 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("Startpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("Startpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], 0, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], -1, tol=preproc_imap["dist_to_atlas"], @@ -91,18 +90,17 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): def end(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("endpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("endpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], -1, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], 0, tol=preproc_imap["dist_to_atlas"], @@ -116,10 +114,15 @@ def length(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] - for idx, sl in enumerate(b_sls.get_selected_sls()): - sl_len = np.sum(np.linalg.norm(np.diff(sl, axis=0), axis=1)) - if sl_len >= min_len and sl_len <= max_len: - accept_idx[idx] = 1 + + # Using resampled fgarray biases lengths to be lower. However, + # this is not meant to be a precise selection requirement, and + # is more meant for efficiency. + segments = np.diff(preproc_imap["fgarray"][b_sls.selected_fiber_idxs], axis=1) + segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) + sl_lens = np.sum(segment_lengths, axis=1) + + accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index d87062f5..83e1cf53 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -6,8 +6,7 @@ def _interp3d(roi, sl): return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] -def check_sls_with_inclusion( - sls, include_rois, include_roi_tols): +def check_sls_with_inclusion(sls, include_rois, include_roi_tols): inc_results = np.zeros(len(sls), dtype=tuple) include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): @@ -30,9 +29,8 @@ def check_sls_with_inclusion( return inc_results -def check_sl_with_exclusion(sl, exclude_rois, - exclude_roi_tols): - """ Helper function to check that a streamline is not too close to a +def check_sl_with_exclusion(sl, exclude_rois, exclude_roi_tols): + """Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for ii, roi in enumerate(exclude_rois): @@ -44,17 +42,15 @@ def check_sl_with_exclusion(sl, exclude_rois, return True -def clean_by_endpoints(streamlines, target, target_idx, tol=0, - flip_sls=None, accepted_idxs=None): +def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): """ Clean a collection of streamlines based on an endpoint ROI. Filters down to only include items that have their start or end points close to the targets. Parameters ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. + fgarray : ndarray of shape (N, M, 3) + Where N is number of streamlines, M is number of nodes. target: Nifti1Image Nifti1Image containing a distance transform of the ROI. target_idx: int. @@ -67,24 +63,22 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0, the endpoint is exactly in the coordinate of the target ROI. flip_sls : 1d array, optional Length is len(streamlines), whether to flip the streamline. - accepted_idxs : 1d array, optional - Boolean array, where entries correspond to eachs streamline, - and streamlines that pass cleaning will be set to 1. Yields ------- boolean array of streamlines that survive cleaning. """ - if accepted_idxs is None: - accepted_idxs = np.zeros(len(streamlines), dtype=np.bool_) + n_sls, n_nodes, _ = fgarray.shape - if flip_sls is None: - flip_sls = np.zeros(len(streamlines)) - flip_sls = flip_sls.astype(int) + # handle target_idx negative values as wrapping around + effective_idx = target_idx if target_idx >= 0 else (n_nodes + target_idx) + indices = np.full(n_sls, effective_idx) - for ii, sl in enumerate(streamlines): - this_idx = target_idx - if flip_sls[ii]: - this_idx = (len(sl) - this_idx - 1) % len(sl) - accepted_idxs[ii] = _interp3d(target, [sl[this_idx]])[0] <= tol + if flip_sls is not None: + flipped_indices = n_nodes - 1 - effective_idx + indices = np.where(flip_sls.astype(bool), flipped_indices, indices) - return accepted_idxs + distances = interpolate_scalar_3d( + target.get_fdata(), fgarray[np.arange(n_sls), indices] + )[0] + + return distances <= tol From 9e2d9b1fc590c1a911a114f95806b1271ba9b47b Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 30 Jan 2026 15:52:21 +0900 Subject: [PATCH 09/51] tighter node thresh --- AFQ/api/bundle_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index be14a633..5044d2e4 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -281,9 +281,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 20}, + "Left Arcuate": {"node_thresh": 10}, "Left Posterior Arcuate": { - "node_thresh": 20, + "node_thresh": 10, "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, @@ -297,9 +297,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 20}, + "Right Arcuate": {"node_thresh": 10}, "Right Posterior Arcuate": { - "node_thresh": 20, + "node_thresh": 10, "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, From 90834a09572caff4a0d1f7a3ad6477d53cdbae0c Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 30 Jan 2026 17:06:08 +0900 Subject: [PATCH 10/51] bring back parietal endpoint ROIs --- AFQ/api/bundle_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 5044d2e4..d670086b 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -280,6 +280,7 @@ def default_bd(): "Left Vertical Occipital": { "cross_midline": False, "space": "template", + "start": templates["VOF_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"node_thresh": 10}, "Left Posterior Arcuate": { @@ -296,6 +297,7 @@ def default_bd(): "Right Vertical Occipital": { "cross_midline": False, "space": "template", + "start": templates["VOF_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"node_thresh": 10}, "Right Posterior Arcuate": { From 84d2e1833849e341173db4ae78c5a8d17c1445d9 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 10:17:07 +0900 Subject: [PATCH 11/51] Add projection to node threshold --- AFQ/api/bundle_dict.py | 8 ++++---- AFQ/recognition/criteria.py | 6 ++++-- AFQ/recognition/other_bundles.py | 22 +++++++++++++++++++++- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index d670086b..2dceec74 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -280,11 +280,11 @@ def default_bd(): "Left Vertical Occipital": { "cross_midline": False, "space": "template", - "start": templates["VOF_L_start"], "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 10}, + "Left Arcuate": {"node_thresh": 10, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 10, + "project": "L/R", "entire_core": "Anterior", }, "Left Inferior Fronto-occipital": {"core": "Right"}, @@ -297,11 +297,11 @@ def default_bd(): "Right Vertical Occipital": { "cross_midline": False, "space": "template", - "start": templates["VOF_R_start"], "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 10}, + "Right Arcuate": {"node_thresh": 10, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 10, + "project": "L/R", "entire_core": "Anterior", }, "Right Inferior Fronto-occipital": {"core": "Left"}, diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 5e078006..0514e445 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -307,7 +307,8 @@ def clean_by_other_bundle( other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, - False, + remove=False, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) @@ -317,7 +318,8 @@ def clean_by_other_bundle( other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, - True, + remove=True, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 1e94106b..9eaa4694 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -9,7 +9,9 @@ logger = logging.getLogger("AFQ") -def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=False): +def clean_by_overlap( + this_bundle_sls, other_bundle_sls, overlap, img, remove=False, project=None +): """ Cleans a set of streamlines by only keeping (or removing) those with significant overlap with another set of streamlines. @@ -32,6 +34,11 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal removed. If False, streamlines that overlap in more than `overlap` nodes are removed. Default: False. + project : {'A/P', 'I/S', 'L/R', None}, optional + If specified, the overlap calculation is projected along the given axis + before cleaning. For example, 'A/P' projects the streamlines along the + anterior-posterior axis. + Default: None. Returns ------- @@ -56,6 +63,19 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal other_bundle_density_map = dtu.density_map( other_bundle_sls, np.eye(4), img.shape[:3] ) + + if project is not None: + orientation = nib.orientations.aff2axcodes(img.affine) + core_axis = next( + idx for idx, label in enumerate(orientation) if label in project.upper() + ) + + projection = np.sum(other_bundle_density_map, axis=core_axis) + + other_bundle_density_map = np.broadcast_to( + np.expand_dims(projection, axis=core_axis), other_bundle_density_map.shape + ) + fiber_probabilities = dts.values_from_volume( other_bundle_density_map, this_bundle_sls, np.eye(4) ) From 348d2c2858821f726e1dd5a62bcd951454456bb3 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 15:01:41 +0900 Subject: [PATCH 12/51] for large tractographies, this 5 percent rule may be necessary --- AFQ/api/bundle_dict.py | 8 ++++---- AFQ/recognition/other_bundles.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 2dceec74..6e69995a 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -281,9 +281,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 10, "project": "L/R"}, + "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { - "node_thresh": 10, + "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, @@ -298,9 +298,9 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 10, "project": "L/R"}, + "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { - "node_thresh": 10, + "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 9eaa4694..e975765a 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -10,7 +10,13 @@ def clean_by_overlap( - this_bundle_sls, other_bundle_sls, overlap, img, remove=False, project=None + this_bundle_sls, + other_bundle_sls, + overlap, + img, + remove=False, + project=None, + other_bundle_min_density=0.05, ): """ Cleans a set of streamlines by only keeping (or removing) those with @@ -39,6 +45,11 @@ def clean_by_overlap( before cleaning. For example, 'A/P' projects the streamlines along the anterior-posterior axis. Default: None. + other_bundle_min_density : float, optional + A threshold to binarize the density map of `other_bundle_sls`. Voxels + with density values above this threshold (as a fraction of the maximum + density) are considered occupied. + Default: 0.05. Returns ------- @@ -64,6 +75,10 @@ def clean_by_overlap( other_bundle_sls, np.eye(4), img.shape[:3] ) + other_bundle_density_map = ( + other_bundle_density_map / other_bundle_density_map.max() + ) > other_bundle_min_density + if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) core_axis = next( From 0c6c93e2f09c80aef8432255f515895d34adaacd Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 2 Feb 2026 15:52:40 +0900 Subject: [PATCH 13/51] add exclude ROI to pAF --- AFQ/api/bundle_dict.py | 4 ++-- AFQ/recognition/other_bundles.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 6e69995a..e5aedad7 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -256,7 +256,7 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"]], + "exclude": [templates["SLF_roi1_L"], templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], @@ -268,7 +268,7 @@ def default_bd(): "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"]], + "exclude": [templates["SLF_roi1_R"], templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index e975765a..3b04cee0 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -75,9 +75,10 @@ def clean_by_overlap( other_bundle_sls, np.eye(4), img.shape[:3] ) - other_bundle_density_map = ( - other_bundle_density_map / other_bundle_density_map.max() - ) > other_bundle_min_density + if remove: + other_bundle_density_map = ( + other_bundle_density_map / other_bundle_density_map.max() + ) > other_bundle_min_density if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) From 34ac0f4385cabe64d1f2533ca0b98809d440aa98 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 3 Feb 2026 16:58:11 +0900 Subject: [PATCH 14/51] update montage code --- AFQ/api/participant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 72beb926..2c87599c 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -278,7 +278,7 @@ def participant_montage(self, images_per_row=2): bundle_dict = self.export("bundle_dict") self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.export(self.export("best_scalar")) + best_scalar = self.kwargs["best_scalar"] t1 = nib.load(self.export("t1_masked")) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): From bd0b0f9ad568aabf77e7d094c6bf61c6e3af2faa Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 09:35:16 +0900 Subject: [PATCH 15/51] fix participant montage --- AFQ/api/participant.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 2c87599c..9dc59ff4 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -6,6 +6,7 @@ import nibabel as nib import numpy as np +from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm @@ -278,8 +279,9 @@ def participant_montage(self, images_per_row=2): bundle_dict = self.export("bundle_dict") self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.kwargs["best_scalar"] t1 = nib.load(self.export("t1_masked")) + best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) + best_scalar = resample(best_scalar, t1) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] @@ -287,12 +289,12 @@ def participant_montage(self, images_per_row=2): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 figure = viz_backend.visualize_volume( - t1, flip_axes=flip_axes, interact=False, inline=False + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False ) figure = viz_backend.visualize_bundles( self.export("bundles"), - affine=t1.affine, - shade_by_volume=best_scalar, + img=t1, + shade_by_volume=best_scalar.get_fdata(), color_by_direction=True, flip_axes=flip_axes, bundle=bundle_name, From 24026e54da1a6e0388590cc97e8d9cbf665c33a5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 09:52:05 +0900 Subject: [PATCH 16/51] improve participant montage --- AFQ/api/participant.py | 120 ++++++++++++++++++++++------------------- AFQ/viz/utils.py | 44 +++++++-------- 2 files changed, 86 insertions(+), 78 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 9dc59ff4..ed8600a4 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -299,42 +299,49 @@ def participant_montage(self, images_per_row=2): flip_axes=flip_axes, bundle=bundle_name, figure=figure, + n_points=40, interact=False, inline=False, ) - view, direc = BEST_BUNDLE_ORIENTATIONS.get(bundle_name, ("Axial", "Top")) - eye = get_eye(view, direc) - - this_fname = tdir + f"/t{ii}.png" - if "plotly" in viz_backend.backend: - figure.update_layout( - scene_camera=dict( - projection=dict(type="orthographic"), - up={"x": 0, "y": 0, "z": 1}, - eye=eye, - center=dict(x=0, y=0, z=0), - ), - showlegend=False, - ) - figure.write_image(this_fname, scale=4) + for jj, view in enumerate(["Sagittal", "Coronal", "Axial"]): + direc = BEST_BUNDLE_ORIENTATIONS.get( + bundle_name, ("Left", "Front", "Top") + )[jj] + + eye = get_eye(view, direc) + + this_fname = tdir + f"/t{ii}_{view}.png" + if "plotly" in viz_backend.backend: + figure.update_layout( + scene_camera=dict( + projection=dict(type="orthographic"), + up={"x": 0, "y": 0, "z": 1}, + eye=eye, + center=dict(x=0, y=0, z=0), + ), + showlegend=False, + ) + figure.write_image(this_fname, scale=4) - # temporary fix for memory leak - import plotly.io as pio + # temporary fix for memory leak + import plotly.io as pio - pio.kaleido.scope._shutdown_kaleido() - else: - from dipy.viz import window - - direc = np.fromiter(eye.values(), dtype=int) - data_shape = np.asarray(nib.load(self.export("b0")).get_fdata().shape) - figure.set_camera( - position=direc * data_shape, - focal_point=data_shape // 2, - view_up=(0, 0, 1), - ) - figure.zoom(0.5) - window.snapshot(figure, fname=this_fname, size=(600, 600)) + pio.kaleido.scope._shutdown_kaleido() + else: + from dipy.viz import window + + direc = np.fromiter(eye.values(), dtype=int) + data_shape = np.asarray( + nib.load(self.export("b0")).get_fdata().shape + ) + figure.set_camera( + position=direc * data_shape, + focal_point=data_shape // 2, + view_up=(0, 0, 1), + ) + figure.zoom(0.5) + window.snapshot(figure, fname=this_fname, size=(600, 600)) def _save_file(curr_img): save_path = op.abspath( @@ -347,33 +354,34 @@ def _save_file(curr_img): max_height = 0 max_width = 0 for ii, bundle_name in enumerate(bundle_dict): - this_img = Image.open(tdir + f"/t{ii}.png") - try: - this_img_trimmed[ii] = trim(this_img) - except IndexError: # this_img is a picture of nothing - this_img_trimmed[ii] = this_img - - text_sz = 70 - width, height = this_img_trimmed[ii].size - height = height + text_sz - result = Image.new( - this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) - ) - result.paste(this_img_trimmed[ii], (0, text_sz)) - this_img_trimmed[ii] = result - - draw = ImageDraw.Draw(this_img_trimmed[ii]) - draw.text( - (0, 0), - bundle_name, - (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), - ) + for view in ["Axial", "Coronal", "Sagittal"]: + this_img = Image.open(tdir + f"/t{ii}_{view}.png") + try: + this_img_trimmed[ii] = trim(this_img) + except IndexError: # this_img is a picture of nothing + this_img_trimmed[ii] = this_img + + text_sz = 70 + width, height = this_img_trimmed[ii].size + height = height + text_sz + result = Image.new( + this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) + ) + result.paste(this_img_trimmed[ii], (0, text_sz)) + this_img_trimmed[ii] = result + + draw = ImageDraw.Draw(this_img_trimmed[ii]) + draw.text( + (0, 0), + bundle_name, + (0, 0, 0), + font=ImageFont.truetype("Arial", text_sz), + ) - if this_img_trimmed[ii].size[0] > max_width: - max_width = this_img_trimmed[ii].size[0] - if this_img_trimmed[ii].size[1] > max_height: - max_height = this_img_trimmed[ii].size[1] + if this_img_trimmed[ii].size[0] > max_width: + max_width = this_img_trimmed[ii].size[0] + if this_img_trimmed[ii].size[1] > max_height: + max_height = this_img_trimmed[ii].size[1] curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index def69151..4ffe316a 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -150,28 +150,28 @@ RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"] BEST_BUNDLE_ORIENTATIONS = { - "Left Anterior Thalamic": ("Sagittal", "Left"), - "Right Anterior Thalamic": ("Sagittal", "Right"), - "Left Corticospinal": ("Sagittal", "Left"), - "Right Corticospinal": ("Sagittal", "Right"), - "Left Cingulum Cingulate": ("Sagittal", "Left"), - "Right Cingulum Cingulate": ("Sagittal", "Right"), - "Forceps Minor": ("Axial", "Top"), - "Forceps Major": ("Axial", "Top"), - "Left Inferior Fronto-occipital": ("Sagittal", "Left"), - "Right Inferior Fronto-occipital": ("Sagittal", "Right"), - "Left Inferior Longitudinal": ("Sagittal", "Left"), - "Right Inferior Longitudinal": ("Sagittal", "Right"), - "Left Superior Longitudinal": ("Axial", "Top"), - "Right Superior Longitudinal": ("Axial", "Top"), - "Left Uncinate": ("Axial", "Bottom"), - "Right Uncinate": ("Axial", "Bottom"), - "Left Arcuate": ("Sagittal", "Left"), - "Right Arcuate": ("Sagittal", "Right"), - "Left Vertical Occipital": ("Coronal", "Back"), - "Right Vertical Occipital": ("Coronal", "Back"), - "Left Posterior Arcuate": ("Coronal", "Back"), - "Right Posterior Arcuate": ("Coronal", "Back"), + "Left Anterior Thalamic": ("Left", "Front", "Top"), + "Right Anterior Thalamic": ("Right", "Front", "Top"), + "Left Corticospinal": ("Left", "Front", "Top"), + "Right Corticospinal": ("Right", "Front", "Top"), + "Left Cingulum Cingulate": ("Left", "Front", "Top"), + "Right Cingulum Cingulate": ("Right", "Front", "Top"), + "Forceps Minor": ("Left", "Front", "Top"), + "Forceps Major": ("Left", "Back", "Top"), + "Left Inferior Fronto-occipital": ("Left", "Front", "Bottom"), + "Right Inferior Fronto-occipital": ("Right", "Front", "Bottom"), + "Left Inferior Longitudinal": ("Left", "Front", "Bottom"), + "Right Inferior Longitudinal": ("Right", "Front", "Bottom"), + "Left Superior Longitudinal": ("Left", "Front", "Top"), + "Right Superior Longitudinal": ("Right", "Front", "Top"), + "Left Uncinate": ("Left", "Front", "Bottom"), + "Right Uncinate": ("Right", "Front", "Bottom"), + "Left Arcuate": ("Left", "Front", "Top"), + "Right Arcuate": ("Right", "Front", "Top"), + "Left Vertical Occipital": ("Left", "Back", "Top"), + "Right Vertical Occipital": ("Right", "Back", "Top"), + "Left Posterior Arcuate": ("Left", "Back", "Top"), + "Right Posterior Arcuate": ("Right", "Back", "Top"), } From 1954cc59196b4d79c90c2ec9a073547ac17d6e93 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 10:13:44 +0900 Subject: [PATCH 17/51] More participant montage improvements --- AFQ/api/participant.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index ed8600a4..380da7ed 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -259,7 +259,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=2): + def participant_montage(self, images_per_row=3): """ Generate montage of all bundles for a given subject. @@ -267,7 +267,7 @@ def participant_montage(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 Returns ------- @@ -282,7 +282,7 @@ def participant_montage(self, images_per_row=2): t1 = nib.load(self.export("t1_masked")) best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) best_scalar = resample(best_scalar, t1) - size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) + size = (images_per_row, math.ceil(3 * len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] for i in range(3): @@ -353,15 +353,16 @@ def _save_file(curr_img): this_img_trimmed = {} max_height = 0 max_width = 0 - for ii, bundle_name in enumerate(bundle_dict): + ii = 0 + for b_idx, bundle_name in enumerate(bundle_dict): for view in ["Axial", "Coronal", "Sagittal"]: - this_img = Image.open(tdir + f"/t{ii}_{view}.png") + this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") try: this_img_trimmed[ii] = trim(this_img) except IndexError: # this_img is a picture of nothing this_img_trimmed[ii] = this_img - text_sz = 70 + text_sz = 40 width, height = this_img_trimmed[ii].size height = height + text_sz result = Image.new( @@ -373,26 +374,27 @@ def _save_file(curr_img): draw = ImageDraw.Draw(this_img_trimmed[ii]) draw.text( (0, 0), - bundle_name, + f"{bundle_name} - {view}", (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), + font=ImageFont.load_default(text_sz), ) if this_img_trimmed[ii].size[0] > max_width: max_width = this_img_trimmed[ii].size[0] if this_img_trimmed[ii].size[1] > max_height: max_height = this_img_trimmed[ii].size[1] + ii += 1 curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" ) - for ii in range(len(bundle_dict)): - x_pos = ii % size[0] - _ii = ii // size[0] + for jj in range(ii): + x_pos = jj % size[0] + _ii = jj // size[0] y_pos = _ii % size[1] _ii = _ii // size[1] - this_img = this_img_trimmed[ii].resize((max_width, max_height)) + this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) _save_file(curr_img) From 4f834fc7920d85505fa70bce96c330925505229e Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 4 Feb 2026 10:29:19 +0900 Subject: [PATCH 18/51] add more options to p montage --- AFQ/api/participant.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 380da7ed..18415712 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -259,7 +259,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=3): + def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None): """ Generate montage of all bundles for a given subject. @@ -269,6 +269,14 @@ def participant_montage(self, images_per_row=3): Number of bundle images per row in output file. Default: 3 + anatomy : bool + Whether to include anatomical images in the montage. + Default: True + + bundle_names : list of str or None + List of bundle names to include in the montage. + Default: None (includes all bundles) + Returns ------- filename of montage images @@ -276,21 +284,26 @@ def participant_montage(self, images_per_row=3): tdir = tempfile.gettempdir() all_fnames = [] - bundle_dict = self.export("bundle_dict") + if bundle_names is None: + bundle_dict = self.export("bundle_dict") + bundle_names = list(bundle_dict.keys()) self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") t1 = nib.load(self.export("t1_masked")) best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) best_scalar = resample(best_scalar, t1) - size = (images_per_row, math.ceil(3 * len(bundle_dict) / images_per_row)) - for ii, bundle_name in enumerate(tqdm(bundle_dict)): + size = (images_per_row, math.ceil(3 * len(bundle_names) / images_per_row)) + for ii, bundle_name in enumerate(tqdm(bundle_names)): flip_axes = [False, False, False] for i in range(3): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 - figure = viz_backend.visualize_volume( - t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False - ) + if anatomy: + figure = viz_backend.visualize_volume( + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False + ) + else: + figure = None figure = viz_backend.visualize_bundles( self.export("bundles"), img=t1, @@ -354,7 +367,7 @@ def _save_file(curr_img): max_height = 0 max_width = 0 ii = 0 - for b_idx, bundle_name in enumerate(bundle_dict): + for b_idx, bundle_name in enumerate(bundle_names): for view in ["Axial", "Coronal", "Sagittal"]: this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") try: From 6d71f3bf4ce47cccb44196744f93ec339d12cf03 Mon Sep 17 00:00:00 2001 From: 36000 Date: Fri, 6 Feb 2026 22:02:03 +0900 Subject: [PATCH 19/51] remove warnings from segmentedsft --- AFQ/tasks/segmentation.py | 2 +- AFQ/utils/streamlines.py | 19 ++++++++----------- AFQ/utils/tests/test_streamlines.py | 2 +- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 2781778e..79cd64cb 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -113,7 +113,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): **segmentation_params, ) - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) if len(seg_sft.sft) < 1: raise ValueError("Fatal: No bundles recognized.") diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 8c37801c..1bb027c0 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -15,10 +15,9 @@ class SegmentedSFT: - def __init__(self, bundles, space, sidecar_info=None): + def __init__(self, bundles, sidecar_info=None): if sidecar_info is None: sidecar_info = {} - reference = None self.bundle_names = [] sls = [] idxs = {} @@ -26,20 +25,17 @@ def __init__(self, bundles, space, sidecar_info=None): idx_count = 0 for b_name in bundles: if isinstance(bundles[b_name], dict): - this_sls = bundles[b_name]["sl"] + this_sft = bundles[b_name]["sl"] this_tracking_idxs[b_name] = bundles[b_name]["idx"] else: - this_sls = bundles[b_name] - if reference is None: - reference = this_sls - this_sls = list(this_sls.streamlines) + this_sft = bundles[b_name] + this_sls = list(this_sft.streamlines) sls.extend(this_sls) new_idx_count = idx_count + len(this_sls) idxs[b_name] = np.arange(idx_count, new_idx_count, dtype=np.uint32) idx_count = new_idx_count self.bundle_names.append(b_name) - self.sft = StatefulTractogram(sls, reference, space) self.bundle_idxs = idxs if len(this_tracking_idxs) > 1: self.this_tracking_idxs = this_tracking_idxs @@ -48,12 +44,13 @@ def __init__(self, bundles, space, sidecar_info=None): self.sidecar_info = sidecar_info self.sidecar_info["bundle_ids"] = {} - dps = np.zeros(len(self.sft.streamlines)) + dps = np.zeros(len(sls)) for ii, bundle_name in enumerate(self.bundle_names): self.sidecar_info["bundle_ids"][f"{bundle_name}"] = ii + 1 dps[self.bundle_idxs[bundle_name]] = ii + 1 - dps = {"bundle": dps} - self.sft.data_per_streamline = dps + self.sft = StatefulTractogram.from_sft( + sls, this_sft, data_per_streamline={"bundle": dps} + ) if self.this_tracking_idxs is not None: for kk, _vv in self.this_tracking_idxs.items(): self.this_tracking_idxs[kk] = ( diff --git a/AFQ/utils/tests/test_streamlines.py b/AFQ/utils/tests/test_streamlines.py index 39c624c7..8931670f 100644 --- a/AFQ/utils/tests/test_streamlines.py +++ b/AFQ/utils/tests/test_streamlines.py @@ -37,7 +37,7 @@ def test_SegmentedSFT(): ), } - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) for k1 in bundles.keys(): for sl1, sl2 in zip( bundles[k1].streamlines, seg_sft.get_bundle(k1).streamlines From 5f0b277b919468b1e6036f4f42f7f02e703efec4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 09:54:18 +0900 Subject: [PATCH 20/51] bf --- AFQ/utils/streamlines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 1bb027c0..571534a1 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -105,7 +105,7 @@ def fromfile(cls, trk_or_trx_file, reference="same", sidecar_file=None): else: bundles["whole_brain"] = sft - return cls(bundles, Space.RASMM, sidecar_info) + return cls(bundles, sidecar_info) def split_streamline(streamlines, sl_to_split, split_idx): From 564ef4f909bb36da8458c14c76df2a163038f106 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 10:36:10 +0900 Subject: [PATCH 21/51] viz bug fix --- AFQ/viz/plotly_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 0076c1b9..1af6acb4 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -512,7 +512,7 @@ def create_gif(figure, file_name, n_frames=30, zoom=2.5, z_offset=0.5, size=(600 def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): - roi = np.where(roi == 1) + roi = np.where(roi > 0) pts = [] for i, flip in enumerate(flip_axes): if flip: From 3637ee1c3bbcce8af302e988dd34c2fd85e6c985 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 10:36:36 +0900 Subject: [PATCH 22/51] further restrict pAF --- AFQ/api/bundle_dict.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index e5aedad7..de2b444f 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -256,7 +256,11 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"], templates["IFO_roi1_L"]], + "exclude": [ + templates["SLF_roi1_L"], + templates["IFO_roi1_L"], + templates["ILF_L_end"], + ], "space": "template", "prob_map": templates["ARC_L_prob_map"], # Better than nothing "start": templates["pARC_L_start"], @@ -268,7 +272,11 @@ def default_bd(): "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"], templates["IFO_roi1_R"]], + "exclude": [ + templates["SLF_roi1_R"], + templates["IFO_roi1_R"], + templates["ILF_R_end"], + ], "space": "template", "prob_map": templates["ARC_R_prob_map"], # Better than nothing "start": templates["pARC_R_start"], From bacc2c94bbd97b3b9ddea01b39b23a5eef29cc4d Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 11:34:14 +0900 Subject: [PATCH 23/51] try more constrained pAF/ARC defs --- AFQ/api/bundle_dict.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index de2b444f..aad19247 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -218,7 +218,7 @@ def default_bd(): "Left Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_L"], templates["SLFt_roi2_L"]], - "exclude": [], + "exclude": [templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], @@ -228,7 +228,7 @@ def default_bd(): "Right Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_R"], templates["SLFt_roi2_R"]], - "exclude": [], + "exclude": [templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], @@ -259,11 +259,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_L"], templates["IFO_roi1_L"], - templates["ILF_L_end"], + templates["ILF_roi2_L"], + templates["HCC_roi2_L"], ], "space": "template", - "prob_map": templates["ARC_L_prob_map"], # Better than nothing + "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], + "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", @@ -275,11 +277,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_R"], templates["IFO_roi1_R"], - templates["ILF_R_end"], + templates["ILF_roi2_R"], + templates["HCC_roi2_R"], ], "space": "template", - "prob_map": templates["ARC_R_prob_map"], # Better than nothing + "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], + "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", From df0e54546e1f05996605ac72192a5799207635c4 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 12:04:29 +0900 Subject: [PATCH 24/51] try this --- AFQ/api/bundle_dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index aad19247..8b253ae0 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -267,6 +267,7 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, + "Left Inferior Longitudinal": {"node_thresh": 40}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -285,6 +286,7 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, + "Right Inferior Longitudinal": {"node_thresh": 40}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, From 5c8c860a3898c3e558f0e37fb6296212cbc8d428 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 12:21:39 +0900 Subject: [PATCH 25/51] tighten ILF constraint --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 8b253ae0..47e8ce00 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -267,7 +267,7 @@ def default_bd(): "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, - "Left Inferior Longitudinal": {"node_thresh": 40}, + "Left Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -286,7 +286,7 @@ def default_bd(): "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, - "Right Inferior Longitudinal": {"node_thresh": 40}, + "Right Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, From 0bc7704d9eaa9b57f397e1477b0534018ed9c64c Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 14:26:44 +0900 Subject: [PATCH 26/51] solve pAF issues with new exclusion ROI --- AFQ/api/bundle_dict.py | 8 ++------ AFQ/data/fetch.py | 6 ++++++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 47e8ce00..553ede71 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -259,15 +259,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_L"], templates["IFO_roi1_L"], - templates["ILF_roi2_L"], - templates["HCC_roi2_L"], + templates["pARC_xroi1_L"], ], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, - "Left Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, @@ -278,15 +276,13 @@ def default_bd(): "exclude": [ templates["SLF_roi1_R"], templates["IFO_roi1_R"], - templates["ILF_roi2_R"], - templates["HCC_roi2_R"], + templates["pARC_xroi1_R"], ], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, - "Right Inferior Longitudinal": {"node_thresh": 20}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", "primary_axis_percentage": 40, diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index c88fc74e..eae03481 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -759,6 +759,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_R_start.nii.gz", "ATR_L_end.nii.gz", "ATR_L_start.nii.gz", + "pARC_xroi1_L.nii.gz", + "pARC_xroi1_R.nii.gz", ] @@ -861,6 +863,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944074", "40944077", "40944080", + "61737616", + "61737619", ] @@ -964,6 +968,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ffc157e9f73a43eff23821f2cfca614a", "a8d308a93b26242c04b878c733cb252f", "1c0b570bb2d622718b01ee2c429a5d15", + "51c8a6b5fbb0834b03986093b9ee4fa3", + "7cf5800a4efa6bac7e70d84095bc259b", ] fetch_templates = _make_reusable_fetcher( From f91b45b940cc4b0f3a48490f6f1515b23a07b387 Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 9 Feb 2026 16:34:22 +0900 Subject: [PATCH 27/51] return to strict VOF seg --- AFQ/api/bundle_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 553ede71..dc96824a 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,8 +297,8 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, + "Left Inferior Longitudinal": {"core": "Right"}, + "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", @@ -314,8 +314,8 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, + "Right Inferior Longitudinal": {"core": "Left"}, + "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", From a8e5cbf0396f604e238bc4605c694542f3f89600 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 10:18:51 +0900 Subject: [PATCH 28/51] return to stricter cleaning --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index dc96824a..917f29ee 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -298,7 +298,7 @@ def default_bd(): "entire_core": "Anterior", }, "Left Inferior Longitudinal": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", @@ -315,7 +315,7 @@ def default_bd(): "entire_core": "Anterior", }, "Right Inferior Longitudinal": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, + "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, "primary_axis": "I/S", From db6c583f037bafd136c111a867f16818875524a6 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 11:40:41 +0900 Subject: [PATCH 29/51] cleaning by other core requires higher levels of precision --- AFQ/recognition/criteria.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 0514e445..c0086c60 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -327,7 +327,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, ) @@ -337,7 +337,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) From 13d016b812df089f8efa6a10929c614bb415ac74 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 13:47:44 +0900 Subject: [PATCH 30/51] bf --- AFQ/recognition/criteria.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c0086c60..eaa224b3 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -327,6 +327,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + # the extra specificity of 100 points is needed np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, @@ -337,7 +338,7 @@ def clean_by_other_bundle( cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 20)), img.affine, True, ) From f2b9f5f8a713d2ab07835ebf4b85d7980d231a22 Mon Sep 17 00:00:00 2001 From: 36000 Date: Tue, 10 Feb 2026 15:53:26 +0900 Subject: [PATCH 31/51] maybe we can do this after clustering --- AFQ/api/bundle_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 917f29ee..c5102f50 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,7 +297,7 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Left Inferior Longitudinal": {"core": "Right"}, + # "Left Inferior Longitudinal": {"core": "Right"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, @@ -314,7 +314,7 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - "Right Inferior Longitudinal": {"core": "Left"}, + # "Right Inferior Longitudinal": {"core": "Left"}, "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, "isolation_forest": {}, From 6db5c58a48e08ec722363666490342e5ce23b0c7 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 12:12:44 +0900 Subject: [PATCH 32/51] Major registration overhaul --- AFQ/_fixes.py | 71 ++++++ AFQ/api/bundle_dict.py | 34 +-- AFQ/api/group.py | 11 +- AFQ/data/fetch.py | 109 ++++++++ AFQ/definitions/image.py | 2 +- AFQ/definitions/mapping.py | 232 ++++++++---------- AFQ/recognition/recognize.py | 16 +- AFQ/recognition/tests/test_recognition.py | 2 +- AFQ/recognition/utils.py | 8 +- AFQ/registration.py | 209 +++++++--------- AFQ/tasks/mapping.py | 6 +- AFQ/tasks/viz.py | 18 +- AFQ/tests/test_api.py | 5 - AFQ/tests/test_registration.py | 43 ++-- AFQ/utils/volume.py | 6 +- AFQ/viz/fury_backend.py | 28 +-- AFQ/viz/plotly_backend.py | 28 +-- AFQ/viz/utils.py | 65 ++--- .../plot_001_group_afq_api.py | 6 +- 19 files changed, 439 insertions(+), 460 deletions(-) diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 3d92664c..fb3c6db1 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -2,6 +2,8 @@ import math import numpy as np +from dipy.align import vector_fields as vfu +from dipy.align.imwarp import DiffeomorphicMap, mult_aff from dipy.data import default_sphere from dipy.reconst.gqi import squared_radial_component from dipy.tracking.streamline import set_number_of_points @@ -12,6 +14,75 @@ logger = logging.getLogger("AFQ") +def get_simplified_transform(self): + """Constructs a simplified version of this Diffeomorhic Map + + The simplified version incorporates the pre-align transform, as well as + the domain and codomain affine transforms into the displacement field. + The resulting transformation may be regarded as operating on the + image spaces given by the domain and codomain discretization. As a + result, self.prealign, self.disp_grid2world, self.domain_grid2world and + self.codomain affine will be None (denoting Identity) in the resulting + diffeomorphic map. + """ + if self.dim == 2: + simplify_f = vfu.simplify_warp_function_2d + else: + simplify_f = vfu.simplify_warp_function_3d + # Simplify the forward transform + D = self.domain_grid2world + P = self.prealign + Rinv = self.disp_world2grid + Cinv = self.codomain_world2grid + + # this is the matrix which we need to multiply the voxel coordinates + # to interpolate on the forward displacement field ("in"side the + # 'forward' brackets in the expression above) + affine_idx_in = mult_aff(Rinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the voxel coordinates + # to add to the displacement ("out"side the 'forward' brackets in the + # expression above) + affine_idx_out = mult_aff(Cinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the displacement vector + # prior to adding to the transformed input point + affine_disp = Cinv + + new_forward = simplify_f( + self.forward, affine_idx_in, affine_idx_out, affine_disp, self.domain_shape + ) + + # Simplify the backward transform + C = self.codomain_grid2world + Pinv = self.prealign_inv + Dinv = self.domain_world2grid + + affine_idx_in = mult_aff(Rinv, C) + affine_idx_out = mult_aff(Dinv, mult_aff(Pinv, C)) + affine_disp = mult_aff(Dinv, Pinv) + new_backward = simplify_f( + self.backward, + affine_idx_in, + affine_idx_out, + affine_disp, + self.codomain_shape, + ) + simplified = DiffeomorphicMap( + dim=self.dim, + disp_shape=self.disp_shape, + disp_grid2world=None, + domain_shape=self.domain_shape, + domain_grid2world=None, + codomain_shape=self.codomain_shape, + codomain_grid2world=None, + prealign=None, + ) + simplified.forward = new_forward + simplified.backward = new_backward + return simplified + + def gwi_odf(gqmodel, data): gqi_vector = np.real( squared_radial_component( diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index c5102f50..2c7497bb 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -334,61 +334,31 @@ def slf_bd(): "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, - }, }, }, citations={"Sagi2024"}, @@ -1337,9 +1307,7 @@ def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): else: boolean_ = False - warped_img = auv.transform_inverse_roi( - fdata, mapping, bundle_name=bundle_name - ) + warped_img = auv.transform_roi(fdata, mapping, bundle_name=bundle_name) if boolean_: warped_img = warped_img.astype(np.uint8) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index e1c9ebeb..b945dd05 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -8,7 +8,6 @@ import warnings from time import time -import dipy.tracking.streamline as dts import dipy.tracking.streamlinespeed as dps import nibabel as nib import numpy as np @@ -584,12 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - delta = dts.values_from_volume( - mapping.forward, tg.streamlines, np.eye(4) - ) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)] - ) + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1030,8 +1024,7 @@ def combine_bundle(self, bundle_name): mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - delta = dts.values_from_volume(mapping.forward, sls, np.eye(4)) - sls_mni.extend([d + s for d, s in zip(delta, sls)]) + sls_mni = mapping.tranform_points(sls) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index eae03481..b40f34ad 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1095,6 +1095,115 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +org800_fnames = [ + "ORG_atlas_tracks_reoriented.trx", + "ORG800_atlas_centroids.npy", + "ORG800_atlas_e_val.npy", + "ORG800_atlas_e_vec_norm.npy", + "ORG800_atlas_e_vec.npy", + "ORG800_atlas_number_of_eigenvectors.npy", + "ORG800_atlas_row_sum_1.npy", + "ORG800_atlas_row_sum_matrix.npy", + "ORG800_atlas_sigma.npy", +] + + +org800_remote_fnames = [ + "61762231", + "61762267", + "61762270", + "61762273", + "61762276", + "61762279", + "61762282", + "61762285", + "61762288", +] + + +org800_md5_hashes = [ + "9022799a73359209080ea832b22ec09b", + "09bfa384f5c44801dfa382d31392a979", + "bab61eb26cb21035e38b5f68b5fdad3e", + "9325f4cb168624d4f275785b18c9f859", + "12d426c5a6fcfbe3b8146bc335bdac96", + "db74e055c47c5b6354c3cb7bbf165f2c", + "e7e51f53b30764f104b93f50d94b6c3c", + "7e894c57a820cd7604a1db6b7ab8cce6", + "1e195b7055e98eb473bbb5af05d48f7d", +] + +fetch_org800_templates = _make_reusable_fetcher( + "fetch_org800_templates", + op.join(afq_home, "org800_templates"), + baseurl, + org800_remote_fnames, + org800_fnames, + md5_list=org800_md5_hashes, + doc="Download AFQ org800 templates", +) + + +def read_org800_templates(load_npy=True, load_trx=True): + """ + Load O'Donnell Research Group (ORG) Fiber Clustering White + Matter Atlas 800 modified for pyAFQ templates from file + + Parameters + ---------- + load_npy : bool, optional + If True, values are loaded as numpy arrays. Otherwise, values are + paths to npy files. Default: True + load_trx : bool, optional + If True, the tractogram is loaded as a StatefulTractogram. Otherwise, + the value is the path to the trx file. Default: True + + Returns + ------- + dict with: keys: names of atlas info + values: Floats, arrays, and StatefulTractogram for the atlas. + Any unloaded will instead be paths. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading org800 templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template(fetch_org800_templates) + + if load_trx: + template_dict["tracks_reoriented"] = load_tractogram( + template_dict.pop("ORG_atlas_tracks_reoriented"), + "same", + ) + if load_npy: + template_dict["centroids"] = np.load( + template_dict.pop("ORG800_atlas_centroids") + ) + template_dict["e_val"] = np.load(template_dict.pop("ORG800_atlas_e_val")) + template_dict["e_vec_norm"] = np.load( + template_dict.pop("ORG800_atlas_e_vec_norm") + ) + template_dict["e_vec"] = np.load(template_dict.pop("ORG800_atlas_e_vec")) + template_dict["number_of_eigenvectors"] = float( + np.load(template_dict.pop("ORG800_atlas_number_of_eigenvectors")) + ) + template_dict["row_sum_1"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_1") + ) + template_dict["row_sum_matrix"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_matrix") + ) + template_dict["sigma"] = float(np.load(template_dict.pop("ORG800_atlas_sigma"))) + + toc = time.perf_counter() + logger.debug( + f"O'Donnell Research Group 800 templates loaded in {toc - tic:0.4f} seconds" + ) + + return template_dict + + massp_fnames = [ "left_VTA.nii.gz", "right_VTA.nii.gz", diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index df5df2b4..4d68fb8b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -932,7 +932,7 @@ def _image_getter_helper(mapping, reg_template, reg_subject): static_affine=reg_template.affine, ).get_fdata() - scalar_data = mapping.transform_inverse(img_data, interpolation="nearest") + scalar_data = mapping.transform(img_data, interpolation="nearest") return nib.Nifti1Image( scalar_data.astype(np.float32), reg_subject.affine ), dict(source=self.path) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index dfcedef7..17171faf 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -5,9 +5,10 @@ import nibabel as nib import numpy as np from dipy.align import affine_registration, syn_registration -from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr import AFQ.registration as reg +from AFQ._fixes import get_simplified_transform from AFQ.definitions.utils import Definition, find_file from AFQ.tasks.utils import get_fname from AFQ.utils.path import space_from_fname, write_json @@ -177,11 +178,11 @@ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp - def transform_inverse(self, data, **kwargs): + def transform(self, data, **kwargs): data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) - def transform_inverse_pts(self, pts): + def transform_pts(self, pts): # This should only be used for curvature analysis, # Because I think the results still need to be shifted pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts) @@ -189,39 +190,13 @@ def transform_inverse_pts(self, pts): pts = self.warp.transform(pts, "fsl", "world") return pts - def transform(self, data, **kwargs): + def transform_inverse(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space" ) -class IdentityMap(Definition): - """ - Does not perform any transformations from MNI to subject where - pyAFQ normally would. - - Examples - -------- - my_example_mapping = IdentityMap() - api.GroupAFQ(mapping=my_example_mapping) - """ - - def __init__(self): - pass - - def get_for_subses( - self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name - ): - return ConformedAffineMapping( - np.identity(4), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) - - class GeneratedMapMixin(object): """ Helper Class @@ -236,24 +211,16 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): mapping_file = mapping_file + extension return mapping_file, meta_fname - def prealign( - self, base_fname, sub_name, tmpl_name, reg_subject, reg_template, save=True - ): - prealign_file_desc = f"_desc-prealign_from-{sub_name}_to-{tmpl_name}_xform" - prealign_file = get_fname(base_fname, f"{prealign_file_desc}.npy") - if not op.exists(prealign_file): - start_time = time() - _, aff = affine_registration( - reg_subject, reg_template, **self.affine_kwargs - ) - meta = dict(type="rigid", dependent="dwi", timing=time() - start_time) - if not save: - return aff - logger.info(f"Saving {prealign_file}") - np.save(prealign_file, aff) - meta_fname = get_fname(base_fname, f"{prealign_file_desc}.json") - write_json(meta_fname, meta) - return prealign_file if save else np.load(prealign_file) + def prealign(self, reg_subject, reg_template): + _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) + return aff + + +class AffineMapMixin(GeneratedMapMixin): + """ + Helper Class + Useful for maps that are generated by pyAFQ + """ def get_for_subses( self, @@ -268,34 +235,22 @@ def get_for_subses( ): sub_space = space_from_fname(dwi_data_file) mapping_file, meta_fname = self.get_fnames( - self.extension, base_fname, sub_space, tmpl_name + ".npy", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - reg_prealign = np.load( - self.prealign( - base_fname, sub_space, tmpl_name, reg_subject, reg_template - ) - ) - else: - reg_prealign = None if not op.exists(mapping_file): start_time = time() - mapping = self.gen_mapping( - base_fname, - sub_space, - tmpl_name, - reg_subject, + self.gen_mapping( reg_template, + reg_subject, subject_sls, template_sls, - reg_prealign, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - reg.write_mapping(mapping, mapping_file) - meta = dict(type="displacementfield", timing=total_time) + np.save(mapping_file, mapping.affine) + meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" else: @@ -305,10 +260,7 @@ def get_for_subses( if isinstance(reg_template, str): meta["reg_template"] = reg_template write_json(meta_fname, meta) - reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign else None - mapping = reg.read_mapping( - mapping_file, dwi, reg_template, prealign=reg_prealign_inv - ) + mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template) return mapping @@ -353,33 +305,69 @@ def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None): self.use_prealign = use_prealign self.affine_kwargs = affine_kwargs self.syn_kwargs = syn_kwargs - self.extension = ".nii.gz" - def gen_mapping( + def get_for_subses( self, base_fname, - sub_space, - tmpl_name, + dwi, + dwi_data_file, reg_subject, reg_template, - subject_sls, - template_sls, - reg_prealign, + tmpl_name, + subject_sls=None, + template_sls=None, ): - _, mapping = syn_registration( - reg_subject.get_fdata(), - reg_template.get_fdata(), - moving_affine=reg_subject.affine, - static_affine=reg_template.affine, - prealign=reg_prealign, - **self.syn_kwargs, + sub_space = space_from_fname(dwi_data_file) + mapping_file_forward, meta_forward_fname = self.get_fnames( + ".nii.gz", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - mapping.codomain_world2grid = np.linalg.inv(reg_prealign) + mapping_file_backward, meta_backward_fname = self.get_fnames( + ".nii.gz", base_fname, tmpl_name, sub_space + ) + + if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward): + meta = dict(type="displacementfield") + meta["dependent"] = "dwi" + if isinstance(reg_subject, str): + meta["reg_subject"] = reg_subject + if isinstance(reg_template, str): + meta["reg_template"] = reg_template + + start_time = time() + if self.use_prealign: + reg_prealign = self.prealign(reg_subject, reg_template) + else: + reg_prealign = None + _, mapping = syn_registration( + reg_subject.get_fdata(), + reg_template.get_fdata(), + moving_affine=reg_subject.affine, + static_affine=reg_template.affine, + prealign=np.linalg.inv(reg_prealign), + **self.syn_kwargs, + ) + mapping = get_simplified_transform(mapping) + + total_time = time() - start_time + meta["total_time"] = total_time + + logger.info(f"Saving {mapping_file_forward}") + nib.save( + nib.Nifti1Image(mapping.forward, reg_subject.affine), + mapping_file_forward, + ) + write_json(meta_forward_fname, meta) + logger.info(f"Saving {mapping_file_backward}") + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, + ) + write_json(meta_backward_fname, meta) + mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward) return mapping -class SlrMap(GeneratedMapMixin, Definition): +class SlrMap(AffineMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. @@ -407,33 +395,23 @@ class SlrMap(GeneratedMapMixin, Definition): def __init__(self, slr_kwargs=None): if slr_kwargs is None: slr_kwargs = {} - self.slr_kwargs = {} - self.use_prealign = False - self.extension = ".npy" + self.slr_kwargs = slr_kwargs def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_template, reg_subject, subject_sls, template_sls, - reg_prealign, ): - return reg.slr_registration( - subject_sls, - template_sls, - moving_affine=reg_subject.affine, - moving_shape=reg_subject.shape, - static_affine=reg_template.affine, - static_shape=reg_template.shape, - **self.slr_kwargs, + _, transform, _, _ = whole_brain_slr( + subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs ) + return transform + -class AffMap(GeneratedMapMixin, Definition): +class AffMap(AffineMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. @@ -457,47 +435,39 @@ class AffMap(GeneratedMapMixin, Definition): def __init__(self, affine_kwargs=None): if affine_kwargs is None: affine_kwargs = {} - self.use_prealign = False self.affine_kwargs = affine_kwargs - self.extension = ".npy" def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ): - return ConformedAffineMapping( - np.linalg.inv( - self.prealign( - base_fname, - sub_space, - tmpl_name, - reg_subject, - reg_template, - save=False, - ) - ), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) + return np.linalg.inv( + self.prealign(reg_subject, reg_template) + ) # TODO: test: this still needs to be inverted? -class ConformedAffineMapping(AffineMap): +class IdentityMap(AffineMapMixin, Definition): """ - Modifies AffineMap API to match DiffeomorphicMap API. - Important for SLR maps API to be indistinguishable from SYN maps API. + Does not perform any transformations from MNI to subject where + pyAFQ normally would. + + Examples + -------- + my_example_mapping = IdentityMap() + api.GroupAFQ(mapping=my_example_mapping) """ - def transform(self, *args, **kwargs): - return super().transform_inverse(*args, **kwargs) + def __init__(self): + pass - def transform_inverse(self, *args, **kwargs): - return super().transform(*args, **kwargs) + def gen_mapping( + self, + reg_subject, + reg_template, + subject_sls, + template_sls, + ): + return np.identity(4) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 1727bd0f..89b0ae24 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -155,10 +155,10 @@ def recognize( tg.to_vox() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) + bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.float32) bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.uint32 + (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 ) fiber_groups = {} @@ -205,10 +205,20 @@ def recognize( "Conflicts in bundle assignment detected. " f"{conflicts} conflicts detected in total out of " f"{n_streamlines} total streamlines. " - "Defaulting to whichever bundle appears first " + "Defaulting to whichever bundle is closest to the include ROI," + "followed by whichever appears first " "in the bundle_dict." ) ) + + # Weight by distance to ROI + valid_dists = bundle_roi_closest != -1 + dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + has_any_valid_roi = np.any(valid_dists, axis=2) + max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) + final_mask = (bundle_decisions > 0) & has_any_valid_roi + bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) + bundle_decisions = np.concatenate( (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 ) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index ef513f14..3fdc0f52 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -22,7 +22,7 @@ hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() reg_template = afd.read_mni_template() -mapping = reg.read_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) +mapping = reg.read_old_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) streamlines = file_dict["tractography_subsampled.trk"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) tg.to_vox() diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 505f8300..8258dea9 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -129,15 +129,13 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): tg.to_vox() moved_sl = [] for sl in tg.streamlines: - moved_sl.append(mapping.transform_inverse_pts(sl)) + moved_sl.append(mapping.transform_pts(sl)) else: tg.to_rasmm() if to == "template": - volume = mapping.forward + moved_sl = mapping.transform_points_inverse(tg.streamlines) else: - volume = mapping.backward - delta = dts.values_from_volume(volume, tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines([d + s for d, s in zip(delta, tg.streamlines)]) + moved_sl = mapping.transform_points(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( diff --git a/AFQ/registration.py b/AFQ/registration.py index f63e784b..7a61c907 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -4,11 +4,14 @@ import nibabel as nib import numpy as np -from dipy.align import syn_registration +from dipy.align.imaffine import AffineMap from dipy.align.imwarp import DiffeomorphicMap -from dipy.align.streamlinear import whole_brain_slr -__all__ = ["syn_register_dwi", "write_mapping", "read_mapping", "slr_registration"] +__all__ = [ + "read_affine_mapping", + "read_syn_mapping", + "read_old_mapping", +] def reduce_shape(shape): @@ -21,78 +24,55 @@ def reduce_shape(shape): return shape -def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): +def read_syn_mapping(disp, codisp): """ - Register DWI data to a template. + Read a syn registration mapping from a nifti file Parameters - ----------- - dwi : nifti image or str - Image containing DWI data, or full path to a nifti file with DWI. - gtab : GradientTable - The gradients associated with the DWI data - template : nifti image or str, optional + ---------- + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from subject to template - syn_kwargs : key-word arguments for :func:`syn_registration` + codisp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from template to subject Returns ------- - DiffeomorphicMap object + A :class:`DiffeomorphicMap` object """ - if template is None: - import AFQ.data.fetch as afd - - template = afd.read_mni_template() - if isinstance(template, str): - template = nib.load(template) - - template_data = template.get_fdata() - template_affine = template.affine - - if isinstance(dwi, str): - dwi = nib.load(dwi) - - dwi_affine = dwi.affine - dwi_data = dwi.get_fdata() - mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) - warped_b0, mapping = syn_registration( - mean_b0, - template_data, - moving_affine=dwi_affine, - static_affine=template_affine, - **syn_kwargs, + if isinstance(disp, str): + disp = nib.load(disp) + + if isinstance(codisp, str): + codisp = nib.load(codisp) + + mapping = DiffeomorphicMap( + dim=3, + disp_shape=codisp.get_fdata().shape[:3], + disp_grid2world=None, + domain_shape=disp.get_fdata().shape[:3], + domain_grid2world=None, + codomain_shape=codisp.get_fdata().shape[:3], + codomain_grid2world=None, ) - return warped_b0, mapping - - -def write_mapping(mapping, fname): - """ - Write out a syn registration mapping to file + mapping.forward = disp.get_fdata().astype(np.float32) + mapping.backward = codisp.get_fdata().astype(np.float32) - Parameters - ---------- - mapping : a DiffeomorphicMap object derived from :func:`syn_registration` - fname : str - Full path to the nifti file storing the mapping - - """ - if isinstance(mapping, DiffeomorphicMap): - mapping_imap = np.array([mapping.forward.T, mapping.backward.T]).T - nib.save(nib.Nifti1Image(mapping_imap, mapping.codomain_world2grid), fname) - else: - np.save(fname, mapping.affine) + return mapping -def read_mapping(disp, domain_img, codomain_img, prealign=None): +def read_affine_mapping(affine, domain_img, codomain_img): """ Read a syn registration mapping from a nifti file Parameters ---------- - disp : str, Nifti1Image, or ndarray - If string, file must of an image or ndarray. - If image, contains the mapping displacement field in each voxel - Shape (x, y, z, 3, 2) + affine : str or ndarray + If string, file must of an ndarray. If ndarray, contains affine transformation used for mapping domain_img : str or Nifti1Image @@ -101,13 +81,10 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): Returns ------- - A :class:`DiffeomorphicMap` object + A :class:`AffineMap` object """ - if isinstance(disp, str): - if "nii.gz" in disp: - disp = nib.load(disp) - else: - disp = np.load(disp) + if isinstance(affine, str): + affine = np.load(affine) if isinstance(domain_img, str): domain_img = nib.load(domain_img) @@ -115,79 +92,59 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): if isinstance(codomain_img, str): codomain_img = nib.load(codomain_img) - if isinstance(disp, nib.Nifti1Image): - mapping = DiffeomorphicMap( - 3, - disp.shape[:3], - disp_grid2world=np.linalg.inv(disp.affine), - domain_shape=domain_img.shape[:3], - domain_grid2world=domain_img.affine, - codomain_shape=codomain_img.shape, - codomain_grid2world=codomain_img.affine, - prealign=prealign, - ) - - disp_data = disp.get_fdata().astype(np.float32) - mapping.forward = disp_data[..., 0] - mapping.backward = disp_data[..., 1] - mapping.is_inverse = True - else: - from AFQ.definitions.mapping import ConformedAffineMapping - - mapping = ConformedAffineMapping( - disp, - domain_grid_shape=reduce_shape(domain_img.shape), - domain_grid2world=domain_img.affine, - codomain_grid_shape=reduce_shape(codomain_img.shape), - codomain_grid2world=codomain_img.affine, - ) + mapping = AffineMap( + affine, + domain_grid_shape=reduce_shape(domain_img.shape), + domain_grid2world=domain_img.affine, + codomain_grid_shape=reduce_shape(codomain_img.shape), + codomain_grid2world=codomain_img.affine, + ) return mapping -def slr_registration( - moving_data, - static_data, - moving_affine=None, - static_affine=None, - moving_shape=None, - static_shape=None, - **kwargs, -): - """Register a source image (moving) to a target image (static). +def read_old_mapping(disp, domain_img, codomain_img): + """ + Warning: This is only used for pyAFQ tests and backwards compatibility. + Read old-style registration mapping from a nifti file. Parameters ---------- - moving : ndarray - The source tractography data to be registered - moving_affine : ndarray - The affine associated with the moving (source) data. - moving_shape : ndarray - The shape of the space associated with the static (target) data. - static : ndarray - The target tractography data for registration - static_affine : ndarray - The affine associated with the static (target) data. - static_shape : ndarray - The shape of the space associated with the static (target) data. - - **kwargs: - kwargs are passed into whole_brain_slr + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + Shape (x, y, z, 3, 2) + + domain_img : str or Nifti1Image + + codomain_img : str or Nifti1Image Returns ------- - AffineMap + A :class:`DiffeomorphicMap` object """ - from AFQ.definitions.mapping import ConformedAffineMapping + if isinstance(disp, str): + disp = nib.load(disp) - _, transform, _, _ = whole_brain_slr( - static_data, moving_data, x0="affine", verbose=False, **kwargs - ) + if isinstance(domain_img, str): + domain_img = nib.load(domain_img) + + if isinstance(codomain_img, str): + codomain_img = nib.load(codomain_img) - return ConformedAffineMapping( - transform, - codomain_grid_shape=reduce_shape(static_shape), - codomain_grid2world=static_affine, - domain_grid_shape=reduce_shape(moving_shape), - domain_grid2world=moving_affine, + mapping = DiffeomorphicMap( + 3, + disp.shape[:3], + disp_grid2world=np.linalg.inv(disp.affine), + domain_shape=domain_img.shape[:3], + domain_grid2world=domain_img.affine, + codomain_shape=codomain_img.shape, + codomain_grid2world=codomain_img.affine, ) + + disp_data = disp.get_fdata().astype(np.float32) + mapping.forward = disp_data[..., 0] + mapping.backward = disp_data[..., 1] + mapping.is_inverse = True + + return mapping diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index c593371a..724711ea 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -30,7 +30,7 @@ def export_registered_b0(base_fname, data_imap, mapping): ) if not op.exists(warped_b0_fname): mean_b0 = nib.load(data_imap["b0"]).get_fdata() - warped_b0 = mapping.transform(mean_b0) + warped_b0 = mapping.transform_inverse(mean_b0) warped_b0 = nib.Nifti1Image(warped_b0, data_imap["reg_template"].affine) logger.info(f"Saving {warped_b0_fname}") nib.save(warped_b0, warped_b0_fname) @@ -54,9 +54,7 @@ def template_xform(base_fname, dwi_data_file, data_imap, mapping): base_fname, f"_space-{subject_space}_desc-template_anat.nii.gz" ) if not op.exists(template_xform_fname): - template_xform = mapping.transform_inverse( - data_imap["reg_template"].get_fdata() - ) + template_xform = mapping.transform(data_imap["reg_template"].get_fdata()) template_xform = nib.Nifti1Image(template_xform, data_imap["dwi_affine"]) logger.info(f"Saving {template_xform_fname}") nib.save(template_xform, template_xform_fname) diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 065f8799..1580c24a 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -21,7 +21,7 @@ logger = logging.getLogger("AFQ") -def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): +def _viz_prepare_vol(vol, scalar_dict, ref): if vol in scalar_dict.keys(): vol = scalar_dict[vol] @@ -31,8 +31,6 @@ def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): vol = resample(vol, ref) vol = vol.get_fdata() - if xform: - vol = mapping.transform_inverse(vol) vol[np.isnan(vol)] = 0 return vol @@ -81,15 +79,12 @@ def viz_bundles( """ if sbv_lims_bundles is None: sbv_lims_bundles = [None, None] - mapping = mapping_imap["mapping"] scalar_dict = segmentation_imap["scalar_dict"] profiles_file = segmentation_imap["profiles"] t1_img = nib.load(structural_imap["t1_masked"]) shade_by_volume = get_tp(best_scalar, structural_imap, data_imap, tissue_imap) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, t1_img - ) - volume = _viz_prepare_vol(t1_img, False, mapping, scalar_dict, t1_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, t1_img) + volume = _viz_prepare_vol(t1_img, scalar_dict, t1_img) flip_axes = [False, False, False] for i in range(3): @@ -183,7 +178,6 @@ def viz_indivBundle( """ if sbv_lims_indiv is None: sbv_lims_indiv = [None, None] - mapping = mapping_imap["mapping"] bundle_dict = data_imap["bundle_dict"] scalar_dict = segmentation_imap["scalar_dict"] volume_img = nib.load(structural_imap["t1_masked"]) @@ -191,10 +185,8 @@ def viz_indivBundle( profiles = pd.read_csv(segmentation_imap["profiles"]) start_time = time() - volume = _viz_prepare_vol(volume_img, False, mapping, scalar_dict, volume_img) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, volume_img - ) + volume = _viz_prepare_vol(volume_img, scalar_dict, volume_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, volume_img) flip_axes = [False, False, False] for i in range(3): diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 9771ea68..5c5b1f8b 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -791,11 +791,6 @@ def test_AFQ_data_waypoint(): "sub-01_ses-01_desc-mapping_from-subject_to-mni_xform.nii.gz", ) nib.save(mapping, mapping_file) - reg_prealign_file = op.join( - myafq.export("output_dir"), - "sub-01_ses-01_desc-prealign_from-subject_to-mni_xform.npy", - ) - np.save(reg_prealign_file, np.eye(4)) # Test ROI exporting: myafq.export("rois") diff --git a/AFQ/tests/test_registration.py b/AFQ/tests/test_registration.py index c1b394c1..d1b3144c 100644 --- a/AFQ/tests/test_registration.py +++ b/AFQ/tests/test_registration.py @@ -5,16 +5,12 @@ import nibabel.tmpdirs as nbtmp import numpy as np import numpy.testing as npt -from dipy.align.imwarp import DiffeomorphicMap +from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr from dipy.io.streamline import load_tractogram import AFQ.data.fetch as afd -from AFQ.registration import ( - read_mapping, - slr_registration, - syn_register_dwi, - write_mapping, -) +from AFQ.registration import read_affine_mapping, reduce_shape MNI_T2 = afd.read_mni_template() hardi_img, gtab = dpd.read_stanford_hardi() @@ -50,27 +46,34 @@ def test_slr_registration(): hcp_atlas = load_tractogram(atlas_fname, "same", bbox_valid_check=False) with nbtmp.InTemporaryDirectory() as tmpdir: - mapping = slr_registration( + _, transform, _, _ = whole_brain_slr( streamlines, hcp_atlas.streamlines, - moving_affine=subset_b0_img.affine, - static_affine=subset_t2_img.affine, - moving_shape=subset_b0_img.shape, - static_shape=subset_t2_img.shape, + x0="affine", + verbose=False, progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8), ) - warped_moving = mapping.transform(subset_b0) + + mapping = AffineMap( + transform, + domain_grid_shape=reduce_shape(subset_b0_img.shape), + domain_grid2world=subset_b0_img.affine, + codomain_grid_shape=reduce_shape(subset_t2_img.shape), + codomain_grid2world=subset_t2_img.affine, + ) + + warped_moving = mapping.transform_inverse(subset_b0) npt.assert_equal(warped_moving.shape, subset_t2.shape) mapping_fname = op.join(tmpdir, "mapping.npy") - write_mapping(mapping, mapping_fname) - file_mapping = read_mapping(mapping_fname, subset_b0_img, subset_t2_img) + np.save(mapping_fname, transform) + file_mapping = read_affine_mapping(mapping_fname, subset_b0_img, subset_t2_img) # Test that it has the same effect on the data: - warped_from_file = file_mapping.transform(subset_b0) + warped_from_file = file_mapping.transform_inverse(subset_b0) npt.assert_equal(warped_from_file, warped_moving) # Test that it is, attribute by attribute, identical: @@ -78,11 +81,3 @@ def test_slr_registration(): assert np.all( mapping.__getattribute__(k) == file_mapping.__getattribute__(k) ) - - -def test_syn_register_dwi(): - warped_b0, mapping = syn_register_dwi( - subset_dwi_data, gtab, template=subset_t2_img, radius=1 - ) - npt.assert_equal(isinstance(mapping, DiffeomorphicMap), True) - npt.assert_equal(warped_b0.shape, subset_t2_img.shape) diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index cea03655..782ac391 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -12,7 +12,7 @@ logger = logging.getLogger("AFQ") -def transform_inverse_roi(roi, mapping, bundle_name="ROI"): +def transform_roi(roi, mapping, bundle_name="ROI"): """ After being non-linearly transformed, ROIs tend to have holes in them. We perform a couple of computational geometry operations on the ROI to @@ -40,12 +40,12 @@ def transform_inverse_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform_inverse(roi, interpolation="linear") + _roi = mapping.transform(roi.astype(float), interpolation="linear") if np.sum(_roi) == 0: logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") _roi = binary_dilation(roi) - _roi = mapping.transform_inverse(_roi, interpolation="linear") + _roi = mapping.transform(_roi.astype(float), interpolation="linear") _roi = patch_up_roi(_roi > 0, bundle_name=bundle_name).astype(np.int32) diff --git a/AFQ/viz/fury_backend.py b/AFQ/viz/fury_backend.py index d8771f20..6ee4d0cf 100644 --- a/AFQ/viz/fury_backend.py +++ b/AFQ/viz/fury_backend.py @@ -203,11 +203,7 @@ def create_gif( def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, color=None, @@ -224,22 +220,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -275,9 +257,7 @@ def visualize_roi( """ if color is None: color = np.array([1, 0, 0]) - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) for i, flip in enumerate(flip_axes): if flip: roi = np.flip(roi, axis=i) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 1af6acb4..07d467e0 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -535,11 +535,7 @@ def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, flip_axes=None, @@ -556,22 +552,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -612,9 +594,7 @@ def visualize_roi( color = np.array([0.9999, 0, 0]) if flip_axes is None: flip_axes = [False, False, False] - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) if figure is None: figure = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]]) diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 4ffe316a..1b82a306 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -13,9 +13,7 @@ from dipy.tracking.streamline import transform_streamlines from PIL import Image, ImageChops -import AFQ.registration as reg import AFQ.utils.streamlines as aus -import AFQ.utils.volume as auv __all__ = ["Viz"] @@ -591,9 +589,7 @@ def gif_from_pngs(tdir, gif_fname, n_frames, png_fname="tgif", add_zeros=False): io.mimsave(gif_fname, angles) -def prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template -): +def prepare_roi(roi, resample_to=None): """ Load the ROI Possibly perform a transformation on an ROI @@ -605,60 +601,27 @@ def prepare_roi( The ROI information. If str, ROI will be loaded using the str as a path. - affine_or_mapping : ndarray, Nifti1Image, or str - An affine transformation or mapping to apply to the ROI before - visualization. Default: no transform. - - static_img: str or Nifti1Image - Template to resample roi to. - - roi_affine: ndarray - - static_affine: ndarray - - reg_template: str or Nifti1Image - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Returns ------- ndarray """ viz_logger.info("Preparing ROI...") + if isinstance(roi, str): + roi = nib.load(roi) + + if resample_to is not None: + if not isinstance(roi, nib.Nifti1Image): + raise ValueError( + ("If resampling, roi must be a Nifti1Image or a path to one.") + ) + roi = resample(roi, resample_to) + if not isinstance(roi, np.ndarray): - if isinstance(roi, str): - roi = nib.load(roi).get_fdata() - else: - roi = roi.get_fdata() - - if affine_or_mapping is not None: - if isinstance(affine_or_mapping, np.ndarray): - # This is an affine: - if static_img is None or roi_affine is None or static_affine is None: - raise ValueError( - "If using an affine to transform an ROI, " - "need to also specify all of the following", - "inputs: `static_img`, `roi_affine`, ", - "`static_affine`", - ) - roi = resample( - roi, static_img, moving_affine=roi_affine, static_affine=static_affine - ).get_fdata() - else: - # Assume it is a mapping: - if isinstance(affine_or_mapping, str) or isinstance( - affine_or_mapping, nib.Nifti1Image - ): - if reg_template is None or static_img is None: - raise ValueError( - "If using a mapping to transform an ROI, need to ", - "also specify all of the following inputs: ", - "`reg_template`, `static_img`", - ) - affine_or_mapping = reg.read_mapping( - affine_or_mapping, static_img, reg_template - ) + roi = roi.get_fdata() - roi = auv.transform_inverse_roi(roi, affine_or_mapping).astype(bool) return roi diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 22fc7542..0b8f0980 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -269,9 +269,9 @@ "NDARAA948VFH"]["HBNsiteRU"], index_col=[0]) for ind in bundle_counts.index: if ind == "Total Recognized": - threshold = 1000 - elif "Fronto-occipital" in ind or "Orbital" in ind: - threshold = 5 + threshold = 3000 + elif "Fronto-occipital" in ind: + threshold = 10 else: threshold = 15 if bundle_counts["n_streamlines"][ind] < threshold: From 2dd6049bc0f653cc52c173399f6371237345ae90 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 12:42:16 +0900 Subject: [PATCH 33/51] the transform points direction is opposite for some reason --- AFQ/api/group.py | 4 ++-- AFQ/recognition/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index b945dd05..0c57ba35 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -583,7 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sl = mapping.transform_points(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1024,7 +1024,7 @@ def combine_bundle(self, bundle_name): mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni = mapping.tranform_points(sls) + sls_mni = mapping.transform_points(sls) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 8258dea9..57887d26 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -133,9 +133,9 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): else: tg.to_rasmm() if to == "template": - moved_sl = mapping.transform_points_inverse(tg.streamlines) - else: moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( From b1878b19e6794671a61225c2eddf480dc3963d53 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 13:22:53 +0900 Subject: [PATCH 34/51] verfified --- AFQ/definitions/mapping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index 17171faf..5cf326e4 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -444,9 +444,7 @@ def gen_mapping( subject_sls, template_sls, ): - return np.linalg.inv( - self.prealign(reg_subject, reg_template) - ) # TODO: test: this still needs to be inverted? + return np.linalg.inv(self.prealign(reg_subject, reg_template)) class IdentityMap(AffineMapMixin, Definition): From d9f62c12997081e1b5720c7050bbbd03144dd161 Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 11 Feb 2026 21:06:53 +0900 Subject: [PATCH 35/51] this should be not the inverse --- AFQ/definitions/mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index 5cf326e4..f5c420e1 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -343,7 +343,7 @@ def get_for_subses( reg_template.get_fdata(), moving_affine=reg_subject.affine, static_affine=reg_template.affine, - prealign=np.linalg.inv(reg_prealign), + prealign=reg_prealign, **self.syn_kwargs, ) mapping = get_simplified_transform(mapping) From 59fea7660bfcf094a4e6a564d10db5d29ee41f2b Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 13:42:43 +0900 Subject: [PATCH 36/51] add ROI transformation fix --- AFQ/registration.py | 3 ++- AFQ/utils/volume.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/AFQ/registration.py b/AFQ/registration.py index 7a61c907..708492ca 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -103,7 +103,7 @@ def read_affine_mapping(affine, domain_img, codomain_img): return mapping -def read_old_mapping(disp, domain_img, codomain_img): +def read_old_mapping(disp, domain_img, codomain_img, prealign=None): """ Warning: This is only used for pyAFQ tests and backwards compatibility. Read old-style registration mapping from a nifti file. @@ -140,6 +140,7 @@ def read_old_mapping(disp, domain_img, codomain_img): domain_grid2world=domain_img.affine, codomain_shape=codomain_img.shape, codomain_grid2world=codomain_img.affine, + prealign=prealign, ) disp_data = disp.get_fdata().astype(np.float32) diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index 782ac391..a8269811 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -40,7 +40,15 @@ def transform_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform(roi.astype(float), interpolation="linear") + # dilate binary images to avoid losing small ROIs + if np.unique(roi).size < 3: + scale_factor = max( + np.asarray(mapping.codomain_shape) / np.asarray(mapping.domain_shape) + ) + for _ in range(max(np.ceil(scale_factor) - 1, 0).astype(int)): + roi = binary_dilation(roi) + + _roi = mapping.transform((roi.astype(float)), interpolation="linear") if np.sum(_roi) == 0: logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") From 258b492d31eccd2416466c82d6efb4c5de4913a9 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 14:25:46 +0900 Subject: [PATCH 37/51] cleanup some minor bugs from tests --- AFQ/recognition/recognize.py | 9 +++++---- AFQ/recognition/roi.py | 8 ++++++++ AFQ/recognition/tests/test_recognition.py | 2 +- AFQ/recognition/tests/test_rois.py | 10 ++++++---- AFQ/registration.py | 2 +- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 89b0ae24..f9a79e63 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -213,11 +213,12 @@ def recognize( # Weight by distance to ROI valid_dists = bundle_roi_closest != -1 - dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) has_any_valid_roi = np.any(valid_dists, axis=2) - max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) - final_mask = (bundle_decisions > 0) & has_any_valid_roi - bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) + if np.any(has_any_valid_roi): + dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) + final_mask = (bundle_decisions > 0) & has_any_valid_roi + bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) bundle_decisions = np.concatenate( (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 83e1cf53..474bc360 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -67,6 +67,14 @@ def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): ------- boolean array of streamlines that survive cleaning. """ + if not isinstance(fgarray, np.ndarray): + raise ValueError( + ( + "fgarray must be a numpy ndarray, you can resample " + "your streamlines using resample_tg in AFQ.recognition.utils" + ) + ) + n_sls, n_nodes, _ = fgarray.shape # handle target_idx negative values as wrapping around diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 3fdc0f52..c03f72b7 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -152,7 +152,7 @@ def test_segment_reco(): # This condition should still hold npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_(len(fiber_groups["CST_L"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index f5140660..0f560115 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -4,6 +4,7 @@ from scipy.ndimage import distance_transform_edt import AFQ.recognition.roi as abr +import AFQ.recognition.utils as abu from AFQ.recognition.roi import check_sl_with_exclusion, check_sls_with_inclusion shape = (15, 15, 15) @@ -17,6 +18,7 @@ np.array([[1, 1, 1], [2, 1, 1], [3, 1, 1]]), np.array([[1, 1, 1], [2, 1, 1]]), ] +fgarray = np.array(abu.resample_tg(streamlines, 20)) roi1 = np.ones(shape, dtype=np.float32) roi1[1, 2, 3] = 0 @@ -43,15 +45,15 @@ def test_clean_by_endpoints(): - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0]) ) # If tol=1, the third streamline also gets included - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0, tol=1)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1, tol=1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0, tol=1)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1, tol=1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0]) ) diff --git a/AFQ/registration.py b/AFQ/registration.py index 708492ca..28b72799 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -146,6 +146,6 @@ def read_old_mapping(disp, domain_img, codomain_img, prealign=None): disp_data = disp.get_fdata().astype(np.float32) mapping.forward = disp_data[..., 0] mapping.backward = disp_data[..., 1] - mapping.is_inverse = True + mapping.is_inverse = False return mapping From c56819bc64028f52bb28522d9169a89aaacfe028 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 15:10:41 +0900 Subject: [PATCH 38/51] add documentation and tests for mixed bundle definitions --- AFQ/recognition/tests/test_recognition.py | 42 +++++++++++++++++++++++ docs/source/reference/bundledict.rst | 38 ++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index c03f72b7..95736aa2 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -9,6 +9,7 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.stats.analysis import afq_profile +import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd import AFQ.recognition.cleaning as abc import AFQ.registration as reg @@ -83,6 +84,47 @@ def test_segment(): npt.assert_equal(len(clean_sl), len(CST_R_sl)) +def test_segment_mixed_roi(): + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + bundle_info = { + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed", + } + } + + with pytest.raises( + ValueError, + match=( + "When using mixed ROI bundle definitions, and subject space ROIs, " + "resample_subject_to cannot be False." + ), + ): + fiber_groups, _ = recognize( + tg, nib.load(hardi_fdata), mapping, bundle_info, reg_template, 2 + ) + + bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundle_info, + reg_template, + 2, + dist_to_atlas=10, + ) + + # We asked for 2 fiber groups: + npt.assert_equal(len(fiber_groups), 1) + OR_LV1_sl = fiber_groups["OR LV1"] + npt.assert_(len(OR_LV1_sl) == 2) + + @pytest.mark.nightly def test_segment_no_prob(): # What if you don't have probability maps? diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 5927e818..43d73d19 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -86,6 +86,44 @@ relation to the Left Arcuate and Inferior Longitudinal fasciculi: 'Left Inferior Longitudinal': {'core': 'Left'}, } + +Mixed space ROIs +================ +Everywhere in the bundle dictionary where an ROI is specified as a path, +be it start, end, include, exclude, or probability map, you can in fact input +a dictionary instead. This dictionary should have two keys: +- 'roi' : path to the ROI Nifti file +- 'space' : either 'template' or 'subject', describing the space the ROI + +Then, for the whole bundle, set "space" to 'mixed'. This allows you to +specify some ROIs in template space and some in subject space for the same +bundle. For example: + +.. code-block:: python + import os.path as op + import AFQ.api.bundle_dict as abd + import AFQ.data.fetch as afd + + # First, organize the data + afd.organize_stanford_data() + bids_path = op.join(op.expanduser('~'), 'AFQ_data', 'stanford_hardi') + sub_path = op.join(bids_path, 'derivatives', 'vistasoft', 'sub-01', 'ses-01') + dwi_path = op.join(sub_path, 'dwi', 'sub-01_ses-01_dwi.nii.gz') + + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + # Then, prepare the bundle dictionary + bundle_info = abd.BundleDict({ + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed" + } + }, resample_subject_to=dwi_path) + + Filtering Order =============== When doing bundle recognition, streamlines are filtered out from the whole From eccf402722f80ad3175dcb2656f0709fd5f6bf80 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:10:03 +0900 Subject: [PATCH 39/51] actually transform points / transform inverse points makes sense --- AFQ/api/group.py | 7 +++---- AFQ/recognition/other_bundles.py | 12 +++++++++--- AFQ/recognition/utils.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 0c57ba35..8cafd125 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -583,7 +583,7 @@ def load_next_subject(): these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points(tg.streamlines) + moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sl = np.asarray(moved_sl) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} @@ -1019,14 +1019,13 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - seg_sft.sft.to_vox() sls = seg_sft.get_bundle(bundle_name).streamlines mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni = mapping.transform_points(sls) + sls_mni.extend(mapping.transform_points_inverse(sls)) - moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) + moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) save_path = op.abspath( op.join(self.afq_path, f"bundle-{bundle_name}_subjects-all_MNI.trk") diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 3b04cee0..618daf45 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -76,9 +76,15 @@ def clean_by_overlap( ) if remove: - other_bundle_density_map = ( - other_bundle_density_map / other_bundle_density_map.max() - ) > other_bundle_min_density + max_val = other_bundle_density_map.max() + if max_val > 0: + other_bundle_density_map = ( + other_bundle_density_map / max_val + ) > other_bundle_min_density + else: + other_bundle_density_map = np.zeros_like( + other_bundle_density_map, dtype=bool + ) if project is not None: orientation = nib.orientations.aff2axcodes(img.affine) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 57887d26..8258dea9 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -133,9 +133,9 @@ def move_streamlines(tg, to, mapping, img, save_intermediates=None): else: tg.to_rasmm() if to == "template": - moved_sl = mapping.transform_points(tg.streamlines) - else: moved_sl = mapping.transform_points_inverse(tg.streamlines) + else: + moved_sl = mapping.transform_points(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) if save_intermediates is not None: save_tractogram( From 28af5781bbd45368e283aa8c26c8c126865ddadf Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:21:51 +0900 Subject: [PATCH 40/51] BFs --- AFQ/api/participant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index 18415712..07730bc5 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -406,7 +406,6 @@ def _save_file(curr_img): x_pos = jj % size[0] _ii = jj // size[0] y_pos = _ii % size[1] - _ii = _ii // size[1] this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) From 3148aa1e388cf435b99d98c5c5d99b5809c06e00 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:22:47 +0900 Subject: [PATCH 41/51] BFs --- AFQ/definitions/mapping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index f5c420e1..ce9ee6ec 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -240,16 +240,16 @@ def get_for_subses( if not op.exists(mapping_file): start_time = time() - self.gen_mapping( - reg_template, + affine = self.gen_mapping( reg_subject, + reg_template, subject_sls, template_sls, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - np.save(mapping_file, mapping.affine) + np.save(mapping_file, affine) meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" @@ -399,8 +399,8 @@ def __init__(self, slr_kwargs=None): def gen_mapping( self, - reg_template, reg_subject, + reg_template, subject_sls, template_sls, ): From 37083932ca73ea718cd1a2bfcea6ca45cc34b877 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:44:34 +0900 Subject: [PATCH 42/51] fix roi dist priority --- AFQ/recognition/criteria.py | 11 ++++++++++- AFQ/recognition/recognize.py | 8 ++++++-- AFQ/recognition/roi.py | 6 ++++-- AFQ/recognition/tests/test_rois.py | 9 +++++++-- AFQ/recognition/utils.py | 2 ++ 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index eaa224b3..8e58df9a 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -176,14 +176,16 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): ) roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) + roi_dists = -np.ones((max_includes, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): - sl_accepted, sl_closest = inc_result + sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: if len(sl_closest) > 1: roi_closest[: len(sl_closest), sl_idx] = sl_closest + roi_dists[: len(sl_dists), sl_idx] = sl_dists # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -192,11 +194,13 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) + roi_dists[: len(sl_dists), sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 b_sls.roi_closest = roi_closest.T + b_sls.roi_dists = roi_dists.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include") @@ -390,6 +394,7 @@ def run_bundle_rec_plan( bundle_idx, bundle_to_flip, bundle_roi_closest, + bundle_roi_dists, bundle_decisions, **segmentation_params, ): @@ -500,3 +505,7 @@ def check_space(roi): bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( b_sls.roi_closest.copy() ) + if hasattr(b_sls, "roi_dists"): + bundle_roi_dists[b_sls.selected_fiber_idxs, bundle_idx, :] = ( + b_sls.roi_dists.copy() + ) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index f9a79e63..0d2c9711 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -160,6 +160,9 @@ def recognize( bundle_roi_closest = -np.ones( (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 ) + bundle_roi_dists = -np.ones( + (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.float32 + ) fiber_groups = {} meta = {} @@ -180,6 +183,7 @@ def recognize( bundle_idx, bundle_to_flip, bundle_roi_closest, + bundle_roi_dists, bundle_decisions, clip_edges=clip_edges, n_cpus=n_cpus, @@ -212,10 +216,10 @@ def recognize( ) # Weight by distance to ROI - valid_dists = bundle_roi_closest != -1 + valid_dists = bundle_roi_dists > 0 has_any_valid_roi = np.any(valid_dists, axis=2) if np.any(has_any_valid_roi): - dist_sums = np.sum(np.where(valid_dists, bundle_roi_closest, 0), axis=2) + dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) final_mask = (bundle_decisions > 0) & has_any_valid_roi bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 474bc360..4ae02310 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -11,21 +11,23 @@ def check_sls_with_inclusion(sls, include_rois, include_roi_tols): include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): closest = np.zeros(len(include_rois), dtype=np.int32) + dists = np.zeros(len(include_rois), dtype=np.float32) sl = np.asarray(sl) valid = True for ii, roi in enumerate(include_rois): dist = interpolate_scalar_3d(roi, sl)[0] closest[ii] = np.argmin(dist) + dists[ii] = dist[closest[ii]] if dist[closest[ii]] > include_roi_tols[ii]: # Too far from one of them: - inc_results[jj] = (False, []) + inc_results[jj] = (False, [], []) valid = False break # Checked all the ROIs and it was close to all of them if valid: - inc_results[jj] = (True, closest) + inc_results[jj] = (True, closest, dists) return inc_results diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index 0f560115..67ef384e 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -65,22 +65,27 @@ def test_check_sls_with_inclusion(): assert result[0][0] is True assert np.allclose(result[0][1][0], 0) assert np.allclose(result[0][1][1], 2) + assert np.allclose(result[0][2][0], 0) + assert np.allclose(result[0][2][1], 0) assert result[1][0] is False def test_check_sl_with_inclusion_pass(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline1], include_rois, include_roi_tols )[0] assert result is True assert len(dists) == 2 + assert np.allclose(dist_idxs[0], 0) + assert np.allclose(dist_idxs[1], 2) def test_check_sl_with_inclusion_fail(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline2], include_rois, include_roi_tols )[0] assert result is False + assert dist_idxs == [] assert dists == [] diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 8258dea9..5c751d1b 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -184,6 +184,8 @@ def select(self, idx, clean_name, cut=False): self.sls_flipped = self.sls_flipped[idx] if hasattr(self, "roi_closest"): self.roi_closest = self.roi_closest[idx] + if hasattr(self, "roi_dists"): + self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time self.logger.info( f"After filtering by {clean_name} (time: {time_taken}s), " From 0a3a0cc58cd479d61ab6fcc1cd39522c2e671696 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:46:29 +0900 Subject: [PATCH 43/51] tweaks --- AFQ/recognition/criteria.py | 2 +- docs/source/reference/bundledict.rst | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 8e58df9a..7b6f8134 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -111,7 +111,7 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): def length(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("length") + b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 43d73d19..581e662b 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -94,12 +94,14 @@ be it start, end, include, exclude, or probability map, you can in fact input a dictionary instead. This dictionary should have two keys: - 'roi' : path to the ROI Nifti file - 'space' : either 'template' or 'subject', describing the space the ROI + is currently in. Then, for the whole bundle, set "space" to 'mixed'. This allows you to specify some ROIs in template space and some in subject space for the same bundle. For example: .. code-block:: python + import os.path as op import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd From 43513169a2782b16f65d31d0ee5bd31f0b6e019b Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 17:55:22 +0900 Subject: [PATCH 44/51] small BF --- AFQ/recognition/recognize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 0d2c9711..8fb02fec 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -216,7 +216,7 @@ def recognize( ) # Weight by distance to ROI - valid_dists = bundle_roi_dists > 0 + valid_dists = bundle_roi_dists >= -0.5 # i.e., not -1 has_any_valid_roi = np.any(valid_dists, axis=2) if np.any(has_any_valid_roi): dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) From a6909496811ea0cedf5be96ac363da5a52eaf005 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 12 Feb 2026 23:20:15 +0900 Subject: [PATCH 45/51] put this back --- AFQ/recognition/tests/test_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 95736aa2..66c3133a 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -194,7 +194,7 @@ def test_segment_reco(): # This condition should still hold npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_L"]) > 0) + npt.assert_(len(fiber_groups["CST_R"]) > 0) def test_exclusion_ROI(): From 57fe0a71541352384d77fc7aa60cc60be29b34ec Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 12:53:46 +0900 Subject: [PATCH 46/51] implement moving streamlines with new mapping system --- AFQ/api/group.py | 23 ++++++++++----- AFQ/recognition/criteria.py | 5 ++-- AFQ/recognition/utils.py | 41 -------------------------- AFQ/utils/streamlines.py | 58 ++++++++++++++++++++++++++++++++++++- 4 files changed, 76 insertions(+), 51 deletions(-) diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 8cafd125..aaaea026 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -554,9 +554,12 @@ def load_next_subject(): this_bundles_file = self.export("bundles", collapse=False)[sub][ses] this_mapping = self.export("mapping", collapse=False)[sub][ses] this_img = self.export("dwi", collapse=False)[sub][ses] + this_reg_template = self.export("reg_template", collapse=False)[sub][ + ses + ] seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img) seg_sft.sft.to_rasmm() - subses_info.append((seg_sft, this_mapping)) + subses_info.append((seg_sft, this_mapping, this_img, this_reg_template)) bundle_dict = self.export("bundle_dict", collapse=False)[ self.valid_sub_list[0] @@ -566,7 +569,7 @@ def load_next_subject(): load_next_subject() # load first subject for b in bundle_dict.bundle_names: for i in range(len(self.valid_sub_list)): - seg_sft, mapping = subses_info[i] + seg_sft, mapping, img, reg_template = subses_info[i] idx = seg_sft.bundle_idxs[b] # use the first subses that works # otherwise try each successive subses @@ -582,9 +585,11 @@ def load_next_subject(): idx = np.random.choice(idx, size=100, replace=False) these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) - tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - moved_sl = mapping.transform_points_inverse(tg.streamlines) - moved_sl = np.asarray(moved_sl) + tg = StatefulTractogram(these_sls, img, Space.RASMM) + moved_sl = aus.move_streamlines( + tg, "template", mapping, reg_template + ) + moved_sl = np.asarray(moved_sl.streamlines) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} for ii, sl_idx in enumerate(idx): @@ -1019,11 +1024,15 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - sls = seg_sft.get_bundle(bundle_name).streamlines + sls = seg_sft.get_bundle(bundle_name) mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - sls_mni.extend(mapping.transform_points_inverse(sls)) + sls_mni.extend( + aus.move_streamlines( + sls, "template", mapping, reg_template + ).streamlines + ) moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 7b6f8134..c931dcbf 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -20,6 +20,7 @@ import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict from AFQ.utils.stats import chunk_indices +from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ "prob_map", @@ -218,7 +219,7 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False ) - moved_ref_sl = abu.move_streamlines( + moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) moved_ref_sl.to_vox() @@ -264,7 +265,7 @@ def recobundles( **kwargs, ): b_sls.initiate_selection("Recobundles") - moved_sl = abu.move_streamlines( + moved_sl = move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), "template", mapping, diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 5c751d1b..0a60552d 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -9,8 +9,6 @@ from dipy.io.streamline import save_tractogram from dipy.tracking.distances import bundles_distances_mdf -from AFQ.definitions.mapping import ConformedFnirtMapping - logger = logging.getLogger("AFQ") @@ -108,45 +106,6 @@ def orient_by_streamline(sls, template_sl): return DM[:, 0] > DM[:, 1] -def move_streamlines(tg, to, mapping, img, save_intermediates=None): - """Move streamlines to or from template space. - - to : str - Either "template" or "subject". - mapping : ConformedMapping - Mapping to use to move streamlines. - img : Nifti1Image - Space to move streamlines to. - """ - tg_og_space = tg.space - if isinstance(mapping, ConformedFnirtMapping): - if to != "subject": - raise ValueError( - "Attempted to transform streamlines to template using " - "unsupported mapping. " - "Use something other than Fnirt." - ) - tg.to_vox() - moved_sl = [] - for sl in tg.streamlines: - moved_sl.append(mapping.transform_pts(sl)) - else: - tg.to_rasmm() - if to == "template": - moved_sl = mapping.transform_points_inverse(tg.streamlines) - else: - moved_sl = mapping.transform_points(tg.streamlines) - moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) - if save_intermediates is not None: - save_tractogram( - moved_sft, - op.join(save_intermediates, f"sls_in_{to}.trk"), - bbox_valid_check=False, - ) - tg.to_space(tg_og_space) - return moved_sft - - def resample_tg(tg, n_points): # reformat for dipy's set_number_of_points if isinstance(tg, np.ndarray): diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 571534a1..52c7051c 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -2,7 +2,7 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram -from dipy.io.streamline import load_tractogram +from dipy.io.streamline import load_tractogram, save_tractogram try: from trx.io import load as load_trx @@ -11,6 +11,7 @@ except ModuleNotFoundError: has_trx = False +from AFQ.definitions.mapping import ConformedFnirtMapping from AFQ.utils.path import drop_extension, read_json @@ -137,3 +138,58 @@ def split_streamline(streamlines, sl_to_split, split_idx): ) return streamlines + + +def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=None): + """Move streamlines to or from template space. + + to : str + Either "template" or "subject". This determines + whether we will use the forward or backwards displacement field. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + img : Nifti1Image + Image defining reference for where the streamlines move to. + to_space : Space or None + If not None, space to move streamlines to after moving them to the + template or subject space. If None, streamlines will be moved back to + their original space. + Default: None. + save_intermediates : str or None + If not None, path to save intermediate tractogram after moving to template + or subject space. + Default: None. + """ + tg_og_space = tg.space + if isinstance(mapping, ConformedFnirtMapping): + if to != "subject": + raise ValueError( + "Attempted to transform streamlines to template using " + "unsupported mapping. " + "Use something other than Fnirt." + ) + tg.to_vox() + moved_sl = [] + for sl in tg.streamlines: + moved_sl.append(mapping.transform_pts(sl)) + moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) + else: + tg.to_vox() + if to == "template": + moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) + moved_sft.to_rasmm() + + if save_intermediates is not None: + save_tractogram( + moved_sft, + op.join(save_intermediates, f"sls_in_{to}.trk"), + bbox_valid_check=False, + ) + if to_space is None: + tg.to_space(tg_og_space) + else: + tg.to_space(to_space) + return moved_sft From f18713ab8524f14c73fe0159cbc2a95492a27e4f Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 23:41:16 +0900 Subject: [PATCH 47/51] add ORG VOF subclusters --- .codespellrc | 2 +- AFQ/api/bundle_dict.py | 211 ++++++++++++++++-- AFQ/definitions/mapping.py | 3 + AFQ/recognition/cleaning.py | 3 + AFQ/recognition/clustering.py | 195 ++++++++++++++++ AFQ/recognition/criteria.py | 135 +++++++---- AFQ/recognition/preprocess.py | 12 +- AFQ/recognition/recognize.py | 98 ++++---- AFQ/recognition/sparse_decisions.py | 116 ++++++++++ AFQ/recognition/utils.py | 25 +++ AFQ/tasks/viz.py | 3 + docs/source/references.bib | 10 + .../plot_001_group_afq_api.py | 4 +- 13 files changed, 695 insertions(+), 122 deletions(-) create mode 100644 AFQ/recognition/clustering.py create mode 100644 AFQ/recognition/sparse_decisions.py diff --git a/.codespellrc b/.codespellrc index 7035fe17..fdf6f1d6 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] skip = [setup.cfg] -ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus \ No newline at end of file +ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo \ No newline at end of file diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 2c7497bb..51f4a0c6 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -297,12 +297,75 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - # "Left Inferior Longitudinal": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, - "isolation_forest": {}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Left Vertical Occipital I": { + "cluster_ID": 89, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital II": { + "cluster_ID": 82, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital III": { + "cluster_ID": 83, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital IV": { + "cluster_ID": 21, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Left Vertical Occipital V": { + "cluster_ID": 454, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + }, + remove_cluster_IDs=[ + 27, + 100, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, "Right Vertical Occipital": { "cross_midline": False, @@ -314,15 +377,84 @@ def default_bd(): "project": "L/R", "entire_core": "Anterior", }, - # "Right Inferior Longitudinal": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 2, "clean_rounds": 1}, "length": {"min_len": 25, "max_len": 60}, - "isolation_forest": {}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", "primary_axis_percentage": 40, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Right Vertical Occipital I": { + "cluster_ID": 89, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital II": { + "cluster_ID": 82, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital III": { + "cluster_ID": 83, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital IV": { + "cluster_ID": 21, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + "Right Vertical Occipital V": { + "cluster_ID": 454, + "isolation_forest": {}, + "orient_mahal": { + "distance_threshold": 4, + "clean_rounds": 3, + }, + }, + }, + remove_cluster_IDs=[ + 27, + 100, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, }, - citations={"Yeatman2012", "takemura2017occipital", "Tzourio-Mazoyer2002"}, + citations={ + "Yeatman2012", + "takemura2017occipital", + "Tzourio-Mazoyer2002", + "zhang2018anatomically", + "Hua2008", + }, ) @@ -334,31 +466,49 @@ def slf_bd(): "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, + "Left Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, + "Right Cingulum Cingulate": { + "node_thresh": 20, + }, }, }, citations={"Sagi2024"}, @@ -1043,7 +1193,6 @@ def __init__( self.resample_to = resample_to self.resample_subject_to = resample_subject_to self.keep_in_memory = keep_in_memory - self.max_includes = 3 self.citations = citations if self.citations is None: self.citations = set() @@ -1097,10 +1246,6 @@ def __init__( def __print__(self): print(self._dict) - def update_max_includes(self, new_max): - if new_max > self.max_includes: - self.max_includes = new_max - def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") @@ -1203,8 +1348,6 @@ def __getitem__(self, key): def __setitem__(self, key, item): self._dict[key] = item - if hasattr(item, "get"): - self.update_max_includes(len(item.get("include", []))) if key not in self.bundle_names: self.bundle_names.append(key) @@ -1464,6 +1607,46 @@ def __add__(self, other): ) +class SpectralSubbundleDict(BundleDict): + """ + A BundleDict where each bundle is defined as a spectral subbundle of a + larger bundle. See `Defining Custom Bundle Dictionaries` in the documentation + for details. + """ + + def __init__( + self, + bundle_info, + resample_to=None, + resample_subject_to=False, + keep_in_memory=False, + citations=None, + remove_cluster_IDs=None, + ): + super().__init__( + bundle_info, resample_to, resample_subject_to, keep_in_memory, citations + ) + if remove_cluster_IDs is None: + remove_cluster_IDs = [] + self.remove_cluster_IDs = remove_cluster_IDs + self.cluster_IDs = [] + self.id_to_name = {} + for b_name, b_info in bundle_info.items(): + if "cluster_ID" not in b_info: + raise ValueError( + ( + f"Bundle {b_name} does not have a cluster_ID. " + "All bundles in a SpectralSubbundleDict must have a cluster_ID." + ) + ) + self.cluster_IDs.append(b_info["cluster_ID"]) + self.id_to_name[b_info["cluster_ID"]] = b_name + self.all_cluster_IDs = self.remove_cluster_IDs + self.cluster_IDs + + def get_subbundle_name(self, cluster_id): + return self.id_to_name.get(cluster_id, None) + + def apply_to_roi_dict( dict_, func, diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index ce9ee6ec..ed5ca043 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -212,6 +212,7 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): return mapping_file, meta_fname def prealign(self, reg_subject, reg_template): + logger.info("Calculating affine pre-alignment...") _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) return aff @@ -338,6 +339,8 @@ def get_for_subses( reg_prealign = self.prealign(reg_subject, reg_template) else: reg_prealign = None + + logger.info("Calculating SyN registration...") _, mapping = syn_registration( reg_subject.get_fdata(), reg_template.get_fdata(), diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index db741fe0..59d3687f 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -101,6 +101,7 @@ def clean_by_orientation_mahalanobis( if np.sum(idx_dist) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -231,6 +232,7 @@ def clean_bundle( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -360,6 +362,7 @@ def clean_by_isolation_forest( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(-sl_outliers)[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py new file mode 100644 index 00000000..2d33b6c3 --- /dev/null +++ b/AFQ/recognition/clustering.py @@ -0,0 +1,195 @@ +# Original source: github.com/SlicerDMRI/whitematteranalysis +# Copyright 2026 BWH and 3D Slicer contributors +# Licensed under 3D Slicer license (BSD style; https://github.com/SlicerDMRI/whitematteranalysis/blob/master/License.txt) # noqa +# Modified by John Kruper for pyAFQ +# Modifications: +# 1. Only mean distance included, and mean distance replaced with numba version. +# 2. Uses atlas data from dictionary and numpy files rather than pickled files, +# to avoid additional dependencies. +# 3. Added function to move template streamlines +# to subject space to calculate distances. + +import numpy as np +import scipy +from dipy.io.stateful_tractogram import Space +from numba import njit, prange + +import AFQ.data.fetch as afd +import AFQ.recognition.utils as abu +import AFQ.utils.streamlines as aus + + +@njit(parallel=True) +def _compute_mean_euclidean_matrix(group_n, group_m): + len_n = group_n.shape[0] + len_m = group_m.shape[0] + num_points = group_n.shape[1] + + dist_matrix = np.empty((len_n, len_m), dtype=np.float64) + + for i in prange(len_n): + for j in range(len_m): + sum_dist = 0.0 + sum_dist_ref = 0.0 + + for k in range(num_points): + dx = group_n[i, k, 0] - group_m[j, k, 0] + dx_ref = group_n[i, k, 0] + group_m[j, k, 0] + dy = group_n[i, k, 1] - group_m[j, k, 1] + dz = group_n[i, k, 2] - group_m[j, k, 2] + + sum_dist += np.sqrt(dx * dx + dy * dy + dz * dz) + sum_dist_ref += np.sqrt(dx_ref * dx_ref + dy * dy + dz * dz) + + mean_d = sum_dist / num_points + mean_d_ref = sum_dist_ref / num_points + + final_d = min(mean_d, mean_d_ref) + dist_matrix[i, j] = final_d * final_d + + return dist_matrix.T + + +def _distance_to_similarity(distance, sigmasq): + similarities = np.exp(-distance / (sigmasq)) + + return similarities + + +def _rectangular_similarity_matrix(fgarray_sub, fgarray_atlas, sigma): + distances = _compute_mean_euclidean_matrix(fgarray_sub, fgarray_atlas) + + sigmasq = sigma * sigma + similarity_matrix = _distance_to_similarity(distances, sigmasq) + + return similarity_matrix + + +def spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=None, + sigma_multiplier=1.0, + cluster_indices=None, +): + """ + Use an existing atlas to label a new streamlines. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group to be labeled. + atlas_fgarray : ndarray + Resampled atlas to use for labelling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + sigma_multiplier : float, optional + Multiplier for the sigma value used in computing the similarity + matrix. Default is 1.0. + cluster_indices : list of int, optional + If provided, only these cluster indices from the atlas will be used + for labeling. Default is None, which uses all clusters. + + Returns + ------- + tuple of (ndarray, ndarray) + Cluster indices for all the fibers and their embedding + """ + if atlas_data is None: + atlas_data = afd.read_org800_templates(load_trx=False) + + number_fibers = sub_fgarray.shape[0] + sz = atlas_fgarray.shape[0] + + # Compute fiber similarities. + B = _rectangular_similarity_matrix( + sub_fgarray, atlas_fgarray, sigma=atlas_data["sigma"] * sigma_multiplier + ) + + # Do Normalized Cuts transform of similarity matrix. + # row sum estimate for current B part of the matrix + row_sum_2 = np.sum(B, axis=0) + np.dot(atlas_data["row_sum_matrix"], B) + + # This happens plenty in our cases. Why? + # Maybe a probabilistic vs UKF thing? + # In practice, this is not an issue since we just set to a small value. + if any(row_sum_2 <= 0): + row_sum_2[row_sum_2 < 0] = 1e-4 + + # Normalized cuts normalization + row_sum = np.concatenate((atlas_data["row_sum_1"], row_sum_2)) + dhat = np.sqrt(np.divide(1, row_sum)) + B = np.multiply(B, np.outer(dhat[0:sz], dhat[sz:].T)) + + # Compute embedding using eigenvectors + V = np.dot( + np.dot(B.T, atlas_data["e_vec"]), np.diag(np.divide(1.0, atlas_data["e_val"])) + ) + V = np.divide(V, atlas_data["e_vec_norm"]) + n_eigen = int(atlas_data["number_of_eigenvectors"]) + embed = np.zeros((number_fibers, n_eigen)) + for i in range(0, n_eigen): + embed[:, i] = np.divide(V[:, -(i + 2)], V[:, -1]) + + # Label streamlines using centroids from atlas + if cluster_indices is not None: + centroids = atlas_data["centroids"][cluster_indices, :] + cluster_idx, _ = scipy.cluster.vq.vq(embed, centroids) + cluster_idx = np.array([cluster_indices[i] for i in cluster_idx]) + else: + cluster_idx, _ = scipy.cluster.vq.vq(embed, atlas_data["centroids"]) + + return cluster_idx, embed + + +def subcluster_by_atlas( + sub_trk, mapping, dwi_ref, cluster_indices, atlas_data=None, n_points=20 +): + """ + Use an existing atlas to label a new set of streamlines, and return the + cluster indices for each streamline. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group in VOX to be labeled. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + dwi_ref : Nifti1Image + Image defining reference for where the atlas streamlines move to. + cluster_indices : list of int + Cluster indices from the atlas to use for labeling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + n_points : int, optional + Number of points to resample streamlines to for labeling. Default is 20. + """ + + if atlas_data is None: + atlas_data = afd.read_org800_templates() + atlas_sft = atlas_data["tracks_reoriented"] + + moved_atlas_sft = aus.move_streamlines( + atlas_sft, "subject", mapping, dwi_ref, to_space=Space.RASMM + ) + atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) + + # Note: if we need more efficiency, + # we could modify the code to consider: + # voxel size, midline axis, and midline location + # then we should be able to do these calculations in + # voxel space without having to move the subject streamlines + # to rasmm (but this is not a bottleneck right now) + sub_trk.to_rasmm() + sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) + + cluster_idxs, _ = spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + + return cluster_idxs diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c931dcbf..3f454dcc 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -19,6 +19,7 @@ import AFQ.recognition.roi as abr import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict +from AFQ.recognition.clustering import subcluster_by_atlas from AFQ.utils.stats import chunk_indices from AFQ.utils.streamlines import move_streamlines @@ -45,6 +46,8 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", + "ORG_spectral_subbundles", + "cluster_ID", ] @@ -138,7 +141,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): +def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet @@ -176,17 +179,18 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols ) - roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) - roi_dists = -np.ones((max_includes, len(b_sls)), dtype=np.float32) + n_inc = len(bundle_def["include"]) + roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) + roi_dists = np.zeros((n_inc, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: + roi_closest[:, sl_idx] = sl_closest + roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: - roi_closest[: len(sl_closest), sl_idx] = sl_closest - roi_dists[: len(sl_dists), sl_idx] = sl_dists # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -194,8 +198,8 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: - roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) - roi_dists[: len(sl_dists), sl_idx] = np.flip(sl_dists) + roi_closest[:, sl_idx] = np.flip(sl_closest) + roi_dists[:, sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 @@ -301,14 +305,15 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): def clean_by_other_bundle( - b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs + b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 + flipped_sls = b_sls.get_selected_sls(flip=True) if "overlap" in bundle_def[other_bundle_name]: cleaned_idx_overlap = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, @@ -319,7 +324,7 @@ def clean_by_other_bundle( if "node_thresh" in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, @@ -331,8 +336,7 @@ def clean_by_other_bundle( if "core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - # the extra specificity of 100 points is needed + np.array(abu.resample_tg(flipped_sls, 100)), np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, @@ -342,8 +346,8 @@ def clean_by_other_bundle( if "entire_core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) @@ -392,11 +396,7 @@ def run_bundle_rec_plan( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_roi_dists, - bundle_decisions, + recognized_bundles_dict, **segmentation_params, ): # Warp ROIs @@ -430,20 +430,25 @@ def check_space(roi): ) logger.info(f"Time to prep ROIs: {time() - start_time}s") - b_sls = abu.SlsBeingRecognized( - tg.streamlines, - logger, - segmentation_params["save_intermediates"], - bundle_name, - img, - len(bundle_def.get("include", [])), - ) + if isinstance(tg, abu.SlsBeingRecognized): + # This only occurs when your inside a subbundle, + # in which case we want to keep the same SlsBeingRecognized object so that + # we can keep track of the same streamlines and their orientations + b_sls = tg + else: + b_sls = abu.SlsBeingRecognized( + tg.streamlines, + logger, + segmentation_params["save_intermediates"], + bundle_name, + img, + len(bundle_def.get("include", [])), + ) inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap inputs["bundle_def"] = bundle_def - inputs["max_includes"] = bundle_dict.max_includes inputs["mapping"] = mapping inputs["img"] = img inputs["reg_template"] = reg_template @@ -454,7 +459,7 @@ def check_space(roi): if ( (potential_criterion not in criteria_order_post_other_bundles) and (potential_criterion not in criteria_order_pre_other_bundles) - and (potential_criterion not in bundle_dict.bundle_names) + and (potential_criterion not in recognized_bundles_dict.keys()) and (potential_criterion not in valid_noncriterion) ): raise ValueError( @@ -464,7 +469,7 @@ def check_space(roi): "Valid criteria are:\n" f"{criteria_order_pre_other_bundles}\n" f"{criteria_order_post_other_bundles}\n" - f"{bundle_dict.bundle_names}\n" + f"{recognized_bundles_dict.keys()}\n" f"{valid_noncriterion}\n" ) ) @@ -473,13 +478,14 @@ def check_space(roi): if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: - for ii, bundle_name in enumerate(bundle_dict.bundle_names): - if bundle_name in bundle_def.keys(): - idx = np.where(bundle_decisions[:, ii])[0] + for o_bundle_name in recognized_bundles_dict.keys(): + if o_bundle_name in bundle_def.keys(): clean_by_other_bundle( **inputs, - other_bundle_name=bundle_name, - other_bundle_sls=tg.streamlines[idx], + other_bundle_name=o_bundle_name, + other_bundle_sls=recognized_bundles_dict[ + o_bundle_name + ].get_selected_sls(flip=True), ) for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: @@ -490,6 +496,22 @@ def check_space(roi): ): mahalanobis(**inputs) + # If you don't cross the midline, we remove streamliens + # entirely on the wrong side of the midline here after filtering + if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: + b_sls.initiate_selection("Wrong side of mid.") + zero_coord = preproc_imap["zero_coord"] + lr_axis = preproc_imap["lr_axis"] + avg_side = np.sign( + np.mean( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, lr_axis] + - zero_coord, + axis=1, + ) + ) + majority_side = np.sign(np.sum(avg_side)) + b_sls.select(avg_side == majority_side, "Wrong side of mid.") + if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " @@ -500,13 +522,40 @@ def check_space(roi): ) if b_sls: - bundle_to_flip[b_sls.selected_fiber_idxs, bundle_idx] = b_sls.sls_flipped.copy() - bundle_decisions[b_sls.selected_fiber_idxs, bundle_idx] = 1 - if hasattr(b_sls, "roi_closest"): - bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_closest.copy() + if "ORG_spectral_subbundles" in bundle_def: + subdict = bundle_def["ORG_spectral_subbundles"] + c_ids = subdict.cluster_IDs + b_sls.initiate_selection( + (f"ORG spectral clustering, {len(c_ids)} subbundles being recognized") ) - if hasattr(b_sls, "roi_dists"): - bundle_roi_dists[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_dists.copy() + + sub_sft = StatefulTractogram( + b_sls.get_selected_sls(flip=True), img, Space.VOX + ) + cluster_labels = subcluster_by_atlas( + sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 ) + clusters_being_recognized = [] + for c_id in c_ids: + bundle_name = subdict.get_subbundle_name(c_id) + n_roi = len(subdict[bundle_name].get("include", [])) + cluster_b_sls = b_sls.copy(bundle_name, n_roi) + cluster_b_sls.select(cluster_labels == c_id, f"Cluster {c_id}") + clusters_being_recognized.append(cluster_b_sls) + + for ii, c_id in enumerate(c_ids): + bundle_name = subdict.get_subbundle_name(c_id) + run_bundle_rec_plan( + bundle_def["ORG_spectral_subbundles"], + clusters_being_recognized[ii], + mapping, + img, + reg_template, + preproc_imap, + bundle_name, + recognized_bundles_dict, + **segmentation_params, + ) + else: + b_sls.bundle_def = bundle_def + recognized_bundles_dict[bundle_name] = b_sls diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8b3ce657..52531045 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -27,7 +27,7 @@ def fgarray(tg): return fg_array -@immlib.calc("crosses") +@immlib.calc("crosses", "lr_axis", "zero_coord") def crosses(fgarray, img): """ Classify the streamlines by whether they cross the midline. @@ -45,9 +45,13 @@ def crosses(fgarray, img): lr_axis = idx break - return np.logical_and( - np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), - np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + return ( + np.logical_and( + np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), + np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + ), + lr_axis, + zero_coord[lr_axis], ) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 8fb02fec..73602909 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -6,10 +6,12 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram +import AFQ.recognition.sparse_decisions as ars import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict from AFQ.recognition.criteria import run_bundle_rec_plan from AFQ.recognition.preprocess import get_preproc_plan +from AFQ.utils.path import write_json logger = logging.getLogger("AFQ") @@ -155,14 +157,7 @@ def recognize( tg.to_vox() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.float32) - bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.int32 - ) - bundle_roi_dists = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.float32 - ) + recognized_bundles_dict = {} fiber_groups = {} meta = {} @@ -170,7 +165,7 @@ def recognize( preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) logger.info("Assigning Streamlines to Bundles") - for bundle_idx, bundle_name in enumerate(bundle_dict.bundle_names): + for bundle_name in bundle_dict.bundle_names: logger.info(f"Finding Streamlines for {bundle_name}") run_bundle_rec_plan( bundle_dict, @@ -180,11 +175,7 @@ def recognize( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_roi_dists, - bundle_decisions, + recognized_bundles_dict, clip_edges=clip_edges, n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, @@ -199,10 +190,18 @@ def recognize( if save_intermediates is not None: os.makedirs(save_intermediates, exist_ok=True) - bc_path = op.join(save_intermediates, "sls_bundle_decisions.npy") - np.save(bc_path, bundle_decisions) + bc_path = op.join(save_intermediates, "sls_bundle_decisions.json") + write_json( + bc_path, + { + b_name: b_sls.selected_fiber_idxs.tolist() + for b_name, b_sls in recognized_bundles_dict.items() + }, + ) + + sparse_dists = ars.compute_sparse_decisions(recognized_bundles_dict, n_streamlines) - conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1) + conflicts = ars.get_conflict_count(sparse_dists) if conflicts > 0: logger.info( ( @@ -215,63 +214,38 @@ def recognize( ) ) - # Weight by distance to ROI - valid_dists = bundle_roi_dists >= -0.5 # i.e., not -1 - has_any_valid_roi = np.any(valid_dists, axis=2) - if np.any(has_any_valid_roi): - dist_sums = np.sum(np.where(valid_dists, bundle_roi_dists, 0), axis=2) - max_roi_dist_sum = float(dist_sums[has_any_valid_roi].max() + 1) - final_mask = (bundle_decisions > 0) & has_any_valid_roi - bundle_decisions[final_mask] = 2 - (dist_sums[final_mask] / max_roi_dist_sum) - - bundle_decisions = np.concatenate( - (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 - ) - bundle_decisions = np.argmax(bundle_decisions, -1) + ars.remove_conflicts(sparse_dists, recognized_bundles_dict) # We do another round through, so that we can: # 1. Clip streamlines according to ROIs # 2. Re-orient streamlines logger.info("Re-orienting streamlines to consistent directions") - for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): - logger.info(f"Processing {bundle}") + for b_name, r_bd in recognized_bundles_dict.items(): + logger.info(f"Processing {b_name}") - select_idx = np.where(bundle_decisions == bundle_idx)[0] - - if len(select_idx) == 0: + if len(r_bd.selected_fiber_idxs) == 0: # There's nothing here, set and move to the next bundle: - if "bundlesection" in bundle_dict.get_b_info(bundle): - for sb_name in bundle_dict.get_b_info(bundle)["bundlesection"]: + if "bundlesection" in bundle_dict.get_b_info(b_name): + for sb_name in bundle_dict.get_b_info(b_name)["bundlesection"]: _return_empty(sb_name, return_idx, fiber_groups, img) else: - _return_empty(bundle, return_idx, fiber_groups, img) + _return_empty(b_name, return_idx, fiber_groups, img) continue - # Use a list here, because ArraySequence doesn't support item - # assignment: - select_sl = list(tg.streamlines[select_idx]) - roi_closest = bundle_roi_closest[select_idx, bundle_idx, :] - n_includes = len(bundle_dict.get_b_info(bundle).get("include", [])) - if clip_edges and n_includes > 1: - logger.info("Clipping Streamlines by ROI") - select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, (0, n_includes - 1), in_place=True - ) - - to_flip = bundle_to_flip[select_idx, bundle_idx] - b_def = dict(bundle_dict.get_b_info(bundle_name)) + b_def = r_bd.bundle_def if "bundlesection" in b_def: - for sb_name, sb_include_cuts in bundle_dict.get_b_info(bundle)[ - "bundlesection" - ].items(): + for sb_name, sb_include_cuts in b_def["bundlesection"].items(): bundlesection_select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, sb_include_cuts, in_place=False + r_bd.get_selected_sls(), + r_bd.roi_closest, + sb_include_cuts, + in_place=False, ) _add_bundle_to_fiber_group( sb_name, bundlesection_select_sl, - select_idx, - to_flip, + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, return_idx, fiber_groups, img, @@ -279,9 +253,15 @@ def recognize( _add_bundle_to_meta(sb_name, b_def, meta) else: _add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip, return_idx, fiber_groups, img + b_name, + r_bd.get_selected_sls(cut=clip_edges), + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, + return_idx, + fiber_groups, + img, ) - _add_bundle_to_meta(bundle, b_def, meta) + _add_bundle_to_meta(b_name, b_def, meta) return fiber_groups, meta diff --git a/AFQ/recognition/sparse_decisions.py b/AFQ/recognition/sparse_decisions.py new file mode 100644 index 00000000..114cc8ee --- /dev/null +++ b/AFQ/recognition/sparse_decisions.py @@ -0,0 +1,116 @@ +import numpy as np +from scipy.sparse import csr_matrix + + +def compute_sparse_decisions(bundles_being_recognized, n_streamlines): + """ + Compute a sparse matrix of distances to ROIs for the streamlines that are + currently being recognized. This can be used to weight decisions by distance + to ROIs, without having to create a dense matrix of distances for all + streamlines and all bundles. + + Parameters + ---------- + bundles_being_recognized : dict + A dictionary of SlsBeingRecognized objects, keyed by bundle name. + n_streamlines : int + The total number of streamlines in the original tractogram. + + Returns + ------- + csr_matrix + A sparse matrix of shape (number of bundles being recognized, n_streamlines), + where the entry (i, j) is a score: + bundles with ROIs result in weights [2.0 to 3.0] with higher scores + for streamlines closer to ROIs + Non-ROI bundles result in weight 1.0 + Everything else is 0.0 (implicit in sparse matrices) + """ + rows, cols, data = [], [], [] + epsilon = 1e-6 + + global_max_dist = 0.0 + for b in bundles_being_recognized.values(): + if hasattr(b, "roi_dists"): + global_max_dist = max(global_max_dist, np.sum(b.roi_dists, axis=-1).max()) + + norm_factor = global_max_dist + 1.0 + + for b_idx, name in enumerate(bundles_being_recognized.keys()): + bundle = bundles_being_recognized[name] + indices = bundle.selected_fiber_idxs + + if hasattr(bundle, "roi_dists"): + dists = np.sum(bundle.roi_dists, axis=-1) + dists = np.maximum(dists, epsilon) + bundle_weights = dists / norm_factor + else: + bundle_weights = np.full(len(indices), 2.0, dtype=np.float32) + + rows.extend([b_idx] * len(indices)) + cols.extend(indices) + data.extend(bundle_weights) + + sparse_scores = csr_matrix( + (data, (rows, cols)), shape=(len(bundles_being_recognized), n_streamlines) + ) + + # Final Decision: 3.0 - Score + # ROI bundles result in weights [2.0 to 3.0] + # No-ROI bundles result in weight 1.0 + sparse_scores.data = 3.0 - sparse_scores.data + + return sparse_scores + + +def get_conflict_count(sparse_scores): + """ + Count how many streamlines are being considered for more than one bundle + """ + sorted_indices = np.sort(sparse_scores.indices) + is_duplicate = np.diff(sorted_indices) == 0 + num_conflicts = np.sum(is_duplicate) + + return num_conflicts + + +def remove_conflicts(sparse_scores, bundles_being_recognized): + """ + Returns a dictionary of {bundle_name: np.array(accepted_indices)} + """ + coo = sparse_scores.tocoo() + + order = np.lexsort((-coo.data, coo.col)) + + mask = np.concatenate(([True], np.diff(coo.col[order]) != 0)) + winner_rows = coo.row[order][mask] + winner_cols = coo.col[order][mask] + + row_sort = np.argsort(winner_rows) + winner_rows = winner_rows[row_sort] + winner_cols = winner_cols[row_sort] + + num_bundles = len(bundles_being_recognized) + split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) + + for i, b_name in enumerate(bundles_being_recognized.keys()): + b_sls = bundles_being_recognized[b_name] + if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): + raise NotImplementedError( + f"Bundle '{b_name}' has unsorted selected_fiber_idxs. " + "The searchsorted optimization requires sorted indices." + "This is a bug in the implementation of the bundle " + "recognition procedure, please report it to the developers." + ) + + accept_idx = b_sls.initiate_selection(f"{b_name} conflicts") + start, end = split_indices[i], split_indices[i + 1] + bundle_winners = winner_cols[start:end] + + if len(bundle_winners) > 0: + local_positions = np.searchsorted(b_sls.selected_fiber_idxs, bundle_winners) + accept_idx[local_positions] = True + b_sls.select(local_positions, "conflicts") + else: + b_sls.select(accept_idx, "conflicts") + bundles_being_recognized.pop(b_name) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index 0a60552d..678c4eed 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -1,3 +1,4 @@ +import copy import logging import os.path as op from time import time @@ -188,3 +189,27 @@ def __bool__(self): def __len__(self): return len(self.selected_fiber_idxs) + + def copy(self, new_name, n_roi): + new_copy = copy.copy(self) + new_copy.b_name = new_name + if n_roi > 0: + if self.n_roi > 0: + raise NotImplementedError( + ( + "You cannot have includes in the original bundle and" + " subbundles; only one or the other." + ) + ) + else: + new_copy.n_roi = n_roi + + new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy() + new_copy.sls_flipped = self.sls_flipped.copy() + + if hasattr(self, "roi_closest"): + new_copy.roi_closest = self.roi_closest.copy() + if hasattr(self, "roi_dists"): + new_copy.roi_dists = self.roi_dists.copy() + + return new_copy diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 1580c24a..a03b92b6 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -203,6 +203,9 @@ def viz_indivBundle( if "bundlesection" in b_info: for sb_name in b_info["bundlesection"]: segmented_bname_to_roi_bname[sb_name] = b_name + elif "ORG_spectral_subbundles" in b_info: + for sb_name in b_info["ORG_spectral_subbundles"]: + segmented_bname_to_roi_bname[sb_name] = b_name else: segmented_bname_to_roi_bname[b_name] = b_name diff --git a/docs/source/references.bib b/docs/source/references.bib index ec7b1eb7..0588b036 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -538,6 +538,16 @@ @InProceedings{Kruper2023 isbn="978-3-031-47292-3" } +@article{zhang2018anatomically, + title={An anatomically curated fiber clustering white matter atlas for consistent white matter tract parcellation across the lifespan}, + author={Zhang, Fan and Wu, Ye and Norton, Isaiah and Rigolo, Laura and Rathi, Yogesh and Makris, Nikos and O'Donnell, Lauren J}, + journal={Neuroimage}, + volume={179}, + pages={429--447}, + year={2018}, + publisher={Elsevier} +} + @ARTICLE{Hua2008, title = "Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification", diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 0b8f0980..749e470d 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="all")[1] + clear_previous_afq="recog")[1] ########################################################################## # Set tractography parameters (optional) @@ -272,6 +272,8 @@ threshold = 3000 elif "Fronto-occipital" in ind: threshold = 10 + elif "Vertical Occipital" in ind: + threshold = 5 else: threshold = 15 if bundle_counts["n_streamlines"][ind] < threshold: From 2e21c5be77deb53609504ba3b346f6b5dfb90f37 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 14 Feb 2026 23:55:58 +0900 Subject: [PATCH 48/51] put this back --- examples/tutorial_examples/plot_001_group_afq_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 749e470d..3f2dbd00 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -46,7 +46,7 @@ bids_path = afd.fetch_hbn_preproc( ["NDARAA948VFH"], - clear_previous_afq="recog")[1] + clear_previous_afq="all")[1] ########################################################################## # Set tractography parameters (optional) From aa4ac6e2dc09565600977d855023b11559fb2a75 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sun, 15 Feb 2026 11:26:17 +0900 Subject: [PATCH 49/51] fix up reco --- AFQ/recognition/criteria.py | 1 + AFQ/recognition/tests/test_recognition.py | 6 +++--- AFQ/utils/streamlines.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 3f454dcc..8f227e85 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -274,6 +274,7 @@ def recobundles( "template", mapping, reg_template, + to_space=Space.RASMM, save_intermediates=save_intermediates, ).streamlines moved_sl_resampled = abu.resample_tg(moved_sl, 100) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 66c3133a..6e4812de 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -176,7 +176,7 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ["CST_R", "CST_L"] + bundle_names = ["MCP"] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) @@ -193,8 +193,8 @@ def test_segment_reco(): ) # This condition should still hold - npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_equal(len(fiber_groups), 1) + npt.assert_(len(fiber_groups["MCP"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 52c7051c..6d64003f 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -180,7 +180,6 @@ def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=Non else: moved_sl = mapping.transform_points_inverse(tg.streamlines) moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) - moved_sft.to_rasmm() if save_intermediates is not None: save_tractogram( @@ -189,7 +188,7 @@ def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=Non bbox_valid_check=False, ) if to_space is None: - tg.to_space(tg_og_space) + moved_sft.to_space(tg_og_space) else: - tg.to_space(to_space) + moved_sft.to_space(to_space) return moved_sft From e6dbda821d0fbe23c750cb0ddad8529ff5da51ce Mon Sep 17 00:00:00 2001 From: 36000 Date: Mon, 16 Feb 2026 11:01:35 +0900 Subject: [PATCH 50/51] color new bundles --- AFQ/viz/plotly_backend.py | 19 +++++++------ AFQ/viz/utils.py | 60 ++++++++++++++++++++++++++++++++++----- 2 files changed, 64 insertions(+), 15 deletions(-) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 07d467e0..8b2f8745 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -40,8 +40,6 @@ def _inline_interact(figure, show, show_inline): def _to_color_range(num): - if num < 0: - num = 0 if num >= 0.999: num = 0.999 if num <= 0.001: @@ -232,9 +230,10 @@ def _draw_streamlines( def _plot_profiles(profiles, bundle_name, color, fig, scalar): if isinstance(profiles, pd.DataFrame): - sc_max = np.max(profiles[scalar].to_numpy()) - sc_90 = np.percentile(profiles[scalar].to_numpy(), 10) - sc_1 = np.percentile(profiles[scalar].to_numpy(), 99) + all_tp = profiles[scalar].to_numpy() + all_tp = np.max(all_tp) - all_tp + lim_0 = np.percentile(all_tp, 1) + lim_1 = np.percentile(all_tp, 90) profiles = profiles[profiles.tractID == bundle_name] x = profiles["nodeID"] @@ -242,10 +241,14 @@ def _plot_profiles(profiles, bundle_name, color, fig, scalar): line_color = [] for scalar_val in profiles[scalar].to_numpy(): - xformed_scalar = np.minimum( - (sc_max - scalar_val) / (sc_1 - sc_90) + sc_90 + 0.1, 0.999 + brightness = np.minimum( + np.maximum( + scalar_val - lim_0, + 0, + ), + lim_1, ) - line_color.append(_color_arr2str(xformed_scalar * color)) + line_color.append(_color_arr2str(brightness * color)) else: x = np.arange(len(profiles)) y = profiles diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 1b82a306..7b4bd2d6 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -1,3 +1,4 @@ +import colorsys import logging import os.path as op from collections import OrderedDict @@ -17,6 +18,23 @@ __all__ = ["Viz"] + +def get_distinct_shades(base_rgb, n_steps, hue_shift): + """ + Creates distinct shades by shifting Hue + """ + hh, ll, ss = colorsys.rgb_to_hls(*base_rgb) + shades = [] + + for i in range(n_steps): + offset = i - (n_steps - 1) / 2 + + new_h = (hh + (offset * hue_shift)) % 1.0 + + shades.append(colorsys.hls_to_rgb(new_h, ll, ss)) + return shades + + viz_logger = logging.getLogger("AFQ") tableau_20 = [ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765), @@ -51,6 +69,18 @@ small_font = 20 marker_size = 200 +slf_l_base = tableau_extension[0] +slf_r_base = tableau_extension[1] + +vof_l_base = tableau_20[6] +vof_r_base = tableau_20[7] + +slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) +slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) + +vof_l_shades = get_distinct_shades(vof_l_base, 5, hue_shift=0.12) +vof_r_shades = get_distinct_shades(vof_r_base, 5, hue_shift=0.12) + COLOR_DICT = OrderedDict( { "Left Anterior Thalamic": tableau_20[0], @@ -75,8 +105,14 @@ "F_L": tableau_20[12], "Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13], - "Left Superior Longitudinal": tableau_20[14], - "Right Superior Longitudinal": tableau_20[15], + "Left Superior Longitudinal": slf_l_base, + "Right Superior Longitudinal": slf_r_base, + "Left Superior Longitudinal I": slf_l_shades[0], + "Left Superior Longitudinal II": slf_l_shades[1], + "Left Superior Longitudinal III": slf_l_shades[2], + "Right Superior Longitudinal I": slf_r_shades[0], + "Right Superior Longitudinal II": slf_r_shades[1], + "Right Superior Longitudinal III": slf_r_shades[2], "Left Uncinate": tableau_20[16], "UF_L": tableau_20[16], "Right Uncinate": tableau_20[17], @@ -85,10 +121,20 @@ "AF_L": tableau_20[18], "Right Arcuate": tableau_20[19], "AF_R": tableau_20[19], - "Left Posterior Arcuate": tableau_20[6], - "Right Posterior Arcuate": tableau_20[7], - "Left Vertical Occipital": tableau_extension[0], - "Right Vertical Occipital": tableau_extension[1], + "Left Posterior Arcuate": tableau_20[14], + "Right Posterior Arcuate": tableau_20[15], + "Left Vertical Occipital": vof_l_base, + "Right Vertical Occipital": vof_r_base, + "Left Vertical Occipital I": vof_l_shades[0], + "Left Vertical Occipital II": vof_l_shades[1], + "Left Vertical Occipital III": vof_l_shades[2], + "Left Vertical Occipital IV": vof_l_shades[3], + "Left Vertical Occipital V": vof_l_shades[4], + "Right Vertical Occipital I": vof_r_shades[0], + "Right Vertical Occipital II": vof_r_shades[1], + "Right Vertical Occipital III": vof_r_shades[2], + "Right Vertical Occipital IV": vof_r_shades[3], + "Right Vertical Occipital V": vof_r_shades[4], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), @@ -510,7 +556,7 @@ def tract_generator( else: if bundle is None: # No selection: visualize all of them: - for bundle_name in seg_sft.bundle_names: + for bundle_name in sorted(seg_sft.bundle_names): idx = seg_sft.bundle_idxs[bundle_name] if len(idx) == 0: continue From 48be3d9d48eff68db8cd7e5ac8fb1be0f9611f2b Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 18 Feb 2026 21:34:36 +0900 Subject: [PATCH 51/51] VOF refinements --- AFQ/api/bundle_dict.py | 159 ++++++++++++++++++++----------- AFQ/data/fetch.py | 6 ++ AFQ/recognition/cleaning.py | 33 +++++-- AFQ/recognition/criteria.py | 18 +++- AFQ/tractography/tractography.py | 4 +- AFQ/viz/utils.py | 8 +- 6 files changed, 148 insertions(+), 80 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 51f4a0c6..ae7b8d59 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -268,7 +268,7 @@ def default_bd(): "Left Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Right Posterior Arcuate": { "cross_midline": False, @@ -285,68 +285,77 @@ def default_bd(): "Right Arcuate": {"overlap": 30}, "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Left Vertical Occipital": { "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], + "exclude": [ + templates["Cerebellar_Hemi_L"], + ], "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, - "length": {"min_len": 25, "max_len": 60}, + "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Left Vertical Occipital I": { - "cluster_ID": 89, - "isolation_forest": {}, + "cluster_ID": 82, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, "Left Vertical Occipital II": { - "cluster_ID": 82, - "isolation_forest": {}, + "cluster_ID": 75, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Left Vertical Occipital III": { - "cluster_ID": 83, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, - "Left Vertical Occipital IV": { + "Left Vertical Occipital III": { "cluster_ID": 21, - "isolation_forest": {}, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Left Vertical Occipital V": { - "cluster_ID": 454, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ + 89, + 93, 27, 100, + 102, + 454, + 27, + 555, + 118, 4, 6, 13, @@ -371,62 +380,71 @@ def default_bd(): "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], + "exclude": [ + templates["Cerebellar_Hemi_R"], + ], "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 20, "project": "L/R", "entire_core": "Anterior", }, - "length": {"min_len": 25, "max_len": 60}, + "length": {"min_len": 30, "max_len": 70}, "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, "ORG_spectral_subbundles": SpectralSubbundleDict( { "Right Vertical Occipital I": { - "cluster_ID": 89, - "isolation_forest": {}, + "cluster_ID": 82, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, "Right Vertical Occipital II": { - "cluster_ID": 82, - "isolation_forest": {}, + "cluster_ID": 75, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Right Vertical Occipital III": { - "cluster_ID": 83, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, - "Right Vertical Occipital IV": { + "Right Vertical Occipital III": { "cluster_ID": 21, - "isolation_forest": {}, "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, }, - }, - "Right Vertical Occipital V": { - "cluster_ID": 454, - "isolation_forest": {}, - "orient_mahal": { - "distance_threshold": 4, - "clean_rounds": 3, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, }, }, }, remove_cluster_IDs=[ + 89, + 93, 27, 100, + 102, + 454, + 27, + 555, + 118, 4, 6, 13, @@ -1310,6 +1328,31 @@ def _cond_load(self, roi_or_sl, resample_to): def get_b_info(self, b_name): return self._dict[b_name] + def relax_cleaning(self, delta_distance=1, delta_length=1): + """ + This can be useful for PTT + """ + cleaner_keys = ["mahal", "isolation_forest", "orient_mahal"] + + for b_name in self.bundle_names: + bundle_data = self._dict[b_name] + + for key in cleaner_keys: + if key in bundle_data: + target = bundle_data[key] + if ( + "distance_threshold" in target + and target["distance_threshold"] != 0 + ): + target["distance_threshold"] += delta_distance + if "length_threshold" in target and target["length_threshold"] != 0: + target["length_threshold"] += delta_length + + if "ORG_spectral_subbundles" in bundle_data: + bundle_data["ORG_spectral_subbundles"].relax_cleaning( + delta_distance, delta_length + ) + def __getitem__(self, key): if isinstance(key, tuple) or isinstance(key, list): # Generates a copy of this BundleDict with only the bundle names diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index b40f34ad..9cb70d23 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -761,6 +761,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_L_start.nii.gz", "pARC_xroi1_L.nii.gz", "pARC_xroi1_R.nii.gz", + "Cerebellar_Hemi_L.nii.gz", + "Cerebellar_Hemi_R.nii.gz", ] @@ -865,6 +867,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944080", "61737616", "61737619", + "61970155", + "61970158", ] @@ -970,6 +974,8 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "1c0b570bb2d622718b01ee2c429a5d15", "51c8a6b5fbb0834b03986093b9ee4fa3", "7cf5800a4efa6bac7e70d84095bc259b", + "f65b3f9133820921d023517a68d4ea41", + "4476935f5aadfcdd633b9a23779625ef", ] fetch_templates = _make_reusable_fetcher( diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 59d3687f..eaaf0cf2 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -50,7 +50,7 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): along_accepted_idx = orientation_along == primary_axis if tol is not None: percentage_primary = ( - 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) + 100 * endpoint_diff[:, primary_axis] / np.sum(endpoint_diff, axis=1) ) logger.debug( (f"Maximum primary percentage found: {np.max(percentage_primary)}") @@ -73,33 +73,42 @@ def clean_by_orientation_mahalanobis( core_only=0.6, min_sl=20, distance_threshold=3, + length_threshold=4, clean_rounds=5, ): + if length_threshold == 0: + length_threshold = np.inf fgarray = np.array(abu.resample_tg(streamlines, n_points)) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : - ] # Crop to middle 60% + ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: - # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray_dists, return_mahalnobis=True, n_points=None, stat=np.mean ) + length_z = zscore(lengths) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) - if not (np.any(m_dist >= distance_threshold)): + if not ( + np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) + ): break + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + idx_len = length_z < length_threshold + idx_belong = np.logical_and(idx_dist, idx_len) - if np.sum(idx_dist) < min_sl: - # need to sort and return exactly min_sl: + if np.sum(idx_belong) < min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] idx = np.sort(idx) logger.debug( @@ -107,9 +116,9 @@ def clean_by_orientation_mahalanobis( ) break else: - # Update by selection: - idx = idx[idx_dist] - fgarray_dists = fgarray_dists[idx_dist] + idx = idx[idx_belong] + fgarray_dists = fgarray_dists[idx_belong] + lengths = lengths[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") @@ -190,6 +199,9 @@ def clean_bundle( else: return tg + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) if core_only != 0: @@ -320,6 +332,9 @@ def clean_by_isolation_forest( ) return np.ones(len(streamlines), dtype=bool) + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) fgarray_dists = np.zeros_like(fgarray) diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 8f227e85..89bffa7f 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -1,10 +1,10 @@ import logging from time import time -import dipy.tracking.streamline as dts import nibabel as nib import numpy as np import ray +from dipy.core.interpolation import interpolate_scalar_3d from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram from dipy.segment.bundles import RecoBundles @@ -56,10 +56,9 @@ def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): b_sls.initiate_selection("Prob. Map") - # using entire fgarray here only because it is the first step - fiber_probabilities = dts.values_from_volume( - bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"], np.eye(4) - ) + fiber_probabilities = interpolate_scalar_3d( + bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"].reshape(-1, 3) + )[0].reshape(-1, 20) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") @@ -154,6 +153,15 @@ def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): else: include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) + # For now I am turning ray parallelization here off. + # It is never worthwhile considering other changes we + # have made to speed up this step, + # so spinning up ray and transferring data back + # and forth is not worth it. + # In the future, I think we should redo this with numba and + # use multithreading + n_cpus = 1 + # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm if n_cpus > 1: diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 1624986f..de99c63d 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -30,7 +30,7 @@ def track( seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=2000000, + n_seeds=5000000, random_seeds=True, rng_seed=None, step_size=0.5, @@ -75,7 +75,7 @@ def track( voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 2000000 + to generate within the mask. Default: 5000000 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. Default: True diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index 7b4bd2d6..e3d2c6c9 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -78,8 +78,8 @@ def get_distinct_shades(base_rgb, n_steps, hue_shift): slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) -vof_l_shades = get_distinct_shades(vof_l_base, 5, hue_shift=0.12) -vof_r_shades = get_distinct_shades(vof_r_base, 5, hue_shift=0.12) +vof_l_shades = get_distinct_shades(vof_l_base, 3, hue_shift=0.15) +vof_r_shades = get_distinct_shades(vof_r_base, 3, hue_shift=0.15) COLOR_DICT = OrderedDict( { @@ -128,13 +128,9 @@ def get_distinct_shades(base_rgb, n_steps, hue_shift): "Left Vertical Occipital I": vof_l_shades[0], "Left Vertical Occipital II": vof_l_shades[1], "Left Vertical Occipital III": vof_l_shades[2], - "Left Vertical Occipital IV": vof_l_shades[3], - "Left Vertical Occipital V": vof_l_shades[4], "Right Vertical Occipital I": vof_r_shades[0], "Right Vertical Occipital II": vof_r_shades[1], "Right Vertical Occipital III": vof_r_shades[2], - "Right Vertical Occipital IV": vof_r_shades[3], - "Right Vertical Occipital V": vof_r_shades[4], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53),