diff --git a/.codespellrc b/.codespellrc index 7035fe17..fdf6f1d6 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] skip = [setup.cfg] -ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus \ No newline at end of file +ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo \ No newline at end of file diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 3d92664c..fb3c6db1 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -2,6 +2,8 @@ import math import numpy as np +from dipy.align import vector_fields as vfu +from dipy.align.imwarp import DiffeomorphicMap, mult_aff from dipy.data import default_sphere from dipy.reconst.gqi import squared_radial_component from dipy.tracking.streamline import set_number_of_points @@ -12,6 +14,75 @@ logger = logging.getLogger("AFQ") +def get_simplified_transform(self): + """Constructs a simplified version of this Diffeomorhic Map + + The simplified version incorporates the pre-align transform, as well as + the domain and codomain affine transforms into the displacement field. + The resulting transformation may be regarded as operating on the + image spaces given by the domain and codomain discretization. As a + result, self.prealign, self.disp_grid2world, self.domain_grid2world and + self.codomain affine will be None (denoting Identity) in the resulting + diffeomorphic map. + """ + if self.dim == 2: + simplify_f = vfu.simplify_warp_function_2d + else: + simplify_f = vfu.simplify_warp_function_3d + # Simplify the forward transform + D = self.domain_grid2world + P = self.prealign + Rinv = self.disp_world2grid + Cinv = self.codomain_world2grid + + # this is the matrix which we need to multiply the voxel coordinates + # to interpolate on the forward displacement field ("in"side the + # 'forward' brackets in the expression above) + affine_idx_in = mult_aff(Rinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the voxel coordinates + # to add to the displacement ("out"side the 'forward' brackets in the + # expression above) + affine_idx_out = mult_aff(Cinv, mult_aff(P, D)) + + # this is the matrix which we need to multiply the displacement vector + # prior to adding to the transformed input point + affine_disp = Cinv + + new_forward = simplify_f( + self.forward, affine_idx_in, affine_idx_out, affine_disp, self.domain_shape + ) + + # Simplify the backward transform + C = self.codomain_grid2world + Pinv = self.prealign_inv + Dinv = self.domain_world2grid + + affine_idx_in = mult_aff(Rinv, C) + affine_idx_out = mult_aff(Dinv, mult_aff(Pinv, C)) + affine_disp = mult_aff(Dinv, Pinv) + new_backward = simplify_f( + self.backward, + affine_idx_in, + affine_idx_out, + affine_disp, + self.codomain_shape, + ) + simplified = DiffeomorphicMap( + dim=self.dim, + disp_shape=self.disp_shape, + disp_grid2world=None, + domain_shape=self.domain_shape, + domain_grid2world=None, + codomain_shape=self.codomain_shape, + codomain_grid2world=None, + prealign=None, + ) + simplified.forward = new_forward + simplified.backward = new_backward + return simplified + + def gwi_odf(gqmodel, data): gqi_vector = np.real( squared_radial_component( diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 22c8e96e..ae7b8d59 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -10,7 +10,8 @@ from AFQ.definitions.utils import find_file from AFQ.tasks.utils import get_fname, str_to_desc -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) __all__ = [ @@ -184,6 +185,7 @@ def default_bd(): "prob_map": templates["IFO_L_prob_map"], "end": templates["IFO_L_start"], "start": templates["IFO_L_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Right Inferior Fronto-occipital": { "cross_midline": False, @@ -193,6 +195,7 @@ def default_bd(): "prob_map": templates["IFO_R_prob_map"], "end": templates["IFO_R_start"], "start": templates["IFO_R_end"], + "length": {"min_len": 100, "max_len": 250}, }, "Left Inferior Longitudinal": { "cross_midline": False, @@ -215,20 +218,22 @@ def default_bd(): "Left Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_L"], templates["SLFt_roi2_L"]], - "exclude": [], + "exclude": [templates["IFO_roi1_L"]], "space": "template", "prob_map": templates["ARC_L_prob_map"], "start": templates["ARC_L_start"], "end": templates["ARC_L_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Right Arcuate": { "cross_midline": False, "include": [templates["SLF_roi1_R"], templates["SLFt_roi2_R"]], - "exclude": [], + "exclude": [templates["IFO_roi1_R"]], "space": "template", "prob_map": templates["ARC_R_prob_map"], "start": templates["ARC_R_start"], "end": templates["ARC_R_end"], + "length": {"min_len": 50, "max_len": 250}, }, "Left Uncinate": { "cross_midline": False, @@ -251,57 +256,223 @@ def default_bd(): "Left Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_L"]], - "exclude": [templates["SLF_roi1_L"]], + "exclude": [ + templates["SLF_roi1_L"], + templates["IFO_roi1_L"], + templates["pARC_xroi1_L"], + ], "space": "template", + "prob_map": templates["ARC_L_prob_map"], "start": templates["pARC_L_start"], + "end": templates["VOF_L_end"], "Left Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Right Posterior Arcuate": { "cross_midline": False, "include": [templates["SLFt_roi2_R"]], - "exclude": [templates["SLF_roi1_R"]], + "exclude": [ + templates["SLF_roi1_R"], + templates["IFO_roi1_R"], + templates["pARC_xroi1_R"], + ], "space": "template", + "prob_map": templates["ARC_R_prob_map"], "start": templates["pARC_R_start"], + "end": templates["VOF_R_end"], "Right Arcuate": {"overlap": 30}, + "length": {"min_len": 30, "max_len": 120}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 50, }, "Left Vertical Occipital": { "cross_midline": False, "space": "template", "end": templates["VOF_L_end"], - "Left Arcuate": {"node_thresh": 20}, + "exclude": [ + templates["Cerebellar_Hemi_L"], + ], + "Left Arcuate": {"node_thresh": 20, "project": "L/R"}, "Left Posterior Arcuate": { "node_thresh": 20, + "project": "L/R", "entire_core": "Anterior", }, - "Left Inferior Fronto-occipital": {"core": "Right"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, - "length": {"min_len": 25}, - "isolation_forest": {}, + "length": {"min_len": 30, "max_len": 70}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Left Vertical Occipital I": { + "cluster_ID": 82, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + "Left Vertical Occipital II": { + "cluster_ID": 75, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + "Left Vertical Occipital III": { + "cluster_ID": 21, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + }, + remove_cluster_IDs=[ + 89, + 93, + 27, + 100, + 102, + 454, + 27, + 555, + 118, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, "Right Vertical Occipital": { "cross_midline": False, "space": "template", "end": templates["VOF_R_end"], - "Right Arcuate": {"node_thresh": 20}, + "exclude": [ + templates["Cerebellar_Hemi_R"], + ], + "Right Arcuate": {"node_thresh": 20, "project": "L/R"}, "Right Posterior Arcuate": { "node_thresh": 20, + "project": "L/R", "entire_core": "Anterior", }, - "Right Inferior Fronto-occipital": {"core": "Left"}, - "orient_mahal": {"distance_threshold": 3, "clean_rounds": 5}, - "length": {"min_len": 25}, - "isolation_forest": {}, + "length": {"min_len": 30, "max_len": 70}, + "mahal": {"clean_rounds": 0}, "primary_axis": "I/S", - "primary_axis_percentage": 40, + "primary_axis_percentage": 60, + "ORG_spectral_subbundles": SpectralSubbundleDict( + { + "Right Vertical Occipital I": { + "cluster_ID": 82, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + "Right Vertical Occipital II": { + "cluster_ID": 75, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + "Right Vertical Occipital III": { + "cluster_ID": 21, + "orient_mahal": { + "distance_threshold": 2, + "length_threshold": 5, + "clean_rounds": 1, + }, + "mahal": { + "distance_threshold": 3, + "length_threshold": 0, + "clean_rounds": 5, + }, + }, + }, + remove_cluster_IDs=[ + 89, + 93, + 27, + 100, + 102, + 454, + 27, + 555, + 118, + 4, + 6, + 13, + 17, + 22, + 23, + 38, + 48, + 50, + 53, + 64, + 65, + 66, + 84, + 87, + 88, + 98, + ], + ), }, }, - citations={"Yeatman2012", "takemura2017occipital"}, + citations={ + "Yeatman2012", + "takemura2017occipital", + "Tzourio-Mazoyer2002", + "zhang2018anatomically", + "Hua2008", + }, ) @@ -313,60 +484,48 @@ def slf_bd(): "include": [templates["SFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Left Superior Longitudinal II": { "include": [templates["MFgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Left Superior Longitudinal III": { "include": [templates["PrgL"], templates["PaL"]], "exclude": [templates["SLFt_roi2_L"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Left Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal I": { "include": [templates["SFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal II": { "include": [templates["MFgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, "Right Superior Longitudinal III": { "include": [templates["PrgR"], templates["PaR"]], "exclude": [templates["SLFt_roi2_R"]], "cross_midline": False, - "mahal": { - "clean_rounds": 20, - "length_threshold": 4, - "distance_threshold": 2, + "Right Cingulum Cingulate": { + "node_thresh": 20, }, }, }, @@ -1052,7 +1211,6 @@ def __init__( self.resample_to = resample_to self.resample_subject_to = resample_subject_to self.keep_in_memory = keep_in_memory - self.max_includes = 3 self.citations = citations if self.citations is None: self.citations = set() @@ -1106,12 +1264,8 @@ def __init__( def __print__(self): print(self._dict) - def update_max_includes(self, new_max): - if new_max > self.max_includes: - self.max_includes = new_max - def _use_bids_info(self, roi_or_sl, bids_layout, bids_path, subject, session): - if isinstance(roi_or_sl, dict): + if isinstance(roi_or_sl, dict) and "roi" not in roi_or_sl: suffix = roi_or_sl.get("suffix", "dwi") roi_or_sl = find_file( bids_layout, bids_path, roi_or_sl, suffix, session, subject @@ -1124,6 +1278,41 @@ def _cond_load(self, roi_or_sl, resample_to): """ Load ROI or streamline if not already loaded """ + if isinstance(roi_or_sl, dict): + space = roi_or_sl.get("space", None) + roi_or_sl = roi_or_sl.get("roi", None) + if roi_or_sl is None or space is None: + raise ValueError( + ( + f"Unclear ROI definition for {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + if space == "template": + resample_to = self.resample_to + elif space == "subject": + resample_to = self.resample_subject_to + if resample_to is False: + raise ValueError( + ( + "When using mixed ROI bundle definitions, " + "and subject space ROIs, " + "resample_subject_to cannot be False." + ) + ) + else: + raise ValueError( + ( + f"Unknown space {space} for ROI definition {roi_or_sl}. " + "See 'Defining Custom Bundle Dictionaries' " + "in the documentation for details." + ) + ) + + logger.debug(f"Loading ROI or streamlines: {roi_or_sl}") + logger.debug(f"Loading ROI or streamlines from space: {resample_to}") + if isinstance(roi_or_sl, str): if ".nii" in roi_or_sl: return afd.read_resample_roi(roi_or_sl, resample_to=resample_to) @@ -1139,6 +1328,31 @@ def _cond_load(self, roi_or_sl, resample_to): def get_b_info(self, b_name): return self._dict[b_name] + def relax_cleaning(self, delta_distance=1, delta_length=1): + """ + This can be useful for PTT + """ + cleaner_keys = ["mahal", "isolation_forest", "orient_mahal"] + + for b_name in self.bundle_names: + bundle_data = self._dict[b_name] + + for key in cleaner_keys: + if key in bundle_data: + target = bundle_data[key] + if ( + "distance_threshold" in target + and target["distance_threshold"] != 0 + ): + target["distance_threshold"] += delta_distance + if "length_threshold" in target and target["length_threshold"] != 0: + target["length_threshold"] += delta_length + + if "ORG_spectral_subbundles" in bundle_data: + bundle_data["ORG_spectral_subbundles"].relax_cleaning( + delta_distance, delta_length + ) + def __getitem__(self, key): if isinstance(key, tuple) or isinstance(key, list): # Generates a copy of this BundleDict with only the bundle names @@ -1177,8 +1391,6 @@ def __getitem__(self, key): def __setitem__(self, key, item): self._dict[key] = item - if hasattr(item, "get"): - self.update_max_includes(len(item.get("include", []))) if key not in self.bundle_names: self.bundle_names.append(key) @@ -1261,24 +1473,31 @@ def is_bundle_in_template(self, bundle_name): return ( "space" not in self._dict[bundle_name] or self._dict[bundle_name]["space"] == "template" + or self._dict[bundle_name]["space"] == "mixed" ) - def _roi_transform_helper(self, roi_or_sl, mapping, new_affine, bundle_name): + def _roi_transform_helper(self, roi_or_sl, mapping, new_img, bundle_name): roi_or_sl = self._cond_load(roi_or_sl, self.resample_to) if isinstance(roi_or_sl, nib.Nifti1Image): + if ( + np.allclose(roi_or_sl.affine, new_img.affine) + and roi_or_sl.shape == new_img.shape[:3] + ): + # This is the case of a mixed bundle definition, where + # some ROIs need transformed and others do not + return roi_or_sl + fdata = roi_or_sl.get_fdata() if len(np.unique(fdata)) <= 2: boolean_ = True else: boolean_ = False - warped_img = auv.transform_inverse_roi( - fdata, mapping, bundle_name=bundle_name - ) + warped_img = auv.transform_roi(fdata, mapping, bundle_name=bundle_name) if boolean_: warped_img = warped_img.astype(np.uint8) - warped_img = nib.Nifti1Image(warped_img, new_affine) + warped_img = nib.Nifti1Image(warped_img, new_img.affine) return warped_img else: return roi_or_sl @@ -1287,7 +1506,7 @@ def transform_rois( self, bundle_name, mapping, - new_affine, + new_img, base_fname=None, to_space="subject", apply_to_recobundles=False, @@ -1306,8 +1525,8 @@ def transform_rois( Name of the bundle to be transformed. mapping : DiffeomorphicMap object A mapping between DWI space and a template. - new_affine : array - Affine of space transformed into. + new_img : Nifti1Image + Image of space transformed into. base_fname : str, optional Base file path to construct file path from. Additional BIDS descriptors will be added to this file path. If None, @@ -1333,7 +1552,7 @@ def transform_rois( bundle_name, self._roi_transform_helper, mapping, - new_affine, + new_img, bundle_name, dry_run=True, apply_to_recobundles=apply_to_recobundles, @@ -1382,7 +1601,9 @@ def transform_rois( def __add__(self, other): for resample in ["resample_to", "resample_subject_to"]: - if ( + if getattr(self, resample) == getattr(other, resample): + pass + elif ( not getattr(self, resample) or not getattr(other, resample) or getattr(self, resample) is None @@ -1429,6 +1650,46 @@ def __add__(self, other): ) +class SpectralSubbundleDict(BundleDict): + """ + A BundleDict where each bundle is defined as a spectral subbundle of a + larger bundle. See `Defining Custom Bundle Dictionaries` in the documentation + for details. + """ + + def __init__( + self, + bundle_info, + resample_to=None, + resample_subject_to=False, + keep_in_memory=False, + citations=None, + remove_cluster_IDs=None, + ): + super().__init__( + bundle_info, resample_to, resample_subject_to, keep_in_memory, citations + ) + if remove_cluster_IDs is None: + remove_cluster_IDs = [] + self.remove_cluster_IDs = remove_cluster_IDs + self.cluster_IDs = [] + self.id_to_name = {} + for b_name, b_info in bundle_info.items(): + if "cluster_ID" not in b_info: + raise ValueError( + ( + f"Bundle {b_name} does not have a cluster_ID. " + "All bundles in a SpectralSubbundleDict must have a cluster_ID." + ) + ) + self.cluster_IDs.append(b_info["cluster_ID"]) + self.id_to_name[b_info["cluster_ID"]] = b_name + self.all_cluster_IDs = self.remove_cluster_IDs + self.cluster_IDs + + def get_subbundle_name(self, cluster_id): + return self.id_to_name.get(cluster_id, None) + + def apply_to_roi_dict( dict_, func, diff --git a/AFQ/api/group.py b/AFQ/api/group.py index a94c6949..aaaea026 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -8,7 +8,6 @@ import warnings from time import time -import dipy.tracking.streamline as dts import dipy.tracking.streamlinespeed as dps import nibabel as nib import numpy as np @@ -49,9 +48,11 @@ __all__ = ["GroupAFQ"] +logging.basicConfig(level=logging.INFO) logger = logging.getLogger("AFQ") logger.setLevel(logging.INFO) + warnings.simplefilter(action="ignore", category=FutureWarning) @@ -553,9 +554,12 @@ def load_next_subject(): this_bundles_file = self.export("bundles", collapse=False)[sub][ses] this_mapping = self.export("mapping", collapse=False)[sub][ses] this_img = self.export("dwi", collapse=False)[sub][ses] + this_reg_template = self.export("reg_template", collapse=False)[sub][ + ses + ] seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img) seg_sft.sft.to_rasmm() - subses_info.append((seg_sft, this_mapping)) + subses_info.append((seg_sft, this_mapping, this_img, this_reg_template)) bundle_dict = self.export("bundle_dict", collapse=False)[ self.valid_sub_list[0] @@ -565,7 +569,7 @@ def load_next_subject(): load_next_subject() # load first subject for b in bundle_dict.bundle_names: for i in range(len(self.valid_sub_list)): - seg_sft, mapping = subses_info[i] + seg_sft, mapping, img, reg_template = subses_info[i] idx = seg_sft.bundle_idxs[b] # use the first subses that works # otherwise try each successive subses @@ -581,14 +585,11 @@ def load_next_subject(): idx = np.random.choice(idx, size=100, replace=False) these_sls = seg_sft.sft.streamlines[idx] these_sls = dps.set_number_of_points(these_sls, 100) - tg = StatefulTractogram(these_sls, seg_sft.sft, Space.RASMM) - delta = dts.values_from_volume( - mapping.forward, tg.streamlines, np.eye(4) - ) - moved_sl = dts.Streamlines( - [d + s for d, s in zip(delta, tg.streamlines)] + tg = StatefulTractogram(these_sls, img, Space.RASMM) + moved_sl = aus.move_streamlines( + tg, "template", mapping, reg_template ) - moved_sl = np.asarray(moved_sl) + moved_sl = np.asarray(moved_sl.streamlines) median_sl = np.median(moved_sl, axis=0) sls_dict[b] = {"coreFiber": median_sl.tolist()} for ii, sl_idx in enumerate(idx): @@ -1023,15 +1024,17 @@ def combine_bundle(self, bundle_name): this_sub = self.valid_sub_list[ii] this_ses = self.valid_ses_list[ii] seg_sft = aus.SegmentedSFT.fromfile(bundles_dict[this_sub][this_ses]) - seg_sft.sft.to_vox() - sls = seg_sft.get_bundle(bundle_name).streamlines + sls = seg_sft.get_bundle(bundle_name) mapping = mapping_dict[this_sub][this_ses] if len(sls) > 0: - delta = dts.values_from_volume(mapping.forward, sls, np.eye(4)) - sls_mni.extend([d + s for d, s in zip(delta, sls)]) + sls_mni.extend( + aus.move_streamlines( + sls, "template", mapping, reg_template + ).streamlines + ) - moved_sft = StatefulTractogram(sls_mni, reg_template, Space.VOX) + moved_sft = StatefulTractogram(sls_mni, reg_template, Space.RASMM) save_path = op.abspath( op.join(self.afq_path, f"bundle-{bundle_name}_subjects-all_MNI.trk") diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index fd0e08e6..07730bc5 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -6,6 +6,7 @@ import nibabel as nib import numpy as np +from dipy.align import resample from PIL import Image, ImageDraw, ImageFont from tqdm import tqdm @@ -34,6 +35,11 @@ __all__ = ["ParticipantAFQ"] +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("AFQ") +logger.setLevel(logging.INFO) + + class ParticipantAFQ(object): f"""{AFQclass_doc}""" @@ -253,7 +259,7 @@ def export_all(self, viz=True, xforms=True, indiv=True): export_all_helper(self, xforms, indiv, viz) self.logger.info(f"Time taken for export all: {time() - start_time}") - def participant_montage(self, images_per_row=2): + def participant_montage(self, images_per_row=3, anatomy=True, bundle_names=None): """ Generate montage of all bundles for a given subject. @@ -261,7 +267,15 @@ def participant_montage(self, images_per_row=2): ---------- images_per_row : int Number of bundle images per row in output file. - Default: 2 + Default: 3 + + anatomy : bool + Whether to include anatomical images in the montage. + Default: True + + bundle_names : list of str or None + List of bundle names to include in the montage. + Default: None (includes all bundles) Returns ------- @@ -270,64 +284,77 @@ def participant_montage(self, images_per_row=2): tdir = tempfile.gettempdir() all_fnames = [] - bundle_dict = self.export("bundle_dict") + if bundle_names is None: + bundle_dict = self.export("bundle_dict") + bundle_names = list(bundle_dict.keys()) self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") - best_scalar = self.export(self.export("best_scalar")) t1 = nib.load(self.export("t1_masked")) - size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) - for ii, bundle_name in enumerate(tqdm(bundle_dict)): + best_scalar = nib.load(self.export(self.kwargs["best_scalar"])) + best_scalar = resample(best_scalar, t1) + size = (images_per_row, math.ceil(3 * len(bundle_names) / images_per_row)) + for ii, bundle_name in enumerate(tqdm(bundle_names)): flip_axes = [False, False, False] for i in range(3): flip_axes[i] = self.export("dwi_affine")[i, i] < 0 - figure = viz_backend.visualize_volume( - t1, flip_axes=flip_axes, interact=False, inline=False - ) + if anatomy: + figure = viz_backend.visualize_volume( + t1.get_fdata(), flip_axes=flip_axes, interact=False, inline=False + ) + else: + figure = None figure = viz_backend.visualize_bundles( self.export("bundles"), - affine=t1.affine, - shade_by_volume=best_scalar, + img=t1, + shade_by_volume=best_scalar.get_fdata(), color_by_direction=True, flip_axes=flip_axes, bundle=bundle_name, figure=figure, + n_points=40, interact=False, inline=False, ) - view, direc = BEST_BUNDLE_ORIENTATIONS.get(bundle_name, ("Axial", "Top")) - eye = get_eye(view, direc) - - this_fname = tdir + f"/t{ii}.png" - if "plotly" in viz_backend.backend: - figure.update_layout( - scene_camera=dict( - projection=dict(type="orthographic"), - up={"x": 0, "y": 0, "z": 1}, - eye=eye, - center=dict(x=0, y=0, z=0), - ), - showlegend=False, - ) - figure.write_image(this_fname, scale=4) + for jj, view in enumerate(["Sagittal", "Coronal", "Axial"]): + direc = BEST_BUNDLE_ORIENTATIONS.get( + bundle_name, ("Left", "Front", "Top") + )[jj] + + eye = get_eye(view, direc) + + this_fname = tdir + f"/t{ii}_{view}.png" + if "plotly" in viz_backend.backend: + figure.update_layout( + scene_camera=dict( + projection=dict(type="orthographic"), + up={"x": 0, "y": 0, "z": 1}, + eye=eye, + center=dict(x=0, y=0, z=0), + ), + showlegend=False, + ) + figure.write_image(this_fname, scale=4) - # temporary fix for memory leak - import plotly.io as pio + # temporary fix for memory leak + import plotly.io as pio - pio.kaleido.scope._shutdown_kaleido() - else: - from dipy.viz import window - - direc = np.fromiter(eye.values(), dtype=int) - data_shape = np.asarray(nib.load(self.export("b0")).get_fdata().shape) - figure.set_camera( - position=direc * data_shape, - focal_point=data_shape // 2, - view_up=(0, 0, 1), - ) - figure.zoom(0.5) - window.snapshot(figure, fname=this_fname, size=(600, 600)) + pio.kaleido.scope._shutdown_kaleido() + else: + from dipy.viz import window + + direc = np.fromiter(eye.values(), dtype=int) + data_shape = np.asarray( + nib.load(self.export("b0")).get_fdata().shape + ) + figure.set_camera( + position=direc * data_shape, + focal_point=data_shape // 2, + view_up=(0, 0, 1), + ) + figure.zoom(0.5) + window.snapshot(figure, fname=this_fname, size=(600, 600)) def _save_file(curr_img): save_path = op.abspath( @@ -339,45 +366,47 @@ def _save_file(curr_img): this_img_trimmed = {} max_height = 0 max_width = 0 - for ii, bundle_name in enumerate(bundle_dict): - this_img = Image.open(tdir + f"/t{ii}.png") - try: - this_img_trimmed[ii] = trim(this_img) - except IndexError: # this_img is a picture of nothing - this_img_trimmed[ii] = this_img - - text_sz = 70 - width, height = this_img_trimmed[ii].size - height = height + text_sz - result = Image.new( - this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) - ) - result.paste(this_img_trimmed[ii], (0, text_sz)) - this_img_trimmed[ii] = result - - draw = ImageDraw.Draw(this_img_trimmed[ii]) - draw.text( - (0, 0), - bundle_name, - (0, 0, 0), - font=ImageFont.truetype("Arial", text_sz), - ) + ii = 0 + for b_idx, bundle_name in enumerate(bundle_names): + for view in ["Axial", "Coronal", "Sagittal"]: + this_img = Image.open(tdir + f"/t{b_idx}_{view}.png") + try: + this_img_trimmed[ii] = trim(this_img) + except IndexError: # this_img is a picture of nothing + this_img_trimmed[ii] = this_img + + text_sz = 40 + width, height = this_img_trimmed[ii].size + height = height + text_sz + result = Image.new( + this_img_trimmed[ii].mode, (width, height), color=(255, 255, 255) + ) + result.paste(this_img_trimmed[ii], (0, text_sz)) + this_img_trimmed[ii] = result + + draw = ImageDraw.Draw(this_img_trimmed[ii]) + draw.text( + (0, 0), + f"{bundle_name} - {view}", + (0, 0, 0), + font=ImageFont.load_default(text_sz), + ) - if this_img_trimmed[ii].size[0] > max_width: - max_width = this_img_trimmed[ii].size[0] - if this_img_trimmed[ii].size[1] > max_height: - max_height = this_img_trimmed[ii].size[1] + if this_img_trimmed[ii].size[0] > max_width: + max_width = this_img_trimmed[ii].size[0] + if this_img_trimmed[ii].size[1] > max_height: + max_height = this_img_trimmed[ii].size[1] + ii += 1 curr_img = Image.new( "RGB", (max_width * size[0], max_height * size[1]), color="white" ) - for ii in range(len(bundle_dict)): - x_pos = ii % size[0] - _ii = ii // size[0] + for jj in range(ii): + x_pos = jj % size[0] + _ii = jj // size[0] y_pos = _ii % size[1] - _ii = _ii // size[1] - this_img = this_img_trimmed[ii].resize((max_width, max_height)) + this_img = this_img_trimmed[jj].resize((max_width, max_height)) curr_img.paste(this_img, (x_pos * max_width, y_pos * max_height)) _save_file(curr_img) diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9400509e..9cb70d23 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -759,6 +759,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ATR_R_start.nii.gz", "ATR_L_end.nii.gz", "ATR_L_start.nii.gz", + "pARC_xroi1_L.nii.gz", + "pARC_xroi1_R.nii.gz", + "Cerebellar_Hemi_L.nii.gz", + "Cerebellar_Hemi_R.nii.gz", ] @@ -861,6 +865,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "40944074", "40944077", "40944080", + "61737616", + "61737619", + "61970155", + "61970158", ] @@ -964,6 +972,10 @@ def read_resample_roi(roi, resample_to=None, threshold=False): "ffc157e9f73a43eff23821f2cfca614a", "a8d308a93b26242c04b878c733cb252f", "1c0b570bb2d622718b01ee2c429a5d15", + "51c8a6b5fbb0834b03986093b9ee4fa3", + "7cf5800a4efa6bac7e70d84095bc259b", + "f65b3f9133820921d023517a68d4ea41", + "4476935f5aadfcdd633b9a23779625ef", ] fetch_templates = _make_reusable_fetcher( @@ -1089,6 +1101,174 @@ def read_oton_templates(as_img=True, resample_to=False): return template_dict +org800_fnames = [ + "ORG_atlas_tracks_reoriented.trx", + "ORG800_atlas_centroids.npy", + "ORG800_atlas_e_val.npy", + "ORG800_atlas_e_vec_norm.npy", + "ORG800_atlas_e_vec.npy", + "ORG800_atlas_number_of_eigenvectors.npy", + "ORG800_atlas_row_sum_1.npy", + "ORG800_atlas_row_sum_matrix.npy", + "ORG800_atlas_sigma.npy", +] + + +org800_remote_fnames = [ + "61762231", + "61762267", + "61762270", + "61762273", + "61762276", + "61762279", + "61762282", + "61762285", + "61762288", +] + + +org800_md5_hashes = [ + "9022799a73359209080ea832b22ec09b", + "09bfa384f5c44801dfa382d31392a979", + "bab61eb26cb21035e38b5f68b5fdad3e", + "9325f4cb168624d4f275785b18c9f859", + "12d426c5a6fcfbe3b8146bc335bdac96", + "db74e055c47c5b6354c3cb7bbf165f2c", + "e7e51f53b30764f104b93f50d94b6c3c", + "7e894c57a820cd7604a1db6b7ab8cce6", + "1e195b7055e98eb473bbb5af05d48f7d", +] + +fetch_org800_templates = _make_reusable_fetcher( + "fetch_org800_templates", + op.join(afq_home, "org800_templates"), + baseurl, + org800_remote_fnames, + org800_fnames, + md5_list=org800_md5_hashes, + doc="Download AFQ org800 templates", +) + + +def read_org800_templates(load_npy=True, load_trx=True): + """ + Load O'Donnell Research Group (ORG) Fiber Clustering White + Matter Atlas 800 modified for pyAFQ templates from file + + Parameters + ---------- + load_npy : bool, optional + If True, values are loaded as numpy arrays. Otherwise, values are + paths to npy files. Default: True + load_trx : bool, optional + If True, the tractogram is loaded as a StatefulTractogram. Otherwise, + the value is the path to the trx file. Default: True + + Returns + ------- + dict with: keys: names of atlas info + values: Floats, arrays, and StatefulTractogram for the atlas. + Any unloaded will instead be paths. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading org800 templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template(fetch_org800_templates) + + if load_trx: + template_dict["tracks_reoriented"] = load_tractogram( + template_dict.pop("ORG_atlas_tracks_reoriented"), + "same", + ) + if load_npy: + template_dict["centroids"] = np.load( + template_dict.pop("ORG800_atlas_centroids") + ) + template_dict["e_val"] = np.load(template_dict.pop("ORG800_atlas_e_val")) + template_dict["e_vec_norm"] = np.load( + template_dict.pop("ORG800_atlas_e_vec_norm") + ) + template_dict["e_vec"] = np.load(template_dict.pop("ORG800_atlas_e_vec")) + template_dict["number_of_eigenvectors"] = float( + np.load(template_dict.pop("ORG800_atlas_number_of_eigenvectors")) + ) + template_dict["row_sum_1"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_1") + ) + template_dict["row_sum_matrix"] = np.load( + template_dict.pop("ORG800_atlas_row_sum_matrix") + ) + template_dict["sigma"] = float(np.load(template_dict.pop("ORG800_atlas_sigma"))) + + toc = time.perf_counter() + logger.debug( + f"O'Donnell Research Group 800 templates loaded in {toc - tic:0.4f} seconds" + ) + + return template_dict + + +massp_fnames = [ + "left_VTA.nii.gz", + "right_VTA.nii.gz", +] + +massp_remote_fnames = [ + "34892325", + "34892319", +] + +massp_md5_hashes = [ + "03d65d85abb161ea25501c343c136e40", + "440874b899d2c1057e5fd77b8b350bc4", +] + +fetch_massp_templates = _make_reusable_fetcher( + "fetch_massp_templates", + op.join(afq_home, "massp_templates"), + baseurl, + massp_remote_fnames, + massp_fnames, + md5_list=massp_md5_hashes, + doc="Download AFQ MassP templates", +) + + +def read_massp_templates(as_img=True, resample_to=False): + """Load AFQ MASSP templates from file + + Parameters + ---------- + as_img : bool, optional + If True, values are `Nifti1Image`. Otherwise, values are + paths to Nifti files. Default: True + resample_to : str or nibabel image class instance, optional + A template image to resample to. Typically, this should be the + template to which individual-level data are registered. Defaults to + the MNI template. Default: False + + Returns + ------- + dict with: keys: names of template ROIs and values: nibabel Nifti1Image + objects from each of the ROI nifti files. + """ + logger = logging.getLogger("AFQ") + + logger.debug("loading MASSP templates") + tic = time.perf_counter() + + template_dict = _fetcher_to_template( + fetch_massp_templates, as_img=as_img, resample_to=resample_to + ) + + toc = time.perf_counter() + logger.debug(f"MASSP templates loaded in {toc - tic:0.4f} seconds") + + return template_dict + + cp_fnames = [ "ICP_L_inferior_prob.nii.gz", "ICP_L_superior_prob.nii.gz", diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 4208056c..4d68fb8b 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -3,8 +3,10 @@ import nibabel as nib import numpy as np from dipy.align import resample +from scipy.ndimage import distance_transform_edt from AFQ.definitions.utils import Definition, find_file, name_from_path +from AFQ.recognition.utils import tolerance_mm_to_vox from AFQ.tasks.utils import get_tp __all__ = [ @@ -324,6 +326,13 @@ class RoiImage(ImageDefinition): use_endpoints : bool Whether to use the endpoints ("start" and "end") to generate the image. + only_wmgmi : bool + Whether to only include portion of ROIs in the WM-GM interface. + only_wm : bool + Whether to only include portion of ROIs in the white matter. + dilate : bool + Whether to dilate the ROIs before combining them, according to the + tolerance that will be used during bundle recognition. tissue_property : str or None Tissue property from `scalars` to multiply the ROI image with. Can be useful to limit seed mask to the core white matter. @@ -350,14 +359,20 @@ def __init__( use_presegment=False, use_endpoints=False, only_wmgmi=False, + only_wm=False, + dilate=True, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None, ): + if only_wmgmi and only_wm: + raise ValueError("only_wmgmi and only_wm cannot both be True") self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints self.only_wmgmi = only_wmgmi + self.only_wm = only_wm + self.dilate = dilate self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold @@ -384,23 +399,38 @@ def _image_getter_helper( for bundle_name in bundle_dict: bundle_entry = bundle_dict.transform_rois( - bundle_name, mapping_imap["mapping"], data_imap["dwi_affine"] + bundle_name, mapping_imap["mapping"], data_imap["dwi"] ) - rois = [] + rois = {} if self.use_endpoints: - rois.extend( - [ - bundle_entry[end_type] + rois.update( + { + bundle_entry[end_type]: end_type for end_type in ["start", "end"] if end_type in bundle_entry - ] + } ) if self.use_waypoints: - rois.extend(bundle_entry.get("include", [])) - for roi in rois: + rois.update( + dict.fromkeys(bundle_entry.get("include", []), "waypoint") + ) + + dist_to_waypoint, dist_to_atlas, _ = tolerance_mm_to_vox( + data_imap["dwi"], + segmentation_params["dist_to_waypoint"], + segmentation_params["dist_to_atlas"], + ) + for roi, roi_type in rois.items(): warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) + if self.dilate: + edt = distance_transform_edt(np.where(warped_roi == 0, 1, 0)) + if roi_type == "waypoint": + warped_roi = edt <= dist_to_waypoint + else: + warped_roi = edt <= dist_to_atlas + image_data = np.logical_or(image_data, warped_roi.astype(bool)) if self.tissue_property is not None: tp = nib.load( @@ -456,6 +486,10 @@ def _image_getter_helper( ) ) + if self.only_wm: + wm = nib.load(tissue_imap["pve_internal"]).get_fdata()[..., 2] >= 0.5 + image_data = np.logical_and(image_data, wm) + return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"] ), dict(source="ROIs") @@ -898,7 +932,7 @@ def _image_getter_helper(mapping, reg_template, reg_subject): static_affine=reg_template.affine, ).get_fdata() - scalar_data = mapping.transform_inverse(img_data, interpolation="nearest") + scalar_data = mapping.transform(img_data, interpolation="nearest") return nib.Nifti1Image( scalar_data.astype(np.float32), reg_subject.affine ), dict(source=self.path) diff --git a/AFQ/definitions/mapping.py b/AFQ/definitions/mapping.py index dfcedef7..ed5ca043 100644 --- a/AFQ/definitions/mapping.py +++ b/AFQ/definitions/mapping.py @@ -5,9 +5,10 @@ import nibabel as nib import numpy as np from dipy.align import affine_registration, syn_registration -from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr import AFQ.registration as reg +from AFQ._fixes import get_simplified_transform from AFQ.definitions.utils import Definition, find_file from AFQ.tasks.utils import get_fname from AFQ.utils.path import space_from_fname, write_json @@ -177,11 +178,11 @@ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp - def transform_inverse(self, data, **kwargs): + def transform(self, data, **kwargs): data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) - def transform_inverse_pts(self, pts): + def transform_pts(self, pts): # This should only be used for curvature analysis, # Because I think the results still need to be shifted pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts) @@ -189,39 +190,13 @@ def transform_inverse_pts(self, pts): pts = self.warp.transform(pts, "fsl", "world") return pts - def transform(self, data, **kwargs): + def transform_inverse(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space" ) -class IdentityMap(Definition): - """ - Does not perform any transformations from MNI to subject where - pyAFQ normally would. - - Examples - -------- - my_example_mapping = IdentityMap() - api.GroupAFQ(mapping=my_example_mapping) - """ - - def __init__(self): - pass - - def get_for_subses( - self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name - ): - return ConformedAffineMapping( - np.identity(4), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) - - class GeneratedMapMixin(object): """ Helper Class @@ -236,24 +211,17 @@ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): mapping_file = mapping_file + extension return mapping_file, meta_fname - def prealign( - self, base_fname, sub_name, tmpl_name, reg_subject, reg_template, save=True - ): - prealign_file_desc = f"_desc-prealign_from-{sub_name}_to-{tmpl_name}_xform" - prealign_file = get_fname(base_fname, f"{prealign_file_desc}.npy") - if not op.exists(prealign_file): - start_time = time() - _, aff = affine_registration( - reg_subject, reg_template, **self.affine_kwargs - ) - meta = dict(type="rigid", dependent="dwi", timing=time() - start_time) - if not save: - return aff - logger.info(f"Saving {prealign_file}") - np.save(prealign_file, aff) - meta_fname = get_fname(base_fname, f"{prealign_file_desc}.json") - write_json(meta_fname, meta) - return prealign_file if save else np.load(prealign_file) + def prealign(self, reg_subject, reg_template): + logger.info("Calculating affine pre-alignment...") + _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) + return aff + + +class AffineMapMixin(GeneratedMapMixin): + """ + Helper Class + Useful for maps that are generated by pyAFQ + """ def get_for_subses( self, @@ -268,34 +236,22 @@ def get_for_subses( ): sub_space = space_from_fname(dwi_data_file) mapping_file, meta_fname = self.get_fnames( - self.extension, base_fname, sub_space, tmpl_name + ".npy", base_fname, sub_space, tmpl_name ) - if self.use_prealign: - reg_prealign = np.load( - self.prealign( - base_fname, sub_space, tmpl_name, reg_subject, reg_template - ) - ) - else: - reg_prealign = None if not op.exists(mapping_file): start_time = time() - mapping = self.gen_mapping( - base_fname, - sub_space, - tmpl_name, + affine = self.gen_mapping( reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") - reg.write_mapping(mapping, mapping_file) - meta = dict(type="displacementfield", timing=total_time) + np.save(mapping_file, affine) + meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" else: @@ -305,10 +261,7 @@ def get_for_subses( if isinstance(reg_template, str): meta["reg_template"] = reg_template write_json(meta_fname, meta) - reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign else None - mapping = reg.read_mapping( - mapping_file, dwi, reg_template, prealign=reg_prealign_inv - ) + mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template) return mapping @@ -353,33 +306,71 @@ def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None): self.use_prealign = use_prealign self.affine_kwargs = affine_kwargs self.syn_kwargs = syn_kwargs - self.extension = ".nii.gz" - def gen_mapping( + def get_for_subses( self, base_fname, - sub_space, - tmpl_name, + dwi, + dwi_data_file, reg_subject, reg_template, - subject_sls, - template_sls, - reg_prealign, + tmpl_name, + subject_sls=None, + template_sls=None, ): - _, mapping = syn_registration( - reg_subject.get_fdata(), - reg_template.get_fdata(), - moving_affine=reg_subject.affine, - static_affine=reg_template.affine, - prealign=reg_prealign, - **self.syn_kwargs, + sub_space = space_from_fname(dwi_data_file) + mapping_file_forward, meta_forward_fname = self.get_fnames( + ".nii.gz", base_fname, sub_space, tmpl_name + ) + mapping_file_backward, meta_backward_fname = self.get_fnames( + ".nii.gz", base_fname, tmpl_name, sub_space ) - if self.use_prealign: - mapping.codomain_world2grid = np.linalg.inv(reg_prealign) + + if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward): + meta = dict(type="displacementfield") + meta["dependent"] = "dwi" + if isinstance(reg_subject, str): + meta["reg_subject"] = reg_subject + if isinstance(reg_template, str): + meta["reg_template"] = reg_template + + start_time = time() + if self.use_prealign: + reg_prealign = self.prealign(reg_subject, reg_template) + else: + reg_prealign = None + + logger.info("Calculating SyN registration...") + _, mapping = syn_registration( + reg_subject.get_fdata(), + reg_template.get_fdata(), + moving_affine=reg_subject.affine, + static_affine=reg_template.affine, + prealign=reg_prealign, + **self.syn_kwargs, + ) + mapping = get_simplified_transform(mapping) + + total_time = time() - start_time + meta["total_time"] = total_time + + logger.info(f"Saving {mapping_file_forward}") + nib.save( + nib.Nifti1Image(mapping.forward, reg_subject.affine), + mapping_file_forward, + ) + write_json(meta_forward_fname, meta) + logger.info(f"Saving {mapping_file_backward}") + nib.save( + nib.Nifti1Image(mapping.backward, reg_template.affine), + mapping_file_backward, + ) + write_json(meta_backward_fname, meta) + mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward) return mapping -class SlrMap(GeneratedMapMixin, Definition): +class SlrMap(AffineMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. @@ -407,33 +398,23 @@ class SlrMap(GeneratedMapMixin, Definition): def __init__(self, slr_kwargs=None): if slr_kwargs is None: slr_kwargs = {} - self.slr_kwargs = {} - self.use_prealign = False - self.extension = ".npy" + self.slr_kwargs = slr_kwargs def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, - reg_template, reg_subject, + reg_template, subject_sls, template_sls, - reg_prealign, ): - return reg.slr_registration( - subject_sls, - template_sls, - moving_affine=reg_subject.affine, - moving_shape=reg_subject.shape, - static_affine=reg_template.affine, - static_shape=reg_template.shape, - **self.slr_kwargs, + _, transform, _, _ = whole_brain_slr( + subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs ) + return transform + -class AffMap(GeneratedMapMixin, Definition): +class AffMap(AffineMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. @@ -457,47 +438,37 @@ class AffMap(GeneratedMapMixin, Definition): def __init__(self, affine_kwargs=None): if affine_kwargs is None: affine_kwargs = {} - self.use_prealign = False self.affine_kwargs = affine_kwargs - self.extension = ".npy" def gen_mapping( self, - base_fname, - sub_space, - tmpl_name, reg_subject, reg_template, subject_sls, template_sls, - reg_prealign, ): - return ConformedAffineMapping( - np.linalg.inv( - self.prealign( - base_fname, - sub_space, - tmpl_name, - reg_subject, - reg_template, - save=False, - ) - ), - domain_grid_shape=reg.reduce_shape(reg_subject.shape), - domain_grid2world=reg_subject.affine, - codomain_grid_shape=reg.reduce_shape(reg_template.shape), - codomain_grid2world=reg_template.affine, - ) + return np.linalg.inv(self.prealign(reg_subject, reg_template)) -class ConformedAffineMapping(AffineMap): +class IdentityMap(AffineMapMixin, Definition): """ - Modifies AffineMap API to match DiffeomorphicMap API. - Important for SLR maps API to be indistinguishable from SYN maps API. + Does not perform any transformations from MNI to subject where + pyAFQ normally would. + + Examples + -------- + my_example_mapping = IdentityMap() + api.GroupAFQ(mapping=my_example_mapping) """ - def transform(self, *args, **kwargs): - return super().transform_inverse(*args, **kwargs) + def __init__(self): + pass - def transform_inverse(self, *args, **kwargs): - return super().transform(*args, **kwargs) + def gen_mapping( + self, + reg_subject, + reg_template, + subject_sls, + template_sls, + ): + return np.identity(4) diff --git a/AFQ/nn/synthseg.py b/AFQ/nn/synthseg.py index 45f80f52..6ea9d0c2 100644 --- a/AFQ/nn/synthseg.py +++ b/AFQ/nn/synthseg.py @@ -86,8 +86,8 @@ def pve_from_synthseg(synthseg_data): PVE data with CSF, GM, and WM segmentations. """ - CSF_labels = [0, 3, 4, 11, 12, 21, 22, 17] - GM_labels = [2, 7, 8, 9, 10, 14, 15, 16, 20, 25, 26, 27, 28, 29, 30, 31] + CSF_labels = [0, 3, 4, 11, 12, 21, 22, 16] + GM_labels = [2, 7, 8, 9, 10, 14, 15, 17, 20, 25, 26, 27, 28, 29, 30, 31] WM_labels = [1, 5, 19, 23] mixed_labels = [13, 18, 32] diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index db741fe0..eaaf0cf2 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -50,7 +50,7 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): along_accepted_idx = orientation_along == primary_axis if tol is not None: percentage_primary = ( - 100 * axis_diff[:, primary_axis] / np.sum(axis_diff, axis=1) + 100 * endpoint_diff[:, primary_axis] / np.sum(endpoint_diff, axis=1) ) logger.debug( (f"Maximum primary percentage found: {np.max(percentage_primary)}") @@ -73,42 +73,52 @@ def clean_by_orientation_mahalanobis( core_only=0.6, min_sl=20, distance_threshold=3, + length_threshold=4, clean_rounds=5, ): + if length_threshold == 0: + length_threshold = np.inf fgarray = np.array(abu.resample_tg(streamlines, n_points)) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : - ] # Crop to middle 60% + ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: - # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray_dists, return_mahalnobis=True, n_points=None, stat=np.mean ) + length_z = zscore(lengths) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) - if not (np.any(m_dist >= distance_threshold)): + if not ( + np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) + ): break + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + idx_len = length_z < length_threshold + idx_belong = np.logical_and(idx_dist, idx_len) - if np.sum(idx_dist) < min_sl: - # need to sort and return exactly min_sl: + if np.sum(idx_belong) < min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) break else: - # Update by selection: - idx = idx[idx_dist] - fgarray_dists = fgarray_dists[idx_dist] + idx = idx[idx_belong] + fgarray_dists = fgarray_dists[idx_belong] + lengths = lengths[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") @@ -189,6 +199,9 @@ def clean_bundle( else: return tg + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) if core_only != 0: @@ -231,6 +244,7 @@ def clean_bundle( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) @@ -318,6 +332,9 @@ def clean_by_isolation_forest( ) return np.ones(len(streamlines), dtype=bool) + if length_threshold == 0: + length_threshold = np.inf + # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) fgarray_dists = np.zeros_like(fgarray) @@ -360,6 +377,7 @@ def clean_by_isolation_forest( if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(-sl_outliers)[:min_sl].astype(int)] + idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py new file mode 100644 index 00000000..2d33b6c3 --- /dev/null +++ b/AFQ/recognition/clustering.py @@ -0,0 +1,195 @@ +# Original source: github.com/SlicerDMRI/whitematteranalysis +# Copyright 2026 BWH and 3D Slicer contributors +# Licensed under 3D Slicer license (BSD style; https://github.com/SlicerDMRI/whitematteranalysis/blob/master/License.txt) # noqa +# Modified by John Kruper for pyAFQ +# Modifications: +# 1. Only mean distance included, and mean distance replaced with numba version. +# 2. Uses atlas data from dictionary and numpy files rather than pickled files, +# to avoid additional dependencies. +# 3. Added function to move template streamlines +# to subject space to calculate distances. + +import numpy as np +import scipy +from dipy.io.stateful_tractogram import Space +from numba import njit, prange + +import AFQ.data.fetch as afd +import AFQ.recognition.utils as abu +import AFQ.utils.streamlines as aus + + +@njit(parallel=True) +def _compute_mean_euclidean_matrix(group_n, group_m): + len_n = group_n.shape[0] + len_m = group_m.shape[0] + num_points = group_n.shape[1] + + dist_matrix = np.empty((len_n, len_m), dtype=np.float64) + + for i in prange(len_n): + for j in range(len_m): + sum_dist = 0.0 + sum_dist_ref = 0.0 + + for k in range(num_points): + dx = group_n[i, k, 0] - group_m[j, k, 0] + dx_ref = group_n[i, k, 0] + group_m[j, k, 0] + dy = group_n[i, k, 1] - group_m[j, k, 1] + dz = group_n[i, k, 2] - group_m[j, k, 2] + + sum_dist += np.sqrt(dx * dx + dy * dy + dz * dz) + sum_dist_ref += np.sqrt(dx_ref * dx_ref + dy * dy + dz * dz) + + mean_d = sum_dist / num_points + mean_d_ref = sum_dist_ref / num_points + + final_d = min(mean_d, mean_d_ref) + dist_matrix[i, j] = final_d * final_d + + return dist_matrix.T + + +def _distance_to_similarity(distance, sigmasq): + similarities = np.exp(-distance / (sigmasq)) + + return similarities + + +def _rectangular_similarity_matrix(fgarray_sub, fgarray_atlas, sigma): + distances = _compute_mean_euclidean_matrix(fgarray_sub, fgarray_atlas) + + sigmasq = sigma * sigma + similarity_matrix = _distance_to_similarity(distances, sigmasq) + + return similarity_matrix + + +def spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=None, + sigma_multiplier=1.0, + cluster_indices=None, +): + """ + Use an existing atlas to label a new streamlines. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group to be labeled. + atlas_fgarray : ndarray + Resampled atlas to use for labelling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + sigma_multiplier : float, optional + Multiplier for the sigma value used in computing the similarity + matrix. Default is 1.0. + cluster_indices : list of int, optional + If provided, only these cluster indices from the atlas will be used + for labeling. Default is None, which uses all clusters. + + Returns + ------- + tuple of (ndarray, ndarray) + Cluster indices for all the fibers and their embedding + """ + if atlas_data is None: + atlas_data = afd.read_org800_templates(load_trx=False) + + number_fibers = sub_fgarray.shape[0] + sz = atlas_fgarray.shape[0] + + # Compute fiber similarities. + B = _rectangular_similarity_matrix( + sub_fgarray, atlas_fgarray, sigma=atlas_data["sigma"] * sigma_multiplier + ) + + # Do Normalized Cuts transform of similarity matrix. + # row sum estimate for current B part of the matrix + row_sum_2 = np.sum(B, axis=0) + np.dot(atlas_data["row_sum_matrix"], B) + + # This happens plenty in our cases. Why? + # Maybe a probabilistic vs UKF thing? + # In practice, this is not an issue since we just set to a small value. + if any(row_sum_2 <= 0): + row_sum_2[row_sum_2 < 0] = 1e-4 + + # Normalized cuts normalization + row_sum = np.concatenate((atlas_data["row_sum_1"], row_sum_2)) + dhat = np.sqrt(np.divide(1, row_sum)) + B = np.multiply(B, np.outer(dhat[0:sz], dhat[sz:].T)) + + # Compute embedding using eigenvectors + V = np.dot( + np.dot(B.T, atlas_data["e_vec"]), np.diag(np.divide(1.0, atlas_data["e_val"])) + ) + V = np.divide(V, atlas_data["e_vec_norm"]) + n_eigen = int(atlas_data["number_of_eigenvectors"]) + embed = np.zeros((number_fibers, n_eigen)) + for i in range(0, n_eigen): + embed[:, i] = np.divide(V[:, -(i + 2)], V[:, -1]) + + # Label streamlines using centroids from atlas + if cluster_indices is not None: + centroids = atlas_data["centroids"][cluster_indices, :] + cluster_idx, _ = scipy.cluster.vq.vq(embed, centroids) + cluster_idx = np.array([cluster_indices[i] for i in cluster_idx]) + else: + cluster_idx, _ = scipy.cluster.vq.vq(embed, atlas_data["centroids"]) + + return cluster_idx, embed + + +def subcluster_by_atlas( + sub_trk, mapping, dwi_ref, cluster_indices, atlas_data=None, n_points=20 +): + """ + Use an existing atlas to label a new set of streamlines, and return the + cluster indices for each streamline. + + Parameters + ---------- + sub_fgarray : ndarray + Resampled fiber group in VOX to be labeled. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + dwi_ref : Nifti1Image + Image defining reference for where the atlas streamlines move to. + cluster_indices : list of int + Cluster indices from the atlas to use for labeling. + atlas_data : dict, optional + Precomputed atlas data formatted as a dictionary of arrays and floats. + See `afd.read_org800_templates` as a reference. + n_points : int, optional + Number of points to resample streamlines to for labeling. Default is 20. + """ + + if atlas_data is None: + atlas_data = afd.read_org800_templates() + atlas_sft = atlas_data["tracks_reoriented"] + + moved_atlas_sft = aus.move_streamlines( + atlas_sft, "subject", mapping, dwi_ref, to_space=Space.RASMM + ) + atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) + + # Note: if we need more efficiency, + # we could modify the code to consider: + # voxel size, midline axis, and midline location + # then we should be able to do these calculations in + # voxel space without having to move the subject streamlines + # to rasmm (but this is not a bottleneck right now) + sub_trk.to_rasmm() + sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) + + cluster_idxs, _ = spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + + return cluster_idxs diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index e68e6251..89bffa7f 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -1,10 +1,10 @@ import logging from time import time -import dipy.tracking.streamline as dts import nibabel as nib import numpy as np import ray +from dipy.core.interpolation import interpolate_scalar_3d from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import load_tractogram from dipy.segment.bundles import RecoBundles @@ -19,14 +19,16 @@ import AFQ.recognition.roi as abr import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict +from AFQ.recognition.clustering import subcluster_by_atlas from AFQ.utils.stats import chunk_indices +from AFQ.utils.streamlines import move_streamlines criteria_order_pre_other_bundles = [ "prob_map", "cross_midline", + "length", "start", "end", - "length", "primary_axis", "include", "exclude", @@ -44,6 +46,8 @@ "primary_axis_percentage", "inc_addtol", "exc_addtol", + "ORG_spectral_subbundles", + "cluster_ID", ] @@ -52,10 +56,9 @@ def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): b_sls.initiate_selection("Prob. Map") - # using entire fgarray here only because it is the first step - fiber_probabilities = dts.values_from_volume( - bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"], np.eye(4) - ) + fiber_probabilities = interpolate_scalar_3d( + bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"].reshape(-1, 3) + )[0].reshape(-1, 20) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") @@ -69,18 +72,17 @@ def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): def start(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("Startpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("Startpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], 0, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["start"], -1, tol=preproc_imap["dist_to_atlas"], @@ -91,18 +93,17 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): def end(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("endpoint") - abr.clean_by_endpoints( - b_sls.get_selected_sls(), + b_sls.initiate_selection("endpoint") + accept_idx = abr.clean_by_endpoints( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], -1, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, - accepted_idxs=accept_idx, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - b_sls.get_selected_sls(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], bundle_def["end"], 0, tol=preproc_imap["dist_to_atlas"], @@ -113,13 +114,18 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): def length(b_sls, bundle_def, preproc_imap, **kwargs): - accept_idx = b_sls.initiate_selection("length") + b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get("max_len", np.inf) / preproc_imap["vox_dim"] - for idx, sl in enumerate(b_sls.get_selected_sls()): - sl_len = np.sum(np.linalg.norm(np.diff(sl, axis=0), axis=1)) - if sl_len >= min_len and sl_len <= max_len: - accept_idx[idx] = 1 + + # Using resampled fgarray biases lengths to be lower. However, + # this is not meant to be a precise selection requirement, and + # is more meant for efficiency. + segments = np.diff(preproc_imap["fgarray"][b_sls.selected_fiber_idxs], axis=1) + segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) + sl_lens = np.sum(segment_lengths, axis=1) + + accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") @@ -134,7 +140,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): +def include(b_sls, bundle_def, preproc_imap, n_cpus, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet @@ -147,6 +153,15 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): else: include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) + # For now I am turning ray parallelization here off. + # It is never worthwhile considering other changes we + # have made to speed up this step, + # so spinning up ray and transferring data back + # and forth is not worth it. + # In the future, I think we should redo this with numba and + # use multithreading + n_cpus = 1 + # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm if n_cpus > 1: @@ -172,15 +187,18 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols ) - roi_closest = -np.ones((max_includes, len(b_sls)), dtype=np.int32) + n_inc = len(bundle_def["include"]) + roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) + roi_dists = np.zeros((n_inc, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): - sl_accepted, sl_closest = inc_result + sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: + roi_closest[:, sl_idx] = sl_closest + roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: - roi_closest[: len(sl_closest), sl_idx] = sl_closest # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI @@ -188,12 +206,14 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, n_cpus, **kwargs): if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: - roi_closest[: len(sl_closest), sl_idx] = np.flip(sl_closest) + roi_closest[:, sl_idx] = np.flip(sl_closest) + roi_dists[:, sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 b_sls.roi_closest = roi_closest.T + b_sls.roi_dists = roi_dists.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include") @@ -211,7 +231,7 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False ) - moved_ref_sl = abu.move_streamlines( + moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) moved_ref_sl.to_vox() @@ -257,11 +277,12 @@ def recobundles( **kwargs, ): b_sls.initiate_selection("Recobundles") - moved_sl = abu.move_streamlines( + moved_sl = move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), "template", mapping, reg_template, + to_space=Space.RASMM, save_intermediates=save_intermediates, ).streamlines moved_sl_resampled = abu.resample_tg(moved_sl, 100) @@ -293,36 +314,39 @@ def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): def clean_by_other_bundle( - b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs + b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 + flipped_sls = b_sls.get_selected_sls(flip=True) if "overlap" in bundle_def[other_bundle_name]: cleaned_idx_overlap = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, - False, + remove=False, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) if "node_thresh" in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_overlap( - b_sls.get_selected_sls(), + flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, - True, + remove=True, + project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) if "core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, False, ) @@ -331,8 +355,8 @@ def clean_by_other_bundle( if "entire_core" in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["entire_core"].lower(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], - np.array(abu.resample_tg(other_bundle_sls, 20)), + np.array(abu.resample_tg(flipped_sls, 100)), + np.array(abu.resample_tg(other_bundle_sls, 100)), img.affine, True, ) @@ -381,10 +405,7 @@ def run_bundle_rec_plan( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_decisions, + recognized_bundles_dict, **segmentation_params, ): # Warp ROIs @@ -392,9 +413,7 @@ def run_bundle_rec_plan( start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( - bundle_dict.transform_rois( - bundle_name, mapping, img.affine, apply_to_recobundles=True - ) + bundle_dict.transform_rois(bundle_name, mapping, img, apply_to_recobundles=True) ) def check_space(roi): @@ -420,20 +439,25 @@ def check_space(roi): ) logger.info(f"Time to prep ROIs: {time() - start_time}s") - b_sls = abu.SlsBeingRecognized( - tg.streamlines, - logger, - segmentation_params["save_intermediates"], - bundle_name, - img, - len(bundle_def.get("include", [])), - ) + if isinstance(tg, abu.SlsBeingRecognized): + # This only occurs when your inside a subbundle, + # in which case we want to keep the same SlsBeingRecognized object so that + # we can keep track of the same streamlines and their orientations + b_sls = tg + else: + b_sls = abu.SlsBeingRecognized( + tg.streamlines, + logger, + segmentation_params["save_intermediates"], + bundle_name, + img, + len(bundle_def.get("include", [])), + ) inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap inputs["bundle_def"] = bundle_def - inputs["max_includes"] = bundle_dict.max_includes inputs["mapping"] = mapping inputs["img"] = img inputs["reg_template"] = reg_template @@ -444,7 +468,7 @@ def check_space(roi): if ( (potential_criterion not in criteria_order_post_other_bundles) and (potential_criterion not in criteria_order_pre_other_bundles) - and (potential_criterion not in bundle_dict.bundle_names) + and (potential_criterion not in recognized_bundles_dict.keys()) and (potential_criterion not in valid_noncriterion) ): raise ValueError( @@ -454,7 +478,7 @@ def check_space(roi): "Valid criteria are:\n" f"{criteria_order_pre_other_bundles}\n" f"{criteria_order_post_other_bundles}\n" - f"{bundle_dict.bundle_names}\n" + f"{recognized_bundles_dict.keys()}\n" f"{valid_noncriterion}\n" ) ) @@ -463,13 +487,14 @@ def check_space(roi): if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: - for ii, bundle_name in enumerate(bundle_dict.bundle_names): - if bundle_name in bundle_def.keys(): - idx = np.where(bundle_decisions[:, ii])[0] + for o_bundle_name in recognized_bundles_dict.keys(): + if o_bundle_name in bundle_def.keys(): clean_by_other_bundle( **inputs, - other_bundle_name=bundle_name, - other_bundle_sls=tg.streamlines[idx], + other_bundle_name=o_bundle_name, + other_bundle_sls=recognized_bundles_dict[ + o_bundle_name + ].get_selected_sls(flip=True), ) for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: @@ -480,6 +505,22 @@ def check_space(roi): ): mahalanobis(**inputs) + # If you don't cross the midline, we remove streamliens + # entirely on the wrong side of the midline here after filtering + if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: + b_sls.initiate_selection("Wrong side of mid.") + zero_coord = preproc_imap["zero_coord"] + lr_axis = preproc_imap["lr_axis"] + avg_side = np.sign( + np.mean( + preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, lr_axis] + - zero_coord, + axis=1, + ) + ) + majority_side = np.sign(np.sum(avg_side)) + b_sls.select(avg_side == majority_side, "Wrong side of mid.") + if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " @@ -490,9 +531,40 @@ def check_space(roi): ) if b_sls: - bundle_to_flip[b_sls.selected_fiber_idxs, bundle_idx] = b_sls.sls_flipped.copy() - bundle_decisions[b_sls.selected_fiber_idxs, bundle_idx] = 1 - if hasattr(b_sls, "roi_closest"): - bundle_roi_closest[b_sls.selected_fiber_idxs, bundle_idx, :] = ( - b_sls.roi_closest.copy() + if "ORG_spectral_subbundles" in bundle_def: + subdict = bundle_def["ORG_spectral_subbundles"] + c_ids = subdict.cluster_IDs + b_sls.initiate_selection( + (f"ORG spectral clustering, {len(c_ids)} subbundles being recognized") + ) + + sub_sft = StatefulTractogram( + b_sls.get_selected_sls(flip=True), img, Space.VOX + ) + cluster_labels = subcluster_by_atlas( + sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 ) + clusters_being_recognized = [] + for c_id in c_ids: + bundle_name = subdict.get_subbundle_name(c_id) + n_roi = len(subdict[bundle_name].get("include", [])) + cluster_b_sls = b_sls.copy(bundle_name, n_roi) + cluster_b_sls.select(cluster_labels == c_id, f"Cluster {c_id}") + clusters_being_recognized.append(cluster_b_sls) + + for ii, c_id in enumerate(c_ids): + bundle_name = subdict.get_subbundle_name(c_id) + run_bundle_rec_plan( + bundle_def["ORG_spectral_subbundles"], + clusters_being_recognized[ii], + mapping, + img, + reg_template, + preproc_imap, + bundle_name, + recognized_bundles_dict, + **segmentation_params, + ) + else: + b_sls.bundle_def = bundle_def + recognized_bundles_dict[bundle_name] = b_sls diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index 1e94106b..618daf45 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -9,7 +9,15 @@ logger = logging.getLogger("AFQ") -def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=False): +def clean_by_overlap( + this_bundle_sls, + other_bundle_sls, + overlap, + img, + remove=False, + project=None, + other_bundle_min_density=0.05, +): """ Cleans a set of streamlines by only keeping (or removing) those with significant overlap with another set of streamlines. @@ -32,6 +40,16 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal removed. If False, streamlines that overlap in more than `overlap` nodes are removed. Default: False. + project : {'A/P', 'I/S', 'L/R', None}, optional + If specified, the overlap calculation is projected along the given axis + before cleaning. For example, 'A/P' projects the streamlines along the + anterior-posterior axis. + Default: None. + other_bundle_min_density : float, optional + A threshold to binarize the density map of `other_bundle_sls`. Voxels + with density values above this threshold (as a fraction of the maximum + density) are considered occupied. + Default: 0.05. Returns ------- @@ -56,6 +74,30 @@ def clean_by_overlap(this_bundle_sls, other_bundle_sls, overlap, img, remove=Fal other_bundle_density_map = dtu.density_map( other_bundle_sls, np.eye(4), img.shape[:3] ) + + if remove: + max_val = other_bundle_density_map.max() + if max_val > 0: + other_bundle_density_map = ( + other_bundle_density_map / max_val + ) > other_bundle_min_density + else: + other_bundle_density_map = np.zeros_like( + other_bundle_density_map, dtype=bool + ) + + if project is not None: + orientation = nib.orientations.aff2axcodes(img.affine) + core_axis = next( + idx for idx, label in enumerate(orientation) if label in project.upper() + ) + + projection = np.sum(other_bundle_density_map, axis=core_axis) + + other_bundle_density_map = np.broadcast_to( + np.expand_dims(projection, axis=core_axis), other_bundle_density_map.shape + ) + fiber_probabilities = dts.values_from_volume( other_bundle_density_map, this_bundle_sls, np.eye(4) ) diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 8bb41e67..52531045 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -1,7 +1,6 @@ import logging from time import time -import dipy.tracking.streamline as dts import immlib import nibabel as nib import numpy as np @@ -13,20 +12,7 @@ @immlib.calc("tol", "dist_to_atlas", "vox_dim") def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): - # We need to calculate the size of a voxel, so we can transform - # from mm to voxel units: - R = img.affine[0:3, 0:3] - vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) - - # Tolerance is set to the square of the distance to the corner - # because we are using the squared Euclidean distance in calls to - # `cdist` to make those calls faster. - if dist_to_waypoint is None: - tol = dts.dist_to_corner(img.affine) - else: - tol = dist_to_waypoint / vox_dim - dist_to_atlas = int(input_dist_to_atlas / vox_dim) - return tol, dist_to_atlas, vox_dim + return abu.tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas) @immlib.calc("fgarray") @@ -41,7 +27,7 @@ def fgarray(tg): return fg_array -@immlib.calc("crosses") +@immlib.calc("crosses", "lr_axis", "zero_coord") def crosses(fgarray, img): """ Classify the streamlines by whether they cross the midline. @@ -59,9 +45,13 @@ def crosses(fgarray, img): lr_axis = idx break - return np.logical_and( - np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), - np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + return ( + np.logical_and( + np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), + np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1), + ), + lr_axis, + zero_coord[lr_axis], ) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 1727bd0f..73602909 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -6,10 +6,12 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram +import AFQ.recognition.sparse_decisions as ars import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict from AFQ.recognition.criteria import run_bundle_rec_plan from AFQ.recognition.preprocess import get_preproc_plan +from AFQ.utils.path import write_json logger = logging.getLogger("AFQ") @@ -155,11 +157,7 @@ def recognize( tg.to_vox() n_streamlines = len(tg) - bundle_decisions = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_to_flip = np.zeros((n_streamlines, len(bundle_dict)), dtype=np.bool_) - bundle_roi_closest = -np.ones( - (n_streamlines, len(bundle_dict), bundle_dict.max_includes), dtype=np.uint32 - ) + recognized_bundles_dict = {} fiber_groups = {} meta = {} @@ -167,7 +165,7 @@ def recognize( preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) logger.info("Assigning Streamlines to Bundles") - for bundle_idx, bundle_name in enumerate(bundle_dict.bundle_names): + for bundle_name in bundle_dict.bundle_names: logger.info(f"Finding Streamlines for {bundle_name}") run_bundle_rec_plan( bundle_dict, @@ -177,10 +175,7 @@ def recognize( reg_template, preproc_imap, bundle_name, - bundle_idx, - bundle_to_flip, - bundle_roi_closest, - bundle_decisions, + recognized_bundles_dict, clip_edges=clip_edges, n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, @@ -195,68 +190,62 @@ def recognize( if save_intermediates is not None: os.makedirs(save_intermediates, exist_ok=True) - bc_path = op.join(save_intermediates, "sls_bundle_decisions.npy") - np.save(bc_path, bundle_decisions) + bc_path = op.join(save_intermediates, "sls_bundle_decisions.json") + write_json( + bc_path, + { + b_name: b_sls.selected_fiber_idxs.tolist() + for b_name, b_sls in recognized_bundles_dict.items() + }, + ) + + sparse_dists = ars.compute_sparse_decisions(recognized_bundles_dict, n_streamlines) - conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1) + conflicts = ars.get_conflict_count(sparse_dists) if conflicts > 0: logger.info( ( "Conflicts in bundle assignment detected. " f"{conflicts} conflicts detected in total out of " f"{n_streamlines} total streamlines. " - "Defaulting to whichever bundle appears first " + "Defaulting to whichever bundle is closest to the include ROI," + "followed by whichever appears first " "in the bundle_dict." ) ) - bundle_decisions = np.concatenate( - (bundle_decisions, np.ones((n_streamlines, 1))), axis=1 - ) - bundle_decisions = np.argmax(bundle_decisions, -1) + + ars.remove_conflicts(sparse_dists, recognized_bundles_dict) # We do another round through, so that we can: # 1. Clip streamlines according to ROIs # 2. Re-orient streamlines logger.info("Re-orienting streamlines to consistent directions") - for bundle_idx, bundle in enumerate(bundle_dict.bundle_names): - logger.info(f"Processing {bundle}") + for b_name, r_bd in recognized_bundles_dict.items(): + logger.info(f"Processing {b_name}") - select_idx = np.where(bundle_decisions == bundle_idx)[0] - - if len(select_idx) == 0: + if len(r_bd.selected_fiber_idxs) == 0: # There's nothing here, set and move to the next bundle: - if "bundlesection" in bundle_dict.get_b_info(bundle): - for sb_name in bundle_dict.get_b_info(bundle)["bundlesection"]: + if "bundlesection" in bundle_dict.get_b_info(b_name): + for sb_name in bundle_dict.get_b_info(b_name)["bundlesection"]: _return_empty(sb_name, return_idx, fiber_groups, img) else: - _return_empty(bundle, return_idx, fiber_groups, img) + _return_empty(b_name, return_idx, fiber_groups, img) continue - # Use a list here, because ArraySequence doesn't support item - # assignment: - select_sl = list(tg.streamlines[select_idx]) - roi_closest = bundle_roi_closest[select_idx, bundle_idx, :] - n_includes = len(bundle_dict.get_b_info(bundle).get("include", [])) - if clip_edges and n_includes > 1: - logger.info("Clipping Streamlines by ROI") - select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, (0, n_includes - 1), in_place=True - ) - - to_flip = bundle_to_flip[select_idx, bundle_idx] - b_def = dict(bundle_dict.get_b_info(bundle_name)) + b_def = r_bd.bundle_def if "bundlesection" in b_def: - for sb_name, sb_include_cuts in bundle_dict.get_b_info(bundle)[ - "bundlesection" - ].items(): + for sb_name, sb_include_cuts in b_def["bundlesection"].items(): bundlesection_select_sl = abu.cut_sls_by_closest( - select_sl, roi_closest, sb_include_cuts, in_place=False + r_bd.get_selected_sls(), + r_bd.roi_closest, + sb_include_cuts, + in_place=False, ) _add_bundle_to_fiber_group( sb_name, bundlesection_select_sl, - select_idx, - to_flip, + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, return_idx, fiber_groups, img, @@ -264,9 +253,15 @@ def recognize( _add_bundle_to_meta(sb_name, b_def, meta) else: _add_bundle_to_fiber_group( - bundle, select_sl, select_idx, to_flip, return_idx, fiber_groups, img + b_name, + r_bd.get_selected_sls(cut=clip_edges), + r_bd.selected_fiber_idxs, + r_bd.sls_flipped, + return_idx, + fiber_groups, + img, ) - _add_bundle_to_meta(bundle, b_def, meta) + _add_bundle_to_meta(b_name, b_def, meta) return fiber_groups, meta diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index d87062f5..4ae02310 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -6,33 +6,33 @@ def _interp3d(roi, sl): return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] -def check_sls_with_inclusion( - sls, include_rois, include_roi_tols): +def check_sls_with_inclusion(sls, include_rois, include_roi_tols): inc_results = np.zeros(len(sls), dtype=tuple) include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] for jj, sl in enumerate(sls): closest = np.zeros(len(include_rois), dtype=np.int32) + dists = np.zeros(len(include_rois), dtype=np.float32) sl = np.asarray(sl) valid = True for ii, roi in enumerate(include_rois): dist = interpolate_scalar_3d(roi, sl)[0] closest[ii] = np.argmin(dist) + dists[ii] = dist[closest[ii]] if dist[closest[ii]] > include_roi_tols[ii]: # Too far from one of them: - inc_results[jj] = (False, []) + inc_results[jj] = (False, [], []) valid = False break # Checked all the ROIs and it was close to all of them if valid: - inc_results[jj] = (True, closest) + inc_results[jj] = (True, closest, dists) return inc_results -def check_sl_with_exclusion(sl, exclude_rois, - exclude_roi_tols): - """ Helper function to check that a streamline is not too close to a +def check_sl_with_exclusion(sl, exclude_rois, exclude_roi_tols): + """Helper function to check that a streamline is not too close to a list of exclusion ROIs. """ for ii, roi in enumerate(exclude_rois): @@ -44,17 +44,15 @@ def check_sl_with_exclusion(sl, exclude_rois, return True -def clean_by_endpoints(streamlines, target, target_idx, tol=0, - flip_sls=None, accepted_idxs=None): +def clean_by_endpoints(fgarray, target, target_idx, tol=0, flip_sls=None): """ Clean a collection of streamlines based on an endpoint ROI. Filters down to only include items that have their start or end points close to the targets. Parameters ---------- - streamlines : sequence of N by 3 arrays - Where N is number of nodes in the array, the collection of - streamlines to filter down to. + fgarray : ndarray of shape (N, M, 3) + Where N is number of streamlines, M is number of nodes. target: Nifti1Image Nifti1Image containing a distance transform of the ROI. target_idx: int. @@ -67,24 +65,30 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0, the endpoint is exactly in the coordinate of the target ROI. flip_sls : 1d array, optional Length is len(streamlines), whether to flip the streamline. - accepted_idxs : 1d array, optional - Boolean array, where entries correspond to eachs streamline, - and streamlines that pass cleaning will be set to 1. Yields ------- boolean array of streamlines that survive cleaning. """ - if accepted_idxs is None: - accepted_idxs = np.zeros(len(streamlines), dtype=np.bool_) + if not isinstance(fgarray, np.ndarray): + raise ValueError( + ( + "fgarray must be a numpy ndarray, you can resample " + "your streamlines using resample_tg in AFQ.recognition.utils" + ) + ) - if flip_sls is None: - flip_sls = np.zeros(len(streamlines)) - flip_sls = flip_sls.astype(int) + n_sls, n_nodes, _ = fgarray.shape - for ii, sl in enumerate(streamlines): - this_idx = target_idx - if flip_sls[ii]: - this_idx = (len(sl) - this_idx - 1) % len(sl) - accepted_idxs[ii] = _interp3d(target, [sl[this_idx]])[0] <= tol + # handle target_idx negative values as wrapping around + effective_idx = target_idx if target_idx >= 0 else (n_nodes + target_idx) + indices = np.full(n_sls, effective_idx) - return accepted_idxs + if flip_sls is not None: + flipped_indices = n_nodes - 1 - effective_idx + indices = np.where(flip_sls.astype(bool), flipped_indices, indices) + + distances = interpolate_scalar_3d( + target.get_fdata(), fgarray[np.arange(n_sls), indices] + )[0] + + return distances <= tol diff --git a/AFQ/recognition/sparse_decisions.py b/AFQ/recognition/sparse_decisions.py new file mode 100644 index 00000000..114cc8ee --- /dev/null +++ b/AFQ/recognition/sparse_decisions.py @@ -0,0 +1,116 @@ +import numpy as np +from scipy.sparse import csr_matrix + + +def compute_sparse_decisions(bundles_being_recognized, n_streamlines): + """ + Compute a sparse matrix of distances to ROIs for the streamlines that are + currently being recognized. This can be used to weight decisions by distance + to ROIs, without having to create a dense matrix of distances for all + streamlines and all bundles. + + Parameters + ---------- + bundles_being_recognized : dict + A dictionary of SlsBeingRecognized objects, keyed by bundle name. + n_streamlines : int + The total number of streamlines in the original tractogram. + + Returns + ------- + csr_matrix + A sparse matrix of shape (number of bundles being recognized, n_streamlines), + where the entry (i, j) is a score: + bundles with ROIs result in weights [2.0 to 3.0] with higher scores + for streamlines closer to ROIs + Non-ROI bundles result in weight 1.0 + Everything else is 0.0 (implicit in sparse matrices) + """ + rows, cols, data = [], [], [] + epsilon = 1e-6 + + global_max_dist = 0.0 + for b in bundles_being_recognized.values(): + if hasattr(b, "roi_dists"): + global_max_dist = max(global_max_dist, np.sum(b.roi_dists, axis=-1).max()) + + norm_factor = global_max_dist + 1.0 + + for b_idx, name in enumerate(bundles_being_recognized.keys()): + bundle = bundles_being_recognized[name] + indices = bundle.selected_fiber_idxs + + if hasattr(bundle, "roi_dists"): + dists = np.sum(bundle.roi_dists, axis=-1) + dists = np.maximum(dists, epsilon) + bundle_weights = dists / norm_factor + else: + bundle_weights = np.full(len(indices), 2.0, dtype=np.float32) + + rows.extend([b_idx] * len(indices)) + cols.extend(indices) + data.extend(bundle_weights) + + sparse_scores = csr_matrix( + (data, (rows, cols)), shape=(len(bundles_being_recognized), n_streamlines) + ) + + # Final Decision: 3.0 - Score + # ROI bundles result in weights [2.0 to 3.0] + # No-ROI bundles result in weight 1.0 + sparse_scores.data = 3.0 - sparse_scores.data + + return sparse_scores + + +def get_conflict_count(sparse_scores): + """ + Count how many streamlines are being considered for more than one bundle + """ + sorted_indices = np.sort(sparse_scores.indices) + is_duplicate = np.diff(sorted_indices) == 0 + num_conflicts = np.sum(is_duplicate) + + return num_conflicts + + +def remove_conflicts(sparse_scores, bundles_being_recognized): + """ + Returns a dictionary of {bundle_name: np.array(accepted_indices)} + """ + coo = sparse_scores.tocoo() + + order = np.lexsort((-coo.data, coo.col)) + + mask = np.concatenate(([True], np.diff(coo.col[order]) != 0)) + winner_rows = coo.row[order][mask] + winner_cols = coo.col[order][mask] + + row_sort = np.argsort(winner_rows) + winner_rows = winner_rows[row_sort] + winner_cols = winner_cols[row_sort] + + num_bundles = len(bundles_being_recognized) + split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) + + for i, b_name in enumerate(bundles_being_recognized.keys()): + b_sls = bundles_being_recognized[b_name] + if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): + raise NotImplementedError( + f"Bundle '{b_name}' has unsorted selected_fiber_idxs. " + "The searchsorted optimization requires sorted indices." + "This is a bug in the implementation of the bundle " + "recognition procedure, please report it to the developers." + ) + + accept_idx = b_sls.initiate_selection(f"{b_name} conflicts") + start, end = split_indices[i], split_indices[i + 1] + bundle_winners = winner_cols[start:end] + + if len(bundle_winners) > 0: + local_positions = np.searchsorted(b_sls.selected_fiber_idxs, bundle_winners) + accept_idx[local_positions] = True + b_sls.select(local_positions, "conflicts") + else: + b_sls.select(accept_idx, "conflicts") + bundles_being_recognized.pop(b_name) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index ef513f14..6e4812de 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -9,6 +9,7 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.stats.analysis import afq_profile +import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd import AFQ.recognition.cleaning as abc import AFQ.registration as reg @@ -22,7 +23,7 @@ hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec") file_dict = afd.read_stanford_hardi_tractography() reg_template = afd.read_mni_template() -mapping = reg.read_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) +mapping = reg.read_old_mapping(file_dict["mapping.nii.gz"], hardi_img, reg_template) streamlines = file_dict["tractography_subsampled.trk"] tg = StatefulTractogram(streamlines, hardi_img, Space.RASMM) tg.to_vox() @@ -83,6 +84,47 @@ def test_segment(): npt.assert_equal(len(clean_sl), len(CST_R_sl)) +def test_segment_mixed_roi(): + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + bundle_info = { + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed", + } + } + + with pytest.raises( + ValueError, + match=( + "When using mixed ROI bundle definitions, and subject space ROIs, " + "resample_subject_to cannot be False." + ), + ): + fiber_groups, _ = recognize( + tg, nib.load(hardi_fdata), mapping, bundle_info, reg_template, 2 + ) + + bundle_info = abd.BundleDict(bundle_info, resample_subject_to=hardi_fdata) + fiber_groups, _ = recognize( + tg, + nib.load(hardi_fdata), + mapping, + bundle_info, + reg_template, + 2, + dist_to_atlas=10, + ) + + # We asked for 2 fiber groups: + npt.assert_equal(len(fiber_groups), 1) + OR_LV1_sl = fiber_groups["OR LV1"] + npt.assert_(len(OR_LV1_sl) == 2) + + @pytest.mark.nightly def test_segment_no_prob(): # What if you don't have probability maps? @@ -134,7 +176,7 @@ def test_segment_clip_edges_api(): def test_segment_reco(): # get bundles for reco method bundles_reco = afd.read_hcp_atlas(16) - bundle_names = ["CST_R", "CST_L"] + bundle_names = ["MCP"] for key in list(bundles_reco): if key not in bundle_names: bundles_reco.pop(key, None) @@ -151,8 +193,8 @@ def test_segment_reco(): ) # This condition should still hold - npt.assert_equal(len(fiber_groups), 2) - npt.assert_(len(fiber_groups["CST_R"]) > 0) + npt.assert_equal(len(fiber_groups), 1) + npt.assert_(len(fiber_groups["MCP"]) > 0) def test_exclusion_ROI(): diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index f5140660..67ef384e 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -4,6 +4,7 @@ from scipy.ndimage import distance_transform_edt import AFQ.recognition.roi as abr +import AFQ.recognition.utils as abu from AFQ.recognition.roi import check_sl_with_exclusion, check_sls_with_inclusion shape = (15, 15, 15) @@ -17,6 +18,7 @@ np.array([[1, 1, 1], [2, 1, 1], [3, 1, 1]]), np.array([[1, 1, 1], [2, 1, 1]]), ] +fgarray = np.array(abu.resample_tg(streamlines, 20)) roi1 = np.ones(shape, dtype=np.float32) roi1[1, 2, 3] = 0 @@ -43,15 +45,15 @@ def test_clean_by_endpoints(): - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 0, 0]) ) # If tol=1, the third streamline also gets included - clean_idx_start = list(abr.clean_by_endpoints(streamlines, start_roi, 0, tol=1)) - clean_idx_end = list(abr.clean_by_endpoints(streamlines, end_roi, -1, tol=1)) + clean_idx_start = list(abr.clean_by_endpoints(fgarray, start_roi, 0, tol=1)) + clean_idx_end = list(abr.clean_by_endpoints(fgarray, end_roi, -1, tol=1)) npt.assert_array_equal( np.logical_and(clean_idx_start, clean_idx_end), np.array([1, 1, 1, 0]) ) @@ -63,22 +65,27 @@ def test_check_sls_with_inclusion(): assert result[0][0] is True assert np.allclose(result[0][1][0], 0) assert np.allclose(result[0][1][1], 2) + assert np.allclose(result[0][2][0], 0) + assert np.allclose(result[0][2][1], 0) assert result[1][0] is False def test_check_sl_with_inclusion_pass(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline1], include_rois, include_roi_tols )[0] assert result is True assert len(dists) == 2 + assert np.allclose(dist_idxs[0], 0) + assert np.allclose(dist_idxs[1], 2) def test_check_sl_with_inclusion_fail(): - result, dists = check_sls_with_inclusion( + result, dist_idxs, dists = check_sls_with_inclusion( [streamline2], include_rois, include_roi_tols )[0] assert result is False + assert dist_idxs == [] assert dists == [] diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index fd34aa18..678c4eed 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -1,3 +1,4 @@ +import copy import logging import os.path as op from time import time @@ -9,11 +10,26 @@ from dipy.io.streamline import save_tractogram from dipy.tracking.distances import bundles_distances_mdf -from AFQ.definitions.mapping import ConformedFnirtMapping - logger = logging.getLogger("AFQ") +def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): + # We need to calculate the size of a voxel, so we can transform + # from mm to voxel units: + R = img.affine[0:3, 0:3] + vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + # Tolerance is set to the square of the distance to the corner + # because we are using the squared Euclidean distance in calls to + # `cdist` to make those calls faster. + if dist_to_waypoint is None: + tol = dts.dist_to_corner(img.affine) + else: + tol = dist_to_waypoint / vox_dim + dist_to_atlas = int(input_dist_to_atlas / vox_dim) + return tol, dist_to_atlas, vox_dim + + def flip_sls(select_sl, idx_to_flip, in_place=False): """ Helper function to flip streamlines @@ -91,47 +107,6 @@ def orient_by_streamline(sls, template_sl): return DM[:, 0] > DM[:, 1] -def move_streamlines(tg, to, mapping, img, save_intermediates=None): - """Move streamlines to or from template space. - - to : str - Either "template" or "subject". - mapping : ConformedMapping - Mapping to use to move streamlines. - img : Nifti1Image - Space to move streamlines to. - """ - tg_og_space = tg.space - if isinstance(mapping, ConformedFnirtMapping): - if to != "subject": - raise ValueError( - "Attempted to transform streamlines to template using " - "unsupported mapping. " - "Use something other than Fnirt." - ) - tg.to_vox() - moved_sl = [] - for sl in tg.streamlines: - moved_sl.append(mapping.transform_inverse_pts(sl)) - else: - tg.to_rasmm() - if to == "template": - volume = mapping.forward - else: - volume = mapping.backward - delta = dts.values_from_volume(volume, tg.streamlines, np.eye(4)) - moved_sl = dts.Streamlines([d + s for d, s in zip(delta, tg.streamlines)]) - moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) - if save_intermediates is not None: - save_tractogram( - moved_sft, - op.join(save_intermediates, f"sls_in_{to}.trk"), - bbox_valid_check=False, - ) - tg.to_space(tg_og_space) - return moved_sft - - def resample_tg(tg, n_points): # reformat for dipy's set_number_of_points if isinstance(tg, np.ndarray): @@ -169,6 +144,8 @@ def select(self, idx, clean_name, cut=False): self.sls_flipped = self.sls_flipped[idx] if hasattr(self, "roi_closest"): self.roi_closest = self.roi_closest[idx] + if hasattr(self, "roi_dists"): + self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time self.logger.info( f"After filtering by {clean_name} (time: {time_taken}s), " @@ -212,3 +189,27 @@ def __bool__(self): def __len__(self): return len(self.selected_fiber_idxs) + + def copy(self, new_name, n_roi): + new_copy = copy.copy(self) + new_copy.b_name = new_name + if n_roi > 0: + if self.n_roi > 0: + raise NotImplementedError( + ( + "You cannot have includes in the original bundle and" + " subbundles; only one or the other." + ) + ) + else: + new_copy.n_roi = n_roi + + new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy() + new_copy.sls_flipped = self.sls_flipped.copy() + + if hasattr(self, "roi_closest"): + new_copy.roi_closest = self.roi_closest.copy() + if hasattr(self, "roi_dists"): + new_copy.roi_dists = self.roi_dists.copy() + + return new_copy diff --git a/AFQ/registration.py b/AFQ/registration.py index f63e784b..28b72799 100644 --- a/AFQ/registration.py +++ b/AFQ/registration.py @@ -4,11 +4,14 @@ import nibabel as nib import numpy as np -from dipy.align import syn_registration +from dipy.align.imaffine import AffineMap from dipy.align.imwarp import DiffeomorphicMap -from dipy.align.streamlinear import whole_brain_slr -__all__ = ["syn_register_dwi", "write_mapping", "read_mapping", "slr_registration"] +__all__ = [ + "read_affine_mapping", + "read_syn_mapping", + "read_old_mapping", +] def reduce_shape(shape): @@ -21,78 +24,55 @@ def reduce_shape(shape): return shape -def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): +def read_syn_mapping(disp, codisp): """ - Register DWI data to a template. + Read a syn registration mapping from a nifti file Parameters - ----------- - dwi : nifti image or str - Image containing DWI data, or full path to a nifti file with DWI. - gtab : GradientTable - The gradients associated with the DWI data - template : nifti image or str, optional + ---------- + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from subject to template - syn_kwargs : key-word arguments for :func:`syn_registration` + codisp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + from template to subject Returns ------- - DiffeomorphicMap object + A :class:`DiffeomorphicMap` object """ - if template is None: - import AFQ.data.fetch as afd - - template = afd.read_mni_template() - if isinstance(template, str): - template = nib.load(template) - - template_data = template.get_fdata() - template_affine = template.affine - - if isinstance(dwi, str): - dwi = nib.load(dwi) - - dwi_affine = dwi.affine - dwi_data = dwi.get_fdata() - mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) - warped_b0, mapping = syn_registration( - mean_b0, - template_data, - moving_affine=dwi_affine, - static_affine=template_affine, - **syn_kwargs, + if isinstance(disp, str): + disp = nib.load(disp) + + if isinstance(codisp, str): + codisp = nib.load(codisp) + + mapping = DiffeomorphicMap( + dim=3, + disp_shape=codisp.get_fdata().shape[:3], + disp_grid2world=None, + domain_shape=disp.get_fdata().shape[:3], + domain_grid2world=None, + codomain_shape=codisp.get_fdata().shape[:3], + codomain_grid2world=None, ) - return warped_b0, mapping - - -def write_mapping(mapping, fname): - """ - Write out a syn registration mapping to file + mapping.forward = disp.get_fdata().astype(np.float32) + mapping.backward = codisp.get_fdata().astype(np.float32) - Parameters - ---------- - mapping : a DiffeomorphicMap object derived from :func:`syn_registration` - fname : str - Full path to the nifti file storing the mapping - - """ - if isinstance(mapping, DiffeomorphicMap): - mapping_imap = np.array([mapping.forward.T, mapping.backward.T]).T - nib.save(nib.Nifti1Image(mapping_imap, mapping.codomain_world2grid), fname) - else: - np.save(fname, mapping.affine) + return mapping -def read_mapping(disp, domain_img, codomain_img, prealign=None): +def read_affine_mapping(affine, domain_img, codomain_img): """ Read a syn registration mapping from a nifti file Parameters ---------- - disp : str, Nifti1Image, or ndarray - If string, file must of an image or ndarray. - If image, contains the mapping displacement field in each voxel - Shape (x, y, z, 3, 2) + affine : str or ndarray + If string, file must of an ndarray. If ndarray, contains affine transformation used for mapping domain_img : str or Nifti1Image @@ -101,13 +81,10 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): Returns ------- - A :class:`DiffeomorphicMap` object + A :class:`AffineMap` object """ - if isinstance(disp, str): - if "nii.gz" in disp: - disp = nib.load(disp) - else: - disp = np.load(disp) + if isinstance(affine, str): + affine = np.load(affine) if isinstance(domain_img, str): domain_img = nib.load(domain_img) @@ -115,79 +92,60 @@ def read_mapping(disp, domain_img, codomain_img, prealign=None): if isinstance(codomain_img, str): codomain_img = nib.load(codomain_img) - if isinstance(disp, nib.Nifti1Image): - mapping = DiffeomorphicMap( - 3, - disp.shape[:3], - disp_grid2world=np.linalg.inv(disp.affine), - domain_shape=domain_img.shape[:3], - domain_grid2world=domain_img.affine, - codomain_shape=codomain_img.shape, - codomain_grid2world=codomain_img.affine, - prealign=prealign, - ) - - disp_data = disp.get_fdata().astype(np.float32) - mapping.forward = disp_data[..., 0] - mapping.backward = disp_data[..., 1] - mapping.is_inverse = True - else: - from AFQ.definitions.mapping import ConformedAffineMapping - - mapping = ConformedAffineMapping( - disp, - domain_grid_shape=reduce_shape(domain_img.shape), - domain_grid2world=domain_img.affine, - codomain_grid_shape=reduce_shape(codomain_img.shape), - codomain_grid2world=codomain_img.affine, - ) + mapping = AffineMap( + affine, + domain_grid_shape=reduce_shape(domain_img.shape), + domain_grid2world=domain_img.affine, + codomain_grid_shape=reduce_shape(codomain_img.shape), + codomain_grid2world=codomain_img.affine, + ) return mapping -def slr_registration( - moving_data, - static_data, - moving_affine=None, - static_affine=None, - moving_shape=None, - static_shape=None, - **kwargs, -): - """Register a source image (moving) to a target image (static). +def read_old_mapping(disp, domain_img, codomain_img, prealign=None): + """ + Warning: This is only used for pyAFQ tests and backwards compatibility. + Read old-style registration mapping from a nifti file. Parameters ---------- - moving : ndarray - The source tractography data to be registered - moving_affine : ndarray - The affine associated with the moving (source) data. - moving_shape : ndarray - The shape of the space associated with the static (target) data. - static : ndarray - The target tractography data for registration - static_affine : ndarray - The affine associated with the static (target) data. - static_shape : ndarray - The shape of the space associated with the static (target) data. - - **kwargs: - kwargs are passed into whole_brain_slr + disp : str or Nifti1Image + If string, file must of an image or ndarray. + If image, contains the mapping displacement field in each voxel + Shape (x, y, z, 3, 2) + + domain_img : str or Nifti1Image + + codomain_img : str or Nifti1Image Returns ------- - AffineMap + A :class:`DiffeomorphicMap` object """ - from AFQ.definitions.mapping import ConformedAffineMapping + if isinstance(disp, str): + disp = nib.load(disp) - _, transform, _, _ = whole_brain_slr( - static_data, moving_data, x0="affine", verbose=False, **kwargs - ) + if isinstance(domain_img, str): + domain_img = nib.load(domain_img) + + if isinstance(codomain_img, str): + codomain_img = nib.load(codomain_img) - return ConformedAffineMapping( - transform, - codomain_grid_shape=reduce_shape(static_shape), - codomain_grid2world=static_affine, - domain_grid_shape=reduce_shape(moving_shape), - domain_grid2world=moving_affine, + mapping = DiffeomorphicMap( + 3, + disp.shape[:3], + disp_grid2world=np.linalg.inv(disp.affine), + domain_shape=domain_img.shape[:3], + domain_grid2world=domain_img.affine, + codomain_shape=codomain_img.shape, + codomain_grid2world=codomain_img.affine, + prealign=prealign, ) + + disp_data = disp.get_fdata().astype(np.float32) + mapping.forward = disp_data[..., 0] + mapping.backward = disp_data[..., 1] + mapping.is_inverse = False + + return mapping diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index bb1d750c..3c041287 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -26,7 +26,6 @@ logger = logging.getLogger("AFQ") -logger.setLevel(logging.INFO) def get_new_signature(og_func, needed_args): diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 33006030..724711ea 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -30,7 +30,7 @@ def export_registered_b0(base_fname, data_imap, mapping): ) if not op.exists(warped_b0_fname): mean_b0 = nib.load(data_imap["b0"]).get_fdata() - warped_b0 = mapping.transform(mean_b0) + warped_b0 = mapping.transform_inverse(mean_b0) warped_b0 = nib.Nifti1Image(warped_b0, data_imap["reg_template"].affine) logger.info(f"Saving {warped_b0_fname}") nib.save(warped_b0, warped_b0_fname) @@ -54,9 +54,7 @@ def template_xform(base_fname, dwi_data_file, data_imap, mapping): base_fname, f"_space-{subject_space}_desc-template_anat.nii.gz" ) if not op.exists(template_xform_fname): - template_xform = mapping.transform_inverse( - data_imap["reg_template"].get_fdata() - ) + template_xform = mapping.transform(data_imap["reg_template"].get_fdata()) template_xform = nib.Nifti1Image(template_xform, data_imap["dwi_affine"]) logger.info(f"Saving {template_xform_fname}") nib.save(template_xform, template_xform_fname) @@ -85,7 +83,7 @@ def export_rois(base_fname, output_dir, dwi_data_file, data_imap, mapping): *bundle_dict.transform_rois( bundle_name, mapping, - data_imap["dwi_affine"], + data_imap["dwi"], base_fname=base_roi_fname, to_space=to_space, ) diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 2781778e..79cd64cb 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -113,7 +113,7 @@ def segment(data_imap, mapping_imap, tractography_imap, segmentation_params): **segmentation_params, ) - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) if len(seg_sft.sft) < 1: raise ValueError("Fatal: No bundles recognized.") diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 128fd157..e57eb393 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -29,7 +29,12 @@ def get_base_fname(output_dir, dwi_data_file): key = key_val_pair.split("-")[0] if key not in used_key_list: fname = fname + key_val_pair + "_" - fname = fname[:-1] + if fname[-1] == "_": + fname = fname[:-1] + else: + # if no key value pairs found, + # have some default base file name + fname = fname + "subject" return fname diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 065f8799..a03b92b6 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -21,7 +21,7 @@ logger = logging.getLogger("AFQ") -def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): +def _viz_prepare_vol(vol, scalar_dict, ref): if vol in scalar_dict.keys(): vol = scalar_dict[vol] @@ -31,8 +31,6 @@ def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): vol = resample(vol, ref) vol = vol.get_fdata() - if xform: - vol = mapping.transform_inverse(vol) vol[np.isnan(vol)] = 0 return vol @@ -81,15 +79,12 @@ def viz_bundles( """ if sbv_lims_bundles is None: sbv_lims_bundles = [None, None] - mapping = mapping_imap["mapping"] scalar_dict = segmentation_imap["scalar_dict"] profiles_file = segmentation_imap["profiles"] t1_img = nib.load(structural_imap["t1_masked"]) shade_by_volume = get_tp(best_scalar, structural_imap, data_imap, tissue_imap) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, t1_img - ) - volume = _viz_prepare_vol(t1_img, False, mapping, scalar_dict, t1_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, t1_img) + volume = _viz_prepare_vol(t1_img, scalar_dict, t1_img) flip_axes = [False, False, False] for i in range(3): @@ -183,7 +178,6 @@ def viz_indivBundle( """ if sbv_lims_indiv is None: sbv_lims_indiv = [None, None] - mapping = mapping_imap["mapping"] bundle_dict = data_imap["bundle_dict"] scalar_dict = segmentation_imap["scalar_dict"] volume_img = nib.load(structural_imap["t1_masked"]) @@ -191,10 +185,8 @@ def viz_indivBundle( profiles = pd.read_csv(segmentation_imap["profiles"]) start_time = time() - volume = _viz_prepare_vol(volume_img, False, mapping, scalar_dict, volume_img) - shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict, volume_img - ) + volume = _viz_prepare_vol(volume_img, scalar_dict, volume_img) + shade_by_volume = _viz_prepare_vol(shade_by_volume, scalar_dict, volume_img) flip_axes = [False, False, False] for i in range(3): @@ -211,6 +203,9 @@ def viz_indivBundle( if "bundlesection" in b_info: for sb_name in b_info["bundlesection"]: segmented_bname_to_roi_bname[sb_name] = b_name + elif "ORG_spectral_subbundles" in b_info: + for sb_name in b_info["ORG_spectral_subbundles"]: + segmented_bname_to_roi_bname[sb_name] = b_name else: segmented_bname_to_roi_bname[b_name] = b_name diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 9771ea68..5c5b1f8b 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -791,11 +791,6 @@ def test_AFQ_data_waypoint(): "sub-01_ses-01_desc-mapping_from-subject_to-mni_xform.nii.gz", ) nib.save(mapping, mapping_file) - reg_prealign_file = op.join( - myafq.export("output_dir"), - "sub-01_ses-01_desc-prealign_from-subject_to-mni_xform.npy", - ) - np.save(reg_prealign_file, np.eye(4)) # Test ROI exporting: myafq.export("rois") diff --git a/AFQ/tests/test_registration.py b/AFQ/tests/test_registration.py index c1b394c1..d1b3144c 100644 --- a/AFQ/tests/test_registration.py +++ b/AFQ/tests/test_registration.py @@ -5,16 +5,12 @@ import nibabel.tmpdirs as nbtmp import numpy as np import numpy.testing as npt -from dipy.align.imwarp import DiffeomorphicMap +from dipy.align.imaffine import AffineMap +from dipy.align.streamlinear import whole_brain_slr from dipy.io.streamline import load_tractogram import AFQ.data.fetch as afd -from AFQ.registration import ( - read_mapping, - slr_registration, - syn_register_dwi, - write_mapping, -) +from AFQ.registration import read_affine_mapping, reduce_shape MNI_T2 = afd.read_mni_template() hardi_img, gtab = dpd.read_stanford_hardi() @@ -50,27 +46,34 @@ def test_slr_registration(): hcp_atlas = load_tractogram(atlas_fname, "same", bbox_valid_check=False) with nbtmp.InTemporaryDirectory() as tmpdir: - mapping = slr_registration( + _, transform, _, _ = whole_brain_slr( streamlines, hcp_atlas.streamlines, - moving_affine=subset_b0_img.affine, - static_affine=subset_t2_img.affine, - moving_shape=subset_b0_img.shape, - static_shape=subset_t2_img.shape, + x0="affine", + verbose=False, progressive=False, greater_than=10, rm_small_clusters=1, rng=np.random.RandomState(seed=8), ) - warped_moving = mapping.transform(subset_b0) + + mapping = AffineMap( + transform, + domain_grid_shape=reduce_shape(subset_b0_img.shape), + domain_grid2world=subset_b0_img.affine, + codomain_grid_shape=reduce_shape(subset_t2_img.shape), + codomain_grid2world=subset_t2_img.affine, + ) + + warped_moving = mapping.transform_inverse(subset_b0) npt.assert_equal(warped_moving.shape, subset_t2.shape) mapping_fname = op.join(tmpdir, "mapping.npy") - write_mapping(mapping, mapping_fname) - file_mapping = read_mapping(mapping_fname, subset_b0_img, subset_t2_img) + np.save(mapping_fname, transform) + file_mapping = read_affine_mapping(mapping_fname, subset_b0_img, subset_t2_img) # Test that it has the same effect on the data: - warped_from_file = file_mapping.transform(subset_b0) + warped_from_file = file_mapping.transform_inverse(subset_b0) npt.assert_equal(warped_from_file, warped_moving) # Test that it is, attribute by attribute, identical: @@ -78,11 +81,3 @@ def test_slr_registration(): assert np.all( mapping.__getattribute__(k) == file_mapping.__getattribute__(k) ) - - -def test_syn_register_dwi(): - warped_b0, mapping = syn_register_dwi( - subset_dwi_data, gtab, template=subset_t2_img, radius=1 - ) - npt.assert_equal(isinstance(mapping, DiffeomorphicMap), True) - npt.assert_equal(warped_b0.shape, subset_t2_img.shape) diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 1624986f..de99c63d 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -30,7 +30,7 @@ def track( seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=2000000, + n_seeds=5000000, random_seeds=True, rng_seed=None, step_size=0.5, @@ -75,7 +75,7 @@ def track( voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 2000000 + to generate within the mask. Default: 5000000 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. Default: True diff --git a/AFQ/utils/streamlines.py b/AFQ/utils/streamlines.py index 8c37801c..6d64003f 100644 --- a/AFQ/utils/streamlines.py +++ b/AFQ/utils/streamlines.py @@ -2,7 +2,7 @@ import numpy as np from dipy.io.stateful_tractogram import Space, StatefulTractogram -from dipy.io.streamline import load_tractogram +from dipy.io.streamline import load_tractogram, save_tractogram try: from trx.io import load as load_trx @@ -11,14 +11,14 @@ except ModuleNotFoundError: has_trx = False +from AFQ.definitions.mapping import ConformedFnirtMapping from AFQ.utils.path import drop_extension, read_json class SegmentedSFT: - def __init__(self, bundles, space, sidecar_info=None): + def __init__(self, bundles, sidecar_info=None): if sidecar_info is None: sidecar_info = {} - reference = None self.bundle_names = [] sls = [] idxs = {} @@ -26,20 +26,17 @@ def __init__(self, bundles, space, sidecar_info=None): idx_count = 0 for b_name in bundles: if isinstance(bundles[b_name], dict): - this_sls = bundles[b_name]["sl"] + this_sft = bundles[b_name]["sl"] this_tracking_idxs[b_name] = bundles[b_name]["idx"] else: - this_sls = bundles[b_name] - if reference is None: - reference = this_sls - this_sls = list(this_sls.streamlines) + this_sft = bundles[b_name] + this_sls = list(this_sft.streamlines) sls.extend(this_sls) new_idx_count = idx_count + len(this_sls) idxs[b_name] = np.arange(idx_count, new_idx_count, dtype=np.uint32) idx_count = new_idx_count self.bundle_names.append(b_name) - self.sft = StatefulTractogram(sls, reference, space) self.bundle_idxs = idxs if len(this_tracking_idxs) > 1: self.this_tracking_idxs = this_tracking_idxs @@ -48,12 +45,13 @@ def __init__(self, bundles, space, sidecar_info=None): self.sidecar_info = sidecar_info self.sidecar_info["bundle_ids"] = {} - dps = np.zeros(len(self.sft.streamlines)) + dps = np.zeros(len(sls)) for ii, bundle_name in enumerate(self.bundle_names): self.sidecar_info["bundle_ids"][f"{bundle_name}"] = ii + 1 dps[self.bundle_idxs[bundle_name]] = ii + 1 - dps = {"bundle": dps} - self.sft.data_per_streamline = dps + self.sft = StatefulTractogram.from_sft( + sls, this_sft, data_per_streamline={"bundle": dps} + ) if self.this_tracking_idxs is not None: for kk, _vv in self.this_tracking_idxs.items(): self.this_tracking_idxs[kk] = ( @@ -108,7 +106,7 @@ def fromfile(cls, trk_or_trx_file, reference="same", sidecar_file=None): else: bundles["whole_brain"] = sft - return cls(bundles, Space.RASMM, sidecar_info) + return cls(bundles, sidecar_info) def split_streamline(streamlines, sl_to_split, split_idx): @@ -140,3 +138,57 @@ def split_streamline(streamlines, sl_to_split, split_idx): ) return streamlines + + +def move_streamlines(tg, to, mapping, img, to_space=None, save_intermediates=None): + """Move streamlines to or from template space. + + to : str + Either "template" or "subject". This determines + whether we will use the forward or backwards displacement field. + mapping : DIPY or pyAFQ mapping + Mapping to use to move streamlines. + img : Nifti1Image + Image defining reference for where the streamlines move to. + to_space : Space or None + If not None, space to move streamlines to after moving them to the + template or subject space. If None, streamlines will be moved back to + their original space. + Default: None. + save_intermediates : str or None + If not None, path to save intermediate tractogram after moving to template + or subject space. + Default: None. + """ + tg_og_space = tg.space + if isinstance(mapping, ConformedFnirtMapping): + if to != "subject": + raise ValueError( + "Attempted to transform streamlines to template using " + "unsupported mapping. " + "Use something other than Fnirt." + ) + tg.to_vox() + moved_sl = [] + for sl in tg.streamlines: + moved_sl.append(mapping.transform_pts(sl)) + moved_sft = StatefulTractogram(moved_sl, img, Space.RASMM) + else: + tg.to_vox() + if to == "template": + moved_sl = mapping.transform_points(tg.streamlines) + else: + moved_sl = mapping.transform_points_inverse(tg.streamlines) + moved_sft = StatefulTractogram(moved_sl, img, Space.VOX) + + if save_intermediates is not None: + save_tractogram( + moved_sft, + op.join(save_intermediates, f"sls_in_{to}.trk"), + bbox_valid_check=False, + ) + if to_space is None: + moved_sft.to_space(tg_og_space) + else: + moved_sft.to_space(to_space) + return moved_sft diff --git a/AFQ/utils/tests/test_streamlines.py b/AFQ/utils/tests/test_streamlines.py index 39c624c7..8931670f 100644 --- a/AFQ/utils/tests/test_streamlines.py +++ b/AFQ/utils/tests/test_streamlines.py @@ -37,7 +37,7 @@ def test_SegmentedSFT(): ), } - seg_sft = aus.SegmentedSFT(bundles, Space.VOX) + seg_sft = aus.SegmentedSFT(bundles) for k1 in bundles.keys(): for sl1, sl2 in zip( bundles[k1].streamlines, seg_sft.get_bundle(k1).streamlines diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index cea03655..a8269811 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -12,7 +12,7 @@ logger = logging.getLogger("AFQ") -def transform_inverse_roi(roi, mapping, bundle_name="ROI"): +def transform_roi(roi, mapping, bundle_name="ROI"): """ After being non-linearly transformed, ROIs tend to have holes in them. We perform a couple of computational geometry operations on the ROI to @@ -40,12 +40,20 @@ def transform_inverse_roi(roi, mapping, bundle_name="ROI"): if isinstance(roi, nib.Nifti1Image): roi = roi.get_fdata() - _roi = mapping.transform_inverse(roi, interpolation="linear") + # dilate binary images to avoid losing small ROIs + if np.unique(roi).size < 3: + scale_factor = max( + np.asarray(mapping.codomain_shape) / np.asarray(mapping.domain_shape) + ) + for _ in range(max(np.ceil(scale_factor) - 1, 0).astype(int)): + roi = binary_dilation(roi) + + _roi = mapping.transform((roi.astype(float)), interpolation="linear") if np.sum(_roi) == 0: logger.warning(f"Lost ROI {bundle_name}, performing automatic binary dilation") _roi = binary_dilation(roi) - _roi = mapping.transform_inverse(_roi, interpolation="linear") + _roi = mapping.transform(_roi.astype(float), interpolation="linear") _roi = patch_up_roi(_roi > 0, bundle_name=bundle_name).astype(np.int32) diff --git a/AFQ/viz/fury_backend.py b/AFQ/viz/fury_backend.py index d8771f20..6ee4d0cf 100644 --- a/AFQ/viz/fury_backend.py +++ b/AFQ/viz/fury_backend.py @@ -203,11 +203,7 @@ def create_gif( def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, color=None, @@ -224,22 +220,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -275,9 +257,7 @@ def visualize_roi( """ if color is None: color = np.array([1, 0, 0]) - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) for i, flip in enumerate(flip_axes): if flip: roi = np.flip(roi, axis=i) diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 0076c1b9..8b2f8745 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -40,8 +40,6 @@ def _inline_interact(figure, show, show_inline): def _to_color_range(num): - if num < 0: - num = 0 if num >= 0.999: num = 0.999 if num <= 0.001: @@ -232,9 +230,10 @@ def _draw_streamlines( def _plot_profiles(profiles, bundle_name, color, fig, scalar): if isinstance(profiles, pd.DataFrame): - sc_max = np.max(profiles[scalar].to_numpy()) - sc_90 = np.percentile(profiles[scalar].to_numpy(), 10) - sc_1 = np.percentile(profiles[scalar].to_numpy(), 99) + all_tp = profiles[scalar].to_numpy() + all_tp = np.max(all_tp) - all_tp + lim_0 = np.percentile(all_tp, 1) + lim_1 = np.percentile(all_tp, 90) profiles = profiles[profiles.tractID == bundle_name] x = profiles["nodeID"] @@ -242,10 +241,14 @@ def _plot_profiles(profiles, bundle_name, color, fig, scalar): line_color = [] for scalar_val in profiles[scalar].to_numpy(): - xformed_scalar = np.minimum( - (sc_max - scalar_val) / (sc_1 - sc_90) + sc_90 + 0.1, 0.999 + brightness = np.minimum( + np.maximum( + scalar_val - lim_0, + 0, + ), + lim_1, ) - line_color.append(_color_arr2str(xformed_scalar * color)) + line_color.append(_color_arr2str(brightness * color)) else: x = np.arange(len(profiles)) y = profiles @@ -512,7 +515,7 @@ def create_gif(figure, file_name, n_frames=30, zoom=2.5, z_offset=0.5, size=(600 def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): - roi = np.where(roi == 1) + roi = np.where(roi > 0) pts = [] for i, flip in enumerate(flip_axes): if flip: @@ -535,11 +538,7 @@ def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): def visualize_roi( roi, - affine_or_mapping=None, - static_img=None, - roi_affine=None, - static_affine=None, - reg_template=None, + resample_to=None, name="ROI", figure=None, flip_axes=None, @@ -556,22 +555,8 @@ def visualize_roi( roi : str or Nifti1Image The ROI information - affine_or_mapping : ndarray, Nifti1Image, or str, optional - An affine transformation or mapping to apply to the ROIs before - visualization. Default: no transform. - - static_img: str or Nifti1Image, optional - Template to resample roi to. - Default: None - - roi_affine: ndarray, optional - Default: None - - static_affine: ndarray, optional - Default: None - - reg_template: str or Nifti1Image, optional - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Default: None name: str, optional @@ -612,9 +597,7 @@ def visualize_roi( color = np.array([0.9999, 0, 0]) if flip_axes is None: flip_axes = [False, False, False] - roi = vut.prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template - ) + roi = vut.prepare_roi(roi, resample_to) if figure is None: figure = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]]) diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index def69151..e3d2c6c9 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -1,3 +1,4 @@ +import colorsys import logging import os.path as op from collections import OrderedDict @@ -13,12 +14,27 @@ from dipy.tracking.streamline import transform_streamlines from PIL import Image, ImageChops -import AFQ.registration as reg import AFQ.utils.streamlines as aus -import AFQ.utils.volume as auv __all__ = ["Viz"] + +def get_distinct_shades(base_rgb, n_steps, hue_shift): + """ + Creates distinct shades by shifting Hue + """ + hh, ll, ss = colorsys.rgb_to_hls(*base_rgb) + shades = [] + + for i in range(n_steps): + offset = i - (n_steps - 1) / 2 + + new_h = (hh + (offset * hue_shift)) % 1.0 + + shades.append(colorsys.hls_to_rgb(new_h, ll, ss)) + return shades + + viz_logger = logging.getLogger("AFQ") tableau_20 = [ (0.12156862745098039, 0.4666666666666667, 0.7058823529411765), @@ -53,6 +69,18 @@ small_font = 20 marker_size = 200 +slf_l_base = tableau_extension[0] +slf_r_base = tableau_extension[1] + +vof_l_base = tableau_20[6] +vof_r_base = tableau_20[7] + +slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1) +slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1) + +vof_l_shades = get_distinct_shades(vof_l_base, 3, hue_shift=0.15) +vof_r_shades = get_distinct_shades(vof_r_base, 3, hue_shift=0.15) + COLOR_DICT = OrderedDict( { "Left Anterior Thalamic": tableau_20[0], @@ -77,8 +105,14 @@ "F_L": tableau_20[12], "Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13], - "Left Superior Longitudinal": tableau_20[14], - "Right Superior Longitudinal": tableau_20[15], + "Left Superior Longitudinal": slf_l_base, + "Right Superior Longitudinal": slf_r_base, + "Left Superior Longitudinal I": slf_l_shades[0], + "Left Superior Longitudinal II": slf_l_shades[1], + "Left Superior Longitudinal III": slf_l_shades[2], + "Right Superior Longitudinal I": slf_r_shades[0], + "Right Superior Longitudinal II": slf_r_shades[1], + "Right Superior Longitudinal III": slf_r_shades[2], "Left Uncinate": tableau_20[16], "UF_L": tableau_20[16], "Right Uncinate": tableau_20[17], @@ -87,10 +121,16 @@ "AF_L": tableau_20[18], "Right Arcuate": tableau_20[19], "AF_R": tableau_20[19], - "Left Posterior Arcuate": tableau_20[6], - "Right Posterior Arcuate": tableau_20[7], - "Left Vertical Occipital": tableau_extension[0], - "Right Vertical Occipital": tableau_extension[1], + "Left Posterior Arcuate": tableau_20[14], + "Right Posterior Arcuate": tableau_20[15], + "Left Vertical Occipital": vof_l_base, + "Right Vertical Occipital": vof_r_base, + "Left Vertical Occipital I": vof_l_shades[0], + "Left Vertical Occipital II": vof_l_shades[1], + "Left Vertical Occipital III": vof_l_shades[2], + "Right Vertical Occipital I": vof_r_shades[0], + "Right Vertical Occipital II": vof_r_shades[1], + "Right Vertical Occipital III": vof_r_shades[2], "median": tableau_20[6], # Paul Tol's palette for callosal bundles "Callosum Orbital": (0.2, 0.13, 0.53), @@ -150,28 +190,28 @@ RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"] BEST_BUNDLE_ORIENTATIONS = { - "Left Anterior Thalamic": ("Sagittal", "Left"), - "Right Anterior Thalamic": ("Sagittal", "Right"), - "Left Corticospinal": ("Sagittal", "Left"), - "Right Corticospinal": ("Sagittal", "Right"), - "Left Cingulum Cingulate": ("Sagittal", "Left"), - "Right Cingulum Cingulate": ("Sagittal", "Right"), - "Forceps Minor": ("Axial", "Top"), - "Forceps Major": ("Axial", "Top"), - "Left Inferior Fronto-occipital": ("Sagittal", "Left"), - "Right Inferior Fronto-occipital": ("Sagittal", "Right"), - "Left Inferior Longitudinal": ("Sagittal", "Left"), - "Right Inferior Longitudinal": ("Sagittal", "Right"), - "Left Superior Longitudinal": ("Axial", "Top"), - "Right Superior Longitudinal": ("Axial", "Top"), - "Left Uncinate": ("Axial", "Bottom"), - "Right Uncinate": ("Axial", "Bottom"), - "Left Arcuate": ("Sagittal", "Left"), - "Right Arcuate": ("Sagittal", "Right"), - "Left Vertical Occipital": ("Coronal", "Back"), - "Right Vertical Occipital": ("Coronal", "Back"), - "Left Posterior Arcuate": ("Coronal", "Back"), - "Right Posterior Arcuate": ("Coronal", "Back"), + "Left Anterior Thalamic": ("Left", "Front", "Top"), + "Right Anterior Thalamic": ("Right", "Front", "Top"), + "Left Corticospinal": ("Left", "Front", "Top"), + "Right Corticospinal": ("Right", "Front", "Top"), + "Left Cingulum Cingulate": ("Left", "Front", "Top"), + "Right Cingulum Cingulate": ("Right", "Front", "Top"), + "Forceps Minor": ("Left", "Front", "Top"), + "Forceps Major": ("Left", "Back", "Top"), + "Left Inferior Fronto-occipital": ("Left", "Front", "Bottom"), + "Right Inferior Fronto-occipital": ("Right", "Front", "Bottom"), + "Left Inferior Longitudinal": ("Left", "Front", "Bottom"), + "Right Inferior Longitudinal": ("Right", "Front", "Bottom"), + "Left Superior Longitudinal": ("Left", "Front", "Top"), + "Right Superior Longitudinal": ("Right", "Front", "Top"), + "Left Uncinate": ("Left", "Front", "Bottom"), + "Right Uncinate": ("Right", "Front", "Bottom"), + "Left Arcuate": ("Left", "Front", "Top"), + "Right Arcuate": ("Right", "Front", "Top"), + "Left Vertical Occipital": ("Left", "Back", "Top"), + "Right Vertical Occipital": ("Right", "Back", "Top"), + "Left Posterior Arcuate": ("Left", "Back", "Top"), + "Right Posterior Arcuate": ("Right", "Back", "Top"), } @@ -512,7 +552,7 @@ def tract_generator( else: if bundle is None: # No selection: visualize all of them: - for bundle_name in seg_sft.bundle_names: + for bundle_name in sorted(seg_sft.bundle_names): idx = seg_sft.bundle_idxs[bundle_name] if len(idx) == 0: continue @@ -591,9 +631,7 @@ def gif_from_pngs(tdir, gif_fname, n_frames, png_fname="tgif", add_zeros=False): io.mimsave(gif_fname, angles) -def prepare_roi( - roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template -): +def prepare_roi(roi, resample_to=None): """ Load the ROI Possibly perform a transformation on an ROI @@ -605,60 +643,27 @@ def prepare_roi( The ROI information. If str, ROI will be loaded using the str as a path. - affine_or_mapping : ndarray, Nifti1Image, or str - An affine transformation or mapping to apply to the ROI before - visualization. Default: no transform. - - static_img: str or Nifti1Image - Template to resample roi to. - - roi_affine: ndarray - - static_affine: ndarray - - reg_template: str or Nifti1Image - Template to use for registration. + resample_to : Nifti1Image, optional + If not None, the ROI will be resampled to the space of this image. Returns ------- ndarray """ viz_logger.info("Preparing ROI...") + if isinstance(roi, str): + roi = nib.load(roi) + + if resample_to is not None: + if not isinstance(roi, nib.Nifti1Image): + raise ValueError( + ("If resampling, roi must be a Nifti1Image or a path to one.") + ) + roi = resample(roi, resample_to) + if not isinstance(roi, np.ndarray): - if isinstance(roi, str): - roi = nib.load(roi).get_fdata() - else: - roi = roi.get_fdata() - - if affine_or_mapping is not None: - if isinstance(affine_or_mapping, np.ndarray): - # This is an affine: - if static_img is None or roi_affine is None or static_affine is None: - raise ValueError( - "If using an affine to transform an ROI, " - "need to also specify all of the following", - "inputs: `static_img`, `roi_affine`, ", - "`static_affine`", - ) - roi = resample( - roi, static_img, moving_affine=roi_affine, static_affine=static_affine - ).get_fdata() - else: - # Assume it is a mapping: - if isinstance(affine_or_mapping, str) or isinstance( - affine_or_mapping, nib.Nifti1Image - ): - if reg_template is None or static_img is None: - raise ValueError( - "If using a mapping to transform an ROI, need to ", - "also specify all of the following inputs: ", - "`reg_template`, `static_img`", - ) - affine_or_mapping = reg.read_mapping( - affine_or_mapping, static_img, reg_template - ) + roi = roi.get_fdata() - roi = auv.transform_inverse_roi(roi, affine_or_mapping).astype(bool) return roi diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index 5927e818..581e662b 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -86,6 +86,46 @@ relation to the Left Arcuate and Inferior Longitudinal fasciculi: 'Left Inferior Longitudinal': {'core': 'Left'}, } + +Mixed space ROIs +================ +Everywhere in the bundle dictionary where an ROI is specified as a path, +be it start, end, include, exclude, or probability map, you can in fact input +a dictionary instead. This dictionary should have two keys: +- 'roi' : path to the ROI Nifti file +- 'space' : either 'template' or 'subject', describing the space the ROI + is currently in. + +Then, for the whole bundle, set "space" to 'mixed'. This allows you to +specify some ROIs in template space and some in subject space for the same +bundle. For example: + +.. code-block:: python + + import os.path as op + import AFQ.api.bundle_dict as abd + import AFQ.data.fetch as afd + + # First, organize the data + afd.organize_stanford_data() + bids_path = op.join(op.expanduser('~'), 'AFQ_data', 'stanford_hardi') + sub_path = op.join(bids_path, 'derivatives', 'vistasoft', 'sub-01', 'ses-01') + dwi_path = op.join(sub_path, 'dwi', 'sub-01_ses-01_dwi.nii.gz') + + lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() + ar_rois = afd.read_ar_templates() + lv1_fname = op.join(lv1_folder, list(lv1_files.keys())[0]) + + # Then, prepare the bundle dictionary + bundle_info = abd.BundleDict({ + "OR LV1": { + "start": {"roi": ar_rois["AAL_Thal_L"], "space": "template"}, + "end": {"roi": lv1_fname, "space": "subject"}, + "space": "mixed" + } + }, resample_subject_to=dwi_path) + + Filtering Order =============== When doing bundle recognition, streamlines are filtered out from the whole diff --git a/docs/source/references.bib b/docs/source/references.bib index 3dba804f..0588b036 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -9,6 +9,27 @@ @article{Grotheer2022 publisher={Nature Publishing Group} } +@article{leong2016white, + title={White-matter tract connecting anterior insula to nucleus accumbens correlates with reduced preference for positively skewed gambles}, + author={Leong, Josiah K and Pestilli, Franco and Wu, Charlene C and Samanez-Larkin, Gregory R and Knutson, Brian}, + journal={Neuron}, + volume={89}, + number={1}, + pages={63--69}, + year={2016}, + publisher={Elsevier} +} + +@article{alkemade2020amsterdam, + title={The Amsterdam Ultra-high field adult lifespan database (AHEAD): A freely available multimodal 7 Tesla submillimeter magnetic resonance imaging database}, + author={Alkemade, Anneke and Mulder, Martijn J and Groot, Josephine M and Isaacs, Bethany R and van Berendonk, Nikita and Lute, Nicky and Isherwood, Scott JS and Bazin, Pierre-Louis and Forstmann, Birte U}, + journal={NeuroImage}, + volume={221}, + pages={117200}, + year={2020}, + publisher={Elsevier} +} + @article{grotheer2023human, title={Human white matter myelinates faster in utero than ex utero}, author={Grotheer, Mareike and Bloom, David and Kruper, John and Richie-Halford, Adam and Zika, Stephanie and Aguilera González, Vicente A and Yeatman, Jason D and Grill-Spector, Kalanit and Rokem, Ariel}, @@ -517,6 +538,16 @@ @InProceedings{Kruper2023 isbn="978-3-031-47292-3" } +@article{zhang2018anatomically, + title={An anatomically curated fiber clustering white matter atlas for consistent white matter tract parcellation across the lifespan}, + author={Zhang, Fan and Wu, Ye and Norton, Isaiah and Rigolo, Laura and Rathi, Yogesh and Makris, Nikos and O'Donnell, Lauren J}, + journal={Neuroimage}, + volume={179}, + pages={429--447}, + year={2018}, + publisher={Elsevier} +} + @ARTICLE{Hua2008, title = "Tract probability maps in stereotaxic spaces: analyses of white matter anatomy and tract-specific quantification", diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index 22fc7542..3f2dbd00 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -269,8 +269,10 @@ "NDARAA948VFH"]["HBNsiteRU"], index_col=[0]) for ind in bundle_counts.index: if ind == "Total Recognized": - threshold = 1000 - elif "Fronto-occipital" in ind or "Orbital" in ind: + threshold = 3000 + elif "Fronto-occipital" in ind: + threshold = 10 + elif "Vertical Occipital" in ind: threshold = 5 else: threshold = 15