Skip to content

Commit 5a7d890

Browse files
authored
Merge pull request #3053 from h-mayorquin/add_scale_to_uV_preprocessing
Add `scale_to_uV` preprocessing
2 parents 203f778 + 0b5c635 commit 5a7d890

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

src/spikeinterface/preprocessing/preprocessinglist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
CenterRecording,
2525
center,
2626
)
27+
from .scale import scale_to_uV
28+
2729
from .whiten import WhitenRecording, whiten, compute_whitening_matrix
2830
from .rectify import RectifyRecording, rectify
2931
from .clip import BlankSaturationRecording, blank_staturation, ClipRecording, clip
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from spikeinterface.core import BaseRecording
6+
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor
7+
8+
9+
def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor:
10+
"""
11+
Scale raw traces to microvolts (µV).
12+
13+
This preprocessor uses the channel-specific gain and offset information
14+
stored in the recording extractor to convert the raw traces to µV units.
15+
16+
Parameters
17+
----------
18+
recording : BaseRecording
19+
The recording extractor to be scaled. The recording extractor must
20+
have gains and offsets otherwise an error will be raised.
21+
22+
Raises
23+
------
24+
AssertionError
25+
If the recording extractor does not have scaleable traces.
26+
"""
27+
# To avoid a circular import
28+
from spikeinterface.preprocessing import ScaleRecording
29+
30+
if not recording.has_scaleable_traces():
31+
error_msg = "Recording must have gains and offsets set to be scaled to µV"
32+
raise RuntimeError(error_msg)
33+
34+
gain = recording.get_channel_gains()
35+
offset = recording.get_channel_offsets()
36+
37+
scaled_to_uV_recording = ScaleRecording(recording, gain=gain, offset=offset, dtype="float32")
38+
39+
# We do this so when get_traces(return_scaled=True) is called, the return is the same.
40+
scaled_to_uV_recording.set_channel_gains(gains=1.0)
41+
scaled_to_uV_recording.set_channel_offsets(offsets=0.0)
42+
43+
return scaled_to_uV_recording
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import numpy as np
3+
from spikeinterface.core.testing_tools import generate_recording
4+
from spikeinterface.preprocessing import scale_to_uV, CenterRecording
5+
6+
7+
def test_scale_to_uV():
8+
# Create a sample recording extractor with fake gains and offsets
9+
num_channels = 4
10+
sampling_frequency = 30_000.0
11+
durations = [1.0, 1.0] # seconds
12+
recording = generate_recording(
13+
num_channels=num_channels,
14+
durations=durations,
15+
sampling_frequency=sampling_frequency,
16+
)
17+
18+
rng = np.random.default_rng(0)
19+
gains = rng.random(size=(num_channels)).astype(np.float32)
20+
offsets = rng.random(size=(num_channels)).astype(np.float32)
21+
recording.set_channel_gains(gains)
22+
recording.set_channel_offsets(offsets)
23+
24+
# Apply the preprocessor
25+
scaled_recording = scale_to_uV(recording=recording)
26+
27+
# Check if the traces are indeed scaled
28+
expected_traces = recording.get_traces(return_scaled=True, segment_index=0)
29+
scaled_traces = scaled_recording.get_traces(segment_index=0)
30+
31+
np.testing.assert_allclose(scaled_traces, expected_traces)
32+
33+
# Test for the error when recording doesn't have scaleable traces
34+
recording.set_channel_gains(None) # Remove gains to make traces unscaleable
35+
with pytest.raises(RuntimeError):
36+
scale_to_uV(recording)
37+
38+
39+
def test_scaling_in_preprocessing_chain():
40+
41+
# Create a sample recording extractor with fake gains and offsets
42+
num_channels = 4
43+
sampling_frequency = 30_000.0
44+
durations = [1.0] # seconds
45+
recording = generate_recording(
46+
num_channels=num_channels,
47+
durations=durations,
48+
sampling_frequency=sampling_frequency,
49+
)
50+
51+
rng = np.random.default_rng(0)
52+
gains = rng.random(size=(num_channels)).astype(np.float32)
53+
offsets = rng.random(size=(num_channels)).astype(np.float32)
54+
55+
recording.set_channel_gains(gains)
56+
recording.set_channel_offsets(offsets)
57+
58+
centered_recording = CenterRecording(scale_to_uV(recording=recording))
59+
traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True)
60+
61+
# Chain preprocessors
62+
centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording))
63+
traces_scaled_with_preprocessor = centered_recording_scaled.get_traces()
64+
65+
np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor)
66+
67+
# Test if the scaling is not done twice
68+
traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True)
69+
70+
np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument)

0 commit comments

Comments
 (0)