Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ dependencies = [
"numpy>=2.0.0;python_version>='3.13'",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.18,<3",
"zarr>=3,<4",
"neo>=0.14.3",
"probeinterface>=0.3.1",
"packaging",
"pydantic",
"numcodecs<0.16.0", # For supporting zarr < 3
]

[build-system]
Expand Down
94 changes: 51 additions & 43 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Literal, Optional, Any
from typing import Literal, Optional, Any, Iterable

from pathlib import Path
from itertools import chain
Expand Down Expand Up @@ -621,6 +621,7 @@ def _get_zarr_root(self, mode="r+"):
assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'"

storage_options = self._backend_options.get("storage_options", {})

zarr_root = super_zarr_open(self.folder, mode=mode, storage_options=storage_options)
return zarr_root

Expand All @@ -644,7 +645,12 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
storage_options = backend_options.get("storage_options", {})
saving_options = backend_options.get("saving_options", {})

zarr_root = zarr.open(folder, mode="w", storage_options=storage_options)
if not is_path_remote(str(folder)):
storage_options_kwargs = {}
else:
storage_options_kwargs = storage_options

zarr_root = zarr.open(folder, mode="w", **storage_options_kwargs)

info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer")
zarr_root.attrs["spikeinterface_info"] = check_json(info)
Expand All @@ -657,13 +663,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
if recording is not None:
rec_dict = recording.to_dict(relative_to=relative_to, recursive=True)
if recording.check_serializability("json"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON())
zarr_rec = np.array([check_json(rec_dict)], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON())
elif recording.check_serializability("pickle"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle())
zarr_rec = np.array([rec_dict], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle())
# In zarr v3, store JSON-serializable data in attributes instead of using object_codec
zarr_root.attrs["recording"] = check_json(rec_dict)
else:
warnings.warn("The Recording is not serializable! The recording link will be lost for future load")
else:
Expand All @@ -673,11 +674,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
# sorting provenance
sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True)
if sorting.check_serializability("json"):
zarr_sort = np.array([check_json(sort_dict)], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON())
elif sorting.check_serializability("pickle"):
zarr_sort = np.array([sort_dict], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle())
# In zarr v3, store JSON-serializable data in attributes instead of using object_codec
zarr_root.attrs["sorting_provenance"] = check_json(sort_dict)
else:
warnings.warn(
"The sorting provenance is not serializable! The sorting provenance link will be lost for future load"
Expand All @@ -698,12 +696,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
recording_info.attrs["probegroup"] = check_json(probegroup.to_dict())

if sparsity is not None:
zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options)
zarr_root.create_array("sparsity_mask", data=sparsity.mask, **saving_options)

add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options)

recording_info = zarr_root.create_group("extensions")

# consolidate metadata for zarr v3
zarr.consolidate_metadata(zarr_root.store)

return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options)
Expand All @@ -715,6 +714,10 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):

backend_options = {} if backend_options is None else backend_options
storage_options = backend_options.get("storage_options", {})
if not is_path_remote(str(folder)):
storage_options_kwargs = {}
else:
storage_options_kwargs = storage_options

zarr_root = super_zarr_open(str(folder), mode="r", storage_options=storage_options)

Expand All @@ -723,7 +726,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
# v0.101.0 did not have a consolidate metadata step after computing extensions.
# Here we try to consolidate the metadata and throw a warning if it fails.
try:
zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options)
zarr_root_a = zarr.open(str(folder), mode="a", **storage_options_kwargs)
zarr.consolidate_metadata(zarr_root_a.store)
except Exception as e:
warnings.warn(
Expand All @@ -741,9 +744,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):

# load recording if possible
if recording is None:
rec_field = zarr_root.get("recording")
if rec_field is not None:
rec_dict = rec_field[0]
# In zarr v3, recording is stored in attributes
rec_dict = zarr_root.attrs.get("recording", None)
if rec_dict is not None:
try:
recording = load(rec_dict, base_folder=folder)
except:
Expand Down Expand Up @@ -859,7 +862,7 @@ def set_sorting_property(
if key in zarr_root["sorting"]["properties"]:
zarr_root["sorting"]["properties"][key][:] = prop_values
else:
zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None)
zarr_root["sorting"]["properties"].create_array(name=key, data=prop_values, compressors=None)
# IMPORTANT: we need to re-consolidate the zarr store!
zarr.consolidate_metadata(zarr_root.store)

