From bcc07f1b05af2590ebc949b527b0d9d30786835f Mon Sep 17 00:00:00 2001 From: Shrecki Date: Tue, 15 Oct 2024 12:04:11 +0200 Subject: [PATCH 01/25] example description update. --- examples/forward/source_space_custom_atlas.py | 214 ++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 examples/forward/source_space_custom_atlas.py diff --git a/examples/forward/source_space_custom_atlas.py b/examples/forward/source_space_custom_atlas.py new file mode 100644 index 00000000000..d32eb25f5ed --- /dev/null +++ b/examples/forward/source_space_custom_atlas.py @@ -0,0 +1,214 @@ +""" +.. _ex-source-space-custom-atlas: + +========================================= +Source reconstruction with a custom atlas +========================================= + +This example shows how to use a custom atlas when performing source reconstruction. +We showcase on the sample dataset how to apply the Yeo atlas during source + reconstruction. +You should replace the atlas with your own atlas and your own subject. + +Any atlas can be used instead of Yeo, provided each region contains a single + label (ie: no probabilistic atlas). + +.. warning:: This tutorial uses FSL and FreeSurfer to perform MRI + coregistrations. If you use a different software, replace the + coregistration function appropriately. +""" + +# Authors: Fabrice Guibert +# +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# %% + +import subprocess +from pathlib import Path as Path + +import nilearn.datasets + +import mne +import mne.datasets +from mne._freesurfer import read_freesurfer_lut +from mne.minimum_norm import apply_inverse, make_inverse_operator + +# The atlas is in a template space. We download here as an example Yeo +# 2011's atlas, which is in the MNI152 1mm template space. +# Replace this part with your atlas and the template space you used. + +nilearn.datasets.fetch_atlas_yeo_2011() # Download Yeo 2011 +yeo_path = Path( + nilearn.datasets.get_data_dirs()[0], "yeo_2011", "Yeo_JNeurophysiol11_MNI152" +) +atlas_path = Path(yeo_path, "Yeo2011_7Networks_MNI152_FreeSurferConformed1mm.nii.gz") +atlas_template_T1_path = Path(yeo_path, "FSL_MNI152_FreeSurferConformed_1mm.nii.gz") + +# The participant's T1 data. Here, we consider the sample dataset +# The brain should be skull stripped. After freesurfer preprocessing, +# you can either use brain.mgz or antsdn.brain.mgz +data_path = mne.datasets.sample.data_path() +subjects_mri_dir = Path(data_path, "subjects") +subject_mri_path = Path(subjects_mri_dir, "sample") +mri_path = Path(subject_mri_path, "mri") +T1_participant_path = Path(mri_path, "brain.mgz") + +assert atlas_path.is_file() +assert atlas_template_T1_path.is_file() +assert T1_participant_path.is_file() + +# %% +# The first step is to put the atlas in subject space. +# We show this step with FSL and freesurfer with linear coregistration. +# If your atlas is already in participant space, +# you can skip this step. Coregistration is done in two steps: +# compute the atlas template to subject T1 transform and apply this transform +# to the atlas file with nearest neighbour interpolation. + +# FSL does not know how to read .mgz, so we need to convert the T1 to nifti format +# With FreeSurfer: +T1_participant_nifti = Path(str(T1_participant_path).replace("mgz", "nii.gz")) +subprocess.run(["mri_convert", T1_participant_path, T1_participant_nifti]) + +# Compute template to subject anatomical transform using flirt. +# If you wish to use other tools such as ANTs, replace these commands +# appropriately. +template_to_anat_transform_path = Path(mri_path, "template_to_anat.mat") +subprocess.run( + [ + "flirt", + "-in", + atlas_template_T1_path, + "-ref", + T1_participant_nifti, + "-out", + Path(mri_path, "T1_atlas_coreg"), + "-omat", + template_to_anat_transform_path, + ] +) + +# Apply the transform to the atlas +atlas_participant = Path(mri_path, "yeo_atlas.nii.gz") + +subprocess.run( + [ + "flirt", + "-in", + atlas_path, + "-ref", + T1_participant_nifti, + "-out", + atlas_participant, + "-applyxfm -init", + template_to_anat_transform_path, + "-interp nearestneighbour", + ] +) + +# Convert resulting atlas from nifti to mgz +# The filename must finish with aseg, to indicate to MNE that it is +# a proper atlas segmentation. +atlas_converted = Path(str(atlas_participant).replace(".nii.gz", "aseg.mgz")) +subprocess.run(["mri_convert", atlas_participant, atlas_converted]) + +assert T1_participant_nifti.is_file() +assert template_to_anat_transform_path.is_file() +assert atlas_participant.is_file() +assert atlas_converted.is_file() + +# %% +# With the atlas in participant space, we're still missing one ingredient. +# We need a dictionary mapping label to region ID / value in the fMRI. +# In FreeSurfer and atlases, these typically take the form of lookup tables. +# You can also build the dictionary by hand. + +atlas_labels = read_freesurfer_lut(Path(yeo_path, "Yeo2011_7Networks_ColorLUT.txt"))[0] +print(atlas_labels) + +# Drop the key corresponding to outer region +del atlas_labels["NONE"] + +# %% +# For the purpose of source reconstruction, let's create a volumetric +# source estimate and source reconstruction with e.g eLORETA. +vol_src = mne.setup_volume_source_space( + "sample", + subjects_dir=subjects_mri_dir, + surface=Path(subject_mri_path, "bem", "inner_skull.surf"), +) + +fif_path = Path(data_path, "MEG", "sample") +fname_trans = Path(fif_path, "sample_audvis_raw-trans.fif") +raw_fname = Path(fif_path, "sample_audvis_filt-0-40_raw.fif") + +model = mne.make_bem_model( + subject="sample", subjects_dir=subjects_mri_dir, ico=4, conductivity=(0.33,) +) +bem_sol = mne.make_bem_solution(model) + +info = mne.io.read_info(raw_fname) +info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, exclude=[])) + +# Build the forward model with our custom source +fwd = mne.make_forward_solution(info, trans=fname_trans, src=vol_src, bem=bem_sol) + + +# Now perform typical source reconstruction steps +raw = mne.io.read_raw_fif(raw_fname) # already has an average reference +events = mne.find_events(raw, stim_channel="STI 014") + +event_id = dict(aud_l=1) # event trigger and conditions +tmin = -0.2 # start of each epoch (200ms before the trigger) +tmax = 0.5 # end of each epoch (500ms after the trigger) +raw.info["bads"] = ["MEG 2443", "EEG 053"] +baseline = (None, 0) # means from the first instant to t = 0 +reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) + +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=("meg", "eog"), + baseline=baseline, + reject=reject, +) + +# Compute noise covariances +noise_cov = mne.compute_covariance( + epochs, tmax=0.0, method=["shrunk", "empirical"], rank=None, verbose=True +) + +# Compute evoked response +evoked = epochs.average().pick("meg") + +# Make inverse operator +inverse_operator = make_inverse_operator( + evoked.info, fwd, noise_cov, loose=1, depth=0.8 +) + +# Compute source time courses +method = "eLORETA" +snr = 3.0 +lambda2 = 1.0 / snr**2 +stc, residual = apply_inverse( + evoked, + inverse_operator, + lambda2, + method=method, + pick_ori=None, + return_residual=True, + verbose=True, +) + +# %% +# Then, we can finally use our atlas! +label_tcs = stc.extract_label_time_course( + labels=(atlas_converted, atlas_labels), src=vol_src +) +label_tcs.shape From 99b593aa7608e7f66add8d1d3c15f255bf6684ba Mon Sep 17 00:00:00 2001 From: Shrecki Date: Wed, 16 Oct 2024 11:04:43 +0200 Subject: [PATCH 02/25] pca_flip allowed for volumetric --- mne/source_estimate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 64c2d588f57..14c0f9d3226 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3640,7 +3640,7 @@ def _get_default_label_modes(): def _get_allowed_label_modes(stc): if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate): - return ("mean", "max", "auto") + return ("mean", "pca_flip" "max", "auto") else: return _get_default_label_modes() From 00e04872b1a9cbf6545a73ae175ae89365f2dd44 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 30 Jan 2025 20:49:48 +0100 Subject: [PATCH 03/25] [FEAT] Volumetric PCA flip implementation Extension of PCA flip in volumetric setting. Underlying logic is pretty much identical to cortical case with respect to flip vector creation and PCA itself Added a check in PCA flip in case the flip is None or number of vertices is below 2, in which case PCA will return a trivial estimate of the signal Complete with its own unit test case --- mne/label.py | 37 ++++++-- mne/source_estimate.py | 37 ++++++-- mne/tests/test_source_estimate.py | 152 +++++++++++++++++++++++++++++- 3 files changed, 210 insertions(+), 16 deletions(-) diff --git a/mne/label.py b/mne/label.py index 7c15bd026e9..e8dd96dfbd4 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1460,22 +1460,47 @@ def label_sign_flip(label, src): flip : array Sign flip vector (contains 1 or -1). """ - if len(src) != 2: - raise ValueError("Only source spaces with 2 hemisphers are accepted") + if len(src) > 2 or len(src) == 0: + raise ValueError( + "Only source spaces with between one and two " + + "hemispheres are accepted, was {len(src)}" + ) - lh_vertno = src[0]["vertno"] - rh_vertno = src[1]["vertno"] + if len(src) == 1 and label.hemi == "both": + raise ValueError( + 'Cannot use hemisphere label "both" when source' + + "space contains a single hemisphere." + ) + + isbi_hemi = len(src) == 2 + lh_vertno = None + rh_vertno = None + + lh_id = -1 + rh_id = -1 + if isbi_hemi: + lh_id = 0 + rh_id = 1 + lh_vertno = src[0]["vertno"] + rh_vertno = src[1]["vertno"] + elif label.hemi == "lh": + lh_vertno = src[0]["vertno"] + elif label.hemi == "rh": + rh_id = 0 + rh_vertno = src[0]["vertno"] + else: + raise Exception(f'Unknown hemisphere type "{label.hemi}"') # get source orientations ori = list() if label.hemi in ("lh", "both"): vertices = label.vertices if label.hemi == "lh" else label.lh.vertices vertno_sel = np.intersect1d(lh_vertno, vertices) - ori.append(src[0]["nn"][vertno_sel]) + ori.append(src[lh_id]["nn"][vertno_sel]) if label.hemi in ("rh", "both"): vertices = label.vertices if label.hemi == "rh" else label.rh.vertices vertno_sel = np.intersect1d(rh_vertno, vertices) - ori.append(src[1]["nn"][vertno_sel]) + ori.append(src[rh_id]["nn"][vertno_sel]) if len(ori) == 0: raise Exception(f'Unknown hemisphere type "{label.hemi}"') ori = np.concatenate(ori, axis=0) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 14c0f9d3226..75b73fdb299 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3370,12 +3370,19 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): def _pca_flip(flip, data): - U, s, V = _safe_svd(data, full_matrices=False) - # determine sign-flip - sign = np.sign(np.dot(U[:, 0], flip)) - # use average power in label for scaling - scale = np.linalg.norm(s) / np.sqrt(len(data)) - return sign * scale * V[0] + result = None + if flip is None: + result = 0 + elif data.shape[0] < 2: + result = data.mean(axis=0) # Trivial accumulator + else: + U, s, V = _safe_svd(data, full_matrices=False) + # determine sign-flip + sign = np.sign(np.dot(U[:, 0], flip)) + # use average power in label for scaling + scale = np.linalg.norm(s) / np.sqrt(len(data)) + result = sign * scale * V[0] + return result _label_funcs = { @@ -3428,6 +3435,10 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # only computes vertex indices and label_flip will be list of None. from .label import BiHemiLabel, Label, label_sign_flip + # logger.info("Selected mode: " + mode) + # print("Entering _prepare_label_extraction") + # print("Selected mode: " + mode) + # if source estimate provided in stc, get vertices from source space and # check that they are the same as in the stcs _check_stc_src(stc, src) @@ -3440,6 +3451,7 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): bad_labels = list() for li, label in enumerate(labels): + # print("Mode: " + mode + " li: " + str(li) + " label: " + str(label)) subject = label["subject"] if use_sparse else label.subject # stc and src can each be None _check_subject( @@ -3505,6 +3517,9 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # So if we override vertno with the stc vertices, it will pick # the correct normals. with _temporary_vertices(src, stc.vertices): + # print(f"src: {src[:2]}") + # print(f"len(src): {len(src[:2])}") + this_flip = label_sign_flip(label, src[:2])[:, None] label_vertidx.append(this_vertidx) @@ -3639,8 +3654,10 @@ def _get_default_label_modes(): def _get_allowed_label_modes(stc): - if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate): - return ("mean", "pca_flip" "max", "auto") + if isinstance(stc, _BaseVectorSourceEstimate): + return ("mean", "max", "auto") + elif isinstance(stc, _BaseVolSourceEstimate): + return ("mean", "pca_flip", "max", "auto") else: return _get_default_label_modes() @@ -3729,6 +3746,10 @@ def _gen_extract_label_time_course( this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] else: this_data = stc.data[vertidx] + # if flip is None: # Happens if fewer than 2 vertices in the label + # if this_data.shape[] + # label_tc[i] = 0 + # else: label_tc[i] = func(flip, this_data) if mode is not None: diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 7eafd2517b2..9ed0786388f 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -2,6 +2,9 @@ # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import os + +os.environ["MNE_DATASETS_TESTING_PATH"] = "/home/guibertf/mne_data/MNE-testing-data" import os import re @@ -679,6 +682,150 @@ def test_center_of_mass(): assert_equal(np.round(t, 2), 0.12) +@testing.requires_testing_data +@pytest.mark.parametrize( + "label_type, mri_res, test_label, cf, call", + [ + (str, False, False, "head", "meth"), # head frame + (str, False, str, "mri", "func"), # fastest, default for testing + (str, True, str, "mri", "func"), # fastest, default for testing + (str, True, False, "mri", "func"), # mri_resolution + (list, True, False, "mri", "func"), # volume label as list + (dict, True, False, "mri", "func"), # volume label as dict + ], +) +def test_extract_label_time_course_volume_pca_flip( + src_volume_labels, label_type, mri_res, test_label, cf, call +): + """Test extraction of label timecourses on VolumetricSourceEstimate with PCA.""" + # Setup of data + src_labels, volume_labels, lut = src_volume_labels + n_tot = 46 + assert n_tot == len(src_labels) + inv = read_inverse_operator(fname_inv_vol) + if cf == "head": + src = inv["src"] + else: + src = read_source_spaces(fname_src_vol) + klass = VolVectorSourceEstimate._scalar_class + vertices = [src[0]["vertno"]] + n_verts = len(src[0]["vertno"]) + n_times = 50 + data = np.arange(1, n_verts + 1) + end_shape = (n_times,) + data = np.repeat(data[..., np.newaxis], n_times, -1) + stcs = [klass(data.astype(float), vertices, 0, 1)] + + def eltc(*args, **kwargs): + if call == "func": + return extract_label_time_course(stcs, *args, **kwargs) + else: + return [stcs[0].extract_label_time_course(*args, **kwargs)] + + # triage "labels" argument + if mri_res: + # All should be there + missing = [] + else: + # Nearest misses these + missing = [ + "Left-vessel", + "Right-vessel", + "5th-Ventricle", + "non-WM-hypointensities", + ] + n_want = len(src_labels) + if label_type is str: + labels = fname_aseg + elif label_type is list: + labels = (fname_aseg, volume_labels) + else: + assert label_type is dict + labels = (fname_aseg, {k: lut[k] for k in volume_labels}) + assert mri_res + assert len(missing) == 0 + # we're going to add one that won't exist + missing = ["intentionally_bad"] + labels[1][missing[0]] = 10000 + n_want += 1 + n_tot += 1 + n_want -= len(missing) + + # _volume_labels(src, labels, mri_resolution) + # actually do the testing + from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels + + labels_expanded = _volume_labels(src, labels, mri_res) + _, src_flip = _prepare_label_extraction( + stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res) + ) + + mode = "pca_flip" + with catch_logging() as log: + label_tc = eltc( + labels, + src, + mode=mode, + allow_empty="ignore", + mri_resolution=mri_res, + verbose=True, + ) + log = log.getvalue() + assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None + if len(missing): + # assert that the missing ones get logged + assert "does not contain" in log + assert repr(missing) in log + else: + assert "does not contain" not in log + assert f"\n{n_want}/{n_tot} atlas regions had at least" in log + assert len(label_tc) == 1 + label_tc = label_tc[0] + assert label_tc.shape == (n_tot,) + end_shape + assert label_tc.shape == (n_tot, n_times) + # let's test some actual values by trusting the masks provided by + # setup_volume_source_space. mri_resolution=True does some + # interpolation so we should not expect equivalence, False does + # nearest so we should. + if mri_res: + rtol = 0.8 # max much more sensitive + else: + rtol = 0.0 + for si, s in enumerate(src_labels): + func = _pca_flip + these = data[np.isin(src[0]["vertno"], s["vertno"])] + print(these.shape) + assert len(these) == s["nuse"] + if si == 0 and s["seg_name"] == "Unknown": + continue # unknown is crappy + if s["nuse"] == 0: + want = 0.0 + if mri_res: + # this one is totally due to interpolation, so no easy + # test here + continue + else: + if src_flip[si] is None: + want = None + else: + want = func(src_flip[si], these) + if want is not None: + assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol) + # compare with in_label, only on every fourth for speed + if test_label is not False and si % 4 == 0: + label = s["seg_name"] + if test_label is int: + label = lut[label] + in_label = stcs[0].in_label(label, fname_aseg, src).data + assert in_label.shape == (s["nuse"],) + end_shape + if np.all(want == 0): + assert in_label.shape[0] == 0 + else: + if src_flip[si] is not None: + in_label = func(src_flip[si], in_label) + assert_allclose(in_label, want, atol=1e-6, rtol=rtol) + + @testing.requires_testing_data @pytest.mark.parametrize("kind", ("surface", "mixed")) @pytest.mark.parametrize("vector", (False, True)) @@ -943,7 +1090,8 @@ def eltc(*args, **kwargs): if cf == "head" and not mri_res: # some missing with pytest.warns(RuntimeWarning, match="any vertices"): eltc(labels, src, allow_empty=True, mri_resolution=mri_res) - for mode in ("mean", "max"): + modes = ("mean", "max") if vector else ("mean", "max") + for mode in modes: with catch_logging() as log: label_tc = eltc( labels, @@ -1469,7 +1617,7 @@ def objective(x): assert_allclose(directions, want_nn, atol=2e-6) -@testing.requires_testing_data +# @testing.requires_testing_data def test_source_estime_project_label(): """Test projecting a source estimate onto direction of max power.""" fwd = read_forward_solution(fname_fwd) From 2e56dcb438679302c7dc11ab494fa495fe6d9e89 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 31 Jan 2025 10:42:37 +0100 Subject: [PATCH 04/25] Removed source_space_custom_atlas example - should be object of separate PR --- examples/forward/source_space_custom_atlas.py | 214 ------------------ 1 file changed, 214 deletions(-) delete mode 100644 examples/forward/source_space_custom_atlas.py diff --git a/examples/forward/source_space_custom_atlas.py b/examples/forward/source_space_custom_atlas.py deleted file mode 100644 index d32eb25f5ed..00000000000 --- a/examples/forward/source_space_custom_atlas.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -.. _ex-source-space-custom-atlas: - -========================================= -Source reconstruction with a custom atlas -========================================= - -This example shows how to use a custom atlas when performing source reconstruction. -We showcase on the sample dataset how to apply the Yeo atlas during source - reconstruction. -You should replace the atlas with your own atlas and your own subject. - -Any atlas can be used instead of Yeo, provided each region contains a single - label (ie: no probabilistic atlas). - -.. warning:: This tutorial uses FSL and FreeSurfer to perform MRI - coregistrations. If you use a different software, replace the - coregistration function appropriately. -""" - -# Authors: Fabrice Guibert -# -# License: BSD-3-Clause -# Copyright the MNE-Python contributors. - -# %% - -import subprocess -from pathlib import Path as Path - -import nilearn.datasets - -import mne -import mne.datasets -from mne._freesurfer import read_freesurfer_lut -from mne.minimum_norm import apply_inverse, make_inverse_operator - -# The atlas is in a template space. We download here as an example Yeo -# 2011's atlas, which is in the MNI152 1mm template space. -# Replace this part with your atlas and the template space you used. - -nilearn.datasets.fetch_atlas_yeo_2011() # Download Yeo 2011 -yeo_path = Path( - nilearn.datasets.get_data_dirs()[0], "yeo_2011", "Yeo_JNeurophysiol11_MNI152" -) -atlas_path = Path(yeo_path, "Yeo2011_7Networks_MNI152_FreeSurferConformed1mm.nii.gz") -atlas_template_T1_path = Path(yeo_path, "FSL_MNI152_FreeSurferConformed_1mm.nii.gz") - -# The participant's T1 data. Here, we consider the sample dataset -# The brain should be skull stripped. After freesurfer preprocessing, -# you can either use brain.mgz or antsdn.brain.mgz -data_path = mne.datasets.sample.data_path() -subjects_mri_dir = Path(data_path, "subjects") -subject_mri_path = Path(subjects_mri_dir, "sample") -mri_path = Path(subject_mri_path, "mri") -T1_participant_path = Path(mri_path, "brain.mgz") - -assert atlas_path.is_file() -assert atlas_template_T1_path.is_file() -assert T1_participant_path.is_file() - -# %% -# The first step is to put the atlas in subject space. -# We show this step with FSL and freesurfer with linear coregistration. -# If your atlas is already in participant space, -# you can skip this step. Coregistration is done in two steps: -# compute the atlas template to subject T1 transform and apply this transform -# to the atlas file with nearest neighbour interpolation. - -# FSL does not know how to read .mgz, so we need to convert the T1 to nifti format -# With FreeSurfer: -T1_participant_nifti = Path(str(T1_participant_path).replace("mgz", "nii.gz")) -subprocess.run(["mri_convert", T1_participant_path, T1_participant_nifti]) - -# Compute template to subject anatomical transform using flirt. -# If you wish to use other tools such as ANTs, replace these commands -# appropriately. -template_to_anat_transform_path = Path(mri_path, "template_to_anat.mat") -subprocess.run( - [ - "flirt", - "-in", - atlas_template_T1_path, - "-ref", - T1_participant_nifti, - "-out", - Path(mri_path, "T1_atlas_coreg"), - "-omat", - template_to_anat_transform_path, - ] -) - -# Apply the transform to the atlas -atlas_participant = Path(mri_path, "yeo_atlas.nii.gz") - -subprocess.run( - [ - "flirt", - "-in", - atlas_path, - "-ref", - T1_participant_nifti, - "-out", - atlas_participant, - "-applyxfm -init", - template_to_anat_transform_path, - "-interp nearestneighbour", - ] -) - -# Convert resulting atlas from nifti to mgz -# The filename must finish with aseg, to indicate to MNE that it is -# a proper atlas segmentation. -atlas_converted = Path(str(atlas_participant).replace(".nii.gz", "aseg.mgz")) -subprocess.run(["mri_convert", atlas_participant, atlas_converted]) - -assert T1_participant_nifti.is_file() -assert template_to_anat_transform_path.is_file() -assert atlas_participant.is_file() -assert atlas_converted.is_file() - -# %% -# With the atlas in participant space, we're still missing one ingredient. -# We need a dictionary mapping label to region ID / value in the fMRI. -# In FreeSurfer and atlases, these typically take the form of lookup tables. -# You can also build the dictionary by hand. - -atlas_labels = read_freesurfer_lut(Path(yeo_path, "Yeo2011_7Networks_ColorLUT.txt"))[0] -print(atlas_labels) - -# Drop the key corresponding to outer region -del atlas_labels["NONE"] - -# %% -# For the purpose of source reconstruction, let's create a volumetric -# source estimate and source reconstruction with e.g eLORETA. -vol_src = mne.setup_volume_source_space( - "sample", - subjects_dir=subjects_mri_dir, - surface=Path(subject_mri_path, "bem", "inner_skull.surf"), -) - -fif_path = Path(data_path, "MEG", "sample") -fname_trans = Path(fif_path, "sample_audvis_raw-trans.fif") -raw_fname = Path(fif_path, "sample_audvis_filt-0-40_raw.fif") - -model = mne.make_bem_model( - subject="sample", subjects_dir=subjects_mri_dir, ico=4, conductivity=(0.33,) -) -bem_sol = mne.make_bem_solution(model) - -info = mne.io.read_info(raw_fname) -info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, exclude=[])) - -# Build the forward model with our custom source -fwd = mne.make_forward_solution(info, trans=fname_trans, src=vol_src, bem=bem_sol) - - -# Now perform typical source reconstruction steps -raw = mne.io.read_raw_fif(raw_fname) # already has an average reference -events = mne.find_events(raw, stim_channel="STI 014") - -event_id = dict(aud_l=1) # event trigger and conditions -tmin = -0.2 # start of each epoch (200ms before the trigger) -tmax = 0.5 # end of each epoch (500ms after the trigger) -raw.info["bads"] = ["MEG 2443", "EEG 053"] -baseline = (None, 0) # means from the first instant to t = 0 -reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) - -epochs = mne.Epochs( - raw, - events, - event_id, - tmin, - tmax, - proj=True, - picks=("meg", "eog"), - baseline=baseline, - reject=reject, -) - -# Compute noise covariances -noise_cov = mne.compute_covariance( - epochs, tmax=0.0, method=["shrunk", "empirical"], rank=None, verbose=True -) - -# Compute evoked response -evoked = epochs.average().pick("meg") - -# Make inverse operator -inverse_operator = make_inverse_operator( - evoked.info, fwd, noise_cov, loose=1, depth=0.8 -) - -# Compute source time courses -method = "eLORETA" -snr = 3.0 -lambda2 = 1.0 / snr**2 -stc, residual = apply_inverse( - evoked, - inverse_operator, - lambda2, - method=method, - pick_ori=None, - return_residual=True, - verbose=True, -) - -# %% -# Then, we can finally use our atlas! -label_tcs = stc.extract_label_time_course( - labels=(atlas_converted, atlas_labels), src=vol_src -) -label_tcs.shape From a677854ac92e9e8138fbcad99f3f2911136fac02 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 31 Jan 2025 10:56:38 +0100 Subject: [PATCH 05/25] [MISC] Changelog update --- doc/changes/devel/13092.newfeature.rst | 1 + doc/changes/names.inc | 1 + 2 files changed, 2 insertions(+) create mode 100644 doc/changes/devel/13092.newfeature.rst diff --git a/doc/changes/devel/13092.newfeature.rst b/doc/changes/devel/13092.newfeature.rst new file mode 100644 index 00000000000..96bf4de6d8b --- /dev/null +++ b/doc/changes/devel/13092.newfeature.rst @@ -0,0 +1 @@ +Add PCA-flip to pool sources in source reconstruction in :func:`mne.source_estimate.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 5a58ac0fa34..8f5efec753d 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -85,6 +85,7 @@ .. _Evgeny Goldstein: https://github.com/evgenygoldstein .. _Ezequiel Mikulan: https://github.com/ezemikulan .. _Ezequiel Mikulan: https://github.com/ezemikulan +.. _Fabrice Guibert: https://github.com/Shrecki .. _Fahimeh Mamashli: https://github.com/fmamashli .. _Farzin Negahbani: https://github.com/Farzin-Negahbani .. _Federico Raimondo: https://github.com/fraimondo From 8a74ffe1622c117113abd70258b98eadff2f5213 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 31 Jan 2025 11:20:08 +0100 Subject: [PATCH 06/25] [MISC] Fixed changelog --- doc/changes/devel/13092.newfeature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/devel/13092.newfeature.rst b/doc/changes/devel/13092.newfeature.rst index 96bf4de6d8b..90f8c60b4f7 100644 --- a/doc/changes/devel/13092.newfeature.rst +++ b/doc/changes/devel/13092.newfeature.rst @@ -1 +1 @@ -Add PCA-flip to pool sources in source reconstruction in :func:`mne.source_estimate.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`. +Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`. From dd15522d2de99872fe1eded7d081629f35588c67 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 31 Jan 2025 14:22:07 +0100 Subject: [PATCH 07/25] [FIX] Removed erroneous path from test case --- mne/tests/test_source_estimate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 791a6b740f8..6111811dc12 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -2,10 +2,6 @@ # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import os - -os.environ["MNE_DATASETS_TESTING_PATH"] = "/home/guibertf/mne_data/MNE-testing-data" - import os import re from contextlib import nullcontext From 6839d736d916deaab95ad056d1293edc7be290d6 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:03:47 +0000 Subject: [PATCH 08/25] [autofix.ci] apply automated fixes --- mne/tests/test_source_estimate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 6111811dc12..001b4c49c20 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -2,6 +2,7 @@ # Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. + import os import re from contextlib import nullcontext From 379614eb61ea04b128a8860f47e47b514d191fa2 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 13 Mar 2025 14:44:35 +0100 Subject: [PATCH 09/25] [FEAT] Simplify label code and remove cruft code --- mne/label.py | 42 +++++++++++++------------------ mne/source_estimate.py | 13 +--------- mne/tests/test_source_estimate.py | 2 +- 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/mne/label.py b/mne/label.py index fe3ab04b13f..68f3d89c98f 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1472,35 +1472,29 @@ def label_sign_flip(label, src): + "space contains a single hemisphere." ) - isbi_hemi = len(src) == 2 - lh_vertno = None - rh_vertno = None - - lh_id = -1 - rh_id = -1 - if isbi_hemi: - lh_id = 0 - rh_id = 1 - lh_vertno = src[0]["vertno"] - rh_vertno = src[1]["vertno"] - elif label.hemi == "lh": - lh_vertno = src[0]["vertno"] - elif label.hemi == "rh": - rh_id = 0 - rh_vertno = src[0]["vertno"] + hemis = {} + + # Build hemisphere info dictionary + if label.hemi == "both": + hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]} + hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]} + elif label.hemi in ("lh", "rh"): + hemis[label.hemi] = {"id": 0, "vertno": src[0]["vertno"]} else: raise Exception(f'Unknown hemisphere type "{label.hemi}"') # get source orientations ori = list() - if label.hemi in ("lh", "both"): - vertices = label.vertices if label.hemi == "lh" else label.lh.vertices - vertno_sel = np.intersect1d(lh_vertno, vertices) - ori.append(src[lh_id]["nn"][vertno_sel]) - if label.hemi in ("rh", "both"): - vertices = label.vertices if label.hemi == "rh" else label.rh.vertices - vertno_sel = np.intersect1d(rh_vertno, vertices) - ori.append(src[rh_id]["nn"][vertno_sel]) + for hemi, hemi_infos in hemis.items(): + # When the label is lh or rh, get vertices directly + if label.hemi == hemi: + vertices = label.vertices + # In the case where label is "both", get label.hemi.vertices + # (so either label.lh.vertices or label.rh.vertices) + else: + vertices = getattr(label, hemi).vertices + vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices) + ori.append(src[hemi_infos["id"]["nn"][vertno_sel]]) if len(ori) == 0: raise Exception(f'Unknown hemisphere type "{label.hemi}"') ori = np.concatenate(ori, axis=0) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 7124a348e71..aacf8484a66 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3440,9 +3440,7 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # only computes vertex indices and label_flip will be list of None. from .label import BiHemiLabel, Label, label_sign_flip - # logger.info("Selected mode: " + mode) - # print("Entering _prepare_label_extraction") - # print("Selected mode: " + mode) + logger.debug(f"Selected mode: {mode}") # if source estimate provided in stc, get vertices from source space and # check that they are the same as in the stcs @@ -3456,7 +3454,6 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): bad_labels = list() for li, label in enumerate(labels): - # print("Mode: " + mode + " li: " + str(li) + " label: " + str(label)) subject = label["subject"] if use_sparse else label.subject # stc and src can each be None _check_subject( @@ -3522,9 +3519,6 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse): # So if we override vertno with the stc vertices, it will pick # the correct normals. with _temporary_vertices(src, stc.vertices): - # print(f"src: {src[:2]}") - # print(f"len(src): {len(src[:2])}") - this_flip = label_sign_flip(label, src[:2])[:, None] label_vertidx.append(this_vertidx) @@ -3749,12 +3743,7 @@ def _gen_extract_label_time_course( this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] else: this_data = stc.data[vertidx] - # if flip is None: # Happens if fewer than 2 vertices in the label - # if this_data.shape[] - # label_tc[i] = 0 - # else: label_tc[i] = func(flip, this_data) - if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 for i, nv in enumerate(nvert[2:]): diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 6111811dc12..c5500fa28e8 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1613,7 +1613,7 @@ def objective(x): assert_allclose(directions, want_nn, atol=2e-6) -# @testing.requires_testing_data +@testing.requires_testing_data def test_source_estime_project_label(): """Test projecting a source estimate onto direction of max power.""" fwd = read_forward_solution(fname_fwd) From 33f911d382cd58bc716c4cd3a3982905373812de Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 13 Mar 2025 14:48:57 +0100 Subject: [PATCH 10/25] [FIX] Removed trivial branch --- mne/tests/test_source_estimate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index c9d6ab3b15b..8df802aae59 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1087,8 +1087,7 @@ def eltc(*args, **kwargs): if cf == "head" and not mri_res: # some missing with pytest.warns(RuntimeWarning, match="any vertices"): eltc(labels, src, allow_empty=True, mri_resolution=mri_res) - modes = ("mean", "max") if vector else ("mean", "max") - for mode in modes: + for mode in ("mean", "max"): with catch_logging() as log: label_tc = eltc( labels, From ec829869fbf17e3b887836b455b71582a24d9172 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 14 Mar 2025 11:42:59 +0100 Subject: [PATCH 11/25] [FIX] label_sign_flip incorrectly handled hemispheres --- mne/label.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mne/label.py b/mne/label.py index 68f3d89c98f..dc9c9c83251 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1479,7 +1479,10 @@ def label_sign_flip(label, src): hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]} hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]} elif label.hemi in ("lh", "rh"): - hemis[label.hemi] = {"id": 0, "vertno": src[0]["vertno"]} + # If two sources available, the hemisphere's ID must be looked up. + # If only a single source, the ID is zero. + index_ = ("lh", "rh").index(label.hemi) if len(src) == 2 else 0 + hemis[label.hemi] = {"id": index_, "vertno": src[index_]["vertno"]} else: raise Exception(f'Unknown hemisphere type "{label.hemi}"') @@ -1494,7 +1497,7 @@ def label_sign_flip(label, src): else: vertices = getattr(label, hemi).vertices vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices) - ori.append(src[hemi_infos["id"]["nn"][vertno_sel]]) + ori.append(src[hemi_infos["id"]]["nn"][vertno_sel]) if len(ori) == 0: raise Exception(f'Unknown hemisphere type "{label.hemi}"') ori = np.concatenate(ori, axis=0) From 0ca43cf3dddabe180d8705b2cce16cd81fb19fba Mon Sep 17 00:00:00 2001 From: Shrecki Date: Tue, 25 Mar 2025 13:13:45 +0100 Subject: [PATCH 12/25] Imports moved up top --- mne/tests/test_source_estimate.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index 8df802aae59..b8af56be21c 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -73,7 +73,14 @@ read_inverse_operator, ) from mne.morph_map import _make_morph_map_hemi -from mne.source_estimate import _get_vol_mask, _make_stc, grade_to_tris +from mne.source_estimate import ( + _get_vol_mask, + _make_stc, + _pca_flip, + _prepare_label_extraction, + _volume_labels, + grade_to_tris, +) from mne.source_space._source_space import _get_src_nn from mne.transforms import apply_trans, invert_transform, transform_surface_to from mne.utils import ( @@ -748,10 +755,7 @@ def eltc(*args, **kwargs): n_tot += 1 n_want -= len(missing) - # _volume_labels(src, labels, mri_resolution) # actually do the testing - from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels - labels_expanded = _volume_labels(src, labels, mri_res) _, src_flip = _prepare_label_extraction( stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res) From ee4a174e33634521fb5c15e2ad9f7a6cc718686b Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 15 May 2025 13:05:06 +0200 Subject: [PATCH 13/25] Updating mri_name to save volumetric source --- mne/source_space/_source_space.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/source_space/_source_space.py b/mne/source_space/_source_space.py index d64989961cf..90741d62bb6 100644 --- a/mne/source_space/_source_space.py +++ b/mne/source_space/_source_space.py @@ -642,6 +642,7 @@ def export_volume( # Get shape, inuse array and interpolation matrix from volume sources src = src_types["volume"][0] + src["mri_file"] = src["mri_volume_name"] aseg_data = None if mri_resolution: # read the mri file used to generate volumes From 3ce6b38f5fb073f22924b7a3171dc924933ea687 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 15 May 2025 17:30:30 +0200 Subject: [PATCH 14/25] Fix of PCA flip in volume: returned constant 0 as flips meaningless in volumes --- mne/source_estimate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index aacf8484a66..cdb18e8b16a 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3377,8 +3377,8 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): def _pca_flip(flip, data): result = None - if flip is None: - result = 0 + if flip is None: # Case of volumetric data: flip is meaningless + flip = 1 elif data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: From 7ae37e5d5708172c5336c87c4cbc11259c7912f5 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 15 May 2025 17:47:02 +0200 Subject: [PATCH 15/25] Fixed pca flip branch --- mne/source_estimate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index cdb18e8b16a..8d7c07c61ee 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3379,7 +3379,7 @@ def _pca_flip(flip, data): result = None if flip is None: # Case of volumetric data: flip is meaningless flip = 1 - elif data.shape[0] < 2: + if data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: U, s, V = _safe_svd(data, full_matrices=False) From d9580daca3cd9fde7e831e3935f0a62a4266999a Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 15 May 2025 17:59:42 +0200 Subject: [PATCH 16/25] Handling of flip being an int --- mne/source_estimate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 8d7c07c61ee..1a97f7ac92b 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3383,8 +3383,12 @@ def _pca_flip(flip, data): result = data.mean(axis=0) # Trivial accumulator else: U, s, V = _safe_svd(data, full_matrices=False) - # determine sign-flip - sign = np.sign(np.dot(U[:, 0], flip)) + # determine sign-flip. + # if flip is a mere int, multiply U and sum + if isinstance(flip, int): + sign = np.sign((flip * U[:, 0]).sum()) + else: + sign = np.sign(np.dot(U[:, 0], flip)) # use average power in label for scaling scale = np.linalg.norm(s) / np.sqrt(len(data)) result = sign * scale * V[0] From fd717799944bc17054019d3d73a7b4b5b30c063d Mon Sep 17 00:00:00 2001 From: Shrecki Date: Thu, 15 May 2025 20:55:20 +0200 Subject: [PATCH 17/25] Using numpy svd instead of scipy --- mne/source_estimate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 1a97f7ac92b..3ffe05161ae 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3382,12 +3382,16 @@ def _pca_flip(flip, data): if data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: - U, s, V = _safe_svd(data, full_matrices=False) # determine sign-flip. # if flip is a mere int, multiply U and sum if isinstance(flip, int): + # We assume here that flip is thus denoting a volumetric. + # It means LAPACK is likely to overflow on big matrices => We use numpy + U, s, V = np.linalg.svd(data, full_matrices=False) + sign = np.sign((flip * U[:, 0]).sum()) else: + U, s, V = _safe_svd(data, full_matrices=False) sign = np.sign(np.dot(U[:, 0], flip)) # use average power in label for scaling scale = np.linalg.norm(s) / np.sqrt(len(data)) From 69937d2c1877774971ca0948eaf0b85f9cf4b371 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 16 May 2025 11:33:39 +0200 Subject: [PATCH 18/25] PCA flip for volumetric is now using randomized SVD to manage to run the SVD at all --- mne/source_estimate.py | 65 ++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 3ffe05161ae..c49fcce2b4f 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -19,7 +19,7 @@ from .cov import Covariance from .evoked import _get_peak from .filter import FilterMixin, _check_fun, resample -from .fixes import _eye_array, _safe_svd +from .fixes import _eye_array from .parallel import parallel_func from .source_space._source_space import ( SourceSpaces, @@ -3375,6 +3375,16 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): return ico +def _compute_pca_quantities(U, s, V, flip): + if isinstance(flip, int): + sign = np.sign((flip * U[:, 0]).sum()) + else: + sign = np.sign(np.dot(U[:, 0], flip)) + scale = np.linalg.norm(s) / np.sqrt(len(U)) + result = sign * scale * V[0] + return result + + def _pca_flip(flip, data): result = None if flip is None: # Case of volumetric data: flip is meaningless @@ -3382,20 +3392,10 @@ def _pca_flip(flip, data): if data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: + U, s, V = np.linalg.svd(data, full_matrices=False) # determine sign-flip. # if flip is a mere int, multiply U and sum - if isinstance(flip, int): - # We assume here that flip is thus denoting a volumetric. - # It means LAPACK is likely to overflow on big matrices => We use numpy - U, s, V = np.linalg.svd(data, full_matrices=False) - - sign = np.sign((flip * U[:, 0]).sum()) - else: - U, s, V = _safe_svd(data, full_matrices=False) - sign = np.sign(np.dot(U[:, 0], flip)) - # use average power in label for scaling - scale = np.linalg.norm(s) / np.sqrt(len(data)) - result = sign * scale * V[0] + result = _compute_pca_quantities(U, s, V, flip) return result @@ -3678,6 +3678,7 @@ def _gen_extract_label_time_course( allow_empty=False, mri_resolution=True, verbose=None, + max_channels=400, ): # loop through source estimates and extract time series if src is None and mode in ["mean", "max"]: @@ -3741,17 +3742,39 @@ def _gen_extract_label_time_course( else: # For other modes, initialize the label_tc array label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype) + pca_volumetric = kind == "volume" and mode == "pca_flip" + if pca_volumetric: + # Precompute randomized SVD on data + # Components are restricted to max_channels, which is the highest possible + # rank and is much smaller than the number of sources + from sklearn.utils.extmath import randomized_svd + + u_data, s_data, vh_data = randomized_svd( + stc.data, n_components=max_channels + ) for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)): if vertidx is not None: - if isinstance(vertidx, sparse.csr_array): - assert mri_resolution - assert vertidx.shape[1] == stc.data.shape[0] - this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) - this_data = vertidx @ this_data - this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + if pca_volumetric: + # Compute SVD of vertices + # We will use it to compute vertidx @ data implicitly, + u_vert, s_vert, vh_Vert = np.linalg.svd(vertidx.todense()) + center_prod = np.diag(s_vert) @ vh_Vert @ u_data @ np.diag(s_data) + u_s, s_s, vh_s = np.linalg.svd(center_prod) + U = u_vert @ u_s + s = s_s + V = vh_s @ vh_data + label_tc[i] = _compute_pca_quantities(U, s, V, flip) else: - this_data = stc.data[vertidx] - label_tc[i] = func(flip, this_data) + if isinstance(vertidx, sparse.csr_array): + assert mri_resolution + assert vertidx.shape[1] == stc.data.shape[0] + this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) + + this_data = vertidx @ this_data + this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + else: + this_data = stc.data[vertidx] + label_tc[i] = func(flip, this_data) if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 for i, nv in enumerate(nvert[2:]): From 8888c136678b4056b7fcee0f8dc3bfd5b5727122 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 16 May 2025 11:49:20 +0200 Subject: [PATCH 19/25] Simplification of PCA flip --- mne/source_estimate.py | 46 +++++++++++++----------------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index c49fcce2b4f..e9b832fb7bf 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3385,16 +3385,17 @@ def _compute_pca_quantities(U, s, V, flip): return result -def _pca_flip(flip, data): +def _pca_flip(flip, data, max_rank): result = None if flip is None: # Case of volumetric data: flip is meaningless flip = 1 if data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: - U, s, V = np.linalg.svd(data, full_matrices=False) + from sklearn.utils.extmath import randomized_svd + + U, s, V = randomized_svd(data, n_components=max_rank) # determine sign-flip. - # if flip is a mere int, multiply U and sum result = _compute_pca_quantities(U, s, V, flip) return result @@ -3742,38 +3743,19 @@ def _gen_extract_label_time_course( else: # For other modes, initialize the label_tc array label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype) - pca_volumetric = kind == "volume" and mode == "pca_flip" - if pca_volumetric: - # Precompute randomized SVD on data - # Components are restricted to max_channels, which is the highest possible - # rank and is much smaller than the number of sources - from sklearn.utils.extmath import randomized_svd - - u_data, s_data, vh_data = randomized_svd( - stc.data, n_components=max_channels - ) for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)): if vertidx is not None: - if pca_volumetric: - # Compute SVD of vertices - # We will use it to compute vertidx @ data implicitly, - u_vert, s_vert, vh_Vert = np.linalg.svd(vertidx.todense()) - center_prod = np.diag(s_vert) @ vh_Vert @ u_data @ np.diag(s_data) - u_s, s_s, vh_s = np.linalg.svd(center_prod) - U = u_vert @ u_s - s = s_s - V = vh_s @ vh_data - label_tc[i] = _compute_pca_quantities(U, s, V, flip) + if isinstance(vertidx, sparse.csr_array): + assert mri_resolution + assert vertidx.shape[1] == stc.data.shape[0] + this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) + this_data = vertidx @ this_data + this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + else: + this_data = stc.data[vertidx] + if mode == "pca_flip": + label_tc[i] = func(flip, this_data, max_channels) else: - if isinstance(vertidx, sparse.csr_array): - assert mri_resolution - assert vertidx.shape[1] == stc.data.shape[0] - this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) - - this_data = vertidx @ this_data - this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] - else: - this_data = stc.data[vertidx] label_tc[i] = func(flip, this_data) if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 From 775ec8063a2b04868d1320c472b87c4b150f49b6 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 16 May 2025 12:55:30 +0200 Subject: [PATCH 20/25] Logging --- mne/source_estimate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index e9b832fb7bf..9ade3a4f77a 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3757,6 +3757,7 @@ def _gen_extract_label_time_course( label_tc[i] = func(flip, this_data, max_channels) else: label_tc[i] = func(flip, this_data) + logger.debug(f"Done with label {i}") if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 for i, nv in enumerate(nvert[2:]): From 0fd4be5049452d8b952d8d2e7cb549cf10979a88 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 16 May 2025 13:28:42 +0200 Subject: [PATCH 21/25] Found a trick to make everything much faster with only two svds --- mne/source_estimate.py | 49 ++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 9ade3a4f77a..4aa283514bd 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3743,20 +3743,47 @@ def _gen_extract_label_time_course( else: # For other modes, initialize the label_tc array label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype) + + pca_volume = mode == "pca_flip" and kind == "volume" + if pca_volume: + from sklearn.utils.extmath import randomized_svd + + logger.debug("First SVD for PCA volume on stc data") + u_b, s_b, vh_b = randomized_svd(stc.data, max_channels) for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)): if vertidx is not None: - if isinstance(vertidx, sparse.csr_array): - assert mri_resolution - assert vertidx.shape[1] == stc.data.shape[0] - this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) - this_data = vertidx @ this_data - this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] - else: - this_data = stc.data[vertidx] - if mode == "pca_flip": - label_tc[i] = func(flip, this_data, max_channels) + if pca_volume: + # Use a trick for efficiency: + # stc = Ub Sb VhB + # full_data = vertidx @ stc + # = vertidx @ Ub @ Sb @ Vhb + # Consider U_f, s_f, Vh_f = SVD(vertidx @ Ub @ Sb) + # Then U,S,V = svd(full_data) is such that + # U_f = U, S = s_f and V = Vh_f @ Vhb + # This trick is more efficient, because: + # - We compute a first SVD once on stc, restricted to + # only first max_channels singular vals/vecs (quite fast) + # - We project vertidx to be from Nvertex x Nsources + # to Nvertex x rank. + # - We compute SVD on Nvertex x rank + # As rank << Nsources, we end up saving a lot of computations. + tmp_array = vertidx @ u_b @ np.diag(s_b) + U, S, v_tmp = np.linalg.svd(tmp_array, full_matrices=False) + V = v_tmp @ vh_b + label_tc[i] = _compute_pca_quantities(U, S, V, flip) else: - label_tc[i] = func(flip, this_data) + if isinstance(vertidx, sparse.csr_array): + assert mri_resolution + assert vertidx.shape[1] == stc.data.shape[0] + this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) + this_data = vertidx @ this_data + this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] + else: + this_data = stc.data[vertidx] + if mode == "pca_flip": + label_tc[i] = func(flip, this_data, max_channels) + else: + label_tc[i] = func(flip, this_data) logger.debug(f"Done with label {i}") if mode is not None: offset = nvert[:-n_mean].sum() # effectively :2 or :0 From 1a89099d0684a00e143540be56b4830ee68c4a32 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 16 May 2025 13:36:25 +0200 Subject: [PATCH 22/25] Flip handling in _compute_pca_quantitites --- mne/source_estimate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 4aa283514bd..5717178d3de 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -3376,6 +3376,8 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): def _compute_pca_quantities(U, s, V, flip): + if flip is None: # Case of volumetric data: flip is meaningless + flip = 1 if isinstance(flip, int): sign = np.sign((flip * U[:, 0]).sum()) else: @@ -3387,8 +3389,6 @@ def _compute_pca_quantities(U, s, V, flip): def _pca_flip(flip, data, max_rank): result = None - if flip is None: # Case of volumetric data: flip is meaningless - flip = 1 if data.shape[0] < 2: result = data.mean(axis=0) # Trivial accumulator else: From b02c14c25c945ca8161afe28db931ea6a780fe7b Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 23 May 2025 11:39:51 +0200 Subject: [PATCH 23/25] Feat: montage now supports .pos information file --- mne/channels/montage.py | 46 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 15cef38dec7..43630e187ea 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1383,6 +1383,48 @@ def _read_isotrak_elp_points(fname): } +def _read_isotrak_pos_points(fname): + """Read Polhemus Isotrak digitizer data from a ``.pos`` file. + + Parameters + ---------- + fname : path-like + The filepath of .pos Polhemus Isotrak file. + + Returns + ------- + out : dict of arrays + The dictionary containing locations for 'nasion', 'lpa', 'rpa' + and 'points'. + """ + with open(fname) as fid: + file_str = fid.read() + + # Get all lines which are points + int_pat = r"[+-]?\d+" + float_pat = r"[+-]?(?:\d+\.\d*|\d*\.\d+)(?:[eE][+-]?\d+)?" + pattern_points = re.compile( + rf"^\s*({int_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})", + re.MULTILINE, + ) + points = pattern_points.findall(file_str) + + # Get nasion, left and right + label_pat = r"[A-Za-z]+" + pattern_labels = re.compile( + rf"^\s*({label_pat})\s+({float_pat})\s+({float_pat})\s+({float_pat})", + re.MULTILINE, + ) + labels = pattern_labels.findall(file_str) + + return { + "nasion": [x[1:] for x in labels if x[0] == "nasion"][0], + "lpa": [x[1:] for x in labels if x[0] == "left"][0], + "rpa": [x[1:] for x in labels if x[0] == "right"][0], + "points": [x[1:] for x in points], + } + + def _read_isotrak_hsp_points(fname): """Read Polhemus Isotrak digitizer data from a ``.hsp`` file. @@ -1459,7 +1501,7 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): read_dig_fif read_dig_localite """ - VALID_FILE_EXT = (".hsp", ".elp", ".eeg") + VALID_FILE_EXT = (".hsp", ".elp", ".eeg", ".pos") fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) @@ -1468,6 +1510,8 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): if ext == ".elp": data = _read_isotrak_elp_points(fname) + elif ext == ".pos": + data = _read_isotrak_pos_points(fname) else: # Default case we read points as hsp since is the most likely scenario data = _read_isotrak_hsp_points(fname) From 8a96feceb1a40ed75403e908aaed60d11eabb345 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 23 May 2025 12:34:47 +0200 Subject: [PATCH 24/25] Float convert in digitization --- mne/channels/montage.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 43630e187ea..303aa6f1b2e 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1418,10 +1418,10 @@ def _read_isotrak_pos_points(fname): labels = pattern_labels.findall(file_str) return { - "nasion": [x[1:] for x in labels if x[0] == "nasion"][0], - "lpa": [x[1:] for x in labels if x[0] == "left"][0], - "rpa": [x[1:] for x in labels if x[0] == "right"][0], - "points": [x[1:] for x in points], + "nasion": [tuple(map(float, x[1:])) for x in labels if x[0] == "nasion"][0], + "lpa": [tuple(map(float, x[1:])) for x in labels if x[0] == "left"][0], + "rpa": [tuple(map(float, x[1:])) for x in labels if x[0] == "right"][0], + "points": [tuple(map(float, x[1:])) for x in points], } From 1074c1a69de69ee00826abb618578ae5fa809b31 Mon Sep 17 00:00:00 2001 From: Shrecki Date: Fri, 23 May 2025 13:31:06 +0200 Subject: [PATCH 25/25] Convert dig points to numpy array --- mne/channels/montage.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 303aa6f1b2e..b2444c94b90 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -1418,10 +1418,16 @@ def _read_isotrak_pos_points(fname): labels = pattern_labels.findall(file_str) return { - "nasion": [tuple(map(float, x[1:])) for x in labels if x[0] == "nasion"][0], - "lpa": [tuple(map(float, x[1:])) for x in labels if x[0] == "left"][0], - "rpa": [tuple(map(float, x[1:])) for x in labels if x[0] == "right"][0], - "points": [tuple(map(float, x[1:])) for x in points], + "nasion": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "nasion"][0] + ), + "lpa": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "left"][0] + ), + "rpa": np.array( + [tuple(map(float, x[1:])) for x in labels if x[0] == "right"][0] + ), + "points": np.array([tuple(map(float, x[1:])) for x in points]), }