Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ def __init__(self, study_folder):
self.labels_by_levels = {}
self.scan_folder()

@classmethod
def _check_cases(cls, cases, levels=None, reference=None):
if reference is None:
reference = list(cases.keys())[0]
for key in cases.keys():
if isinstance(reference, str):
assert isinstance(key, str), f"Case key {key} for cases is not homogeneous"
if levels is None:
levels = "level0"
else:
assert isinstance(levels, str)
elif isinstance(reference, tuple):
assert isinstance(key, tuple), f"Case key {key} for cases is not homogeneous"
num_levels = len(reference)
assert len(key) == num_levels, f"Case key {key} for cases is not homogeneous, tuple negth differ"
else:
raise ValueError("Keys for cases must str or tuple")
return levels

@classmethod
def create(cls, study_folder, datasets={}, cases={}, levels=None):
"""
Expand All @@ -76,27 +95,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
study : BenchmarkStudy
The created study.
"""
# check that cases keys are homogeneous
key0 = list(cases.keys())[0]
if isinstance(key0, str):
assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous"
if levels is None:
levels = "level0"
else:
assert isinstance(levels, str)
elif isinstance(key0, tuple):
assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous"
num_levels = len(key0)
assert all(
len(key) == num_levels for key in cases.keys()
), "Keys for cases are not homogeneous, tuple negth differ"
if levels is None:
levels = [f"level{i}" for i in range(num_levels)]
else:
levels = list(levels)
assert len(levels) == num_levels
else:
raise ValueError("Keys for cases must str or tuple")
levels = cls._check_cases(cases, levels)

study_folder = Path(study_folder)
study_folder.mkdir(exist_ok=False, parents=True)
Expand Down Expand Up @@ -272,6 +271,25 @@ def remove_benchmark(self, key):
f.unlink()
self.benchmarks[key] = None

def add_cases(self, cases):

_ = self._check_cases(cases, reference=list(self.cases.keys())[0])
for case in cases.values():
dataset = case["dataset"]
assert dataset in list(self.datasets.keys()), f"Unknown dataset {dataset} for the Study"
self.cases.update(cases)
for key in cases.keys():
benchmark = self.create_benchmark(key=key)
self.benchmarks[key] = benchmark
(self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases))

def remove_cases(self, case_keys):
for key in case_keys:
assert key in list(self.cases.keys()), f"Case {key} is not in the cases of the Study"
self.cases.pop(key)
self.remove_benchmark(key)
(self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases))

def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
if case_keys is None:
case_keys = list(self.cases.keys())
Expand Down
63 changes: 62 additions & 1 deletion src/spikeinterface/benchmark/tests/test_benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ def _create_simple_study(study_folder):
# print(study)


def _create_very_simple_study(study_folder):
rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42)

datasets = {
"toy_tetrode": (rec0, gt_sorting0),
}

# cases can also be generated via simple loops
cases = {
#
("tdc2", "no-preprocess", "tetrode"): {
"label": "tridesclous2 without preprocessing and standard params",
"dataset": "toy_tetrode",
"params": {
"sorter_name": "tridesclous2",
},
},
#
("tdc2", "with-preprocess", "probe32"): {
"label": "tridesclous2 with preprocessing standar params",
"dataset": "toy_tetrode",
"params": {
"sorter_name": "tridesclous2",
},
},
}

study = SorterStudy.create(
study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"]
)
# print(study)


def _create_complex_study(study_folder):
rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42)
rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91)
Expand Down Expand Up @@ -258,19 +291,47 @@ def test_get_grouped_keys_mapping(create_complex_study):
assert len(keys) == 16


def test_add_remove_cases(create_simple_study):
# job_kwargs = dict(n_jobs=2, chunk_duration="1s")

study_folder = create_simple_study
study = SorterStudy(study_folder)

# # this run the sorters
# study.run()
# this is from the base class
# rt = study.get_run_times()

case_key = list(study.cases.keys())[0]
case = study.cases[case_key].copy()
study.remove_cases([case_key])
study.add_cases({case_key: case})
# study.run()


if __name__ == "__main__":
study_folder_simple = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy"
if study_folder_simple.exists():
shutil.rmtree(study_folder_simple)
_create_simple_study(study_folder_simple)
_create_simple_study(_create_very_simple_study)
test_SorterStudy(study_folder_simple)

study_folder_very_simple = (
Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_AddRemoveCases"
)
if study_folder_very_simple.exists():
shutil.rmtree(study_folder_very_simple)
_create_very_simple_study(_create_very_simple_study)
test_add_remove_cases(_create_very_simple_study)

study_folder_complex = (
Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy_complex"
)
if study_folder_complex.exists():
shutil.rmtree(study_folder_complex)
_create_complex_study(study_folder_complex)
test_get_grouped_keys_mapping(study_folder_complex)
test_add_remove_cases(create_simple_study)

# # test out all plots and levels
# import matplotlib.pyplot as plt
Expand Down