Expand Down Expand Up @@ -1531,12 +1534,13 @@ def get_sorting_provenance(self):
elif self.format == "zarr":
zarr_root = self._get_zarr_root(mode="r")
sorting_provenance = None
if "sorting_provenance" in zarr_root.keys():
# In zarr v3, sorting_provenance is stored in attributes
sort_dict = zarr_root.attrs.get("sorting_provenance", None)
if sort_dict is not None:
# try-except here is because it's not required to be able
# to load the sorting provenance, as the user might have deleted
# the original sorting folder
try:
sort_dict = zarr_root["sorting_provenance"][0]
sorting_provenance = load(sort_dict, base_folder=self.folder)
except:
pass
Expand Down Expand Up @@ -2479,8 +2483,9 @@ def load_data(self):
extension_group = self._get_zarr_extension_group(mode="r")
for ext_data_name in extension_group.keys():
ext_data_ = extension_group[ext_data_name]
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
# In zarr v3, check if it's a group with dict_data attribute
if "dict_data" in ext_data_.attrs:
ext_data = ext_data_.attrs["dict_data"]
elif "dataframe" in ext_data_.attrs:
import pandas as pd

Expand Down Expand Up @@ -2565,9 +2570,10 @@ def run(self, save=True, **kwargs):
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store)

def save(self):
self._reset_extension_folder()
self._save_params()
self._save_importing_provenance()
self._save_run_info()
Expand All @@ -2576,7 +2582,7 @@ def save(self):
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store)

def _save_data(self):
if self.format == "memory":
Expand Down Expand Up @@ -2623,40 +2629,44 @@ def _save_data(self):
extension_group = self._get_zarr_extension_group(mode="r+")

# if compression is not externally given, we use the default
if "compressor" not in saving_options:
saving_options["compressor"] = get_default_zarr_compressor()
if "compressors" not in saving_options and "compressor" not in saving_options:
saving_options["compressors"] = get_default_zarr_compressor()
if "compressor" in saving_options:
saving_options["compressors"] = [saving_options["compressor"]]
del saving_options["compressor"]

for ext_data_name, ext_data in self.data.items():
if ext_data_name in extension_group:
del extension_group[ext_data_name]
if isinstance(ext_data, dict):
extension_group.create_dataset(
name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON()
)
# In zarr v3, store dict in a subgroup with attributes
dict_group = extension_group.create_group(ext_data_name)
dict_group.attrs["dict_data"] = check_json(ext_data)
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options)
extension_group.create_array(name=ext_data_name, data=ext_data, **saving_options)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
df_group = extension_group.create_group(ext_data_name)
# first we save the index
indices = ext_data.index.to_numpy()
if indices.dtype.kind == "O":
indices = indices.astype(str)
df_group.create_dataset(name="index", data=indices)
df_group.create_array(name="index", data=indices)
for col in ext_data.columns:
col_data = ext_data[col].to_numpy()
if col_data.dtype.kind == "O":
col_data = col_data.astype(str)
df_group.create_dataset(name=col, data=col_data)
df_group.create_array(name=col, data=col_data)
df_group.attrs["dataframe"] = True
else:
# any object
try:
extension_group.create_dataset(
name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle()
)
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# try:
# extension_group.create_array(
# name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle()
# )
# except:
# raise Exception(f"Could not save {ext_data_name} as extension data")
# extension_group[ext_data_name].attrs["object"] = True
warnings.warn(f"Data type of {ext_data_name} not supported for zarr saving, skipping.")

def _reset_extension_folder(self):
"""
Expand Down Expand Up @@ -2734,8 +2744,6 @@ def set_params(self, save=True, **params):
def _save_params(self):
params_to_save = self.params.copy()

self._reset_extension_folder()

# TODO make sparsity local Result specific
# if "sparsity" in params_to_save and params_to_save["sparsity"] is not None:
# assert isinstance(
Expand Down
12 changes: 7 additions & 5 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,19 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:
"""

