1515from spikeinterface .preprocessing .inter_session_alignment import alignment_utils
1616from spikeinterface .preprocessing .motion import run_peak_detection_pipeline_node
1717import copy
18- import scipy
19- import matplotlib .pyplot as plt
20- from scipy .ndimage import gaussian_filter
21- import matplotlib .pyplot as plt
2218
2319
2420def get_estimate_histogram_kwargs () -> dict :
@@ -54,7 +50,8 @@ def get_estimate_histogram_kwargs() -> dict:
5450 "log_scale" : False ,
5551 "depth_smooth_um" : None ,
5652 "histogram_type" : "activity_1d" ,
57- "weight_with_amplitude" : True ,
53+ "weight_with_amplitude" : False ,
54+ "avg_in_bin" : False , # TODO
5855 }
5956
6057
@@ -111,8 +108,12 @@ def get_interpolate_motion_kwargs():
111108 Settings to pass to `InterpolateMotionRecording`,
112109 see that class for parameter descriptions.
113110 """
114- return {"border_mode" : "remove_channels" , "spatial_interpolation_method" : "kriging" , "sigma_um" : 20.0 , "p" : 2 }
115-
111+ return {
112+ "border_mode" : "force_zeros" , # fixed as this until can figure out probe
113+ "spatial_interpolation_method" : "kriging" ,
114+ "sigma_um" : 20.0 ,
115+ "p" : 2
116+ }
116117
117118###############################################################################
118119# Public Entry Level Functions
@@ -221,7 +222,7 @@ def align_sessions(
221222
222223 # Ensure list lengths match and all channel locations are the same across recordings.
223224 _check_align_sessions_inputs (
224- recordings_list , peaks_list , peak_locations_list , alignment_order , estimate_histogram_kwargs
225+ recordings_list , peaks_list , peak_locations_list , alignment_order , estimate_histogram_kwargs , interpolate_motion_kwargs
225226 )
226227
227228 print ("Computing a single activity histogram from each session..." )
@@ -400,6 +401,7 @@ def _compute_session_histograms(
400401 depth_smooth_um : float ,
401402 log_scale : bool ,
402403 weight_with_amplitude : bool ,
404+ avg_in_bin : bool ,
403405) -> tuple [list [np .ndarray ], list [np .ndarray ], np .ndarray , np .ndarray , list [dict ]]:
404406 """
405407 Compute a 1d activity histogram for the session. As
@@ -464,6 +466,7 @@ def _compute_session_histograms(
464466 chunked_bin_size_s ,
465467 depth_smooth_um ,
466468 weight_with_amplitude ,
469+ avg_in_bin ,
467470 )
468471 temporal_bin_centers_list .append (temporal_bin_centers )
469472 session_histogram_list .append (session_hist )
@@ -489,6 +492,7 @@ def _get_single_session_activity_histogram(
489492 chunked_bin_size_s : float | "estimate" ,
490493 depth_smooth_um : float ,
491494 weight_with_amplitude : bool ,
495+ avg_in_bin : bool ,
492496) -> tuple [np .ndarray , np .ndarray , dict ]:
493497 """
494498 Compute an activity histogram for a single session.
@@ -544,11 +548,12 @@ def _get_single_session_activity_histogram(
544548 bin_s = None ,
545549 depth_smooth_um = None ,
546550 scale_to_hz = False ,
547- weight_with_amplitude = weight_with_amplitude ,
551+ weight_with_amplitude = False ,
552+ avg_in_bin = False ,
548553 )
549554
550555 # It is important that the passed histogram is scaled to firing rate in Hz
551- scaled_hist = one_bin_histogram / recording .get_duration ()
556+ scaled_hist = one_bin_histogram / recording .get_duration () # TODO: why is this done here when have a scale_to_hz arg??!?
552557 chunked_bin_size_s = alignment_utils .estimate_chunk_size (scaled_hist )
553558 chunked_bin_size_s = np .min ([chunked_bin_size_s , recording .get_duration ()])
554559
@@ -563,6 +568,7 @@ def _get_single_session_activity_histogram(
563568 bin_s = chunked_bin_size_s ,
564569 depth_smooth_um = depth_smooth_um ,
565570 weight_with_amplitude = weight_with_amplitude ,
571+ avg_in_bin = avg_in_bin ,
566572 scale_to_hz = True ,
567573 )
568574
@@ -645,7 +651,14 @@ def _create_motion_recordings(
645651
646652 corrected_recording = _add_displacement_to_interpolate_recording (recording , motion )
647653 else :
648- corrected_recording = InterpolateMotionRecording (recording , motion , ** interpolate_motion_kwargs )
654+ corrected_recording = InterpolateMotionRecording (
655+ recording ,
656+ motion ,
657+ interpolation_time_bin_centers_s = motion .temporal_bins_s ,
658+ interpolation_time_bin_edges_s = [np .array (recording .get_times ()[0 ], recording .get_times ()[- 1 ])],
659+ ** interpolate_motion_kwargs
660+ )
661+ corrected_recording = corrected_recording .set_probe (recording .get_probe ()) # TODO: if this works, might need to do above
649662
650663 corrected_recordings_list .append (corrected_recording )
651664
@@ -780,6 +793,7 @@ def _correct_session_displacement(
780793 estimate_histogram_kwargs ["chunked_bin_size_s" ],
781794 estimate_histogram_kwargs ["depth_smooth_um" ],
782795 estimate_histogram_kwargs ["weight_with_amplitude" ],
796+ estimate_histogram_kwargs ["avg_in_bin" ],
783797 )
784798 corrected_session_histogram_list .append (session_hist )
785799
@@ -927,6 +941,7 @@ def _check_align_sessions_inputs(
927941 peak_locations_list : list [np .ndarray ],
928942 alignment_order : str ,
929943 estimate_histogram_kwargs : dict ,
944+ interpolate_motion_kwargs : dict ,
930945):
931946 """
932947 Perform checks on the input of `align_sessions()`
@@ -946,13 +961,14 @@ def _check_align_sessions_inputs(
946961 )
947962
948963 channel_locs = [rec .get_channel_locations () for rec in recordings_list ]
949- if not all (np .array_equal (locs , channel_locs [0 ]) for locs in channel_locs ):
964+ if not all ([ np .array_equal (locs , channel_locs [0 ]) for locs in channel_locs ] ):
950965 raise ValueError (
951966 "The recordings in `recordings_list` do not all have "
952967 "the same channel locations. All recordings must be "
953968 "performed using the same probe."
954969 )
955970
971+
956972 accepted_hist_methods = [
957973 "entire_session" ,
958974 "chunked_mean" ,
@@ -981,3 +997,5 @@ def _check_align_sessions_inputs(
981997
982998 if ses_num == 0 :
983999 raise ValueError ("`alignment_order` required the session number, not session index." )
1000+
1001+ assert interpolate_motion_kwargs ["border_mode" ] == "force_zeros" , "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
0 commit comments