Skip to content

Commit 90cb2b1

Browse files
committed
MOre docstrings
1 parent 19582f1 commit 90cb2b1

File tree

1 file changed

+87
-53
lines changed

1 file changed

+87
-53
lines changed

src/spikeinterface/core/generate.py

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def generate_sorting(
103103
Parameters
104104
----------
105105
num_units : int, default: 5
106-
Number of units
106+
Number of units.
107107
sampling_frequency : float, default: 30000.0
108-
The sampling frequency
108+
The sampling frequency.
109109
durations : list, default: [10.325, 3.5]
110-
Duration of each segment in s
110+
Duration of each segment in s.
111111
firing_rates : float, default: 3.0
112112
The firing rate of each unit (in Hz).
113113
empty_units : list, default: None
@@ -123,12 +123,12 @@ def generate_sorting(
123123
border_size_samples : int, default: 20
124124
The size of the border in samples to add border spikes.
125125
seed : int, default: None
126-
The random seed
126+
The random seed.
127127
128128
Returns
129129
-------
130130
sorting : NumpySorting
131-
The sorting object
131+
The sorting object.
132132
"""
133133
seed = _ensure_seed(seed)
134134
rng = np.random.default_rng(seed)
@@ -187,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
187187
Parameters
188188
----------
189189
sorting : BaseSorting
190-
The sorting object
190+
The sorting object.
191191
sync_event_ratio : float
192192
The ratio of added synchronous spikes with respect to the total number of spikes.
193193
E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra
194194
spikes are synchronous (same sample_index), but on different units (not duplicates).
195195
seed : int, default: None
196-
The random seed
196+
The random seed.
197197
198198
199199
Returns
200200
-------
201201
sorting : TransformSorting
202-
The sorting object, keeping track of added spikes
202+
The sorting object, keeping track of added spikes.
203203
204204
"""
205205
rng = np.random.default_rng(seed)
@@ -249,18 +249,18 @@ def generate_sorting_to_inject(
249249
Parameters
250250
----------
251251
sorting : BaseSorting
252-
The sorting object
252+
The sorting object.
253253
num_samples: list of size num_segments.
254254
The number of samples in all the segments of the sorting, to generate spike times
255-
covering entire the entire duration of the segments
255+
covering entire the entire duration of the segments.
256256
max_injected_per_unit: int, default 1000
257-
The maximal number of spikes injected per units
257+
The maximal number of spikes injected per units.
258258
injected_rate: float, default 0.05
259-
The rate at which spikes are injected
259+
The rate at which spikes are injected.
260260
refractory_period_ms: float, default 1.5
261-
The refractory period that should not be violated while injecting new spikes
261+
The refractory period that should not be violated while injecting new spikes.
262262
seed: int, default None
263-
The random seed
263+
The random seed.
264264
265265
Returns
266266
-------
@@ -312,22 +312,22 @@ class TransformSorting(BaseSorting):
312312
Parameters
313313
----------
314314
sorting : BaseSorting
315-
The sorting object
315+
The sorting object.
316316
added_spikes_existing_units : np.array (spike_vector)
317-
The spikes that should be added to the sorting object, for existing units
317+
The spikes that should be added to the sorting object, for existing units.
318318
added_spikes_new_units: np.array (spike_vector)
319-
The spikes that should be added to the sorting object, for new units
319+
The spikes that should be added to the sorting object, for new units.
320320
new_units_ids: list
321-
The unit_ids that should be added if spikes for new units are added
321+
The unit_ids that should be added if spikes for new units are added.
322322
refractory_period_ms : float, default None
323323
The refractory period violation to prevent duplicates and/or unphysiological addition
324324
of spikes. Any spike times in added_spikes violating the refractory period will be
325-
discarded
325+
discarded.
326326
327327
Returns
328328
-------
329329
sorting : TransformSorting
330-
The sorting object with the added spikes and/or units
330+
The sorting object with the added spikes and/or units.
331331
"""
332332

333333
def __init__(
@@ -428,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe
428428
429429
Parameters
430430
----------
431-
sorting1: the first sorting
432-
sorting2: the second sorting
431+
sorting1: BaseSorting
432+
The first sorting.
433+
sorting2: BaseSorting
434+
The second sorting.
433435
refractory_period_ms : float, default None
434436
The refractory period violation to prevent duplicates and/or unphysiological addition
435437
of spikes. Any spike times in added_spikes violating the refractory period will be
436-
discarded
438+
discarded.
437439
"""
438440
assert (
439441
sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency()
@@ -492,12 +494,14 @@ def add_from_unit_dict(
492494
Parameters
493495
----------
494496
495-
sorting1: the first sorting
497+
sorting1: BaseSorting
498+
The first sorting
496499
dict_list: list of dict
500+
A list of dict with unit_ids as keys and spike times as values.
497501
refractory_period_ms : float, default None
498502
The refractory period violation to prevent duplicates and/or unphysiological addition
499503
of spikes. Any spike times in added_spikes violating the refractory period will be
500-
discarded
504+
discarded.
501505
"""
502506
sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency())
503507
sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms)
@@ -515,18 +519,19 @@ def from_times_labels(
515519
516520
Parameters
517521
----------
518-
sorting1: the first sorting
522+
sorting1: BaseSorting
523+
The first sorting
519524
times_list: list of array (or array)
520-
An array of spike times (in frames)
525+
An array of spike times (in frames).
521526
labels_list: list of array (or array)
522-
An array of spike labels corresponding to the given times
527+
An array of spike labels corresponding to the given times.
523528
unit_ids: list or None, default: None
524529
The explicit list of unit_ids that should be extracted from labels_list
525-
If None, then it will be np.unique(labels_list)
530+
If None, then it will be np.unique(labels_list).
526531
refractory_period_ms : float, default None
527532
The refractory period violation to prevent duplicates and/or unphysiological addition
528533
of spikes. Any spike times in added_spikes violating the refractory period will be
529-
discarded
534+
discarded.
530535
"""
531536

532537
sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids)
@@ -556,6 +561,16 @@ def clean_refractory_period(self):
556561

557562

558563
def create_sorting_npz(num_seg, file_path):
564+
"""
565+
Create a NPZ sorting file.
566+
567+
Parameters
568+
----------
569+
num_seg : int
570+
The number of segments.
571+
file_path : str | Path
572+
The file path to save the NPZ file.
573+
"""
559574
# create a NPZ sorting file
560575
d = {}
561576
d["unit_ids"] = np.array([0, 1, 2], dtype="int64")
@@ -674,18 +689,18 @@ def synthesize_poisson_spike_vector(
674689
Parameters
675690
----------
676691
num_units : int, default: 20
677-
Number of neuronal units to simulate
692+
Number of neuronal units to simulate.
678693
sampling_frequency : float, default: 30000.0
679-
Sampling frequency in Hz
694+
Sampling frequency in Hz.
680695
duration : float, default: 60.0
681-
Duration of the simulation in seconds
696+
Duration of the simulation in seconds.
682697
refractory_period_ms : float, default: 4.0
683-
Refractory period between spikes in milliseconds
698+
Refractory period between spikes in milliseconds.
684699
firing_rates : float or array_like or tuple, default: 3.0
685700
Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with
686-
each element being the firing rate for one unit
701+
each element being the firing rate for one unit.
687702
seed : int, default: 0
688-
Seed for random number generator
703+
Seed for random number generator.
689704
690705
Returns
691706
-------
@@ -779,27 +794,27 @@ def synthesize_random_firings(
779794
Parameters
780795
----------
781796
num_units : int
782-
number of units
797+
Number of units.
783798
sampling_frequency : float
784-
sampling rate
799+
Sampling rate.
785800
duration : float
786-
duration of the segment in seconds
801+
Duration of the segment in seconds.
787802
refractory_period_ms: float
788-
refractory_period in ms
803+
Refractory period in ms.
789804
firing_rates: float or list[float]
790805
The firing rate of each unit (in Hz).
791806
If float, all units will have the same firing rate.
792807
add_shift_shuffle: bool, default: False
793808
Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat.
794809
seed: int, default: None
795-
seed for the generator
810+
Seed for the generator.
796811
797812
Returns
798813
-------
799-
times:
800-
Concatenated and sorted times vector
801-
labels:
802-
Concatenated and sorted label vector
814+
times: np.array
815+
Concatenated and sorted times vector.
816+
labels: np.array
817+
Concatenated and sorted label vector.
803818
804819
"""
805820

@@ -883,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
883898
Parameters
884899
----------
885900
sorting :
886-
Original sorting
901+
Original sorting.
887902
num : int
888-
Number of injected units
903+
Number of injected units.
889904
max_shift : int
890-
range of the shift in sample
905+
range of the shift in sample.
891906
ratio: float
892907
Proportion of original spike in the injected units.
893908
@@ -938,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
938953

939954

940955
def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None):
941-
""" """
956+
"""
957+
Inject some split units in a sorting.
942958
959+
Parameters
960+
----------
961+
sorting : BaseSorting
962+
Original sorting.
963+
split_ids : list
964+
List of unit_ids to split.
965+
num_split : int, default: 2
966+
Number of split units.
967+
output_ids : bool, default: False
968+
If True, return the new unit_ids.
969+
seed : int, default: None
970+
Random seed.
971+
972+
Returns
973+
-------
974+
sorting_with_split : NumpySorting
975+
A sorting with split units.
976+
"""
943977
unit_ids = sorting.unit_ids
944978
assert unit_ids.dtype.kind == "i"
945979

@@ -989,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
9891023
num_violations : int
9901024
Number of contaminating spikes.
9911025
violation_delta : float, default: 1e-5
992-
Temporal offset of contaminating spikes (in seconds)
1026+
Temporal offset of contaminating spikes (in seconds).
9931027
9941028
Returns
9951029
-------
@@ -1246,7 +1280,7 @@ def generate_recording_by_size(
12461280
num_channels: int
12471281
Number of channels.
12481282
seed : int, default: None
1249-
The seed for np.random.default_rng
1283+
The seed for np.random.default_rng.
12501284
12511285
Returns
12521286
-------
@@ -1646,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording):
16461680
* (num_units, num_samples, num_channels): standard case
16471681
* (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter.
16481682
nbefore: list[int] | int | None, default: None
1649-
Where is the center of the template for each unit?
1683+
The number of samples before the peak of the template to align the spike.
16501684
If None, will default to the highest peak.
16511685
amplitude_factor: list[float] | float | None, default: None
16521686
The amplitude of each spike for each unit.
@@ -1661,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording):
16611695
You can use int for mono-segment objects.
16621696
upsample_vector: np.array or None, default: None.
16631697
When templates is 4d we can simulate a jitter.
1664-
Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3]
1698+
Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3].
16651699
16661700
Returns
16671701
-------

0 commit comments

Comments
 (0)