1111from ..event import find_events
1212from ..evoked import Evoked
1313from ..io import BaseRaw
14- from ..utils import _check_option , _check_preload , fill_doc
14+ from ..utils import _check_option , _check_preload , _validate_type , fill_doc
1515
1616
1717def _get_window (start , end ):
@@ -20,7 +20,9 @@ def _get_window(start, end):
2020 return window
2121
2222
23- def _fix_artifact (data , window , picks , first_samp , last_samp , mode ):
23+ def _fix_artifact (
24+ data , window , picks , first_samp , last_samp , base_tmin , base_tmax , mode
25+ ):
2426 """Modify original data by using parameter data."""
2527 if mode == "linear" :
2628 x = np .array ([first_samp , last_samp ])
@@ -32,6 +34,10 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
3234 data [picks , first_samp :last_samp ] = (
3335 data [picks , first_samp :last_samp ] * window [np .newaxis , :]
3436 )
37+ if mode == "constant" :
38+ data [picks , first_samp :last_samp ] = data [picks , base_tmin :base_tmax ].mean (
39+ axis = 1
40+ )[:, None ]
3541
3642
3743@fill_doc
@@ -41,6 +47,8 @@ def fix_stim_artifact(
4147 event_id = None ,
4248 tmin = 0.0 ,
4349 tmax = 0.01 ,
50+ * ,
51+ baseline = None ,
4452 mode = "linear" ,
4553 stim_channel = None ,
4654 picks = None ,
@@ -63,10 +71,23 @@ def fix_stim_artifact(
6371 Start time of the interpolation window in seconds.
6472 tmax : float
6573 End time of the interpolation window in seconds.
66- mode : 'linear' | 'window'
74+ baseline : None | tuple, shape (2,)
75+ The baseline to use when ``mode='constant'``, in which case it
76+ must be non-None.
77+
78+ .. versionadded:: 1.8
79+ mode : 'linear' | 'window' | 'constant'
6780 Way to fill the artifacted time interval.
68- 'linear' does linear interpolation
69- 'window' applies a (1 - hanning) window.
81+
82+ ``"linear"``
83+ Does linear interpolation.
84+ ``"window"``
85+ Applies a ``(1 - hanning)`` window.
86+ ``"constant"``
87+ Uses baseline average. baseline parameter must be provided.
88+
89+ .. versionchanged:: 1.8
90+ Added the ``"constant"`` mode.
7091 stim_channel : str | None
7192 Stim channel to use.
7293 %(picks_all_data)s
@@ -76,9 +97,22 @@ def fix_stim_artifact(
7697 inst : instance of Raw or Evoked or Epochs
7798 Instance with modified data.
7899 """
79- _check_option ("mode" , mode , ["linear" , "window" ])
100+ _check_option ("mode" , mode , ["linear" , "window" , "constant" ])
80101 s_start = int (np .ceil (inst .info ["sfreq" ] * tmin ))
81102 s_end = int (np .ceil (inst .info ["sfreq" ] * tmax ))
103+ if mode == "constant" :
104+ _validate_type (
105+ baseline , (tuple , list ), "baseline" , extra = "when mode='constant'"
106+ )
107+ _check_option ("len(baseline)" , len (baseline ), [2 ])
108+ for bi , b in enumerate (baseline ):
109+ _validate_type (
110+ b , "numeric" , f"baseline[{ bi } ]" , extra = "when mode='constant'"
111+ )
112+ b_start = int (np .ceil (inst .info ["sfreq" ] * baseline [0 ]))
113+ b_end = int (np .ceil (inst .info ["sfreq" ] * baseline [1 ]))
114+ else :
115+ b_start = b_end = np .nan
82116 if (mode == "window" ) and (s_end - s_start ) < 4 :
83117 raise ValueError (
84118 'Time range is too short. Use a larger interval or set mode to "linear".'
@@ -104,7 +138,11 @@ def fix_stim_artifact(
104138 for event_idx in event_start :
105139 first_samp = int (event_idx ) - inst .first_samp + s_start
106140 last_samp = int (event_idx ) - inst .first_samp + s_end
107- _fix_artifact (data , window , picks , first_samp , last_samp , mode )
141+ base_t1 = int (event_idx ) - inst .first_samp + b_start
142+ base_t2 = int (event_idx ) - inst .first_samp + b_end
143+ _fix_artifact (
144+ data , window , picks , first_samp , last_samp , base_t1 , base_t2 , mode
145+ )
108146 elif isinstance (inst , BaseEpochs ):
109147 if inst .reject is not None :
110148 raise RuntimeError (
@@ -114,14 +152,23 @@ def fix_stim_artifact(
114152 first_samp = s_start - e_start
115153 last_samp = s_end - e_start
116154 data = inst ._data
155+ base_t1 = b_start - e_start
156+ base_t2 = b_end - e_start
117157 for epoch in data :
118- _fix_artifact (epoch , window , picks , first_samp , last_samp , mode )
158+ _fix_artifact (
159+ epoch , window , picks , first_samp , last_samp , base_t1 , base_t2 , mode
160+ )
119161
120162 elif isinstance (inst , Evoked ):
121163 first_samp = s_start - inst .first
122164 last_samp = s_end - inst .first
123165 data = inst .data
124- _fix_artifact (data , window , picks , first_samp , last_samp , mode )
166+ base_t1 = b_start - inst .first
167+ base_t2 = b_end - inst .first
168+
169+ _fix_artifact (
170+ data , window , picks , first_samp , last_samp , base_t1 , base_t2 , mode
171+ )
125172
126173 else :
127174 raise TypeError (f"Not a Raw or Epochs or Evoked (got { type (inst )} )." )
0 commit comments