Skip to content

Commit 610ec2a

Browse files
fmamashlijasmainaklarsoner
authored
Additional option to flatten TMS artifact (#6915)
Co-authored-by: Mainak Jas <jasmainak@users.noreply.github.com> Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent f3a7fde commit 610ec2a

File tree

6 files changed

+104
-16
lines changed

6 files changed

+104
-16
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ jobs:
120120
- run: ./tools/github_actions_dependencies.sh
121121
# Minimal commands on Linux (macOS stalls)
122122
- run: ./tools/get_minimal_commands.sh
123-
if: ${{ startswith(matrix.os, 'ubuntu') }}
123+
if: startswith(matrix.os, 'ubuntu') && matrix.kind != 'minimal' && matrix.kind != 'old'
124124
- run: ./tools/github_actions_infos.sh
125125
# Check Qt
126126
- run: ./tools/check_qt_import.sh $MNE_QT_BACKEND
127-
if: ${{ env.MNE_QT_BACKEND != '' }}
127+
if: env.MNE_QT_BACKEND != ''
128128
- name: Run tests with no testing data
129129
run: MNE_SKIP_TESTING_DATASET_TESTS=true pytest -m "not (ultraslowtest or pgtest)" --tb=short --cov=mne --cov-report xml -vv -rfE mne/
130130
if: matrix.kind == 'minimal'
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add option to :func:`mne.preprocessing.fix_stim_artifact` to use baseline average to flatten TMS pulse artifact by `Fahimeh Mamashli`_ and `Padma Sundaram`_ and `Mohammad Daneshzand`_.

examples/datasets/brainstorm_data.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
=====================================
77
88
Here we compute the evoked from raw for the Brainstorm
9-
tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and:
10-
11-
https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf
9+
tutorial dataset. For comparison, see :footcite:`TadelEtAl2011` and
10+
https://neuroimage.usc.edu/brainstorm/Tutorials/MedianNerveCtf.
1211
"""
1312

1413
# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>

mne/preprocessing/stim.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..event import find_events
1212
from ..evoked import Evoked
1313
from ..io import BaseRaw
14-
from ..utils import _check_option, _check_preload, fill_doc
14+
from ..utils import _check_option, _check_preload, _validate_type, fill_doc
1515

1616

1717
def _get_window(start, end):
@@ -20,7 +20,9 @@ def _get_window(start, end):
2020
return window
2121

2222

23-
def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
23+
def _fix_artifact(
24+
data, window, picks, first_samp, last_samp, base_tmin, base_tmax, mode
25+
):
2426
"""Modify original data by using parameter data."""
2527
if mode == "linear":
2628
x = np.array([first_samp, last_samp])
@@ -32,6 +34,10 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
3234
data[picks, first_samp:last_samp] = (
3335
data[picks, first_samp:last_samp] * window[np.newaxis, :]
3436
)
37+
if mode == "constant":
38+
data[picks, first_samp:last_samp] = data[picks, base_tmin:base_tmax].mean(
39+
axis=1
40+
)[:, None]
3541

3642

3743
@fill_doc
@@ -41,6 +47,8 @@ def fix_stim_artifact(
4147
event_id=None,
4248
tmin=0.0,
4349
tmax=0.01,
50+
*,
51+
baseline=None,
4452
mode="linear",
4553
stim_channel=None,
4654
picks=None,
@@ -63,10 +71,23 @@ def fix_stim_artifact(
6371
Start time of the interpolation window in seconds.
6472
tmax : float
6573
End time of the interpolation window in seconds.
66-
mode : 'linear' | 'window'
74+
baseline : None | tuple, shape (2,)
75+
The baseline to use when ``mode='constant'``, in which case it
76+
must be non-None.
77+
78+
.. versionadded:: 1.8
79+
mode : 'linear' | 'window' | 'constant'
6780
Way to fill the artifacted time interval.
68-
'linear' does linear interpolation
69-
'window' applies a (1 - hanning) window.
81+
82+
``"linear"``
83+
Does linear interpolation.
84+
``"window"``
85+
Applies a ``(1 - hanning)`` window.
86+
``"constant"``
87+
Uses baseline average. baseline parameter must be provided.
88+
89+
.. versionchanged:: 1.8
90+
Added the ``"constant"`` mode.
7091
stim_channel : str | None
7192
Stim channel to use.
7293
%(picks_all_data)s
@@ -76,9 +97,22 @@ def fix_stim_artifact(
7697
inst : instance of Raw or Evoked or Epochs
7798
Instance with modified data.
7899
"""
79-
_check_option("mode", mode, ["linear", "window"])
100+
_check_option("mode", mode, ["linear", "window", "constant"])
80101
s_start = int(np.ceil(inst.info["sfreq"] * tmin))
81102
s_end = int(np.ceil(inst.info["sfreq"] * tmax))
103+
if mode == "constant":
104+
_validate_type(
105+
baseline, (tuple, list), "baseline", extra="when mode='constant'"
106+
)
107+
_check_option("len(baseline)", len(baseline), [2])
108+
for bi, b in enumerate(baseline):
109+
_validate_type(
110+
b, "numeric", f"baseline[{bi}]", extra="when mode='constant'"
111+
)
112+
b_start = int(np.ceil(inst.info["sfreq"] * baseline[0]))
113+
b_end = int(np.ceil(inst.info["sfreq"] * baseline[1]))
114+
else:
115+
b_start = b_end = np.nan
82116
if (mode == "window") and (s_end - s_start) < 4:
83117
raise ValueError(
84118
'Time range is too short. Use a larger interval or set mode to "linear".'
@@ -104,7 +138,11 @@ def fix_stim_artifact(
104138
for event_idx in event_start:
105139
first_samp = int(event_idx) - inst.first_samp + s_start
106140
last_samp = int(event_idx) - inst.first_samp + s_end
107-
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
141+
base_t1 = int(event_idx) - inst.first_samp + b_start
142+
base_t2 = int(event_idx) - inst.first_samp + b_end
143+
_fix_artifact(
144+
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
145+
)
108146
elif isinstance(inst, BaseEpochs):
109147
if inst.reject is not None:
110148
raise RuntimeError(
@@ -114,14 +152,23 @@ def fix_stim_artifact(
114152
first_samp = s_start - e_start
115153
last_samp = s_end - e_start
116154
data = inst._data
155+
base_t1 = b_start - e_start
156+
base_t2 = b_end - e_start
117157
for epoch in data:
118-
_fix_artifact(epoch, window, picks, first_samp, last_samp, mode)
158+
_fix_artifact(
159+
epoch, window, picks, first_samp, last_samp, base_t1, base_t2, mode
160+
)
119161

120162
elif isinstance(inst, Evoked):
121163
first_samp = s_start - inst.first
122164
last_samp = s_end - inst.first
123165
data = inst.data
124-
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
166+
base_t1 = b_start - inst.first
167+
base_t2 = b_end - inst.first
168+
169+
_fix_artifact(
170+
data, window, picks, first_samp, last_samp, base_t1, base_t2, mode
171+
)
125172

126173
else:
127174
raise TypeError(f"Not a Raw or Epochs or Evoked (got {type(inst)}).")

mne/preprocessing/tests/test_stim.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ def test_fix_stim_artifact():
5555
data_from_epochs_fix = epochs.get_data(copy=False)[:, :, tmin_samp:tmax_samp]
5656
assert not np.all(data_from_epochs_fix != 0)
5757

58+
baseline = (-0.1, -0.05)
59+
epochs = fix_stim_artifact(
60+
epochs, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
61+
)
62+
b_start = int(np.ceil(epochs.info["sfreq"] * baseline[0]))
63+
b_end = int(np.ceil(epochs.info["sfreq"] * baseline[1]))
64+
base_t1 = b_start - e_start
65+
base_t2 = b_end - e_start
66+
baseline_mean = epochs.get_data()[:, :, base_t1:base_t2].mean(axis=2)[0][0]
67+
data = epochs.get_data()[:, :, tmin_samp:tmax_samp]
68+
assert data[0][0][0] == baseline_mean
69+
5870
# use window before stimulus in raw
5971
event_idx = np.where(events[:, 2] == 1)[0][0]
6072
tmin, tmax = -0.045, -0.015
@@ -81,8 +93,22 @@ def test_fix_stim_artifact():
8193
raw, events, event_id=1, tmin=tmin, tmax=tmax, mode="window"
8294
)
8395
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
96+
8497
assert np.all(data) == 0.0
8598

99+
raw = fix_stim_artifact(
100+
raw,
101+
events,
102+
event_id=1,
103+
tmin=tmin,
104+
tmax=tmax,
105+
baseline=baseline,
106+
mode="constant",
107+
)
108+
data, times = raw[:, (tidx + tmin_samp) : (tidx + tmax_samp)]
109+
baseline_mean, _ = raw[:, (tidx + b_start) : (tidx + b_end)]
110+
assert baseline_mean.mean(axis=1)[0] == data[0][0]
111+
86112
# get epochs from raw with fixed data
87113
tmin, tmax, event_id = -0.2, 0.5, 1
88114
epochs = Epochs(
@@ -117,3 +143,12 @@ def test_fix_stim_artifact():
117143
evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode="window")
118144
data = evoked.data[:, tmin_samp:tmax_samp]
119145
assert np.all(data) == 0.0
146+
147+
evoked = fix_stim_artifact(
148+
evoked, tmin=tmin, tmax=tmax, baseline=baseline, mode="constant"
149+
)
150+
base_t1 = int(baseline[0] * evoked.info["sfreq"]) - evoked.first
151+
base_t2 = int(baseline[1] * evoked.info["sfreq"]) - evoked.first
152+
data = evoked.data[:, tmin_samp:tmax_samp]
153+
baseline_mean = evoked.data[:, base_t1:base_t2].mean(axis=1)[0]
154+
assert data[0][0] == baseline_mean

tools/install_pre_requirements.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ echo "PyQt6 and scientific-python-nightly-wheels dependencies"
1515
python -m pip install $STD_ARGS pip setuptools packaging \
1616
threadpoolctl cycler fonttools kiwisolver pyparsing pillow python-dateutil \
1717
patsy pytz tzdata nibabel tqdm trx-python joblib numexpr "$QT_BINDING" \
18-
py-cpuinfo blosc2
18+
py-cpuinfo blosc2 hatchling
1919
echo "NumPy/SciPy/pandas etc."
2020
python -m pip uninstall -yq numpy
2121
python -m pip install $STD_ARGS --only-binary ":all:" --default-timeout=60 \
2222
--index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \
2323
"numpy>=2.1.0.dev0" "scikit-learn>=1.6.dev0" "scipy>=1.15.0.dev0" \
24-
"statsmodels>=0.15.0.dev0" "pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \
24+
"pandas>=3.0.0.dev0" "matplotlib>=3.10.0.dev0" \
2525
"h5py>=3.12.1" "dipy>=1.10.0.dev0" "pyarrow>=19.0.0.dev0" "tables>=3.10.2.dev0"
2626

27+
# statsmodels requires formulaic@main so we need to use --extra-index-url
28+
echo "statsmodels"
29+
python -m pip install $STD_ARGS --only-binary ":all:" \
30+
--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" \
31+
"statsmodels>=0.15.0.dev0"
32+
2733
# No Numba because it forces an old NumPy version
2834

2935
echo "pymatreader"

0 commit comments

Comments
 (0)