From df5e826e9dffaf06cc4462f9ec473e8a4a5c299c Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 11:21:23 -0500 Subject: [PATCH 1/3] Promote canonical MATLAB-facing core classes --- nstat/Covariate.py | 10 +- nstat/SignalObj.py | 10 +- nstat/core.py | 1075 +++++++++++++++++++++++----- nstat/nspikeTrain.py | 10 +- nstat/nstColl.py | 9 +- parity/class_fidelity.yml | 342 +++++++++ tests/test_api_surface.py | 8 +- tests/test_class_fidelity_audit.py | 60 ++ tests/test_nspiketrain_fidelity.py | 31 + tests/test_signalobj_fidelity.py | 56 ++ 10 files changed, 1413 insertions(+), 198 deletions(-) create mode 100644 parity/class_fidelity.yml create mode 100644 tests/test_class_fidelity_audit.py create mode 100644 tests/test_nspiketrain_fidelity.py create mode 100644 tests/test_signalobj_fidelity.py diff --git a/nstat/Covariate.py b/nstat/Covariate.py index e04d6825..3ce9ab1a 100644 --- a/nstat/Covariate.py +++ b/nstat/Covariate.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .signal import Covariate as _Covariate - - -class Covariate(_Covariate): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.Covariate.Covariate", "nstat.signal.Covariate") - super().__init__(*args, **kwargs) - +from .core import Covariate __all__ = ["Covariate"] diff --git a/nstat/SignalObj.py b/nstat/SignalObj.py index 32e333a3..b2c335bb 100644 --- a/nstat/SignalObj.py +++ b/nstat/SignalObj.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .signal import Signal - - -class SignalObj(Signal): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.SignalObj.SignalObj", "nstat.signal.Signal") - super().__init__(*args, **kwargs) - +from .core import SignalObj __all__ = ["SignalObj"] diff --git a/nstat/core.py b/nstat/core.py index a1836b1a..ffbaecc4 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -1,58 +1,105 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Iterable, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np def _as_1d_float(values: Sequence[float] | np.ndarray, name: str) -> np.ndarray: + array = np.asarray(values, dtype=float) + if array.ndim == 0: + raise ValueError(f"{name} must be array-like.") + if array.ndim > 2 or (array.ndim == 2 and min(array.shape) != 1): + raise ValueError(f"{name} can only have one dimension.") + return array.reshape(-1) + + +def _normalize_signal_matrix(data: Sequence[float] | Sequence[Sequence[float]] | np.ndarray, n_time: int) -> np.ndarray: + matrix = np.asarray(data, dtype=float) + if matrix.ndim == 0: + raise ValueError("Data must be array-like.") + if matrix.ndim == 1: + matrix = matrix.reshape(-1, 1) + elif matrix.ndim != 2: + raise ValueError("Data must be one- or two-dimensional.") + + if matrix.shape[0] == n_time: + return matrix.astype(float, copy=True) + if matrix.shape[1] == n_time: + return matrix.T.astype(float, copy=True) + raise ValueError("Data dimensions do not match the time vector specified.") + + +def _coerce_1based_indices(values: Sequence[int] | np.ndarray, upper: int) -> list[int]: + out: list[int] = [] + for raw in np.asarray(values).reshape(-1): + index = int(raw) + if index < 1 or index > upper: + raise IndexError("Signal index out of range. Indexing is 1-based.") + out.append(index) + return out + + +def _roundn(values: Sequence[float] | np.ndarray, decimals: int) -> np.ndarray: + return np.round(np.asarray(values, dtype=float), decimals=max(int(decimals), 0)) + + +def _matlab_mode_1d(values: Sequence[float] | np.ndarray) -> float: array = np.asarray(values, dtype=float).reshape(-1) if array.size == 0: - raise ValueError(f"{name} must be non-empty.") - return array + return np.nan + unique, counts = np.unique(array, return_counts=True) + best = np.flatnonzero(counts == np.max(counts)) + return float(unique[int(best[0])]) class SignalObj: - """Python approximation of nSTAT SignalObj. - - The class stores a time vector and one or more aligned signal channels. - """ + """Closer MATLAB-style signal abstraction used throughout the Python port.""" def __init__( self, time: Sequence[float], data: Sequence[float] | Sequence[Sequence[float]] | np.ndarray, - name: str = "signal", - xlabel: str = "time", + name: str = "", + xlabelval: str = "time", xunits: str = "s", yunits: str = "", - data_labels: Sequence[str] | None = None, + dataLabels: Sequence[str] | str | None = None, + plotProps: Sequence[Any] | str | None = None, + **kwargs, ) -> None: - t = _as_1d_float(time, "time") - if np.any(np.diff(t) <= 0): - raise ValueError("time must be strictly increasing.") - - x = np.asarray(data, dtype=float) - if x.ndim == 1: - x = x[:, None] - if x.shape[0] != t.shape[0]: - raise ValueError("data must have same first dimension as time.") - - self.time = t - self.data = x - self.name = name - self.xlabelval = xlabel - self.xunits = xunits - self.yunits = yunits + if "xlabel" in kwargs and "xlabelval" not in kwargs: + xlabelval = kwargs.pop("xlabel") + if "data_labels" in kwargs and dataLabels is None: + dataLabels = kwargs.pop("data_labels") + if kwargs: + unexpected = ", ".join(sorted(kwargs)) + raise TypeError(f"Unexpected keyword arguments: {unexpected}") + + self.time = _as_1d_float(time, "Time vector") + self.data = _normalize_signal_matrix(data, self.time.size) + self.name = str(name) + self.xlabelval = str(xlabelval) + self.xunits = str(xunits) + self.yunits = str(yunits) + self.minTime = float(np.min(self.time)) if self.time.size else 0.0 + self.maxTime = float(np.max(self.time)) if self.time.size else 0.0 - if data_labels is None: - labels = [f"{name}_{k+1}" for k in range(self.data.shape[1])] + if self.time.size > 1: + delta_t = float(np.mean(np.diff(self.time))) else: - labels = list(data_labels) - if len(labels) != self.data.shape[1]: - raise ValueError("data_labels length must match signal dimension.") - self.dataLabels = labels + delta_t = np.nan + if not np.isfinite(delta_t) or delta_t <= 0: + delta_t = 0.001 + self.sampleRate = float(1.0 / delta_t) + self.origSampleRate = float(self.sampleRate) + self.originalTime = self.time.copy() + self.originalData = self.data.copy() + self.dataMask = np.ones(self.dimension, dtype=int) + self.plotProps: list[Any] = [] + self.setPlotProps(plotProps) + self.setDataLabels(dataLabels if dataLabels is not None else "") self.conf_interval: tuple[np.ndarray, np.ndarray] | None = None @property @@ -71,132 +118,493 @@ def units(self) -> str: @property def sample_rate(self) -> float: - if self.time.shape[0] < 2: - return 0.0 - dt = np.median(np.diff(self.time)) - if dt <= 0: - return 0.0 - return float(1.0 / dt) + return float(self.sampleRate) - def copySignal(self) -> "SignalObj": - out = SignalObj( - self.time.copy(), - self.data.copy(), + def _spawn( + self, + time: np.ndarray, + data: np.ndarray, + *, + data_labels: Sequence[str] | None = None, + plot_props: Sequence[Any] | None = None, + ) -> "SignalObj": + labels = list(self.dataLabels) if data_labels is None else list(data_labels) + props = list(self.plotProps) if plot_props is None else list(plot_props) + return self.__class__( + np.asarray(time, dtype=float).copy(), + np.asarray(data, dtype=float).copy(), self.name, self.xlabelval, self.xunits, self.yunits, - self.dataLabels, + labels, + props, ) - out.conf_interval = None if self.conf_interval is None else ( - self.conf_interval[0].copy(), - self.conf_interval[1].copy(), - ) - return out + + def copySignal(self) -> "SignalObj": + copied = self._spawn(self.time, self.data) + if self.conf_interval is not None: + copied.conf_interval = ( + np.asarray(self.conf_interval[0], dtype=float).copy(), + np.asarray(self.conf_interval[1], dtype=float).copy(), + ) + copied.dataMask = np.asarray(self.dataMask, dtype=int).copy() + copied.originalTime = self.originalTime.copy() + copied.originalData = self.originalData.copy() + copied.sampleRate = float(self.sampleRate) + copied.origSampleRate = float(self.origSampleRate) + copied.minTime = float(self.minTime) + copied.maxTime = float(self.maxTime) + return copied def setName(self, name: str) -> None: - self.name = str(name) + if not isinstance(name, str): + raise TypeError("Name must be a string!") + self.name = name + + def setXlabel(self, name: str) -> None: + self.xlabelval = str(name) + + def setYLabel(self, name: str) -> None: + self.setName(name) + + def setUnits(self, xUnits: str, yUnits: str | None = None) -> None: + if yUnits is not None: + self.setYUnits(yUnits) + self.setXUnits(xUnits) + + def setXUnits(self, units: str) -> None: + if isinstance(units, str): + self.xunits = units + + def setYUnits(self, units: str) -> None: + if isinstance(units, str): + self.yunits = units + + def setSampleRate(self, sampleRate: float) -> None: + requested = float(sampleRate) + current = float(self.sampleRate) + if abs(round(requested, 3) - round(current, 3)) > 0: + self.resampleMe(requested) - def setDataLabels(self, labels: Sequence[str]) -> None: - labels = list(labels) + def setDataLabels(self, dataLabels: Sequence[str] | str | None) -> None: + if dataLabels is None or (isinstance(dataLabels, str) and dataLabels == ""): + self.dataLabels = ["" for _ in range(self.dimension)] + return + + if isinstance(dataLabels, str): + self.dataLabels = [dataLabels for _ in range(self.dimension)] + return + + labels = [str(label) for label in dataLabels] if len(labels) != self.dimension: - raise ValueError("labels length must equal number of signal channels.") + raise ValueError("Need the number of labels to match the number of dimensions of the SignalObj") self.dataLabels = labels - def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: - low, high = bounds - low = np.asarray(low, dtype=float) - high = np.asarray(high, dtype=float) - if low.shape[0] != self.time.shape[0] or high.shape[0] != self.time.shape[0]: - raise ValueError("confidence interval bounds must align with time.") - self.conf_interval = (low, high) + def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None = None) -> None: + if index is None: + if plotProps is None: + self.plotProps = [None for _ in range(self.dimension)] + elif isinstance(plotProps, str): + self.plotProps = [plotProps for _ in range(self.dimension)] + else: + props = list(plotProps) + if len(props) == 1 and self.dimension > 1: + props = props * self.dimension + if len(props) != self.dimension: + raise ValueError("plotProps length must match signal dimension.") + self.plotProps = props + return + + indices = _coerce_1based_indices([index], self.dimension) + target = indices[0] - 1 + if not self.plotProps: + self.plotProps = [None for _ in range(self.dimension)] + if isinstance(plotProps, Sequence) and not isinstance(plotProps, str): + props = list(plotProps) + self.plotProps[target] = props[0] if props else None + else: + self.plotProps[target] = plotProps + + def setDataMask(self, dataMask: Sequence[int] | np.ndarray) -> None: + mask = np.asarray(dataMask, dtype=int).reshape(-1) + if mask.size != self.dimension: + raise ValueError("dataMask must match the number of signal dimensions.") + if np.any((mask != 0) & (mask != 1)): + raise ValueError("dataMask must be binary.") + self.dataMask = mask + + def setMaskByInd(self, index: Sequence[int] | np.ndarray) -> None: + selected = _coerce_1based_indices(index, self.dimension) + mask = np.zeros(self.dimension, dtype=int) + mask[np.asarray(selected, dtype=int) - 1] = 1 + self.setDataMask(mask) + + def setMaskByLabels(self, labels: Sequence[str] | str) -> None: + indices = self.getIndicesFromLabels(labels) + if isinstance(indices, list) and indices and isinstance(indices[0], list): + flat = [item for sub in indices for item in sub] + elif isinstance(indices, list): + flat = indices + else: + flat = [indices] + self.setMaskByInd(flat) + + def setMask(self, mask: Sequence[int] | Sequence[str] | np.ndarray | None = None) -> None: + if mask is None: + self.setDataMask(np.zeros(self.dimension, dtype=int)) + return - def getSubSignal(self, idx: int) -> "SignalObj": - if idx < 1 or idx > self.dimension: + if isinstance(mask, str): + self.setMaskByLabels(mask) + return + + values = list(mask) + if not values: + self.setDataMask(np.zeros(self.dimension, dtype=int)) + return + + first = values[0] + if isinstance(first, str): + self.setMaskByLabels(values) + return + + arr = np.asarray(values) + if arr.size == self.dimension and np.all(np.isin(arr, [0, 1])): + self.setDataMask(arr.astype(int)) + return + self.setMaskByInd(arr.astype(int)) + + def getTime(self) -> np.ndarray: + return self.time.copy() + + def getData(self) -> np.ndarray: + return self.dataToMatrix() + + def getOriginalData(self) -> tuple[np.ndarray, np.ndarray]: + return self.originalTime.copy(), self.originalData.copy() + + def getOrigDataSig(self) -> "SignalObj": + return self._spawn(self.originalTime, self.originalData) + + def getPlotProps(self, index: int) -> Any: + idx = _coerce_1based_indices([index], self.dimension)[0] - 1 + return self.plotProps[idx] + + def getIndexFromLabel(self, label: str) -> list[int]: + matches = [i + 1 for i, value in enumerate(self.dataLabels) if value == label] + if not matches: + raise ValueError("Label does not exist!") + return matches + + def getIndicesFromLabels(self, label: Sequence[str] | str): + if isinstance(label, str): + matches = self.getIndexFromLabel(label) + return matches[0] if len(matches) == 1 else matches + + out = [self.getIndexFromLabel(str(item)) for item in label] + counts = [len(item) for item in out] + if counts and max(counts) == 1: + return [item[0] for item in out] + return out + + def getValueAt(self, x: Sequence[float] | float) -> np.ndarray: + query = np.asarray(x, dtype=float).reshape(-1) + out = np.zeros((query.size, self.dimension), dtype=float) + valid = (query >= self.minTime) & (query <= self.maxTime) + if np.any(valid): + q_valid = query[valid] + right = np.searchsorted(self.time, q_valid, side="left") + right = np.clip(right, 0, self.time.size - 1) + left = np.clip(right - 1, 0, self.time.size - 1) + choose_right = np.abs(self.time[right] - q_valid) <= np.abs(self.time[left] - q_valid) + indices = np.where(choose_right, right, left) + out[valid] = self.data[indices] + return out[0] if np.isscalar(x) else out + + def _selector_to_zero_based(self, selectorArray: Sequence[int] | np.ndarray | None) -> np.ndarray: + if selectorArray is None: + if self.isMaskSet(): + selected = self.findIndFromDataMask() + else: + selected = list(range(1, self.dimension + 1)) + else: + if isinstance(selectorArray, str): + selected = self.getIndicesFromLabels(selectorArray) + else: + selected = selectorArray + indices = np.asarray(selected, dtype=int).reshape(-1) + if indices.size == 0: + return np.array([], dtype=int) + if np.min(indices) < 1 or np.max(indices) > self.dimension: raise IndexError("Signal index out of range. Indexing is 1-based.") - j = idx - 1 - return SignalObj( + return indices - 1 + + def dataToMatrix(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> np.ndarray: + indices = self._selector_to_zero_based(selectorArray) + if indices.size == 0: + return np.zeros((self.time.size, 0), dtype=float) + return self.data[:, indices] + + def _labels_for_indices(self, zero_based: np.ndarray) -> list[str]: + return [self.dataLabels[int(i)] for i in zero_based] + + def _plot_props_for_indices(self, zero_based: np.ndarray) -> list[Any]: + if not self.plotProps: + return [None for _ in zero_based] + return [self.plotProps[int(i)] for i in zero_based] + + def getSubSignalFromInd(self, selectorArray: Sequence[int] | np.ndarray) -> "SignalObj": + indices = self._selector_to_zero_based(selectorArray) + return self._spawn( self.time, - self.data[:, j], - self.name, - self.xlabelval, - self.xunits, - self.yunits, - [self.dataLabels[j]], + self.data[:, indices], + data_labels=self._labels_for_indices(indices), + plot_props=self._plot_props_for_indices(indices), ) - def getSigInTimeWindow(self, t0: float, t1: float) -> "SignalObj": - mask = (self.time >= t0) & (self.time <= t1) - if not np.any(mask): - raise ValueError("Requested time window has no samples.") - return SignalObj( - self.time[mask], - self.data[mask, :], - self.name, - self.xlabelval, - self.xunits, - self.yunits, - self.dataLabels, - ) + def getSubSignalFromNames(self, labels: Sequence[str] | str) -> "SignalObj": + indices = self.getIndicesFromLabels(labels) + return self.getSubSignalFromInd(indices if isinstance(indices, list) else [indices]) + + def getSubSignal(self, identifier) -> "SignalObj": + if isinstance(identifier, str): + return self.getSubSignalFromNames(identifier) + if isinstance(identifier, np.ndarray): + values = identifier.reshape(-1).tolist() + elif isinstance(identifier, Sequence): + values = list(identifier) + else: + values = [identifier] + if values and isinstance(values[0], str): + return self.getSubSignalFromNames(values) + return self.getSubSignalFromInd(values) + + def findNearestTimeIndex(self, time: float) -> int: + value = float(time) + if value < self.minTime: + return 1 + if value > self.maxTime: + return self.time.size + right = int(np.searchsorted(self.time, value, side="left")) + if right <= 0: + return 1 + if right >= self.time.size: + return self.time.size + left = right - 1 + if abs(self.time[right] - value) <= abs(self.time[left] - value): + return right + 1 + return left + 1 + + def findNearestTimeIndices(self, times: Sequence[float] | np.ndarray) -> np.ndarray: + return np.asarray([self.findNearestTimeIndex(value) for value in np.asarray(times, dtype=float).reshape(-1)], dtype=int) + + def setMinTime(self, minTime: float | None = None, holdVals: int = 0) -> None: + target = self.time[0] if minTime is None else float(minTime) + timeVec = self.getTime() + if target < float(np.min(timeVec)): + maxTime = float(np.max(timeVec)) + dt = 1.0 / self.sampleRate + newTime = np.arange(target, maxTime + 0.5 * dt, dt, dtype=float) + numSamples = int(newTime.size - timeVec.size) + if holdVals == 1: + pad = np.tile(self.data[0:1, :], (numSamples, 1)) + else: + pad = np.zeros((numSamples, self.dimension), dtype=float) + self.data = np.vstack([pad, self.data]) + self.time = newTime + elif target > float(np.min(timeVec)): + startIndex = self.findNearestTimeIndex(target) - 1 + self.time = self.time[startIndex:] + self.data = self.data[startIndex:, :] + self.minTime = float(np.min(self.time)) + + def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None: + target = self.time[-1] if maxTime is None else float(maxTime) + timeVec = self.getTime() + if float(np.max(timeVec)) < target: + minTime = float(np.min(timeVec)) + n_samples = int(round(self.sampleRate * (target - minTime))) + 1 + n_samples = max(n_samples, timeVec.size) + newTime = np.linspace(minTime, target, n_samples, dtype=float) + numSamples = int(newTime.size - timeVec.size) + if holdVals == 1: + pad = np.tile(self.data[-1:, :], (numSamples, 1)) + else: + pad = np.zeros((numSamples, self.dimension), dtype=float) + self.data = np.vstack([self.data, pad]) + self.time = newTime + elif float(np.max(timeVec)) > target: + endIndex = self.findNearestTimeIndex(target) + self.time = self.time[:endIndex] + self.data = self.data[:endIndex, :] + self.maxTime = float(np.max(self.time)) def merge(self, other: "SignalObj") -> "SignalObj": if self.time.shape != other.time.shape or np.max(np.abs(self.time - other.time)) > 1e-9: raise ValueError("Signals must share an identical time grid to merge.") - return SignalObj( + merged = self._spawn( self.time, np.column_stack([self.data, other.data]), - self.name, - self.xlabelval, - self.xunits, - self.yunits, - [*self.dataLabels, *other.dataLabels], + data_labels=[*self.dataLabels, *list(other.dataLabels)], + plot_props=[*self.plotProps, *getattr(other, "plotProps", [None for _ in range(other.dimension)])], ) + return merged + + def getSigInTimeWindow( + self, + wMin: Sequence[float] | float | None = None, + wMax: Sequence[float] | float | None = None, + holdVals: int = 0, + ) -> "SignalObj": + if wMax is None: + wMax = self.maxTime + if wMin is None: + wMin = self.minTime + + min_values = np.asarray([wMin] if np.isscalar(wMin) else wMin, dtype=float).reshape(-1) + max_values = np.asarray([wMax] if np.isscalar(wMax) else wMax, dtype=float).reshape(-1) + if min_values.size != max_values.size: + raise ValueError("Window minTimes must contain the same number of elements as window maxTimes") + + if min_values.size == 1 and self.minTime == float(min_values[0]) and self.maxTime == float(max_values[0]): + return self.copySignal() + + windowed: SignalObj | None = None + for idx, (left, right) in enumerate(zip(min_values, max_values), start=1): + current = self.copySignal() + if left < current.minTime: + current.setMinTime(left, holdVals) + if right > current.maxTime: + current.setMaxTime(right, holdVals) + + start = current.findNearestTimeIndex(left) - 1 + stop = current.findNearestTimeIndex(right) + current.time = current.time[start:stop] + current.data = current.data[start:stop, :] + labels = list(current.dataLabels) + if min_values.size > 1: + labels = [f"{label}_{{{idx}}}" for label in labels] + current.setDataLabels(labels) + current.setMinTime() + current.setMaxTime() + windowed = current if windowed is None else windowed.merge(current) + return windowed if windowed is not None else self.copySignal() + + def restoreToOriginal(self, rMask: int = 0) -> None: + self.time = self.originalTime.copy() + self.data = self.originalData.copy() + self.minTime = float(np.min(self.time)) + self.maxTime = float(np.max(self.time)) + self.sampleRate = float(self.origSampleRate) + if rMask == 1: + self.resetMask() + + def resetMask(self) -> None: + self.dataMask = np.ones(self.dimension, dtype=int) + + def findIndFromDataMask(self) -> list[int]: + return [int(index) + 1 for index in np.flatnonzero(self.dataMask == 1)] + + def isMaskSet(self) -> bool: + return bool(np.any(self.dataMask == 0)) + + def mean(self, axis: int | None = None) -> "SignalObj": + axis_arg = 0 if axis is None else axis + mean_data = np.mean(self.data, axis=axis_arg) + array = np.asarray(mean_data, dtype=float) + if array.ndim == 1 and array.size == self.dimension: + labels = [f"\\mu({label})" if label else "" for label in self.dataLabels] + return self._spawn( + np.asarray([self.time[0], self.time[-1]], dtype=float), + np.vstack([array, array]), + data_labels=labels, + ) + reshaped = array.reshape(-1, 1) + return self._spawn(self.time, reshaped, data_labels=[f"\\mu({self.name})"]) + + def std(self, axis: int | None = None) -> "SignalObj": + axis_arg = 0 if axis is None else axis + std_data = np.std(self.data, axis=axis_arg) + array = np.asarray(std_data, dtype=float) + if array.ndim == 1 and array.size == self.dimension: + labels = [f"\\sigma({label})" if label else "" for label in self.dataLabels] + return self._spawn( + np.asarray([self.time[0], self.time[-1]], dtype=float), + np.vstack([array, array]), + data_labels=labels, + ) + reshaped = array.reshape(-1, 1) + return self._spawn(self.time, reshaped, data_labels=[f"\\sigma({self.name})"]) def resample(self, sample_rate: float) -> "SignalObj": - if sample_rate <= 0: - raise ValueError("sample_rate must be > 0.") - dt = 1.0 / float(sample_rate) - t_new = np.arange(self.time[0], self.time[-1] + 0.5 * dt, dt) - x_new = np.column_stack( - [np.interp(t_new, self.time, self.data[:, i]) for i in range(self.dimension)] - ) - return SignalObj( - t_new, - x_new, - self.name, - self.xlabelval, - self.xunits, - self.yunits, - self.dataLabels, - ) + copied = self.copySignal() + copied.resampleMe(sample_rate) + return copied + + def resampleMe(self, newSampleRate: float) -> None: + rate = float(newSampleRate) + if rate <= 0: + raise ValueError("sampleRate must be > 0.") + dt = 1.0 / rate + newTime = np.arange(self.time[0], self.time[-1] + 0.5 * dt, dt, dtype=float) + newData = np.column_stack([np.interp(newTime, self.time, self.data[:, i]) for i in range(self.dimension)]) + self.time = newTime + self.data = newData + self.sampleRate = rate + self.minTime = float(np.min(newTime)) + self.maxTime = float(np.max(newTime)) @property def derivative(self) -> "SignalObj": - dt = np.gradient(self.time) deriv = np.column_stack([np.gradient(self.data[:, i], self.time) for i in range(self.dimension)]) - # Avoid numerical noise spikes where dt is near 0. deriv[~np.isfinite(deriv)] = 0.0 + labels = [f"d_{label}" if label else "" for label in self.dataLabels] + return self._spawn(self.time, deriv, data_labels=labels) + + def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: + low, high = bounds + low_arr = np.asarray(low, dtype=float) + high_arr = np.asarray(high, dtype=float) + if low_arr.shape[0] != self.time.shape[0] or high_arr.shape[0] != self.time.shape[0]: + raise ValueError("confidence interval bounds must align with time.") + self.conf_interval = (low_arr, high_arr) + + def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> dict[str, Any]: + data = self.dataToMatrix(selectorArray) + return { + "time": self.time.tolist(), + "data": data.tolist(), + "name": self.name, + "xlabelval": self.xlabelval, + "xunits": self.xunits, + "yunits": self.yunits, + "dataLabels": list(self.dataLabels), + "plotProps": list(self.plotProps), + } + + def toStructure(self) -> dict[str, Any]: + return self.dataToStructure() + + @staticmethod + def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": return SignalObj( - self.time, - deriv, - f"d/dt({self.name})", - self.xlabelval, - self.xunits, - self.yunits, - [f"d_{lbl}" for lbl in self.dataLabels], + structure["time"], + structure["data"], + structure.get("name", ""), + structure.get("xlabelval", "time"), + structure.get("xunits", "s"), + structure.get("yunits", ""), + structure.get("dataLabels"), + structure.get("plotProps"), ) def plot(self, *_, **__) -> None: - # Intentionally lightweight: plotting is handled in examples where needed. return None class Covariate(SignalObj): - """MATLAB-compatible alias for SignalObj. - - Accepts both MATLAB-style positional arguments and Pythonic keywords: - `Covariate(time=t, values=x, name='stim', units='a.u.')`. - """ + """MATLAB-style covariate signal with CI and zero-mean views.""" def __init__(self, *args, **kwargs) -> None: if "values" in kwargs and "data" not in kwargs: @@ -204,31 +612,197 @@ def __init__(self, *args, **kwargs) -> None: if "units" in kwargs and "yunits" not in kwargs: kwargs["yunits"] = kwargs.pop("units") super().__init__(*args, **kwargs) + self.ci: list[Any] | None = None + @property + def mu(self) -> SignalObj: + return self.mean() -@dataclass -class nspikeTrain: - """Python approximation of MATLAB nspikeTrain.""" + @property + def sigma(self) -> SignalObj: + return self.std() + + def computeMeanPlusCI(self, alphaVal: float = 0.05) -> "Covariate": + from .confidence_interval import ConfidenceInterval + + sorted_data = np.sort(self.data, axis=1) + n_rep = sorted_data.shape[1] + if n_rep == 0: + raise ValueError("Covariate must contain at least one column to compute confidence intervals.") + ecdf = np.arange(1, n_rep + 1, dtype=float) / float(n_rep) + lower = np.empty(sorted_data.shape[0], dtype=float) + upper = np.empty(sorted_data.shape[0], dtype=float) + for row_idx in range(sorted_data.shape[0]): + row = sorted_data[row_idx] + lower_idx = np.flatnonzero(ecdf < (alphaVal / 2.0)) + upper_idx = np.flatnonzero(ecdf > (1.0 - alphaVal / 2.0)) + lower[row_idx] = row[int(lower_idx[-1])] if lower_idx.size else row[0] + upper[row_idx] = row[int(upper_idx[0])] if upper_idx.size else row[-1] + confInt = ConfidenceInterval(self.time, np.column_stack([lower, upper])) + mean_signal = np.mean(self.data, axis=1) + newCov = Covariate( + self.time.copy(), + mean_signal, + self.name, + self.xlabelval, + self.xunits, + self.yunits, + [f"\\mu({self.name})" if self.name else "\\mu"], + ) + newCov.setConfInterval(confInt) + return newCov + + def getSigRep(self, repType: str = "standard") -> SignalObj: + rep = str(repType).strip().lower() + if rep == "standard": + return self + if rep == "zero-mean": + centered = self.data - np.mean(self.data, axis=0, keepdims=True) + return Covariate( + self.time, + centered, + self.name, + self.xlabelval, + self.xunits, + self.yunits, + list(self.dataLabels), + list(self.plotProps), + ) + raise ValueError("repType must be either 'zero-mean' or 'standard'") + + def isConfIntervalSet(self) -> bool: + return bool(self.ci) + + def setConfInterval(self, ciObj) -> None: + if isinstance(ciObj, list): + self.ci = list(ciObj) + else: + self.ci = [ciObj] + + def copySignal(self) -> "Covariate": + copied = Covariate( + self.time.copy(), + self.data.copy(), + self.name, + self.xlabelval, + self.xunits, + self.yunits, + list(self.dataLabels), + list(self.plotProps), + ) + copied.dataMask = np.asarray(self.dataMask, dtype=int).copy() + copied.originalTime = self.originalTime.copy() + copied.originalData = self.originalData.copy() + copied.sampleRate = float(self.sampleRate) + copied.origSampleRate = float(self.origSampleRate) + copied.minTime = float(self.minTime) + copied.maxTime = float(self.maxTime) + copied.ci = None if not self.ci else list(self.ci) + if self.conf_interval is not None: + copied.conf_interval = ( + np.asarray(self.conf_interval[0], dtype=float).copy(), + np.asarray(self.conf_interval[1], dtype=float).copy(), + ) + return copied + + def getSubSignal(self, identifier) -> "Covariate": + sub = super().getSubSignal(identifier) + cov = Covariate( + sub.time, + sub.data, + sub.name, + sub.xlabelval, + sub.xunits, + sub.yunits, + list(sub.dataLabels), + list(sub.plotProps), + ) + if self.isConfIntervalSet(): + selected: list[int] = [] + for label in cov.dataLabels: + if label: + match = next((i for i, original in enumerate(self.dataLabels) if original == label), None) + if match is None: + raise ValueError("Unable to align Covariate confidence interval with sub-signal labels.") + selected.append(match) + else: + selected.append(len(selected)) + cov.setConfInterval([self.ci[index] for index in selected]) + return cov - spikeTimes: np.ndarray - name: str = "" - binwidth: float = 0.001 - minTime: float | None = None - maxTime: float | None = None + def toStructure(self) -> dict[str, Any]: + structure = super().toStructure() + if self.isConfIntervalSet(): + ci_payload: list[dict[str, Any]] = [] + for item in self.ci or []: + if hasattr(item, "time") and hasattr(item, "bounds"): + ci_payload.append( + { + "time": np.asarray(item.time, dtype=float).tolist(), + "bounds": np.asarray(item.bounds, dtype=float).tolist(), + "color": getattr(item, "color", "b"), + } + ) + structure["ci"] = ci_payload + return structure - def __post_init__(self) -> None: - spikes = np.asarray(self.spikeTimes, dtype=float).reshape(-1) - spikes = np.sort(spikes) - self.spikeTimes = spikes - if self.minTime is None: - self.minTime = float(spikes[0]) if spikes.size else 0.0 - if self.maxTime is None: - self.maxTime = float(spikes[-1]) if spikes.size else self.minTime +class nspikeTrain: + """Closer MATLAB-style spike-train object with cached signal representation.""" - self.minTime = float(self.minTime) - self.maxTime = float(self.maxTime) - self.sampleRate = float(1.0 / self.binwidth) + def __init__( + self, + spikeTimes, + name: str = "", + binwidth: float = 0.001, + minTime: float | None = None, + maxTime: float | None = None, + xlabelval: str = "time", + xunits: str = "s", + yunits: str = "", + dataLabels: str | Sequence[str] | None = "", + makePlots: int = 0, + ) -> None: + if spikeTimes is None: + raise ValueError("nspikeTrain requires a spikeTimes array as input to create an object") + spikes = np.asarray(spikeTimes, dtype=float).reshape(-1) + self.spikeTimes = np.sort(spikes) + self.originalSpikeTimes = self.spikeTimes.copy() + self.name = str(name) + self.sampleRate = float(1.0 / float(binwidth)) + self.originalSampleRate = float(self.sampleRate) + if minTime is None: + minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 + if maxTime is None: + maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 + self.minTime = float(minTime) + self.maxTime = float(maxTime) + self.originalMinTime = float(self.minTime) + self.originalMaxTime = float(self.maxTime) + self.xlabelval = str(xlabelval) + self.xunits = str(xunits) + self.yunits = str(yunits) + self.dataLabels = dataLabels if dataLabels is not None else "" + self.sigRep: SignalObj | None = None + self.isSigRepBin: bool | None = None + self._sigrep_cache_key: tuple[float, float, float] | None = None + self.MER = None + if makePlots >= 0: + self.computeStatistics(makePlots) + else: + self.avgFiringRate = None + self.B = None + self.An = None + self.burstTimes = None + self.burstRate = None + self.burstDuration = None + self.burstSig = None + self.burstIndex = None + self.numBursts = None + self.numSpikesPerBurst = None + self.avgSpikesPerBurst = None + self.stdSpikesPerBurst = None + self.Lstatistic = None @property def times(self) -> np.ndarray: @@ -236,7 +810,7 @@ def times(self) -> np.ndarray: @property def n_spikes(self) -> int: - return int(self.spikeTimes.shape[0]) + return int(self.spikeTimes.size) @property def duration(self) -> float: @@ -244,24 +818,134 @@ def duration(self) -> float: @property def firing_rate_hz(self) -> float: - d = self.duration - if d <= 0: + if self.duration <= 0: return 0.0 - return float(self.n_spikes / d) + return float(self.n_spikes / self.duration) + + def setMER(self, MERSig: SignalObj) -> None: + if isinstance(MERSig, SignalObj): + self.MER = MERSig def setName(self, name: str) -> None: self.name = str(name) - def setMinTime(self, value: float) -> None: - self.minTime = float(value) + def computeStatistics(self, makePlots: int = 0) -> None: + self.avgFiringRate = self.firing_rate_hz + isi = self.getISIs() + mode_isi = _matlab_mode_1d(isi) + self.burstIndex = float(1.0 / mode_isi / self.avgFiringRate) if np.isfinite(mode_isi) and self.avgFiringRate > 0 else np.nan + self.B = np.nan + self.An = np.nan + self.burstTimes = np.array([], dtype=float) + self.burstRate = np.array([], dtype=float) + self.burstDuration = np.array([], dtype=float) + self.burstSig = None + self.numBursts = 0 + self.numSpikesPerBurst = np.array([], dtype=float) + self.avgSpikesPerBurst = np.nan + self.stdSpikesPerBurst = np.nan + self.Lstatistic = self.getLStatistic() + + def getLStatistic(self) -> float: + isi = self.getISIs() + if isi.size == 0: + return np.nan + mean_isi = float(np.mean(isi)) + if not np.isfinite(mean_isi) or mean_isi <= 0: + return np.nan + duration = self.maxTime - self.minTime + if not np.isfinite(duration) or duration <= 0: + return np.nan + approx = self.getSigRep(mean_isi) + return float(np.unique(approx.data[:, 0]).size) + + def _cache_key(self, binwidth: float, minTime: float, maxTime: float) -> tuple[float, float, float]: + return (round(float(binwidth), 12), round(float(minTime), 12), round(float(maxTime), 12)) + + def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> SignalObj: + if binwidth <= 0: + raise ValueError("binwidth must be > 0") + if maxTime < minTime: + raise ValueError("maxTime must be >= minTime") + + max_bins = int(1e6) + precision = max(0, int(2 * np.ceil(np.log10(1.0 / binwidth)))) + bw = float(_roundn([binwidth], precision)[0]) + duration = float(maxTime - minTime) + if np.isfinite(duration) and duration > 0 and np.isfinite(bw) and bw > 0: + est_bins = duration / bw + 1.0 + if not np.isfinite(est_bins) or est_bins > max_bins: + bw = duration / float(max_bins - 1) + precision = max(0, int(2 * np.ceil(np.log10(1.0 / bw)))) + bw = float(_roundn([bw], precision)[0]) + if not np.isfinite(bw) or bw <= 0: + bw = duration / float(max_bins - 1) if np.isfinite(duration) and duration > 0 else 1.0 / max(self.sampleRate, 1.0) + + numBins = int(np.floor(duration / bw + 1.0)) if np.isfinite(duration) else 2 + if numBins < 2: + numBins = 2 + if numBins > max_bins: + numBins = max_bins + timeVec = np.linspace(minTime, maxTime, numBins, dtype=float) + if timeVec.size > 1: + bw = float(np.mean(np.diff(timeVec))) + windowTimes = np.concatenate([[minTime - bw / 2.0], timeVec + bw / 2.0]) - def setMaxTime(self, value: float) -> None: - self.maxTime = float(value) + spikeTimes = _roundn(self.spikeTimes, precision) + rounded_windows = _roundn(windowTimes, precision + 1) + counts = np.zeros(timeVec.size, dtype=float) + split_index = int(np.floor(rounded_windows.size / 2.0)) + for idx in range(timeVec.size): + left = rounded_windows[idx] + right = rounded_windows[idx + 1] + if idx == rounded_windows.size - 2: + temp = spikeTimes[spikeTimes >= left] + counts[idx] = float(np.sum(temp <= right)) + elif idx + 1 > split_index: + temp = spikeTimes[spikeTimes >= left] + counts[idx] = float(np.sum(temp < right)) + else: + temp = spikeTimes[spikeTimes < right] + counts[idx] = float(np.sum(temp >= left)) - def getISIs(self) -> np.ndarray: - if self.n_spikes < 2: + label = self.dataLabels if isinstance(self.dataLabels, str) else "" + sig = SignalObj(timeVec, counts.astype(float), self.name, self.xlabelval, self.xunits, self.yunits, label) + self.isSigRepBin = bool(np.all(counts <= 1)) + return sig + + def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj: + self.sigRep = self.getSigRep(binwidth, minTime, maxTime) + return self.sigRep + + def clearSigRep(self) -> None: + self.sigRep = None + self._sigrep_cache_key = None + self.isSigRepBin = None + + def setMinTime(self, minTime: float) -> None: + self.minTime = float(minTime) + self.clearSigRep() + + def setMaxTime(self, maxTime: float) -> None: + self.maxTime = float(maxTime) + self.clearSigRep() + + def resample(self, sampleRate: float) -> "nspikeTrain": + self.sampleRate = float(sampleRate) + self.clearSigRep() + return self + + def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + start = self.minTime if minTime is None else float(minTime) + stop = self.maxTime if maxTime is None else float(maxTime) + spikes = self.spikeTimes[(self.spikeTimes >= start) & (self.spikeTimes <= stop)] + return spikes.copy() + + def getISIs(self, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + spikes = self.getSpikeTimes(minTime, maxTime) + if spikes.size < 2: return np.array([], dtype=float) - return np.diff(self.spikeTimes) + return np.diff(spikes) def getSigRep( self, @@ -269,26 +953,91 @@ def getSigRep( minTime: float | None = None, maxTime: float | None = None, ) -> SignalObj: - bw = self.binwidth if binwidth is None else float(binwidth) - t0 = self.minTime if minTime is None else float(minTime) - t1 = self.maxTime if maxTime is None else float(maxTime) - if bw <= 0: - raise ValueError("binwidth must be > 0") - if t1 < t0: - raise ValueError("maxTime must be >= minTime") + bw = (1.0 / self.sampleRate) if binwidth is None else float(binwidth) + start = self.minTime if minTime is None else float(minTime) + stop = self.maxTime if maxTime is None else float(maxTime) + key = self._cache_key(bw, start, stop) + if self.sigRep is not None and self._sigrep_cache_key == key: + return self.sigRep.copySignal() + sig = self._build_sigrep(bw, start, stop) + self.sigRep = sig.copySignal() + self._sigrep_cache_key = key + return sig - edges = np.arange(t0, t1 + 1.5 * bw, bw) - if edges.shape[0] < 2: - edges = np.array([t0, t0 + bw], dtype=float) - counts, _ = np.histogram(self.spikeTimes, bins=edges) - centers = edges[:-1] + 0.5 * bw - return SignalObj(centers, counts.astype(float), self.name or "spikes", "time", "s", "count", ["counts"]) + def getMaxBinSizeBinary(self) -> float: + isi = self.getISIs() + if isi.size == 0: + return np.inf + return float(np.min(isi)) + + def isSigRepBinary(self) -> bool: + if self.isSigRepBin is None: + self.getSigRep() + return bool(self.isSigRepBin) + + def computeRate(self) -> SignalObj: + sig = self.getSigRep() + if self.sampleRate <= 0: + return sig + rate = np.asarray(sig.data[:, 0], dtype=float) * float(self.sampleRate) + return SignalObj(sig.time, rate, self.name, sig.xlabelval, sig.xunits, "spikes/sec", sig.dataLabels) + + def restoreToOriginal(self) -> None: + self.spikeTimes = self.originalSpikeTimes.copy() + self.minTime = float(self.originalMinTime) + self.maxTime = float(self.originalMaxTime) + self.sampleRate = float(self.originalSampleRate) + self.clearSigRep() + + def nstCopy(self) -> "nspikeTrain": + return nspikeTrain( + self.spikeTimes.copy(), + self.name, + 1.0 / self.sampleRate if self.sampleRate > 0 else 0.001, + self.minTime, + self.maxTime, + self.xlabelval, + self.xunits, + self.yunits, + self.dataLabels, + -1, + ) def to_binned_counts(self, bin_edges: Sequence[float]) -> np.ndarray: edges = np.asarray(bin_edges, dtype=float).reshape(-1) counts, _ = np.histogram(self.spikeTimes, bins=edges) return counts.astype(float) + def toStructure(self) -> dict[str, Any]: + return { + "spikeTimes": self.spikeTimes.tolist(), + "name": self.name, + "sampleRate": self.sampleRate, + "minTime": self.minTime, + "maxTime": self.maxTime, + "xlabelval": self.xlabelval, + "xunits": self.xunits, + "yunits": self.yunits, + "dataLabels": self.dataLabels, + } + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "nspikeTrain": + sampleRate = float(structure.get("sampleRate", 1000.0)) + binwidth = 1.0 / sampleRate if sampleRate > 0 else 0.001 + return nspikeTrain( + structure.get("spikeTimes", []), + structure.get("name", ""), + binwidth, + structure.get("minTime"), + structure.get("maxTime"), + structure.get("xlabelval", "time"), + structure.get("xunits", "s"), + structure.get("yunits", ""), + structure.get("dataLabels", ""), + -1, + ) + # Backward-compatible alias used by earlier Python scaffolding. SpikeTrain = nspikeTrain diff --git a/nstat/nspikeTrain.py b/nstat/nspikeTrain.py index f6e54d8a..92662bdb 100644 --- a/nstat/nspikeTrain.py +++ b/nstat/nspikeTrain.py @@ -1,13 +1,5 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .spikes import SpikeTrain - - -class nspikeTrain(SpikeTrain): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.nspikeTrain.nspikeTrain", "nstat.spikes.SpikeTrain") - super().__init__(*args, **kwargs) - +from .core import nspikeTrain __all__ = ["nspikeTrain"] diff --git a/nstat/nstColl.py b/nstat/nstColl.py index 1e1cb96a..e7af8063 100644 --- a/nstat/nstColl.py +++ b/nstat/nstColl.py @@ -1,13 +1,10 @@ from __future__ import annotations -from ._compat import warn_deprecated_adapter -from .spikes import SpikeTrainCollection +from .trial import SpikeTrainCollection as _SpikeTrainCollection -class nstColl(SpikeTrainCollection): - def __init__(self, *args, **kwargs) -> None: - warn_deprecated_adapter("nstat.nstColl.nstColl", "nstat.spikes.SpikeTrainCollection") - super().__init__(*args, **kwargs) +class nstColl(_SpikeTrainCollection): + """MATLAB-facing spike-train collection class.""" __all__ = ["nstColl"] diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml new file mode 100644 index 00000000..6e2707b8 --- /dev/null +++ b/parity/class_fidelity.yml @@ -0,0 +1,342 @@ +version: 1 +generated_on: 2026-03-07 +source_repositories: + matlab: https://github.com/cajigaslab/nSTAT + python: https://github.com/cajigaslab/nSTAT-python +status_legend: + - exact + - high_fidelity + - partial + - shim_only + - missing + - not_applicable +items: + - matlab_name: SignalObj + kind: class + matlab_path: SignalObj.m + python_symbol: nstat.SignalObj + python_path: nstat/core.py + status: partial + constructor_parity: Closer than the previous shim-first design; constructor defaults, orientation handling, labels, masks, resampling, and time-window APIs now mirror MATLAB more closely. + property_parity: Core observable fields exist (time, data, name, xlabelval, xunits, yunits, sampleRate, originalTime, originalData, dataMask, plotProps), but not every MATLAB property or dependent behavior is implemented. + method_parity: Foundational methods are implemented for labels, masking, sub-signals, nearest-time lookup, time-window extraction, restore/reset, mean/std, and resampling. Arithmetic, filtering, plotting, correlation, and many utility methods are still missing. + default_value_parity: Defaults for labels and units now match MATLAB more closely, including the 1 kHz fallback when sample spacing is ill-conditioned. + shape_and_indexing_parity: Signals use time-by-dimension storage and one-based selector behavior for MATLAB-facing methods. + error_warning_parity: Some MATLAB-style validation is present, but warning text and all edge-case errors are not yet matched. + output_type_parity: MATLAB-facing methods return SignalObj/Covariate instances where expected. + known_semantic_differences: + - Plotting and many arithmetic/operator overloads are still absent. + - Structure serialization is only partial compared with MATLAB. + recommended_remediation: + - Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. + - Add fixture-backed tests for label masking, merge, derivative, and window semantics. + - matlab_name: Covariate + kind: class + matlab_path: Covariate.m + python_symbol: nstat.Covariate + python_path: nstat/core.py + status: partial + constructor_parity: Uses the SignalObj constructor shape and supports Python aliases for values and units. + property_parity: mu and sigma views exist; ci storage is supported. + method_parity: copySignal, getSubSignal, computeMeanPlusCI, getSigRep, and setConfInterval exist, but arithmetic with CI propagation and full structure round-tripping are incomplete. + default_value_parity: Mostly inherited from SignalObj. + shape_and_indexing_parity: Time-by-dimension behavior matches SignalObj and MATLAB-facing one-based selectors are preserved. + error_warning_parity: Basic validation is present, but not all MATLAB message paths are matched. + output_type_parity: Covariate methods return Covariate or SignalObj as MATLAB expects for the implemented subset. + known_semantic_differences: + - CI-aware plus/minus operator behavior is not yet ported. + - Plotting with confidence intervals is not yet implemented. + recommended_remediation: + - Port arithmetic overloads and CI plotting semantics from MATLAB. + - Add fixture-backed tests for zero-mean and CI propagation workflows. + - matlab_name: nspikeTrain + kind: class + matlab_path: nspikeTrain.m + python_symbol: nstat.nspikeTrain + python_path: nstat/core.py + status: partial + constructor_parity: Constructor argument order and defaults now follow MATLAB closely, including min/max/sample-rate initialization and cached signal-representation fields. + property_parity: Core public fields exist (spikeTimes, minTime, maxTime, sampleRate, sigRep, isSigRepBin, MER, avgFiringRate, burst/stat placeholders), but the full MATLAB state surface is larger. + method_parity: getSigRep, getSpikeTimes, getISIs, getMaxBinSizeBinary, computeRate, restoreToOriginal, nstCopy, and structure export exist. Plotting, burst-detection detail, partitioning, and several statistics/utilities are still incomplete. + default_value_parity: Defaults and cache behavior now track MATLAB much more closely than the previous wrapper-based implementation. + shape_and_indexing_parity: Spike vectors remain one-dimensional and time-window filtering is inclusive on both ends, matching MATLAB. + error_warning_parity: Core argument validation exists, but warnings and all numerical corner cases are not yet matched exactly. + output_type_parity: Signal representation returns SignalObj and rate conversion returns SignalObj as expected. + known_semantic_differences: + - Many plotting/statistical helper methods remain unported. + - Burst metrics are placeholders rather than MATLAB-equivalent calculations. + recommended_remediation: + - Port burst/statistics helpers and plotting routines. + - Add MATLAB-derived fixtures for binary binning, windowing, and rate outputs. + - matlab_name: nstColl + kind: class + matlab_path: nstColl.m + python_symbol: nstat.nstColl + python_path: nstat/trial.py + status: partial + constructor_parity: Basic collection construction exists, but MATLAB supports richer empty-init and object-state patterns. + property_parity: numSpikeTrains, minTime, maxTime, and sampleRate are present. + method_parity: getNST, dataToMatrix, psth, and psthGLM exist; masking, config utilities, neighborhood operations, and richer analysis helpers are still missing. + default_value_parity: Partial only. + shape_and_indexing_parity: MATLAB-facing one-based getNST is preserved. + error_warning_parity: Simplified. + output_type_parity: PSTH returns Covariate. + known_semantic_differences: + - No MATLAB-equivalent mask state or richer analysis utilities yet. + recommended_remediation: + - Port the remaining collection methods from MATLAB and move the class into a canonical MATLAB-facing implementation file. + - matlab_name: Trial + kind: class + matlab_path: Trial.m + python_symbol: nstat.Trial + python_path: nstat/trial.py + status: partial + constructor_parity: Supports core spike/covariate/event wiring, but not the full MATLAB constructor and object-state surface. + property_parity: spikeColl and covarColl are exposed; broader trial metadata/state is still missing. + method_parity: Limited to core matrix/vector access. + default_value_parity: Partial only. + shape_and_indexing_parity: Core one-based neuron selection is preserved via getSpikeVector. + error_warning_parity: Simplified. + output_type_parity: Returns NumPy arrays rather than richer MATLAB-style objects in several workflows. + known_semantic_differences: + - Trial workflow semantics remain much thinner than MATLAB. + recommended_remediation: + - Port richer trial state, consistency checks, and MATLAB workflow helpers. + - matlab_name: TrialConfig + kind: class + matlab_path: TrialConfig.m + python_symbol: nstat.TrialConfig + python_path: nstat/trial.py + status: partial + constructor_parity: Current dataclass captures only a subset of MATLAB configuration fields. + property_parity: covMask, sampleRate, history, ensCovHist, covLag, and name exist, but MATLAB exposes richer behavior. + method_parity: Only naming and covariate-name extraction are currently implemented. + default_value_parity: Partial only. + shape_and_indexing_parity: N/A for this class. + error_warning_parity: Simplified. + output_type_parity: Python dataclass rather than a richer MATLAB handle-style object. + known_semantic_differences: + - Configuration validation and selection behavior are incomplete. + recommended_remediation: + - Port MATLAB configuration validation, normalization, and selection workflows. + - matlab_name: ConfigColl + kind: class + matlab_path: ConfigColl.m + python_symbol: nstat.ConfigColl + python_path: nstat/trial.py + status: partial + constructor_parity: Basic collection support exists. + property_parity: numConfigs and configArray exist. + method_parity: addConfig, getConfig, and getConfigNames exist; MATLAB collection utilities are broader. + default_value_parity: Partial only. + shape_and_indexing_parity: One-based getConfig behavior is preserved. + error_warning_parity: Simplified. + output_type_parity: Returns TrialConfig instances. + known_semantic_differences: + - Richer MATLAB config-management behavior is still missing. + recommended_remediation: + - Port the remaining ConfigColl helpers and name/selection semantics from MATLAB. + - matlab_name: Analysis + kind: class + matlab_path: Analysis.m + python_symbol: nstat.Analysis + python_path: nstat/analysis.py + status: partial + constructor_parity: Python analysis setup exists, but MATLAB option surface and workflow selection semantics are richer. + property_parity: Partial only. + method_parity: Core fitting helpers exist; RunAnalysisForNeuron and RunAnalysisForAllNeurons are still simplified relative to MATLAB. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial only. + error_warning_parity: Simplified. + output_type_parity: Returns FitSummary/FitResult equivalents, but not with full MATLAB metadata fidelity. + known_semantic_differences: + - Algorithm selection and analysis-option semantics are still thinner than MATLAB. + recommended_remediation: + - Port MATLAB analysis options and representative workflow outputs into dataset-backed tests. + - matlab_name: FitResult + kind: class + matlab_path: FitResult.m + python_symbol: nstat.FitResult + python_path: nstat/fit.py + status: partial + constructor_parity: Partial. + property_parity: Core lambda/spike-train references exist, but MATLAB surface is richer. + method_parity: Summary/reporting methods are only partially ported. + default_value_parity: Partial only. + shape_and_indexing_parity: N/A for this class. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - Fit metadata and reporting behavior remain thinner than MATLAB. + recommended_remediation: + - Port MATLAB result-summary and reporting APIs with golden fixtures. + - matlab_name: FitResSummary + kind: class + matlab_path: FitResSummary.m + python_symbol: nstat.FitResSummary + python_path: nstat/fit.py + status: partial + constructor_parity: Partial. + property_parity: Partial. + method_parity: Partial. + default_value_parity: Partial only. + shape_and_indexing_parity: N/A for this class. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - Summary-table and figure/report behavior is not yet MATLAB-equivalent. + recommended_remediation: + - Port summary aggregation and reporting semantics from MATLAB. + - matlab_name: CIF + kind: class + matlab_path: CIF.m + python_symbol: nstat.CIF + python_path: nstat/cif.py + status: partial + constructor_parity: Partial. + property_parity: Partial. + method_parity: Simulation and conversion helpers exist, but the full MATLAB object model is broader. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - History-aware and decoding-relevant CIF workflows remain thinner than MATLAB. + recommended_remediation: + - Port MATLAB CIF behaviors used by simulation, fitting, and decoding workflows. + - matlab_name: DecodingAlgorithms + kind: class + matlab_path: DecodingAlgorithms.m + python_symbol: nstat.DecodingAlgorithms + python_path: nstat/decoding_algorithms.py + status: partial + constructor_parity: Partial. + property_parity: Partial. + method_parity: Python decoding helpers exist, but not with full MATLAB workflow fidelity. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - Point-process decoding workflows are not yet fully MATLAB-equivalent. + recommended_remediation: + - Port canonical decoding workflows and validate them against MATLAB-derived outputs. + - matlab_name: History + kind: class + matlab_path: History.m + python_symbol: nstat.History + python_path: nstat/history.py + status: partial + constructor_parity: Partial. + property_parity: Partial. + method_parity: Basic history basis construction exists; richer MATLAB history workflows and outputs are missing. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - MATLAB returns richer covariate collections and configuration behavior. + recommended_remediation: + - Port full History object workflows and fixture-backed outputs. + - matlab_name: Events + kind: class + matlab_path: Events.m + python_symbol: nstat.Events + python_path: nstat/events.py + status: partial + constructor_parity: Partial. + property_parity: Event times/labels support exists, but color and full validation parity are incomplete. + method_parity: Partial. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - MATLAB validation and plotting semantics are not fully ported. + recommended_remediation: + - Port event validation, color handling, and notebook-backed workflows. + - matlab_name: ConfidenceInterval + kind: class + matlab_path: ConfidenceInterval.m + python_symbol: nstat.ConfidenceInterval + python_path: nstat/confidence_interval.py + status: partial + constructor_parity: Basic time-and-bounds construction exists. + property_parity: lower and upper accessors exist; broader MATLAB behavior is missing. + method_parity: Minimal. + default_value_parity: Partial only. + shape_and_indexing_parity: Partial. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - Plotting and structure round-tripping are incomplete. + recommended_remediation: + - Port MATLAB plotting and serialization semantics. + - matlab_name: CovColl + kind: class + matlab_path: CovColl.m + python_symbol: nstat.CovColl + python_path: nstat/trial.py + status: partial + constructor_parity: Basic collection support exists. + property_parity: Partial. + method_parity: add/get/dataToMatrix exist; MATLAB collection behavior is broader. + default_value_parity: Partial only. + shape_and_indexing_parity: Shared-time enforcement is implemented. + error_warning_parity: Simplified. + output_type_parity: Partial. + known_semantic_differences: + - Richer selection and covariate management helpers are missing. + recommended_remediation: + - Port remaining CovColl behaviors and helpfile workflows. + - matlab_name: getPaperDataDirs + kind: function + matlab_path: getPaperDataDirs.m + python_symbol: nstat.getPaperDataDirs + python_path: nstat/data_manager.py + status: high_fidelity + constructor_parity: N/A + property_parity: N/A + method_parity: Python helper exposes MATLAB-style name and standalone repo semantics. + default_value_parity: Defaults to the Python repo's independent example-data cache instead of a MATLAB checkout path. + shape_and_indexing_parity: N/A + error_warning_parity: Close for the Python use case. + output_type_parity: Returns directory paths as a Python tuple/list structure rather than MATLAB cell arrays. + known_semantic_differences: + - Python returns native path types/strings rather than MATLAB cells. + recommended_remediation: + - Add a MATLAB-reference fixture for the directory tuple shape if stricter parity is needed. + - matlab_name: nSTAT_Install + kind: function + matlab_path: nSTAT_Install.m + python_symbol: nstat.nSTAT_Install + python_path: nstat/install.py + status: partial + constructor_parity: N/A + property_parity: N/A + method_parity: Python installer covers data download and docs rebuild paths, but MATLAB path-cleanup semantics remain a no-op compatibility path. + default_value_parity: Close for Python packaging, not exact for MATLAB path management. + shape_and_indexing_parity: N/A + error_warning_parity: Partial. + output_type_parity: Returns Python dictionaries/status text rather than MATLAB console-only behavior. + known_semantic_differences: + - MATLAB path management is intentionally non-applicable in Python. + recommended_remediation: + - Keep documenting the no-op compatibility behavior and test installer status outputs. + - matlab_name: nstatOpenHelpPage + kind: function + matlab_path: nstatOpenHelpPage.m + python_symbol: null + python_path: null + status: not_applicable + constructor_parity: N/A + property_parity: N/A + method_parity: MATLAB help-browser integration has no direct standalone Python equivalent. + default_value_parity: N/A + shape_and_indexing_parity: N/A + error_warning_parity: N/A + output_type_parity: N/A + known_semantic_differences: + - Python uses Sphinx docs pages instead of the MATLAB help browser. + recommended_remediation: + - None. diff --git a/tests/test_api_surface.py b/tests/test_api_surface.py index 6879c6e0..8ddfd8d8 100644 --- a/tests/test_api_surface.py +++ b/tests/test_api_surface.py @@ -21,10 +21,14 @@ def test_canonical_api_imports() -> None: assert nstat.nSTAT_Install is not None -def test_compatibility_adapters_emit_deprecation() -> None: +def test_matlab_facing_class_imports_are_canonical() -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") from nstat.SignalObj import SignalObj + from nstat.Covariate import Covariate + from nstat.nspikeTrain import nspikeTrain _ = SignalObj([0.0, 1.0], [1.0, 2.0]) - assert any("deprecated" in str(item.message).lower() for item in w) + _ = Covariate([0.0, 1.0], [1.0, 2.0]) + _ = nspikeTrain([0.25, 0.5], makePlots=-1) + assert not any("deprecated" in str(item.message).lower() for item in w) diff --git a/tests/test_class_fidelity_audit.py b/tests/test_class_fidelity_audit.py new file mode 100644 index 00000000..c8dffa60 --- /dev/null +++ b/tests/test_class_fidelity_audit.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml + + +REPO_ROOT = Path(__file__).resolve().parents[1] +AUDIT_PATH = REPO_ROOT / "parity" / "class_fidelity.yml" +VALID_STATUSES = {"exact", "high_fidelity", "partial", "shim_only", "missing", "not_applicable"} +PRIORITY_CLASSES = { + "SignalObj", + "Covariate", + "Trial", + "TrialConfig", + "ConfigColl", + "nspikeTrain", + "nstColl", + "Analysis", + "FitResult", + "FitResSummary", + "CIF", + "DecodingAlgorithms", + "History", + "Events", +} + + +def _load_audit() -> dict: + payload = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {} + assert payload.get("items"), "parity/class_fidelity.yml is empty" + return payload + + +def test_class_fidelity_audit_covers_priority_classes() -> None: + payload = _load_audit() + names = {str(item.get("matlab_name", "")).strip() for item in payload["items"]} + assert PRIORITY_CLASSES.issubset(names) + + +def test_class_fidelity_audit_uses_known_status_values() -> None: + payload = _load_audit() + for item in payload["items"]: + assert item["status"] in VALID_STATUSES + + +def test_core_matlab_facing_classes_are_not_shim_only() -> None: + payload = _load_audit() + audit_by_name = {str(item["matlab_name"]): item for item in payload["items"]} + + for name in ("SignalObj", "Covariate", "nspikeTrain"): + row = audit_by_name[name] + assert row["status"] not in {"shim_only", "missing"} + assert row["python_path"] == "nstat/core.py" + + +def test_class_fidelity_audit_has_unique_matlab_names() -> None: + payload = _load_audit() + names = [str(item.get("matlab_name", "")).strip() for item in payload["items"]] + assert len(names) == len(set(names)) diff --git a/tests/test_nspiketrain_fidelity.py b/tests/test_nspiketrain_fidelity.py new file mode 100644 index 00000000..0cf2652a --- /dev/null +++ b/tests/test_nspiketrain_fidelity.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import numpy as np + +from nstat.nspikeTrain import nspikeTrain + + +def test_nspiketrain_constructor_runs_statistics_without_numpy_mode_error() -> None: + train = nspikeTrain([0.0, 0.5, 1.0], "neuron") + + assert np.isfinite(train.avgFiringRate) + assert np.isfinite(train.burstIndex) + assert train.Lstatistic is not None + + +def test_nspiketrain_sigrep_uses_matlab_style_centers_and_inclusive_last_bin() -> None: + train = nspikeTrain([0.0, 0.5, 1.0], "neuron", 0.5, 0.0, 1.0, makePlots=-1) + + sig = train.getSigRep() + + np.testing.assert_allclose(sig.time, [0.0, 0.5, 1.0]) + np.testing.assert_allclose(sig.data[:, 0], [1.0, 1.0, 1.0]) + assert train.isSigRepBinary() + + +def test_nspiketrain_windowing_and_binary_limit_follow_matlab_semantics() -> None: + train = nspikeTrain([0.1, 0.4, 0.9], "neuron", makePlots=-1) + + np.testing.assert_allclose(train.getSpikeTimes(0.1, 0.4), [0.1, 0.4]) + np.testing.assert_allclose(train.getISIs(0.1, 0.9), [0.3, 0.5]) + np.testing.assert_allclose(train.getMaxBinSizeBinary(), 0.3) diff --git a/tests/test_signalobj_fidelity.py b/tests/test_signalobj_fidelity.py new file mode 100644 index 00000000..9de9b050 --- /dev/null +++ b/tests/test_signalobj_fidelity.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import numpy as np + +from nstat.Covariate import Covariate +from nstat.SignalObj import SignalObj + + +def test_signalobj_normalizes_channel_orientation_and_uses_one_based_selection() -> None: + sig = SignalObj( + [0.0, 0.5, 1.0], + [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]], + "stim", + "time", + "s", + "a.u.", + ["x", "y"], + ) + + assert sig.data.shape == (3, 2) + np.testing.assert_allclose(sig.data[:, 0], [1.0, 2.0, 3.0]) + np.testing.assert_allclose(sig.data[:, 1], [10.0, 20.0, 30.0]) + + sub = sig.getSubSignal(2) + assert sub.dimension == 1 + assert sub.dataLabels == ["y"] + np.testing.assert_allclose(sub.data[:, 0], [10.0, 20.0, 30.0]) + + +def test_signalobj_time_window_padding_matches_matlab_style_hold_values() -> None: + sig = SignalObj([0.0, 1.0], [5.0, 7.0], "stim") + sig.setMinTime(-1.0, holdVals=1) + sig.setMaxTime(2.0, holdVals=1) + + np.testing.assert_allclose(sig.time, [-1.0, 0.0, 1.0, 2.0]) + np.testing.assert_allclose(sig.data[:, 0], [5.0, 5.0, 7.0, 7.0]) + + +def test_covariate_compute_mean_plus_ci_uses_timewise_mean() -> None: + cov = Covariate( + [0.0, 1.0, 2.0], + [[1.0, 3.0], [2.0, 4.0], [6.0, 8.0]], + "lambda", + "time", + "s", + "spikes/sec", + ["trial1", "trial2"], + ) + + mean_cov = cov.computeMeanPlusCI(alphaVal=0.5) + + np.testing.assert_allclose(mean_cov.time, [0.0, 1.0, 2.0]) + np.testing.assert_allclose(mean_cov.data[:, 0], [2.0, 3.0, 7.0]) + assert mean_cov.isConfIntervalSet() + assert mean_cov.ci is not None + assert len(mean_cov.ci) == 1 From b167b1118015c432d8d6b2bfa859509197de2f36 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 11:23:22 -0500 Subject: [PATCH 2/3] Add parity-core notebook execution tier --- .github/workflows/ci.yml | 17 ++++++++++++ tests/test_notebook_ci_groups.py | 44 ++++++++++++++++++++++++++++++++ tools/notebooks/topic_groups.yml | 17 ++++++++++++ 3 files changed, 78 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 77ad3fae..e5006160 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -134,6 +134,23 @@ jobs: - name: Execute Python notebook smoke group run: python tools/notebooks/run_notebooks.py --group ci_smoke --timeout 600 + notebook-parity-core: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[dev] + python -m pip install ipykernel + python -m ipykernel install --user --name python3 --display-name "Python 3" + - name: Execute Python notebook parity-core group + run: python tools/notebooks/run_notebooks.py --group parity_core --timeout 900 + cleanroom-compliance: runs-on: ubuntu-latest diff --git a/tests/test_notebook_ci_groups.py b/tests/test_notebook_ci_groups.py index b7b67096..21301f45 100644 --- a/tests/test_notebook_ci_groups.py +++ b/tests/test_notebook_ci_groups.py @@ -13,6 +13,19 @@ "ConfidenceIntervalOverview", "nSTATPaperExamples", } +REQUIRED_PARITY_CORE_TOPICS = { + "AnalysisExamples", + "DecodingExample", + "DecodingExampleWithHist", + "ExplicitStimulusWhiskerData", + "HippocampalPlaceCellExample", + "HybridFilterExample", + "SignalObjExamples", + "TrialExamples", + "ValidationDataSet", + "nSTATPaperExamples", + "nSpikeTrainExamples", +} def test_ci_smoke_group_covers_required_parity_notebooks() -> None: @@ -35,3 +48,34 @@ def test_ci_smoke_group_topics_exist_in_notebook_manifest() -> None: missing = [topic for topic in ci_smoke if topic not in notebook_topics] assert not missing, f"CI smoke group references unknown notebook topics: {missing}" + + +def test_parity_core_group_covers_required_helpfile_parity_notebooks() -> None: + notebook_manifest = yaml.safe_load(NOTEBOOK_MANIFEST_PATH.read_text(encoding="utf-8")) or {} + notebook_topics = {row["topic"] for row in notebook_manifest.get("notebooks", [])} + + groups_payload = yaml.safe_load(TOPIC_GROUPS_PATH.read_text(encoding="utf-8")) or {} + parity_core = set(groups_payload.get("groups", {}).get("parity_core", [])) + + assert REQUIRED_PARITY_CORE_TOPICS <= notebook_topics + assert REQUIRED_PARITY_CORE_TOPICS <= parity_core + + +def test_parity_core_group_extends_ci_smoke_coverage() -> None: + groups_payload = yaml.safe_load(TOPIC_GROUPS_PATH.read_text(encoding="utf-8")) or {} + groups = groups_payload.get("groups", {}) + ci_smoke = set(groups.get("ci_smoke", [])) + parity_core = set(groups.get("parity_core", [])) + + assert ci_smoke < parity_core + + +def test_parity_core_group_topics_exist_in_notebook_manifest() -> None: + notebook_manifest = yaml.safe_load(NOTEBOOK_MANIFEST_PATH.read_text(encoding="utf-8")) or {} + notebook_topics = {row["topic"] for row in notebook_manifest.get("notebooks", [])} + + groups_payload = yaml.safe_load(TOPIC_GROUPS_PATH.read_text(encoding="utf-8")) or {} + parity_core = groups_payload.get("groups", {}).get("parity_core", []) + + missing = [topic for topic in parity_core if topic not in notebook_topics] + assert not missing, f"parity_core group references unknown notebook topics: {missing}" diff --git a/tools/notebooks/topic_groups.yml b/tools/notebooks/topic_groups.yml index 5b863106..4e04b4ba 100644 --- a/tools/notebooks/topic_groups.yml +++ b/tools/notebooks/topic_groups.yml @@ -30,3 +30,20 @@ groups: - mEPSCAnalysis - nSTATPaperExamples - nSpikeTrainExamples + parity_core: + - AnalysisExamples + - ConfidenceIntervalOverview + - ConfigCollExamples + - CovariateExamples + - DecodingExample + - DecodingExampleWithHist + - ExplicitStimulusWhiskerData + - HippocampalPlaceCellExample + - HybridFilterExample + - SignalObjExamples + - StimulusDecode2D + - TrialConfigExamples + - TrialExamples + - ValidationDataSet + - nSTATPaperExamples + - nSpikeTrainExamples From d037bb9454ca804e40820526574e1f2dbbeb1fd9 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sat, 7 Mar 2026 11:25:08 -0500 Subject: [PATCH 3/3] Report class-fidelity deltas separately --- nstat/parity_report.py | 76 +++++++++++++++++++++++++++++++++++-- parity/README.md | 2 + parity/report.md | 39 +++++++++++++++++-- tests/test_parity_report.py | 5 ++- 4 files changed, 114 insertions(+), 8 deletions(-) diff --git a/nstat/parity_report.py b/nstat/parity_report.py index 6020e41d..f61841fd 100644 --- a/nstat/parity_report.py +++ b/nstat/parity_report.py @@ -31,6 +31,26 @@ def load_parity_manifest(repo_root: Path | None = None) -> dict[str, Any]: return yaml.safe_load(path.read_text(encoding="utf-8")) +def load_class_fidelity_audit(repo_root: Path | None = None) -> dict[str, Any]: + base = _repo_root() if repo_root is None else repo_root.resolve() + path = base / "parity" / "class_fidelity.yml" + return yaml.safe_load(path.read_text(encoding="utf-8")) + + +def _summarize_class_fidelity(payload: dict[str, Any]) -> dict[str, int]: + counts = {status: 0 for status in payload.get("status_legend", [])} + for row in payload.get("items", []): + status = str(row.get("status", "")).strip() + if status not in counts: + counts[status] = 0 + counts[status] += 1 + return counts + + +def _iter_class_fidelity_rows(payload: dict[str, Any], statuses: set[str]) -> list[dict[str, Any]]: + return [row for row in payload.get("items", []) if row.get("status") in statuses] + + def _iter_outstanding_rows(payload: dict[str, Any]) -> list[tuple[str, dict[str, Any]]]: rows: list[tuple[str, dict[str, Any]]] = [] for section_name in SUMMARY_SECTIONS: @@ -51,10 +71,12 @@ def _iter_non_applicable_rows(payload: dict[str, Any]) -> list[tuple[str, dict[s def render_parity_report(repo_root: Path | None = None) -> str: payload = load_parity_manifest(repo_root) + class_fidelity = load_class_fidelity_audit(repo_root) + class_counts = _summarize_class_fidelity(class_fidelity) lines = [ "# nSTAT Python Parity Report", "", - "Generated from `parity/manifest.yml`.", + "Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`.", "", f"- MATLAB reference: {payload['source_repositories']['matlab']}", f"- Python target: {payload['source_repositories']['python']}", @@ -73,6 +95,18 @@ def render_parity_report(repo_root: Path | None = None) -> str: f"| `{label}` | {counts['mapped']} | {counts['partial']} | {counts['missing']} | {counts['not_applicable']} |" ) + lines.extend( + [ + "", + "## Class Fidelity Summary", + "", + "| Status | Count |", + "|---|---:|", + ] + ) + for status in class_fidelity.get("status_legend", []): + lines.append(f"| `{status}` | {class_counts.get(status, 0)} |") + lines.extend(["", "## Coverage Notes", ""]) lines.append( "- Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable." @@ -86,11 +120,18 @@ def render_parity_report(repo_root: Path | None = None) -> str: lines.append( "- Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped." ) - lines.extend(["", "## Remaining Deltas", ""]) + priority_remaining = _iter_class_fidelity_rows(class_fidelity, {"partial", "shim_only", "missing"}) + if not priority_remaining: + lines.append("- Class fidelity: the class audit reports no partial, shim-only, or missing items.") + else: + lines.append( + "- Class fidelity: mapping parity is ahead of semantic parity; the audit still reports partial fidelity for several MATLAB-facing classes and workflows." + ) + lines.extend(["", "## Remaining Mapping Deltas", ""]) outstanding = _iter_outstanding_rows(payload) if not outstanding: - lines.append("No partial or missing items remain.") + lines.append("No partial or missing items remain in the mapping inventory.") else: current_section = "" for section_name, row in outstanding: @@ -105,15 +146,41 @@ def render_parity_report(repo_root: Path | None = None) -> str: notes = row.get("notes", "") lines.append(f"- `{label}` -> `{python_target}`: {notes}") + lines.extend(["", "## Remaining Class-Fidelity Deltas", ""]) + if not priority_remaining: + lines.append("No partial, shim-only, or missing class-fidelity items remain.") + else: + for row in priority_remaining: + label = row.get("matlab_name") or row.get("python_symbol") or row.get("matlab_path") + python_target = row.get("python_symbol") or row.get("python_path") + recommendation = row.get("recommended_remediation", []) + if isinstance(recommendation, list): + recommendation_text = recommendation[0] if recommendation else "" + else: + recommendation_text = str(recommendation) + note = row.get("method_parity", "") + detail = recommendation_text or note + lines.append(f"- `{label}` -> `{python_target}` [{row['status']}]: {detail}") + lines.extend(["", "## Justified Non-Applicable Items", ""]) non_applicable = _iter_non_applicable_rows(payload) + class_non_applicable = _iter_class_fidelity_rows(class_fidelity, {"not_applicable"}) if not non_applicable: - lines.append("None.") + if not class_non_applicable: + lines.append("None.") else: for section_name, row in non_applicable: label = row.get("matlab") or row.get("path") or row.get("name") notes = row.get("notes", "") lines.append(f"- `{section_name}`: `{label}`. {notes}") + for row in class_non_applicable: + label = row.get("matlab_name") or row.get("matlab_path") + notes = row.get("known_semantic_differences", []) + if isinstance(notes, list): + note_text = notes[0] if notes else "" + else: + note_text = str(notes) + lines.append(f"- `class_fidelity`: `{label}`. {note_text}") lines.append("") return "\n".join(lines) @@ -127,6 +194,7 @@ def write_parity_report(repo_root: Path | None = None) -> Path: __all__ = [ + "load_class_fidelity_audit", "load_parity_manifest", "render_parity_report", "write_parity_report", diff --git a/parity/README.md b/parity/README.md index 8eb3b0fc..0dd1c0ef 100644 --- a/parity/README.md +++ b/parity/README.md @@ -4,6 +4,7 @@ This directory tracks MATLAB-to-Python parity for the standalone port. Current inventory source: - [`manifest.yml`](./manifest.yml) +- [`class_fidelity.yml`](./class_fidelity.yml) - [`report.md`](./report.md) Generated report: @@ -14,6 +15,7 @@ python tools/parity/build_report.py Current headline status: - Public API coverage matches the MATLAB inventory except for the explicitly non-applicable `nstatOpenHelpPage`. +- Class-fidelity auditing is tracked separately from name-mapping parity in `class_fidelity.yml`, and it remains intentionally stricter and more conservative than the mapping manifest. - Help/notebook parity covers the inventoried MATLAB help workflow surface, including the top-level `NeuralSpikeAnalysis_top`, `PaperOverview`, `Examples`, and `ClassDefinitions` navigation pages. - Canonical paper examples, gallery structure, and README/docs presentation are committed and mapped in Python. - CI now validates API surface, dataset integrity, notebook smoke execution, paper gallery drift, and docs builds. diff --git a/parity/report.md b/parity/report.md index 2a48e9ee..f7dbd449 100644 --- a/parity/report.md +++ b/parity/report.md @@ -1,6 +1,6 @@ # nSTAT Python Parity Report -Generated from `parity/manifest.yml`. +Generated from `parity/manifest.yml` and `parity/class_fidelity.yml`. - MATLAB reference: https://github.com/cajigaslab/nSTAT - Python target: https://github.com/cajigaslab/nSTAT-python @@ -18,15 +18,47 @@ Generated from `parity/manifest.yml`. | `installer setup` | 4 | 0 | 0 | 3 | | `repo structure` | 1 | 0 | 0 | 0 | +## Class Fidelity Summary + +| Status | Count | +|---|---:| +| `exact` | 0 | +| `high_fidelity` | 1 | +| `partial` | 17 | +| `shim_only` | 0 | +| `missing` | 0 | +| `not_applicable` | 1 | + ## Coverage Notes - Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable. - Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents. - Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped. +- Class fidelity: mapping parity is ahead of semantic parity; the audit still reports partial fidelity for several MATLAB-facing classes and workflows. + +## Remaining Mapping Deltas + +No partial or missing items remain in the mapping inventory. -## Remaining Deltas +## Remaining Class-Fidelity Deltas -No partial or missing items remain. +- `SignalObj` -> `nstat.SignalObj` [partial]: Port arithmetic, filtering, plotting, and structure round-trip methods from MATLAB. +- `Covariate` -> `nstat.Covariate` [partial]: Port arithmetic overloads and CI plotting semantics from MATLAB. +- `nspikeTrain` -> `nstat.nspikeTrain` [partial]: Port burst/statistics helpers and plotting routines. +- `nstColl` -> `nstat.nstColl` [partial]: Port the remaining collection methods from MATLAB and move the class into a canonical MATLAB-facing implementation file. +- `Trial` -> `nstat.Trial` [partial]: Port richer trial state, consistency checks, and MATLAB workflow helpers. +- `TrialConfig` -> `nstat.TrialConfig` [partial]: Port MATLAB configuration validation, normalization, and selection workflows. +- `ConfigColl` -> `nstat.ConfigColl` [partial]: Port the remaining ConfigColl helpers and name/selection semantics from MATLAB. +- `Analysis` -> `nstat.Analysis` [partial]: Port MATLAB analysis options and representative workflow outputs into dataset-backed tests. +- `FitResult` -> `nstat.FitResult` [partial]: Port MATLAB result-summary and reporting APIs with golden fixtures. +- `FitResSummary` -> `nstat.FitResSummary` [partial]: Port summary aggregation and reporting semantics from MATLAB. +- `CIF` -> `nstat.CIF` [partial]: Port MATLAB CIF behaviors used by simulation, fitting, and decoding workflows. +- `DecodingAlgorithms` -> `nstat.DecodingAlgorithms` [partial]: Port canonical decoding workflows and validate them against MATLAB-derived outputs. +- `History` -> `nstat.History` [partial]: Port full History object workflows and fixture-backed outputs. +- `Events` -> `nstat.Events` [partial]: Port event validation, color handling, and notebook-backed workflows. +- `ConfidenceInterval` -> `nstat.ConfidenceInterval` [partial]: Port MATLAB plotting and serialization semantics. +- `CovColl` -> `nstat.CovColl` [partial]: Port remaining CovColl behaviors and helpfile workflows. +- `nSTAT_Install` -> `nstat.nSTAT_Install` [partial]: Keep documenting the no-op compatibility behavior and test installer status outputs. ## Justified Non-Applicable Items @@ -34,3 +66,4 @@ No partial or missing items remain. - `installer_setup`: `CleanUserPathPrefs option`. Accepted as a compatibility no-op because Python does not use MATLAB-style saved user path preferences. - `installer_setup`: `MATLAB runtime path pruning`. Python packaging/import resolution replaces MATLAB path management. - `installer_setup`: `MATLAB toolbox cache refresh and savepath`. There is no Python equivalent to MATLAB toolbox cache refresh or savepath persistence. +- `class_fidelity`: `nstatOpenHelpPage`. Python uses Sphinx docs pages instead of the MATLAB help browser. diff --git a/tests/test_parity_report.py b/tests/test_parity_report.py index d67d49ca..9356ad46 100644 --- a/tests/test_parity_report.py +++ b/tests/test_parity_report.py @@ -19,5 +19,8 @@ def test_parity_report_highlights_current_constraints() -> None: assert "no missing MATLAB public APIs remain" in text assert "paper examples and docs gallery" in text.lower() assert "all canonical paper examples and committed gallery directories are mapped" in text - assert "No partial or missing items remain." in text + assert "class fidelity" in text.lower() + assert "No partial or missing items remain in the mapping inventory." in text + assert "Remaining Class-Fidelity Deltas" in text + assert "SignalObj" in text assert "nstatOpenHelpPage" in text