66from spikeinterface .preprocessing .basepreprocessor import BasePreprocessor , BasePreprocessorSegment
77from spikeinterface .preprocessing .filter import fix_dtype
88
9+ from .motion_utils import ensure_time_bin_edges , ensure_time_bins
10+
911
1012def correct_motion_on_peaks (peaks , peak_locations , motion , recording ) -> np .ndarray :
1113 """
@@ -54,14 +56,19 @@ def interpolate_motion_on_traces(
5456 segment_index = None ,
5557 channel_inds = None ,
5658 interpolation_time_bin_centers_s = None ,
59+ interpolation_time_bin_edges_s = None ,
5760 spatial_interpolation_method = "kriging" ,
5861 spatial_interpolation_kwargs = {},
5962 dtype = None ,
6063):
6164 """
6265 Apply inverse motion with spatial interpolation on traces.
6366
64- Traces can be full traces, but also waveforms snippets.
67+ Traces can be full traces, but also waveforms snippets. Times used for looking up
68+ displacements are controlled by interpolation_time_bin_edges_s or
69+ interpolation_time_bin_centers_s, or fall back to the Motion object's time bins
70+ by default; times in the recording outside these time bins use the closest edge
71+ bin's displacement value during interpolation.
6572
6673 Parameters
6774 ----------
@@ -80,6 +87,9 @@ def interpolate_motion_on_traces(
8087 interpolation_time_bin_centers_s : None or np.array
8188 Manually specify the time bins which the interpolation happens
8289 in for this segment. If None, these are the motion estimate's time bins.
90+ interpolation_time_bin_edges_s : None or np.array
91+ If present, interpolation chunks will be the time bins defined by these edges
92+ rather than interpolation_time_bin_centers_s or the motion's bins.
8393 spatial_interpolation_method : "idw" | "kriging", default: "kriging"
8494 The spatial interpolation method used to interpolate the channel locations:
8595 * idw : Inverse Distance Weighing
@@ -119,26 +129,33 @@ def interpolate_motion_on_traces(
119129 total_num_chans = channel_locations .shape [0 ]
120130
121131 # -- determine the blocks of frames that will land in the same interpolation time bin
122- time_bins = interpolation_time_bin_centers_s
123- if time_bins is None :
124- time_bins = motion .temporal_bins_s [segment_index ]
125- bin_s = time_bins [1 ] - time_bins [0 ]
126- bins_start = time_bins [0 ] - 0.5 * bin_s
127- # nearest bin center for each frame?
128- bin_inds = (times - bins_start ) // bin_s
129- bin_inds = bin_inds .astype (int )
132+ if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None :
133+ interpolation_time_bin_centers_s = motion .temporal_bins_s [segment_index ]
134+ interpolation_time_bin_edges_s = motion .temporal_bin_edges_s [segment_index ]
135+ else :
136+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s = ensure_time_bins (
137+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s
138+ )
139+
140+ # bin the frame times according to the interpolation time bins.
141+ # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
142+ # hence the -1. doing it with "left" is not as nice -- we want t==b[0]
143+ # to lead to i=1 (rounding down).
144+ interpolation_bin_inds = np .searchsorted (interpolation_time_bin_edges_s , times , side = "right" ) - 1
145+
130146 # the time bins may not cover the whole set of times in the recording,
131147 # so we need to clip these indices to the valid range
132- np .clip (bin_inds , 0 , time_bins .size , out = bin_inds )
148+ n_bins = interpolation_time_bin_edges_s .shape [0 ] - 1
149+ np .clip (interpolation_bin_inds , 0 , n_bins - 1 , out = interpolation_bin_inds )
133150
134151 # -- what are the possibilities here anyway?
135- bins_here = np .arange (bin_inds [0 ], bin_inds [- 1 ] + 1 )
152+ interpolation_bins_here = np .arange (interpolation_bin_inds [0 ], interpolation_bin_inds [- 1 ] + 1 )
136153
137154 # inperpolation kernel will be the same per temporal bin
138155 interp_times = np .empty (total_num_chans )
139156 current_start_index = 0
140- for bin_ind in bins_here :
141- bin_time = time_bins [ bin_ind ]
157+ for interp_bin_ind in interpolation_bins_here :
158+ bin_time = interpolation_time_bin_centers_s [ interp_bin_ind ]
142159 interp_times .fill (bin_time )
143160 channel_motions = motion .get_displacement_at_time_and_depth (
144161 interp_times ,
@@ -166,16 +183,17 @@ def interpolate_motion_on_traces(
166183 # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}")
167184 # plt.show()
168185
186+ # quick search logic to find frames corresponding to this interpolation bin in the recording
169187 # quickly find the end of this bin, which is also the start of the next
170188 next_start_index = current_start_index + np .searchsorted (
171- bin_inds [current_start_index :], bin_ind + 1 , side = "left"
189+ interpolation_bin_inds [current_start_index :], interp_bin_ind + 1 , side = "left"
172190 )
173- in_bin = slice (current_start_index , next_start_index )
191+ frames_in_bin = slice (current_start_index , next_start_index )
174192
175193 # here we use a simple np.matmul even if dirft_kernel can be super sparse.
176194 # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
177195 # in ChunkRecordingExecutor)
178- np .matmul (traces [in_bin ], drift_kernel , out = traces_corrected [in_bin ])
196+ np .matmul (traces [frames_in_bin ], drift_kernel , out = traces_corrected [frames_in_bin ])
179197 current_start_index = next_start_index
180198
181199 return traces_corrected
@@ -297,6 +315,7 @@ def __init__(
297315 p = 1 ,
298316 num_closest = 3 ,
299317 interpolation_time_bin_centers_s = None ,
318+ interpolation_time_bin_edges_s = None ,
300319 interpolation_time_bin_size_s = None ,
301320 dtype = None ,
302321 ** spatial_interpolation_kwargs ,
@@ -363,9 +382,14 @@ def __init__(
363382
364383 # handle manual interpolation_time_bin_centers_s
365384 # the case where interpolation_time_bin_size_s is set is handled per-segment below
366- if interpolation_time_bin_centers_s is None :
385+ if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None :
367386 if interpolation_time_bin_size_s is None :
368387 interpolation_time_bin_centers_s = motion .temporal_bins_s
388+ interpolation_time_bin_edges_s = motion .temporal_bin_edges_s
389+ else :
390+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s = ensure_time_bins (
391+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s
392+ )
369393
370394 for segment_index , parent_segment in enumerate (recording ._recording_segments ):
371395 # finish the per-segment part of the time bin logic
@@ -375,8 +399,13 @@ def __init__(
375399 t_start , t_end = parent_segment .sample_index_to_time (np .array ([0 , s_end ]))
376400 halfbin = interpolation_time_bin_size_s / 2.0
377401 segment_interpolation_time_bins_s = np .arange (t_start + halfbin , t_end , interpolation_time_bin_size_s )
402+ segment_interpolation_time_bin_edges_s = np .arange (
403+ t_start , t_end + halfbin , interpolation_time_bin_size_s
404+ )
405+ assert segment_interpolation_time_bin_edges_s .shape == (segment_interpolation_time_bins_s .shape [0 ] + 1 ,)
378406 else :
379407 segment_interpolation_time_bins_s = interpolation_time_bin_centers_s [segment_index ]
408+ segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s [segment_index ]
380409
381410 rec_segment = InterpolateMotionRecordingSegment (
382411 parent_segment ,
@@ -387,6 +416,7 @@ def __init__(
387416 channel_inds ,
388417 segment_index ,
389418 segment_interpolation_time_bins_s ,
419+ segment_interpolation_time_bin_edges_s ,
390420 dtype = dtype_ ,
391421 )
392422 self .add_recording_segment (rec_segment )
@@ -420,6 +450,7 @@ def __init__(
420450 channel_inds ,
421451 segment_index ,
422452 interpolation_time_bin_centers_s ,
453+ interpolation_time_bin_edges_s ,
423454 dtype = "float32" ,
424455 ):
425456 BasePreprocessorSegment .__init__ (self , parent_recording_segment )
@@ -429,13 +460,11 @@ def __init__(
429460 self .channel_inds = channel_inds
430461 self .segment_index = segment_index
431462 self .interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
463+ self .interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
432464 self .dtype = dtype
433465 self .motion = motion
434466
435467 def get_traces (self , start_frame , end_frame , channel_indices ):
436- if self .time_vector is not None :
437- raise NotImplementedError ("InterpolateMotionRecording does not yet support recordings with time_vectors." )
438-
439468 if start_frame is None :
440469 start_frame = 0
441470 if end_frame is None :
@@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
453482 channel_inds = self .channel_inds ,
454483 spatial_interpolation_method = self .spatial_interpolation_method ,
455484 spatial_interpolation_kwargs = self .spatial_interpolation_kwargs ,
456- interpolation_time_bin_centers_s = self .interpolation_time_bin_centers_s ,
485+ interpolation_time_bin_edges_s = self .interpolation_time_bin_edges_s ,
457486 )
458487
459488 if channel_indices is not None :
0 commit comments