diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 63f4da09b4..c5251ccfe2 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -50,6 +50,25 @@ def __init__(self, study_folder): self.labels_by_levels = {} self.scan_folder() + @classmethod + def _check_cases(cls, cases, levels=None, reference=None): + if reference is None: + reference = list(cases.keys())[0] + for key in cases.keys(): + if isinstance(reference, str): + assert isinstance(key, str), f"Case key {key} for cases is not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(reference, tuple): + assert isinstance(key, tuple), f"Case key {key} for cases is not homogeneous" + num_levels = len(reference) + assert len(key) == num_levels, f"Case key {key} for cases is not homogeneous, tuple negth differ" + else: + raise ValueError("Keys for cases must str or tuple") + return levels + @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): """ @@ -76,27 +95,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): study : BenchmarkStudy The created study. """ - # check that cases keys are homogeneous - key0 = list(cases.keys())[0] - if isinstance(key0, str): - assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" - num_levels = len(key0) - assert all( - len(key) == num_levels for key in cases.keys() - ), "Keys for cases are not homogeneous, tuple negth differ" - if levels is None: - levels = [f"level{i}" for i in range(num_levels)] - else: - levels = list(levels) - assert len(levels) == num_levels - else: - raise ValueError("Keys for cases must str or tuple") + levels = cls._check_cases(cases, levels) study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) @@ -272,6 +271,25 @@ def remove_benchmark(self, key): f.unlink() self.benchmarks[key] = None + def add_cases(self, cases): + + _ = self._check_cases(cases, reference=list(self.cases.keys())[0]) + for case in cases.values(): + dataset = case["dataset"] + assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset} for the Study" + self.cases.update(cases) + for key in cases.keys(): + benchmark = self.create_benchmark(key=key) + self.benchmarks[key] = benchmark + (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) + + def remove_cases(self, case_keys): + for key in case_keys: + assert key in list(self.cases.keys()), f"Case {key} is not in the cases of the Study" + self.cases.pop(key) + self.remove_benchmark(key) + (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) + def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): if case_keys is None: case_keys = list(self.cases.keys()) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 10e34b2f38..ceaa9331cc 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -65,6 +65,39 @@ def _create_simple_study(study_folder): # print(study) +def _create_very_simple_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + + datasets = { + "toy_tetrode": (rec0, gt_sorting0), + } + + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "params": { + "sorter_name": "tridesclous2", + }, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_tetrode", + "params": { + "sorter_name": "tridesclous2", + }, + }, + } + + study = SorterStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] + ) + # print(study) + + def _create_complex_study(study_folder): rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91) @@ -258,12 +291,39 @@ def test_get_grouped_keys_mapping(create_complex_study): assert len(keys) == 16 +def test_add_remove_cases(create_simple_study): + # job_kwargs = dict(n_jobs=2, chunk_duration="1s") + + study_folder = create_simple_study + study = SorterStudy(study_folder) + + # # this run the sorters + # study.run() + # this is from the base class + # rt = study.get_run_times() + + case_key = list(study.cases.keys())[0] + case = study.cases[case_key].copy() + study.remove_cases([case_key]) + study.add_cases({case_key: case}) + # study.run() + + if __name__ == "__main__": study_folder_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" if study_folder_simple.exists(): shutil.rmtree(study_folder_simple) - _create_simple_study(study_folder_simple) + _create_simple_study(_create_very_simple_study) test_SorterStudy(study_folder_simple) + + study_folder_very_simple = ( + Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases" + ) + if study_folder_very_simple.exists(): + shutil.rmtree(study_folder_very_simple) + _create_very_simple_study(_create_very_simple_study) + test_add_remove_cases(_create_very_simple_study) + study_folder_complex = ( Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy_complex" ) @@ -271,6 +331,7 @@ def test_get_grouped_keys_mapping(create_complex_study): shutil.rmtree(study_folder_complex) _create_complex_study(study_folder_complex) test_get_grouped_keys_mapping(study_folder_complex) + test_add_remove_cases(create_simple_study) # # test out all plots and levels # import matplotlib.pyplot as plt