@@ -103,11 +103,11 @@ def generate_sorting(
103103 Parameters
104104 ----------
105105 num_units : int, default: 5
106- Number of units
106+ Number of units.
107107 sampling_frequency : float, default: 30000.0
108- The sampling frequency
108+ The sampling frequency.
109109 durations : list, default: [10.325, 3.5]
110- Duration of each segment in s
110+ Duration of each segment in s.
111111 firing_rates : float, default: 3.0
112112 The firing rate of each unit (in Hz).
113113 empty_units : list, default: None
@@ -123,12 +123,12 @@ def generate_sorting(
123123 border_size_samples : int, default: 20
124124 The size of the border in samples to add border spikes.
125125 seed : int, default: None
126- The random seed
126+ The random seed.
127127
128128 Returns
129129 -------
130130 sorting : NumpySorting
131- The sorting object
131+ The sorting object.
132132 """
133133 seed = _ensure_seed (seed )
134134 rng = np .random .default_rng (seed )
@@ -187,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
187187 Parameters
188188 ----------
189189 sorting : BaseSorting
190- The sorting object
190+ The sorting object.
191191 sync_event_ratio : float
192192 The ratio of added synchronous spikes with respect to the total number of spikes.
193193 E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra
194194 spikes are synchronous (same sample_index), but on different units (not duplicates).
195195 seed : int, default: None
196- The random seed
196+ The random seed.
197197
198198
199199 Returns
200200 -------
201201 sorting : TransformSorting
202- The sorting object, keeping track of added spikes
202+ The sorting object, keeping track of added spikes.
203203
204204 """
205205 rng = np .random .default_rng (seed )
@@ -249,18 +249,18 @@ def generate_sorting_to_inject(
249249 Parameters
250250 ----------
251251 sorting : BaseSorting
252- The sorting object
252+ The sorting object.
253253 num_samples: list of size num_segments.
254254 The number of samples in all the segments of the sorting, to generate spike times
255- covering entire the entire duration of the segments
255+ covering entire the entire duration of the segments.
256256 max_injected_per_unit: int, default 1000
257- The maximal number of spikes injected per units
257+ The maximal number of spikes injected per units.
258258 injected_rate: float, default 0.05
259- The rate at which spikes are injected
259+ The rate at which spikes are injected.
260260 refractory_period_ms: float, default 1.5
261- The refractory period that should not be violated while injecting new spikes
261+ The refractory period that should not be violated while injecting new spikes.
262262 seed: int, default None
263- The random seed
263+ The random seed.
264264
265265 Returns
266266 -------
@@ -312,22 +312,22 @@ class TransformSorting(BaseSorting):
312312 Parameters
313313 ----------
314314 sorting : BaseSorting
315- The sorting object
315+ The sorting object.
316316 added_spikes_existing_units : np.array (spike_vector)
317- The spikes that should be added to the sorting object, for existing units
317+ The spikes that should be added to the sorting object, for existing units.
318318 added_spikes_new_units: np.array (spike_vector)
319- The spikes that should be added to the sorting object, for new units
319+ The spikes that should be added to the sorting object, for new units.
320320 new_units_ids: list
321- The unit_ids that should be added if spikes for new units are added
321+ The unit_ids that should be added if spikes for new units are added.
322322 refractory_period_ms : float, default None
323323 The refractory period violation to prevent duplicates and/or unphysiological addition
324324 of spikes. Any spike times in added_spikes violating the refractory period will be
325- discarded
325+ discarded.
326326
327327 Returns
328328 -------
329329 sorting : TransformSorting
330- The sorting object with the added spikes and/or units
330+ The sorting object with the added spikes and/or units.
331331 """
332332
333333 def __init__ (
@@ -428,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe
428428
429429 Parameters
430430 ----------
431- sorting1: the first sorting
432- sorting2: the second sorting
431+ sorting1: BaseSorting
432+ The first sorting.
433+ sorting2: BaseSorting
434+ The second sorting.
433435 refractory_period_ms : float, default None
434436 The refractory period violation to prevent duplicates and/or unphysiological addition
435437 of spikes. Any spike times in added_spikes violating the refractory period will be
436- discarded
438+ discarded.
437439 """
438440 assert (
439441 sorting1 .get_sampling_frequency () == sorting2 .get_sampling_frequency ()
@@ -492,12 +494,14 @@ def add_from_unit_dict(
492494 Parameters
493495 ----------
494496
495- sorting1: the first sorting
497+ sorting1: BaseSorting
498+ The first sorting
496499 dict_list: list of dict
500+ A list of dict with unit_ids as keys and spike times as values.
497501 refractory_period_ms : float, default None
498502 The refractory period violation to prevent duplicates and/or unphysiological addition
499503 of spikes. Any spike times in added_spikes violating the refractory period will be
500- discarded
504+ discarded.
501505 """
502506 sorting2 = NumpySorting .from_unit_dict (units_dict_list , sorting1 .get_sampling_frequency ())
503507 sorting = TransformSorting .add_from_sorting (sorting1 , sorting2 , refractory_period_ms )
@@ -515,18 +519,19 @@ def from_times_labels(
515519
516520 Parameters
517521 ----------
518- sorting1: the first sorting
522+ sorting1: BaseSorting
523+ The first sorting
519524 times_list: list of array (or array)
520- An array of spike times (in frames)
525+ An array of spike times (in frames).
521526 labels_list: list of array (or array)
522- An array of spike labels corresponding to the given times
527+ An array of spike labels corresponding to the given times.
523528 unit_ids: list or None, default: None
524529 The explicit list of unit_ids that should be extracted from labels_list
525- If None, then it will be np.unique(labels_list)
530+ If None, then it will be np.unique(labels_list).
526531 refractory_period_ms : float, default None
527532 The refractory period violation to prevent duplicates and/or unphysiological addition
528533 of spikes. Any spike times in added_spikes violating the refractory period will be
529- discarded
534+ discarded.
530535 """
531536
532537 sorting2 = NumpySorting .from_times_labels (times_list , labels_list , sampling_frequency , unit_ids )
@@ -556,6 +561,16 @@ def clean_refractory_period(self):
556561
557562
558563def create_sorting_npz (num_seg , file_path ):
564+ """
565+ Create a NPZ sorting file.
566+
567+ Parameters
568+ ----------
569+ num_seg : int
570+ The number of segments.
571+ file_path : str | Path
572+ The file path to save the NPZ file.
573+ """
559574 # create a NPZ sorting file
560575 d = {}
561576 d ["unit_ids" ] = np .array ([0 , 1 , 2 ], dtype = "int64" )
@@ -674,18 +689,18 @@ def synthesize_poisson_spike_vector(
674689 Parameters
675690 ----------
676691 num_units : int, default: 20
677- Number of neuronal units to simulate
692+ Number of neuronal units to simulate.
678693 sampling_frequency : float, default: 30000.0
679- Sampling frequency in Hz
694+ Sampling frequency in Hz.
680695 duration : float, default: 60.0
681- Duration of the simulation in seconds
696+ Duration of the simulation in seconds.
682697 refractory_period_ms : float, default: 4.0
683- Refractory period between spikes in milliseconds
698+ Refractory period between spikes in milliseconds.
684699 firing_rates : float or array_like or tuple, default: 3.0
685700 Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with
686- each element being the firing rate for one unit
701+ each element being the firing rate for one unit.
687702 seed : int, default: 0
688- Seed for random number generator
703+ Seed for random number generator.
689704
690705 Returns
691706 -------
@@ -779,27 +794,27 @@ def synthesize_random_firings(
779794 Parameters
780795 ----------
781796 num_units : int
782- number of units
797+ Number of units.
783798 sampling_frequency : float
784- sampling rate
799+ Sampling rate.
785800 duration : float
786- duration of the segment in seconds
801+ Duration of the segment in seconds.
787802 refractory_period_ms: float
788- refractory_period in ms
803+ Refractory period in ms.
789804 firing_rates: float or list[float]
790805 The firing rate of each unit (in Hz).
791806 If float, all units will have the same firing rate.
792807 add_shift_shuffle: bool, default: False
793808 Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat.
794809 seed: int, default: None
795- seed for the generator
810+ Seed for the generator.
796811
797812 Returns
798813 -------
799- times:
800- Concatenated and sorted times vector
801- labels:
802- Concatenated and sorted label vector
814+ times: np.array
815+ Concatenated and sorted times vector.
816+ labels: np.array
817+ Concatenated and sorted label vector.
803818
804819 """
805820
@@ -883,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
883898 Parameters
884899 ----------
885900 sorting :
886- Original sorting
901+ Original sorting.
887902 num : int
888- Number of injected units
903+ Number of injected units.
889904 max_shift : int
890- range of the shift in sample
905+ range of the shift in sample.
891906 ratio: float
892907 Proportion of original spike in the injected units.
893908
@@ -938,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
938953
939954
940955def inject_some_split_units (sorting , split_ids : list , num_split = 2 , output_ids = False , seed = None ):
941- """ """
956+ """
957+ Inject some split units in a sorting.
942958
959+ Parameters
960+ ----------
961+ sorting : BaseSorting
962+ Original sorting.
963+ split_ids : list
964+ List of unit_ids to split.
965+ num_split : int, default: 2
966+ Number of split units.
967+ output_ids : bool, default: False
968+ If True, return the new unit_ids.
969+ seed : int, default: None
970+ Random seed.
971+
972+ Returns
973+ -------
974+ sorting_with_split : NumpySorting
975+ A sorting with split units.
976+ """
943977 unit_ids = sorting .unit_ids
944978 assert unit_ids .dtype .kind == "i"
945979
@@ -989,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
9891023 num_violations : int
9901024 Number of contaminating spikes.
9911025 violation_delta : float, default: 1e-5
992- Temporal offset of contaminating spikes (in seconds)
1026+ Temporal offset of contaminating spikes (in seconds).
9931027
9941028 Returns
9951029 -------
@@ -1246,7 +1280,7 @@ def generate_recording_by_size(
12461280 num_channels: int
12471281 Number of channels.
12481282 seed : int, default: None
1249- The seed for np.random.default_rng
1283+ The seed for np.random.default_rng.
12501284
12511285 Returns
12521286 -------
@@ -1646,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording):
16461680 * (num_units, num_samples, num_channels): standard case
16471681 * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter.
16481682 nbefore: list[int] | int | None, default: None
1649- Where is the center of the template for each unit?
1683+ The number of samples before the peak of the template to align the spike.
16501684 If None, will default to the highest peak.
16511685 amplitude_factor: list[float] | float | None, default: None
16521686 The amplitude of each spike for each unit.
@@ -1661,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording):
16611695 You can use int for mono-segment objects.
16621696 upsample_vector: np.array or None, default: None.
16631697 When templates is 4d we can simulate a jitter.
1664- Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe [3]
1698+ Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape [3].
16651699
16661700 Returns
16671701 -------
0 commit comments