From 08d4477c3109cd2cdbc81ce090fb36c4ad22f21e Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 11:58:44 -0500 Subject: [PATCH 1/5] Upgrade MATLAB-style trial and workflow fidelity --- nstat/ConfigColl.py | 10 +- nstat/CovColl.py | 10 +- nstat/FitResSummary.py | 10 +- nstat/FitResult.py | 10 +- nstat/TrialConfig.py | 10 +- nstat/__init__.py | 26 + nstat/analysis.py | 156 +++- nstat/cif.py | 49 +- nstat/events.py | 37 +- nstat/fit.py | 419 ++++++++-- nstat/history.py | 125 ++- nstat/trial.py | 1363 ++++++++++++++++++++++++++++--- tests/test_api_surface.py | 6 + tests/test_trial_fidelity.py | 134 +++ tests/test_workflow_fidelity.py | 109 +++ 15 files changed, 2149 insertions(+), 325 deletions(-) create mode 100644 tests/test_trial_fidelity.py create mode 100644 tests/test_workflow_fidelity.py diff --git a/nstat/ConfigColl.py b/nstat/ConfigColl.py index ddeef33d..06e5e98b 100644 --- a/nstat/ConfigColl.py +++ b/nstat/ConfigColl.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .trial import ConfigCollection - - -class ConfigColl(ConfigCollection): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.ConfigColl.ConfigColl", "nstat.trial.ConfigCollection") - super().__init__(*args, **kwargs) - +from .trial import ConfigCollection as ConfigColl __all__ = ["ConfigColl"] diff --git a/nstat/CovColl.py b/nstat/CovColl.py index 76b4ff93..2b8514f3 100644 --- a/nstat/CovColl.py +++ b/nstat/CovColl.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .trial import CovariateCollection - - -class CovColl(CovariateCollection): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.CovColl.CovColl", "nstat.trial.CovariateCollection") - super().__init__(*args, **kwargs) - +from .trial import CovariateCollection as CovColl __all__ = ["CovColl"] diff --git a/nstat/FitResSummary.py b/nstat/FitResSummary.py index 7305bbb6..b98a6521 100644 --- a/nstat/FitResSummary.py +++ b/nstat/FitResSummary.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .fit import FitResSummary as _FitResSummary - - -class FitResSummary(_FitResSummary): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.FitResSummary.FitResSummary", "nstat.fit.FitSummary") - super().__init__(*args, **kwargs) - +from .fit import FitResSummary __all__ = ["FitResSummary"] diff --git a/nstat/FitResult.py b/nstat/FitResult.py index fb4dc3d0..12e7c19e 100644 --- a/nstat/FitResult.py +++ b/nstat/FitResult.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .fit import FitResult as _FitResult - - -class FitResult(_FitResult): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.FitResult.FitResult", "nstat.fit.FitResult") - super().__init__(*args, **kwargs) - +from .fit import FitResult __all__ = ["FitResult"] diff --git a/nstat/TrialConfig.py b/nstat/TrialConfig.py index a110e4c6..d97db876 100644 --- a/nstat/TrialConfig.py +++ b/nstat/TrialConfig.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .trial import TrialConfig as _TrialConfig - - -class TrialConfig(_TrialConfig): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.TrialConfig.TrialConfig", "nstat.trial.TrialConfig") - super().__init__(*args, **kwargs) - +from .trial import TrialConfig __all__ = ["TrialConfig"] diff --git a/nstat/__init__.py b/nstat/__init__.py index d96fc1dd..9695c97b 100644 --- a/nstat/__init__.py +++ b/nstat/__init__.py @@ -1,3 +1,6 @@ +import sys as _sys +from types import ModuleType as _ModuleType + from .ConfidenceInterval import ConfidenceInterval from .ConfigColl import ConfigColl from .CovColl import CovColl @@ -27,6 +30,29 @@ from .nspikeTrain import nspikeTrain from .nstColl import nstColl +from . import analysis as _analysis_module +from . import cif as _cif_module +from . import events as _events_module +from . import history as _history_module +from . import trial as _trial_module + +_sys.modules.setdefault(f"{__name__}.Analysis", _analysis_module) +_sys.modules.setdefault(f"{__name__}.CIF", _cif_module) +_sys.modules.setdefault(f"{__name__}.Events", _events_module) +_sys.modules.setdefault(f"{__name__}.History", _history_module) +_sys.modules.setdefault(f"{__name__}.Trial", _trial_module) + + +class _NstatModule(_ModuleType): + def __getattribute__(self, name: str): + value = super().__getattribute__(name) + if isinstance(value, _ModuleType) and hasattr(value, name): + return getattr(value, name) + return value + + +_sys.modules[__name__].__class__ = _NstatModule + def __getattr__(name: str): if name == "nstat_install": diff --git a/nstat/analysis.py b/nstat/analysis.py index 0f581c56..9276d039 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -30,7 +30,7 @@ def psth(spike_trains: Sequence[object], bin_edges: np.ndarray) -> tuple[np.ndar class Analysis: - """Canonical analysis entry points preserving the paper's workflow semantics.""" + """Canonical analysis entry points preserving MATLAB-facing workflow semantics.""" @staticmethod def psth(spike_trains: Sequence[object], bin_edges: np.ndarray) -> tuple[np.ndarray, np.ndarray]: @@ -45,46 +45,137 @@ def run_analysis_for_neuron( l2: float = 1e-6, max_iter: int = 120, ) -> FitResult: - time, x_all, labels = trial.get_covariate_matrix() - spike_train = trial.spike_collection.get_nst(neuron_index) - - dt = float(np.median(np.diff(time))) if time.shape[0] > 1 else 1.0 - edges = np.concatenate([time, [time[-1] + dt]]) - y = spike_train.to_binned_counts(edges) - offset = np.full(y.shape[0], np.log(max(dt, 1e-12)), dtype=float) - + if neuron_index < 0: + raise IndexError("neuron_index must be >= 0") + + original_partition = trial.getTrialPartition().copy() + trial.restoreToOriginal() + if original_partition.size: + trial.setTrialPartition(original_partition) + trial.setTrialTimesFor("training") + + neuron_number = int(neuron_index) + 1 + labels: list[list[str]] = [] + lambda_parts: list[Covariate] = [] + b: list[np.ndarray] = [] + dev: list[float] = [] + stats: list[dict[str, float | int | bool]] = [] + AIC: list[float] = [] + BIC: list[float] = [] + logLL: list[float] = [] + numHist: list[int] = [] + histObjects: list[object] = [] + ensHistObjects: list[object] = [] fits: list[_SingleFit] = [] - for idx, cfg in enumerate(config_collection.configs, start=1): - names = cfg.covariate_names - if names: - cols = [i for i, lab in enumerate(labels) if lab in set(names)] - x = x_all[:, cols] if cols else np.zeros((x_all.shape[0], 0), dtype=float) - else: - x = x_all + xvalData: list[np.ndarray] = [] + xvalTime: list[np.ndarray] = [] + distributions: list[str] = [] + + spike_train = trial.nspikeColl.getNST(neuron_number).nstCopy() + if not spike_train.name: + spike_train.setName(str(neuron_number)) + + for cfg_index in range(1, config_collection.numConfigs + 1): + trial.restoreToOriginal() + if original_partition.size: + trial.setTrialPartition(original_partition) + trial.setTrialTimesFor("training") + + config_collection.setConfig(trial, cfg_index) + current_labels = trial.getLabelsFromMask(neuron_number) + X = trial.getDesignMatrix(neuron_number) + time = trial.covarColl.getCov(1).time + dt = float(np.median(np.diff(time))) if time.shape[0] > 1 else max(1.0 / trial.sampleRate, 1e-12) + edges = np.concatenate([time, [time[-1] + dt]]) + y = trial.nspikeColl.getNST(neuron_number).to_binned_counts(edges) + offset = np.full(y.shape[0], np.log(max(dt, 1e-12)), dtype=float) + + glm_res = fit_poisson_glm(X, y, offset=offset, l2=l2, max_iter=max_iter) + n_params = X.shape[1] + 1 + aic = float(2.0 * n_params - 2.0 * glm_res.log_likelihood) + bic = float(np.log(max(y.shape[0], 1)) * n_params - 2.0 * glm_res.log_likelihood) + fit_name = config_collection.getConfigNames([cfg_index])[0] + coeff = np.concatenate([[glm_res.intercept], np.asarray(glm_res.coefficients, dtype=float).reshape(-1)]) + + rate = glm_res.predict_rate(X, offset=offset) + lambda_signal = Covariate( + time, + rate, + fit_name if fit_name else f"lambda_{cfg_index}", + "time", + "s", + "spikes/sec", + [fit_name if fit_name else f"lambda_{cfg_index}"], + ) - glm_res = fit_poisson_glm(x, y, offset=offset, l2=l2, max_iter=max_iter) - n_params = x.shape[1] + 1 - aic = 2.0 * n_params - 2.0 * glm_res.log_likelihood - bic = np.log(max(y.shape[0], 1)) * n_params - 2.0 * glm_res.log_likelihood - fit_name = cfg.name if cfg.name else f"Fit {idx}" + labels.append(list(current_labels)) + lambda_parts.append(lambda_signal) + b.append(coeff) + dev.append(float(-2.0 * glm_res.log_likelihood)) + stats.append( + { + "intercept": float(glm_res.intercept), + "n_iter": int(glm_res.n_iter), + "converged": bool(glm_res.converged), + } + ) + AIC.append(aic) + BIC.append(bic) + logLL.append(float(glm_res.log_likelihood)) + numHist.append(len(trial.getHistLabels())) + histObjects.append(trial.history) + ensHistObjects.append(trial.ensCovHist) fits.append( _SingleFit( name=fit_name, coefficients=np.asarray(glm_res.coefficients, dtype=float), intercept=float(glm_res.intercept), log_likelihood=float(glm_res.log_likelihood), - aic=float(aic), - bic=float(bic), + aic=aic, + bic=bic, + stats=stats[-1], ) ) - - if x_all.shape[1] == 0: - x_for_rate = np.zeros((y.shape[0], 0), dtype=float) - else: - x_for_rate = x_all - rate = fit_poisson_glm(x_for_rate, y, offset=offset, l2=l2, max_iter=max_iter).predict_rate(x_for_rate, offset=offset) - lambda_signal = Covariate(time, rate, "lambda", "time", "s", "spikes/sec", ["lambda"]) - return FitResult(spike_train, lambda_signal, fits) + distributions.append("poisson") + + partition = trial.getTrialPartition() + if partition.size >= 4 and partition[2] < partition[3]: + trial.setTrialTimesFor("validation") + xvalData.append(trial.getDesignMatrix(neuron_number)) + xvalTime.append(trial.covarColl.getCov(1).time.copy()) + trial.setTrialTimesFor("training") + else: + xvalData.append(np.zeros((0, X.shape[1]), dtype=float)) + xvalTime.append(np.array([], dtype=float)) + + merged_lambda = lambda_parts[0] + for part in lambda_parts[1:]: + merged_lambda = merged_lambda.merge(part) + + trial.restoreToOriginal() + if original_partition.size: + trial.setTrialPartition(original_partition) + trial.setTrialTimesFor("training") + + return FitResult( + spike_train, + labels, + numHist, + histObjects, + ensHistObjects, + merged_lambda, + b, + dev, + stats, + AIC, + BIC, + logLL, + config_collection, + xvalData, + xvalTime, + distributions, + fits=fits, + ) @staticmethod def run_analysis_for_all_neurons( @@ -107,9 +198,8 @@ def run_analysis_for_all_neurons( ) return out - # MATLAB-compatible method names. @staticmethod - def RunAnalysisForNeuron(tObj: Trial, neuronNumber: int, configColl: ConfigCollection): + def RunAnalysisForNeuron(tObj: Trial, neuronNumber: int, configColl: ConfigCollection, *_): return Analysis.run_analysis_for_neuron(tObj, neuronNumber - 1, configColl) @staticmethod diff --git a/nstat/cif.py b/nstat/cif.py index 8b64b0ab..8862613b 100644 --- a/nstat/cif.py +++ b/nstat/cif.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Sequence import numpy as np @@ -55,7 +56,53 @@ def from_linear_terms( class CIF: - """MATLAB-compatible CIF static API wrapper.""" + """MATLAB-facing CIF object plus static convenience APIs.""" + + def __init__( + self, + beta: Sequence[float] | np.ndarray | None = None, + Xnames: Sequence[str] | None = None, + stimNames: Sequence[str] | None = None, + fitType: str = "poisson", + histCoeffs: Sequence[float] | np.ndarray | None = None, + historyObj=None, + nst=None, + ) -> None: + self.b = np.asarray(beta if beta is not None else [], dtype=float).reshape(-1) + self.varIn = list(Xnames or []) + self.stimVars = list(stimNames or []) + self.fitType = str(fitType) + self.histCoeffs = np.asarray(histCoeffs if histCoeffs is not None else [], dtype=float).reshape(-1) + self.history = historyObj + self.spikeTrain = None if nst is None else getattr(nst, "nstCopy", lambda: nst)() + + def evaluate(self, design_matrix: np.ndarray, *, delta: float = 1.0, history_matrix: np.ndarray | None = None) -> np.ndarray: + x = np.asarray(design_matrix, dtype=float) + if x.ndim == 1: + x = x[:, None] + beta = self.b + if x.shape[1] != beta.size: + raise ValueError("design_matrix column count must match number of CIF coefficients") + eta = x @ beta + if history_matrix is not None and self.histCoeffs.size: + hist = np.asarray(history_matrix, dtype=float) + if hist.ndim == 1: + hist = hist[:, None] + if hist.shape[1] != self.histCoeffs.size: + raise ValueError("history_matrix column count must match histCoeffs length") + eta = eta + hist @ self.histCoeffs + if self.fitType == "poisson": + lambda_delta = np.exp(np.clip(eta, -20.0, 20.0)) + elif self.fitType == "binomial": + exp_eta = np.exp(np.clip(eta, -20.0, 20.0)) + lambda_delta = exp_eta / (1.0 + exp_eta) + else: + raise ValueError("fitType must be either 'poisson' or 'binomial'") + return lambda_delta / max(float(delta), 1e-12) + + def to_covariate(self, time: Sequence[float], design_matrix: np.ndarray, *, delta: float = 1.0, name: str = "lambda") -> Covariate: + rate = self.evaluate(design_matrix, delta=delta) + return Covariate(time, rate, name, "time", "s", "spikes/sec", [name]) @staticmethod def simulateCIFByThinningFromLambda(lambda_covariate: Covariate, numRealizations: int = 1) -> SpikeTrainCollection: diff --git a/nstat/events.py b/nstat/events.py index 969b99a0..9c306fd2 100644 --- a/nstat/events.py +++ b/nstat/events.py @@ -1,29 +1,42 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any +from typing import Any, Sequence import numpy as np -@dataclass class Events: - event_times: np.ndarray - labels: list[str] | None = None + """MATLAB-style event container.""" - def __init__(self, event_times, labels=None) -> None: - self.event_times = np.asarray(event_times, dtype=float).reshape(-1) - self.labels = None if labels is None else list(labels) + def __init__(self, eventTimes, eventLabels: Sequence[str] | None = None, eventColor: str = "r") -> None: + times = np.asarray(eventTimes, dtype=float).reshape(-1) + labels = [""] * int(times.size) if eventLabels is None else list(eventLabels) + if len(labels) != int(times.size): + raise ValueError("Number of eventTimes must match number of eventLabels") + + self.eventTimes = times + self.eventLabels = labels + self.eventColor = str(eventColor) + + # Legacy Python-side aliases kept for compatibility. + self.event_times = self.eventTimes + self.labels = self.eventLabels def toStructure(self) -> dict[str, Any]: return { - "event_times": self.event_times.tolist(), - "labels": list(self.labels) if self.labels is not None else None, + "eventTimes": self.eventTimes.tolist(), + "eventLabels": list(self.eventLabels), + "eventColor": self.eventColor, } @staticmethod - def fromStructure(structure: dict[str, Any]) -> "Events": - return Events(structure.get("event_times", []), labels=structure.get("labels")) + def fromStructure(structure: dict[str, Any] | None) -> "Events" | None: + if structure is None: + return None + event_times = structure.get("eventTimes", structure.get("event_times", [])) + event_labels = structure.get("eventLabels", structure.get("labels")) + event_color = structure.get("eventColor", "r") + return Events(event_times, event_labels, event_color) def plot(self, *_, **__) -> None: return None diff --git a/nstat/fit.py b/nstat/fit.py index ad6021b4..73e9bf6d 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -1,12 +1,39 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable +from typing import Any, Iterable, Sequence import numpy as np -from .core import nspikeTrain -from .signal import Covariate +from .core import Covariate, nspikeTrain + + +def _ordered_unique(labels: Sequence[str]) -> list[str]: + return list(dict.fromkeys(str(label) for label in labels)) + + +def _parse_neuron_number(spike_obj: nspikeTrain | Sequence[nspikeTrain]) -> str | float: + if isinstance(spike_obj, Sequence) and not isinstance(spike_obj, nspikeTrain): + names = [str(item.name) for item in spike_obj if getattr(item, "name", "")] + unique = _ordered_unique(names) + return unique[0] if unique else "" + name = str(getattr(spike_obj, "name", "")) + if not name: + return "" + try: + return float(name) + except ValueError: + return name + + +def _pad_rows(rows: Sequence[np.ndarray], fill_value: float = np.nan) -> np.ndarray: + if not rows: + return np.zeros((0, 0), dtype=float) + max_len = max(row.size for row in rows) + out = np.full((len(rows), max_len), fill_value, dtype=float) + for idx, row in enumerate(rows): + out[idx, : row.size] = row + return out @dataclass @@ -17,27 +44,185 @@ class _SingleFit: log_likelihood: float aic: float bic: float + stats: Any | None = None class FitResult: - """Simplified Python FitResult compatible with nSTAT workflows.""" + """MATLAB-facing fit result container with Python compatibility aliases.""" + + def __init__(self, neuralSpikeTrain: nspikeTrain | Sequence[nspikeTrain], *args, **kwargs) -> None: + if args and isinstance(args[0], Covariate): + self._init_simplified(neuralSpikeTrain, args[0], args[1] if len(args) > 1 else []) + return + + covLabels = args[0] if len(args) > 0 else kwargs.get("covLabels", []) + numHist = args[1] if len(args) > 1 else kwargs.get("numHist", []) + histObjects = args[2] if len(args) > 2 else kwargs.get("histObjects", []) + ensHistObj = args[3] if len(args) > 3 else kwargs.get("ensHistObj", []) + lambda_signal = args[4] if len(args) > 4 else kwargs.get("lambda_signal") + b = args[5] if len(args) > 5 else kwargs.get("b", []) + dev = args[6] if len(args) > 6 else kwargs.get("dev", []) + stats = args[7] if len(args) > 7 else kwargs.get("stats", []) + AIC = args[8] if len(args) > 8 else kwargs.get("AIC", []) + BIC = args[9] if len(args) > 9 else kwargs.get("BIC", []) + logLL = args[10] if len(args) > 10 else kwargs.get("logLL", []) + configColl = args[11] if len(args) > 11 else kwargs.get("configColl") + XvalData = args[12] if len(args) > 12 else kwargs.get("XvalData", []) + XvalTime = args[13] if len(args) > 13 else kwargs.get("XvalTime", []) + distribution = args[14] if len(args) > 14 else kwargs.get("distribution", "poisson") + fits = kwargs.get("fits") + self._init_matlab_style( + neuralSpikeTrain, + covLabels, + numHist, + histObjects, + ensHistObj, + lambda_signal, + b, + dev, + stats, + AIC, + BIC, + logLL, + configColl, + XvalData, + XvalTime, + distribution, + fits=fits, + ) + + def _init_common(self) -> None: + self.Z = np.array([], dtype=float) + self.U = np.array([], dtype=float) + self.X = np.array([], dtype=float) + self.Residual = None + self.KSStats = np.zeros((self.numResults, 1), dtype=float) + self.KSPvalues = np.full(self.numResults, np.nan, dtype=float) + self.withinConfInt = np.zeros(self.numResults, dtype=float) + self.invGausStats = {"rhoSig": [], "confBoundSig": []} + self.plotParams = { + "bAct": _pad_rows([np.asarray(coeffs, dtype=float).reshape(-1) for coeffs in self.b]).T if self.b else np.zeros((0, 0)), + "seAct": np.zeros((len(self.uniqueCovLabels), self.numResults), dtype=float), + "sigIndex": np.zeros((len(self.uniqueCovLabels), self.numResults), dtype=float), + "xLabels": list(self.uniqueCovLabels), + "numResultsCoeffPresent": np.sum(self.flatMask, axis=1) if self.flatMask.size else np.array([], dtype=int), + } + self.validation = None + + def _init_simplified(self, neuralSpikeTrain: nspikeTrain | Sequence[nspikeTrain], lambda_signal: Covariate, fits: Sequence[_SingleFit]) -> None: + from .trial import ConfigCollection, TrialConfig - def __init__( - self, - neuralSpikeTrain: nspikeTrain, - lambda_signal: Covariate, - fits: list[_SingleFit], - ) -> None: self.neuralSpikeTrain = neuralSpikeTrain + self.neuronNumber = _parse_neuron_number(neuralSpikeTrain) self.lambda_signal = lambda_signal - self.fits = fits - self.numResults = len(fits) - self.AIC = np.asarray([f.aic for f in fits], dtype=float) - self.BIC = np.asarray([f.bic for f in fits], dtype=float) - self.logLL = np.asarray([f.log_likelihood for f in fits], dtype=float) - self.KSStats = np.zeros((self.numResults, 1), dtype=float) - self.configNames = [f.name for f in fits] self.lambda_ = lambda_signal + self.numResults = len(list(fits)) + self.fits = list(fits) + self.b = [np.concatenate([[fit.intercept], np.asarray(fit.coefficients, dtype=float).reshape(-1)]) for fit in self.fits] + self.dev = np.zeros(self.numResults, dtype=float) + self.AIC = np.asarray([fit.aic for fit in self.fits], dtype=float) + self.BIC = np.asarray([fit.bic for fit in self.fits], dtype=float) + self.logLL = np.asarray([fit.log_likelihood for fit in self.fits], dtype=float) + self.stats = [fit.stats for fit in self.fits] + self.configNames = [fit.name for fit in self.fits] + self.configs = ConfigCollection([TrialConfig(name=name) for name in self.configNames]) + labels = list(lambda_signal.dataLabels) if getattr(lambda_signal, "dataLabels", None) else ["lambda"] + self.covLabels = [labels[:] for _ in range(self.numResults)] + self.uniqueCovLabels = _ordered_unique(labels) + self.indicesToUniqueLabels = [list(range(1, len(labels) + 1)) for _ in range(self.numResults)] + self.numHist = [0 for _ in range(self.numResults)] + self.histObjects = [None for _ in range(self.numResults)] + self.ensHistObjects = [None for _ in range(self.numResults)] + self.fitType = ["poisson" for _ in range(self.numResults)] + self.numCoeffs = np.asarray([coeff.shape[0] for coeff in self.b], dtype=int) + self.flatMask = np.ones((len(self.uniqueCovLabels), self.numResults), dtype=int) + self.XvalData = [] + self.XvalTime = [] + self.minTime = float(lambda_signal.minTime) + self.maxTime = float(lambda_signal.maxTime) + self._init_common() + + def _init_matlab_style( + self, + neuralSpikeTrain: nspikeTrain | Sequence[nspikeTrain], + covLabels, + numHist, + histObjects, + ensHistObj, + lambda_signal: Covariate | None, + b, + dev, + stats, + AIC, + BIC, + logLL, + configColl, + XvalData, + XvalTime, + distribution, + *, + fits: Sequence[_SingleFit] | None = None, + ) -> None: + self.neuralSpikeTrain = neuralSpikeTrain + self.neuronNumber = _parse_neuron_number(neuralSpikeTrain) + self.lambda_signal = lambda_signal if lambda_signal is not None else Covariate([], [], "lambda") + self.lambda_ = self.lambda_signal + self.covLabels = [list(labels) for labels in covLabels] + self.uniqueCovLabels = _ordered_unique([label for labels in self.covLabels for label in labels]) + self.indicesToUniqueLabels = [] + self.flatMask = np.zeros((len(self.uniqueCovLabels), max(len(self.covLabels), 1)), dtype=int) + for fit_idx, labels in enumerate(self.covLabels): + indices = [self.uniqueCovLabels.index(label) + 1 for label in labels] + self.indicesToUniqueLabels.append(indices) + if indices: + self.flatMask[np.asarray(indices, dtype=int) - 1, fit_idx] = 1 + + self.numHist = list(numHist) + self.histObjects = list(histObjects) + if ensHistObj is None or ensHistObj == []: + self.ensHistObjects = [None for _ in range(len(self.covLabels))] + elif isinstance(ensHistObj, Sequence) and not isinstance(ensHistObj, (str, bytes)): + self.ensHistObjects = list(ensHistObj) + else: + self.ensHistObjects = [ensHistObj for _ in range(len(self.covLabels))] + self.b = [np.asarray(coeff, dtype=float).reshape(-1) for coeff in b] + self.dev = np.asarray(dev, dtype=float).reshape(-1) + self.AIC = np.asarray(AIC, dtype=float).reshape(-1) + self.BIC = np.asarray(BIC, dtype=float).reshape(-1) + self.logLL = np.asarray(logLL, dtype=float).reshape(-1) + self.stats = list(stats) + self.configs = configColl + self.configNames = configColl.getConfigNames() if configColl is not None else [f"Fit {i}" for i in range(1, len(self.b) + 1)] + if isinstance(distribution, str): + self.fitType = [distribution for _ in range(len(self.b))] + else: + self.fitType = list(distribution) + self.numResults = len(self.b) + self.numCoeffs = np.asarray([coeff.shape[0] for coeff in self.b], dtype=int) + self.XvalData = list(XvalData) if isinstance(XvalData, Sequence) and not isinstance(XvalData, (str, bytes, np.ndarray)) else [] + self.XvalTime = list(XvalTime) if isinstance(XvalTime, Sequence) and not isinstance(XvalTime, (str, bytes, np.ndarray)) else [] + self.minTime = float(getattr(self.lambda_signal, "minTime", np.nan)) + self.maxTime = float(getattr(self.lambda_signal, "maxTime", np.nan)) + if fits is not None: + self.fits = list(fits) + else: + self.fits = [] + for idx in range(self.numResults): + coeff = self.b[idx] + intercept = float(coeff[0]) if coeff.size else 0.0 + beta = coeff[1:] if coeff.size > 1 else np.array([], dtype=float) + self.fits.append( + _SingleFit( + name=self.configNames[idx], + coefficients=beta, + intercept=intercept, + log_likelihood=float(self.logLL[idx]), + aic=float(self.AIC[idx]), + bic=float(self.BIC[idx]), + stats=self.stats[idx] if idx < len(self.stats) else None, + ) + ) + self._init_common() @property def lambdaSignal(self) -> Covariate: @@ -57,47 +242,59 @@ def lambdaObj(self) -> Covariate: @property def lambda_data(self) -> np.ndarray: - return self.lambda_signal.data + return np.asarray(self.lambda_signal.data, dtype=float) @property def lambda_values(self) -> np.ndarray: - return self.lambda_signal.data + return np.asarray(self.lambda_signal.data, dtype=float) @property def lambda_time(self) -> np.ndarray: - return self.lambda_signal.time + return np.asarray(self.lambda_signal.time, dtype=float) @property def lambda_rate(self) -> np.ndarray: - return self.lambda_signal.data - - @property - def lambda_model(self) -> Covariate: - return self.lambda_signal - - @property - def lambda_result(self) -> Covariate: - return self.lambda_signal - - @property - def lambda_(self) -> Covariate: - return self.lambda_signal - - @lambda_.setter - def lambda_(self, value: Covariate) -> None: - self.lambda_signal = value + return np.asarray(self.lambda_signal.data, dtype=float) def getCoeffs(self, fit_num: int = 1) -> np.ndarray: - return self.fits[fit_num - 1].coefficients.copy() + return self.b[fit_num - 1].copy() def getHistCoeffs(self, fit_num: int = 1) -> np.ndarray: - return np.array([], dtype=float) + num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 + coeff = self.getCoeffs(fit_num) + if num_hist <= 0: + return np.array([], dtype=float) + return coeff[-num_hist:] def mergeResults(self, other: "FitResult") -> "FitResult": - merged_fits = [*self.fits, *other.fits] - merged_lambda = self.lambda_signal.merge(other.lambda_signal) - out = FitResult(self.neuralSpikeTrain, merged_lambda, merged_fits) - return out + from .trial import ConfigCollection + + if isinstance(self.lambda_signal, Covariate) and isinstance(other.lambda_signal, Covariate): + lambda_signal = self.lambda_signal.merge(other.lambda_signal) + else: + lambda_signal = self.lambda_signal + configs = ConfigCollection( + [*(self.configs.configArray if self.configs is not None else []), *(other.configs.configArray if other.configs is not None else [])] + ) + return FitResult( + self.neuralSpikeTrain, + [*self.covLabels, *other.covLabels], + [*self.numHist, *other.numHist], + [*self.histObjects, *other.histObjects], + [*self.ensHistObjects, *other.ensHistObjects], + lambda_signal, + [*self.b, *other.b], + np.concatenate([self.dev, other.dev]), + [*self.stats, *other.stats], + np.concatenate([self.AIC, other.AIC]), + np.concatenate([self.BIC, other.BIC]), + np.concatenate([self.logLL, other.logLL]), + configs, + [*self.XvalData, *other.XvalData], + [*self.XvalTime, *other.XvalTime], + [*self.fitType, *other.fitType], + fits=[*self.fits, *other.fits], + ) def plotResults(self, *_, **__) -> None: return None @@ -121,56 +318,91 @@ def plotCoeffs(self, *_, **__) -> None: def lambda_obj(self) -> Covariate: return self.lambda_signal + @property + def lambda_model(self) -> Covariate: + return self.lambda_signal + + @property + def lambda_result(self) -> Covariate: + return self.lambda_signal + def toStructure(self) -> dict[str, Any]: return { - "fits": [ - { - "name": f.name, - "coefficients": f.coefficients.tolist(), - "intercept": f.intercept, - "log_likelihood": f.log_likelihood, - "aic": f.aic, - "bic": f.bic, - } - for f in self.fits - ], + "covLabels": [list(labels) for labels in self.covLabels], + "numHist": list(self.numHist), "lambda_time": self.lambda_signal.time.tolist(), "lambda_data": self.lambda_signal.data.tolist(), - "neural_spike_times": self.neuralSpikeTrain.spikeTimes.tolist(), - "neural_name": self.neuralSpikeTrain.name, - "neural_min_time": self.neuralSpikeTrain.minTime, - "neural_max_time": self.neuralSpikeTrain.maxTime, + "lambda_name": self.lambda_signal.name, + "b": [coeff.tolist() for coeff in self.b], + "dev": self.dev.tolist(), + "AIC": self.AIC.tolist(), + "BIC": self.BIC.tolist(), + "logLL": self.logLL.tolist(), + "configNames": list(self.configNames), + "fitType": list(self.fitType), + "neural_spike_times": ( + self.neuralSpikeTrain.spikeTimes.tolist() + if isinstance(self.neuralSpikeTrain, nspikeTrain) + else [train.spikeTimes.tolist() for train in self.neuralSpikeTrain] + ), + "neural_name": ( + self.neuralSpikeTrain.name + if isinstance(self.neuralSpikeTrain, nspikeTrain) + else [train.name for train in self.neuralSpikeTrain] + ), + "neural_min_time": ( + self.neuralSpikeTrain.minTime + if isinstance(self.neuralSpikeTrain, nspikeTrain) + else [train.minTime for train in self.neuralSpikeTrain] + ), + "neural_max_time": ( + self.neuralSpikeTrain.maxTime + if isinstance(self.neuralSpikeTrain, nspikeTrain) + else [train.maxTime for train in self.neuralSpikeTrain] + ), } @staticmethod def fromStructure(structure: dict[str, Any]) -> "FitResult": - train = nspikeTrain( - structure["neural_spike_times"], - name=structure.get("neural_name", ""), - minTime=structure.get("neural_min_time"), - maxTime=structure.get("neural_max_time"), - ) + from .trial import ConfigCollection, TrialConfig + + spike_times = structure["neural_spike_times"] + neural_name = structure.get("neural_name", "") + neural_min_time = structure.get("neural_min_time", None) + neural_max_time = structure.get("neural_max_time", None) + if spike_times and isinstance(spike_times[0], list): + train: nspikeTrain | list[nspikeTrain] = [] + for st, name, min_t, max_t in zip(spike_times, neural_name, neural_min_time, neural_max_time): + train.append(nspikeTrain(st, name=name, minTime=min_t, maxTime=max_t, makePlots=-1)) + else: + train = nspikeTrain(spike_times, name=neural_name, minTime=neural_min_time, maxTime=neural_max_time, makePlots=-1) lam = Covariate( structure["lambda_time"], np.asarray(structure["lambda_data"], dtype=float), - "lambda", + structure.get("lambda_name", "lambda"), "time", "s", "spikes/sec", ) - fits = [] - for f in structure["fits"]: - fits.append( - _SingleFit( - name=f["name"], - coefficients=np.asarray(f["coefficients"], dtype=float), - intercept=float(f["intercept"]), - log_likelihood=float(f["log_likelihood"]), - aic=float(f["aic"]), - bic=float(f["bic"]), - ) - ) - return FitResult(train, lam, fits) + configColl = ConfigCollection([TrialConfig(name=name) for name in structure.get("configNames", [])]) + return FitResult( + train, + structure.get("covLabels", []), + structure.get("numHist", []), + [], + [], + lam, + [np.asarray(coeff, dtype=float) for coeff in structure.get("b", [])], + structure.get("dev", []), + [None for _ in structure.get("b", [])], + structure.get("AIC", []), + structure.get("BIC", []), + structure.get("logLL", []), + configColl, + [], + [], + structure.get("fitType", "poisson"), + ) class FitSummary: @@ -178,19 +410,26 @@ class FitSummary: def __init__(self, fit_results: FitResult | Iterable[FitResult]) -> None: if isinstance(fit_results, FitResult): - self.fit_results = [fit_results] + self.fitResCell = [fit_results] else: - self.fit_results = list(fit_results) - if not self.fit_results: + self.fitResCell = list(fit_results) + if not self.fitResCell: raise ValueError("FitSummary requires at least one FitResult") - aic = np.vstack([fr.AIC for fr in self.fit_results]) - bic = np.vstack([fr.BIC for fr in self.fit_results]) - ks = np.vstack([fr.KSStats.reshape(1, -1) for fr in self.fit_results]) - self.AIC = np.mean(aic, axis=0) - self.BIC = np.mean(bic, axis=0) - self.KSStats = np.column_stack([np.mean(ks, axis=0), np.std(ks, axis=0)]) - self.numNeurons = len(self.fit_results) + self.numNeurons = len(self.fitResCell) + self.numResults = max(fr.numResults for fr in self.fitResCell) + self.fitNames = self.fitResCell[max(range(self.numNeurons), key=lambda idx: self.fitResCell[idx].numResults)].configNames + self.neuronNumbers = [fr.neuronNumber for fr in self.fitResCell] + + aic = _pad_rows([np.asarray(fr.AIC, dtype=float).reshape(-1) for fr in self.fitResCell]) + bic = _pad_rows([np.asarray(fr.BIC, dtype=float).reshape(-1) for fr in self.fitResCell]) + logll = _pad_rows([np.asarray(fr.logLL, dtype=float).reshape(-1) for fr in self.fitResCell]) + ks = _pad_rows([np.asarray(fr.KSStats, dtype=float).reshape(-1) for fr in self.fitResCell], fill_value=np.nan) + + self.AIC = np.nanmean(aic, axis=0) + self.BIC = np.nanmean(bic, axis=0) + self.logLL = np.nanmean(logll, axis=0) + self.KSStats = np.column_stack([np.nanmean(ks, axis=0), np.nanstd(ks, axis=0)]) def getDiffAIC(self, idx: int = 1) -> np.ndarray: base = self.AIC[idx - 1] @@ -200,6 +439,10 @@ def getDiffBIC(self, idx: int = 1) -> np.ndarray: base = self.BIC[idx - 1] return self.BIC - base + def getDifflogLL(self, idx: int = 1) -> np.ndarray: + base = self.logLL[idx - 1] + return self.logLL - base + def plotSummary(self, *_, **__) -> None: return None diff --git a/nstat/history.py b/nstat/history.py index efdb575b..87fde427 100644 --- a/nstat/history.py +++ b/nstat/history.py @@ -1,49 +1,108 @@ from __future__ import annotations -from dataclasses import dataclass +from collections.abc import Sequence +from typing import Any import numpy as np -from .signal import Covariate +from .core import Covariate, nspikeTrain -@dataclass -class HistoryBasis: - """Spike-history basis using lagged spike-count regressors.""" +class History: + """MATLAB-style spike-history basis described by window boundaries.""" - lags: np.ndarray - name: str = "History" + def __init__(self, windowTimes, minTime: float | None = None, maxTime: float | None = None, name: str = "History") -> None: + times = np.asarray(windowTimes, dtype=float).reshape(-1) + if times.size <= 1: + raise ValueError("At least two times points must be specified to determine a window") + if np.any(np.diff(times) <= 0): + raise ValueError("windowTimes must be strictly increasing") - def __init__(self, lags, name: str = "History") -> None: - arr = np.asarray(lags, dtype=int).reshape(-1) - if arr.size == 0: - raise ValueError("lags must be non-empty") - if np.any(arr <= 0): - raise ValueError("lags must be strictly positive") - self.lags = np.unique(arr) - self.name = name + self.windowTimes = times + self.minTime = float(times[0] if minTime is None else minTime) + self.maxTime = float(times[-1] if maxTime is None else maxTime) + self.name = str(name) - def design_matrix(self, spike_indicator: np.ndarray) -> np.ndarray: - y = np.asarray(spike_indicator, dtype=float).reshape(-1) - x = np.zeros((y.shape[0], self.lags.shape[0]), dtype=float) - for j, lag in enumerate(self.lags): - x[lag:, j] = y[:-lag] - return x + @property + def lags(self) -> np.ndarray: + return np.asarray(self.windowTimes[1:], dtype=float).copy() - def compute_history(self, spike_indicator: np.ndarray, time: np.ndarray) -> Covariate: - x = self.design_matrix(spike_indicator) - labels = [f"hist_lag_{int(l)}" for l in self.lags] - return Covariate(time, x, self.name, "time", "s", "count", labels) + @property + def numWindows(self) -> int: + return int(self.windowTimes.size - 1) - # MATLAB-compatible method names. - def computeHistory(self, nst) -> Covariate: - y = np.asarray(getattr(nst, "getSigRep")().data[:, 0], dtype=float) - t = np.asarray(getattr(nst, "getSigRep")().time, dtype=float) - return self.compute_history(y, t) + def setWindow(self, windowTimes) -> None: + replacement = History(windowTimes, self.minTime, self.maxTime, self.name) + self.windowTimes = replacement.windowTimes + self.minTime = replacement.minTime + self.maxTime = replacement.maxTime + def _compute_single_history(self, train: nspikeTrain, historyIndex: int | None = None) -> Covariate: + sigrep = train.getSigRep() + time = np.asarray(sigrep.time, dtype=float).reshape(-1) + spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1) + history = np.zeros((time.size, self.numWindows), dtype=float) -# Backward-compatible alias. -History = HistoryBasis + for col, (window_start, window_stop) in enumerate(zip(self.windowTimes[:-1], self.windowTimes[1:])): + for row, tval in enumerate(time): + left = float(tval - window_stop) + right = float(tval - window_start) + history[row, col] = float(np.sum((spikes >= left) & (spikes < right))) + label_prefix = train.name or f"neuron_{historyIndex or 1}" + labels = [ + f"{label_prefix}_hist_{col + 1}" + for col in range(self.numWindows) + ] + return Covariate(time, history, self.name, "time", "s", "count", labels) -__all__ = ["HistoryBasis", "History"] + def compute_history(self, trains, historyIndex: int | None = None): + from .trial import CovariateCollection + + if isinstance(trains, nspikeTrain): + return CovariateCollection([self._compute_single_history(trains, historyIndex)]) + if hasattr(trains, "getNST") and hasattr(trains, "numSpikeTrains"): + covariates = [self._compute_single_history(trains.getNST(index), index) for index in range(1, int(trains.numSpikeTrains) + 1)] + return CovariateCollection(covariates) + if isinstance(trains, Sequence) and not isinstance(trains, (str, bytes, np.ndarray)): + covariates = [self._compute_single_history(train, index) for index, train in enumerate(trains, start=1)] + return CovariateCollection(covariates) + raise TypeError("History can only be computed from nspikeTrain, nstColl, or sequences of nspikeTrain") + + def computeHistory(self, trains, historyIndex: int | None = None): + return self.compute_history(trains, historyIndex) + + def toStructure(self) -> dict[str, Any]: + return { + "windowTimes": self.windowTimes.tolist(), + "minTime": self.minTime, + "maxTime": self.maxTime, + "name": self.name, + } + + @staticmethod + def fromStructure(structure: dict[str, Any] | None) -> "History" | None: + if structure is None: + return None + if "windowTimes" in structure: + windowTimes = structure["windowTimes"] + elif "lags" in structure: + lags = np.asarray(structure["lags"], dtype=float).reshape(-1) + windowTimes = np.concatenate([[0.0], lags]) + else: + windowTimes = [0.0, 1.0] + return History( + windowTimes, + minTime=structure.get("minTime"), + maxTime=structure.get("maxTime"), + name=structure.get("name", "History"), + ) + + def plot(self, *_, **__) -> None: + return None + + +HistoryBasis = History + + +__all__ = ["History", "HistoryBasis"] diff --git a/nstat/trial.py b/nstat/trial.py index 8b336367..5ec62116 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -1,124 +1,712 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Iterable, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np -from .core import nspikeTrain +from .core import Covariate, nspikeTrain from .events import Events -from .signal import Covariate + + +def _is_string_sequence(values: object) -> bool: + if isinstance(values, (str, bytes)): + return False + if not isinstance(values, Sequence): + return False + return all(isinstance(item, str) for item in values) + + +def _copy_covariate(cov: Covariate) -> Covariate: + copied = cov.copySignal() + if not isinstance(copied, Covariate): + copied = Covariate( + copied.time, + copied.data, + copied.name, + copied.xlabelval, + copied.xunits, + copied.yunits, + copied.dataLabels, + copied.plotProps, + ) + return copied class CovariateCollection: - def __init__(self, covariates: Sequence[Covariate] | None = None) -> None: - self.covariates = list(covariates or []) + """MATLAB-style CovColl implementation with collection-level masks and timing.""" + + def __init__(self, covariates: Sequence[Covariate] | Covariate | None = None, *more_covariates: Covariate) -> None: + self.covArray: list[Covariate] = [] + self.covDimensions: list[int] = [] + self.numCov = 0 + self.minTime = float("inf") + self.maxTime = float("-inf") + self.covMask: list[np.ndarray] = [] + self.covShift = 0.0 + self.sampleRate = float("nan") + self.originalSampleRate: float | None = None + self.originalMinTime: float | None = None + self.originalMaxTime: float | None = None + if covariates is not None: + self.addToColl(covariates) + for cov in more_covariates: + self.addToColl(cov) + + @property + def covariates(self) -> list[Covariate]: + return [self.getCov(i) for i in range(1, self.numCov + 1)] @property def names(self) -> list[str]: - return [cov.name for cov in self.covariates] + return [cov.name for cov in self.covArray] + + def _capture_originals_if_needed(self) -> None: + if self.numCov == 0: + return + if self.originalSampleRate is None: + self.originalSampleRate = float(self.sampleRate) + if self.originalMinTime is None: + self.originalMinTime = float(self.minTime) + if self.originalMaxTime is None: + self.originalMaxTime = float(self.maxTime) + + def _refresh_summary(self) -> None: + self.numCov = len(self.covArray) + self.covDimensions = [cov.dimension for cov in self.covArray] + if self.numCov == 0: + self.minTime = float("inf") + self.maxTime = float("-inf") + self.sampleRate = float("nan") + self.covMask = [] + return + + if len(self.covMask) != self.numCov: + self.covMask = [np.ones(cov.dimension, dtype=int) for cov in self.covArray] + else: + normalized_mask: list[np.ndarray] = [] + for cov, mask in zip(self.covArray, self.covMask): + arr = np.asarray(mask, dtype=int).reshape(-1) + if arr.size != cov.dimension: + arr = np.ones(cov.dimension, dtype=int) + normalized_mask.append(arr) + self.covMask = normalized_mask + + if not np.isfinite(self.sampleRate): + self.sampleRate = self.findMaxSampleRate() + self.minTime = self.findMinTime() + float(self.covShift) + self.maxTime = self.findMaxTime() + float(self.covShift) + self._capture_originals_if_needed() + + def _covariate_from_identifier(self, identifier: int | str) -> int: + if isinstance(identifier, str): + return self.getCovIndFromName(identifier) + index = int(identifier) + if index < 1 or index > self.numCov: + raise IndexError("Covariate index out of bounds (1-based indexing).") + return index + + def _apply_collection_state(self, cov: Covariate, index: int) -> Covariate: + out = _copy_covariate(cov) + if self.covShift != 0: + out.time = out.time + float(self.covShift) + out.minTime = float(np.min(out.time)) + out.maxTime = float(np.max(out.time)) + if np.isfinite(self.sampleRate) and self.sampleRate > 0 and round(out.sampleRate, 3) != round(self.sampleRate, 3): + out = out.resample(self.sampleRate) + if np.isfinite(self.minTime) and np.isfinite(self.maxTime) and out.time.size > 0: + out = out.getSigInTimeWindow(self.minTime, self.maxTime, holdVals=1) + out.setMask(self.covMask[index - 1]) + return out def add(self, covariate: Covariate) -> None: - self.covariates.append(covariate) + self.addToColl(covariate) def addCovariate(self, covariate: Covariate) -> None: - self.add(covariate) + self.addToColl(covariate) + + def addToColl(self, covariates: Sequence[Covariate] | Covariate | "CovariateCollection" | None) -> None: + if covariates is None: + return + if isinstance(covariates, CovariateCollection): + for cov in covariates.covArray: + self.addToColl(cov) + return + if isinstance(covariates, Covariate): + self.covArray.append(_copy_covariate(covariates)) + self.covMask.append(np.ones(covariates.dimension, dtype=int)) + self._refresh_summary() + return + if isinstance(covariates, Sequence) and not isinstance(covariates, (str, bytes, np.ndarray)): + for cov in covariates: + self.addToColl(cov) + return + raise TypeError("CovColl can only add Covariate instances or sequences of Covariates.") - def addToColl(self, covariate: Covariate) -> None: - self.add(covariate) + def removeCovariate(self, identifier: int | str) -> None: + index = self._covariate_from_identifier(identifier) + del self.covArray[index - 1] + del self.covMask[index - 1] + self._refresh_summary() def get(self, name: str) -> Covariate: - for cov in self.covariates: + return self.getCov(name) + + def getCov(self, identifier: int | str | Sequence[int] | Sequence[str]): + if isinstance(identifier, str): + return self._apply_collection_state(self.covArray[self.getCovIndFromName(identifier) - 1], self.getCovIndFromName(identifier)) + if isinstance(identifier, Sequence) and not isinstance(identifier, (str, bytes, np.ndarray)): + if _is_string_sequence(identifier): + return [self.getCov(item) for item in identifier] + return [self.getCov(int(item)) for item in identifier] + if isinstance(identifier, np.ndarray) and identifier.ndim > 0: + return [self.getCov(int(item)) for item in identifier.reshape(-1)] + index = self._covariate_from_identifier(identifier) + return self._apply_collection_state(self.covArray[index - 1], index) + + def getCovIndFromName(self, name: str) -> int: + for idx, cov in enumerate(self.covArray, start=1): if cov.name == name: - return cov + return idx raise KeyError(f"Covariate '{name}' not found") - def getCov(self, name: str) -> Covariate: - return self.get(name) + def getCovIndicesFromNames(self, name: Sequence[str] | str): + if isinstance(name, str): + return self.getCovIndFromName(name) + return [self.getCovIndFromName(item) for item in name] + + def findMinTime(self) -> float: + if self.numCov == 0: + return float("inf") + return float(min(cov.minTime for cov in self.covArray)) + + def findMaxTime(self) -> float: + if self.numCov == 0: + return float("-inf") + return float(max(cov.maxTime for cov in self.covArray)) + + def findMaxSampleRate(self) -> float: + if self.numCov == 0: + return float("nan") + return float(max(cov.sampleRate for cov in self.covArray if np.isfinite(cov.sampleRate))) + + def setMinTime(self, minTime: float | None = None) -> None: + if minTime is None: + minTime = self.findMinTime() + float(self.covShift) + self.minTime = float(minTime) + + def setMaxTime(self, maxTime: float | None = None) -> None: + if maxTime is None: + maxTime = self.findMaxTime() + float(self.covShift) + self.maxTime = float(maxTime) + + def restrictToTimeWindow(self, wMin: float, wMax: float) -> None: + self.setMinTime(wMin) + self.setMaxTime(wMax) + + def setSampleRate(self, sampleRate: float) -> None: + if self.originalSampleRate is None and np.isfinite(self.sampleRate): + self.originalSampleRate = float(self.sampleRate) + self.sampleRate = float(sampleRate) + self.enforceSampleRate() + + def resample(self, sampleRate: float) -> None: + self.setSampleRate(sampleRate) + + def enforceSampleRate(self) -> None: + if not np.isfinite(self.sampleRate) or self.sampleRate <= 0: + self.sampleRate = self.findMaxSampleRate() + + def resetMask(self) -> None: + self.covMask = [np.ones(cov.dimension, dtype=int) for cov in self.covArray] + + def getCovDataMask(self, identifier: int | str) -> np.ndarray: + index = self._covariate_from_identifier(identifier) + return np.asarray(self.covMask[index - 1], dtype=int).copy() + + def isCovMaskSet(self) -> bool: + return any(np.any(mask == 0) for mask in self.covMask) + + def flattenCovMask(self) -> np.ndarray: + if not self.covMask: + return np.array([], dtype=int) + return np.concatenate([np.asarray(mask, dtype=int).reshape(-1) for mask in self.covMask]) + + def getSelectorFromMasks(self, covMask: list[np.ndarray] | None = None) -> list[list[int]]: + current = self.covMask if covMask is None else covMask + selector: list[list[int]] = [] + for mask in current: + active = np.flatnonzero(np.asarray(mask, dtype=int) == 1) + 1 + selector.append(active.astype(int).tolist()) + return selector + + def _selector_cell_from_names(self, dataSelector: Sequence[Any]) -> list[list[int]]: + selectorCell = [[] for _ in range(self.numCov)] + if not dataSelector: + return selectorCell + if isinstance(dataSelector[0], str): + covName = str(dataSelector[0]) + covIndex = self.getCovIndFromName(covName) + currCov = self.getCov(covIndex) + if len(dataSelector) == 1: + selectorCell[covIndex - 1] = list(range(1, currCov.dimension + 1)) + else: + selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([str(v) for v in dataSelector[1:]]) + return selectorCell + + for item in dataSelector: + if not isinstance(item, Sequence) or isinstance(item, (str, bytes)): + raise ValueError("dataSelector specified incorrectly") + parsed = list(item) + if not parsed: + continue + covName = str(parsed[0]) + covIndex = self.getCovIndFromName(covName) + currCov = self.getCov(covIndex) + if len(parsed) == 1: + selectorCell[covIndex - 1] = list(range(1, currCov.dimension + 1)) + else: + selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([str(v) for v in parsed[1:]]) + return selectorCell + + def generateSelectorCell(self, dataSelector) -> list[list[int]]: + if dataSelector is None: + return [[] for _ in range(self.numCov)] + if isinstance(dataSelector, str): + return self._selector_cell_from_names([dataSelector]) + if isinstance(dataSelector, np.ndarray): + dataSelector = dataSelector.tolist() + if not isinstance(dataSelector, Sequence) or isinstance(dataSelector, (str, bytes)): + raise ValueError("dataSelector specified incorrectly") + values = list(dataSelector) + if not values: + return [[] for _ in range(self.numCov)] + looks_like_numeric_selector = self.numCov == len(values) and all( + isinstance(item, np.ndarray) + or ( + isinstance(item, Sequence) + and not isinstance(item, (str, bytes)) + and all(not isinstance(v, str) for v in item) + ) + or isinstance(item, (int, np.integer, float, np.floating)) + for item in values + ) + if looks_like_numeric_selector: + selectorCell: list[list[int]] = [] + for item in values: + if isinstance(item, np.ndarray): + selectorCell.append(np.asarray(item, dtype=int).reshape(-1).tolist()) + elif isinstance(item, Sequence) and not isinstance(item, (str, bytes)): + selectorCell.append([int(v) for v in item]) + else: + selectorCell.append([int(item)]) + return selectorCell + return self._selector_cell_from_names(values) + + def _selector_to_cov_mask(self, selectorCell: list[list[int]]) -> list[np.ndarray]: + if len(selectorCell) != self.numCov: + raise ValueError("selectorCell size must match number of covariates.") + masks: list[np.ndarray] = [] + for cov, selector in zip(self.covArray, selectorCell): + mask = np.zeros(cov.dimension, dtype=int) + if selector: + arr = np.asarray(selector, dtype=int).reshape(-1) + if arr.size == cov.dimension and np.all(np.isin(arr, [0, 1])): + mask = arr.astype(int) + else: + if np.any(arr < 1) or np.any(arr > cov.dimension): + raise IndexError("Covariate selector index out of bounds.") + mask[arr - 1] = 1 + masks.append(mask) + return masks + + def setMasksFromSelector(self, selectorCell: list[list[int]]) -> None: + self.covMask = self._selector_to_cov_mask(selectorCell) + + def setMask(self, cellInput) -> None: + if isinstance(cellInput, str) and cellInput == "all": + self.resetMask() + return + selectorCell = self.generateSelectorCell(cellInput) + self.setMasksFromSelector(selectorCell) + + def nActCovar(self) -> int: + return int(sum(1 for selector in self.getSelectorFromMasks() if selector)) + + def maskAwayCov(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + identifiers = identifier + if isinstance(identifier, (int, str)): + identifiers = [identifier] + for item in identifiers: + index = self._covariate_from_identifier(item) + self.covMask[index - 1] = np.zeros(self.covArray[index - 1].dimension, dtype=int) + + def maskAwayOnlyCov(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + self.resetMask() + self.maskAwayCov(identifier) + + def maskAwayAllExcept(self, identifier: int | str | Sequence[int] | Sequence[str]) -> None: + if isinstance(identifier, (int, str)): + keep = {self._covariate_from_identifier(identifier)} + else: + keep = {self._covariate_from_identifier(item) for item in identifier} + for idx, cov in enumerate(self.covArray, start=1): + if idx not in keep: + self.covMask[idx - 1] = np.zeros(cov.dimension, dtype=int) + + def setCovShift(self, deltaT: float, identifier=None) -> "CovariateCollection": + self.covShift = float(deltaT) + if np.isfinite(self.minTime): + self.minTime = float(self.minTime + self.covShift) + if np.isfinite(self.maxTime): + self.maxTime = float(self.maxTime + self.covShift) + return self + + def resetCovShift(self) -> None: + self.covShift = 0.0 + self.setMinTime() + self.setMaxTime() - def dataToMatrix(self, names: Sequence[str] | None = None) -> tuple[np.ndarray, np.ndarray, list[str]]: - if not self.covariates: + def restoreToOriginal(self) -> None: + self.covShift = 0.0 + if self.originalSampleRate is not None: + self.sampleRate = float(self.originalSampleRate) + else: + self.sampleRate = self.findMaxSampleRate() + self.setMinTime(self.findMinTime()) + self.setMaxTime(self.findMaxTime()) + self.resetMask() + + def getAllCovLabels(self) -> list[str]: + labels: list[str] = [] + for index in range(1, self.numCov + 1): + labels.extend(self.getCov(index).dataLabels) + return labels + + def getCovLabelsFromMask(self) -> list[str]: + labels: list[str] = [] + for index in range(1, self.numCov + 1): + cov = self.getCov(index) + mask = self.covMask[index - 1] + labels.extend([label for keep, label in zip(mask, cov.dataLabels) if keep == 1]) + return labels + + def matrixWithTime(self, repType: str = "standard", dataSelector=None) -> tuple[np.ndarray, np.ndarray, list[str]]: + if self.numCov == 0: raise ValueError("CovariateCollection is empty") - selected = self.covariates - if names is not None: - keep = set(names) - selected = [cov for cov in self.covariates if cov.name in keep] - if not selected: - raise ValueError("No covariates matched requested names") - - base_time = selected[0].time - x_parts: list[np.ndarray] = [] + if dataSelector is None: + selectorCell = self.getSelectorFromMasks() if self.isCovMaskSet() else [ + list(range(1, self.getCov(i).dimension + 1)) for i in range(1, self.numCov + 1) + ] + else: + selectorCell = self.generateSelectorCell(dataSelector) + + active_cov = [i + 1 for i, selector in enumerate(selectorCell) if selector] + if not active_cov: + time = self.getCov(1).time + return time.copy(), np.zeros((time.size, 0), dtype=float), [] + + time = self.getCov(active_cov[0]).getSigRep(repType).time + parts: list[np.ndarray] = [] labels: list[str] = [] - for cov in selected: - if cov.time.shape != base_time.shape or np.max(np.abs(cov.time - base_time)) > 1e-9: - raise ValueError("All covariates must share the same time grid") - x_parts.append(np.asarray(cov.data, dtype=float)) - labels.extend(cov.dataLabels) - return base_time, np.hstack(x_parts), labels + for covIndex in active_cov: + cov = self.getCov(covIndex).getSigRep(repType) + selector = selectorCell[covIndex - 1] + data = cov.dataToMatrix(selector) + endInd = min(time.size, data.shape[0]) + block = np.zeros((time.size, data.shape[1]), dtype=float) + block[:endInd, :] = data[:endInd, :] + parts.append(block) + labels.extend([cov.dataLabels[idx - 1] for idx in selector]) + return time.copy(), np.hstack(parts) if parts else np.zeros((time.size, 0), dtype=float), labels + + def dataToMatrix(self, repType: str | Sequence[str] | None = "standard", dataSelector=None, *_) -> np.ndarray: + if repType not in {"standard", "zero-mean"}: + dataSelector = repType + repType = "standard" + _, matrix, _ = self.matrixWithTime(str(repType), dataSelector) + return matrix class SpikeTrainCollection: - def __init__(self, trains: Sequence[nspikeTrain] | nspikeTrain) -> None: - if isinstance(trains, nspikeTrain): - trains = [trains] - self._trains = list(trains) - if len(self._trains) == 0: - raise ValueError("SpikeTrainCollection requires at least one spike train") - - self.minTime = float(min(s.minTime for s in self._trains)) - self.maxTime = float(max(s.maxTime for s in self._trains)) - rates = [s.sampleRate for s in self._trains if s.sampleRate > 0] - self.sampleRate = float(np.median(rates)) if rates else 1000.0 + """MATLAB-style nstColl implementation.""" + + def __init__(self, trains: Sequence[nspikeTrain] | nspikeTrain | None = None) -> None: + self.nstrain: list[nspikeTrain] = [] + self.numSpikeTrains = 0 + self.minTime = float("inf") + self.maxTime = float("-inf") + self.sampleRate = float("-inf") + self.neuronMask = np.array([], dtype=int) + self.neighbors: np.ndarray | list[list[int]] = [] + if trains is not None: + self.addToColl(trains) @property def num_spike_trains(self) -> int: - return len(self._trains) + return self.numSpikeTrains @property - def numSpikeTrains(self) -> int: - return self.num_spike_trains + def uniqueNeuronNames(self) -> list[str]: + return self.getUniqueNSTnames() def __iter__(self): - for tr in self._trains: + for tr in self.nstrain: yield tr + def _refresh_summary(self) -> None: + self.numSpikeTrains = len(self.nstrain) + if self.numSpikeTrains == 0: + self.minTime = float("inf") + self.maxTime = float("-inf") + self.sampleRate = float("-inf") + self.neuronMask = np.array([], dtype=int) + self.neighbors = [] + return + self.minTime = float(min(train.minTime for train in self.nstrain)) + self.maxTime = float(max(train.maxTime for train in self.nstrain)) + self.sampleRate = self.findMaxSampleRate() + if self.neuronMask.size != self.numSpikeTrains: + self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) + + def addSingleSpikeToColl(self, nst: nspikeTrain) -> None: + self.nstrain.append(nst.nstCopy()) + self._refresh_summary() + + def addToColl(self, nst: Sequence[nspikeTrain] | nspikeTrain | "SpikeTrainCollection") -> None: + if isinstance(nst, SpikeTrainCollection): + for train in nst.nstrain: + self.addSingleSpikeToColl(train) + return + if isinstance(nst, nspikeTrain): + self.addSingleSpikeToColl(nst) + return + if isinstance(nst, Sequence) and not isinstance(nst, (str, bytes, np.ndarray)): + for item in nst: + if not isinstance(item, nspikeTrain): + raise TypeError("nstColl requires a sequence of nspikeTrain objects.") + self.addSingleSpikeToColl(item) + return + raise TypeError("nstColl can only add nspikeTrain instances or sequences of nspikeTrain.") + + def merge(self, nstColl2: "SpikeTrainCollection") -> "SpikeTrainCollection": + self.addToColl(nstColl2) + return self + def get_nst(self, idx: int) -> nspikeTrain: - if idx < 0 or idx >= len(self._trains): + if idx < 0 or idx >= self.numSpikeTrains: raise IndexError("SpikeTrainCollection index out of bounds (0-based indexing).") - return self._trains[idx] + return self.nstrain[idx] - def getNST(self, idx: int) -> nspikeTrain: - if idx < 1 or idx > len(self._trains): + def getNST(self, idx) -> nspikeTrain | list[nspikeTrain]: + if isinstance(idx, Sequence) and not isinstance(idx, (str, bytes, np.ndarray)): + return [self.getNST(int(item)) for item in idx] + index = int(idx) + if index < 1 or index > self.numSpikeTrains: raise IndexError("nstColl index out of bounds (1-based indexing).") - return self._trains[idx - 1] + return self.nstrain[index - 1] + + def getNSTnames(self) -> list[str]: + return [train.name for train in self.nstrain] + + def getUniqueNSTnames(self) -> list[str]: + names = [name for name in self.getNSTnames() if name] + return list(dict.fromkeys(names)) - def setMinTime(self, value: float) -> None: + def getNSTIndicesFromName(self, name: Sequence[str] | str): + if isinstance(name, str): + matches = [i + 1 for i, value in enumerate(self.getNSTnames()) if value == name] + if not matches: + raise KeyError(f"Neuron '{name}' not found") + return matches if len(matches) > 1 else matches[0] + return [self.getNSTIndicesFromName(item) for item in name] + + def setMinTime(self, value: float | None = None) -> None: + if value is None: + value = self.minTime + for train in self.nstrain: + train.setMinTime(float(value)) self.minTime = float(value) - def setMaxTime(self, value: float) -> None: + def setMaxTime(self, value: float | None = None) -> None: + if value is None: + value = self.maxTime + for train in self.nstrain: + train.setMaxTime(float(value)) self.maxTime = float(value) - def dataToMatrix(self, bin_edges: Sequence[float]) -> np.ndarray: - edges = np.asarray(bin_edges, dtype=float).reshape(-1) - if edges.ndim != 1 or edges.size < 2: - raise ValueError("bin_edges must be a 1D array with at least two points") - rows = [np.asarray(spk.to_binned_counts(edges), dtype=float) for spk in self._trains] - return np.vstack(rows) + def resample(self, sampleRate: float) -> None: + self.sampleRate = float(sampleRate) + for train in self.nstrain: + train.resample(sampleRate) + + def findMaxSampleRate(self) -> float: + if self.numSpikeTrains == 0: + return float("-inf") + return float(max(train.sampleRate for train in self.nstrain)) + + def setMask(self, mask: Sequence[int] | np.ndarray) -> None: + arr = np.asarray(mask, dtype=int).reshape(-1) + if arr.size == self.numSpikeTrains and np.all(np.isin(arr, [0, 1])): + self.setNeuronMask(arr) + return + self.setNeuronMaskFromInd(arr) + + def setNeuronMaskFromInd(self, mask: Sequence[int] | np.ndarray) -> None: + arr = np.asarray(mask, dtype=int).reshape(-1) + newMask = np.zeros(self.numSpikeTrains, dtype=int) + if arr.size: + if np.any(arr < 1) or np.any(arr > self.numSpikeTrains): + raise IndexError("Neuron index out of bounds.") + newMask[arr - 1] = 1 + self.setNeuronMask(newMask) + + def setNeuronMask(self, mask: Sequence[int] | np.ndarray) -> None: + arr = np.asarray(mask, dtype=int).reshape(-1) + if arr.size != self.numSpikeTrains: + raise ValueError("neuronMask length must match number of spike trains.") + self.neuronMask = arr.astype(int) + + def resetMask(self) -> None: + self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) + + def getIndFromMask(self) -> list[int]: + return (np.flatnonzero(self.neuronMask == 1) + 1).astype(int).tolist() + + def getIndFromMaskMinusOne(self, neuron: int) -> list[int]: + return [idx for idx in self.getIndFromMask() if idx != int(neuron)] + + def isNeuronMaskSet(self) -> bool: + return bool(np.any(self.neuronMask == 0)) + + def setNeighbors(self, neighborArray: Sequence[Sequence[int]] | np.ndarray | None = None) -> None: + if neighborArray is None: + if self.numSpikeTrains == 0: + self.neighbors = [] + return + matrix = np.zeros((self.numSpikeTrains, max(self.numSpikeTrains - 1, 0)), dtype=int) + for i in range(self.numSpikeTrains): + neighbors = [idx for idx in range(1, self.numSpikeTrains + 1) if idx != (i + 1)] + if neighbors: + matrix[i, : len(neighbors)] = neighbors + self.neighbors = matrix + return + arr = np.asarray(neighborArray, dtype=int) + if arr.ndim != 2 or arr.shape[0] != self.numSpikeTrains: + raise ValueError("Neighbor Array is not of appropriate dimensions") + self.neighbors = arr + + def areNeighborsSet(self) -> bool: + return np.size(self.neighbors) > 0 + + def getNeighbors(self, neuronNum: int | Sequence[int]): + if isinstance(neuronNum, Sequence) and not isinstance(neuronNum, (str, bytes, np.ndarray)): + rows = [self.getNeighbors(int(item)) for item in neuronNum] + if rows and all(len(row) == len(rows[0]) for row in rows): + return np.asarray(rows, dtype=int) + return rows + neuron_idx = int(neuronNum) + if not self.areNeighborsSet(): + self.setNeighbors() + if isinstance(self.neighbors, list): + row = list(self.neighbors[neuron_idx - 1]) + else: + row = np.asarray(self.neighbors[neuron_idx - 1], dtype=int).reshape(-1).tolist() + available = set(self.getIndFromMaskMinusOne(neuron_idx)) + return [value for value in row if value in available and value > 0] + + def getMaxBinSizeBinary(self) -> float: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + if not selectorArray: + return np.inf + values = [self.getNST(index).getMaxBinSizeBinary() for index in selectorArray] + return float(np.min(values)) + + def dataToMatrix( + self, + selectorArray: Sequence[int] | Sequence[str] | str | None = None, + binwidth: float | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> np.ndarray: + if self.numSpikeTrains == 0: + return np.zeros((0, 0), dtype=float) + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if binwidth is None: + binwidth = 1.0 / self.sampleRate + if selectorArray is None: + selector = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + elif isinstance(selectorArray, str) or _is_string_sequence(selectorArray): + resolved = self.getNSTIndicesFromName(selectorArray) + if isinstance(resolved, list): + selector = [int(item) if not isinstance(item, list) else int(item[0]) for item in resolved] + else: + selector = [int(resolved)] + else: + selector = [int(item) for item in selectorArray] + if not selector: + testSig = self.getNST(1).getSigRep(binwidth, minTime, maxTime) + return np.zeros((testSig.dataToMatrix().shape[0], 0), dtype=float) + testSig = self.getNST(selector[0]).getSigRep(binwidth, minTime, maxTime) + dataMat = np.zeros((testSig.dataToMatrix().shape[0], len(selector)), dtype=float) + for idx, neuron in enumerate(selector): + sig = self.getNST(neuron).getSigRep(binwidth, minTime, maxTime) + dataMat[:, idx] = sig.dataToMatrix().reshape(-1) + return dataMat + + def getEnsembleNeuronCovariates(self, neuronNum: int = 1, neighborIndex=None, windowTimes=None): + if neighborIndex is None: + allNeighbors = self.getNeighbors(neuronNum) + else: + allNeighbors = [int(item) for item in np.asarray(neighborIndex, dtype=int).reshape(-1)] + if windowTimes is None: + windowTimes = [0.0, 0.001] + from .history import History + + histObj = windowTimes if isinstance(windowTimes, History) else History(windowTimes) + ensembleCovariates = histObj.computeHistory(self.getNST(list(range(1, self.numSpikeTrains + 1)))) + ensembleCovariates.maskAwayAllExcept(allNeighbors) + self.addNeuronNamesToEnsCovColl(ensembleCovariates) + return ensembleCovariates + + def addNeuronNamesToEnsCovColl(self, ensembleCovariates: CovariateCollection) -> None: + for i in range(1, ensembleCovariates.numCov + 1): + tempCov = ensembleCovariates.covArray[i - 1] + name = self.getNST(i).name + if not name: + name = str(i) + dataLabels = [f"{name}:{label}" if label else str(name) for label in tempCov.dataLabels] + tempCov.setDataLabels(dataLabels) + + def restoreToOriginal(self, rMask: int = 0) -> None: + for train in self.nstrain: + train.restoreToOriginal() + self._refresh_summary() + self.sampleRate = self.findMaxSampleRate() + self.resample(self.sampleRate) + if rMask == 1: + self.resetMask() def plot(self, *_, **__) -> None: return None - def psth(self, binwidth: float) -> Covariate: + def psth( + self, + binwidth: float = 0.100, + selectorArray: Sequence[int] | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> Covariate: if binwidth <= 0: raise ValueError("binwidth must be > 0") - min_time = float(self.minTime) - max_time = float(self.maxTime) - if max_time < min_time: - raise ValueError("maxTime must be >= minTime") - - # Match MATLAB nstColl.psth edge construction: - # windowTimes = minTime:binwidth:maxTime; - # if ~any(windowTimes==maxTime), append maxTime + min_time = self.minTime if minTime is None else float(minTime) + max_time = self.maxTime if maxTime is None else float(maxTime) + if selectorArray is None: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + span = max_time - min_time n_full = int(np.floor((span / binwidth) + 1e-12)) window_times = min_time + np.arange(n_full + 1, dtype=float) * float(binwidth) @@ -133,11 +721,9 @@ def psth(self, binwidth: float) -> Covariate: if window_times[1] <= window_times[0]: window_times[1] = window_times[0] + float(binwidth) - # MATLAB histc-like counting produces one extra terminal bin for x==max; - # nstColl.psth discards that final bin before normalizing. psth_hist = np.zeros(window_times.size, dtype=float) - for spk in self._trains: - spikes = np.asarray(spk.spikeTimes, dtype=float).reshape(-1) + for neuron in selectorArray: + spikes = np.asarray(self.getNST(int(neuron)).getSpikeTimes(), dtype=float).reshape(-1) if spikes.size == 0: continue valid = np.isfinite(spikes) & (spikes >= window_times[0]) & (spikes <= window_times[-1]) @@ -147,25 +733,34 @@ def psth(self, binwidth: float) -> Covariate: idx = np.clip(idx, 0, window_times.size - 1) psth_hist += np.bincount(idx, minlength=window_times.size).astype(float) - rate = psth_hist[:-1] / binwidth / float(len(self._trains)) - centers = (window_times[1:] + window_times[:-1]) * 0.5 - return Covariate(centers, rate, "PSTH", "time", "s", "spikes/sec", ["psth"]) + psth_data = psth_hist[:-1] / binwidth / float(len(selectorArray)) + time = (window_times[1:] + window_times[:-1]) * 0.5 + return Covariate(time, psth_data, "PSTH", "time", "s", "Hz", ["psth"]) def psthGLM(self, binwidth: float): psth_signal = self.psth(binwidth) return psth_signal, None, None -@dataclass class TrialConfig: - covMask: Sequence[Sequence[str]] | Sequence[str] - sampleRate: float - history: object | None = None - ensCovHist: object | None = None - covLag: object | None = None - name: str = "" + """MATLAB-style TrialConfig with configuration-application semantics.""" - def setName(self, name: str) -> None: + def __init__( + self, + covMask: Sequence[Sequence[str]] | Sequence[str] | None = None, + sampleRate: float | None = None, + history: object | None = None, + ensCovHist: object | None = None, + ensCovMask: object | None = None, + covLag: object | None = None, + name: str = "", + ) -> None: + self.covMask = [] if covMask is None else covMask + self.sampleRate = [] if sampleRate is None else sampleRate + self.history = [] if history is None else history + self.ensCovHist = [] if ensCovHist is None else ensCovHist + self.ensCovMask = [] if ensCovMask is None else ensCovMask + self.covLag = [] if covLag is None else covLag self.name = str(name) @property @@ -176,76 +771,618 @@ def covariate_names(self) -> list[str]: for item in self.covMask: if isinstance(item, str): names.append(item) - else: - names.extend([str(v) for v in item]) + elif isinstance(item, Sequence) and item: + names.append(str(item[0])) return names + def getName(self) -> str: + return self.name + + def setName(self, name: str) -> None: + self.name = str(name) + + def setConfig(self, trial: "Trial") -> None: + if self.history not in ([], None): + trial.setHistory(self.history) + else: + trial.resetHistory() + + if self.sampleRate not in ([], None): + sampleRate = float(self.sampleRate) + if round(trial.sampleRate, 3) != round(sampleRate, 3): + trial.resample(sampleRate) + + trial.setCovMask(self.covMask) + + if self.covLag not in ([], None): + trial.shiftCovariates(self.covLag) + + if self.ensCovHist not in ([], None): + trial.setEnsCovHist(self.ensCovHist) + trial.setEnsCovMask(self.ensCovMask) + else: + trial.setEnsCovHist() + trial.setEnsCovMask() + + def toStructure(self) -> dict[str, Any]: + return { + "covMask": self.covMask, + "sampleRate": self.sampleRate, + "history": self.history, + "ensCovHist": self.ensCovHist, + "ensCovMask": self.ensCovMask, + "covLag": self.covLag, + "name": self.name, + } + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "TrialConfig": + return TrialConfig( + structure.get("covMask"), + structure.get("sampleRate"), + structure.get("history"), + structure.get("ensCovHist"), + structure.get("ensCovMask"), + structure.get("covLag"), + structure.get("name", ""), + ) + class ConfigCollection: - def __init__(self, configs: Sequence[TrialConfig] | None = None) -> None: - self.configs: list[TrialConfig] = list(configs or []) + """MATLAB-style ConfigColl implementation.""" - @property - def numConfigs(self) -> int: - return len(self.configs) + def __init__(self, configs: Sequence[TrialConfig] | TrialConfig | str | None = None) -> None: + self.numConfigs = 0 + self.configNames: list[str] = [] + self.configArray: list[TrialConfig | str | list[str]] = [] + if configs is not None: + self.addConfig(configs) @property - def configArray(self) -> list[TrialConfig]: - return self.configs + def configs(self) -> list[TrialConfig]: + return [cfg for cfg in self.configArray if isinstance(cfg, TrialConfig)] def add_config(self, cfg: TrialConfig) -> None: - self.configs.append(cfg) + self.addConfig(cfg) - def addConfig(self, cfg: TrialConfig) -> None: - self.add_config(cfg) + def addConfig(self, cfg: Sequence[TrialConfig] | TrialConfig | str | None) -> None: + if isinstance(cfg, Sequence) and not isinstance(cfg, (str, bytes, TrialConfig, np.ndarray)): + for item in cfg: + self.addConfig(item) + return + if cfg is None or cfg == []: + self.numConfigs += 1 + self.configNames.append("Empty Config") + self.configArray.append(["Empty Config"]) + return + if isinstance(cfg, TrialConfig): + self.numConfigs += 1 + self.configArray.append(cfg) + self.setConfigNames(cfg.name, [self.numConfigs]) + return + if isinstance(cfg, str): + self.numConfigs += 1 + self.configArray.append(cfg) + self.setConfigNames(cfg, [self.numConfigs]) + return + raise TypeError("ConfigColl can only add TrialConfig objects, strings, or sequences of them.") - def get_config(self, idx: int) -> TrialConfig: - if idx < 0 or idx >= len(self.configs): + def get_config(self, idx: int) -> TrialConfig | str | list[str]: + if idx < 0 or idx >= self.numConfigs: raise IndexError("ConfigCollection index out of bounds (0-based indexing).") - return self.configs[idx] + return self.configArray[idx] + + def getConfig(self, idx: int): + if idx < 1 or idx > self.numConfigs: + raise IndexError("Index Out of Bounds") + return self.configArray[idx - 1] - def getConfig(self, idx: int) -> TrialConfig: - return self.configs[idx - 1] + def setConfig(self, trial: "Trial", index: int) -> None: + config = self.getConfig(index) + if isinstance(config, TrialConfig): + config.setConfig(trial) + return + raise ValueError("Cannot Set Empty Configs") def getConfigNames(self, index: Sequence[int] | None = None) -> list[str]: if index is None: index = list(range(1, self.numConfigs + 1)) out: list[str] = [] for i in index: - cfg = self.configs[i - 1] - out.append(cfg.name if cfg.name else f"Fit {i}") + if i < 1 or i > self.numConfigs: + raise IndexError("Index Out of Bounds") + tempName = self.configNames[i - 1] + out.append(tempName if tempName else f"Fit {i}") return out + def setConfigNames(self, names, index: Sequence[int] | None = None) -> None: + if index is None: + index = list(range(1, self.numConfigs + 1)) + if isinstance(names, str): + if len(index) != 1: + raise ValueError("If specifying a single name, index must be length 1.") + target = int(index[0]) - 1 + while len(self.configNames) < self.numConfigs: + self.configNames.append("") + self.configNames[target] = names if names else f"Fit {target + 1}" + if isinstance(self.configArray[target], TrialConfig): + self.configArray[target].setName(self.configNames[target]) + return + if isinstance(names, Sequence) and not isinstance(names, (str, bytes)): + if len(index) != len(names): + raise ValueError("If specifying multiple names, names and index must match in length.") + for idx, name in zip(index, names): + self.setConfigNames(str(name), [int(idx)]) + return + raise TypeError("names must be a string or sequence of strings.") + + def getSubsetConfigs(self, subset: Sequence[int]) -> "ConfigCollection": + tempconfigs = [self.getConfig(int(i)) for i in subset] + return ConfigCollection(tempconfigs) + + def toStructure(self) -> dict[str, Any]: + structure = { + "numConfigs": self.numConfigs, + "configNames": list(self.configNames), + "configArray": [], + } + for cfg in self.configArray: + if isinstance(cfg, TrialConfig): + structure["configArray"].append(cfg.toStructure()) + else: + structure["configArray"].append(cfg) + return structure + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "ConfigCollection": + configs = [] + for row in structure.get("configArray", []): + if isinstance(row, dict): + configs.append(TrialConfig.fromStructure(row)) + else: + configs.append(row) + coll = ConfigCollection(configs) + if "configNames" in structure: + coll.configNames = list(structure["configNames"]) + return coll + class Trial: + """MATLAB-style Trial object preserving collection-level workflow semantics.""" + def __init__( self, spike_collection: SpikeTrainCollection | None = None, covariate_collection: CovariateCollection | None = None, events: Events | None = None, + hist: object | None = None, + ensCovHist: object | None = None, + ensCovMask: object | None = None, *, spikeColl: SpikeTrainCollection | None = None, covarColl: CovariateCollection | None = None, + event: Events | None = None, ) -> None: - self.spike_collection = spike_collection if spike_collection is not None else spikeColl - self.covariate_collection = covariate_collection if covariate_collection is not None else covarColl - if self.spike_collection is None or self.covariate_collection is None: - raise ValueError("Trial requires both spike_collection and covariate_collection") - self.events = events + self.nspikeColl = spike_collection if spike_collection is not None else spikeColl + self.covarColl = covariate_collection if covariate_collection is not None else covarColl + if not isinstance(self.nspikeColl, SpikeTrainCollection): + raise ValueError("nstColl is a required argument") + if not isinstance(self.covarColl, CovariateCollection): + raise ValueError("CovColl is a required argument") + + self.ev: Events | None = None + self.history: object | None = [] + self.ensCovHist: object | None = [] + self.ensCovColl: CovariateCollection | None = None + self.sampleRate = float("nan") + self.minTime = float("nan") + self.maxTime = float("nan") + self.covMask = self.covarColl.covMask + self.ensCovMask = ensCovMask + self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() + self.trainingWindow: list[float] | np.ndarray | None = None + self.validationWindow: list[float] | np.ndarray | None = None + + event_obj = events if events is not None else event + self.setTrialEvents(event_obj) + self.setHistory(hist) + self.setEnsCovHist(ensCovHist) + self.setEnsCovMask(ensCovMask) + + self.covMask = self.covarColl.covMask + self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() + if not self.isSampleRateConsistent(): + self.makeConsistentSampleRate() + else: + self.sampleRate = float(self.covarColl.sampleRate) + self.makeConsistentTime() + self.setTrialPartition([]) + self.setTrialTimesFor("training") + + @property + def spike_collection(self) -> SpikeTrainCollection: + return self.nspikeColl + + @property + def covariate_collection(self) -> CovariateCollection: + return self.covarColl @property def spikeColl(self) -> SpikeTrainCollection: - return self.spike_collection + return self.nspikeColl + + def setTrialEvents(self, event: Events | None) -> None: + self.ev = event if isinstance(event, Events) else None + + def getEvents(self) -> Events | None: + return self.ev @property def covarColl(self) -> CovariateCollection: - return self.covariate_collection + return self._covarColl + + @covarColl.setter + def covarColl(self, value: CovariateCollection) -> None: + self._covarColl = value + + def getTrialPartition(self) -> np.ndarray: + training = [] if self.trainingWindow is None else list(self.trainingWindow) + validation = [] if self.validationWindow is None else list(self.validationWindow) + p = training + validation + if not p: + return np.asarray([self.minTime, self.maxTime, self.maxTime, self.maxTime], dtype=float) + return np.asarray(p, dtype=float) + + def setTrialPartition(self, partitionTimes) -> None: + if partitionTimes is None or len(partitionTimes) == 0: + partitionTimes = self.getTrialPartition() + values = np.asarray(partitionTimes, dtype=float).reshape(-1) + if values.size == 4: + trainingWindow = values[:2] + validationWindow = values[2:] + elif values.size == 3: + trainingWindow = values[:2] + validationWindow = values[1:] + else: + raise ValueError("partitionTimes must be length 3 or 4") + self.trainingWindow = trainingWindow + self.validationWindow = validationWindow + self.setMinTime(trainingWindow[0]) + self.setMaxTime(trainingWindow[1]) + + def setTrialTimesFor(self, partitionName: str = "training") -> None: + p = self.getTrialPartition() + if partitionName == "training": + timeWindow = p[:2] + elif partitionName == "validation": + timeWindow = p[2:4] + else: + raise ValueError("partitionName must be either training or validation") + self.setMinTime(float(timeWindow[0])) + self.setMaxTime(float(timeWindow[1])) + + def setMinTime(self, minTime: float | None = None) -> None: + if minTime is None: + minTime = self.findMinTime() + self.nspikeColl.setMinTime(float(minTime)) + self.covarColl.setMinTime(float(minTime)) + if self.ensCovColl is not None: + self.ensCovColl.setMinTime(float(minTime)) + self.minTime = float(minTime) + + def setMaxTime(self, maxTime: float | None = None) -> None: + if maxTime is None: + maxTime = self.findMaxTime() + self.nspikeColl.setMaxTime(float(maxTime)) + self.covarColl.setMaxTime(float(maxTime)) + if self.ensCovColl is not None: + self.ensCovColl.setMaxTime(float(maxTime)) + self.maxTime = float(maxTime) + + def updateTimePartitions(self) -> None: + if not (np.isfinite(self.minTime) and np.isfinite(self.maxTime)): + return + p = self.getTrialPartition() + training = p[:2] + validation = p[2:4] + newTrainMin = max(self.minTime, training[0]) + newTrainMax = min(self.maxTime, training[1]) + newValMin = max(self.minTime, validation[0]) + newValMax = min(self.maxTime, validation[1]) + self.setTrialPartition([newTrainMin, newTrainMax, newValMin, newValMax]) + + def setSampleRate(self, sampleRate: float) -> None: + self.sampleRate = float(sampleRate) + self.nspikeColl.resample(sampleRate) + self.covarColl.resample(sampleRate) + self.resampleEnsColl() + + def resample(self, sampleRate: float) -> None: + self.setSampleRate(sampleRate) + + def setEnsCovMask(self, mask=None) -> None: + if mask is None or mask == []: + nSpikes = self.nspikeColl.numSpikeTrains + mask = np.ones((nSpikes, nSpikes), dtype=int) - np.eye(nSpikes, dtype=int) + self.ensCovMask = np.asarray(mask, dtype=int) + + def setCovMask(self, mask) -> None: + if isinstance(mask, str) and mask == "all": + self.covarColl.resetMask() + else: + self.covarColl.setMask(mask) + self.covMask = self.covarColl.covMask + + def resetCovMask(self) -> None: + self.covarColl.resetMask() + self.covMask = self.covarColl.covMask + + def setNeuronMask(self, mask) -> None: + self.nspikeColl.setMask(mask) + self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() + + def resetNeuronMask(self) -> None: + self.nspikeColl.resetMask() + self.neuronMask = np.asarray(self.nspikeColl.neuronMask, dtype=int).copy() + + def setNeighbors(self, *args) -> None: + self.nspikeColl.setNeighbors(*args) + + def setHistory(self, hist) -> None: + if hist is None or hist == []: + self.history = [] + return + from .history import History + + if isinstance(hist, History): + self.history = hist + return + if isinstance(hist, Sequence) and not isinstance(hist, (str, bytes)): + if hist and all(isinstance(item, History) for item in hist): + self.history = list(hist) + return + arr = np.asarray(hist, dtype=float).reshape(-1) + if arr.size <= 1: + raise ValueError("At least two times points must be specified to determine a window") + self.history = History(arr) + return + raise TypeError("Can only set trial history by using History objects or windowTimes") + + def resetHistory(self) -> None: + self.history = [] + + def setEnsCovHist(self, hist=None) -> None: + if hist is None or hist == []: + self.ensCovHist = [] + self.ensCovColl = None + return + from .history import History + + if isinstance(hist, History): + self.ensCovHist = hist + elif isinstance(hist, Sequence) and not isinstance(hist, (str, bytes)): + arr = np.asarray(hist, dtype=float).reshape(-1) + if arr.size <= 1: + raise ValueError("At least two times points must be specified to determine a window") + self.ensCovHist = History(arr) + else: + raise TypeError("Can only set trial ensCovHist by using History objects or windowTimes") + self.ensCovColl = self.getEnsembleNeuronCovariates(1, [], self.ensCovHist) + + def isNeuronMaskSet(self) -> bool: + return self.nspikeColl.isNeuronMaskSet() + + def isCovMaskSet(self) -> bool: + return self.covarColl.isCovMaskSet() + + def isMaskSet(self) -> bool: + return self.isNeuronMaskSet() or self.isCovMaskSet() + + def isHistSet(self) -> bool: + if self.history in (None, []): + return False + from .history import History + + if isinstance(self.history, History): + return True + return isinstance(self.history, list) and bool(self.history) and all(isinstance(item, History) for item in self.history) + + def isEnsCovHistSet(self) -> bool: + from .history import History + + return isinstance(self.ensCovHist, History) + + def addCov(self, cov: Covariate) -> None: + self.covarColl.addToColl(cov) + self.covMask = self.covarColl.covMask + if not self.isSampleRateConsistent(): + self.makeConsistentSampleRate() + self.makeConsistentTime() + + def removeCov(self, identifier: int | str) -> None: + self.covarColl.removeCovariate(identifier) + self.covMask = self.covarColl.covMask + if not self.isSampleRateConsistent(): + self.makeConsistentSampleRate() + self.makeConsistentTime() + + def getSpikeVector(self, *args, neuron_index: int = 1) -> np.ndarray: + if not args: + return self.nspikeColl.dataToMatrix() + first = args[0] + if isinstance(first, Sequence) and not isinstance(first, (str, bytes, np.ndarray)): + bin_edges = np.asarray(first, dtype=float).reshape(-1) + return self.nspikeColl.getNST(neuron_index).to_binned_counts(bin_edges) + return self.nspikeColl.dataToMatrix(*args) def get_covariate_matrix(self, selected_covariates: Sequence[str] | None = None) -> tuple[np.ndarray, np.ndarray, list[str]]: - return self.covariate_collection.dataToMatrix(selected_covariates) + return self.covarColl.matrixWithTime("standard", selected_covariates) + + def getDesignMatrix(self, neuronNum: int, dataSelector=None) -> np.ndarray: + X = self.covarColl.dataToMatrix("standard", dataSelector) + if self.isHistSet(): + H = self.getHistMatrices(neuronNum) + X = H if X.size == 0 else np.column_stack([X, H]) + if self.isEnsCovHistSet(): + E = self.getEnsCovMatrix(neuronNum) + X = E if X.size == 0 else np.column_stack([X, E]) + return X + + def getHistForNeurons(self, neuronIndex) -> CovariateCollection: + if not self.isHistSet(): + raise ValueError("Set Trial history and retry") + nst = self.nspikeColl.getNST(neuronIndex) + if isinstance(self.history, list): + histCovColl: CovariateCollection | None = None + for i, hist in enumerate(self.history, start=1): + temp = hist.computeHistory(nst, i) + histCovColl = temp if histCovColl is None else CovariateCollection([*histCovColl.covArray, *temp.covArray]) + assert histCovColl is not None + return histCovColl + return self.history.computeHistory(nst) + + def getHistMatrices(self, neuronIndex: int) -> np.ndarray: + if not self.isHistSet(): + time = self.nspikeColl.getNST(neuronIndex).getSigRep().time + return np.zeros((time.size, 0), dtype=float) + histCovColl = self.getHistForNeurons(neuronIndex) + return histCovColl.dataToMatrix("standard") + + def getEnsembleNeuronCovariates(self, *args): + return self.nspikeColl.getEnsembleNeuronCovariates(*args) + + def getEnsCovMatrix(self, neuronNum: int, includedNeurons=None) -> np.ndarray: + if not self.isEnsCovHistSet() or self.ensCovColl is None: + return np.zeros((self.nspikeColl.getNST(neuronNum).getSigRep().time.size, 0), dtype=float) + if includedNeurons is None: + includedNeurons = np.flatnonzero(self.ensCovMask[:, neuronNum - 1] == 1) + 1 + ensCovCollTemp = CovariateCollection(self.ensCovColl.covArray) + ensCovCollTemp.covMask = [mask.copy() for mask in self.ensCovColl.covMask] + ensCovCollTemp.maskAwayAllExcept(includedNeurons) + return ensCovCollTemp.dataToMatrix("standard") + + def getNeuronIndFromMask(self) -> list[int]: + return self.nspikeColl.getIndFromMask() + + def getNumUniqueNeurons(self) -> int: + return len(self.nspikeColl.uniqueNeuronNames) + + def getNeuronNames(self) -> list[str]: + return self.nspikeColl.getNSTnames() + + def getUniqueNeuronNames(self) -> list[str]: + return self.nspikeColl.getUniqueNSTnames() + + def getNeuronIndFromName(self, neuronName: str): + tempInd = self.nspikeColl.getNSTIndicesFromName(neuronName) + currMask = set(self.neuronMask_indices()) + if isinstance(tempInd, list): + return [idx for idx in tempInd if idx in currMask] + return [tempInd] if tempInd in currMask else [] + + def neuronMask_indices(self) -> list[int]: + return self.nspikeColl.getIndFromMask() + + def getNeuronNeighbors(self, neuronNum=None): + if neuronNum is None: + neuronNum = self.getNeuronIndFromMask() + return self.nspikeColl.getNeighbors(neuronNum) + + def getCovSelectorFromMask(self): + return self.covarColl.getSelectorFromMasks() + + def getCov(self, identifier): + return self.covarColl.getCov(identifier) + + def getNeuron(self, identifier): + return self.nspikeColl.getNST(identifier) + + def getAllCovLabels(self) -> list[str]: + return self.covarColl.getAllCovLabels() + + def getCovLabelsFromMask(self) -> list[str]: + return self.covarColl.getCovLabelsFromMask() + + def getHistLabels(self) -> list[str]: + if not self.isHistSet(): + return [] + return self.getHistForNeurons(1).getAllCovLabels() + + def getEnsCovLabels(self) -> list[str]: + if not self.isEnsCovHistSet() or self.ensCovColl is None: + return [] + return self.ensCovColl.getAllCovLabels() + + def getEnsCovLabelsFromMask(self, neuronNum: int) -> list[str]: + if not self.isEnsCovHistSet() or self.ensCovColl is None: + return [] + included = np.flatnonzero(self.ensCovMask[:, neuronNum - 1] == 1) + 1 + ensCovCollTemp = CovariateCollection(self.ensCovColl.covArray) + ensCovCollTemp.covMask = [mask.copy() for mask in self.ensCovColl.covMask] + ensCovCollTemp.maskAwayAllExcept(included) + return ensCovCollTemp.getCovLabelsFromMask() + + def getLabelsFromMask(self, neuronNum: int) -> list[str]: + labels = list(self.getCovLabelsFromMask()) + labels.extend(self.getHistLabels()) + labels.extend(self.getEnsCovLabelsFromMask(neuronNum)) + return labels + + def flattenCovMask(self) -> np.ndarray: + return self.covarColl.flattenCovMask() + + def flattenMask(self) -> np.ndarray: + flat = self.flattenCovMask() + if self.isHistSet(): + flat = np.concatenate([flat, np.ones(len(self.getHistLabels()), dtype=int)]) + if self.isEnsCovHistSet(): + flat = np.concatenate([flat, np.ones(len(self.getEnsCovLabels()), dtype=int)]) + return flat + + def shiftCovariates(self, *args) -> None: + self.covarColl.setCovShift(*args) + self.makeConsistentTime() + + def resetEnsCovMask(self) -> None: + self.setEnsCovMask() + + def resampleEnsColl(self) -> None: + if self.ensCovColl is not None and self.ensCovHist not in (None, []): + self.ensCovColl = self.getEnsembleNeuronCovariates(1, [], self.ensCovHist) + else: + self.setEnsCovHist() + + def restoreToOriginal(self) -> None: + self.nspikeColl.restoreToOriginal() + self.covarColl.restoreToOriginal() + if not self.isSampleRateConsistent(): + self.makeConsistentSampleRate() + self.resampleEnsColl() + self.makeConsistentTime() + + def makeConsistentSampleRate(self) -> None: + self.resample(self.findMaxSampleRate()) + + def makeConsistentTime(self) -> None: + self.setMinTime(self.findMinTime()) + self.setMaxTime(self.findMaxTime()) + + def isSampleRateConsistent(self) -> bool: + if self.nspikeColl.numSpikeTrains == 0 or self.covarColl.numCov == 0: + return True + target = round(float(self.findMaxSampleRate()), 3) + values = [round(float(self.nspikeColl.sampleRate), 3), round(float(self.covarColl.sampleRate), 3)] + return all(value == target for value in values) + + def findMaxSampleRate(self) -> float: + values = [value for value in [self.nspikeColl.findMaxSampleRate(), self.covarColl.findMaxSampleRate()] if np.isfinite(value)] + return float(max(values)) if values else float("nan") + + def findMinTime(self) -> float: + return float(min(self.nspikeColl.minTime, self.covarColl.minTime)) - def getSpikeVector(self, bin_edges: Sequence[float], neuron_index: int = 1) -> np.ndarray: - return self.spike_collection.getNST(neuron_index).to_binned_counts(bin_edges) + def findMaxTime(self) -> float: + return float(max(self.nspikeColl.maxTime, self.covarColl.maxTime)) # Backward-compatible MATLAB-style aliases. diff --git a/tests/test_api_surface.py b/tests/test_api_surface.py index 8ddfd8d8..30bcbffe 100644 --- a/tests/test_api_surface.py +++ b/tests/test_api_surface.py @@ -24,11 +24,17 @@ def test_canonical_api_imports() -> None: def test_matlab_facing_class_imports_are_canonical() -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") + from nstat.ConfigColl import ConfigColl + from nstat.CovColl import CovColl from nstat.SignalObj import SignalObj + from nstat.TrialConfig import TrialConfig from nstat.Covariate import Covariate from nstat.nspikeTrain import nspikeTrain _ = SignalObj([0.0, 1.0], [1.0, 2.0]) _ = Covariate([0.0, 1.0], [1.0, 2.0]) _ = nspikeTrain([0.25, 0.5], makePlots=-1) + _ = CovColl([]) + _ = ConfigColl([]) + _ = TrialConfig() assert not any("deprecated" in str(item.message).lower() for item in w) diff --git a/tests/test_trial_fidelity.py b/tests/test_trial_fidelity.py new file mode 100644 index 00000000..4bcf886d --- /dev/null +++ b/tests/test_trial_fidelity.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from nstat import Covariate, Events, History, Trial, TrialConfig, nspikeTrain +from nstat.ConfigColl import ConfigColl +from nstat.CovColl import CovColl +from nstat.nstColl import nstColl + + +def _make_covariates() -> tuple[Covariate, Covariate]: + time = np.array([0.0, 0.5, 1.0], dtype=float) + position = Covariate(time, np.column_stack([[0.0, 1.0, 2.0], [10.0, 11.0, 12.0]]), "Position", "time", "s", "", ["x", "y"]) + stimulus = Covariate(time, [5.0, 6.0, 7.0], "Stimulus", "time", "s", "a.u.", ["stim"]) + return position, stimulus + + +def _make_spikes() -> tuple[nspikeTrain, nspikeTrain]: + n1 = nspikeTrain([0.0, 0.5, 1.0], "n1", 0.5, 0.0, 1.0, makePlots=-1) + n2 = nspikeTrain([0.25, 0.75], "n2", 0.5, 0.0, 1.0, makePlots=-1) + return n1, n2 + + +def test_covcoll_masking_selector_and_time_matrix() -> None: + position, stimulus = _make_covariates() + coll = CovColl([position, stimulus]) + + assert coll.numCov == 2 + assert coll.getCovIndFromName("Position") == 1 + assert coll.getCovIndFromName("Stimulus") == 2 + + coll.setMask([["Position", "x"], ["Stimulus"]]) + time, matrix, labels = coll.matrixWithTime() + + np.testing.assert_allclose(time, [0.0, 0.5, 1.0]) + np.testing.assert_allclose(matrix, [[0.0, 5.0], [1.0, 6.0], [2.0, 7.0]]) + assert labels == ["x", "stim"] + assert coll.getCovLabelsFromMask() == ["x", "stim"] + + coll.setCovShift(0.5) + shifted = coll.getCov("Stimulus") + np.testing.assert_allclose(shifted.time, [0.5, 1.0, 1.5]) + + +def test_nstcoll_neighbors_mask_and_data_matrix() -> None: + train1, train2 = _make_spikes() + coll = nstColl() + coll.addToColl([train1, train2]) + + assert coll.numSpikeTrains == 2 + assert coll.getNSTnames() == ["n1", "n2"] + assert coll.getUniqueNSTnames() == ["n1", "n2"] + + coll.setNeighbors() + assert coll.getNeighbors(1) == [2] + assert coll.getNeighbors(2) == [1] + + coll.setMask([1]) + assert coll.getIndFromMask() == [1] + np.testing.assert_allclose(coll.getMaxBinSizeBinary(), 0.5) + + matrix = coll.dataToMatrix([1, 2], 0.5, 0.0, 1.0) + np.testing.assert_allclose(matrix, [[1.0, 0.0], [1.0, 1.0], [1.0, 1.0]]) + + +def test_trialconfig_and_configcoll_apply_and_roundtrip() -> None: + position, stimulus = _make_covariates() + train1, train2 = _make_spikes() + trial = Trial(nstColl([train1, train2]), CovColl([position, stimulus])) + + cfg = TrialConfig( + covMask=[["Position", "x"], ["Stimulus"]], + sampleRate=2.0, + history=[0.0, 0.5, 1.0], + covLag=0.5, + name="stim_pos", + ) + cfg.setConfig(trial) + + assert round(trial.sampleRate, 3) == 2.0 + assert trial.isHistSet() + assert trial.getCovLabelsFromMask() == ["x", "stim"] + + roundtrip = TrialConfig.fromStructure(cfg.toStructure()) + assert roundtrip.name == "stim_pos" + assert roundtrip.covariate_names == ["Position", "Stimulus"] + + configs = ConfigColl([cfg, "manual", None]) + assert configs.numConfigs == 3 + assert configs.getConfigNames() == ["stim_pos", "manual", "Empty Config"] + subset = configs.getSubsetConfigs([1, 2]) + assert subset.numConfigs == 2 + rebuilt = ConfigColl.fromStructure(configs.toStructure()) + assert rebuilt.getConfigNames() == ["stim_pos", "manual", "Empty Config"] + + +def test_trial_partition_history_design_matrix_and_spike_vector() -> None: + position, stimulus = _make_covariates() + train1, train2 = _make_spikes() + events = Events([0.25, 0.75], ["cue", "reward"], "g") + hist = History([0.0, 0.5, 1.0]) + trial = Trial(nstColl([train1, train2]), CovColl([position, stimulus]), events, hist) + + assert trial.getEvents() is events + assert trial.isHistSet() + assert len(trial.getHistLabels()) == 2 + + trial.setTrialPartition([0.0, 0.5, 1.0]) + np.testing.assert_allclose(trial.getTrialPartition(), [0.0, 0.5, 0.5, 1.0]) + trial.setTrialTimesFor("validation") + np.testing.assert_allclose([trial.minTime, trial.maxTime], [0.5, 1.0]) + + design = trial.getDesignMatrix(1) + assert design.shape[1] == 5 + spikes = trial.getSpikeVector() + assert spikes.shape[1] == 2 + + +def test_events_validation_and_history_collection_output() -> None: + with pytest.raises(ValueError, match="Number of eventTimes"): + Events([0.1, 0.2], ["one"]) + + events = Events([0.1], ["cue"], "b") + rebuilt = Events.fromStructure(events.toStructure()) + assert rebuilt is not None + assert rebuilt.eventColor == "b" + assert rebuilt.eventLabels == ["cue"] + + history = History([0.0, 0.5, 1.0]) + train, _ = _make_spikes() + hist_cov = history.computeHistory(train) + assert hist_cov.numCov == 1 + np.testing.assert_allclose(hist_cov.dataToMatrix().shape, (3, 2)) diff --git a/tests/test_workflow_fidelity.py b/tests/test_workflow_fidelity.py new file mode 100644 index 00000000..ef0f2e62 --- /dev/null +++ b/tests/test_workflow_fidelity.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import numpy as np + +from nstat import Analysis, CIF, CIFModel, DecodingAlgorithms, FitResSummary, Trial, TrialConfig +from nstat.ConfigColl import ConfigColl +from nstat.CovColl import CovColl +from nstat.Covariate import Covariate +from nstat.Events import Events +from nstat.History import History +from nstat.FitResult import FitResult +from nstat.nstColl import nstColl +from nstat.nspikeTrain import nspikeTrain + + +def _build_trial() -> Trial: + time = np.arange(0.0, 1.0, 0.1) + stim = Covariate(time, np.sin(2 * np.pi * time), "Stimulus", "time", "s", "", ["stim"]) + vel = Covariate(time, np.cos(2 * np.pi * time), "Velocity", "time", "s", "", ["vel"]) + spikes = nstColl( + [ + nspikeTrain([0.1, 0.3, 0.7], "1", 0.1, 0.0, 0.9, makePlots=-1), + nspikeTrain([0.2, 0.5, 0.8], "2", 0.1, 0.0, 0.9, makePlots=-1), + ] + ) + return Trial(spikes, CovColl([stim, vel]), Events([0.2], ["cue"]), History([0.0, 0.1, 0.2])) + + +def test_analysis_returns_matlab_style_fitresult_surface() -> None: + trial = _build_trial() + configs = ConfigColl( + [ + TrialConfig(covMask=[["Stimulus"]], sampleRate=10.0, history=[0.0, 0.1, 0.2], name="stim_hist"), + TrialConfig(covMask=[["Velocity"]], sampleRate=10.0, name="vel_only"), + ] + ) + + fit = Analysis.RunAnalysisForNeuron(trial, 1, configs) + + assert isinstance(fit, FitResult) + assert fit.numResults == 2 + assert fit.configNames == ["stim_hist", "vel_only"] + assert fit.lambdaSignal.dimension == 2 + assert fit.neuronNumber == 1.0 + assert len(fit.covLabels) == 2 + assert "stim" in fit.uniqueCovLabels + assert fit.getCoeffs(1).shape[0] >= 2 + assert fit.getHistCoeffs(1).shape[0] == 2 + + +def test_fitresult_roundtrip_and_summary_preserve_core_metadata() -> None: + trial = _build_trial() + configs = ConfigColl([TrialConfig(covMask=[["Stimulus"]], sampleRate=10.0, name="stim_only")]) + fits = Analysis.RunAnalysisForAllNeurons(trial, configs) + + rebuilt = FitResult.fromStructure(fits[0].toStructure()) + assert rebuilt.numResults == fits[0].numResults + assert rebuilt.configNames == fits[0].configNames + np.testing.assert_allclose(rebuilt.AIC, fits[0].AIC) + + summary = FitResSummary(fits) + assert summary.numNeurons == 2 + assert summary.numResults == 1 + assert summary.fitNames == ["stim_only"] + assert summary.AIC.shape == (1,) + + +def test_cif_instantiation_evaluation_and_simulate_from_lambda() -> None: + cif = CIF([0.2, -0.1], ["stim", "vel"], ["stim"], fitType="poisson") + design = np.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]], dtype=float) + rate = cif.evaluate(design, delta=0.1) + assert rate.shape == (3,) + assert np.all(rate > 0) + + time = np.array([0.0, 0.1, 0.2], dtype=float) + cov = CIF.from_linear_terms(time, 0.1, np.array([0.2, -0.1]), design, 0.1, "lambda") + sim = CIF.simulateCIFByThinningFromLambda(cov, numRealizations=2) + assert sim.numSpikeTrains == 2 + + model = CIFModel(time, np.array([5.0, 6.0, 7.0]), name="lambda") + sim2 = model.simulate(num_realizations=1, seed=1) + assert sim2.numSpikeTrains == 1 + + +def test_decoding_aliases_produce_state_and_covariance_outputs() -> None: + obs = np.array([[1.0], [0.5], [0.2]], dtype=float) + a = np.array([[1.0]], dtype=float) + h = np.array([[1.0]], dtype=float) + q = np.array([[0.01]], dtype=float) + r = np.array([[0.04]], dtype=float) + x0 = np.array([0.0], dtype=float) + p0 = np.array([[1.0]], dtype=float) + + out = DecodingAlgorithms.PPDecodeFilterLinear(obs, a, h, q, r, x0, p0) + assert out["state"].shape == (3, 1) + assert out["cov"].shape == (3, 1, 1) + + +def test_history_and_events_roundtrip_in_workflow_context() -> None: + history = History([0.0, 0.2, 0.4], minTime=0.0, maxTime=1.0) + rebuilt_history = History.fromStructure(history.toStructure()) + assert rebuilt_history is not None + np.testing.assert_allclose(rebuilt_history.windowTimes, history.windowTimes) + + events = Events([0.1, 0.4], ["start", "stop"], "m") + rebuilt_events = Events.fromStructure(events.toStructure()) + assert rebuilt_events is not None + assert rebuilt_events.eventColor == "m" + assert rebuilt_events.eventLabels == ["start", "stop"] From efbc8c0bdd83f172751b553ee37965f824ce8c09 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 11:58:53 -0500 Subject: [PATCH 2/5] Audit class fidelity and sync notebook parity notes --- notebooks/AnalysisExamples.ipynb | 12 ++ notebooks/DecodingExample.ipynb | 12 ++ notebooks/DecodingExampleWithHist.ipynb | 12 ++ notebooks/ExplicitStimulusWhiskerData.ipynb | 12 ++ notebooks/HippocampalPlaceCellExample.ipynb | 12 ++ notebooks/HybridFilterExample.ipynb | 12 ++ notebooks/PPSimExample.ipynb | 12 ++ notebooks/StimulusDecode2D.ipynb | 12 ++ notebooks/TrialExamples.ipynb | 12 ++ notebooks/ValidationDataSet.ipynb | 12 ++ notebooks/nSTATPaperExamples.ipynb | 12 ++ parity/class_fidelity.yml | 207 ++++++++++---------- parity/report.md | 15 +- tests/test_notebook_ci_groups.py | 2 + tests/test_notebook_parity_notes.py | 38 ++++ tools/notebooks/parity_notes.yml | 57 ++++++ tools/notebooks/sync_parity_notes.py | 55 ++++++ tools/notebooks/topic_groups.yml | 1 + 18 files changed, 393 insertions(+), 114 deletions(-) create mode 100644 tests/test_notebook_parity_notes.py create mode 100644 tools/notebooks/parity_notes.yml create mode 100644 tools/notebooks/sync_parity_notes.py diff --git a/notebooks/AnalysisExamples.ipynb b/notebooks/AnalysisExamples.ipynb index 23162719..1c90b5bb 100644 --- a/notebooks/AnalysisExamples.ipynb +++ b/notebooks/AnalysisExamples.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "db89e36d", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `AnalysisExamples.mlx`\n", + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: Advanced MATLAB algorithm-selection branches and some report plots are still lighter in Python." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/DecodingExample.ipynb b/notebooks/DecodingExample.ipynb index 619ca0b0..39b36c6f 100644 --- a/notebooks/DecodingExample.ipynb +++ b/notebooks/DecodingExample.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "c488b5fa", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `DecodingExample.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/DecodingExampleWithHist.ipynb b/notebooks/DecodingExampleWithHist.ipynb index 80200c11..ae6b9c6a 100644 --- a/notebooks/DecodingExampleWithHist.ipynb +++ b/notebooks/DecodingExampleWithHist.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "739c56fe", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `DecodingExampleWithHist.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/ExplicitStimulusWhiskerData.ipynb b/notebooks/ExplicitStimulusWhiskerData.ipynb index ceca5c10..aaf09b7c 100644 --- a/notebooks/ExplicitStimulusWhiskerData.ipynb +++ b/notebooks/ExplicitStimulusWhiskerData.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "8bf801b2", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `ExplicitStimulusWhiskerData.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: Dataset-backed workflow is present, but figure-level and narrative parity with MATLAB are still incomplete." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/HippocampalPlaceCellExample.ipynb b/notebooks/HippocampalPlaceCellExample.ipynb index 1859a108..75575d5b 100644 --- a/notebooks/HippocampalPlaceCellExample.ipynb +++ b/notebooks/HippocampalPlaceCellExample.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "cd1a2218", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `HippocampalPlaceCellExample.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: Core place-cell workflow is ported, but MATLAB figure sequencing and summary outputs are not yet exact." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/HybridFilterExample.ipynb b/notebooks/HybridFilterExample.ipynb index 8654bf01..21899728 100644 --- a/notebooks/HybridFilterExample.ipynb +++ b/notebooks/HybridFilterExample.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "0e36ffa9", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `HybridFilterExample.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: Hybrid filtering workflow executes, but MATLAB-specific output details and downstream validation remain incomplete." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/PPSimExample.ipynb b/notebooks/PPSimExample.ipynb index 8173e71b..4dcb3554 100644 --- a/notebooks/PPSimExample.ipynb +++ b/notebooks/PPSimExample.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "8968c9f0", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `PPSimExample.mlx`\n", + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: MATLAB plotting/report formatting remains lighter, but the core point-process simulation workflow is closely aligned." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/StimulusDecode2D.ipynb b/notebooks/StimulusDecode2D.ipynb index e8b3d80d..3ced1d63 100644 --- a/notebooks/StimulusDecode2D.ipynb +++ b/notebooks/StimulusDecode2D.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "59333bc7", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `StimulusDecode2D.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: The 2D stimulus decoding workflow runs, but MATLAB-equivalent outputs and tolerance-backed parity checks still need expansion." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/TrialExamples.ipynb b/notebooks/TrialExamples.ipynb index 49c93b10..ccdf0c95 100644 --- a/notebooks/TrialExamples.ipynb +++ b/notebooks/TrialExamples.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "93f72e2f", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `TrialExamples.mlx`\n", + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: Some MATLAB plotting/display details remain simplified, but the core Trial object workflow now follows the MATLAB semantics closely." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/ValidationDataSet.ipynb b/notebooks/ValidationDataSet.ipynb index f0d06f55..2511325c 100644 --- a/notebooks/ValidationDataSet.ipynb +++ b/notebooks/ValidationDataSet.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "575b2a91", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `ValidationDataSet.mlx`\n", + "- Fidelity status: `partial`\n", + "- Remaining justified differences: Validation dataset coverage exists, but MATLAB reference summaries and figure parity are not yet complete." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/nSTATPaperExamples.ipynb b/notebooks/nSTATPaperExamples.ipynb index 309ac5a0..b08f0a69 100644 --- a/notebooks/nSTATPaperExamples.ipynb +++ b/notebooks/nSTATPaperExamples.ipynb @@ -1,5 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "21e0c2ce", + "metadata": {}, + "source": [ + "\n", + "## MATLAB Parity Note\n", + "- Source MATLAB helpfile: `nSTATPaperExamples.mlx`\n", + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: Python uses standalone figshare-backed data access and generated gallery assets rather than MATLAB path-based setup." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 6e2707b8..c8837c66 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -73,137 +73,142 @@ items: matlab_path: nstColl.m python_symbol: nstat.nstColl python_path: nstat/trial.py - status: partial - constructor_parity: Basic collection construction exists, but MATLAB supports richer empty-init and object-state patterns. - property_parity: numSpikeTrains, minTime, maxTime, and sampleRate are present. - method_parity: getNST, dataToMatrix, psth, and psthGLM exist; masking, config utilities, neighborhood operations, and richer analysis helpers are still missing. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: Empty construction, direct sequence construction, and MATLAB-style collection state initialization now match MATLAB much more closely. + property_parity: Core MATLAB-visible fields exist, including nstrain, numSpikeTrains, minTime, maxTime, sampleRate, neuronMask, and neighbors. + method_parity: MATLAB-facing collection methods are now first-class, including addToColl, addSingleSpikeToColl, merge, getNST, name/index lookup, masking, neighborhood management, dataToMatrix, ensemble-covariate helpers, restoreToOriginal, psth, and psthGLM. + default_value_parity: Defaults for masks, sample rate, and min/max time now track MATLAB collection semantics closely. shape_and_indexing_parity: MATLAB-facing one-based getNST is preserved. - error_warning_parity: Simplified. + error_warning_parity: Core validation is present, though MATLAB warning text and some edge-case messages still differ. output_type_parity: PSTH returns Covariate. known_semantic_differences: - - No MATLAB-equivalent mask state or richer analysis utilities yet. + - Some plotting/statistics helpers and lower-level utility methods from MATLAB are still absent. recommended_remediation: - - Port the remaining collection methods from MATLAB and move the class into a canonical MATLAB-facing implementation file. + - Add MATLAB-derived fixtures for neighbor masks, ensemble covariates, and PSTH outputs. + - Port any remaining collection utilities that surface in MATLAB helpfiles. - matlab_name: Trial kind: class matlab_path: Trial.m python_symbol: nstat.Trial python_path: nstat/trial.py - status: partial - constructor_parity: Supports core spike/covariate/event wiring, but not the full MATLAB constructor and object-state surface. - property_parity: spikeColl and covarColl are exposed; broader trial metadata/state is still missing. - method_parity: Limited to core matrix/vector access. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: The canonical Python Trial now accepts MATLAB-style spike, covariate, event, history, and ensemble-history inputs and normalizes trial state similarly to MATLAB. + property_parity: Core MATLAB-facing state is now present, including nspikeColl, covarColl, ev, history, ensCovHist, ensCovColl, sampleRate, minTime, maxTime, covMask, ensCovMask, neuronMask, trainingWindow, and validationWindow. + method_parity: The MATLAB trial workflow is much richer now, covering event/history setup, partitioning, sample-rate and time consistency, neuron/covariate masking, design-matrix generation, history/ensemble covariates, label extraction, and restore/reset helpers. + default_value_parity: Default object state and partition behavior are much closer to MATLAB than the earlier thin implementation. shape_and_indexing_parity: Core one-based neuron selection is preserved via getSpikeVector. - error_warning_parity: Simplified. - output_type_parity: Returns NumPy arrays rather than richer MATLAB-style objects in several workflows. + error_warning_parity: Core validation is present, but some MATLAB warning and edge-case pathways still differ. + output_type_parity: Matrix-producing methods intentionally return NumPy arrays, while MATLAB-facing object-producing workflows return Trial/CovColl/nstColl-compatible objects where expected. known_semantic_differences: - - Trial workflow semantics remain much thinner than MATLAB. + - Some MATLAB plotting, partition-serialization, and specialized workflow helpers remain unported. recommended_remediation: - - Port richer trial state, consistency checks, and MATLAB workflow helpers. + - Add dataset-backed fixtures for trial partitioning, ensemble-history construction, and design-matrix parity. + - Port the remaining specialized Trial helpers used only in MATLAB helpfiles. - matlab_name: TrialConfig kind: class matlab_path: TrialConfig.m python_symbol: nstat.TrialConfig python_path: nstat/trial.py - status: partial - constructor_parity: Current dataclass captures only a subset of MATLAB configuration fields. - property_parity: covMask, sampleRate, history, ensCovHist, covLag, and name exist, but MATLAB exposes richer behavior. - method_parity: Only naming and covariate-name extraction are currently implemented. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: The constructor now matches MATLAB intent much more closely, including covMask, sampleRate, history, ensCovHist, ensCovMask, covLag, and name handling. + property_parity: Core configuration fields and normalized metadata are now exposed in the canonical implementation rather than a dataclass shim. + method_parity: MATLAB-facing methods now include naming, structure round-trip, and setConfig application against Trial state. + default_value_parity: Defaults for empty masks/configs and name handling are close to MATLAB. shape_and_indexing_parity: N/A for this class. - error_warning_parity: Simplified. - output_type_parity: Python dataclass rather than a richer MATLAB handle-style object. + error_warning_parity: Validation is still lighter than MATLAB in some malformed-configuration paths. + output_type_parity: Returns and mutates canonical TrialConfig/Trial objects as expected. known_semantic_differences: - - Configuration validation and selection behavior are incomplete. + - Some MATLAB normalization and validation branches remain looser in Python. recommended_remediation: - - Port MATLAB configuration validation, normalization, and selection workflows. + - Add malformed-config fixtures from MATLAB to tighten validation and default coercion behavior. - matlab_name: ConfigColl kind: class matlab_path: ConfigColl.m python_symbol: nstat.ConfigColl python_path: nstat/trial.py - status: partial - constructor_parity: Basic collection support exists. - property_parity: numConfigs and configArray exist. - method_parity: addConfig, getConfig, and getConfigNames exist; MATLAB collection utilities are broader. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: Supports MATLAB-style collections of TrialConfig objects, string-named configs, and empty config placeholders. + property_parity: numConfigs, configNames, and configArray are exposed with MATLAB-style semantics. + method_parity: addConfig, getConfig, setConfig, getConfigNames, setConfigNames, getSubsetConfigs, and structure round-trip are now implemented in the canonical collection. + default_value_parity: Empty-config and naming defaults now align closely with MATLAB behavior. shape_and_indexing_parity: One-based getConfig behavior is preserved. - error_warning_parity: Simplified. + error_warning_parity: Basic validation exists, though some MATLAB collection-coercion edge cases are still looser. output_type_parity: Returns TrialConfig instances. known_semantic_differences: - - Richer MATLAB config-management behavior is still missing. + - Some MATLAB-specific collection manipulation helpers remain unported. recommended_remediation: - - Port the remaining ConfigColl helpers and name/selection semantics from MATLAB. + - Add fixture-backed tests for edge-case config coercion and selection semantics. - matlab_name: Analysis kind: class matlab_path: Analysis.m python_symbol: nstat.Analysis python_path: nstat/analysis.py - status: partial - constructor_parity: Python analysis setup exists, but MATLAB option surface and workflow selection semantics are richer. - property_parity: Partial only. - method_parity: Core fitting helpers exist; RunAnalysisForNeuron and RunAnalysisForAllNeurons are still simplified relative to MATLAB. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial only. - error_warning_parity: Simplified. - output_type_parity: Returns FitSummary/FitResult equivalents, but not with full MATLAB metadata fidelity. + status: high_fidelity + constructor_parity: Analysis remains a static-workflow class in Python, but the MATLAB-facing entry points are now aligned around RunAnalysisForNeuron and RunAnalysisForAllNeurons semantics. + property_parity: N/A for the static workflow surface. + method_parity: Canonical analysis now restores trial state, applies ConfigColl entries, builds MATLAB-style design matrices and labels, and returns richer FitResult metadata for per-neuron and all-neuron workflows. + default_value_parity: Default fitting behavior and Poisson-GLM selection are much closer to the MATLAB workflow defaults. + shape_and_indexing_parity: MATLAB-facing one-based neuron numbering remains available through the public entry points. + error_warning_parity: Core validation is present, though algorithm-selection and advanced option warnings remain thinner than MATLAB. + output_type_parity: Returns MATLAB-facing FitResult/FitResSummary-compatible objects with richer metadata than the previous simplified implementation. known_semantic_differences: - - Algorithm selection and analysis-option semantics are still thinner than MATLAB. + - Advanced MATLAB algorithm-selection, cross-validation, and plotting/reporting branches are still incomplete. recommended_remediation: - - Port MATLAB analysis options and representative workflow outputs into dataset-backed tests. + - Add dataset-backed numerical parity fixtures for canonical analysis workflows. + - Port remaining algorithm-selection and validation-option branches from MATLAB. - matlab_name: FitResult kind: class matlab_path: FitResult.m python_symbol: nstat.FitResult python_path: nstat/fit.py - status: partial - constructor_parity: Partial. - property_parity: Core lambda/spike-train references exist, but MATLAB surface is richer. - method_parity: Summary/reporting methods are only partially ported. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: The canonical constructor now supports both the legacy simplified Python path and a MATLAB-style metadata-rich construction path. + property_parity: Core MATLAB-facing result fields are now present, including lambda aliases, config metadata, coefficient arrays, history metadata, AIC/BIC/logLL, validation placeholders, and plotParams scaffolding. + method_parity: getCoeffs, getHistCoeffs, mergeResults, and structure round-trip now operate on the richer MATLAB-style result surface. + default_value_parity: Default result metadata and placeholder fields are much closer to MATLAB than the earlier lightweight container. shape_and_indexing_parity: N/A for this class. - error_warning_parity: Simplified. - output_type_parity: Partial. + error_warning_parity: Validation is still lighter than MATLAB in malformed-structure and reporting edge cases. + output_type_parity: Returns canonical FitResult objects with MATLAB-style aliases and list/array fields. known_semantic_differences: - - Fit metadata and reporting behavior remain thinner than MATLAB. + - Plotting, KS/inverse-Gaussian reporting detail, and some summary utilities remain stubbed. recommended_remediation: - - Port MATLAB result-summary and reporting APIs with golden fixtures. + - Add MATLAB-derived golden fixtures for coefficient metadata and validation/report payloads. + - Port the remaining plotting/report helpers used by the MATLAB toolbox. - matlab_name: FitResSummary kind: class matlab_path: FitResSummary.m python_symbol: nstat.FitResSummary python_path: nstat/fit.py - status: partial - constructor_parity: Partial. - property_parity: Partial. - method_parity: Partial. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: Summary objects now aggregate MATLAB-style FitResult collections directly. + property_parity: Core summary fields exist, including fitResCell, numNeurons, numResults, fitNames, neuronNumbers, AIC, BIC, logLL, and KSStats. + method_parity: MATLAB-style difference helpers are implemented through getDiffAIC, getDiffBIC, and getDifflogLL. + default_value_parity: Summary initialization is close for the implemented metadata surface. shape_and_indexing_parity: N/A for this class. - error_warning_parity: Simplified. - output_type_parity: Partial. + error_warning_parity: Still lighter than MATLAB for mismatched summary inputs. + output_type_parity: Returns canonical FitResSummary/FitSummary objects. known_semantic_differences: - - Summary-table and figure/report behavior is not yet MATLAB-equivalent. + - Summary plotting and richer report/table exports are still not MATLAB-equivalent. recommended_remediation: - - Port summary aggregation and reporting semantics from MATLAB. + - Add golden fixtures for multi-neuron summary aggregation and remaining report outputs. - matlab_name: CIF kind: class matlab_path: CIF.m python_symbol: nstat.CIF python_path: nstat/cif.py - status: partial - constructor_parity: Partial. - property_parity: Partial. - method_parity: Simulation and conversion helpers exist, but the full MATLAB object model is broader. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial. - error_warning_parity: Simplified. - output_type_parity: Partial. + status: high_fidelity + constructor_parity: The canonical CIF object now accepts MATLAB-style beta, name, fitType, history, and spike-train metadata. + property_parity: Core modeling metadata is present for fitting and simulation workflows. + method_parity: evaluate, to_covariate, simulateCIFByThinningFromLambda, and from_linear_terms provide the MATLAB-facing simulation and conversion surface used by current workflows. + default_value_parity: Default fitType and basic constructor normalization are close to MATLAB for the implemented workflow subset. + shape_and_indexing_parity: Vector/matrix handling is aligned to MATLAB-style time-by-feature design matrices. + error_warning_parity: Validation is present, though advanced MATLAB error paths remain thinner. + output_type_parity: Returns rate arrays, Covariates, and spike-train collections in the expected workflow positions. known_semantic_differences: - - History-aware and decoding-relevant CIF workflows remain thinner than MATLAB. + - Some history-aware, decoding-specific, and reporting helpers remain unported. recommended_remediation: - - Port MATLAB CIF behaviors used by simulation, fitting, and decoding workflows. + - Add MATLAB-derived fixtures for CIF evaluation and thinning outputs. + - Port the remaining decoding-oriented CIF helpers. - matlab_name: DecodingAlgorithms kind: class matlab_path: DecodingAlgorithms.m @@ -226,35 +231,35 @@ items: matlab_path: History.m python_symbol: nstat.History python_path: nstat/history.py - status: partial - constructor_parity: Partial. - property_parity: Partial. - method_parity: Basic history basis construction exists; richer MATLAB history workflows and outputs are missing. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial. - error_warning_parity: Simplified. - output_type_parity: Partial. + status: high_fidelity + constructor_parity: History now uses MATLAB-style windowTimes construction with optional min/max metadata. + property_parity: windowTimes, minTime, maxTime, and lags-compatible access are exposed. + method_parity: setWindow, computeHistory/compute_history, structure round-trip, and CovColl-producing history workflows are now implemented for single trains, train collections, and trial history use. + default_value_parity: Window-boundary defaults are close to MATLAB for the implemented history workflows. + shape_and_indexing_parity: WindowTimes are interpreted as MATLAB-style consecutive lag boundaries. + error_warning_parity: Core validation is present, though MATLAB warning text and some malformed-input branches remain thinner. + output_type_parity: Returns CovariateCollection outputs in the MATLAB-facing workflows that consume History objects. known_semantic_differences: - - MATLAB returns richer covariate collections and configuration behavior. + - Plotting and some specialized history-basis utilities remain unported. recommended_remediation: - - Port full History object workflows and fixture-backed outputs. + - Add MATLAB-derived fixtures for history-window outputs and multi-neuron history collections. - matlab_name: Events kind: class matlab_path: Events.m python_symbol: nstat.Events python_path: nstat/events.py - status: partial - constructor_parity: Partial. - property_parity: Event times/labels support exists, but color and full validation parity are incomplete. - method_parity: Partial. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial. - error_warning_parity: Simplified. - output_type_parity: Partial. + status: high_fidelity + constructor_parity: Constructor now tracks MATLAB eventTimes, eventLabels, and eventColor semantics, including label-count validation. + property_parity: eventTimes, eventLabels, and eventColor are canonical public fields, with legacy Python aliases preserved. + method_parity: Structure round-trip and notebook/workflow-facing access patterns are implemented. + default_value_parity: Empty-label and default-color behavior are close to MATLAB for the implemented workflow subset. + shape_and_indexing_parity: Event vectors are stored in MATLAB-style flat time/label arrays. + error_warning_parity: Core validation now matches MATLAB intent, though plotting-related behaviors remain absent. + output_type_parity: Returns canonical Events objects. known_semantic_differences: - - MATLAB validation and plotting semantics are not fully ported. + - Plotting and some MATLAB-specific display behaviors are still unported. recommended_remediation: - - Port event validation, color handling, and notebook-backed workflows. + - Add notebook-backed fixtures for event serialization and display workflows. - matlab_name: ConfidenceInterval kind: class matlab_path: ConfidenceInterval.m @@ -277,18 +282,18 @@ items: matlab_path: CovColl.m python_symbol: nstat.CovColl python_path: nstat/trial.py - status: partial - constructor_parity: Basic collection support exists. - property_parity: Partial. - method_parity: add/get/dataToMatrix exist; MATLAB collection behavior is broader. - default_value_parity: Partial only. + status: high_fidelity + constructor_parity: CovColl now supports MATLAB-style direct construction, empty initialization, and nested collection ingestion. + property_parity: Core collection state exists, including covArray, covDimensions, numCov, minTime, maxTime, covMask, covShift, sampleRate, and original timing metadata. + method_parity: MATLAB-facing collection methods are now first-class, covering add/remove, name/index lookup, mask selectors, time-window restriction, resampling, matrixWithTime, dataToMatrix, shift/reset, label extraction, and restoreToOriginal. + default_value_parity: Default mask, shift, sample-rate, and timing behavior now track MATLAB collection semantics closely. shape_and_indexing_parity: Shared-time enforcement is implemented. - error_warning_parity: Simplified. - output_type_parity: Partial. + error_warning_parity: Core validation is present, though some MATLAB warning text and malformed-selector branches are still thinner. + output_type_parity: Returns Covariate and CovariateCollection-compatible outputs across MATLAB-facing workflows. known_semantic_differences: - - Richer selection and covariate management helpers are missing. + - Some structure serialization and rarely used helper methods remain unported. recommended_remediation: - - Port remaining CovColl behaviors and helpfile workflows. + - Add MATLAB-derived fixtures for selector masks, time-window coercion, and serialized collection state. - matlab_name: getPaperDataDirs kind: function matlab_path: getPaperDataDirs.m diff --git a/parity/report.md b/parity/report.md index f7dbd449..446fcb3a 100644 --- a/parity/report.md +++ b/parity/report.md @@ -23,8 +23,8 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. | Status | Count | |---|---:| | `exact` | 0 | -| `high_fidelity` | 1 | -| `partial` | 17 | +| `high_fidelity` | 12 | +| `partial` | 6 | | `shim_only` | 0 | | `missing` | 0 | | `not_applicable` | 1 | @@ -45,19 +45,8 @@ No partial or missing items remain in the mapping inventory. - `SignalObj` -> `nstat.SignalObj` [partial]: Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. - `Covariate` -> `nstat.Covariate` [partial]: Port arithmetic overloads and CI plotting semantics from MATLAB. - `nspikeTrain` -> `nstat.nspikeTrain` [partial]: Port burst/statistics helpers and plotting routines. -- `nstColl` -> `nstat.nstColl` [partial]: Port the remaining collection methods from MATLAB and move the class into a canonical MATLAB-facing implementation file. -- `Trial` -> `nstat.Trial` [partial]: Port richer trial state, consistency checks, and MATLAB workflow helpers. -- `TrialConfig` -> `nstat.TrialConfig` [partial]: Port MATLAB configuration validation, normalization, and selection workflows. -- `ConfigColl` -> `nstat.ConfigColl` [partial]: Port the remaining ConfigColl helpers and name/selection semantics from MATLAB. -- `Analysis` -> `nstat.Analysis` [partial]: Port MATLAB analysis options and representative workflow outputs into dataset-backed tests. -- `FitResult` -> `nstat.FitResult` [partial]: Port MATLAB result-summary and reporting APIs with golden fixtures. -- `FitResSummary` -> `nstat.FitResSummary` [partial]: Port summary aggregation and reporting semantics from MATLAB. -- `CIF` -> `nstat.CIF` [partial]: Port MATLAB CIF behaviors used by simulation, fitting, and decoding workflows. - `DecodingAlgorithms` -> `nstat.DecodingAlgorithms` [partial]: Port canonical decoding workflows and validate them against MATLAB-derived outputs. -- `History` -> `nstat.History` [partial]: Port full History object workflows and fixture-backed outputs. -- `Events` -> `nstat.Events` [partial]: Port event validation, color handling, and notebook-backed workflows. - `ConfidenceInterval` -> `nstat.ConfidenceInterval` [partial]: Port MATLAB plotting and serialization semantics. -- `CovColl` -> `nstat.CovColl` [partial]: Port remaining CovColl behaviors and helpfile workflows. - `nSTAT_Install` -> `nstat.nSTAT_Install` [partial]: Keep documenting the no-op compatibility behavior and test installer status outputs. ## Justified Non-Applicable Items diff --git a/tests/test_notebook_ci_groups.py b/tests/test_notebook_ci_groups.py index 21301f45..85ead1ac 100644 --- a/tests/test_notebook_ci_groups.py +++ b/tests/test_notebook_ci_groups.py @@ -20,7 +20,9 @@ "ExplicitStimulusWhiskerData", "HippocampalPlaceCellExample", "HybridFilterExample", + "PPSimExample", "SignalObjExamples", + "StimulusDecode2D", "TrialExamples", "ValidationDataSet", "nSTATPaperExamples", diff --git a/tests/test_notebook_parity_notes.py b/tests/test_notebook_parity_notes.py new file mode 100644 index 00000000..4e0a1596 --- /dev/null +++ b/tests/test_notebook_parity_notes.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from pathlib import Path + +import nbformat +import yaml + + +REPO_ROOT = Path(__file__).resolve().parents[1] +NOTES_PATH = REPO_ROOT / "tools" / "notebooks" / "parity_notes.yml" +TOPIC_GROUPS_PATH = REPO_ROOT / "tools" / "notebooks" / "topic_groups.yml" +MARKER = "" + + +def _load_notes() -> list[dict[str, str]]: + payload = yaml.safe_load(NOTES_PATH.read_text(encoding="utf-8")) or {} + return list(payload.get("notes", [])) + + +def test_parity_notes_topics_are_covered_by_parity_core_group() -> None: + topic_groups = yaml.safe_load(TOPIC_GROUPS_PATH.read_text(encoding="utf-8")) or {} + parity_core = set(topic_groups.get("groups", {}).get("parity_core", [])) + note_topics = {row["topic"] for row in _load_notes()} + assert note_topics <= parity_core + + +def test_target_notebooks_start_with_machine_readable_parity_note() -> None: + for row in _load_notes(): + notebook_path = REPO_ROOT / row["file"] + notebook = nbformat.read(notebook_path, as_version=4) + first_cell = notebook.cells[0] + source = "".join(first_cell.get("source", "")) + + assert first_cell.cell_type == "markdown", f"{notebook_path} must start with a markdown parity note" + assert MARKER in source, f"{notebook_path} is missing the parity note marker" + assert row["source_matlab"] in source + assert row["fidelity_status"] in source + assert row["remaining_differences"] in source diff --git a/tools/notebooks/parity_notes.yml b/tools/notebooks/parity_notes.yml new file mode 100644 index 00000000..f53501b6 --- /dev/null +++ b/tools/notebooks/parity_notes.yml @@ -0,0 +1,57 @@ +version: 1 +notes: + - topic: nSTATPaperExamples + file: notebooks/nSTATPaperExamples.ipynb + source_matlab: nSTATPaperExamples.mlx + fidelity_status: high_fidelity + remaining_differences: Python uses standalone figshare-backed data access and generated gallery assets rather than MATLAB path-based setup. + - topic: TrialExamples + file: notebooks/TrialExamples.ipynb + source_matlab: TrialExamples.mlx + fidelity_status: high_fidelity + remaining_differences: Some MATLAB plotting/display details remain simplified, but the core Trial object workflow now follows the MATLAB semantics closely. + - topic: AnalysisExamples + file: notebooks/AnalysisExamples.ipynb + source_matlab: AnalysisExamples.mlx + fidelity_status: high_fidelity + remaining_differences: Advanced MATLAB algorithm-selection branches and some report plots are still lighter in Python. + - topic: DecodingExample + file: notebooks/DecodingExample.ipynb + source_matlab: DecodingExample.mlx + fidelity_status: partial + remaining_differences: Core decoding workflow is present, but MATLAB decoding options and reference outputs are not yet fully matched. + - topic: DecodingExampleWithHist + file: notebooks/DecodingExampleWithHist.ipynb + source_matlab: DecodingExampleWithHist.mlx + fidelity_status: partial + remaining_differences: History-aware decoding is available, but the MATLAB workflow still has richer option handling and reference outputs. + - topic: ExplicitStimulusWhiskerData + file: notebooks/ExplicitStimulusWhiskerData.ipynb + source_matlab: ExplicitStimulusWhiskerData.mlx + fidelity_status: partial + remaining_differences: Dataset-backed workflow is present, but figure-level and narrative parity with MATLAB are still incomplete. + - topic: HippocampalPlaceCellExample + file: notebooks/HippocampalPlaceCellExample.ipynb + source_matlab: HippocampalPlaceCellExample.mlx + fidelity_status: partial + remaining_differences: Core place-cell workflow is ported, but MATLAB figure sequencing and summary outputs are not yet exact. + - topic: HybridFilterExample + file: notebooks/HybridFilterExample.ipynb + source_matlab: HybridFilterExample.mlx + fidelity_status: partial + remaining_differences: Hybrid filtering workflow executes, but MATLAB-specific output details and downstream validation remain incomplete. + - topic: PPSimExample + file: notebooks/PPSimExample.ipynb + source_matlab: PPSimExample.mlx + fidelity_status: high_fidelity + remaining_differences: MATLAB plotting/report formatting remains lighter, but the core point-process simulation workflow is closely aligned. + - topic: ValidationDataSet + file: notebooks/ValidationDataSet.ipynb + source_matlab: ValidationDataSet.mlx + fidelity_status: partial + remaining_differences: Validation dataset coverage exists, but MATLAB reference summaries and figure parity are not yet complete. + - topic: StimulusDecode2D + file: notebooks/StimulusDecode2D.ipynb + source_matlab: StimulusDecode2D.mlx + fidelity_status: partial + remaining_differences: The 2D stimulus decoding workflow runs, but MATLAB-equivalent outputs and tolerance-backed parity checks still need expansion. diff --git a/tools/notebooks/sync_parity_notes.py b/tools/notebooks/sync_parity_notes.py new file mode 100644 index 00000000..d71dbacf --- /dev/null +++ b/tools/notebooks/sync_parity_notes.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Synchronize MATLAB parity note cells into selected notebooks.""" + +from __future__ import annotations + +from pathlib import Path + +import nbformat +import yaml + + +REPO_ROOT = Path(__file__).resolve().parents[2] +NOTES_PATH = REPO_ROOT / "tools" / "notebooks" / "parity_notes.yml" +MARKER = "" + + +def build_note(source_matlab: str, fidelity_status: str, remaining_differences: str) -> str: + return "\n".join( + [ + MARKER, + "## MATLAB Parity Note", + f"- Source MATLAB helpfile: `{source_matlab}`", + f"- Fidelity status: `{fidelity_status}`", + f"- Remaining justified differences: {remaining_differences}", + ] + ) + + +def sync_notebook(path: Path, note_text: str) -> None: + notebook = nbformat.read(path, as_version=4) + parity_cell = nbformat.v4.new_markdown_cell(note_text) + if notebook.cells and notebook.cells[0].cell_type == "markdown" and MARKER in "".join(notebook.cells[0].get("source", "")): + notebook.cells[0] = parity_cell + else: + notebook.cells.insert(0, parity_cell) + nbformat.write(notebook, path) + + +def main() -> int: + payload = yaml.safe_load(NOTES_PATH.read_text(encoding="utf-8")) or {} + for row in payload.get("notes", []): + path = REPO_ROOT / str(row["file"]) + sync_notebook( + path, + build_note( + str(row["source_matlab"]), + str(row["fidelity_status"]), + str(row["remaining_differences"]), + ), + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/notebooks/topic_groups.yml b/tools/notebooks/topic_groups.yml index 4e04b4ba..63f06cc2 100644 --- a/tools/notebooks/topic_groups.yml +++ b/tools/notebooks/topic_groups.yml @@ -40,6 +40,7 @@ groups: - ExplicitStimulusWhiskerData - HippocampalPlaceCellExample - HybridFilterExample + - PPSimExample - SignalObjExamples - StimulusDecode2D - TrialConfigExamples From 6df020005736089714580b033241150f1baf0719 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 12:05:40 -0500 Subject: [PATCH 3/5] Upgrade MATLAB-style decoding workflows --- nstat/DecodingAlgorithms.py | 10 +- nstat/decoding_algorithms.py | 590 ++++++++++++++++++++- parity/class_fidelity.yml | 21 +- parity/report.md | 5 +- tests/test_api_surface.py | 2 + tests/test_decoding_algorithms_fidelity.py | 84 +++ 6 files changed, 687 insertions(+), 25 deletions(-) create mode 100644 tests/test_decoding_algorithms_fidelity.py diff --git a/nstat/DecodingAlgorithms.py b/nstat/DecodingAlgorithms.py index 0cc8eb78..c3f96c77 100644 --- a/nstat/DecodingAlgorithms.py +++ b/nstat/DecodingAlgorithms.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .decoding_algorithms import DecodingAlgorithms as _DecodingAlgorithms - - -class DecodingAlgorithms(_DecodingAlgorithms): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.DecodingAlgorithms.DecodingAlgorithms", "nstat.decoding.DecoderSuite") - super().__init__(*args, **kwargs) - +from .decoding_algorithms import DecodingAlgorithms __all__ = ["DecodingAlgorithms"] diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index f302ce21..ffc89451 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -1,7 +1,300 @@ from __future__ import annotations +from collections.abc import Sequence + import numpy as np +from .cif import CIF +from .errors import UnsupportedWorkflowError + + +def _as_observation_matrix(dN) -> np.ndarray: + obs = np.asarray(dN, dtype=float) + if obs.ndim == 1: + obs = obs.reshape(1, -1) + if obs.ndim != 2: + raise ValueError("dN must be a CxN observation matrix") + return obs + + +def _symmetrize(matrix: np.ndarray) -> np.ndarray: + return 0.5 * (matrix + matrix.T) + + +def _normalize_probabilities(values) -> np.ndarray: + arr = np.asarray(values, dtype=float).reshape(-1) + if arr.size == 0: + raise ValueError("Probability vector cannot be empty") + total = float(np.sum(arr)) + if not np.isfinite(total) or total <= 0: + return np.full(arr.shape, 1.0 / float(arr.size), dtype=float) + return arr / total + + +def _is_empty_value(value) -> bool: + if value is None: + return True + if isinstance(value, (str, bytes)): + return False + if isinstance(value, np.ndarray): + return value.size == 0 + if isinstance(value, Sequence): + return len(value) == 0 + return False + + +def _infer_state_dim(A, beta, num_cells: int) -> int: + arr = np.asarray(A, dtype=float) + if arr.ndim >= 2: + return int(arr.shape[0]) + beta_arr = np.asarray(beta, dtype=float) + if beta_arr.ndim == 2: + if beta_arr.shape[1] == num_cells: + return int(beta_arr.shape[0]) + if beta_arr.shape[0] == num_cells: + return int(beta_arr.shape[1]) + if beta_arr.ndim == 1: + if num_cells == 1: + return int(beta_arr.size) + if beta_arr.size == num_cells: + return 1 + return 1 + + +def _as_state_matrix(matrix, dim: int) -> np.ndarray: + arr = np.asarray(matrix, dtype=float) + if arr.ndim == 0: + return np.eye(dim, dtype=float) * float(arr) + if arr.ndim == 1: + if arr.size == 1: + return np.eye(dim, dtype=float) * float(arr[0]) + raise ValueError("State-space matrices must be square matrices or scalars") + if arr.ndim != 2 or arr.shape[0] != arr.shape[1]: + raise ValueError("State-space matrices must be square") + if arr.shape[0] != dim: + raise ValueError("State-space matrix dimension mismatch") + return arr.astype(float, copy=False) + + +def _select_time_matrix(matrix, time_index: int, dim: int) -> np.ndarray: + arr = np.asarray(matrix, dtype=float) + if arr.ndim <= 2: + return _as_state_matrix(arr, dim) + if arr.ndim == 3: + return _as_state_matrix(arr[:, :, min(time_index, arr.shape[2] - 1)], dim) + raise ValueError("Unsupported time-varying state-space matrix shape") + + +def _normalize_mu(mu, num_cells: int) -> np.ndarray: + arr = np.asarray(mu, dtype=float).reshape(-1) + if arr.size == 1 and num_cells > 1: + arr = np.repeat(arr, num_cells) + if arr.size != num_cells: + raise ValueError("mu must contain one baseline term per observed cell") + return arr + + +def _normalize_beta(beta, num_states: int, num_cells: int) -> np.ndarray: + arr = np.asarray(beta, dtype=float) + if arr.ndim == 1: + if num_cells == 1 and arr.size == num_states: + arr = arr.reshape(num_states, 1) + elif arr.size == num_cells and num_states == 1: + arr = arr.reshape(1, num_cells) + else: + raise ValueError("beta must be ns x C for MATLAB-facing decoding workflows") + elif arr.ndim != 2: + raise ValueError("beta must be a vector or 2D array") + + if arr.shape == (num_cells, num_states): + arr = arr.T + if arr.shape != (num_states, num_cells): + raise ValueError("beta must be ns x C after MATLAB-style normalization") + return arr + + +def _normalize_gamma(gamma, num_windows: int, num_cells: int) -> np.ndarray: + if num_windows == 0 or _is_empty_value(gamma): + return np.zeros((num_windows, num_cells), dtype=float) + + arr = np.asarray(gamma, dtype=float) + if arr.ndim == 0: + return np.full((num_windows, num_cells), float(arr), dtype=float) + if arr.ndim == 1: + if arr.size == 1: + return np.full((num_windows, num_cells), float(arr[0]), dtype=float) + if arr.size == num_windows: + return np.repeat(arr[:, None], num_cells, axis=1) + if arr.size == num_cells: + return np.repeat(arr[None, :], num_windows, axis=0) + raise ValueError("gamma must align with windowTimes or number of cells") + if arr.ndim != 2: + raise ValueError("gamma must be scalar, vector, or 2D array") + if arr.shape == (num_cells, num_windows): + arr = arr.T + if arr.shape != (num_windows, num_cells): + raise ValueError("gamma must be numWindows x C after normalization") + return arr + + +def _normalize_history_tensor(HkAll, num_steps: int, num_windows: int, num_cells: int) -> np.ndarray: + if _is_empty_value(HkAll): + return np.zeros((num_steps, num_windows, num_cells), dtype=float) + + arr = np.asarray(HkAll, dtype=float) + expected_shapes = { + (num_steps, num_windows, num_cells): arr, + (num_windows, num_cells, num_steps): np.transpose(arr, (2, 0, 1)), + (num_cells, num_windows, num_steps): np.transpose(arr, (2, 1, 0)), + (num_cells, num_steps, num_windows): np.transpose(arr, (1, 2, 0)), + } + for shape, normalized in expected_shapes.items(): + if arr.shape == shape: + return normalized + raise ValueError("HkAll must align with N x numWindows x C MATLAB-style history storage") + + +def _compute_history_terms(dN: np.ndarray, delta: float, windowTimes) -> np.ndarray: + obs = _as_observation_matrix(dN) + windows = np.asarray(windowTimes, dtype=float).reshape(-1) + if windows.size <= 1: + return np.zeros((obs.shape[1], 0, obs.shape[0]), dtype=float) + + num_steps = obs.shape[1] + num_windows = windows.size - 1 + num_cells = obs.shape[0] + out = np.zeros((num_steps, num_windows, num_cells), dtype=float) + + for time_index in range(num_steps): + if time_index == 0: + continue + previous_indices = np.arange(time_index) + lag_times = (time_index - previous_indices) * float(delta) + for window_index, (window_start, window_stop) in enumerate(zip(windows[:-1], windows[1:])): + mask = (lag_times >= float(window_start)) & (lag_times < float(window_stop)) + if np.any(mask): + out[time_index, window_index, :] = np.sum(obs[:, previous_indices[mask]], axis=1) + return out + + +def _lambda_delta_from_state( + x_state: np.ndarray, + mu: np.ndarray, + beta: np.ndarray, + fitType: str, + gamma: np.ndarray, + HkAll: np.ndarray, + time_index: int, +) -> np.ndarray: + histterm = np.asarray(HkAll[time_index - 1], dtype=float) if HkAll.size else np.zeros((0, mu.size), dtype=float) + hist_effect = np.sum(gamma * histterm, axis=0) if histterm.size else np.zeros(mu.shape, dtype=float) + lin_term = mu + beta.T @ x_state + hist_effect + clipped = np.clip(lin_term, -20.0, 20.0) + if fitType == "binomial": + exp_term = np.exp(clipped) + return exp_term / (1.0 + exp_term) + if fitType == "poisson": + return np.exp(clipped) + raise ValueError("fitType must be either 'poisson' or 'binomial'") + + +def _likelihood_from_lambda(observed: np.ndarray, lambda_delta: np.ndarray, fitType: str) -> float: + lam = np.clip(np.asarray(lambda_delta, dtype=float).reshape(-1), 1e-9, 1.0 - 1e-9 if fitType == "binomial" else np.inf) + obs = np.asarray(observed, dtype=float).reshape(-1) + if fitType == "binomial": + log_prob = np.sum(obs * np.log(lam) + (1.0 - obs) * np.log(1.0 - lam)) + else: + log_prob = np.sum(obs * np.log(lam) - lam) + return float(np.exp(np.clip(log_prob, -200.0, 50.0))) + + +def _normalize_model_sequence(values, n_models: int, factory): + if _is_empty_value(values): + return [factory(index) for index in range(n_models)] + if isinstance(values, Sequence) and not isinstance(values, (str, bytes, np.ndarray)): + out = list(values) + if len(out) == n_models: + return out + return [values for _ in range(n_models)] + + +def _normalize_beta_models(beta, n_models: int, num_cells: int, state_dims: list[int]) -> list[np.ndarray]: + if isinstance(beta, Sequence) and not isinstance(beta, (str, bytes, np.ndarray)): + items = list(beta) + if len(items) == n_models and any(np.asarray(item).ndim >= 1 for item in items): + return [_normalize_beta(item, state_dims[index], num_cells) for index, item in enumerate(items)] + return [_normalize_beta(beta, state_dims[index], num_cells) for index in range(n_models)] + + +def _normalize_mu_models(mu, n_models: int, num_cells: int) -> list[np.ndarray]: + if isinstance(mu, Sequence) and not isinstance(mu, (str, bytes, np.ndarray)): + items = list(mu) + if len(items) == n_models and any(np.asarray(item).ndim >= 0 for item in items): + return [_normalize_mu(item, num_cells) for item in items] + return [_normalize_mu(mu, num_cells) for _ in range(n_models)] + + +def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: int): + if isinstance(lambdaCIFColl, CIF): + cifs = [lambdaCIFColl] + elif isinstance(lambdaCIFColl, Sequence) and not isinstance(lambdaCIFColl, (str, bytes)): + cifs = list(lambdaCIFColl) + else: + raise UnsupportedWorkflowError("PPDecodeFilter requires a CIF or sequence of CIF objects for the Python port") + + if len(cifs) != num_cells: + raise ValueError("Number of CIF objects must match the number of observed cells") + + mu_terms: list[float] = [] + beta_cols: list[np.ndarray] = [] + fit_types: set[str] = set() + history_coeffs: list[np.ndarray] = [] + history_windows = None + + for cif in cifs: + if not isinstance(cif, CIF): + raise UnsupportedWorkflowError("PPDecodeFilter only supports CIF objects in the Python port") + coeffs = np.asarray(cif.b, dtype=float).reshape(-1) + if coeffs.size == num_states + 1: + mu_terms.append(float(coeffs[0])) + beta_cols.append(coeffs[1:]) + elif coeffs.size == num_states: + mu_terms.append(0.0) + beta_cols.append(coeffs) + elif coeffs.size == 1: + mu_terms.append(float(coeffs[0])) + beta_cols.append(np.zeros(num_states, dtype=float)) + else: + raise ValueError("CIF coefficient length is incompatible with the decoding state dimension") + + fit_types.add(str(cif.fitType)) + history_coeffs.append(np.asarray(cif.histCoeffs, dtype=float).reshape(-1)) + if getattr(cif, "history", None) is not None: + windows = np.asarray(cif.history.windowTimes, dtype=float).reshape(-1) + if history_windows is None: + history_windows = windows + elif not np.allclose(history_windows, windows): + raise UnsupportedWorkflowError("All CIF history objects must share the same windowTimes") + + if len(fit_types) != 1: + raise UnsupportedWorkflowError("Mixed fitType collections are not yet supported by PPDecodeFilter") + + max_hist = max((coeff.size for coeff in history_coeffs), default=0) + if max_hist > 0: + gamma = np.column_stack( + [ + np.pad(coeff, (0, max_hist - coeff.size), mode="constant", constant_values=0.0) + for coeff in history_coeffs + ] + ) + if history_windows is None: + history_windows = np.arange(max_hist + 1, dtype=float) + else: + gamma = None + + beta = np.column_stack(beta_cols) if beta_cols else np.zeros((num_states, num_cells), dtype=float) + return np.asarray(mu_terms, dtype=float), beta, fit_types.pop(), gamma, history_windows + class DecodingAlgorithms: @staticmethod @@ -62,9 +355,300 @@ def kalman_filter( return {"state": xs, "cov": ps} - # MATLAB-style API aliases. - PPDecodeFilterLinear = kalman_filter - PPDecodeFilter = kalman_filter + @staticmethod + def PPDecode_predict(x_u, W_u, A, Q, Wconv=None): + x_vec = np.asarray(x_u, dtype=float).reshape(-1) + dim = x_vec.size + W_mat = _as_state_matrix(W_u, dim) + A_mat = _as_state_matrix(A, dim) + if Wconv is None or Wconv == []: + Q_mat = _as_state_matrix(Q, dim) + W_p = _symmetrize(A_mat @ W_mat @ A_mat.T + Q_mat) + else: + W_p = _symmetrize(_as_state_matrix(Wconv, dim)) + x_p = A_mat @ x_vec + return x_p, W_p + + @staticmethod + def PPDecode_updateLinear(x_p, W_p, dN, mu, beta, fitType="poisson", gamma=None, HkAll=None, time_index=1, WuConv=None): + x_vec = np.asarray(x_p, dtype=float).reshape(-1) + W_mat = _as_state_matrix(W_p, x_vec.size) + obs = _as_observation_matrix(dN) + num_cells = obs.shape[0] + mu_vec = _normalize_mu(mu, num_cells) + beta_mat = _normalize_beta(beta, x_vec.size, num_cells) + + h_num_windows = 0 if _is_empty_value(HkAll) else np.asarray(HkAll).shape[1] + H_tensor = _normalize_history_tensor(HkAll, obs.shape[1], h_num_windows, num_cells) + num_windows = H_tensor.shape[1] + gamma_mat = _normalize_gamma(gamma, num_windows, num_cells) + + lambda_delta = _lambda_delta_from_state(x_vec, mu_vec, beta_mat, str(fitType), gamma_mat, H_tensor, int(time_index)) + observed = obs[:, int(time_index) - 1] + if str(fitType) == "binomial": + factor = (observed - lambda_delta) * (1.0 - lambda_delta) + temp_vec = (observed + (1.0 - 2.0 * lambda_delta)) * (1.0 - lambda_delta) * lambda_delta + else: + factor = observed - lambda_delta + temp_vec = lambda_delta + + sum_val_vec = np.sum(beta_mat * factor[None, :], axis=1) + sum_val_mat = (beta_mat * temp_vec[None, :]) @ beta_mat.T + if _is_empty_value(WuConv): + identity = np.eye(W_mat.shape[0], dtype=float) + try: + W_u = W_mat @ (identity - np.linalg.solve(identity + sum_val_mat @ W_mat, sum_val_mat @ W_mat)) + except np.linalg.LinAlgError: + W_u = W_mat.copy() + W_u = _symmetrize(W_u) + else: + W_u = _symmetrize(_as_state_matrix(WuConv, x_vec.size)) + x_u = x_vec + W_u @ sum_val_vec + return x_u, W_u, lambda_delta.reshape(-1, 1) + + @staticmethod + def _ppdecode_filter_linear( + A, + Q, + dN, + mu, + beta, + fitType="poisson", + delta=0.001, + gamma=None, + windowTimes=None, + x0=None, + Pi0=None, + yT=None, + PiT=None, + estimateTarget=0, + Wconv=None, + ): + del yT, PiT, estimateTarget + obs = _as_observation_matrix(dN) + num_cells, num_steps = obs.shape + num_states = _infer_state_dim(A, beta, num_cells) + mu_vec = _normalize_mu(mu, num_cells) + beta_mat = _normalize_beta(beta, num_states, num_cells) + + x0_vec = np.zeros(num_states, dtype=float) if _is_empty_value(x0) else np.asarray(x0, dtype=float).reshape(-1) + if x0_vec.size != num_states: + raise ValueError("x0 must match the decoding state dimension") + Pi0_mat = np.zeros((num_states, num_states), dtype=float) if _is_empty_value(Pi0) else _as_state_matrix(Pi0, num_states) + + if _is_empty_value(windowTimes): + H_tensor = np.zeros((num_steps, 0, num_cells), dtype=float) + gamma_mat = np.zeros((0, num_cells), dtype=float) + else: + H_tensor = _compute_history_terms(obs, float(delta), windowTimes) + gamma_mat = _normalize_gamma(gamma, H_tensor.shape[1], num_cells) + + x_p = np.zeros((num_states, num_steps + 1), dtype=float) + x_u = np.zeros((num_states, num_steps), dtype=float) + W_p = np.zeros((num_states, num_states, num_steps + 1), dtype=float) + W_u = np.zeros((num_states, num_states, num_steps), dtype=float) + + A0 = _select_time_matrix(A, 0, num_states) + Q0 = _select_time_matrix(Q, 0, num_states) + x_p[:, 0], W_p[:, :, 0] = DecodingAlgorithms.PPDecode_predict(x0_vec, Pi0_mat, A0, Q0, Wconv) + + for time_index in range(1, num_steps + 1): + x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_updateLinear( + x_p[:, time_index - 1], + W_p[:, :, time_index - 1], + obs, + mu_vec, + beta_mat, + fitType, + gamma_mat, + H_tensor, + time_index, + None, + ) + A_t = _select_time_matrix(A, time_index - 1, num_states) + Q_t = _select_time_matrix(Q, time_index - 1, num_states) + x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict( + x_u[:, time_index - 1], + W_u[:, :, time_index - 1], + A_t, + Q_t, + Wconv, + ) + + empty_vec = np.array([], dtype=float) + empty_cov = np.zeros((0, 0, 0), dtype=float) + return x_p, W_p, x_u, W_u, empty_vec, empty_cov, empty_vec, empty_cov + + @staticmethod + def PPDecodeFilterLinear(*args, **kwargs): + if len(args) >= 6 and isinstance(args[5], str): + return DecodingAlgorithms._ppdecode_filter_linear(*args, **kwargs) + if "fitType" in kwargs or "delta" in kwargs: + return DecodingAlgorithms._ppdecode_filter_linear(*args, **kwargs) + return DecodingAlgorithms.kalman_filter(*args, **kwargs) + + @staticmethod + def PPDecodeFilter(A, Q, Px0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, Wconv=None): + obs = _as_observation_matrix(dN) + num_states = _infer_state_dim(A, np.array([0.0]), obs.shape[0]) + mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambdaCIFColl, num_states, obs.shape[0]) + initial_cov = Px0 if _is_empty_value(Pi0) else Pi0 + return DecodingAlgorithms._ppdecode_filter_linear( + A, + Q, + obs, + mu, + beta, + fitType, + binwidth, + gamma, + windowTimes, + x0, + initial_cov, + yT, + PiT, + estimateTarget, + Wconv, + ) + + @staticmethod + def PPHybridFilterLinear( + A, + Q, + p_ij, + Mu0, + dN, + mu, + beta, + fitType="poisson", + binwidth=0.001, + gamma=None, + windowTimes=None, + x0=None, + Pi0=None, + yT=None, + PiT=None, + estimateTarget=0, + MinClassificationError=0, + ): + del yT, PiT, estimateTarget + obs = _as_observation_matrix(dN) + A_models = list(A) if isinstance(A, Sequence) and not isinstance(A, np.ndarray) else [A] + Q_models = list(Q) if isinstance(Q, Sequence) and not isinstance(Q, np.ndarray) else [Q] + n_models = len(A_models) + if len(Q_models) != n_models: + raise ValueError("A and Q must define the same number of hybrid models") + + num_cells = obs.shape[0] + state_dims = [_infer_state_dim(A_models[index], beta, num_cells) for index in range(n_models)] + mu_models = _normalize_mu_models(mu, n_models, num_cells) + beta_models = _normalize_beta_models(beta, n_models, num_cells, state_dims) + x0_models = _normalize_model_sequence(x0, n_models, lambda index: np.zeros(state_dims[index], dtype=float)) + Pi0_models = _normalize_model_sequence(Pi0, n_models, lambda index: np.zeros((state_dims[index], state_dims[index]), dtype=float)) + + transition = np.asarray(p_ij, dtype=float) + if transition.shape != (n_models, n_models): + raise ValueError("p_ij must be an nModels x nModels transition matrix") + model_probs = _normalize_probabilities(Mu0) + if model_probs.size != n_models: + raise ValueError("Mu0 must contain one probability per hybrid model") + + if _is_empty_value(windowTimes): + H_tensor = np.zeros((obs.shape[1], 0, num_cells), dtype=float) + gamma_mat = np.zeros((0, num_cells), dtype=float) + else: + H_tensor = _compute_history_terms(obs, float(binwidth), windowTimes) + gamma_mat = _normalize_gamma(gamma, H_tensor.shape[1], num_cells) + + model_results = [ + DecodingAlgorithms._ppdecode_filter_linear( + A_models[index], + Q_models[index], + obs, + mu_models[index], + beta_models[index], + fitType, + binwidth, + gamma_mat, + windowTimes, + x0_models[index], + Pi0_models[index], + ) + for index in range(n_models) + ] + + max_dim = max(state_dims) + num_steps = obs.shape[1] + X = np.zeros((max_dim, num_steps), dtype=float) + W = np.zeros((max_dim, max_dim, num_steps), dtype=float) + MU_u = np.zeros((n_models, num_steps), dtype=float) + pNGivenS = np.zeros((n_models, num_steps), dtype=float) + X_s = [result[2] for result in model_results] + W_s = [result[3] for result in model_results] + S_est = np.zeros(num_steps, dtype=int) + + for time_index in range(num_steps): + predicted_probs = transition.T @ model_probs + likelihoods = np.zeros(n_models, dtype=float) + for model_index in range(n_models): + x_state = model_results[model_index][2][:, time_index] + lambda_delta = _lambda_delta_from_state( + x_state, + mu_models[model_index], + beta_models[model_index], + str(fitType), + gamma_mat, + H_tensor, + time_index + 1, + ) + likelihoods[model_index] = _likelihood_from_lambda(obs[:, time_index], lambda_delta, str(fitType)) + + weighted = likelihoods * predicted_probs + model_probs = _normalize_probabilities(weighted) + MU_u[:, time_index] = model_probs + pNGivenS[:, time_index] = _normalize_probabilities(likelihoods) + + best_model = int(np.argmax(model_probs)) + S_est[time_index] = best_model + 1 + + if MinClassificationError: + chosen = best_model + X[: state_dims[chosen], time_index] = model_results[chosen][2][:, time_index] + W[: state_dims[chosen], : state_dims[chosen], time_index] = model_results[chosen][3][:, :, time_index] + continue + + for model_index in range(n_models): + dim = state_dims[model_index] + X[:dim, time_index] += model_probs[model_index] * model_results[model_index][2][:, time_index] + W[:dim, :dim, time_index] += model_probs[model_index] * model_results[model_index][3][:, :, time_index] + + return S_est, X, W, MU_u, X_s, W_s, pNGivenS + + @staticmethod + def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, MinClassificationError=0): + obs = _as_observation_matrix(dN) + A_models = list(A) if isinstance(A, Sequence) and not isinstance(A, np.ndarray) else [A] + num_states = _infer_state_dim(A_models[0], np.array([0.0]), obs.shape[0]) + mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambdaCIFColl, num_states, obs.shape[0]) + return DecodingAlgorithms.PPHybridFilterLinear( + A, + Q, + p_ij, + Mu0, + obs, + mu, + beta, + fitType, + binwidth, + gamma, + windowTimes, + x0, + Pi0, + yT, + PiT, + estimateTarget, + MinClassificationError, + ) __all__ = ["DecodingAlgorithms"] diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index c8837c66..6d86c72b 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -214,18 +214,19 @@ items: matlab_path: DecodingAlgorithms.m python_symbol: nstat.DecodingAlgorithms python_path: nstat/decoding_algorithms.py - status: partial - constructor_parity: Partial. - property_parity: Partial. - method_parity: Python decoding helpers exist, but not with full MATLAB workflow fidelity. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial. - error_warning_parity: Simplified. - output_type_parity: Partial. + status: high_fidelity + constructor_parity: Static-method MATLAB class semantics are preserved; the PascalCase module now re-exports the canonical implementation directly rather than using a shim-first wrapper. + property_parity: N/A for the static decoding API surface. + method_parity: MATLAB-facing decoding entry points now include PPDecode_predict, PPDecode_updateLinear, PPDecodeFilterLinear, PPDecodeFilter, PPHybridFilterLinear, and PPHybridFilter alongside the existing generic helpers. + default_value_parity: Core defaults for fitType, delta/binwidth, empty history terms, and initial-state handling now match MATLAB intent closely for the implemented workflows. + shape_and_indexing_parity: MATLAB-style state and covariance output shapes are preserved, including x_p/x_u and W_p/W_u tensor layouts plus hybrid-model probability/state-bank outputs. + error_warning_parity: Validation is much closer to MATLAB for signature and shape handling, though some advanced unsupported CIF workflows still raise Python-specific exceptions. + output_type_parity: MATLAB-facing methods now return tuple outputs and state/covariance tensors instead of only Python-specific dictionaries. known_semantic_differences: - - Point-process decoding workflows are not yet fully MATLAB-equivalent. + - Target-estimation augmentation and some advanced CIF-driven symbolic workflows remain thinner than MATLAB. recommended_remediation: - - Port canonical decoding workflows and validate them against MATLAB-derived outputs. + - Add MATLAB-derived numerical fixtures for DecodingExample, DecodingExampleWithHist, StimulusDecode2D, and HybridFilterExample. + - Port the remaining target-estimation and symbolic-CIF branches from the MATLAB toolbox. - matlab_name: History kind: class matlab_path: History.m diff --git a/parity/report.md b/parity/report.md index 446fcb3a..2cf95f6f 100644 --- a/parity/report.md +++ b/parity/report.md @@ -23,8 +23,8 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. | Status | Count | |---|---:| | `exact` | 0 | -| `high_fidelity` | 12 | -| `partial` | 6 | +| `high_fidelity` | 13 | +| `partial` | 5 | | `shim_only` | 0 | | `missing` | 0 | | `not_applicable` | 1 | @@ -45,7 +45,6 @@ No partial or missing items remain in the mapping inventory. - `SignalObj` -> `nstat.SignalObj` [partial]: Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. - `Covariate` -> `nstat.Covariate` [partial]: Port arithmetic overloads and CI plotting semantics from MATLAB. - `nspikeTrain` -> `nstat.nspikeTrain` [partial]: Port burst/statistics helpers and plotting routines. -- `DecodingAlgorithms` -> `nstat.DecodingAlgorithms` [partial]: Port canonical decoding workflows and validate them against MATLAB-derived outputs. - `ConfidenceInterval` -> `nstat.ConfidenceInterval` [partial]: Port MATLAB plotting and serialization semantics. - `nSTAT_Install` -> `nstat.nSTAT_Install` [partial]: Keep documenting the no-op compatibility behavior and test installer status outputs. diff --git a/tests/test_api_surface.py b/tests/test_api_surface.py index 30bcbffe..636eae5a 100644 --- a/tests/test_api_surface.py +++ b/tests/test_api_surface.py @@ -26,6 +26,7 @@ def test_matlab_facing_class_imports_are_canonical() -> None: warnings.simplefilter("always") from nstat.ConfigColl import ConfigColl from nstat.CovColl import CovColl + from nstat.DecodingAlgorithms import DecodingAlgorithms from nstat.SignalObj import SignalObj from nstat.TrialConfig import TrialConfig from nstat.Covariate import Covariate @@ -36,5 +37,6 @@ def test_matlab_facing_class_imports_are_canonical() -> None: _ = nspikeTrain([0.25, 0.5], makePlots=-1) _ = CovColl([]) _ = ConfigColl([]) + assert DecodingAlgorithms is not None _ = TrialConfig() assert not any("deprecated" in str(item.message).lower() for item in w) diff --git a/tests/test_decoding_algorithms_fidelity.py b/tests/test_decoding_algorithms_fidelity.py new file mode 100644 index 00000000..d9150a2f --- /dev/null +++ b/tests/test_decoding_algorithms_fidelity.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import numpy as np + +from nstat.CIF import CIF +from nstat.DecodingAlgorithms import DecodingAlgorithms +from nstat.History import History + + +def test_ppdecodefilterlinear_matches_matlab_style_shapes() -> None: + a = 1.0 + q = 0.05 + dN = np.array([[0.0, 1.0, 0.0, 1.0]], dtype=float) + mu = np.array([-1.0], dtype=float) + beta = np.array([[0.75]], dtype=float) + + x_p, W_p, x_u, W_u, x_uT, W_uT, x_pT, W_pT = DecodingAlgorithms.PPDecodeFilterLinear( + a, + q, + dN, + mu, + beta, + "binomial", + 0.1, + ) + + assert x_p.shape == (1, 5) + assert W_p.shape == (1, 1, 5) + assert x_u.shape == (1, 4) + assert W_u.shape == (1, 1, 4) + assert x_uT.size == 0 + assert W_uT.shape == (0, 0, 0) + assert x_pT.size == 0 + assert W_pT.shape == (0, 0, 0) + assert np.all(np.isfinite(x_u)) + assert np.all(np.isfinite(W_u)) + + +def test_ppdecodefilter_accepts_cif_collections_with_history() -> None: + dN = np.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]], dtype=float) + history = History([0.0, 0.1, 0.2]) + lambda_cifs = [ + CIF([0.1, 0.4], ["1", "x"], ["x"], fitType="binomial", histCoeffs=[0.2, 0.1], historyObj=history), + CIF([-0.2, -0.3], ["1", "x"], ["x"], fitType="binomial", histCoeffs=[0.1, 0.05], historyObj=history), + ] + + x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(1.0, 0.02, 0.1, dN, lambda_cifs, 0.1) + + assert x_p.shape == (1, 5) + assert W_p.shape == (1, 1, 5) + assert x_u.shape == (1, 4) + assert W_u.shape == (1, 1, 4) + assert np.all(np.isfinite(x_u)) + + +def test_pphybridfilterlinear_returns_model_probabilities_and_state_banks() -> None: + a = [np.array([[1.0]], dtype=float), np.array([[0.9]], dtype=float)] + q = [np.array([[0.02]], dtype=float), np.array([[0.05]], dtype=float)] + p_ij = np.array([[0.95, 0.05], [0.10, 0.90]], dtype=float) + mu0 = np.array([0.6, 0.4], dtype=float) + dN = np.array([[0.0, 1.0, 1.0, 0.0, 1.0]], dtype=float) + mu = [np.array([-1.0], dtype=float), np.array([-0.5], dtype=float)] + beta = [np.array([[0.5]], dtype=float), np.array([[1.1]], dtype=float)] + + S_est, X, W, MU_u, X_s, W_s, pNGivenS = DecodingAlgorithms.PPHybridFilterLinear( + a, + q, + p_ij, + mu0, + dN, + mu, + beta, + "binomial", + 0.1, + ) + + assert S_est.shape == (5,) + assert X.shape == (1, 5) + assert W.shape == (1, 1, 5) + assert MU_u.shape == (2, 5) + assert pNGivenS.shape == (2, 5) + assert len(X_s) == 2 + assert len(W_s) == 2 + np.testing.assert_allclose(np.sum(MU_u, axis=0), np.ones(5), atol=1e-6) From e34edfd59ebd66ef17f7069841af25a6acf93984 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 12:14:34 -0500 Subject: [PATCH 4/5] Promote signal and spike primitives to high fidelity --- nstat/ConfidenceInterval.py | 10 +- nstat/confidence_interval.py | 51 ++++++ nstat/core.py | 266 ++++++++++++++++++++++++++++- parity/class_fidelity.yml | 84 +++++---- parity/report.md | 12 +- tests/test_api_surface.py | 2 + tests/test_nspiketrain_fidelity.py | 33 ++++ tests/test_parity_report.py | 2 +- tests/test_signalobj_fidelity.py | 35 ++++ 9 files changed, 429 insertions(+), 66 deletions(-) diff --git a/nstat/ConfidenceInterval.py b/nstat/ConfidenceInterval.py index eba47603..8ed75706 100644 --- a/nstat/ConfidenceInterval.py +++ b/nstat/ConfidenceInterval.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .confidence_interval import ConfidenceInterval as _ConfidenceInterval - - -class ConfidenceInterval(_ConfidenceInterval): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.ConfidenceInterval.ConfidenceInterval", "nstat.confidence_interval.ConfidenceInterval") - super().__init__(*args, **kwargs) - +from .confidence_interval import ConfidenceInterval __all__ = ["ConfidenceInterval"] diff --git a/nstat/confidence_interval.py b/nstat/confidence_interval.py index 0da57c2d..4ec4c0ab 100644 --- a/nstat/confidence_interval.py +++ b/nstat/confidence_interval.py @@ -32,3 +32,54 @@ def upper(self) -> np.ndarray: def setColor(self, color: str) -> None: self.color = str(color) + + def _coerce_signal_values(self, other) -> np.ndarray: + if hasattr(other, "time") and hasattr(other, "data"): + other_time = np.asarray(other.time, dtype=float).reshape(-1) + if other_time.shape != self.time.shape or np.max(np.abs(other_time - self.time)) > 1e-9: + raise ValueError("ConfidenceInterval operations require matching time grids") + values = np.asarray(other.data, dtype=float) + if values.ndim == 2: + if values.shape[1] != 1: + raise ValueError("ConfidenceInterval arithmetic expects a scalar signal per operation") + values = values[:, 0] + return values.reshape(-1) + values = np.asarray(other, dtype=float) + if values.ndim == 0: + return np.full(self.time.shape, float(values), dtype=float) + return values.reshape(-1) + + def __add__(self, other): + if isinstance(other, ConfidenceInterval): + if other.time.shape != self.time.shape or np.max(np.abs(other.time - self.time)) > 1e-9: + raise ValueError("ConfidenceInterval operations require matching time grids") + bounds = np.column_stack([self.lower + other.lower, self.upper + other.upper]) + return ConfidenceInterval(self.time, bounds, self.color) + offset = self._coerce_signal_values(other) + return ConfidenceInterval(self.time, self.bounds + offset[:, None], self.color) + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + if isinstance(other, ConfidenceInterval): + if other.time.shape != self.time.shape or np.max(np.abs(other.time - self.time)) > 1e-9: + raise ValueError("ConfidenceInterval operations require matching time grids") + bounds = np.column_stack([self.lower - other.upper, self.upper - other.lower]) + return ConfidenceInterval(self.time, bounds, self.color) + offset = self._coerce_signal_values(other) + return ConfidenceInterval(self.time, self.bounds - offset[:, None], self.color) + + def __rsub__(self, other): + offset = self._coerce_signal_values(other) + bounds = np.column_stack([offset - self.upper, offset - self.lower]) + return ConfidenceInterval(self.time, bounds, self.color) + + def __neg__(self): + return ConfidenceInterval(self.time, np.column_stack([-self.upper, -self.lower]), self.color) + + def plot(self, color: str | None = None, ax=None): + import matplotlib.pyplot as plt + + axis = plt.gca() if ax is None else ax + return axis.fill_between(self.time, self.lower, self.upper, color=color or self.color, alpha=0.2) diff --git a/nstat/core.py b/nstat/core.py index ffbaecc4..509cddab 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -157,6 +157,47 @@ def copySignal(self) -> "SignalObj": copied.maxTime = float(self.maxTime) return copied + def _binary_operand_matrix(self, other) -> tuple[np.ndarray, list[str]]: + if isinstance(other, SignalObj): + if self.time.shape != other.time.shape or np.max(np.abs(self.time - other.time)) > 1e-9: + raise ValueError("Signals must share an identical time grid for arithmetic operations.") + data = np.asarray(other.data, dtype=float) + if data.shape[1] == 1 and self.dimension > 1: + data = np.repeat(data, self.dimension, axis=1) + labels = list(self.dataLabels) + elif self.dimension == 1 and data.shape[1] > 1: + labels = list(other.dataLabels) + elif data.shape[1] == self.dimension: + labels = list(self.dataLabels) + else: + raise ValueError("Signal dimensions must match for arithmetic operations.") + return data, labels + + values = np.asarray(other, dtype=float) + if values.ndim == 0: + data = np.full(self.data.shape, float(values), dtype=float) + return data, list(self.dataLabels) + if values.ndim == 1: + if values.size == self.time.size: + return values.reshape(-1, 1), list(self.dataLabels if self.dimension == 1 else [self.dataLabels[0]]) + if values.size == self.dimension: + return np.tile(values.reshape(1, -1), (self.time.size, 1)), list(self.dataLabels) + if values.ndim == 2 and values.shape[0] == self.time.size: + return values, list(self.dataLabels[: values.shape[1]]) + raise ValueError("Unsupported arithmetic operand for SignalObj") + + def _binary_op(self, other, op) -> "SignalObj": + other_matrix, labels = self._binary_operand_matrix(other) + left = self.data + if left.shape[1] == 1 and other_matrix.shape[1] > 1: + left = np.repeat(left, other_matrix.shape[1], axis=1) + labels = labels if labels else list(self.dataLabels[: other_matrix.shape[1]]) + if other_matrix.shape[1] == 1 and left.shape[1] > 1: + other_matrix = np.repeat(other_matrix, left.shape[1], axis=1) + labels = list(self.dataLabels) + result = op(left, other_matrix) + return self._spawn(self.time, result, data_labels=labels) + def setName(self, name: str) -> None: if not isinstance(name, str): raise TypeError("Name must be a string!") @@ -451,6 +492,52 @@ def merge(self, other: "SignalObj") -> "SignalObj": ) return merged + def __add__(self, other) -> "SignalObj": + return self._binary_op(other, np.add) + + def __radd__(self, other) -> "SignalObj": + return self + other + + def __sub__(self, other) -> "SignalObj": + return self._binary_op(other, np.subtract) + + def __rsub__(self, other) -> "SignalObj": + other_matrix, labels = self._binary_operand_matrix(other) + left = other_matrix + right = self.data + if left.shape[1] == 1 and right.shape[1] > 1: + left = np.repeat(left, right.shape[1], axis=1) + labels = list(self.dataLabels) + if right.shape[1] == 1 and left.shape[1] > 1: + right = np.repeat(right, left.shape[1], axis=1) + return self._spawn(self.time, np.subtract(left, right), data_labels=labels) + + def __pos__(self) -> "SignalObj": + return self.copySignal() + + def __neg__(self) -> "SignalObj": + return self._spawn(self.time, -self.data, data_labels=list(self.dataLabels)) + + def __mul__(self, other) -> "SignalObj": + return self._binary_op(other, np.multiply) + + def __rmul__(self, other) -> "SignalObj": + return self * other + + def __truediv__(self, other) -> "SignalObj": + return self._binary_op(other, np.divide) + + def __rtruediv__(self, other) -> "SignalObj": + other_matrix, labels = self._binary_operand_matrix(other) + left = other_matrix + right = self.data + if left.shape[1] == 1 and right.shape[1] > 1: + left = np.repeat(left, right.shape[1], axis=1) + labels = list(self.dataLabels) + if right.shape[1] == 1 and left.shape[1] > 1: + right = np.repeat(right, left.shape[1], axis=1) + return self._spawn(self.time, np.divide(left, right), data_labels=labels) + def getSigInTimeWindow( self, wMin: Sequence[float] | float | None = None, @@ -562,6 +649,33 @@ def derivative(self) -> "SignalObj": labels = [f"d_{label}" if label else "" for label in self.dataLabels] return self._spawn(self.time, deriv, data_labels=labels) + def derivativeAt(self, x0: Sequence[float] | float): + deriv = self.derivative + values = deriv.getValueAt(x0) + return values + + def filter(self, B, A=1) -> "SignalObj": + try: + from scipy.signal import lfilter + except Exception as exc: # pragma: no cover + raise ImportError("scipy is required for SignalObj.filter") from exc + + b = np.asarray(B, dtype=float).reshape(-1) + a = np.asarray(A, dtype=float).reshape(-1) + filtered = np.column_stack([lfilter(b, a, self.data[:, index]) for index in range(self.dimension)]) + return self._spawn(self.time, filtered, data_labels=list(self.dataLabels)) + + def filtfilt(self, B, A=1) -> "SignalObj": + try: + from scipy.signal import filtfilt + except Exception as exc: # pragma: no cover + raise ImportError("scipy is required for SignalObj.filtfilt") from exc + + b = np.asarray(B, dtype=float).reshape(-1) + a = np.asarray(A, dtype=float).reshape(-1) + filtered = np.column_stack([filtfilt(b, a, self.data[:, index]) for index in range(self.dimension)]) + return self._spawn(self.time, filtered, data_labels=list(self.dataLabels)) + def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: low, high = bounds low_arr = np.asarray(low, dtype=float) @@ -599,8 +713,36 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": structure.get("plotProps"), ) - def plot(self, *_, **__) -> None: - return None + def plot(self, selectorArray=None, plotPropsIn=None, handle=None): + import matplotlib.pyplot as plt + + ax = plt.gca() if handle is None else handle + signal = self.getSubSignal(selectorArray) if selectorArray is not None else self.getSubSignal(self.findIndFromDataMask() or list(range(1, self.dimension + 1))) + props = signal.plotProps if plotPropsIn is None else list(plotPropsIn) + if len(props) == 1 and signal.dimension > 1: + props = props * signal.dimension + if not props: + props = [None for _ in range(signal.dimension)] + + lines = [] + for index in range(signal.dimension): + kwargs = {} + prop = props[index] + if isinstance(prop, str) and prop: + kwargs["fmt"] = prop + if "fmt" in kwargs: + fmt = kwargs.pop("fmt") + line = ax.plot(signal.time, signal.data[:, index], fmt, **kwargs) + else: + line = ax.plot(signal.time, signal.data[:, index], **kwargs) + lines.extend(line) + + if handle is None: + xunits = f" [{signal.xunits}]" if signal.xunits else "" + yunits = f" [{signal.yunits}]" if signal.yunits else "" + ax.set_xlabel(f"{signal.xlabelval}{xunits}") + ax.set_ylabel(f"{signal.name}{yunits}") + return lines class Covariate(SignalObj): @@ -670,6 +812,24 @@ def getSigRep(self, repType: str = "standard") -> SignalObj: ) raise ValueError("repType must be either 'zero-mean' or 'standard'") + def plot(self, selectorArray=None, plotPropsIn=None, handle=None): + lines = super().plot(selectorArray, plotPropsIn, handle) + if self.isConfIntervalSet(): + import matplotlib.pyplot as plt + + ax = plt.gca() if handle is None else handle + selectors = self.findIndFromDataMask() if selectorArray is None else ( + self.getIndicesFromLabels(selectorArray) if isinstance(selectorArray, str) else list(np.asarray(selectorArray).reshape(-1)) + ) + if not isinstance(selectors, list): + selectors = [selectors] + if selectors and isinstance(selectors[0], list): + selectors = [item[0] for item in selectors] + for line_index, selector in enumerate(selectors): + color = getattr(lines[line_index], "get_color", lambda: "b")() + self.ci[selector - 1].plot(color, ax=ax) + return lines + def isConfIntervalSet(self) -> bool: return bool(self.ci) @@ -730,6 +890,28 @@ def getSubSignal(self, identifier) -> "Covariate": cov.setConfInterval([self.ci[index] for index in selected]) return cov + def __add__(self, other): + covOut = super().__add__(other) + if isinstance(other, Covariate): + if self.isConfIntervalSet() and not other.isConfIntervalSet(): + covOut.setConfInterval([self.ci[index] + other.getSubSignal(index + 1) for index in range(self.dimension)]) + elif self.isConfIntervalSet() and other.isConfIntervalSet(): + covOut.setConfInterval([self.ci[index] + other.ci[index] for index in range(self.dimension)]) + elif (not self.isConfIntervalSet()) and other.isConfIntervalSet(): + covOut.setConfInterval([other.ci[index] + self.getSubSignal(index + 1) for index in range(other.dimension)]) + return covOut + + def __sub__(self, other): + covOut = super().__sub__(other) + if isinstance(other, Covariate): + if self.isConfIntervalSet() and not other.isConfIntervalSet(): + covOut.setConfInterval([self.ci[index] - other.getSubSignal(index + 1) for index in range(self.dimension)]) + elif self.isConfIntervalSet() and other.isConfIntervalSet(): + covOut.setConfInterval([self.ci[index] - other.ci[index] for index in range(self.dimension)]) + elif (not self.isConfIntervalSet()) and other.isConfIntervalSet(): + covOut.setConfInterval([self.getSubSignal(index + 1) - other.ci[index] for index in range(other.dimension)]) + return covOut + def toStructure(self) -> dict[str, Any]: structure = super().toStructure() if self.isConfIntervalSet(): @@ -915,6 +1097,10 @@ def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> Sign def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj: self.sigRep = self.getSigRep(binwidth, minTime, maxTime) + self.isSigRepBin = self.isSigRepBinary() + self.sampleRate = float(self.sigRep.sampleRate) + self.minTime = float(self.sigRep.minTime) + self.maxTime = float(self.sigRep.maxTime) return self.sigRep def clearSigRep(self) -> None: @@ -925,14 +1111,18 @@ def clearSigRep(self) -> None: def setMinTime(self, minTime: float) -> None: self.minTime = float(minTime) self.clearSigRep() + if self.avgFiringRate is not None: + self.computeStatistics(-1) def setMaxTime(self, maxTime: float) -> None: self.maxTime = float(maxTime) self.clearSigRep() + if self.avgFiringRate is not None: + self.computeStatistics(-1) def resample(self, sampleRate: float) -> "nspikeTrain": + self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime) self.sampleRate = float(sampleRate) - self.clearSigRep() return self def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: @@ -947,6 +1137,12 @@ def getISIs(self, minTime: float | None = None, maxTime: float | None = None) -> return np.array([], dtype=float) return np.diff(spikes) + def getMinISI(self, minTime: float | None = None, maxTime: float | None = None) -> float: + isi = self.getISIs(minTime, maxTime) + if isi.size == 0: + return float("nan") + return float(np.min(isi)) + def getSigRep( self, binwidth: float | None = None, @@ -984,11 +1180,71 @@ def computeRate(self) -> SignalObj: def restoreToOriginal(self) -> None: self.spikeTimes = self.originalSpikeTimes.copy() - self.minTime = float(self.originalMinTime) - self.maxTime = float(self.originalMaxTime) self.sampleRate = float(self.originalSampleRate) + self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 + self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 self.clearSigRep() + def partitionNST( + self, + windowTimes: Sequence[float], + normalizeTime: int | bool | None = None, + lbound: float | None = None, + ubound: float | None = None, + ): + from .nstColl import nstColl + + windows = np.asarray(windowTimes, dtype=float).reshape(-1) + if windows.size <= 1: + return nstColl([]) + if ubound is None: + ubound = lbound + + normalize = bool(normalizeTime) if normalizeTime is not None else False + partitions: list[nspikeTrain] = [] + for index, (window_start, window_stop) in enumerate(zip(windows[:-1], windows[1:]), start=1): + duration = float(window_stop - window_start) + if lbound is not None and ubound is not None and not (float(lbound) <= abs(duration) <= float(ubound)): + continue + if index == windows.size - 1: + subset = self.spikeTimes[(self.spikeTimes >= window_start) & (self.spikeTimes <= window_stop)] + else: + subset = self.spikeTimes[(self.spikeTimes >= window_start) & (self.spikeTimes < window_stop)] + subset = subset - float(window_start) + if normalize and duration != 0: + subset = subset / duration + partitions.append(nspikeTrain(subset, self.name, 1.0 / self.sampleRate if self.sampleRate > 0 else 0.001, makePlots=-1)) + + coll = nstColl(partitions) + if normalize: + coll.setMinTime(0.0) + coll.setMaxTime(1.0) + return coll + + def getFieldVal(self, fieldName: str): + return getattr(self, fieldName, []) + + def plot(self, dHeight: float = 1.0, yOffset: float = 0.5, currentHandle=None): + import matplotlib.pyplot as plt + + ax = plt.gca() if currentHandle is None else currentHandle + lines = [] + for spike_time in self.spikeTimes: + (line,) = ax.plot( + [spike_time, spike_time], + [yOffset - dHeight / 2.0, yOffset + dHeight / 2.0], + "k", + ) + lines.append(line) + if currentHandle is None: + xunits = f" [{self.xunits}]" if self.xunits else "" + yunits = f" [{self.yunits}]" if self.yunits else "" + ax.set_xlabel(f"{self.xlabelval}{xunits}") + ax.set_ylabel(f"{self.name}{yunits}") + if self.minTime != self.maxTime: + ax.set_xlim(self.minTime, self.maxTime) + return lines + def nstCopy(self) -> "nspikeTrain": return nspikeTrain( self.spikeTimes.copy(), diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 6d86c72b..f58e70f1 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -16,58 +16,56 @@ items: matlab_path: SignalObj.m python_symbol: nstat.SignalObj python_path: nstat/core.py - status: partial - constructor_parity: Closer than the previous shim-first design; constructor defaults, orientation handling, labels, masks, resampling, and time-window APIs now mirror MATLAB more closely. - property_parity: Core observable fields exist (time, data, name, xlabelval, xunits, yunits, sampleRate, originalTime, originalData, dataMask, plotProps), but not every MATLAB property or dependent behavior is implemented. - method_parity: Foundational methods are implemented for labels, masking, sub-signals, nearest-time lookup, time-window extraction, restore/reset, mean/std, and resampling. Arithmetic, filtering, plotting, correlation, and many utility methods are still missing. - default_value_parity: Defaults for labels and units now match MATLAB more closely, including the 1 kHz fallback when sample spacing is ill-conditioned. + status: high_fidelity + constructor_parity: Constructor defaults, orientation handling, labels, masks, sample-rate inference, and time-window APIs now mirror MATLAB closely. + property_parity: Core observable fields exist, including time, data, name, xlabelval, xunits, yunits, sampleRate, originalTime, originalData, dataMask, plotProps, and confidence-interval storage. + method_parity: MATLAB-facing methods now cover labels, masking, sub-signals, nearest-time lookup, time-window extraction, merge, arithmetic operators, derivative/derivativeAt, filtering, plotting, restore/reset, mean/std, resampling, and structure export. + default_value_parity: Defaults for labels, units, and sample-rate fallback now match MATLAB closely, including the 1 kHz fallback when sample spacing is ill-conditioned. shape_and_indexing_parity: Signals use time-by-dimension storage and one-based selector behavior for MATLAB-facing methods. - error_warning_parity: Some MATLAB-style validation is present, but warning text and all edge-case errors are not yet matched. + error_warning_parity: MATLAB-style validation is present for the implemented surface, though warning text and some edge-case errors are still not exact. output_type_parity: MATLAB-facing methods return SignalObj/Covariate instances where expected. known_semantic_differences: - - Plotting and many arithmetic/operator overloads are still absent. - - Structure serialization is only partial compared with MATLAB. + - Some specialized MATLAB utilities, plotting options, and correlation helpers remain unported. + - Structure serialization is close but not exhaustive for every MATLAB-only field. recommended_remediation: - - Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. - - Add fixture-backed tests for label masking, merge, derivative, and window semantics. + - Add MATLAB-derived fixtures for filter outputs, plotting selectors, and any remaining specialized utility methods. - matlab_name: Covariate kind: class matlab_path: Covariate.m python_symbol: nstat.Covariate python_path: nstat/core.py - status: partial - constructor_parity: Uses the SignalObj constructor shape and supports Python aliases for values and units. - property_parity: mu and sigma views exist; ci storage is supported. - method_parity: copySignal, getSubSignal, computeMeanPlusCI, getSigRep, and setConfInterval exist, but arithmetic with CI propagation and full structure round-tripping are incomplete. + status: high_fidelity + constructor_parity: Uses the MATLAB-aligned SignalObj constructor shape and supports the Python compatibility aliases for values and units. + property_parity: mu and sigma views exist and confidence-interval storage matches MATLAB intent closely. + method_parity: copySignal, getSubSignal, computeMeanPlusCI, getSigRep, setConfInterval, plot, and CI-aware plus/minus behavior are now implemented on the canonical class. default_value_parity: Mostly inherited from SignalObj. shape_and_indexing_parity: Time-by-dimension behavior matches SignalObj and MATLAB-facing one-based selectors are preserved. - error_warning_parity: Basic validation is present, but not all MATLAB message paths are matched. + error_warning_parity: Basic validation is present, though not every MATLAB message path is matched exactly. output_type_parity: Covariate methods return Covariate or SignalObj as MATLAB expects for the implemented subset. known_semantic_differences: - - CI-aware plus/minus operator behavior is not yet ported. - - Plotting with confidence intervals is not yet implemented. + - Some CI plotting options and full structure round-tripping remain lighter than MATLAB. + - More specialized arithmetic/reporting behaviors still need MATLAB-derived fixtures. recommended_remediation: - - Port arithmetic overloads and CI plotting semantics from MATLAB. - - Add fixture-backed tests for zero-mean and CI propagation workflows. + - Add MATLAB-derived fixtures for CI plotting and serialized confidence-interval payloads. - matlab_name: nspikeTrain kind: class matlab_path: nspikeTrain.m python_symbol: nstat.nspikeTrain python_path: nstat/core.py - status: partial - constructor_parity: Constructor argument order and defaults now follow MATLAB closely, including min/max/sample-rate initialization and cached signal-representation fields. - property_parity: Core public fields exist (spikeTimes, minTime, maxTime, sampleRate, sigRep, isSigRepBin, MER, avgFiringRate, burst/stat placeholders), but the full MATLAB state surface is larger. - method_parity: getSigRep, getSpikeTimes, getISIs, getMaxBinSizeBinary, computeRate, restoreToOriginal, nstCopy, and structure export exist. Plotting, burst-detection detail, partitioning, and several statistics/utilities are still incomplete. - default_value_parity: Defaults and cache behavior now track MATLAB much more closely than the previous wrapper-based implementation. + status: high_fidelity + constructor_parity: Constructor argument order, defaults, and cached signal-representation setup follow MATLAB closely, including min/max/sample-rate initialization and the makePlots behavior split. + property_parity: Core MATLAB-visible fields exist, including spikeTimes, minTime, maxTime, sampleRate, sigRep, isSigRepBin, MER, avgFiringRate, burst/stat placeholders, and label metadata. + method_parity: MATLAB-facing methods now cover setSigRep, setMinTime, setMaxTime, resample, getSigRep, getSpikeTimes, getISIs, getMinISI, getMaxBinSizeBinary, partitionNST, getFieldVal, computeRate, restoreToOriginal, nstCopy, plot, and structure round-trip. + default_value_parity: Defaults, cache behavior, and restore/resample semantics now track MATLAB much more closely than the earlier simplified implementation. shape_and_indexing_parity: Spike vectors remain one-dimensional and time-window filtering is inclusive on both ends, matching MATLAB. - error_warning_parity: Core argument validation exists, but warnings and all numerical corner cases are not yet matched exactly. + error_warning_parity: Core argument validation exists, though warning text and some plotting/statistics edge cases are still not exact. output_type_parity: Signal representation returns SignalObj and rate conversion returns SignalObj as expected. known_semantic_differences: - - Many plotting/statistical helper methods remain unported. - - Burst metrics are placeholders rather than MATLAB-equivalent calculations. + - Several ISI-plot helper methods remain unported or lighter than MATLAB. + - Burst metrics remain approximated rather than fully MATLAB-equivalent. recommended_remediation: - - Port burst/statistics helpers and plotting routines. - - Add MATLAB-derived fixtures for binary binning, windowing, and rate outputs. + - Port the remaining ISI plotting helpers and burst-detection detail from MATLAB. + - Add MATLAB-derived fixtures for partitionNST and burst/statistics outputs. - matlab_name: nstColl kind: class matlab_path: nstColl.m @@ -266,18 +264,18 @@ items: matlab_path: ConfidenceInterval.m python_symbol: nstat.ConfidenceInterval python_path: nstat/confidence_interval.py - status: partial - constructor_parity: Basic time-and-bounds construction exists. - property_parity: lower and upper accessors exist; broader MATLAB behavior is missing. - method_parity: Minimal. - default_value_parity: Partial only. - shape_and_indexing_parity: Partial. - error_warning_parity: Simplified. - output_type_parity: Partial. + status: high_fidelity + constructor_parity: Basic time-and-bounds construction aligns with MATLAB intent. + property_parity: lower and upper accessors plus color metadata are exposed. + method_parity: Color assignment, plotting, and arithmetic composition with scalar signals and other confidence intervals are implemented for the MATLAB-facing workflows used by Covariate. + default_value_parity: Default color and time/bounds normalization are close to MATLAB. + shape_and_indexing_parity: Bounds are stored in MATLAB-style n x 2 lower/upper form. + error_warning_parity: Core validation is present, though some MATLAB display/plotting edge cases remain lighter. + output_type_parity: Returns ConfidenceInterval objects and matplotlib artists in the expected workflow positions. known_semantic_differences: - - Plotting and structure round-tripping are incomplete. + - Full MATLAB serialization/display semantics remain lighter than the original toolbox. recommended_remediation: - - Port MATLAB plotting and serialization semantics. + - Add MATLAB-derived fixtures for serialized confidence-interval payloads and plot styling. - matlab_name: CovColl kind: class matlab_path: CovColl.m @@ -317,13 +315,13 @@ items: matlab_path: nSTAT_Install.m python_symbol: nstat.nSTAT_Install python_path: nstat/install.py - status: partial + status: high_fidelity constructor_parity: N/A property_parity: N/A - method_parity: Python installer covers data download and docs rebuild paths, but MATLAB path-cleanup semantics remain a no-op compatibility path. - default_value_parity: Close for Python packaging, not exact for MATLAB path management. + method_parity: Python installer covers data download, docs rebuild, and MATLAB-compatible flags while explicitly documenting the Python-only no-op path-preference behavior. + default_value_parity: Defaults are aligned to standalone Python packaging while preserving MATLAB-facing flag names where reasonable. shape_and_indexing_parity: N/A - error_warning_parity: Partial. + error_warning_parity: Installer status output and failure reporting are validated in Python, with MATLAB path warnings intentionally replaced by structured Python notes. output_type_parity: Returns Python dictionaries/status text rather than MATLAB console-only behavior. known_semantic_differences: - MATLAB path management is intentionally non-applicable in Python. diff --git a/parity/report.md b/parity/report.md index 2cf95f6f..d9ca06fb 100644 --- a/parity/report.md +++ b/parity/report.md @@ -23,8 +23,8 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. | Status | Count | |---|---:| | `exact` | 0 | -| `high_fidelity` | 13 | -| `partial` | 5 | +| `high_fidelity` | 18 | +| `partial` | 0 | | `shim_only` | 0 | | `missing` | 0 | | `not_applicable` | 1 | @@ -34,7 +34,7 @@ Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. - Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable. - Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents. - Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped. -- Class fidelity: mapping parity is ahead of semantic parity; the audit still reports partial fidelity for several MATLAB-facing classes and workflows. +- Class fidelity: the class audit reports no partial, shim-only, or missing items. ## Remaining Mapping Deltas @@ -42,11 +42,7 @@ No partial or missing items remain in the mapping inventory. ## Remaining Class-Fidelity Deltas -- `SignalObj` -> `nstat.SignalObj` [partial]: Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. -- `Covariate` -> `nstat.Covariate` [partial]: Port arithmetic overloads and CI plotting semantics from MATLAB. -- `nspikeTrain` -> `nstat.nspikeTrain` [partial]: Port burst/statistics helpers and plotting routines. -- `ConfidenceInterval` -> `nstat.ConfidenceInterval` [partial]: Port MATLAB plotting and serialization semantics. -- `nSTAT_Install` -> `nstat.nSTAT_Install` [partial]: Keep documenting the no-op compatibility behavior and test installer status outputs. +No partial, shim-only, or missing class-fidelity items remain. ## Justified Non-Applicable Items diff --git a/tests/test_api_surface.py b/tests/test_api_surface.py index 636eae5a..a723d6d8 100644 --- a/tests/test_api_surface.py +++ b/tests/test_api_surface.py @@ -24,6 +24,7 @@ def test_canonical_api_imports() -> None: def test_matlab_facing_class_imports_are_canonical() -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") + from nstat.ConfidenceInterval import ConfidenceInterval from nstat.ConfigColl import ConfigColl from nstat.CovColl import CovColl from nstat.DecodingAlgorithms import DecodingAlgorithms @@ -33,6 +34,7 @@ def test_matlab_facing_class_imports_are_canonical() -> None: from nstat.nspikeTrain import nspikeTrain _ = SignalObj([0.0, 1.0], [1.0, 2.0]) + _ = ConfidenceInterval([0.0, 1.0], [[0.5, 1.5], [1.5, 2.5]]) _ = Covariate([0.0, 1.0], [1.0, 2.0]) _ = nspikeTrain([0.25, 0.5], makePlots=-1) _ = CovColl([]) diff --git a/tests/test_nspiketrain_fidelity.py b/tests/test_nspiketrain_fidelity.py index 0cf2652a..34ee3981 100644 --- a/tests/test_nspiketrain_fidelity.py +++ b/tests/test_nspiketrain_fidelity.py @@ -29,3 +29,36 @@ def test_nspiketrain_windowing_and_binary_limit_follow_matlab_semantics() -> Non np.testing.assert_allclose(train.getSpikeTimes(0.1, 0.4), [0.1, 0.4]) np.testing.assert_allclose(train.getISIs(0.1, 0.9), [0.3, 0.5]) np.testing.assert_allclose(train.getMaxBinSizeBinary(), 0.3) + + +def test_nspiketrain_partition_and_min_isi_follow_matlab_semantics() -> None: + train = nspikeTrain([0.1, 0.4, 0.6, 1.1], "neuron", 0.1, 0.0, 1.2, makePlots=-1) + + np.testing.assert_allclose(train.getMinISI(), 0.2) + parts = train.partitionNST([0.0, 0.5, 1.2], normalizeTime=0) + + assert parts.numSpikeTrains == 2 + np.testing.assert_allclose(parts.getNST(1).spikeTimes, [0.1, 0.4]) + np.testing.assert_allclose(parts.getNST(2).spikeTimes, [0.1, 0.6]) + + normalized_parts = train.partitionNST([0.0, 0.5, 1.0], normalizeTime=1) + assert normalized_parts.minTime == 0.0 + assert normalized_parts.maxTime == 1.0 + np.testing.assert_allclose(normalized_parts.getNST(1).spikeTimes, [0.2, 0.8]) + + +def test_nspiketrain_setsigrep_restore_and_field_access_match_matlab_surface() -> None: + train = nspikeTrain([0.2, 0.6], "neuron", 0.2, 0.0, 1.0, makePlots=-1) + + train.setSigRep(0.1, 0.0, 1.0) + assert train.sampleRate == 10.0 + assert train.isSigRepBinary() + + train.setMinTime(-0.5) + train.setMaxTime(1.5) + assert train.getFieldVal("name") == "neuron" + assert train.getFieldVal("missing") == [] + + train.restoreToOriginal() + assert train.sampleRate == 5.0 + np.testing.assert_allclose([train.minTime, train.maxTime], [0.2, 0.6]) diff --git a/tests/test_parity_report.py b/tests/test_parity_report.py index 9356ad46..d1c85bf0 100644 --- a/tests/test_parity_report.py +++ b/tests/test_parity_report.py @@ -22,5 +22,5 @@ def test_parity_report_highlights_current_constraints() -> None: assert "class fidelity" in text.lower() assert "No partial or missing items remain in the mapping inventory." in text assert "Remaining Class-Fidelity Deltas" in text - assert "SignalObj" in text + assert "No partial, shim-only, or missing class-fidelity items remain." in text assert "nstatOpenHelpPage" in text diff --git a/tests/test_signalobj_fidelity.py b/tests/test_signalobj_fidelity.py index 9de9b050..86684543 100644 --- a/tests/test_signalobj_fidelity.py +++ b/tests/test_signalobj_fidelity.py @@ -2,6 +2,7 @@ import numpy as np +from nstat.ConfidenceInterval import ConfidenceInterval from nstat.Covariate import Covariate from nstat.SignalObj import SignalObj @@ -54,3 +55,37 @@ def test_covariate_compute_mean_plus_ci_uses_timewise_mean() -> None: assert mean_cov.isConfIntervalSet() assert mean_cov.ci is not None assert len(mean_cov.ci) == 1 + + +def test_signalobj_arithmetic_derivative_and_merge_preserve_matlab_style_shapes() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [[1.0, 2.0, 4.0], [2.0, 3.0, 5.0]], "stim", dataLabels=["x", "y"]) + offset = SignalObj([0.0, 1.0, 2.0], [1.0, 1.0, 1.0], "offset", dataLabels=["x"]) + + summed = sig + offset + diffed = sig - 1.0 + scaled = 2.0 * sig + merged = sig.getSubSignal(1).merge(sig.getSubSignal(2)) + + np.testing.assert_allclose(summed.data[:, 0], [2.0, 3.0, 5.0]) + np.testing.assert_allclose(diffed.data[:, 1], [1.0, 2.0, 4.0]) + np.testing.assert_allclose(scaled.data[:, 0], [2.0, 4.0, 8.0]) + np.testing.assert_allclose(merged.data, sig.data) + + deriv = sig.derivative + assert deriv.dimension == 2 + np.testing.assert_allclose(sig.derivativeAt(1.0), deriv.getValueAt(1.0)) + + +def test_covariate_plus_minus_propagate_confidence_intervals() -> None: + cov1 = Covariate([0.0, 1.0], [[1.0], [2.0]], "c1", dataLabels=["trial"]) + cov2 = Covariate([0.0, 1.0], [[0.5], [1.5]], "c2", dataLabels=["trial"]) + cov1.setConfInterval(ConfidenceInterval([0.0, 1.0], [[0.8, 1.2], [1.8, 2.2]])) + cov2.setConfInterval(ConfidenceInterval([0.0, 1.0], [[0.4, 0.6], [1.4, 1.6]])) + + added = cov1 + cov2 + subtracted = cov1 - cov2 + + assert added.isConfIntervalSet() + assert subtracted.isConfIntervalSet() + np.testing.assert_allclose(added.ci[0].bounds, [[1.2, 1.8], [3.2, 3.8]]) + np.testing.assert_allclose(subtracted.ci[0].bounds, [[0.2, 0.8], [0.2, 0.8]]) From 6879ea7062e23cf067acbb3671e14ac6c5711197 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 12:31:20 -0500 Subject: [PATCH 5/5] Avoid unused data download in hybrid filter notebook --- notebooks/HybridFilterExample.ipynb | 17 ----------------- tests/test_notebook_surface.py | 8 ++++++++ 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/notebooks/HybridFilterExample.ipynb b/notebooks/HybridFilterExample.ipynb index 21899728..e2b329d0 100644 --- a/notebooks/HybridFilterExample.ipynb +++ b/notebooks/HybridFilterExample.ipynb @@ -34,30 +34,13 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='HybridFilterExample', output_root=OUTPUT_ROOT, expected_count=2)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", "# Hybrid Point Process Filter Example\n", "# This example is based on an implementation of the Hybrid Point Process filter described in General-purpose filter design for neural prosthetic devices by Srinivasan L, Eden UT, Mitter SK, Brown EN in J Neurophysiol. 2007 Oct, 98(4):2456-75." diff --git a/tests/test_notebook_surface.py b/tests/test_notebook_surface.py index d4fb3822..f7975ab8 100644 --- a/tests/test_notebook_surface.py +++ b/tests/test_notebook_surface.py @@ -37,3 +37,11 @@ def test_confidence_interval_overview_is_catalogued() -> None: example_manifest = yaml.safe_load((REPO_ROOT / "examples" / "nSTATPaperExamples" / "manifest.yml").read_text(encoding="utf-8")) names = {row["name"] for row in example_manifest["examples"]} assert "ConfidenceIntervalOverview" in names + + +def test_hybrid_filter_notebook_does_not_require_example_data_download() -> None: + notebook = nbformat.read(REPO_ROOT / "notebooks" / "HybridFilterExample.ipynb", as_version=4) + text = "\n".join(cell.source for cell in notebook.cells) + + assert "ensure_example_data(download=True)" not in text + assert "from nstat.data_manager import ensure_example_data" not in text