Skip to content

Commit 0657d8e

Browse files
committed
Fix t_starts not propagated to save memory.
1 parent 5c06804 commit 0657d8e

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,11 +566,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
566566
if kwargs.get("sharedmem", True):
567567
from .numpyextractors import SharedMemoryRecording
568568

569-
cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
569+
cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
570570
else:
571571
from spikeinterface.core import NumpyRecording
572572

573-
cached = NumpyRecording.from_recording(self, **job_kwargs)
573+
cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
574574

575575
elif format == "zarr":
576576
from .zarrextractors import ZarrRecordingExtractor

src/spikeinterface/core/numpyextractors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
8585
}
8686

8787
@staticmethod
88-
def from_recording(source_recording, **job_kwargs):
88+
def from_recording(source_recording, t_starts=None, **job_kwargs):
8989
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
9090
if shms[0] is not None:
9191
# if the computation was done in parallel then traces_list is shared array
@@ -99,9 +99,10 @@ def from_recording(source_recording, **job_kwargs):
9999
recording = NumpyRecording(
100100
traces_list,
101101
source_recording.get_sampling_frequency(),
102-
t_starts=None,
102+
t_starts=t_starts,
103103
channel_ids=source_recording.channel_ids,
104104
)
105+
return recording
105106

106107

107108
class NumpyRecordingSegment(BaseRecordingSegment):
@@ -211,7 +212,7 @@ def __del__(self):
211212
shm.unlink()
212213

213214
@staticmethod
214-
def from_recording(source_recording, **job_kwargs):
215+
def from_recording(source_recording, t_starts=None, **job_kwargs):
215216
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)
216217

217218
# TODO later : propagte t_starts ?
@@ -222,7 +223,7 @@ def from_recording(source_recording, **job_kwargs):
222223
dtype=source_recording.dtype,
223224
sampling_frequency=source_recording.sampling_frequency,
224225
channel_ids=source_recording.channel_ids,
225-
t_starts=None,
226+
t_starts=t_starts,
226227
main_shm_owner=True,
227228
)
228229

0 commit comments

Comments
 (0)