From adf8c9eeddfe3302595764eba2da038c44f2c126 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 11 Dec 2025 13:14:02 +0100 Subject: [PATCH 1/2] wip: support zarr v3 --- pyproject.toml | 3 +- src/spikeinterface/core/sortinganalyzer.py | 94 ++++++++++--------- src/spikeinterface/core/template.py | 12 ++- .../tests/test_analyzer_extension_core.py | 2 +- .../core/tests/test_baserecording.py | 4 +- .../core/tests/test_sortinganalyzer.py | 45 ++++----- .../core/tests/test_zarrextractors.py | 35 ++++--- src/spikeinterface/core/zarrextractors.py | 91 ++++++++++++------ 8 files changed, 169 insertions(+), 117 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c3a3cf3b1..a753911fa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,11 @@ dependencies = [ "numpy>=2.0.0;python_version>='3.13'", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=2.18,<3", + "zarr>=3,<4", "neo>=0.14.3", "probeinterface>=0.3.1", "packaging", "pydantic", - "numcodecs<0.16.0", # For supporting zarr < 3 ] [build-system] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..31cb317565 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Literal, Optional, Any +from typing import Literal, Optional, Any, Iterable from pathlib import Path from itertools import chain @@ -621,6 +621,7 @@ def _get_zarr_root(self, mode="r+"): assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" storage_options = self._backend_options.get("storage_options", {}) + zarr_root = super_zarr_open(self.folder, mode=mode, storage_options=storage_options) return zarr_root @@ -644,7 +645,12 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att storage_options = backend_options.get("storage_options", {}) saving_options = backend_options.get("saving_options", {}) - zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options + + zarr_root = zarr.open(folder, mode="w", **storage_options_kwargs) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -657,13 +663,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att if recording is not None: rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) - elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["recording"] = check_json(rec_dict) else: warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: @@ -673,11 +674,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att # sorting provenance sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): - zarr_sort = np.array([check_json(sort_dict)], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) - elif sorting.check_serializability("pickle"): - zarr_sort = np.array([sort_dict], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) else: warnings.warn( "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" @@ -698,12 +696,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) + zarr_root.create_array("sparsity_mask", data=sparsity.mask, **saving_options) add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") + # consolidate metadata for zarr v3 zarr.consolidate_metadata(zarr_root.store) return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) @@ -715,6 +714,10 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options zarr_root = super_zarr_open(str(folder), mode="r", storage_options=storage_options) @@ -723,7 +726,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # v0.101.0 did not have a consolidate metadata step after computing extensions. # Here we try to consolidate the metadata and throw a warning if it fails. try: - zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr_root_a = zarr.open(str(folder), mode="a", **storage_options_kwargs) zarr.consolidate_metadata(zarr_root_a.store) except Exception as e: warnings.warn( @@ -741,9 +744,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # load recording if possible if recording is None: - rec_field = zarr_root.get("recording") - if rec_field is not None: - rec_dict = rec_field[0] + # In zarr v3, recording is stored in attributes + rec_dict = zarr_root.attrs.get("recording", None) + if rec_dict is not None: try: recording = load(rec_dict, base_folder=folder) except: @@ -859,7 +862,7 @@ def set_sorting_property( if key in zarr_root["sorting"]["properties"]: zarr_root["sorting"]["properties"][key][:] = prop_values else: - zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None) + zarr_root["sorting"]["properties"].create_array(name=key, data=prop_values, compressors=None) # IMPORTANT: we need to re-consolidate the zarr store! zarr.consolidate_metadata(zarr_root.store) @@ -1531,12 +1534,13 @@ def get_sorting_provenance(self): elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") sorting_provenance = None - if "sorting_provenance" in zarr_root.keys(): + # In zarr v3, sorting_provenance is stored in attributes + sort_dict = zarr_root.attrs.get("sorting_provenance", None) + if sort_dict is not None: # try-except here is because it's not required to be able # to load the sorting provenance, as the user might have deleted # the original sorting folder try: - sort_dict = zarr_root["sorting_provenance"][0] sorting_provenance = load(sort_dict, base_folder=self.folder) except: pass @@ -2479,8 +2483,9 @@ def load_data(self): extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): ext_data_ = extension_group[ext_data_name] - if "dict" in ext_data_.attrs: - ext_data = ext_data_[0] + # In zarr v3, check if it's a group with dict_data attribute + if "dict_data" in ext_data_.attrs: + ext_data = ext_data_.attrs["dict_data"] elif "dataframe" in ext_data_.attrs: import pandas as pd @@ -2565,9 +2570,10 @@ def run(self, save=True, **kwargs): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def save(self): + self._reset_extension_folder() self._save_params() self._save_importing_provenance() self._save_run_info() @@ -2576,7 +2582,7 @@ def save(self): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def _save_data(self): if self.format == "memory": @@ -2623,40 +2629,44 @@ def _save_data(self): extension_group = self._get_zarr_extension_group(mode="r+") # if compression is not externally given, we use the default - if "compressor" not in saving_options: - saving_options["compressor"] = get_default_zarr_compressor() + if "compressors" not in saving_options and "compressor" not in saving_options: + saving_options["compressors"] = get_default_zarr_compressor() + if "compressor" in saving_options: + saving_options["compressors"] = [saving_options["compressor"]] + del saving_options["compressor"] for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] if isinstance(ext_data, dict): - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() - ) + # In zarr v3, store dict in a subgroup with attributes + dict_group = extension_group.create_group(ext_data_name) + dict_group.attrs["dict_data"] = check_json(ext_data) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) + extension_group.create_array(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index indices = ext_data.index.to_numpy() if indices.dtype.kind == "O": indices = indices.astype(str) - df_group.create_dataset(name="index", data=indices) + df_group.create_array(name="index", data=indices) for col in ext_data.columns: col_data = ext_data[col].to_numpy() if col_data.dtype.kind == "O": col_data = col_data.astype(str) - df_group.create_dataset(name=col, data=col_data) + df_group.create_array(name=col, data=col_data) df_group.attrs["dataframe"] = True else: # any object - try: - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - extension_group[ext_data_name].attrs["object"] = True + # try: + # extension_group.create_array( + # name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() + # ) + # except: + # raise Exception(f"Could not save {ext_data_name} as extension data") + # extension_group[ext_data_name].attrs["object"] = True + warnings.warn(f"Data type of {ext_data_name} not supported for zarr saving, skipping.") def _reset_extension_folder(self): """ @@ -2734,8 +2744,6 @@ def set_params(self, save=True, **params): def _save_params(self): params_to_save = self.params.copy() - self._reset_extension_folder() - # TODO make sparsity local Result specific # if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: # assert isinstance( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 91d25bece6..50ae6cbfdf 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -321,17 +321,19 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: """ # Saves one chunk per unit - arrays_chunk = (1, None, None) - zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) - zarr_group.create_dataset("channel_ids", data=self.channel_ids) - zarr_group.create_dataset("unit_ids", data=self.unit_ids) + # In zarr v3, chunks must be a full tuple with actual dimensions + num_units, num_samples, num_channels = self.templates_array.shape + arrays_chunk = (1, num_samples, num_channels) + zarr_group.create_array("templates_array", data=self.templates_array, chunks=arrays_chunk) + zarr_group.create_array("channel_ids", data=self.channel_ids) + zarr_group.create_array("unit_ids", data=self.unit_ids) zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore zarr_group.attrs["is_in_uV"] = self.is_in_uV if self.sparsity_mask is not None: - zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) + zarr_group.create_array("sparsity_mask", data=self.sparsity_mask) if self.probe is not None: probe_group = zarr_group.create_group("probe") diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 6f5bef3c6c..574cc89e10 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -91,7 +91,7 @@ def test_ComputeRandomSpikes(format, sparse, create_cache_folder): print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) - print("Delering extension") + print("Deleting extension") sorting_analyzer.delete_extension("random_spikes") print("Re-computing random spikes") diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9de800b33d..186767c026 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -324,7 +324,7 @@ def test_BaseRecording(create_cache_folder): # test save to zarr compressor = get_default_zarr_compressor() - rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressor=compressor) + rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressors=compressor) rec_zarr_loaded = load(cache_folder / "recording.zarr") # annotations is False because Zarr adds compression ratios check_recordings_equal(rec2, rec_zarr, return_in_uV=False, check_annotations=False, check_properties=True) @@ -336,7 +336,7 @@ def test_BaseRecording(create_cache_folder): assert rec2.get_annotation(annotation_name) == rec_zarr_loaded.get_annotation(annotation_name) rec_zarr2 = rec2.save( - format="zarr", folder=cache_folder / "recording_channel_chunk", compressor=compressor, channel_chunk_size=2 + format="zarr", folder=cache_folder / "recording_channel_chunk", compressors=compressor, channel_chunk_size=2 ) rec_zarr2_loaded = load(cache_folder / "recording_channel_chunk.zarr") diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index ab0b071df4..3a1c73d746 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -17,6 +17,7 @@ AnalyzerExtension, _sort_extensions_by_dependency, ) +from spikeinterface.core.zarrextractors import check_compressors_match import numpy as np @@ -38,6 +39,8 @@ def get_dataset(): integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + # make sure the recording is serializable + recording = recording.save() sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting @@ -133,13 +136,12 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) # check that compression is applied - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 @@ -160,35 +162,34 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): sparsity=None, return_in_uV=False, overwrite=True, - backend_options={"saving_options": {"compressor": None}}, + backend_options={"saving_options": {"compressors": None}}, ) print(sorting_analyzer_no_compression._backend_options) sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) assert ( - sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor - is None + len( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressors + ) + == 0 ) - assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + assert len(sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressors) == 0 # test a different compressor - from numcodecs import LZMA + from zarr.codecs.numcodecs import LZMA lzma_compressor = LZMA() folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( - format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} + format="zarr", folder=folder, backend_options={"saving_options": {"compressors": lzma_compressor}} ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test set_sorting_property diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index cc0c60721e..a52d456594 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -10,50 +10,55 @@ generate_sorting, load, ) -from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor +from spikeinterface.core.zarrextractors import ( + add_sorting_to_zarr_group, + get_default_zarr_compressor, + check_compressors_match, +) def test_zarr_compression_options(tmp_path): - from numcodecs import Blosc, Delta, FixedScaleOffset + from zarr.codecs.numcodecs import Delta, FixedScaleOffset + from zarr.codecs import BloscCodec, BloscShuffle recording = generate_recording(durations=[2]) recording.set_times(recording.get_times() + 100) # store in root standard normal way # default compressor - defaut_compressor = get_default_zarr_compressor() + default_compressor = get_default_zarr_compressor() # other compressor - other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE) - other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE) + other_compressor1 = BloscCodec(cname="zlib", clevel=3, shuffle=BloscShuffle.noshuffle) + other_compressor2 = BloscCodec(cname="blosclz", clevel=8, shuffle=BloscShuffle.shuffle) # timestamps compressors / filters default_filters = None - other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())] + other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype().str)] other_filters2 = [Delta(dtype="float64")] # default ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr") rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr") - assert rec_default._root["traces_seg0"].compressor == defaut_compressor - assert rec_default._root["traces_seg0"].filters == default_filters - assert rec_default._root["times_seg0"].compressor == defaut_compressor - assert rec_default._root["times_seg0"].filters == default_filters + check_compressors_match(rec_default._root["traces_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["times_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["traces_seg0"].filters, default_filters) + check_compressors_match(rec_default._root["times_seg0"].filters, default_filters) # now with other compressor ZarrRecordingExtractor.write_recording( recording, tmp_path / "rec_other.zarr", - compressor=defaut_compressor, + compressors=default_compressor, filters=default_filters, compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2}, filters_by_dataset={"traces": other_filters1, "times": other_filters2}, ) rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr") - assert rec_other._root["traces_seg0"].compressor == other_compressor1 - assert rec_other._root["traces_seg0"].filters == other_filters1 - assert rec_other._root["times_seg0"].compressor == other_compressor2 - assert rec_other._root["times_seg0"].filters == other_filters2 + check_compressors_match(rec_other._root["traces_seg0"].compressors[0], other_compressor1) + check_compressors_match(rec_other._root["traces_seg0"].filters, other_filters1) + check_compressors_match(rec_other._root["times_seg0"].compressors[0], other_compressor2) + check_compressors_match(rec_other._root["times_seg0"].filters, other_filters2) def test_ZarrSortingExtractor(tmp_path): diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 162d67a458..6b20c6bbaa 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -48,11 +48,13 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d import zarr # if mode is append or read/write, we try to open the folder with zarr.open - # since zarr.open_consolidated does not support creating new groups/datasets + # In zarr v3, we use use_consolidated parameter instead of open_consolidated if mode in ("a", "r+"): open_funcs = (zarr.open,) + use_consolidated_options = (False,) else: - open_funcs = (zarr.open_consolidated, zarr.open) + open_funcs = (zarr.open,) + use_consolidated_options = (True, False) # if storage_options is None, we try to open the folder with and without anonymous access # if storage_options is not None, we try to open the folder with the given storage options @@ -64,12 +66,14 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d root = None exception = None if is_path_remote(str(folder_path)): - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: if root is not None: break for storage_options in storage_options_to_test: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open( + str(folder_path), mode=mode, storage_options=storage_options, use_consolidated=use_consolidated + ) break except Exception as e: exception = e @@ -77,9 +81,9 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d else: if not Path(folder_path).is_dir(): raise ValueError(f"Folder {folder_path} does not exist") - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open(str(folder_path), mode=mode, use_consolidated=use_consolidated) break except Exception as e: exception = e @@ -91,6 +95,34 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d return root +def check_compressors_match(comp1, comp2, skip_typesize=True): + """ + Check if two compressor objects match. + + Parameters + ---------- + comp1 : zarr.Codec | Tuple[zarr.Codec] + The first compressor object to compare. + comp2 : zarr.Codec | Tuple[zarr.Codec] + The second compressor object to compare. + skip_typesize : bool, optional + Whether to skip the typesize check, default: True + """ + if not isinstance(comp1, (list, tuple)): + assert not isinstance(comp2, list) + comp1 = [comp1] + comp2 = [comp2] + for i in range(len(comp1)): + comp1_dict = comp1[i].to_dict() + comp2_dict = comp2[i].to_dict() + if skip_typesize: + if "typesize" in comp1_dict["configuration"]: + comp1_dict["configuration"].pop("typesize", None) + if "typesize" in comp2_dict["configuration"]: + comp2_dict["configuration"].pop("typesize", None) + assert comp1_dict == comp2_dict + + class ZarrRecordingExtractor(BaseRecording): """ RecordingExtractor for a zarr format @@ -289,7 +321,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) - spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) + spikes = np.zeros(spikes_group["sample_index"].shape[0], dtype=minimum_spike_dtype) spikes["sample_index"] = spikes_group["sample_index"][:] spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): @@ -392,9 +424,9 @@ def get_default_zarr_compressor(clevel: int = 5): Blosc.compressor The compressor object that can be used with the save to zarr function """ - from numcodecs import Blosc + from zarr.codecs import BloscCodec, BloscShuffle - return Blosc(cname="zstd", clevel=clevel, shuffle=Blosc.BITSHUFFLE) + return BloscCodec(cname="zstd", clevel=clevel, shuffle=BloscShuffle.bitshuffle) def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_or_sorting: BaseRecording | BaseSorting): @@ -405,7 +437,7 @@ def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_o if values.dtype.kind == "O": warnings.warn(f"Property {key} not saved because it is a python Object type") continue - prop_group.create_dataset(name=key, data=values, compressor=None) + prop_group.create_array(name=key, data=values, compressors=None) # save annotations zarr_group.attrs["annotations"] = check_json(recording_or_sorting._annotations) @@ -424,12 +456,12 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G kwargs : dict Other arguments passed to the zarr compressor """ - from numcodecs import Delta + from zarr.codecs.numcodecs import Delta num_segments = sorting.get_num_segments() zarr_group.attrs["sampling_frequency"] = float(sorting.sampling_frequency) zarr_group.attrs["num_segments"] = int(num_segments) - zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) + zarr_group.create_array(name="unit_ids", data=sorting.unit_ids, compressors=None) compressor = kwargs.get("compressor", get_default_zarr_compressor()) @@ -438,18 +470,21 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G spikes = sorting.to_spike_vector() for field in spikes.dtype.fields: if field != "segment_index": - spikes_group.create_dataset( + dtype = spikes[field].dtype + spikes_data = spikes[field] + spikes_group.create_array( name=field, - data=spikes[field], - compressor=compressor, - filters=[Delta(dtype=spikes[field].dtype)], + data=spikes_data, + compressors=compressor, + filters=[Delta(dtype=spikes[field].dtype.str)], ) else: segment_slices = [] for segment_index in range(num_segments): i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append([i0, i1]) - spikes_group.create_dataset(name="segment_slices", data=segment_slices, compressor=None) + segment_slices = np.array(segment_slices, dtype="int64") + spikes_group.create_array(name="segment_slices", data=segment_slices, compressors=None) add_properties_and_annotations(zarr_group, sorting) @@ -468,7 +503,7 @@ def add_recording_to_zarr_group( # save data (done the subclass) zarr_group.attrs["sampling_frequency"] = float(recording.get_sampling_frequency()) zarr_group.attrs["num_segments"] = int(recording.get_num_segments()) - zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) + zarr_group.create_array(name="channel_ids", data=recording.get_channel_ids(), compressors=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] dtype = recording.get_dtype() if dtype is None else dtype @@ -484,7 +519,7 @@ def add_recording_to_zarr_group( recording=recording, zarr_group=zarr_group, dataset_paths=dataset_paths, - compressor=compressor_traces, + compressors=compressor_traces, filters=filters_traces, dtype=dtype, channel_chunk_size=channel_chunk_size, @@ -507,17 +542,17 @@ def add_recording_to_zarr_group( filters_times = filters_by_dataset.get("times", global_filters) if time_vector is not None: - _ = zarr_group.create_dataset( + _ = zarr_group.create_array( name=f"times_seg{segment_index}", data=time_vector, filters=filters_times, - compressor=compressor_times, + compressors=compressor_times, ) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"] if np.any(~np.isnan(t_starts)): - zarr_group.create_dataset(name="t_starts", data=t_starts, compressor=None) + zarr_group.create_array(name="t_starts", data=t_starts, compressors=None) add_properties_and_annotations(zarr_group, recording) @@ -528,7 +563,7 @@ def add_traces_to_zarr( dataset_paths, channel_chunk_size=None, dtype=None, - compressor=None, + compressors=None, filters=None, verbose=False, **job_kwargs, @@ -548,7 +583,7 @@ def add_traces_to_zarr( Channels per chunk dtype : dtype, default: None Type of the saved data - compressor : zarr compressor or None, default: None + compressors : zarr compressor or None, default: None Zarr compressor filters : list, default: None List of zarr filters @@ -581,13 +616,15 @@ def add_traces_to_zarr( num_channels = recording.get_num_channels() dset_name = dataset_paths[segment_index] shape = (num_frames, num_channels) - dset = zarr_group.create_dataset( + # In zarr v3, chunks must be a tuple of integers (no None allowed) + chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) + dset = zarr_group.create_array( name=dset_name, shape=shape, - chunks=(chunk_size, channel_chunk_size), + chunks=chunks, dtype=dtype, filters=filters, - compressor=compressor, + compressors=compressors, ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) From 4dbb7b38fab00795aebb3644ba4bb325052607cf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Dec 2025 16:49:23 +0100 Subject: [PATCH 2/2] wip --- src/spikeinterface/core/zarrextractors.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 6b20c6bbaa..0ec28d544a 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -14,6 +14,9 @@ from .core_tools import is_path_remote +zarr.config.set({"default_zarr_version": 3}) + + def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): """ Open a zarr folder with super powers. @@ -463,7 +466,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G zarr_group.attrs["num_segments"] = int(num_segments) zarr_group.create_array(name="unit_ids", data=sorting.unit_ids, compressors=None) - compressor = kwargs.get("compressor", get_default_zarr_compressor()) + compressor = kwargs.get("compressors") or kwargs.get("compressor") + if compressor is None: + compressor = get_default_zarr_compressor() # save sub fields spikes_group = zarr_group.create_group(name="spikes") @@ -508,7 +513,9 @@ def add_recording_to_zarr_group( dtype = recording.get_dtype() if dtype is None else dtype channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) - global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) + global_compressor = kwargs.get("compressors") or kwargs.get("compressor") + if global_compressor is None: + global_compressor = get_default_zarr_compressor() compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) global_filters = zarr_kwargs.pop("filters", None) filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) @@ -609,6 +616,9 @@ def add_traces_to_zarr( job_kwargs = fix_job_kwargs(job_kwargs) chunk_size = ensure_chunk_size(recording, **job_kwargs) + if not isinstance(compressors, (list, tuple)): + compressors = [compressors] + # create zarr datasets files zarr_datasets = [] for segment_index in range(recording.get_num_segments()): @@ -618,13 +628,8 @@ def add_traces_to_zarr( shape = (num_frames, num_channels) # In zarr v3, chunks must be a tuple of integers (no None allowed) chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) - dset = zarr_group.create_array( - name=dset_name, - shape=shape, - chunks=chunks, - dtype=dtype, - filters=filters, - compressors=compressors, + dset = zarr_group.create( + name=dset_name, shape=shape, chunks=chunks, dtype=dtype, filters=filters, codecs=compressors, zarr_format=3 ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer())