Skip to content

Commit 1bf0487

Browse files
committed
Start adding tests.
1 parent 0b829d9 commit 1bf0487

File tree

7 files changed

+402
-18
lines changed

7 files changed

+402
-18
lines changed

debugging/playing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
# Load / generate some recordings
3131
# --------------------------------------------------------------------------------------
3232

33+
# try num units 5 and 65
34+
3335
recordings_list, _ = generate_session_displacement_recordings(
34-
num_units=65,
36+
num_units=5,
3537
recording_durations=[200, 200, 200],
3638
recording_shifts=((0, 0), (0, -200), (0, 150)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
3739
non_rigid_gradient=None, # 0.1, # 0.1,

playing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from spikeinterface.generation import generate_drifting_recording
2+
from spikeinterface.preprocessing.motion import correct_motion
3+
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
4+
5+
rec = generate_drifting_recording(duration=100)[0]
6+
7+
proc_rec = correct_motion(rec)
8+
9+
rec.set_probe(rec.get_probe())
10+

src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_activity_histogram(
2626
depth_smooth_um: float | None,
2727
scale_to_hz: bool = False,
2828
weight_with_amplitude: bool = False,
29+
avg_in_bin: bool = True,
2930
):
3031
"""
3132
Generate a 2D activity histogram for the session. Wraps the underlying
@@ -69,6 +70,7 @@ def get_activity_histogram(
6970
hist_margin_um=None,
7071
spatial_bin_edges=spatial_bin_edges,
7172
depth_smooth_um=depth_smooth_um,
73+
avg_in_bin=avg_in_bin,
7274
)
7375
assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing"
7476

@@ -88,7 +90,6 @@ def get_activity_histogram(
8890

8991
return activity_histogram, temporal_bin_centers, spatial_bin_centers
9092

91-
9293
def get_bin_centers(bin_edges):
9394
return (bin_edges[1:] + bin_edges[:-1]) / 2
9495

@@ -310,6 +311,14 @@ def compute_histogram_crosscorrelation(
310311
windowed_histogram_j - np.mean(windowed_histogram_i),
311312
mode="full",
312313
)
314+
import os
315+
if "hello_world" in os.environ:
316+
plt.plot(windowed_histogram_i)
317+
plt.plot(windowed_histogram_j)
318+
plt.show()
319+
320+
plt.plot(xcorr)
321+
plt.show()
313322

314323
if num_shifts:
315324
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)

src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
from spikeinterface.preprocessing.inter_session_alignment import alignment_utils
1616
from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node
1717
import copy
18-
import scipy
19-
import matplotlib.pyplot as plt
20-
from scipy.ndimage import gaussian_filter
21-
import matplotlib.pyplot as plt
2218

2319

2420
def get_estimate_histogram_kwargs() -> dict:
@@ -54,7 +50,8 @@ def get_estimate_histogram_kwargs() -> dict:
5450
"log_scale": False,
5551
"depth_smooth_um": None,
5652
"histogram_type": "activity_1d",
57-
"weight_with_amplitude": True,
53+
"weight_with_amplitude": False,
54+
"avg_in_bin": False, # TODO
5855
}
5956

6057

@@ -111,8 +108,12 @@ def get_interpolate_motion_kwargs():
111108
Settings to pass to `InterpolateMotionRecording`,
112109
see that class for parameter descriptions.
113110
"""
114-
return {"border_mode": "remove_channels", "spatial_interpolation_method": "kriging", "sigma_um": 20.0, "p": 2}
115-
111+
return {
112+
"border_mode": "force_zeros", # fixed as this until can figure out probe
113+
"spatial_interpolation_method": "kriging",
114+
"sigma_um": 20.0,
115+
"p": 2
116+
}
116117

117118
###############################################################################
118119
# Public Entry Level Functions
@@ -221,7 +222,7 @@ def align_sessions(
221222

222223
# Ensure list lengths match and all channel locations are the same across recordings.
223224
_check_align_sessions_inputs(
224-
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs
225+
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs, interpolate_motion_kwargs
225226
)
226227

227228
print("Computing a single activity histogram from each session...")
@@ -400,6 +401,7 @@ def _compute_session_histograms(
400401
depth_smooth_um: float,
401402
log_scale: bool,
402403
weight_with_amplitude: bool,
404+
avg_in_bin: bool,
403405
) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]:
404406
"""
405407
Compute a 1d activity histogram for the session. As
@@ -464,6 +466,7 @@ def _compute_session_histograms(
464466
chunked_bin_size_s,
465467
depth_smooth_um,
466468
weight_with_amplitude,
469+
avg_in_bin,
467470
)
468471
temporal_bin_centers_list.append(temporal_bin_centers)
469472
session_histogram_list.append(session_hist)
@@ -489,6 +492,7 @@ def _get_single_session_activity_histogram(
489492
chunked_bin_size_s: float | "estimate",
490493
depth_smooth_um: float,
491494
weight_with_amplitude: bool,
495+
avg_in_bin: bool,
492496
) -> tuple[np.ndarray, np.ndarray, dict]:
493497
"""
494498
Compute an activity histogram for a single session.
@@ -544,11 +548,12 @@ def _get_single_session_activity_histogram(
544548
bin_s=None,
545549
depth_smooth_um=None,
546550
scale_to_hz=False,
547-
weight_with_amplitude=weight_with_amplitude,
551+
weight_with_amplitude=False,
552+
avg_in_bin=False,
548553
)
549554

550555
# It is important that the passed histogram is scaled to firing rate in Hz
551-
scaled_hist = one_bin_histogram / recording.get_duration()
556+
scaled_hist = one_bin_histogram / recording.get_duration() # TODO: why is this done here when have a scale_to_hz arg??!?
552557
chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist)
553558
chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()])
554559

@@ -563,6 +568,7 @@ def _get_single_session_activity_histogram(
563568
bin_s=chunked_bin_size_s,
564569
depth_smooth_um=depth_smooth_um,
565570
weight_with_amplitude=weight_with_amplitude,
571+
avg_in_bin=avg_in_bin,
566572
scale_to_hz=True,
567573
)
568574

@@ -645,7 +651,14 @@ def _create_motion_recordings(
645651

646652
corrected_recording = _add_displacement_to_interpolate_recording(recording, motion)
647653
else:
648-
corrected_recording = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
654+
corrected_recording = InterpolateMotionRecording(
655+
recording,
656+
motion,
657+
interpolation_time_bin_centers_s=motion.temporal_bins_s,
658+
interpolation_time_bin_edges_s=[np.array(recording.get_times()[0], recording.get_times()[-1])],
659+
**interpolate_motion_kwargs
660+
)
661+
corrected_recording = corrected_recording.set_probe(recording.get_probe()) # TODO: if this works, might need to do above
649662

650663
corrected_recordings_list.append(corrected_recording)
651664

@@ -780,6 +793,7 @@ def _correct_session_displacement(
780793
estimate_histogram_kwargs["chunked_bin_size_s"],
781794
estimate_histogram_kwargs["depth_smooth_um"],
782795
estimate_histogram_kwargs["weight_with_amplitude"],
796+
estimate_histogram_kwargs["avg_in_bin"],
783797
)
784798
corrected_session_histogram_list.append(session_hist)
785799

@@ -927,6 +941,7 @@ def _check_align_sessions_inputs(
927941
peak_locations_list: list[np.ndarray],
928942
alignment_order: str,
929943
estimate_histogram_kwargs: dict,
944+
interpolate_motion_kwargs: dict,
930945
):
931946
"""
932947
Perform checks on the input of `align_sessions()`
@@ -946,13 +961,14 @@ def _check_align_sessions_inputs(
946961
)
947962

948963
channel_locs = [rec.get_channel_locations() for rec in recordings_list]
949-
if not all(np.array_equal(locs, channel_locs[0]) for locs in channel_locs):
964+
if not all([np.array_equal(locs, channel_locs[0]) for locs in channel_locs]):
950965
raise ValueError(
951966
"The recordings in `recordings_list` do not all have "
952967
"the same channel locations. All recordings must be "
953968
"performed using the same probe."
954969
)
955970

971+
956972
accepted_hist_methods = [
957973
"entire_session",
958974
"chunked_mean",
@@ -981,3 +997,5 @@ def _check_align_sessions_inputs(
981997

982998
if ses_num == 0:
983999
raise ValueError("`alignment_order` required the session number, not session index.")
1000+
1001+
assert interpolate_motion_kwargs["border_mode"] == "force_zeros", "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam

0 commit comments

Comments
 (0)