Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Low-level
.. automodule:: spikeinterface.core
:noindex:

.. autoclass:: ChunkRecordingExecutor
.. autoclass:: ChunkExecutor


Back-compatibility with ``WaveformExtractor`` (version > 0.100.0)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(

for segment_index in range(multisortingcomparison._num_segments):
sorting_segment = AgreementSortingSegment(multisortingcomparison._spiketrains[segment_index])
self.add_sorting_segment(sorting_segment)
self.add_segment(sorting_segment)

self._kwargs = dict(
sampling_frequency=sampling_frequency,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@
get_best_job_kwargs,
ensure_n_jobs,
ensure_chunk_size,
ChunkRecordingExecutor,
ChunkExecutor,
split_job_kwargs,
fix_job_kwargs,
)
from .chunkable_tools import write_binary, write_memory
from .recording_tools import (
write_binary_recording,
write_to_h5_dataset_format,
get_random_data_chunks,
get_channel_distances,
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
* ComputeNoiseLevels which is very convenient to have
"""

from __future__ import annotations

import warnings
import numpy as np
from collections import namedtuple
Expand Down
79 changes: 68 additions & 11 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
import shutil
from typing import Any, Iterable, List, Optional, Sequence, Union
Expand Down Expand Up @@ -61,6 +62,9 @@ def __init__(self, main_ids: Sequence) -> None:
self._main_ids.dtype.kind in "uiSU"
), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}"

# segments
self._segments: List[BaseSegment] = []

# dict at object level
self._annotations = {}

Expand Down Expand Up @@ -116,9 +120,16 @@ def name(self, value):
# we remove the annotation if it exists
_ = self._annotations.pop("name", None)

@property
def segments(self) -> list:
return self._segments

def add_segment(self, segment: BaseSegment) -> None:
self._segments.append(segment)
segment.set_parent_extractor(self)

def get_num_segments(self) -> int:
# This is implemented in BaseRecording or BaseSorting
raise NotImplementedError
return len(self._segments)

def get_parent(self) -> Optional[BaseExtractor]:
"""Returns parent object if it exists, otherwise None"""
Expand Down Expand Up @@ -210,13 +221,6 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No
else:
raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it")

def get_preferred_mp_context(self):
"""
Get the preferred context for multiprocessing.
If None, the context is set by the multiprocessing package.
"""
return self._preferred_mp_context

def get_annotation(self, key: str, copy: bool = True) -> Any:
"""
Get a annotation.
Expand Down Expand Up @@ -405,8 +409,9 @@ def copy_metadata(

other.extra_requirements.extend(self.extra_requirements)

if self._preferred_mp_context is not None:
other._preferred_mp_context = self._preferred_mp_context
# call all extra copy metadata if it exists (e.g., with chunkable mixin)
if hasattr(self, "_extra_copy_metadata"):
self._extra_copy_metadata(other, only_main=only_main, ids=ids, skip_properties=skip_properties)

def to_dict(
self,
Expand Down Expand Up @@ -1165,3 +1170,55 @@ def parent_extractor(self) -> Union[BaseExtractor, None]:

def set_parent_extractor(self, parent_extractor: BaseExtractor) -> None:
self._parent_extractor = weakref.ref(parent_extractor)


class ChunkableMixin(ABC):
"""
Abstract mixin class for chunkable objects.
Provides methods to handle chunked data access, that can be used for parallelization.

The Mixin is abstract since all methods need to be implemented in the child class in order
for it to function properly.
"""

_preferred_mp_context = None

@abstractmethod
def get_sampling_frequency(self) -> float:
raise NotImplementedError

@abstractmethod
def get_num_samples(self, segment_index: int | None = None) -> int:
raise NotImplementedError

@abstractmethod
def get_sample_size_in_bytes(self) -> int:
raise NotImplementedError

@abstractmethod
def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]:
raise NotImplementedError

@abstractmethod
def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
raise NotImplementedError

def _extra_copy_metadata(self, other: "ChunkableMixin", **kwargs) -> None:
"""
Copy metadata from another ChunkableMixin object.

Parameters
----------
other : ChunkableMixin
The object from which to copy metadata.
"""
# inherit preferred mp context if any
if self.__class__._preferred_mp_context is not None:
other.__class__._preferred_mp_context = self.__class__._preferred_mp_context

def get_preferred_mp_context(self):
"""
Get the preferred context for multiprocessing.
If None, the context is set by the multiprocessing package.
"""
return self.__class__._preferred_mp_context
75 changes: 37 additions & 38 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import numpy as np
from probeinterface import read_probeinterface, write_probeinterface

from .base import BaseSegment
from .base import BaseSegment, ChunkableMixin
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs
from .recording_tools import write_binary_recording


class BaseRecording(BaseRecordingSnippets):
class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
"""
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
Internally handle list of RecordingSegment
Expand Down Expand Up @@ -44,8 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype
)

self._recording_segments: list[BaseRecordingSegment] = []

# initialize main annotation and properties
self.annotate(is_filtered=False)

Expand Down Expand Up @@ -171,28 +168,19 @@ def __sub__(self, other):

return SubtractRecordings(self, other)

def get_num_segments(self) -> int:
def get_sample_size_in_bytes(self):
"""
Returns the number of segments.
Returns the size of a single sample across all channels in bytes.

Returns
-------
int
Number of segments in the recording
The size of a single sample in bytes
"""
return len(self._recording_segments)

def add_recording_segment(self, recording_segment):
"""Adds a recording segment.

Parameters
----------
recording_segment : BaseRecordingSegment
The recording segment to add
"""
# todo: check channel count and sampling frequency
self._recording_segments.append(recording_segment)
recording_segment.set_parent_extractor(self)
num_channels = self.get_num_channels()
dtype_size_bytes = self.get_dtype().itemsize
sample_size = num_channels * dtype_size_bytes
return sample_size

def get_num_samples(self, segment_index: int | None = None) -> int:
"""
Expand All @@ -211,7 +199,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int:
The number of samples
"""
segment_index = self._check_segment_index(segment_index)
return int(self._recording_segments[segment_index].get_num_samples())
return int(self.segments[segment_index].get_num_samples())

get_num_frames = get_num_samples

Expand Down Expand Up @@ -343,7 +331,7 @@ def get_traces(
"""
segment_index = self._check_segment_index(segment_index)
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
start_frame = int(start_frame) if start_frame is not None else 0
num_samples = rs.get_num_samples()
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples
Expand Down Expand Up @@ -401,7 +389,7 @@ def get_time_info(self, segment_index=None) -> dict:
"""

segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
time_kwargs = rs.get_times_kwargs()

return time_kwargs
Expand All @@ -425,7 +413,7 @@ def get_times(self, segment_index=None) -> np.ndarray:
The 1d times array
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
times = rs.get_times()
return times

Expand All @@ -443,7 +431,7 @@ def get_start_time(self, segment_index=None) -> float:
The start time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.get_start_time()

def get_end_time(self, segment_index=None) -> float:
Expand All @@ -460,7 +448,7 @@ def get_end_time(self, segment_index=None) -> float:
The stop time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.get_end_time()

def has_time_vector(self, segment_index: Optional[int] = None):
Expand All @@ -477,7 +465,7 @@ def has_time_vector(self, segment_index: Optional[int] = None):
True if the recording has time vectors, False otherwise
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
d = rs.get_times_kwargs()
return d["time_vector"] is not None

Expand All @@ -494,7 +482,7 @@ def set_times(self, times, segment_index=None, with_warning=True):
If True, a warning is printed
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]

assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"
Expand All @@ -517,7 +505,7 @@ def reset_times(self):
segment's sampling frequency is set to the recording's sampling frequency.
"""
for segment_index in range(self.get_num_segments()):
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
if self.has_time_vector(segment_index):
rs.time_vector = None
rs.t_start = None
Expand Down Expand Up @@ -545,7 +533,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
segments_to_shift = (segment_index,)

for segment_index in segments_to_shift:
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]

if self.has_time_vector(segment_index=segment_index):
rs.time_vector += shift
Expand All @@ -558,19 +546,28 @@ def sample_index_to_time(self, sample_ind, segment_index=None):
Transform sample index into time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.sample_index_to_time(sample_ind)

def time_to_sample_index(self, time_s, segment_index=None):
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.time_to_sample_index(time_s)

def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
"""
General retrieval function for chunkable objects
"""
return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs)

def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]:
return (self.get_num_samples(segment_index=segment_index), self.get_num_channels())

def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for rs in self._recording_segments:
for rs in self.segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])

Expand All @@ -580,7 +577,7 @@ def _get_t_starts(self):

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
for rs in self.segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
Expand All @@ -591,12 +588,14 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
from .chunkable_tools import write_binary

folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()
t_starts = self._get_t_starts()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)
write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

from .binaryrecordingextractor import BinaryRecordingExtractor

Expand Down Expand Up @@ -668,7 +667,7 @@ def _extra_metadata_from_folder(self, folder):
self.set_probegroup(probegroup, in_place=True)

# load time vector if any
for segment_index, rs in enumerate(self._recording_segments):
for segment_index, rs in enumerate(self.segments):
time_file = folder / f"times_cached_seg{segment_index}.npy"
if time_file.is_file():
time_vector = np.load(time_file)
Expand All @@ -681,7 +680,7 @@ def _extra_metadata_to_folder(self, folder):
write_probeinterface(folder / "probe.json", probegroup)

# save time vector if any
for segment_index, rs in enumerate(self._recording_segments):
for segment_index, rs in enumerate(self.segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
Expand Down
Loading