# Saves one chunk per unit
arrays_chunk = (1, None, None)
zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk)
zarr_group.create_dataset("channel_ids", data=self.channel_ids)
zarr_group.create_dataset("unit_ids", data=self.unit_ids)
# In zarr v3, chunks must be a full tuple with actual dimensions
num_units, num_samples, num_channels = self.templates_array.shape
arrays_chunk = (1, num_samples, num_channels)
zarr_group.create_array("templates_array", data=self.templates_array, chunks=arrays_chunk)
zarr_group.create_array("channel_ids", data=self.channel_ids)
zarr_group.create_array("unit_ids", data=self.unit_ids)

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore
zarr_group.attrs["is_in_uV"] = self.is_in_uV

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)
zarr_group.create_array("sparsity_mask", data=self.sparsity_mask)

if self.probe is not None:
probe_group = zarr_group.create_group("probe")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_ComputeRandomSpikes(format, sparse, create_cache_folder):

print("Checking results")
_check_result_extension(sorting_analyzer, "random_spikes", cache_folder)
print("Delering extension")
print("Deleting extension")
sorting_analyzer.delete_extension("random_spikes")

print("Re-computing random spikes")
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_BaseRecording(create_cache_folder):

# test save to zarr
compressor = get_default_zarr_compressor()
rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressor=compressor)
rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressors=compressor)
rec_zarr_loaded = load(cache_folder / "recording.zarr")
# annotations is False because Zarr adds compression ratios
check_recordings_equal(rec2, rec_zarr, return_in_uV=False, check_annotations=False, check_properties=True)
Expand All @@ -336,7 +336,7 @@ def test_BaseRecording(create_cache_folder):
assert rec2.get_annotation(annotation_name) == rec_zarr_loaded.get_annotation(annotation_name)

rec_zarr2 = rec2.save(
format="zarr", folder=cache_folder / "recording_channel_chunk", compressor=compressor, channel_chunk_size=2
format="zarr", folder=cache_folder / "recording_channel_chunk", compressors=compressor, channel_chunk_size=2
)
rec_zarr2_loaded = load(cache_folder / "recording_channel_chunk.zarr")

Expand Down
45 changes: 23 additions & 22 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AnalyzerExtension,
_sort_extensions_by_dependency,
)
from spikeinterface.core.zarrextractors import check_compressors_match

import numpy as np

Expand All @@ -38,6 +39,8 @@ def get_dataset():
integer_unit_ids = [int(id) for id in sorting.get_unit_ids()]

recording = recording.rename_channels(new_channel_ids=integer_channel_ids)
# make sure the recording is serializable
recording = recording.save()
sorting = sorting.rename_units(new_unit_ids=integer_unit_ids)
return recording, sorting

Expand Down Expand Up @@ -133,13 +136,12 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)

# check that compression is applied
assert (
sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id
== default_compressor.codec_id
check_compressors_match(
default_compressor,
sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0],
)
assert (
sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id
== default_compressor.codec_id
check_compressors_match(
default_compressor, sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressors[0]
)

# test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041
Expand All @@ -160,35 +162,34 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
sparsity=None,
return_in_uV=False,
overwrite=True,
backend_options={"saving_options": {"compressor": None}},
backend_options={"saving_options": {"compressors": None}},
)
print(sorting_analyzer_no_compression._backend_options)
sorting_analyzer_no_compression.compute(["random_spikes", "templates"])
assert (
sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][
"random_spikes_indices"
].compressor
is None
len(
sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][
"random_spikes_indices"
].compressors
)
== 0
)
assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None
assert len(sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressors) == 0

# test a different compressor
from numcodecs import LZMA
from zarr.codecs.numcodecs import LZMA

lzma_compressor = LZMA()
folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr"
sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as(
format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}}
format="zarr", folder=folder, backend_options={"saving_options": {"compressors": lzma_compressor}}
)
assert (
sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][
"random_spikes_indices"
].compressor.codec_id
== LZMA.codec_id
check_compressors_match(
lzma_compressor,
sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0],
)
assert (
sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id
== LZMA.codec_id
check_compressors_match(
lzma_compressor, sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressors[0]
)

# test set_sorting_property
Expand Down
Loading
Loading