@@ -71,6 +71,8 @@ def get_compute_alignment_kwargs() -> dict:
7171 windows along the probe depth. See `get_spatial_windows`.
7272 """
7373 return {
74+ "num_shifts_global" : None ,
75+ "num_shifts_block" : 20 ,
7476 "interpolate" : False ,
7577 "interp_factor" : 10 ,
7678 "kriging_sigma" : 1 ,
@@ -93,8 +95,6 @@ def get_non_rigid_window_kwargs():
9395 """
9496 return {
9597 "rigid" : True ,
96- "num_shifts_global" : None ,
97- "num_shifts_block" : 20 ,
9898 "win_shape" : "gaussian" ,
9999 "win_step_um" : 50 ,
100100 "win_scale_um" : 50 ,
@@ -109,12 +109,13 @@ def get_interpolate_motion_kwargs():
109109 see that class for parameter descriptions.
110110 """
111111 return {
112- "border_mode" : "force_zeros" , # fixed as this until can figure out probe
112+ "border_mode" : "force_zeros" , # fixed as this until can figure out probe
113113 "spatial_interpolation_method" : "kriging" ,
114114 "sigma_um" : 20.0 ,
115- "p" : 2
115+ "p" : 2 ,
116116 }
117117
118+
118119###############################################################################
119120# Public Entry Level Functions
120121###############################################################################
@@ -222,7 +223,12 @@ def align_sessions(
222223
223224 # Ensure list lengths match and all channel locations are the same across recordings.
224225 _check_align_sessions_inputs (
225- recordings_list , peaks_list , peak_locations_list , alignment_order , estimate_histogram_kwargs , interpolate_motion_kwargs
226+ recordings_list ,
227+ peaks_list ,
228+ peak_locations_list ,
229+ alignment_order ,
230+ estimate_histogram_kwargs ,
231+ interpolate_motion_kwargs ,
226232 )
227233
228234 print ("Computing a single activity histogram from each session..." )
@@ -311,7 +317,10 @@ def align_sessions_after_motion_correction(
311317 )
312318
313319 motion_window_kwargs = copy .deepcopy (motion_kwargs_list [0 ])
314- if motion_window_kwargs ["direction" ] != "y" :
320+
321+ if (
322+ "direction" in motion_window_kwargs and motion_window_kwargs ["direction" ] != "y"
323+ ): # TODO: why is this not in all?
315324 raise ValueError ("motion correct must have been performed along the 'y' dimension." )
316325
317326 if align_sessions_kwargs is None :
@@ -322,24 +331,37 @@ def align_sessions_after_motion_correction(
322331 # shifts together.
323332 if (
324333 "non_rigid_window_kwargs" in align_sessions_kwargs
325- and "nonrigid" in align_sessions_kwargs ["non_rigid_window_kwargs" ]["rigid_mode " ]
334+ and not align_sessions_kwargs ["non_rigid_window_kwargs" ]["rigid " ]
326335 ):
327-
336+ # TODO: carefully walk through this function! and test all assumptions...
328337 if not motion_window_kwargs ["rigid" ]:
329- print (
338+ print ( # TODO: make a warning
330339 "Nonrigid inter-session alignment must use the motion correct "
331340 "nonrigid settings. Overwriting any passed `non_rigid_window_kwargs` "
332341 "with the motion object non_rigid_window_kwargs."
333342 )
334- motion_window_kwargs .pop ("method" )
335- motion_window_kwargs .pop ("direction" )
343+ non_rigid_window_kwargs = get_non_rigid_window_kwargs ()
344+
345+ # TODO: generate function for replacing one dict into another?
346+ for (
347+ k ,
348+ v ,
349+ ) in motion_window_kwargs .items (): # TODO: can get tighter alignment here with original implementation?
350+ if k in non_rigid_window_kwargs :
351+ non_rigid_window_kwargs [k ] = v
352+
336353 align_sessions_kwargs = copy .deepcopy (align_sessions_kwargs )
337- align_sessions_kwargs ["non_rigid_window_kwargs" ] = motion_window_kwargs
354+ align_sessions_kwargs ["non_rigid_window_kwargs" ] = non_rigid_window_kwargs
355+
356+ corrected_peak_locations = [
357+ correct_motion_on_peaks (info ["peaks" ], info ["peak_locations" ], info ["motion" ], recording )
358+ for info , recording in zip (motion_info_list , recordings_list )
359+ ]
338360
339361 return align_sessions (
340362 recordings_list ,
341363 [info ["peaks" ] for info in motion_info_list ],
342- [ info [ "peak_locations" ] for info in motion_info_list ] ,
364+ corrected_peak_locations ,
343365 ** align_sessions_kwargs ,
344366 )
345367
@@ -459,14 +481,14 @@ def _compute_session_histograms(
459481 recording ,
460482 peaks ,
461483 peak_locations ,
462- histogram_type ,
463- spatial_bin_edges ,
464- method ,
465- log_scale ,
466- chunked_bin_size_s ,
467- depth_smooth_um ,
468- weight_with_amplitude ,
469- avg_in_bin ,
484+ histogram_type = histogram_type ,
485+ spatial_bin_edges = spatial_bin_edges ,
486+ method = method ,
487+ log_scale = log_scale ,
488+ chunked_bin_size_s = chunked_bin_size_s ,
489+ depth_smooth_um = depth_smooth_um ,
490+ weight_with_amplitude = weight_with_amplitude ,
491+ avg_in_bin = avg_in_bin ,
470492 )
471493 temporal_bin_centers_list .append (temporal_bin_centers )
472494 session_histogram_list .append (session_hist )
@@ -539,32 +561,31 @@ def _get_single_session_activity_histogram(
539561 # full estimation for chunked bin size
540562 if chunked_bin_size_s == "estimate" :
541563
542- one_bin_histogram , _ , _ = alignment_utils .get_activity_histogram (
564+ scaled_hist , _ , _ = alignment_utils .get_2d_activity_histogram (
543565 recording ,
544566 peaks ,
545567 peak_locations ,
546568 spatial_bin_edges ,
547569 log_scale = False ,
548570 bin_s = None ,
549571 depth_smooth_um = None ,
550- scale_to_hz = False ,
572+ scale_to_hz = True ,
551573 weight_with_amplitude = False ,
552574 avg_in_bin = False ,
553575 )
554576
555577 # It is important that the passed histogram is scaled to firing rate in Hz
556- scaled_hist = one_bin_histogram / recording .get_duration () # TODO: why is this done here when have a scale_to_hz arg??!?
557578 chunked_bin_size_s = alignment_utils .estimate_chunk_size (scaled_hist )
558579 chunked_bin_size_s = np .min ([chunked_bin_size_s , recording .get_duration ()])
559580
560581 if histogram_type == "activity_1d" :
561582
562- chunked_histograms , chunked_temporal_bin_centers , _ = alignment_utils .get_activity_histogram (
583+ chunked_histograms , chunked_temporal_bin_centers , _ = alignment_utils .get_2d_activity_histogram (
563584 recording ,
564585 peaks ,
565586 peak_locations ,
566587 spatial_bin_edges ,
567- log_scale ,
588+ log_scale = log_scale ,
568589 bin_s = chunked_bin_size_s ,
569590 depth_smooth_um = depth_smooth_um ,
570591 weight_with_amplitude = weight_with_amplitude ,
@@ -656,9 +677,11 @@ def _create_motion_recordings(
656677 motion ,
657678 interpolation_time_bin_centers_s = motion .temporal_bins_s ,
658679 interpolation_time_bin_edges_s = [np .array (recording .get_times ()[0 ], recording .get_times ()[- 1 ])],
659- ** interpolate_motion_kwargs
680+ ** interpolate_motion_kwargs ,
660681 )
661- corrected_recording = corrected_recording .set_probe (recording .get_probe ()) # TODO: if this works, might need to do above
682+ corrected_recording = corrected_recording .set_probe (
683+ recording .get_probe ()
684+ ) # TODO: if this works, might need to do above
662685
663686 corrected_recordings_list .append (corrected_recording )
664687
@@ -840,8 +863,8 @@ def _compute_session_alignment(
840863 session_histogram_array = np .array (session_histogram_list )
841864
842865 akima_interp_nonrigid = compute_alignment_kwargs .pop ("akima_interp_nonrigid" )
843- num_shifts_global = non_rigid_window_kwargs .pop ("num_shifts_global" )
844- num_shifts_block = non_rigid_window_kwargs .pop ("num_shifts_block" )
866+ num_shifts_global = compute_alignment_kwargs .pop ("num_shifts_global" )
867+ num_shifts_block = compute_alignment_kwargs .pop ("num_shifts_block" )
845868
846869 non_rigid_windows , non_rigid_window_centers = get_spatial_windows (
847870 contact_depths , spatial_bin_centers , ** non_rigid_window_kwargs
@@ -870,7 +893,7 @@ def _compute_session_alignment(
870893
871894 # Then compute the nonrigid shifts
872895 nonrigid_session_offsets_matrix = alignment_utils .compute_histogram_crosscorrelation (
873- shifted_histograms , non_rigid_windows , num_shifts_block , ** compute_alignment_kwargs
896+ shifted_histograms , non_rigid_windows , num_shifts = num_shifts_block , ** compute_alignment_kwargs
874897 )
875898 non_rigid_shifts = alignment_utils .get_shifts_from_session_matrix (alignment_order , nonrigid_session_offsets_matrix )
876899
@@ -920,7 +943,7 @@ def _estimate_rigid_alignment(
920943 rigid_session_offsets_matrix = alignment_utils .compute_histogram_crosscorrelation (
921944 session_histogram_array ,
922945 rigid_window ,
923- num_shifts ,
946+ num_shifts = num_shifts ,
924947 ** compute_alignment_kwargs , # TODO: remove the copy above and pass directly. Consider removing this function...
925948 )
926949 optimal_shift_indices = alignment_utils .get_shifts_from_session_matrix (
@@ -968,7 +991,6 @@ def _check_align_sessions_inputs(
968991 "performed using the same probe."
969992 )
970993
971-
972994 accepted_hist_methods = [
973995 "entire_session" ,
974996 "chunked_mean" ,
@@ -998,4 +1020,6 @@ def _check_align_sessions_inputs(
9981020 if ses_num == 0 :
9991021 raise ValueError ("`alignment_order` required the session number, not session index." )
10001022
1001- assert interpolate_motion_kwargs ["border_mode" ] == "force_zeros" , "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
1023+ assert (
1024+ interpolate_motion_kwargs ["border_mode" ] == "force_zeros"
1025+ ), "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
0 commit comments