@@ -52,7 +52,63 @@ def generate_session_displacement_recordings(
5252 extra_outputs = False ,
5353 seed = None ,
5454):
55- """ """
55+ """
56+ Generate a set of recordings simulating probe drift across recording
57+ sessions.
58+
59+ Rigid drift can be added in the (x, y) direction in `recording_shifts`.
60+ These drifts can be made non-rigid (scaled dependent on the unit location)
61+ with the `non_rigid_gradient` parameter. Amplitude of units can be scaled
62+ (e.g. template signal removed by scaling with zero) by specifying scaling
63+ factors in `recording_amplitude_scalings`.
64+
65+ Parameters
66+ ----------
67+
68+ num_units : int
69+ The number of units in the generated recordings.
70+ recording_durations : list
71+ An array of length (num_recordings,) specifying the
72+ duration that each created recording should be.
73+ recording_shifts : list
74+ An array of length (num_recordings,) in which each element
75+ is a 2-element array specifying the (x, y) shift for the recording.
76+ Typically, the first recording will have shift (0, 0) so all further
77+ recordings are shifted relative to it. e.g. to create two recordings,
78+ the second shifted by 50 um in the x-direction and 250 um in the y
79+ direction : ((0, 0), (50, 250)).
80+ non_rigid_gradient : float
81+ Factor which sets the level of non-rigidty in the displacement.
82+ See `calculate_displacement_unit_factor` for details.
83+ recording_amplitude_scalings : dict
84+ A dict with keys:
85+ "method" - order by which to apply the scalings.
86+ "by_passed_order" - scalings are applied to the unit templates
87+ in order passed
88+ "by_firing_rate" - scalings are applied to the units in order of
89+ maximum to minimum firing rate
90+ "by_amplitude_and_firing_rate" - scalings are applied to the units
91+ in order of amplitude * firing_rate (maximum to minimum)
92+ "scalings" - a list of numpy arrays, one for each recording, with
93+ each entry an array of length num_units holding the unit scalings.
94+ e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
95+
96+ All other parameters are used as in from `generate_drifting_recording()`.
97+
98+ Returns
99+ -------
100+ output_recordings : list
101+ A list of recordings with units shifted (i.e. replicated probe shift).
102+ output_sortings : list
103+ A list of corresponding sorting objects.
104+ extra_outputs_dict (options) : dict
105+ When `extra_outputs` is `True`, a dict containing variables used
106+ in the generation process.
107+ "unit_locations" : A list (length num records) of shifted unit locations
108+ "templates_array_moved" : list[np.array]
109+ A list (length num records) of (num_units, num_samples, num_channels)
110+ arrays of templates that have been shifted.
111+ """
56112 _check_generate_session_displacement_arguments (
57113 num_units , recording_durations , recording_shifts , recording_amplitude_scalings
58114 )
@@ -82,7 +138,7 @@ def generate_session_displacement_recordings(
82138
83139 for rec_idx , (shift , duration ) in enumerate (zip (recording_shifts , recording_durations )):
84140
85- displacement_vector , displacement_unit_factor = get_inter_session_displacements (
141+ displacement_vector , displacement_unit_factor = _get_inter_session_displacements (
86142 shift ,
87143 non_rigid_gradient ,
88144 num_units ,
@@ -114,7 +170,7 @@ def generate_session_displacement_recordings(
114170 )
115171
116172 # Generate the (possibly shifted, scaled) unit templates
117- templates_moved_array = generate_templates (
173+ template_array_moved = generate_templates (
118174 channel_locations ,
119175 unit_locations_moved ,
120176 sampling_frequency = sampling_frequency ,
@@ -124,8 +180,8 @@ def generate_session_displacement_recordings(
124180
125181 if recording_amplitude_scalings is not None :
126182
127- templates_moved_array = amplitude_scale_templates_in_place (
128- templates_moved_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx
183+ template_array_moved = _amplitude_scale_templates_in_place (
184+ template_array_moved , recording_amplitude_scalings , sorting_extra_outputs , rec_idx
129185 )
130186
131187 # Bring it all together in a `InjectTemplatesRecording` and
@@ -135,7 +191,7 @@ def generate_session_displacement_recordings(
135191
136192 recording = InjectTemplatesRecording (
137193 sorting = sorting ,
138- templates = templates_moved_array ,
194+ templates = template_array_moved ,
139195 nbefore = nbefore ,
140196 amplitude_factor = None ,
141197 parent_recording = noise ,
@@ -152,19 +208,46 @@ def generate_session_displacement_recordings(
152208 output_recordings .append (recording )
153209 output_sortings .append (sorting )
154210 extra_outputs_dict ["unit_locations" ].append (unit_locations_moved )
155- extra_outputs_dict ["template_array_moved" ].append (templates_moved_array )
211+ extra_outputs_dict ["template_array_moved" ].append (template_array_moved )
156212
157213 if extra_outputs :
158214 return output_recordings , output_sortings , extra_outputs_dict
159215 else :
160216 return output_recordings , output_sortings
161217
162218
163- def get_inter_session_displacements (shift , non_rigid_gradient , num_units , unit_locations ):
164- """ """
219+ def _get_inter_session_displacements (shift , non_rigid_gradient , num_units , unit_locations ):
220+ """
221+ Get the formatted `displacement_vector` and `displacement_unit_factor`
222+ used to shift the `unit_locations`..
223+
224+ Parameters
225+ ---------
226+ shift : np.array | list | tuple
227+ A 2-element array with the shift in the (x, y) direction.
228+ non_rigid_gradient : float
229+ Factor which sets the level of non-rigidty in the displacement.
230+ See `calculate_displacement_unit_factor` for details.
231+ num_units : int
232+ Number of units
233+ unit_locations : np.array
234+ (num_units, 3) array of unit locations (x, y, z).
235+
236+ Returns
237+ -------
238+ displacement_vector : np.array
239+ A (:, 2) array of (x, y) of displacements
240+ to add to (i.e. move) unit_locations.
241+ e.g. array([[1, 2]])
242+ displacement_unit_factor : np.array
243+ A (num_units, :) array of scaling values to apply to the
244+ displacement vector in order to add nonrigid shift to
245+ the displacement. Note the same scaling is applied to the
246+ x and y dimension.
247+ """
165248 displacement_vector = np .atleast_2d (shift )
166249
167- if non_rigid_gradient is None or shift == ( 0 , 0 ):
250+ if non_rigid_gradient is None or ( shift [ 0 ] == 0 and shift [ 1 ] == 0 ):
168251 displacement_unit_factor = np .ones ((num_units , 1 ))
169252 else :
170253 displacement_unit_factor = calculate_displacement_unit_factor (
@@ -178,8 +261,38 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l
178261 return displacement_vector , displacement_unit_factor
179262
180263
181- def amplitude_scale_templates_in_place (templates_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx ):
182- """ """
264+ def _amplitude_scale_templates_in_place (templates_array , recording_amplitude_scalings , sorting_extra_outputs , rec_idx ):
265+ """
266+ Scale a set of templates given a set of scaling values. The scaling
267+ values can be applied in the order passed, or instead in order of
268+ the unit firing range (max to min) or unit amplitude * firing rate (max to min).
269+ This will chang the `templates_array` in place.
270+
271+ Parameters
272+ ----------
273+ templates_array : np.array
274+ A (num_units, num_samples, num_channels) array of
275+ template waveforms for all units.
276+ recording_amplitude_scalings : dict
277+ see `generate_session_displacement_recordings()`.
278+ sorting_extra_outputs : dict
279+ Extra output of `generate_sorting` holding the firing frequency of all units.
280+ The unit order is assumed to match the templates.
281+ rec_idx : int
282+ The index of the recording for which the templates are being scaled.
283+
284+ Notes
285+ -----
286+ This method is used in the context of inter-session displacement. Often,
287+ units may drop out of the recording across sessions. This simulates this by
288+ directly scaling the template (e.g. if scaling by 0, the template is completely
289+ dropped out). The provided scalings can be applied in the order passed, or
290+ in the order of unit firing rate or firing rate * amplitude. The idea is
291+ that it may be desired to remove to downscale the most activate neurons
292+ that contribute most significantly to activity histograms. Similarly,
293+ if amplitude is included in activity histograms the amplitude may
294+ also want to be considered when ordering the units to downscale.
295+ """
183296 method = recording_amplitude_scalings ["method" ]
184297
185298 if method in ["by_amplitude_and_firing_rate" , "by_firing_rate" ]:
0 commit comments