Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mne/io/snirf/_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ def __init__(
# Extract wavelengths
fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths"))
fnirs_wavelengths = [int(w) for w in fnirs_wavelengths]
if len(fnirs_wavelengths) != 2:
if len(fnirs_wavelengths) < 2:
raise RuntimeError(
f"The data contains "
f"{len(fnirs_wavelengths)}"
f" wavelengths: {fnirs_wavelengths}. "
f"MNE only supports reading continuous"
" wave amplitude SNIRF files "
"with two wavelengths."
f"MNE requires at least two wavelengths for "
"continuous wave amplitude SNIRF files."
)

# Extract channels
Expand Down
92 changes: 69 additions & 23 deletions mne/preprocessing/nirs/_beer_lambert_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..._fiff.constants import FIFF
from ...io import BaseRaw
from ...utils import _validate_type, pinv, warn
from ..nirs import _validate_nirs_info, source_detector_distances
from ..nirs import _channel_frequencies, _validate_nirs_info, source_detector_distances


def beer_lambert_law(raw, ppf=6.0):
Expand All @@ -36,23 +36,46 @@ def beer_lambert_law(raw, ppf=6.0):
_validate_type(raw, BaseRaw, "raw")
_validate_type(ppf, ("numeric", "array-like"), "ppf")
ppf = np.array(ppf, float)
if ppf.ndim == 0: # upcast single float to shape (2,)
ppf = np.array([ppf, ppf])
if ppf.shape != (2,):
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")

# Use nominal channel frequencies
#
# Notes on implementation:
# 1. Frequencies are calculated the same way as in nirs._validate_nirs_info().
# 2. Wavelength values in the info structure may contain actual frequencies,
# which may be used for more accurate calculation in the future.
# 3. nirs._channel_frequencies uses both cw_amplitude and OD data to determine
# frequencies, whereas we only need those from OD here. Is there any chance
# that they're different?
# 4. If actual frequencies were used, using np.unique() like below will lead to
# errors. Instead, absorption coefficients will need to be calculated for
# each individual frequency.
freqs = _channel_frequencies(raw.info)

# Get unique wavelengths and determine number of wavelengths
unique_freqs = np.unique(freqs)
n_wavelengths = len(unique_freqs)

# PPF validation for multiple wavelengths
if ppf.ndim == 0: # single float
# same PPF for all wavelengths, shape (n_wavelengths, 1)
ppf = np.full((n_wavelengths, 1), ppf)
elif ppf.ndim == 1 and len(ppf) == n_wavelengths:
# separate ppf for each wavelength
ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1)
else:
raise ValueError(
f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}"
f"ppf must be a single float or an array-like of length {n_wavelengths} "
f"(number of wavelengths), got shape {ppf.shape}"
)
ppf = ppf[:, np.newaxis] # shape (2, 1)
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)

abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2)
distances = source_detector_distances(raw.info, picks="all")
bad = ~np.isfinite(distances[picks])
bad |= distances[picks] <= 0
if bad.any():
warn(
"Source-detector distances are zero on NaN, some resulting "
"Source-detector distances are zero or NaN, some resulting "
"concentrations will be zero. Consider setting a montage "
"with raw.set_montage."
)
Expand All @@ -64,20 +87,41 @@ def beer_lambert_law(raw, ppf=6.0):
"likely due to optode locations being stored in a "
" unit other than meters."
)

rename = dict()
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
channels_to_drop_all = [] # Accumulate all channels to drop

# Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...])
for ii in range(0, len(picks), n_wavelengths):
group_picks = picks[ii : ii + n_wavelengths]
# Calculate Δc based on the system: ΔOD = E * L * PPF * Δc
# where E is (n_wavelengths, 2), Δc is (2, n_timepoints)
# using pseudo-inverse
EL = abs_coef * distances[group_picks[0]] * ppf
iEL = pinv(EL)
conc_data = iEL @ raw._data[group_picks] * 1e-3

raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3
# Replace the first two channels with HbO and HbR
raw._data[group_picks[:2]] = conc_data[:2] # HbO, HbR

# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in zip((ii, jj), ("hbo", "hbr")):
for ki, kind in zip(group_picks[:2], ("hbo", "hbr")):
ch = raw.info["chs"][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f"{ch['ch_name'].split(' ')[0]} {kind}"
rename[ch["ch_name"]] = new_name

# Accumulate extra wavelength channels to drop (keep only HbO and HbR)
if n_wavelengths > 2:
channels_to_drop = group_picks[2:]
channel_names_to_drop = [raw.ch_names[idx] for idx in channels_to_drop]
channels_to_drop_all.extend(channel_names_to_drop)

# Drop all accumulated extra wavelength channels after processing all groups
if channels_to_drop_all:
raw.drop_channels(channels_to_drop_all)

raw.rename_channels(rename)

# Validate the format of data after transformation is valid
Expand All @@ -95,7 +139,9 @@ def _load_absorption(freqs):
# save('extinction_coef.mat', 'extinct_coef')
#
# Returns data as [[HbO2(freq1), Hb(freq1)],
# [HbO2(freq2), Hb(freq2)]]
# [HbO2(freq2), Hb(freq2)],
# ...,
# [HbO2(freqN), Hb(freqN)]]
extinction_fname = op.join(
op.dirname(__file__), "..", "..", "data", "extinction_coef.mat"
)
Expand All @@ -104,12 +150,12 @@ def _load_absorption(freqs):
interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear")
interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear")

ext_coef = np.array(
[
[interp_hbo(freqs[0]), interp_hb(freqs[0])],
[interp_hbo(freqs[1]), interp_hb(freqs[1])],
]
)
abs_coef = ext_coef * 0.2303
# Build coefficient matrix for all wavelengths
# Shape: (n_wavelengths, 2) where columns are [HbO2, Hb]
ext_coef = np.zeros((len(freqs), 2))
for i, freq in enumerate(freqs):
ext_coef[i, 0] = interp_hbo(freq) # HbO2
ext_coef[i, 1] = interp_hb(freq) # Hb

abs_coef = ext_coef * 0.2303
return abs_coef
36 changes: 28 additions & 8 deletions mne/preprocessing/nirs/_scalp_coupling_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...io import BaseRaw
from ...utils import _validate_type, verbose
from ..nirs import _validate_nirs_info
from ..nirs import _channel_frequencies, _validate_nirs_info


@verbose
Expand Down Expand Up @@ -56,14 +56,34 @@ def scalp_coupling_index(
verbose=verbose,
).get_data()

# Determine number of wavelengths per source-detector pair
# We use nominal wavelengths as the info structure may contain arbitrary data.
freqs = _channel_frequencies(raw.info)
n_wavelengths = len(np.unique(freqs))

sci = np.zeros(picks.shape)
for ii in range(0, len(picks), 2):
with np.errstate(invalid="ignore"):
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
if not np.isfinite(c): # someone had std=0
c = 0
sci[ii] = c
sci[ii + 1] = c

# Calculate all pairwise correlations within each group and use the minimum as SCI
pair_indices = np.triu_indices(n_wavelengths, k=1)

for gg in range(0, len(picks), n_wavelengths):
group_data = filtered_data[gg : gg + n_wavelengths]

# Calculate pairwise correlations within the group
correlations = np.zeros(pair_indices[0].shape[0])

for n, (ii, jj) in enumerate(zip(*pair_indices)):
with np.errstate(invalid="ignore"):
c = np.corrcoef(group_data[ii], group_data[jj])[0][1]
if np.isfinite(c):
correlations[n] = c

# Use minimum correlation as SCI
group_sci = correlations.min()

# Assign the same SCI value to all channels in the group
sci[gg : gg + n_wavelengths] = group_sci

sci[zero_mask] = 0
sci = sci[np.argsort(picks)] # restore original order
return sci
Loading
Loading