From cecffac07886da8d3e4466e82f81179aca8a4011 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 11 Dec 2025 11:33:54 +0100 Subject: [PATCH 1/7] WIP --- .../benchmark/benchmark_base.py | 33 ++++++++++ .../benchmark/tests/test_benchmark_sorter.py | 62 ++++++++++++++++++- 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 63f4da09b4..67fc3f3e92 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -272,6 +272,39 @@ def remove_benchmark(self, key): f.unlink() self.benchmarks[key] = None + def add_cases(self, cases): + # check that cases keys are homogeneous + key0 = list(self.cases.keys())[0] + for key in cases.keys(): + print(key) + if isinstance(key0, str): + assert isinstance(key, str), f"Key {key} for cases is not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert isinstance(key, tuple), f"Key {key} for cases is not homogeneous" + num_levels = len(key0) + assert len(key) == num_levels, f"Key {key} for cases is not homogeneous, tuple negth differ" + else: + raise ValueError("Keys for cases must str or tuple") + + for case in cases.values(): + assert case["dataset"] in self.datasets.keys(), f"Unknown dataset {case["dataset"]}" + + self.cases.update(cases) + benchmark = self.create_benchmark(key=key) + self.benchmarks[key] = benchmark + (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) + + def remove_cases(self, case_keys): + for key in case_keys: + assert key in list(self.cases.keys()), f"Key {key} is not in the cases of the Study" + self.cases.pop(key) + self.remove_benchmark(key) + (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) + def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): if case_keys is None: case_keys = list(self.cases.keys()) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 10e34b2f38..ecfa6693a4 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -64,6 +64,38 @@ def _create_simple_study(study_folder): ) # print(study) +def _create_very_simple_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + + datasets = { + "toy_tetrode": (rec0, gt_sorting0), + } + + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "params": { + "sorter_name": "tridesclous2", + }, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_tetrode", + "params": { + "sorter_name": "tridesclous2", + }, + }, + } + + study = SorterStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] + ) + # print(study) + def _create_complex_study(study_folder): rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) @@ -258,12 +290,39 @@ def test_get_grouped_keys_mapping(create_complex_study): assert len(keys) == 16 +def test_add_remove_cases(create_simple_study): + # job_kwargs = dict(n_jobs=2, chunk_duration="1s") + + study_folder = create_simple_study + study = SorterStudy(study_folder) + print(study) + + # # this run the sorters + study.run() + # this is from the base class + rt = study.get_run_times() + + case_key = list(study.cases.keys())[0] + case = study.cases[case_key].copy() + study.remove_cases([case_key]) + study.add_cases({case_key : case}) + study.run() + + + if __name__ == "__main__": study_folder_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" if study_folder_simple.exists(): shutil.rmtree(study_folder_simple) - _create_simple_study(study_folder_simple) + _create_simple_study(_create_very_simple_study) test_SorterStudy(study_folder_simple) + + study_folder_very_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases" + if study_folder_very_simple.exists(): + shutil.rmtree(study_folder_very_simple) + _create_simple_study(_create_very_simple_study) + test_add_remove_cases(_create_very_simple_study) + study_folder_complex = ( Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy_complex" ) @@ -271,6 +330,7 @@ def test_get_grouped_keys_mapping(create_complex_study): shutil.rmtree(study_folder_complex) _create_complex_study(study_folder_complex) test_get_grouped_keys_mapping(study_folder_complex) + test_add_remove_cases(create_simple_study) # # test out all plots and levels # import matplotlib.pyplot as plt From 10e289fbb951ddcb8eb5db0ffa3e52434513f02b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:49:32 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/benchmark/benchmark_base.py | 4 ++-- .../benchmark/tests/test_benchmark_sorter.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 67fc3f3e92..2ddf0a2bc0 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -289,7 +289,7 @@ def add_cases(self, cases): assert len(key) == num_levels, f"Key {key} for cases is not homogeneous, tuple negth differ" else: raise ValueError("Keys for cases must str or tuple") - + for case in cases.values(): assert case["dataset"] in self.datasets.keys(), f"Unknown dataset {case["dataset"]}" @@ -297,7 +297,7 @@ def add_cases(self, cases): benchmark = self.create_benchmark(key=key) self.benchmarks[key] = benchmark (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) - + def remove_cases(self, case_keys): for key in case_keys: assert key in list(self.cases.keys()), f"Key {key} is not in the cases of the Study" diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index ecfa6693a4..94aa28ed2f 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -64,6 +64,7 @@ def _create_simple_study(study_folder): ) # print(study) + def _create_very_simple_study(study_folder): rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) @@ -305,10 +306,9 @@ def test_add_remove_cases(create_simple_study): case_key = list(study.cases.keys())[0] case = study.cases[case_key].copy() study.remove_cases([case_key]) - study.add_cases({case_key : case}) + study.add_cases({case_key: case}) study.run() - if __name__ == "__main__": study_folder_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" @@ -317,7 +317,9 @@ def test_add_remove_cases(create_simple_study): _create_simple_study(_create_very_simple_study) test_SorterStudy(study_folder_simple) - study_folder_very_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases" + study_folder_very_simple = ( + Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases" + ) if study_folder_very_simple.exists(): shutil.rmtree(study_folder_very_simple) _create_simple_study(_create_very_simple_study) From b242a41d66c33fa5a5ad82f82b46c2819f247343 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 11 Dec 2025 11:50:19 +0100 Subject: [PATCH 3/7] Tests --- src/spikeinterface/benchmark/tests/test_benchmark_sorter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index ecfa6693a4..99834da053 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -320,7 +320,7 @@ def test_add_remove_cases(create_simple_study): study_folder_very_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases" if study_folder_very_simple.exists(): shutil.rmtree(study_folder_very_simple) - _create_simple_study(_create_very_simple_study) + _create_very_simple_study(_create_very_simple_study) test_add_remove_cases(_create_very_simple_study) study_folder_complex = ( From ec90f13032abe7b003176d0f43b8c6a1694bae27 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 11 Dec 2025 13:42:44 +0100 Subject: [PATCH 4/7] Formatting --- src/spikeinterface/benchmark/benchmark_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 2ddf0a2bc0..6ebd3217d7 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -291,7 +291,8 @@ def add_cases(self, cases): raise ValueError("Keys for cases must str or tuple") for case in cases.values(): - assert case["dataset"] in self.datasets.keys(), f"Unknown dataset {case["dataset"]}" + dataset = case["dataset"] + assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset}" self.cases.update(cases) benchmark = self.create_benchmark(key=key) From f09e3d7caf2aa5c8f12d95deb20fee9ff3d4cfee Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 11 Dec 2025 14:04:37 +0100 Subject: [PATCH 5/7] Fixes --- .../benchmark/benchmark_base.py | 69 +++++++------------ .../benchmark/tests/test_benchmark_sorter.py | 7 +- 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 6ebd3217d7..1cb3abb520 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -50,6 +50,24 @@ def __init__(self, study_folder): self.labels_by_levels = {} self.scan_folder() + @classmethod + def _check_cases(cls, cases, levels=None, reference=None): + if reference is None: + reference = list(cases.keys())[0] + for key in cases.keys(): + if isinstance(reference, str): + assert isinstance(key, str), f"Case key {key} for cases is not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(reference, tuple): + assert isinstance(key, tuple), f"Case key {key} for cases is not homogeneous" + num_levels = len(reference) + assert len(key) == num_levels, f"Case key {key} for cases is not homogeneous, tuple negth differ" + else: + raise ValueError("Keys for cases must str or tuple") + @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): """ @@ -76,27 +94,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): study : BenchmarkStudy The created study. """ - # check that cases keys are homogeneous - key0 = list(cases.keys())[0] - if isinstance(key0, str): - assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" - num_levels = len(key0) - assert all( - len(key) == num_levels for key in cases.keys() - ), "Keys for cases are not homogeneous, tuple negth differ" - if levels is None: - levels = [f"level{i}" for i in range(num_levels)] - else: - levels = list(levels) - assert len(levels) == num_levels - else: - raise ValueError("Keys for cases must str or tuple") + cls._check_cases(cases, levels) study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) @@ -273,35 +271,20 @@ def remove_benchmark(self, key): self.benchmarks[key] = None def add_cases(self, cases): - # check that cases keys are homogeneous - key0 = list(self.cases.keys())[0] - for key in cases.keys(): - print(key) - if isinstance(key0, str): - assert isinstance(key, str), f"Key {key} for cases is not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert isinstance(key, tuple), f"Key {key} for cases is not homogeneous" - num_levels = len(key0) - assert len(key) == num_levels, f"Key {key} for cases is not homogeneous, tuple negth differ" - else: - raise ValueError("Keys for cases must str or tuple") - + + self._check_cases(cases, reference=list(self.cases.keys())[0]) for case in cases.values(): dataset = case["dataset"] - assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset}" - + assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset} for the Study" self.cases.update(cases) - benchmark = self.create_benchmark(key=key) - self.benchmarks[key] = benchmark + for key in cases.keys(): + benchmark = self.create_benchmark(key=key) + self.benchmarks[key] = benchmark (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) def remove_cases(self, case_keys): for key in case_keys: - assert key in list(self.cases.keys()), f"Key {key} is not in the cases of the Study" + assert key in list(self.cases.keys()), f"Case {key} is not in the cases of the Study" self.cases.pop(key) self.remove_benchmark(key) (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 3ecf142d9b..077bcf3d64 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -296,18 +296,17 @@ def test_add_remove_cases(create_simple_study): study_folder = create_simple_study study = SorterStudy(study_folder) - print(study) # # this run the sorters - study.run() + #study.run() # this is from the base class - rt = study.get_run_times() + #rt = study.get_run_times() case_key = list(study.cases.keys())[0] case = study.cases[case_key].copy() study.remove_cases([case_key]) study.add_cases({case_key: case}) - study.run() + #study.run() if __name__ == "__main__": From e02f6f0d23c23d8817f90eafb124c3d6b819efbe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 13:05:27 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/benchmark/benchmark_base.py | 2 +- src/spikeinterface/benchmark/tests/test_benchmark_sorter.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 1cb3abb520..bad6ee41de 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -271,7 +271,7 @@ def remove_benchmark(self, key): self.benchmarks[key] = None def add_cases(self, cases): - + self._check_cases(cases, reference=list(self.cases.keys())[0]) for case in cases.values(): dataset = case["dataset"] diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 077bcf3d64..ceaa9331cc 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -298,15 +298,15 @@ def test_add_remove_cases(create_simple_study): study = SorterStudy(study_folder) # # this run the sorters - #study.run() + # study.run() # this is from the base class - #rt = study.get_run_times() + # rt = study.get_run_times() case_key = list(study.cases.keys())[0] case = study.cases[case_key].copy() study.remove_cases([case_key]) study.add_cases({case_key: case}) - #study.run() + # study.run() if __name__ == "__main__": From c5d88fc1c772cc635536c00ca8fe254769f746ef Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 12 Dec 2025 16:31:11 +0100 Subject: [PATCH 7/7] Return levels --- src/spikeinterface/benchmark/benchmark_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index bad6ee41de..c5251ccfe2 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -67,6 +67,7 @@ def _check_cases(cls, cases, levels=None, reference=None): assert len(key) == num_levels, f"Case key {key} for cases is not homogeneous, tuple negth differ" else: raise ValueError("Keys for cases must str or tuple") + return levels @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): @@ -94,7 +95,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): study : BenchmarkStudy The created study. """ - cls._check_cases(cases, levels) + levels = cls._check_cases(cases, levels) study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) @@ -272,7 +273,7 @@ def remove_benchmark(self, key): def add_cases(self, cases): - self._check_cases(cases, reference=list(self.cases.keys())[0]) + _ = self._check_cases(cases, reference=list(self.cases.keys())[0]) for case in cases.values(): dataset = case["dataset"] assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset} for the Study"