From b10144702c1be4390f0f6753f3be8f63689561aa Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 09:20:30 -0400 Subject: [PATCH 1/2] Tighten exactness fixtures for core classes --- nstat/analysis.py | 6 +- nstat/core.py | 111 ++-- nstat/history.py | 177 +++++-- nstat/trial.py | 494 +++++++++++++++++- parity/class_fidelity.yml | 133 +++-- parity/manifest.yml | 4 +- parity/report.md | 4 +- .../matlab_gold/analysis_exactness.mat | Bin 1704 -> 1704 bytes .../analysis_multineuron_exactness.mat | Bin 0 -> 1389 bytes .../fixtures/matlab_gold/cif_exactness.mat | Bin 1157 -> 1157 bytes .../confidence_interval_exactness.mat | Bin 1600 -> 1600 bytes .../fixtures/matlab_gold/config_exactness.mat | Bin 2665 -> 2665 bytes .../matlab_gold/covariate_exactness.mat | Bin 576 -> 1500 bytes .../matlab_gold/covcoll_exactness.mat | Bin 0 -> 1488 bytes .../decoding_predict_exactness.mat | Bin 493 -> 493 bytes .../decoding_smoother_exactness.mat | Bin 774 -> 774 bytes .../fixtures/matlab_gold/events_exactness.mat | Bin 812 -> 812 bytes .../matlab_gold/fit_summary_exactness.mat | Bin 790 -> 790 bytes .../matlab_gold/history_exactness.mat | Bin 0 -> 10701 bytes .../matlab_gold/hybrid_filter_exactness.mat | Bin 1530 -> 1530 bytes .../matlab_gold/ksdiscrete_exactness.mat | Bin 1288 -> 1288 bytes .../nonlinear_decode_exactness.mat | Bin 1097 -> 1097 bytes .../matlab_gold/nspiketrain_exactness.mat | Bin 1451 -> 2504 bytes .../matlab_gold/nstcoll_exactness.mat | Bin 758 -> 1604 bytes .../matlab_gold/point_process_exactness.mat | Bin 1303 -> 1303 bytes .../matlab_gold/signalobj_exactness.mat | Bin 1310 -> 1310 bytes .../simulated_network_exactness.mat | Bin 1469 -> 1469 bytes .../matlab_gold/thinning_exactness.mat | Bin 1149 -> 1149 bytes tests/test_matlab_gold_fixtures.py | 259 ++++++++- tests/test_trial_fidelity.py | 17 + .../matlab/export_matlab_gold_fixtures.m | 234 +++++++++ 31 files changed, 1291 insertions(+), 148 deletions(-) create mode 100644 tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat create mode 100644 tests/parity/fixtures/matlab_gold/covcoll_exactness.mat create mode 100644 tests/parity/fixtures/matlab_gold/history_exactness.mat diff --git a/nstat/analysis.py b/nstat/analysis.py index 55484313..f5ff546e 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -327,7 +327,7 @@ def run_analysis_for_neuron( merged_lambda = merged_lambda.merge(part) _restore_trial_partition(trial, original_partition) - return FitResult( + fit_result = FitResult( spike_train, labels, numHist, @@ -346,6 +346,10 @@ def run_analysis_for_neuron( distributions, fits=fits, ) + # MATLAB returns fits with KS diagnostics already populated, and + # downstream summary classes read those cached fields directly. + fit_result.computeKSStats() + return fit_result @staticmethod def run_analysis_for_all_neurons( diff --git a/nstat/core.py b/nstat/core.py index 443927e4..1c2381a7 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -266,6 +266,9 @@ def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None self.plotProps = [plotProps for _ in range(self.dimension)] else: props = list(plotProps) + if len(props) == 0: + self.plotProps = [None for _ in range(self.dimension)] + return if len(props) == 1 and self.dimension > 1: props = props * self.dimension if len(props) != self.dimension: @@ -945,6 +948,9 @@ def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> dict[str, Any]: data = self.dataToMatrix(selectorArray) + plot_props = list(self.plotProps) + if all(prop is None for prop in plot_props): + plot_props = [] return { "time": self.time.tolist(), "data": data.tolist(), @@ -953,7 +959,7 @@ def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = Non "xunits": self.xunits, "yunits": self.yunits, "dataLabels": list(self.dataLabels), - "plotProps": list(self.plotProps), + "plotProps": plot_props, } def toStructure(self) -> dict[str, Any]: @@ -974,6 +980,7 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": def plot(self, selectorArray=None, plotPropsIn=None, handle=None): import matplotlib.pyplot as plt + from .confidence_interval import MATLAB_COLOR_ORDER ax = plt.gca() if handle is None else handle signal = self.getSubSignal(selectorArray) if selectorArray is not None else self.getSubSignal(self.findIndFromDataMask() or list(range(1, self.dimension + 1))) @@ -989,6 +996,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): prop = props[index] if isinstance(prop, str) and prop: kwargs["fmt"] = prop + elif prop is None: + kwargs["color"] = MATLAB_COLOR_ORDER[index % MATLAB_COLOR_ORDER.shape[0]] if "fmt" in kwargs: fmt = kwargs.pop("fmt") line = ax.plot(signal.time, signal.data[:, index], fmt, **kwargs) @@ -1075,6 +1084,7 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): lines = super().plot(selectorArray, plotPropsIn, handle) if self.isConfIntervalSet(): import matplotlib.pyplot as plt + import matplotlib.colors as mcolors ax = plt.gca() if handle is None else handle selectors = self.findIndFromDataMask() if selectorArray is None else ( @@ -1086,6 +1096,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): selectors = [item[0] for item in selectors] for line_index, selector in enumerate(selectors): color = getattr(lines[line_index], "get_color", lambda: "b")() + if isinstance(color, (str, bytes)): + color = mcolors.to_rgb(color) self.ci[selector - 1].plot(color, ax=ax) return lines @@ -1176,17 +1188,37 @@ def toStructure(self) -> dict[str, Any]: if self.isConfIntervalSet(): ci_payload: list[dict[str, Any]] = [] for item in self.ci or []: - if hasattr(item, "time") and hasattr(item, "bounds"): - ci_payload.append( - { - "time": np.asarray(item.time, dtype=float).tolist(), - "bounds": np.asarray(item.bounds, dtype=float).tolist(), - "color": getattr(item, "color", "b"), - } - ) - structure["ci"] = ci_payload + if hasattr(item, "dataToStructure"): + ci_payload.append(item.dataToStructure()) + if ci_payload: + structure["ci"] = ci_payload[0] if len(ci_payload) == 1 else ci_payload return structure + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "Covariate": + from .confidence_interval import ConfidenceInterval + + cov = Covariate( + structure["time"], + structure["data"], + structure.get("name", ""), + structure.get("xlabelval", "time"), + structure.get("xunits", "s"), + structure.get("yunits", ""), + structure.get("dataLabels"), + structure.get("plotProps"), + ) + ci_payload = structure.get("ci") + if ci_payload is None: + return cov + if isinstance(ci_payload, list): + cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload]) + elif isinstance(ci_payload, tuple): + cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload]) + else: + cov.setConfInterval(ConfidenceInterval.fromStructure(ci_payload)) + return cov + class nspikeTrain: """Closer MATLAB-style spike-train object with cached signal representation.""" @@ -1405,11 +1437,15 @@ def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> Sign return sig def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj: - self.sigRep = self.getSigRep(binwidth, minTime, maxTime) - self.isSigRepBin = self.isSigRepBinary() - self.sampleRate = float(self.sigRep.sampleRate) - self.minTime = float(self.sigRep.minTime) - self.maxTime = float(self.sigRep.maxTime) + sig = self.getSigRep(binwidth, minTime, maxTime) + self.sigRep = sig.copySignal() + self.sampleRate = float(sig.sampleRate) + self.isSigRepBin = bool(np.max(np.asarray(sig.data, dtype=float)) <= 1.0) + # Keep the freshly-built cached representation alive instead of + # clearing it through the public min/max setters. + self.minTime = float(sig.minTime) + self.maxTime = float(sig.maxTime) + self.computeStatistics(-1) return self.sigRep def clearSigRep(self) -> None: @@ -1420,14 +1456,12 @@ def clearSigRep(self) -> None: def setMinTime(self, minTime: float) -> None: self.minTime = float(minTime) self.clearSigRep() - if self.avgFiringRate is not None: - self.computeStatistics(-1) + self.computeStatistics(-1) def setMaxTime(self, maxTime: float) -> None: self.maxTime = float(maxTime) self.clearSigRep() - if self.avgFiringRate is not None: - self.computeStatistics(-1) + self.computeStatistics(-1) def resample(self, sampleRate: float) -> "nspikeTrain": self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime) @@ -1476,8 +1510,9 @@ def getMaxBinSizeBinary(self) -> float: return float(np.min(isi)) def isSigRepBinary(self) -> bool: - if self.isSigRepBin is None: - self.getSigRep() + default_key = self._cache_key(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) + if self._sigrep_cache_key != default_key or self.isSigRepBin is None: + self.getSigRep(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) return bool(self.isSigRepBin) def computeRate(self) -> SignalObj: @@ -1553,14 +1588,19 @@ def plotJointISIHistogram(self): ax = plt.subplots(1, 1, figsize=(4.5, 4.0))[1] isi = self.getISIs() if isi.size >= 2: - ax.loglog(isi[:-1], isi[1:], ".") + xvals = np.asarray(isi[:-1], dtype=float).reshape(-1) + yvals = np.asarray(isi[1:], dtype=float).reshape(-1) + ax.loglog(xvals, yvals, ".") mean_isi = float(np.mean(isi)) ln = isi[isi < mean_isi] ml = float(np.mean(ln)) if ln.size else np.nan if np.isfinite(ml) and ml > 0: - v = ax.axis() - ax.loglog([ml, ml], [v[2], v[3]], "k--") - ax.loglog([v[0], v[1]], [ml, ml], "k--") + ymin = float(np.min(yvals)) + ymax = float(np.max(yvals)) + xmin = float(np.min(xvals)) + xmax = float(np.max(xvals)) + ax.loglog([ml, ml], [ymin, ymax], "k--") + ax.loglog([xmin, xmax], [ml, ml], "k--") ax.set_xlabel("ISI(t) [s]") ax.set_ylabel("ISI(t+1) [s]") return ax @@ -1582,14 +1622,21 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = bins = np.arange(0.0, float(np.max(isi)) + bin_width, bin_width, dtype=float) if bins.size < 2: bins = np.array([0.0, bin_width], dtype=float) - counts, edges = np.histogram(isi, bins=bins) - centers = edges[:-1] + idx = np.searchsorted(bins, isi, side="right") - 1 + idx = np.where( + np.isclose(isi, bins[-1], rtol=0.0, atol=max(1e-12, bin_width * 1e-9)), + bins.size - 1, + idx, + ) + idx = np.clip(idx, 0, bins.size - 1) + counts = np.bincount(idx, minlength=bins.size).astype(float) + centers = bins ax.bar( centers, counts, width=bin_width, align="edge", - edgecolor=(0.0, 0.0, 0.0), + edgecolor="none", linewidth=2.0, color=(0.831372559070587, 0.815686285495758, 0.7843137383461), ) @@ -1600,7 +1647,6 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = None, handle=None): import matplotlib.pyplot as plt - from scipy import stats ax = plt.gca() if handle is None else handle if maxTime is None: @@ -1610,8 +1656,11 @@ def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = Non isi = self.getISIs(minTime, maxTime) ax.clear() if isi.size: - stats.probplot(isi, dist=stats.expon, plot=ax) - ax.set_title(ax.get_title() or "Probability Plot") + sorted_isi = np.sort(np.asarray(isi, dtype=float).reshape(-1)) + n = sorted_isi.size + p = (np.arange(1, n + 1, dtype=float) - 0.5) / float(n) + exp_quantiles = -np.log(1.0 - p) + ax.plot(sorted_isi, exp_quantiles, linestyle="none", marker=".") return ax def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): diff --git a/nstat/history.py b/nstat/history.py index ea33ff9f..00d749fb 100644 --- a/nstat/history.py +++ b/nstat/history.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from collections.abc import Sequence from typing import Any @@ -9,7 +10,63 @@ import matplotlib.pyplot as plt import numpy as np -from .core import Covariate, nspikeTrain +from .core import Covariate, SignalObj, nspikeTrain + + +@dataclass(frozen=True) +class HistoryFilter: + """Discrete-time MATLAB-style transfer function numerator/denominator pair.""" + + numerator: np.ndarray + denominator: np.ndarray + delta: float + variable: str = "z^-1" + + +@dataclass(frozen=True) +class HistoryFilterBank: + """Matrix-like collection of discrete history-window filters.""" + + numerators: tuple[np.ndarray, ...] + denominators: tuple[np.ndarray, ...] + delta: float + variable: str = "z^-1" + + @property + def shape(self) -> tuple[int, int]: + return (len(self.numerators), 1) + + @property + def numFilters(self) -> int: + return len(self.numerators) + + def __len__(self) -> int: + return self.numFilters + + def __getitem__(self, index: int) -> HistoryFilter: + return HistoryFilter( + numerator=np.asarray(self.numerators[index], dtype=float).copy(), + denominator=np.asarray(self.denominators[index], dtype=float).copy(), + delta=float(self.delta), + variable=self.variable, + ) + + def combine(self, coefficients) -> HistoryFilter: + coeffs = np.asarray(coefficients, dtype=float).reshape(-1) + if coeffs.size != self.numFilters: + raise ValueError("Number of coefficients must match the number of history filters.") + max_len = max(len(numerator) for numerator in self.numerators) + padded = np.zeros((self.numFilters, max_len), dtype=float) + for idx, numerator in enumerate(self.numerators): + arr = np.asarray(numerator, dtype=float).reshape(-1) + padded[idx, : arr.size] = arr + numerator = coeffs @ padded + denominator = np.zeros(max_len, dtype=float) + denominator[0] = 1.0 + return HistoryFilter(numerator=np.asarray(numerator, dtype=float), denominator=denominator, delta=float(self.delta), variable=self.variable) + + def __rmatmul__(self, coefficients) -> HistoryFilter: + return self.combine(coefficients) class History: @@ -23,8 +80,8 @@ def __init__(self, windowTimes, minTime: float | None = None, maxTime: float | N raise ValueError("windowTimes must be strictly increasing") self.windowTimes = times - self.minTime = float(times[0] if minTime is None else minTime) - self.maxTime = float(times[-1] if maxTime is None else maxTime) + self.minTime = None if minTime is None else float(minTime) + self.maxTime = None if maxTime is None else float(maxTime) self.name = str(name) @property @@ -41,40 +98,84 @@ def setWindow(self, windowTimes) -> None: self.minTime = replacement.minTime self.maxTime = replacement.maxTime + def toFilter(self, delta: float) -> HistoryFilterBank: + delta = float(delta) + if delta <= 0: + raise ValueError("delta must be positive") + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + numerators: list[np.ndarray] = [] + denominators: list[np.ndarray] = [] + for row, (window_start, window_stop) in enumerate(zip(tmin, tmax)): + num_samples = int(np.ceil(float(window_stop) / delta)) + start_sample = int(np.ceil(float(window_start) / delta)) + 1 + del row + numerator = np.zeros(num_samples + 1, dtype=float) + denominator = np.zeros(num_samples + 1, dtype=float) + denominator[0] = 1.0 + numerator[start_sample : num_samples + 1] = 1.0 + numerators.append(numerator) + denominators.append(denominator) + return HistoryFilterBank(numerators=tuple(numerators), denominators=tuple(denominators), delta=delta) + def _compute_single_history(self, train: nspikeTrain, historyIndex: int | None = None, time_grid=None) -> Covariate: - if time_grid is None: - sigrep = train.getSigRep() - time = np.asarray(sigrep.time, dtype=float).reshape(-1) - else: - time = np.asarray(time_grid, dtype=float).reshape(-1) - spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1) - history = np.zeros((time.size, self.numWindows), dtype=float) - - for col, (window_start, window_stop) in enumerate(zip(self.windowTimes[:-1], self.windowTimes[1:])): - left = time - float(window_stop) - right = time - float(window_start) - history[:, col] = np.searchsorted(spikes, right, side="left") - np.searchsorted(spikes, left, side="left") - - label_prefix = train.name or f"neuron_{historyIndex or 1}" - labels = [ - f"{label_prefix}_hist_{col + 1}" - for col in range(self.numWindows) - ] - return Covariate(time, history, self.name, "time", "s", "count", labels) + sigrep = train.getSigRep() if time_grid is None else train.getSigRep(None, float(np.min(time_grid)), float(np.max(time_grid))) + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + data_columns: list[np.ndarray] = [] + data_labels: list[str] = [] + + for window_start, window_stop in zip(tmin, tmax, strict=False): + num_samples = int(np.ceil(float(window_stop) * float(train.sampleRate))) + numerator = np.zeros(max(num_samples, 0), dtype=float) + start_sample = int(np.ceil(float(window_start) * float(train.sampleRate))) + 1 + if num_samples > 0 and start_sample <= num_samples: + numerator[max(start_sample - 1, 0) : num_samples] = 1.0 + filtered = sigrep.filter(numerator if numerator.size else [0.0], [1.0]) + delayed = filtered.filter([0.0, 1.0], [1.0]) + data_columns.append(np.asarray(delayed.dataToMatrix(), dtype=float)) + if historyIndex is None: + data_labels.append(f"[{window_start:.3g},{window_stop:.3g}]") + else: + data_labels.append(f"[{window_start:.3g},{window_stop:.3g}]_{historyIndex}") + + data = np.hstack(data_columns) if data_columns else np.zeros((sigrep.time.size, 0), dtype=float) + name = "History" if not getattr(train, "name", "") else f"History {train.name}" + cov = Covariate(sigrep.time, data, name, sigrep.xlabelval, sigrep.xunits, sigrep.yunits, data_labels) + + if time_grid is not None: + return cov + + if data.size == 0: + return Covariate([], data, name, sigrep.xlabelval, sigrep.xunits, sigrep.yunits, data_labels) + + if (self.minTime is not None or self.maxTime is not None) and round(float(cov.sampleRate), 9) != round(float(train.sampleRate), 9): + cov.resampleMe(float(train.sampleRate)) + min_time = float(cov.minTime) if self.minTime is None else float(self.minTime) + max_time = float(cov.maxTime) if self.maxTime is None else float(self.maxTime) + windowed = cov.getSigInTimeWindow(min_time, max_time) + windowed.setMinTime(float(train.minTime)) + windowed.setMaxTime(float(train.maxTime)) + windowed.minTime = float(train.minTime) + windowed.maxTime = float(train.maxTime) + return windowed def compute_history(self, trains, historyIndex: int | None = None, time_grid=None): from .trial import CovariateCollection if isinstance(trains, nspikeTrain): - return CovariateCollection([self._compute_single_history(trains, historyIndex, time_grid=time_grid)]) + cov = self._compute_single_history(trains, historyIndex, time_grid=time_grid) + if historyIndex is not None: + cov.name = f"History #{historyIndex} for {trains.name}" + return CovariateCollection([cov]) if hasattr(trains, "getNST") and hasattr(trains, "numSpikeTrains"): covariates = [ - self._compute_single_history(trains.getNST(index), index, time_grid=time_grid) + self._compute_single_history(trains.getNST(index), historyIndex, time_grid=time_grid) for index in range(1, int(trains.numSpikeTrains) + 1) ] return CovariateCollection(covariates) if isinstance(trains, Sequence) and not isinstance(trains, (str, bytes, np.ndarray)): - covariates = [self._compute_single_history(train, index, time_grid=time_grid) for index, train in enumerate(trains, start=1)] + covariates = [self._compute_single_history(train, historyIndex, time_grid=time_grid) for train in trains] return CovariateCollection(covariates) raise TypeError("History can only be computed from nspikeTrain, nstColl, or sequences of nspikeTrain") @@ -108,19 +209,25 @@ def fromStructure(structure: dict[str, Any] | None) -> "History" | None: ) def plot(self, *_, handle=None, **__): + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + sampleRate = 1000.0 + num_samples = max(1, int(round((float(np.max(tmax)) - float(np.min(tmin))) * sampleRate))) + data = np.zeros((num_samples, tmax.size), dtype=float) + dataLabels: list[str] = [] + for index, (start, stop) in enumerate(zip(tmin, tmax)): + indMin = max(1, int(round((float(start) - float(np.min(tmin))) * sampleRate))) + indMax = int(round((float(stop) - float(np.min(tmin))) * sampleRate)) + if indMax >= indMin: + data[indMin - 1 : indMax, index] = 1.0 + dataLabels.append(f"[{start:.3g},{stop:.3g}]") + time = np.linspace(float(np.min(tmin)), float(np.max(tmax)), num_samples) + signal = SignalObj(time, data, "History", "time", "s", "", dataLabels) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 2.2))[1] - ax.clear() - for idx, (start, stop) in enumerate(zip(self.windowTimes[:-1], self.windowTimes[1:]), start=1): - ax.broken_barh([(float(start), float(stop - start))], (idx - 0.4, 0.8), facecolors="tab:blue", alpha=0.6) - ax.set_xlabel("time [s]") - ax.set_ylabel("history bin") - ax.set_yticks(range(1, self.numWindows + 1)) - ax.set_title(self.name) - ax.set_xlim(float(self.windowTimes[0]), float(self.windowTimes[-1])) - return ax + return signal.plot(handle=ax) HistoryBasis = History -__all__ = ["History", "HistoryBasis"] +__all__ = ["History", "HistoryBasis", "HistoryFilter", "HistoryFilterBank"] diff --git a/nstat/trial.py b/nstat/trial.py index 94e5af70..67c7f6c8 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -8,8 +8,9 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np +from scipy.signal import filtfilt -from .core import Covariate, nspikeTrain +from .core import Covariate, SignalObj, nspikeTrain from .events import Events @@ -49,6 +50,36 @@ def _copy_covariate(cov: Covariate) -> Covariate: return copied +def _copy_covariate_for_collection_view(cov: Covariate) -> Covariate: + """Mirror MATLAB CovColl.getCov copy semantics. + + MATLAB reconstructs a fresh Covariate from the visible time/data payload + rather than preserving internal sample-rate/original-data bookkeeping. + That matters for degenerate one-sample covariates, where the constructor + falls back to the 1 kHz default used throughout the toolbox. + """ + + copied = Covariate( + np.asarray(cov.time, dtype=float).copy(), + np.asarray(cov.data, dtype=float).copy(), + cov.name, + cov.xlabelval, + cov.xunits, + cov.yunits, + list(cov.dataLabels), + list(cov.plotProps), + ) + copied.dataMask = np.asarray(cov.dataMask, dtype=int).copy() + if cov.ci: + copied.ci = list(cov.ci) + if cov.conf_interval is not None: + copied.conf_interval = ( + np.asarray(cov.conf_interval[0], dtype=float).copy(), + np.asarray(cov.conf_interval[1], dtype=float).copy(), + ) + return copied + + class CovariateCollection: """MATLAB-style CovColl implementation with collection-level masks and timing.""" @@ -123,12 +154,12 @@ def _covariate_from_identifier(self, identifier: int | str) -> int: return index def _apply_collection_state(self, cov: Covariate, index: int) -> Covariate: - out = _copy_covariate(cov) + out = _copy_covariate_for_collection_view(cov) if self.covShift != 0: out.time = out.time + float(self.covShift) out.minTime = float(np.min(out.time)) out.maxTime = float(np.max(out.time)) - if np.isfinite(self.sampleRate) and self.sampleRate > 0 and round(out.sampleRate, 3) != round(self.sampleRate, 3): + if out.time.size > 1 and np.isfinite(self.sampleRate) and self.sampleRate > 0 and round(out.sampleRate, 3) != round(self.sampleRate, 3): out = out.resample(self.sampleRate) if np.isfinite(self.minTime) and np.isfinite(self.maxTime) and out.time.size > 0: out = out.getSigInTimeWindow(self.minTime, self.maxTime, holdVals=1) @@ -141,6 +172,9 @@ def add(self, covariate: Covariate) -> None: def addCovariate(self, covariate: Covariate) -> None: self.addToColl(covariate) + def addCovCollection(self, covariates: "CovariateCollection") -> None: + self.addToColl(covariates) + def addToColl(self, covariates: Sequence[Covariate] | Covariate | "CovariateCollection" | None) -> None: if covariates is None: return @@ -165,6 +199,10 @@ def removeCovariate(self, identifier: int | str) -> None: del self.covMask[index - 1] self._refresh_summary() + def copy(self) -> "CovariateCollection": + cov = [self.getCov(i).copySignal() for i in range(1, self.numCov + 1)] + return CovariateCollection(cov) + def get(self, name: str) -> Covariate: return self.getCov(name) @@ -191,6 +229,26 @@ def getCovIndicesFromNames(self, name: Sequence[str] | str): return self.getCovIndFromName(name) return [self.getCovIndFromName(item) for item in name] + def isCovPresent(self, cov) -> int: + if isinstance(cov, Covariate): + if not cov.name: + return 0 + try: + self.getCovIndFromName(cov.name) + except KeyError: + return 0 + return 1 + if isinstance(cov, str): + try: + self.getCovIndFromName(cov) + except KeyError: + return 0 + return 1 + if isinstance(cov, (int, np.integer, float, np.floating)): + index = int(cov) + return int(index > 0 and index < self.numCov) + raise TypeError("Need either covariate class or name of covariate or index of covariate") + def findMinTime(self) -> float: if self.numCov == 0: return float("inf") @@ -457,6 +515,68 @@ def dataToMatrix(self, repType: str | Sequence[str] | None = "standard", dataSel _, matrix, _ = self.matrixWithTime(str(repType), dataSelector) return matrix + def dataToStructure( + self, + selectorCell=None, + binwidth: float | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> dict[str, Any]: + del binwidth, minTime, maxTime + if selectorCell is None: + if self.isCovMaskSet(): + selectorCell = self.getSelectorFromMasks() + else: + selectorCell = [list(range(1, self.getCov(i).dimension + 1)) for i in range(1, self.numCov + 1)] + dataMatrix = self.dataToMatrix("standard", selectorCell) + return { + "time": self.getCov(1).time.copy() if self.numCov else np.array([], dtype=float), + "signals": {"values": dataMatrix}, + } + + def toStructure(self) -> dict[str, Any]: + self.resetMask() + structure: dict[str, Any] = { + "numCov": int(self.numCov), + "minTime": float(self.minTime), + "maxTime": float(self.maxTime), + "covDimensions": [int(value) for value in self.covDimensions], + "covMask": [np.asarray(mask, dtype=int).reshape(-1).tolist() for mask in self.covMask], + "covShift": float(self.covShift), + "sampleRate": float(self.sampleRate) if np.isfinite(self.sampleRate) else self.sampleRate, + "originalSampleRate": self.originalSampleRate, + "originalMinTime": self.originalMinTime, + "originalMaxTime": self.originalMaxTime, + "covArray": [cov.toStructure() for cov in self.covArray], + } + return structure + + @staticmethod + def fromStructure(structure) -> "CovariateCollection" | list["CovariateCollection"]: + if isinstance(structure, list): + return [CovariateCollection.fromStructure(item) for item in structure] + if not isinstance(structure, dict): + raise TypeError("CovColl.fromStructure expects a dictionary or list of dictionaries.") + cov = [ + Covariate( + row["time"], + row["data"], + row.get("name", ""), + row.get("xlabelval", "time"), + row.get("xunits", "s"), + row.get("yunits", ""), + row.get("dataLabels"), + row.get("plotProps"), + ) + for row in structure.get("covArray", []) + ] + ccObj = CovariateCollection(cov) + if "minTime" in structure: + ccObj.setMinTime(float(structure["minTime"])) + if "maxTime" in structure: + ccObj.setMaxTime(float(structure["maxTime"])) + return ccObj + class SpikeTrainCollection: """MATLAB-style nstColl implementation.""" @@ -484,6 +604,9 @@ def __iter__(self): for tr in self.nstrain: yield tr + def __len__(self) -> int: + return int(self.numSpikeTrains) + def _refresh_summary(self) -> None: self.numSpikeTrains = len(self.nstrain) if self.numSpikeTrains == 0: @@ -500,8 +623,22 @@ def _refresh_summary(self) -> None: self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) def addSingleSpikeToColl(self, nst: nspikeTrain) -> None: - self.nstrain.append(nst.nstCopy()) - self._refresh_summary() + train = nst.nstCopy() + if not getattr(train, "name", ""): + train.setName(str(self.numSpikeTrains + 1)) + if self.numSpikeTrains == 0: + self.minTime = float(train.minTime) + self.maxTime = float(train.maxTime) + self.sampleRate = float(train.sampleRate) + else: + self.updateTimes(train) + self.sampleRate = float(max(float(self.sampleRate), float(train.sampleRate))) + self.enforceSampleRate() + self.nstrain.append(train) + self.numSpikeTrains = len(self.nstrain) + self.neuronMask = np.append(self.neuronMask, 1).astype(int) + if self.numSpikeTrains == 1: + self.neighbors = [] def addToColl(self, nst: Sequence[nspikeTrain] | nspikeTrain | "SpikeTrainCollection") -> None: if isinstance(nst, SpikeTrainCollection): @@ -523,6 +660,15 @@ def merge(self, nstColl2: "SpikeTrainCollection") -> "SpikeTrainCollection": self.addToColl(nstColl2) return self + def length(self) -> int: + return int(self.numSpikeTrains) + + def getFirstSpikeTime(self) -> float: + return float(self.minTime) + + def getLastSpikeTime(self) -> float: + return float(self.maxTime) + def get_nst(self, idx: int) -> nspikeTrain: if idx < 0 or idx >= self.numSpikeTrains: raise IndexError("SpikeTrainCollection index out of bounds (0-based indexing).") @@ -551,6 +697,43 @@ def getNSTIndicesFromName(self, name: Sequence[str] | str): return matches if len(matches) > 1 else matches[0] return [self.getNSTIndicesFromName(item) for item in name] + def getNSTnameFromInd(self, ind: int) -> str: + index = int(ind) + if index < 1 or index > self.numSpikeTrains: + raise IndexError("Index is out of bounds!") + return str(self.nstrain[index - 1].name) + + def getNSTFromName(self, neuronName=None): + if neuronName is None: + neuronName = self.getUniqueNSTnames() + indices = self.getNSTIndicesFromName(neuronName) + return self.getNST(indices) + + def getFieldVal(self, fieldName: str): + fieldVal: list[float] = [] + neuronNumbers: list[int] = [] + cnt = 1 + for index in range(1, self.numSpikeTrains + 1): + currVal = self.getNST(index).getFieldVal(fieldName) + if currVal is None: + continue + if isinstance(currVal, np.ndarray) and currVal.size == 0: + continue + if len(fieldVal) < cnt: + fieldVal.extend([0.0] * (cnt - len(fieldVal))) + fieldVal[cnt - 1] = float(currVal) + cnt += 1 + if len(neuronNumbers) < cnt: + neuronNumbers.extend([0] * (cnt - len(neuronNumbers))) + neuronNumbers[cnt - 1] = index + return np.asarray(fieldVal, dtype=float), np.asarray(neuronNumbers, dtype=int) + + def shiftTime(self, timeShift: float | None = None) -> "SpikeTrainCollection": + if timeShift is None: + timeShift = -float(self.minTime) + shifted = [nspikeTrain(np.asarray(train.spikeTimes, dtype=float) + float(timeShift)) for train in self.nstrain] + return SpikeTrainCollection(shifted) + def toSpikeTrain( self, selectorArray: Sequence[int] | Sequence[str] | str | None = None, @@ -631,6 +814,12 @@ def resample(self, sampleRate: float) -> None: for train in self.nstrain: train.resample(sampleRate) + def enforceSampleRate(self) -> None: + for index in range(1, self.numSpikeTrains + 1): + currSpike = self.getNST(index) + if round(float(currSpike.sampleRate), 9) != round(float(self.sampleRate), 9): + currSpike.resample(float(self.sampleRate)) + def findMaxSampleRate(self) -> float: if self.numSpikeTrains == 0: return float("-inf") @@ -713,6 +902,12 @@ def getMaxBinSizeBinary(self) -> float: values = [self.getNST(index).getMaxBinSizeBinary() for index in selectorArray] return float(np.min(values)) + def BinarySigRep(self) -> bool: + return bool(all(self.getNST(index).isSigRepBinary() for index in range(1, self.numSpikeTrains + 1))) + + def isSigRepBinary(self) -> bool: + return self.BinarySigRep() + def dataToMatrix( self, selectorArray: Sequence[int] | Sequence[str] | str | None = None, @@ -749,7 +944,9 @@ def dataToMatrix( return dataMat def getEnsembleNeuronCovariates(self, neuronNum: int = 1, neighborIndex=None, windowTimes=None): - if neighborIndex is None: + if neighborIndex is None or ( + isinstance(neighborIndex, (list, tuple, np.ndarray)) and np.asarray(neighborIndex).size == 0 + ): allNeighbors = self.getNeighbors(neuronNum) else: allNeighbors = [int(item) for item in np.asarray(neighborIndex, dtype=int).reshape(-1)] @@ -781,6 +978,21 @@ def restoreToOriginal(self, rMask: int = 0) -> None: if rMask == 1: self.resetMask() + def ensureConsistancy(self) -> None: + self.enforceSampleRate() + self.setMinTime() + self.setMaxTime() + + def updateTimes(self, nst: nspikeTrain) -> None: + if float(nst.minTime) <= float(self.minTime): + self.setMinTime(float(nst.minTime)) + else: + nst.setMinTime(float(self.minTime)) + if float(nst.maxTime) >= float(self.maxTime): + self.setMaxTime(float(nst.maxTime)) + else: + nst.setMaxTime(float(self.maxTime)) + def plot(self, *_, handle=None, **__): selected = self.getIndFromMask() if not selected: @@ -795,6 +1007,65 @@ def plot(self, *_, handle=None, **__): ax.set_title("Spike Train Raster") return ax + def getMinISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + isis = self.getISIs(selectorArray, minTime, maxTime) + return np.asarray([float(np.min(values)) if values.size else 0.0 for values in isis], dtype=float) + + def getISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + return [self.getNST(int(neuron)).getISIs(minTime, maxTime) for neuron in selectorArray] + + def plotISIHistogram(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None, handle=None): + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + fig = handle if handle is not None else plt.figure(figsize=(7.0, max(2.5, 2.2 * len(selectorArray)))) + fig.clear() + axes = fig.subplots(len(selectorArray), 1) + if not isinstance(axes, np.ndarray): + axes = np.asarray([axes], dtype=object) + for ax, neuron in zip(axes.reshape(-1), selectorArray, strict=False): + self.getNST(int(neuron)).plotISIHistogram(minTime, maxTime, handle=ax) + fig.tight_layout() + return fig + + def plotExponentialFit( + self, + selectorArray: Sequence[int] | None = None, + minTime: float | None = None, + maxTime: float | None = None, + numBins: int | None = None, + handle=None, + ): + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + fig = handle if handle is not None else plt.figure(figsize=(7.0, max(2.5, 2.2 * len(selectorArray)))) + fig.clear() + axes = fig.subplots(len(selectorArray), 1) + if not isinstance(axes, np.ndarray): + axes = np.asarray([axes], dtype=object) + for ax, neuron in zip(axes.reshape(-1), selectorArray, strict=False): + self.getNST(int(neuron)).plotExponentialFit(minTime, maxTime, numBins, handle=ax) + fig.tight_layout() + return fig + + def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + del minTime, maxTime + selector = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + return [self.getNST(int(index)).getSpikeTimes() for index in selector] + def psth( self, binwidth: float = 0.100, @@ -831,7 +1102,21 @@ def psth( valid = np.isfinite(spikes) & (spikes >= window_times[0]) & (spikes <= window_times[-1]) if not np.any(valid): continue - idx = np.searchsorted(window_times, spikes[valid], side="right") - 1 + spike_times = spikes[valid] + # Mirror MATLAB histc edge semantics on a numerically noisy uniform grid: + # interior samples land in [edge_k, edge_{k+1}), but samples that match + # an edge belong to that edge's bin, with the final edge kept in the + # extra histc bin that is discarded below. + left = np.searchsorted(window_times, spike_times, side="left") + right = np.searchsorted(window_times, spike_times, side="right") - 1 + edge_tol = max(1e-12, abs(float(binwidth)) * 1e-9) + exact_edge = (left < window_times.size) & np.isclose( + spike_times, + window_times[np.clip(left, 0, window_times.size - 1)], + rtol=0.0, + atol=edge_tol, + ) + idx = np.where(exact_edge, left, right) idx = np.clip(idx, 0, window_times.size - 1) psth_hist += np.bincount(idx, minlength=window_times.size).astype(float) @@ -843,6 +1128,201 @@ def psthGLM(self, binwidth: float): psth_signal = self.psth(binwidth) return psth_signal, None, None + def psthBars( + self, + binwidth: float = 0.100, + selectorArray: Sequence[int] | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> SignalObj: + """Deterministic pure-Python fallback for MATLAB nstColl.psthBars. + + MATLAB delegates this method to an external BARS package that is not + bundled with the source tree. The Python port preserves the public + surface and return structure with a smoothed PSTH approximation. + """ + if binwidth <= 0: + raise ValueError("binwidth must be > 0") + min_time = self.minTime if minTime is None else float(minTime) + max_time = self.maxTime if maxTime is None else float(maxTime) + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + selector = [int(item) for item in selectorArray] + if not selector: + raise ValueError("selectorArray must contain at least one neuron") + + time = np.arange(min_time, max_time + float(binwidth), float(binwidth), dtype=float) + if time.size == 0: + time = np.array([min_time, max_time], dtype=float) + if not np.isclose(time[-1], max_time): + if time[-1] < max_time: + time = np.append(time, max_time) + else: + time[-1] = max_time + + psthData = np.zeros(time.size, dtype=float) + for neuron in selector: + spikeTimes = np.asarray(self.getNST(neuron).getSpikeTimes(), dtype=float).reshape(-1) + if spikeTimes.size == 0: + continue + valid = np.isfinite(spikeTimes) & (spikeTimes >= time[0]) & (spikeTimes <= time[-1]) + if not np.any(valid): + continue + spikeTimes = spikeTimes[valid] + left = np.searchsorted(time, spikeTimes, side="left") + right = np.searchsorted(time, spikeTimes, side="right") - 1 + edge_tol = max(1e-12, abs(float(binwidth)) * 1e-9) + exact_edge = (left < time.size) & np.isclose( + spikeTimes, + time[np.clip(left, 0, time.size - 1)], + rtol=0.0, + atol=edge_tol, + ) + idx = np.where(exact_edge, left, right) + idx = np.clip(idx, 0, time.size - 1) + psthData += np.bincount(idx, minlength=time.size).astype(float) + + psthData = psthData / float(binwidth) / float(len(selector)) + + # MATLAB uses an external BARS fitter here; preserve the public output + # structure with a deterministic smoothed-rate fallback. + if psthData.size >= 3: + kernel = np.array([0.25, 0.5, 0.25], dtype=float) + mean_curve = np.convolve(psthData, kernel, mode="same") + else: + mean_curve = psthData.copy() + mode_curve = mean_curve.copy() + counts_per_bin = np.maximum(mean_curve * float(binwidth) * float(len(selector)), 0.0) + stderr = np.sqrt(counts_per_bin) / max(float(binwidth) * float(len(selector)), 1e-12) + ciLower = np.maximum(mean_curve - 1.96 * stderr, 0.0) + ciUpper = mean_curve + 1.96 * stderr + data = np.column_stack([mode_curve, mean_curve, ciLower, ciUpper]) + return SignalObj( + time, + data, + "PSTH_{bars}", + "time", + "s", + "Hz", + ["mode", "mean", "ciLower", "ciUpper"], + ) + + def _psth_glm_coeffs( + self, + basisWidth: float, + windowTimes=None, + fitType: str = "poisson", + ) -> np.ndarray: + from .analysis import Analysis + + basis = self.generateUnitImpulseBasis(float(basisWidth), float(self.minTime), float(self.maxTime), float(self.sampleRate)) + trial = Trial(SpikeTrainCollection([train.nstCopy() for train in self.nstrain]), CovariateCollection([basis])) + hist = [] if windowTimes is None else np.asarray(windowTimes, dtype=float).reshape(-1) + label_select = [[basis.name, *list(basis.dataLabels)]] + cfg = TrialConfig(label_select, float(self.sampleRate), hist, []) + cfg.setName("GLM-PSTH+Hist" if np.asarray(hist).size else "GLM-PSTH") + cfgColl = ConfigCollection([cfg]) + algorithm = "GLM" if str(fitType or "poisson").lower() == "poisson" else "BNLRCG" + psth_result = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, algorithm, [], 1) + fit = psth_result[0] if isinstance(psth_result, list) else psth_result + coeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1) + numBasis = basis.dimension + if coeffs.size < numBasis: + padded = np.zeros(numBasis, dtype=float) + padded[: coeffs.size] = coeffs + return padded + return coeffs[:numBasis] + + def estimateVarianceAcrossTrials( + self, + numBasis: int | None = None, + windowTimes=None, + numIter: int | None = None, + fitType: str | None = None, + ) -> np.ndarray: + if fitType is None or fitType == "": + fitType = "poisson" + if numIter is None: + numIter = 20 + if windowTimes is None: + windowTimes = [] + if numBasis is None: + numBasis = 20 + + numBasis = int(numBasis) + numIter = int(numIter) + coeffs = np.zeros((numBasis, numIter), dtype=float) + numRealizations = int(self.numSpikeTrains) + if numRealizations == 0 or numBasis <= 0 or numIter <= 0: + return np.zeros((max(numBasis, 0), max(numBasis, 0)), dtype=float) + + basisWidth = (float(self.maxTime) - float(self.minTime)) / float(numBasis) + sumNumber = max(int(np.floor(numRealizations / 2.0 - 1.0)), 0) + delta = 1.0 / float(self.sampleRate) + minTime = float(self.minTime) + maxTime = float(self.maxTime) + halfIters = min(int(np.floor(numIter / 2.0)), sumNumber) + + for i in range(1, halfIters + 1): + subset = SpikeTrainCollection(self.getNST(list(range(i, i + sumNumber + 1)))) + subset.resample(1.0 / delta) + subset.setMaxTime(maxTime) + subset.setMinTime(minTime) + coeffs[:, i - 1] = subset._psth_glm_coeffs(basisWidth, windowTimes, fitType) + + for i in range(numRealizations, numRealizations - halfIters, -1): + subset = SpikeTrainCollection(self.getNST(list(range(i, i - sumNumber - 1, -1)))) + subset.resample(1.0 / delta) + subset.setMaxTime(maxTime) + subset.setMinTime(minTime) + coeffs[:, i - 1] = subset._psth_glm_coeffs(basisWidth, windowTimes, fitType) + + coeff_rows = [row[row != 0] for row in coeffs] + max_width = max((row.size for row in coeff_rows), default=0) + if max_width == 0: + return np.zeros((numBasis, numBasis), dtype=float) + coeffsTemp = np.full((numBasis, max_width), np.nan, dtype=float) + for idx, row in enumerate(coeff_rows): + coeffsTemp[idx, : row.size] = row + + nTerms = 4 + filt_num = np.ones(nTerms, dtype=float) / float(nTerms) + coeffsTemp[np.isnan(coeffsTemp)] = 0.0 + if coeffsTemp.T.shape[0] > 3 * nTerms: + fcoeffs = filtfilt(filt_num, [1.0], coeffsTemp.T, axis=0).T + else: + fcoeffs = coeffsTemp + + diffs = np.diff(fcoeffs, axis=1) + if diffs.shape[1] <= 1: + varEst = np.full(numBasis, np.nan, dtype=float) + else: + with np.errstate(invalid="ignore", divide="ignore"): + varEst = np.nanvar(diffs, axis=1, ddof=1) + return np.diag(varEst) + + @staticmethod + def generateUnitImpulseBasis(basisWidth: float, minTime: float, maxTime: float, sampleRate: float = 1000.0) -> Covariate: + windowTimes = np.arange(float(minTime), float(maxTime), float(basisWidth)) + if windowTimes.size == 0 or not np.isclose(windowTimes[-1], maxTime): + windowTimes = np.append(windowTimes, float(maxTime)) + else: + windowTimes[-1] = float(maxTime) + if windowTimes.size < 2: + windowTimes = np.array([float(minTime), float(maxTime)], dtype=float) + timeVec = np.arange(float(minTime), float(maxTime) + (1.0 / float(sampleRate)), 1.0 / float(sampleRate)) + dataMat = np.zeros((timeVec.size, windowTimes.size - 1), dtype=float) + dataLabels: list[str] = [] + for i in range(windowTimes.size - 1): + start = float(windowTimes[i]) + stop = float(windowTimes[i + 1]) + if i == windowTimes.size - 2: + dataMat[:, i] = ((timeVec >= start) & (timeVec <= stop)).astype(float) + else: + dataMat[:, i] = ((timeVec >= start) & (timeVec < stop)).astype(float) + dataLabels.append(f"b{i + 1:02d}" if i + 1 < 10 else f"b{i + 1}") + return Covariate(timeVec, dataMat, "UnitPulseBasis", "time", "s", "", dataLabels) + class TrialConfig: """MATLAB-style TrialConfig with configuration-application semantics.""" diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 8657b271..65d51300 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -16,7 +16,7 @@ items: matlab_path: SignalObj.m python_public_name: nstat.SignalObj python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Constructor defaults, orientation handling, labels, masks, sample-rate inference, and time-window APIs now mirror MATLAB closely. property_parity: Core observable fields exist, including time, data, name, xlabelval, @@ -56,7 +56,7 @@ items: matlab_path: Covariate.m python_public_name: nstat.Covariate python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Uses the MATLAB-aligned SignalObj constructor shape and supports the Python compatibility aliases for values and units. property_parity: mu and sigma views exist and confidence-interval storage matches @@ -71,22 +71,16 @@ items: output_type_parity: Covariate methods return Covariate or SignalObj as MATLAB expects for the implemented subset. symbol_presence_verified: yes - known_remaining_differences: - - Some CI plotting options and full structure round-tripping remain lighter than - MATLAB. - - More specialized arithmetic/reporting behaviors still need MATLAB-derived fixtures. - required_remediation: - - Extend the committed MATLAB-derived fixtures beyond `computeMeanPlusCI` and - explicit confidence-interval payloads to cover CI plotting and serialized - round-trips. - plotting_report_parity: Core signal and confidence-interval plotting works; some - MATLAB CI styling/report variations remain lighter. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: CI-aware plotting and serialized round-tripping are + fixture-backed against MATLAB for the exported Covariate surface. - matlab_name: nspikeTrain kind: class matlab_path: nspikeTrain.m python_public_name: nstat.nspikeTrain python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Constructor argument order, defaults, and cached signal-representation setup follow MATLAB closely, including min/max/sample-rate initialization and the makePlots behavior split. @@ -107,15 +101,10 @@ items: output_type_parity: Signal representation returns SignalObj and rate conversion returns SignalObj as expected. symbol_presence_verified: yes - known_remaining_differences: - - Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers - remains lighter than MATLAB. - required_remediation: - - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, restore-bound - semantics, and burst-stat summaries to cover ISI plotting traces and the remaining - visualization details. - plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute - on the canonical class, though visual detail remains lighter than MATLAB. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: ISI spectrum, joint-ISI, histogram, probability-plot, + and exponential-fit surfaces are fixture-backed against MATLAB. - matlab_name: nstColl kind: class matlab_path: nstColl.m @@ -128,8 +117,10 @@ items: minTime, maxTime, sampleRate, neuronMask, and neighbors. method_parity: MATLAB-facing collection methods are now first-class, including addToColl, addSingleSpikeToColl, merge, getNST, name/index lookup, masking, neighborhood - management, dataToMatrix, fixture-backed toSpikeTrain collapsing, ensemble-covariate - helpers, restoreToOriginal, psth, and psthGLM. + management, getFieldVal, getSpikeTimes/getISIs wrappers, BinarySigRep/isSigRepBinary, + fixture-backed dataToMatrix, fixture-backed toSpikeTrain collapsing, fixture-backed + ensemble-covariate helpers, restoreToOriginal, fixture-backed psth, psthGLM, + deterministic-fallback psthBars, and Python-side estimateVarianceAcrossTrials. defaults_parity: Defaults for masks, sample rate, and min/max time now track MATLAB collection semantics closely. indexing_parity: MATLAB-facing one-based getNST is preserved. @@ -138,13 +129,22 @@ items: output_type_parity: PSTH returns Covariate. symbol_presence_verified: yes known_remaining_differences: - - Some plotting/statistics helpers and lower-level utility methods from MATLAB are - still absent. + - psthBars now exists, but MATLAB delegates to an external BARS fitter that is + not bundled with the source tree; the Python port currently uses a deterministic + smoothed PSTH fallback instead of exact BARS output. + - MATLAB-only public branch `ssglm` remains unported. + - "estimateVarianceAcrossTrials now exists in Python, but the nontrivial MATLAB + reference path is internally inconsistent through psthGLM / RunAnalysisForAllNeurons, + so the method is not yet fixture-backed strongly enough to promote nstColl to exact." + - Collection-level plotting/report layout still differs from MATLAB in subplot composition + and presentation details. required_remediation: - - Extend the committed MATLAB-derived fixtures beyond collection naming, - `dataToMatrix`, and `toSpikeTrain` outputs to cover neighbor masks, ensemble - covariates, and PSTH outputs. - - Port any remaining collection utilities that surface in MATLAB helpfiles. + - Port `ssglm`. + - Add or vendor a stable BARS-equivalent reference path before promoting psthBars behavior to exact. + - "Add a stable MATLAB-side reference/export path for estimateVarianceAcrossTrials, + then back the Python method with fixtures before promoting nstColl to exact." + - Add fixture-backed checks for the remaining collection plotting/report helpers before + promoting `nstColl` to exact. plotting_report_parity: Raster and PSTH plotting works for core workflows; some collection summary visuals remain unported. - matlab_name: Trial @@ -430,28 +430,23 @@ items: matlab_path: History.m python_public_name: nstat.History python_impl_path: nstat/history.py - status: high_fidelity - constructor_parity: History now uses MATLAB-style windowTimes construction with + status: exact + constructor_parity: History uses MATLAB-style windowTimes construction with optional min/max metadata. property_parity: windowTimes, minTime, maxTime, and lags-compatible access are exposed. - method_parity: setWindow, computeHistory/compute_history, structure round-trip, - and CovColl-producing history workflows are now implemented for single trains, - train collections, and trial history use. - defaults_parity: Window-boundary defaults are close to MATLAB for the implemented - history workflows. + method_parity: setWindow, computeHistory/compute_history, toFilter, plot, and + structure round-trip now match the MATLAB public surface. + defaults_parity: Window-boundary defaults and CovColl return semantics are fixture-backed + against MATLAB. indexing_parity: WindowTimes are interpreted as MATLAB-style consecutive lag boundaries. - error_warning_parity: Core validation is present, though MATLAB warning text and - some malformed-input branches remain thinner. + error_warning_parity: Constructor validation and runtime error branches match the + implemented MATLAB public surface. output_type_parity: Returns CovariateCollection outputs in the MATLAB-facing workflows - that consume History objects. + that consume History objects, including the MATLAB one-sample internal covariate quirk. symbol_presence_verified: yes - known_remaining_differences: - - Plotting and some specialized history-basis utilities remain unported. - required_remediation: - - Add MATLAB-derived fixtures for history-window outputs and multi-neuron history - collections. - plotting_report_parity: No dedicated history plotting parity beyond workflow-generated - covariates and notebook figures. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: MATLAB-style history-window plotting is fixture-backed. - matlab_name: Events kind: class matlab_path: Events.m @@ -512,29 +507,29 @@ items: matlab_path: CovColl.m python_public_name: nstat.CovColl python_impl_path: nstat/trial.py - status: high_fidelity - constructor_parity: CovColl now supports MATLAB-style direct construction, empty - initialization, and nested collection ingestion. - property_parity: Core collection state exists, including covArray, covDimensions, - numCov, minTime, maxTime, covMask, covShift, sampleRate, and original timing metadata. - method_parity: MATLAB-facing collection methods are now first-class, covering add/remove, - name/index lookup, mask selectors, time-window restriction, resampling, matrixWithTime, - dataToMatrix, shift/reset, label extraction, and restoreToOriginal. - defaults_parity: Default mask, shift, sample-rate, and timing behavior now track - MATLAB collection semantics closely. - indexing_parity: Shared-time enforcement is implemented. - error_warning_parity: Core validation is present, though some MATLAB warning text - and malformed-selector branches are still thinner. - output_type_parity: Returns Covariate and CovariateCollection-compatible outputs - across MATLAB-facing workflows. + status: exact + constructor_parity: Direct construction, empty construction, and nested collection + ingestion are now fixture-backed against MATLAB behavior. + property_parity: covArray, covDimensions, numCov, minTime, maxTime, covMask, + covShift, sampleRate, and the original timing/sample-rate metadata are now + fixture-backed on the canonical implementation. + method_parity: MATLAB-facing collection methods now include add/remove, copy, + isCovPresent, name/index lookup, mask selectors, time-window restriction, resampling, + matrix/data export, shift/reset, label extraction, dataToStructure, and structure + round-trip with MATLAB-compatible mask-reset behavior. + defaults_parity: Default mask, shift, and timing behavior are fixture-backed + against MATLAB. + indexing_parity: Shared-time enforcement and one-based selector semantics are + fixture-backed against MATLAB. + error_warning_parity: The implemented constructor/selector/state branches now + match MATLAB behavior for the fixture-backed public surface. + output_type_parity: Returns Covariate and CovColl-compatible outputs across the + MATLAB-facing workflow surface. symbol_presence_verified: yes - known_remaining_differences: - - Some structure serialization and rarely used helper methods remain unported. - required_remediation: - - Add MATLAB-derived fixtures for selector masks, time-window coercion, and serialized - collection state. - plotting_report_parity: Collection plotting is available for core workflows; some - MATLAB summary visuals remain absent. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: Core MATLAB-facing collection plotting and exported structure + views are fixture-backed on the canonical surface. - matlab_name: getPaperDataDirs kind: function matlab_path: getPaperDataDirs.m diff --git a/parity/manifest.yml b/parity/manifest.yml index 206db5a3..585779c5 100644 --- a/parity/manifest.yml +++ b/parity/manifest.yml @@ -476,8 +476,8 @@ repo_structure: or repo-root package stub. fidelity_summary: class_fidelity: - exact: 3 - high_fidelity: 15 + exact: 8 + high_fidelity: 10 not_applicable: 1 notebook_fidelity: high_fidelity: 13 diff --git a/parity/report.md b/parity/report.md index f7a6da0d..0053769d 100644 --- a/parity/report.md +++ b/parity/report.md @@ -22,8 +22,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo | Status | Count | |---|---:| -| `exact` | 3 | -| `high_fidelity` | 15 | +| `exact` | 8 | +| `high_fidelity` | 10 | | `partial` | 0 | | `wrapper_only` | 0 | | `missing` | 0 | diff --git a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat index fdb6049917060ee28e0eb9be1156b2073825f8c3..b506acf36f91f77d729cbf8e038ed0605037e1e3 100644 GIT binary patch delta 41 wcmZ3%yMlLuv4n4ao`P>;k%EGyf`NsVk%5(onSzmlk=evR<%tPw8%s>s0P-#iHUIzs delta 41 wcmZ3%yMlLuu|#lbo`P>;k%EGSf{~$>sil>%se+M#k=evR<%tPw8%s>s0Q3P1MgRZ+ diff --git a/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat new file mode 100644 index 0000000000000000000000000000000000000000..3527e7e054a521a741ec39a838e92d087f56116e GIT binary patch literal 1389 zcmeZu4DoSvQZUssQ1EpO(M`+DN!3vZ$Vn_o%P-2c0*X01nwjV*I2WZRmZYXA5-}&hTiG>P?X0v2ep@FdoTqIK#Nqu|f3IbG9Tmjb|*69!)xP z=f{*IlU!CSAGC(y7H*PSz*3#hiXH` zH6MW+nfwd$Z!K>;Xk5*(>$uQ*kU`Ew8>Ead$RabJhpqqZ`FV_u=HD1r1`AyP8Kng` zY6G&HloAE(*PEEz>26*!gHz)W%M)e>Cm-g!ARVS~9UoA22a|mhe{Mb_$FMM(y@d$PT}YZQKL=~hv1eHLfrAC4Ss!k50jkXc_PU#o7~H#f zd8Y8G{kP7}p7r*qt%lB>7ngzD4`D-<{km!{?C6ASlYK3o^qFw>v&F!Q4^fSGTL5f5LgrBTNo4 zLc2cwkl)$)%{4%ZLF*=;DaaH=+C)hqAm4mO_@=y|z%$9yFkoJ1qo4*SLpvwiVvuGW z*%PFh8}5WX_+YRj(8RoeOvV#n=hWiC2E<1hfDE literal 0 HcmV?d00001 diff --git a/tests/parity/fixtures/matlab_gold/cif_exactness.mat b/tests/parity/fixtures/matlab_gold/cif_exactness.mat index 96de6ae3662ab17f5d051e3a2613ca0717dfaa40..2818502a5b5d5613b485295cead746702625c73f 100644 GIT binary patch delta 41 wcmZqWY~`F_Ea97Cqr{J4dq@ZA-U}R`zYH4L`q+n!VWHvESd13;k%EGyf`O%#p@Eg5se+M#k=evR<%tPw8%xfx0s#6>3-;k%EGSf{~$>sil>XrGk-xk=evR<%tPw8%xfx0s#Dx3=RMQ diff --git a/tests/parity/fixtures/matlab_gold/config_exactness.mat b/tests/parity/fixtures/matlab_gold/config_exactness.mat index 1c4f78bde12ec6da16bc01dfd2cc0fccebec6970..5a70a0152e6a68e129d5716ba4890215e9391cdc 100644 GIT binary patch delta 41 xcmaDU@=|1iv4n4ao`P>;k%EGyf`NsVk%5(op@NZtk=evR<%tPw8%w@&0ssl*421vy delta 41 xcmaDU@=|1iu|#lbo`P>;k%EGSf{~$>sil>XrGk-xk=evR<%tPw8%w@&0sss)44nV~ diff --git a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat index b7d6f4fd7bed1ce43305993d403f3a022d8a024e..f7901f36a257cfb7310b88f6635f1ecab9d79c95 100644 GIT binary patch delta 780 zcmX@Wa)*0@v4n4ao`P>;k%EGyf`O%#p@Eg5g@TcRk=evR<%tPw8%xeGPS$0VPiB7f zXwsrfPbOVjbmhmSGj~GR*!&jEU}I)FGmDeS`mN#@&uZahfgL4M0nKrJba(@YfcSKgscO@wh%G&~TxFWdFBP?=&+0{*uY6|Q$T-PT zC*n+^qftD=32nWTXIL3FJ8^4+%yYtNo&l11;!mDTXAyZSF4BAzYFfs#0}N+s7&Ukq zR<>|@flNw=n>2&*NRGo9#-)x8q6(+jFZwu~;W%>0Atoj$WL8X2)D)@tb0p)Z%?O?q z#S?aZG8dCV{l=$EA%FV*My}W+7$SG*e*!)!j3J>8l{k z36t+J$_Usv#7v2wH9aV3N~FO2Ih;Y0Ihl+je14d!ft28lk*Cw2EuA8gDs}`OBe)`D zgEqG!$hg4Cjf|r9#?FSy0^AIRj(&BXHGXMvK>yYG)dd==2b^K;k%EGSf{~$>sil>XrGk-xk=evR<%tPw8%xeG0s#N93>^Rf diff --git a/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat new file mode 100644 index 0000000000000000000000000000000000000000..b7ba24ee0b8456f7f507cbbd8ce982d5e890800c GIT binary patch literal 1488 zcmeZu4DoSvQZUssQ1EpO(M`+DN!3vZ$Vn_o%P-2c0*X01nwjV*I2WZRmZYXA3`u4R&lr;2 z*f@IEOP`4_|EzJyX;?C!wZBe4fMJtXg9*bfeYV>mGfd!S z6oAcWxTJ7^>!~rbp|Z$lroz0W!m=!{n!F^>z%A8Kl<&uD1hS zZ!t4WuV-4GR)nm&fVaVdX7+CaznX0sSa+~@f%H4T^?RV}=VmVSO7cvqDfA4?Eb`0@ z%&GyKV~b+W;RW3MsNwJd-JEVFbaNt;{4%o)1LklpnKZGJk>LdkXEDeOYq(ob%wRGU zo+8HjsOit6Nk1MHFdWHoIKz0PhezN**OBzLOH2+iQ{rb$ z4+@$RDKLNXj488%&#;RaJ6t;7J;?Oo$%52^zDR(YU+cQRm~0D-xv;k%EGyf`NsVk)f4=iGq=Vk=evR<%tPw8%rt~0R;66=Kufz delta 41 wcmaFM{FZrwu|#lbo`P>;k%EGSf{~$>sil>viGq=Vk=evR<%tPw8%rt~0S7A!{Qv*} diff --git a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat index ff5e21f940a3a09e1969b09b85a66be2db9c1f1b..379a650e68e59247e6689d6b157a34285b7515d2 100644 GIT binary patch delta 41 wcmZo;Yh#;WEa97Cqr{J4dq@ZA-U}R`zYH4L^qF`iTWHvESd13;k%EGyf`NsVk%5(op@NZtk=evR<%tPw8%uUD0RZv<3#$MC delta 41 xcmZ3(wuWtju|#lbo`P>;k%EGSf{~$>sil>%fr62Nk=evR<%tPw8%uUD0RZ#U3%dXS diff --git a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat index a13b752932c2f2fe073bec077471d565a7dc9b0e..445c4c6f6412addaa592f98a78d2db4cf002d99a 100644 GIT binary patch delta 41 xcmbQnHjQn9v4n4ao`P>;k%EGyf`NsVk%5(oxq^{_k=evR<%tPw8%q{30RZb%3vU1b delta 41 xcmbQnHjQn9u|#lbo`P>;k%EGSf{~$>sil>%nSzmlk=evR<%tPw8%q{30RZhM3x5Cr diff --git a/tests/parity/fixtures/matlab_gold/history_exactness.mat b/tests/parity/fixtures/matlab_gold/history_exactness.mat new file mode 100644 index 0000000000000000000000000000000000000000..ca1469e42dbac3daf25eca221c07438c6378a319 GIT binary patch literal 10701 zcma)?1yCH_x~`Gn0fM^)f-}M026qUM;I2W#3=-S}f#42ZlU~co!rj&t zbs^81{t3%$Aph@M`Ao?BRyms^ydm!6lKSCIaHyCMD8 zhODfP|NB_@_x>KTVP$R}iDW5AzZEn^--o@)hr+{`yoJ7g3hzq3JB^1g&sf%pYS5>e zOJm?4Fq~ufB4Kr6YkF=0qz4SY993mw8|I80|EmGv z6?Q0<=H^I96P*}k|NHP?z4#NZKq^N@npnfA{1evrdt|Ne=H}+8Q_|RKp-3FbkBJ$< z!B|RsDD|w*$n3U?w~GettVz=lWUmLenNx{v-f&CL7+czrqZOJHe+qfUjjiyR8wE*` z6NmN>RsU~l8`4)NX*lxNPfwrqMSZmknJSPrhGeAWBvIxj7#Js-=~Cn-r6*_VJSE;n zwZvco>(`}!`DopPluYNGXO(A1N#%-3!m>s!p+h4ZCV~`Bi0|-+uk<(nuV;v4Ns5kX z9)3y@tp7{}2i48Z*3?JC!B4@~O*!C|ji0Hd3%hOCCp#0rvrCO^4N#wb+RC$C2Q=E;imZM_Je>bpB)i+A|Ow5U3&qS?CN-uPU; z@w-e=Zg;j6cq0`4{np~+bk}u#%vqh}V|lX}^6&DHg6P_jt^%P`%^9`Kttp2Vl^za9W;3KCFq`sanr*sK3n3kBwnLv1vCZ62IstW>LSpE z_<_&WJf4^mjUq`~W!Pox7l>eKl|8LCSFEnsKIvpo+nis619ld_0 z0%q8|3R*x?mZHiGP+;OC?{i}2jv)7kv)<5)Bh7`Q?fii${I6g-+Gq6tH!SNPu=LMP zNW%kY<9}#8f7AZkD?>}KF#=~(=8Do(9%WO0Qoa#Z5gLT?3#oz#z2r|&-KEV-%a>N?o_q%=8gyYS%g*7t2d zsuZ8s90ZLh^}YA&)W@|QsoJxLm$i+_*W+jGky0yE-uG)m?#J48Z!~F8#qCArX$Bh`z{ z2e`ke&E6I~@i5XFJef~Blx;n^5R2E|zjaJC5W9w$2%mcBUmX^AUUx~P^X%P1aQ81y z+ixe{zD@U^IlA|IKu#7p@79YZ!nl@j@8?<;dxb<_XPI?se+BXddN-QW%;+E?ypk0w#M{+ z30MZ+TfHr2Z1P%2>m>I(vS{D#U9|~3vOH{dNHH`xYS08?K6Ntl?}yX>;=o+zZuaU7 zS!gw%iMy7)BZEJnSw;G6nEyJVo!eTrU-g7Z{5qpWQJpeW5Acuk@5ZmZw&=X)IBx8{ zocs=>x-x} z%9!XR0I9PR#b|qs!qp5Qv!t;=q<50_j!VPq*nxjrm{K$Fk|)+(5u9VCWz%!;{^_+5 zJn9=SFJE5Ib=EJ3>W8_G`(7_*=^x2bV|$H1OClcXj{MSi|8gzHfMn=CQkLW*Pxwfg z8qsvfN;B^NQE{A)tlQ{U)b{FEt11V5@k36N~nkU zEZxAC=dFn;Ube*0rNO(^2l*V+@Oo)Aw;q@aD>#lSy)vzj+^`L!TxjeQ`~5Y5ninC& zM;_RN-G<+|PMokW#n7TxU-#*Pu`k6*9IB*LHhW4f-)Cg1qkF1tu9jLR?x?H~lrKN( z-kBa9ly1qF&jjn1xqZu3kx&}=fWKK%qD`h<-#Ufe{(L4v7gF&Y2cnX#^u=Rwu;cAa zH2;%X@>S71M@kPy+-mXd6YPIsxYcev0-=qZV6pX|O+Gq{&zPrz=?50lREwiJ*JYpu{2<=D>UO4W# zx9DaJy-7vBKDS8WEeB^zTxg-HcaMa1>Ww~(0W)U0)w*{1T>4f-GG=L{ z4^8o+dlt;%oA4rlZ9%o7eOa8r4s)CVW3*^Sc+VI1d(I=ZZg~H@iv;-vu1|AHTfY6y00{YtN~&a1zhXZ_X8yc26id6s(i=->qmbx7)p zIm;(2<~!Rd$>Pk4(^qwl%(l1Qf)_$m4&8CKIuVWeSt7otrI(H6LyPamQ#FT60CpRr zkXD_e-tiL2?Y#c3)oDSYv3!l~yt$C->4Dj@c@g`9Crl9?*`Y0}lyJW$=7^i@P?tcB z3=)uNglJA^&q(dqDq)7K#Q5jPl+$c**}h16)6G)*Xg9e0xIMivOd?G*H+pkFaVz{G zUbTCCVDLg<;3KH~wtb-$NRoV$8?BQy2DmxU2KdvH&9@4ZNqrnIrb9e`pGCcFd?N!` z_pG|ou}8jVz+dbvNR4@$q?ij04TvMeypg9}zWoqQg_smN{khdQ`Fh>pJWgC!*vYti zMJP@(F~U%V_H>8E+up}7zs-Qw{{@}HVa4{e+3Dn=`{D(G+19k(Y5w5#w%?7qqyl7( z!v&(b1{<=DGoKDT<=F-`j*g?A?aE&NOqR+GNX9Q$5i>(fr(xG=npSEc-XPYj!czuIl|sUE$S)HJ7xG^hZrfXamaOExeb|Svll&*$fUzg z(j#SV%o&QX4|s^D@NsyaHy4(<6ZqF-CQa)>9tyk(?&~qZe1&6OM>=x3oc&P;D%+N& zB$F?r>05n{KRmy&CHS>T*G>N><-s6qI(&*gdZ$qSCzlKOI;J`@z%a}=GrWLr;C$8+ z%0+@!&`Zj=r1hrCkH^1-+-F*P&9j$m+BN#5?>R<&dXLek)|ThJSNj9!YnJa1YSJZ* zKN;qPm(rz4FDhLS-hi;aNd)-rn?w-afU$aB_Z!OO0)SHW@))Ga6Sb;&)~>Et{+&;f zB+wNe0iTUuu)H`wA*^~ID8O{$jdgDv%35oaiA*AJ?t!bjwCA{UOj(qcnld@DEdb@> zQ&$T65Ks#|!JvyuwBNax%f$DMVQp^#1O_f8`;VD;Mt=J@ z&A(hjRUBbNhOAUIH81q+CxK3{yGsmD!%4Fpp&AY_3d8S(XN{Z^%?MmB9}d~+M6srw zTFR;yf%ZJQP6plKhM-8&Ldge1rHWMXf`KP*#eAC)gj^*vCIVhRMYyizqBF4hn-IiY zTjc}OiF!>xq1~y=1q19J7F^fz(Y;jHQnYt6sJM>8yr5B;5#5KXi$d8S5mO1d z?scsxIK_;LyC+Ic&R^JlPLpXJVkmcEuYJ0XwHk(3-)Qzs{r-I0dr63>#zpN}`kbcH z&ZbQ4!c2V1em&%~1ny2M;l}4smqf5H0VKTD;^x+%>QhN0=R5Rq8?`LCyORbK@i~wZ z4C0(!n4=HW2$ydM=IrxSOQVXVcbBK>Hm%?S9Rxwfj&`7gg3;Bq%+2ZNDd;sts*-~J z?E0!q2_8F^hNQ>%q<=t_P`$^4Uic#vD53u7RX0GExAyFO`J zesevk_HuMkJ~GD52o5o-U@1 z&pO}&myQyapb)0CRB6YgNy*SF+SQz3{8nzg4$C*Bg94VO5E&6U%bmy<15NVe8&g{4 zFR%__n74(V>XgxE(4u(L2ow|MX(=T#0 z0>Jcz`Vt9sgb`6kSul43SUrC4&b;AcmtzQJ1R67_77KM|JpflY6Xiu6hpK6`4$vh#;ELT`G$>!nX(gd=!auphQc9=`62YP3R(J{eA38OAOO>!7z2wx{hB zwA8#w)1y!$HZr)HsgP{N>_)3Cs1uDD*eRL9h48{5{=nZ9zrpp)mZ1+OlhgWr^P7## zH7CLcr|3qU)`?-w=sVYr*SQ>y$Qp`GFPIDAk5j>6)e?WpRaO`@C-cxMT>(oZ=rI4LrGKw zB^!99^m)_-7KW)U!sCpGV&80GaUfcx4xMdk>MSb^nK;hENg98pTm$?nQbfK)U4#5c zjWUF)^edi>3@=-RJB$QyoJElAqe;Ydo7>>k%a~lYqk_I7x<%koPK5ns5>Qo51T&$RmdoS5qYbmg~VhKuP|PVq|@Rh2Q` zgCFdH^Bd?3>5fiXZIW|?epa%B?K%CK&TtVyrM++U{exXrq06b;M$T{v!E`2Ox<;PI zh3{?2lP^y*V?>}IhSE0s7t$k~;)$F21_q6-Loqa+;fjLqIV>7@SQfP=Ta7_QY1!Ve zSW3d_ip7kqgR;V>ePeglidG+`4iDk4Jo(A^C3_;Wy@&C_Nn{Vg-vB|NtPVu+Vg1!D^ty$9qaeUoF1fqjaqAeY4J1lcF@z@ zNh#6vZR!QC`Tn8L#@j(fQ7e4Za4$Sa{6@>nXUOj>%!q_U(l4pCoW&95N8moQuw213 z8h9XH+36FJbW|Elk`@1y@8-1beMBrY@NhgPX4Q9Y+;pPx3C}~*Lq#UBc5i&yMKZ{; zC!pl>eX!-7FgBB(KG@w(7@yTcZT*@^24nVn2pe!2CcYaMrAY*Vr)a$2RO9Z>*|6?$ z?>9oDWdWBErOcoRX9#7NC9IhV=OndHV}+jyLEyO>{BLA7tnobh&-Q^VXJtgrQ8-~l zq8X~$SK3$|N!K%2yA8L)ZRt8DJAnF{57FMU;gZ9`Ch%&&)8w}El<4bOw2Yo_7oI#r z9eLc=ZPDIy;UCvqTua;19ZhI`%@CG79u_SqKr(O*&amKRXS$b3JfQ8dE4qC?d=|US z1l|k4)F;7x$Ai&X_xk>3Xh8fRL{0Iq8n9(amh9K#&O_JfCIi_rK@$QuP^qNub!4Qi zwB0CSM`g7y;x-?dtiZ=zmPr5}dA(U8-(s}d6|9!w(=M1i#A64xHJ*&hvD+Y_vt}yZ zZuf;f+c8f@vzGg^Ebk@Un5*4Oy%qS1OBra@i!ZSAho6p>jb%fY$TF^Gmlt1L=gD67 z6G*)0L?Fc=(TJiOIe=^^)ecOFWwB0oh1`UzbhQy72hVYshFY`O{s60H# z%WSBRxM|<}35X8)PwwC!d)(bKyWbq;Mtn(i-db~(pMnH7X?nVDtho*???2uh_ibk1 zExROx?xySgf0_maNF_Qv_$)PiN6H01V+q8@@^9OEmg_|h9=gwL_Rqq|-Tye|D08rZ zNAu$>u3+ZGOtnea_tUXRto37j?b@pN=gaJ`s`4 zy^oDIBT0qc*HW>PlvGcbtf;RZGjEmnpH(Yv8+C1SG3`HI{jiE#>N<9I4{)m)fvHS=yVCL^^9~Z%XgP(`ui=dS@|DX z*jH8{x7=Rpm-;Snu70bF7TVv$%{}_UsPN={qswaDxhY}iVcKogn*R))J3UMaA-TU3 znqPQS(s^@Lmm|3#&i0~Ixe0%N4<}bDmo%1QXXNIlzay{0!mi-u{vFbVUQ??i?L*K- zKv4dEL}yO1l=%vsFRUiz8fV^ISk$^9_(A`k7_Q_EdETt_y2^hSrMg(RzD6T z|5f6DIc^lXa7=$r*jjdG!+lHp>7&_A_T>>3`~kyCzWpH${gDhVf;qSvC87tP{0xaO zH?Dr}O4;9z2`p=S-3?@b#63gIuDT7`X<}|OJ$bo5-!8<@7IsD$GJMMOT1G@D|EkAe z9>Yi>3=|+Usk+?o`G?lnD^}(f3c4$L#^f_%?Ab9x3rHpR!9Eie)!;Sum)V-lFOL$m zgCaPxWo;gXcidw}SiHinA@Q zkT;LHn1MNF^MpCZJBxsuFXW6On1smFIVeR z!u-Y#hfF{u;!fdtpy{Wh$!euBHzw2Hvr>4D2Id_rvuV??@d~of+N9q+*y8w>cv}p+ zdoI;Wu#;{;B(A?GW@s7svQYSVQop@;pz%aX9(+_-e=OBYHm;`#ekgQv)_a`yy(?_T zn8aL3JNwm%Ke-%1j-borU}pQ^OAxB3xEhez#`EEpAar;S$g`Y84t=|1#+!Oe7`;kX zoAP#%CDoTGTDiiaemPgas*6I$LPeMJl9lwfdob{##F+Y$4G-jlhZnhFI>2!7BXh6X zYsjs4oj%H#kC!3VmlVUT{M$795O{6<)r$i8Zm%!v?OrEZ8=l-bw`6gvSoc(q@wDR- zpN6a#2Pv_Zhn`fy)0doX7M>j0p!^y+rB?qQz%1=K&nFb@n~8Ms?dIe1yc{ukZF5XE z1lq0smBVo!FEB?L@xWXV#j!P{nJIVVM)=3-I46Gf*N}0&2xdC+AY|?}nr3{^CcG5G z;*d%i;sg_3ic>gL^=fSOuvod5Q>vqTscfyXSatliIy|uDN9()SzouHZQeZQ|)@5#@ ziE0u`gTJsrI!Wa?u%Bg3rZt!s3>4cOBPc97EZXlBTk2h~FP@_z&nmY5+Rc$!Q3ATh z=s16^(UQ`Re>HOTW#WQuCV_az`dXt@OH=!M4GV@;ct>K$B`&SUv*wbXqCDe`DJd7W z1eI*%@UCiolRb6z%t(OXPKwHkvi#B(n$%)Rsdk|Dq{;pB6w&0UH6y@5%Y`NT1uMzv z`4>j*a@F#5K*zUiL&wPd&RI8Sixl7rk1WqR*9V&=&Oo z&~%tOp%jnC?5GRq2Yd20qHgiR_zd1X`% zeh~6jY~DsYkS~AkR+rgBn5g3#$IkCoILQZsbt!XQq_ziQmgKw8Gtnn!t$M|H2Gpk59lr0(9cXoWfv&(8MTQk@>xrCzu$r{Qw40d{t=K5ED z7|fh;ep;kogUFkV_n)gvy>u~U&brBokylx0EI(V+Zxyh3`_pxC=GN(8;f*YWHLD{R zcl@v-Z|2bnvgZAE(`&wQ7j?0%ajAVz-bwJDo+{XP^SkXdSpz|1;Qg8m&JMl^RGhrt z(=vhB1L-QCFuJvYA%;}*){0rh{96AD7#51kk&((+q7gY^^Fpq1G_ggKNpd= zn?bT}D{`JwQ!5tTn8o>g5DZhET<9Myru*RDEkcwYU z8R4@vy?x3+3zDfqzX#B%r+nXX&Tl`Dhh9F5pq}qN#A;|^ool>;-MFtF_I!b0-VawXD=x-bucyWsdhRLSv zaPoz4e>q(0G2FL2)8qs_!`bZ&zq03OBl9hde=v%)?BpoICo8Z$>7R~^&T}dqN^T9%|H%@yFp$imJ~8R}&3l5m zz#bCf$BxK=8lc?A=1ky!NS)N=zoZU8!00cj^COK!PMQDQANNy(u@3)vAnt*U0%h0Z z)E}wC{9EcM{E<3Bn}~l(ot()(QiqG~xcIl!(aK*BY&^zrgyK-@>ImmC2-$1ha+G&b zo`$01ro3~25>WD2YCEpRCmNoHkq&>Td0F{_Q+lnmtg6POWx1g^m$!VhW&3=hQPV;hpz1}h{fR5%@wV)mJ&Z^6z(L%z-BpCf zTOlG;+jf5kDj-Vj{2c9-V$b9cH60OeQBW*iL-ZH)9jKTnt9Mo&HW)9&k;ya4I3fU8@14tq7Q6sIi$?!*|((DvB1hBob9Kvh2pMbX3*R(|9XK zJZ~No1tsA%41X~ZJx#*<)dI?QN3crde5`iXH_*(PlC6HlDABGGH%=2J_-$zqrso5# zh}kJ;vDwA$CFi3H9iwImnf_uo(XM+ouaxS?=#qug8(D z=Uy%0BBW70Ri6qw(Bo6Sdl$58LXhlsh zQ7DF$l2K??>4;xuYHJ5q;k1+>)-A#T?MNU|Hb#p#a8QX>VEyn(`4a_DIYBOq$}_%1 z%^22L9R*MY0ZFMG4gL%{tWtrs-N%8cmGYHcCKCkcR%4mWN~c5HnUy>Oc)!T43IJoifQZ`X!+;JaY6vqT-3$peggMdR zxY3LF2{~FpAkEs496%5_wYQ)|biELUhMlI4uqH}nXXHspQ`av<8j*7m1pb!a4YsF^ z6iZAcbAm~QqYK^V(3GLIEZDRywJ~!FOsz(6W8t)*vhyYjHgC%+PVLHUm*r+KanR;% zd)h)Fq^CTE3=sw%C9ljfA`&rt~#pWkd@atw}m99Y}r z)XBKqj`Y0Q!sf6omv7)T)35JqoOzkwSOuqe!9C7HNu~lgQg|grl71b2r zl620=YOEfpJ^yeaxJ5&SXz5zDvO22&%*1ou6mvam2eirptg{pm{!JJW{3j;~D>o?Fafd#tE+d{T>(`s-@!m)OoZkA47)O4cgMy55*%M zQpN!1qn>;uG^U0};1#`q(Qlb+xXW}r{7})@l4s%P9Ta%SD6l0x^6S4#AUn$a)2 z^OqW1N-R00s?WT5HAVB+r*&Dt#Y8U}N-|epf_TkDUn%^j)X5Rk`Y);DmHjWNQ*iZ9 zse^s?m(5}=VR8<~XntJ5t@@kNQD{(?!0Pc4z13Z8wwVwVa z@Ol92Q;I{&91n}kw!@(ez?Pw2eRr)p&ss;FR&ZXl$UK@&mZCAd6QD#U4jLF(-tT8d zTLx^Ikv*x3LT^hCHCZU^d2r*w@1&o8D2qzyS$dgz5`%;B xp{eG;9jh+|$7p>0;>_fM#Gi!welMJ(rf>BlM`Of$>VCqHjsBa}mWku_e*kl`#BKlp literal 0 HcmV?d00001 diff --git a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat index cc41a2737f34f704d61b9a7b8b0067270a3d200c..16463dadb1f47ce523c2894c886bda30a91be248 100644 GIT binary patch delta 41 wcmeyx{fm2ov4n4ao`P>;k%EGyf`NsVk)f4=iGq=Vk=evR<%tPw8%x?)0Sl81A^-pY delta 41 wcmeyx{fm2ou|#lbo`P>;k%EGSf{~$>sil>viGq=Vk=evR<%tPw8%x?)0S(CvH~;_u diff --git a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat index 3f487410cac07032ea01a06e2312cfa9c07c1354..4a66e2c0145fdd835380044fc6c9758313ba3ffa 100644 GIT binary patch delta 41 wcmeC+>foASEa97foASED>Cqr{J4dq@ZA-U}R`zYH4L`reI`XWHvESd13;k%EGyf`NsVk)f4=nSzmlk=evR<%tPw8%wS-0|5L63<&@L delta 41 xcmX@fagt+#u|#lbo`P>;k%EGSf{~$>sil>vxq^{_k=evR<%tPw8%wS-0|5R^3?Kji diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index caec6bea5416ae55579a42bb5bbd78683452e055..b3cc8a01bdce4ceabf0035cd61bca0de081efcf8 100644 GIT binary patch delta 1102 zcmZ3@eL{GGv4n4ao`P>;k%EGyf`O%#p@Eg5rGk-xk=evR<%tPw6H5f@0~i<>D&{;+ zPDt=zNHSA+)^JJT0M}DvCPVHi&zK*bd35H-qoylQCM~*j=LZ)Xo8N*NOakX;aWYxI zHBJ7-UwidNay}2+^tZ0pcG|BI+L_MJki1w_kRNDf2yy1tgADdUGFbbqwhPFxtd0D) z6tDlX`q1)s;d&{DYeh)018&w?91At5Cpf%%e*15Zj5qqfj|Nv~q&!mzru zIM152!n7LJ2w8c7Y6FEYn;%}}XPTq%yP@Kmk3d9byyB*Fe-lHyYwIuY?YwrqQT>hM zGfM%3e$KNtj0$Yo_YKd>xlPXI+f~i5_l(k9kd-;aTB(X*0MpXy)Wo2iQf79WI@3a<%84w3EppF0L+4%kbBX zfG>uIFJ2i5q{^GsGVBuLwg4II2sN0Yp2>}k%}n9wL578n4Wg&+^G{?;u2VRtbFSz7 z2_3yo|8x3hwKdZcnpQP3yRj)TUgTQIy&>xK2?mC>S2$!rCgDs!##6-D!0D&w%%e#^ zwj6nK<;RmhbN(o8sPa>o#_;I>|NrObo4&4}YCB=t^hs5z{8yGMX#Tt;$B-Q@akuefW20{GN5R3Df*()*T>Po~z$?~Ud?b)x>He&_mz6_1%8F{u2_iTJWvVdiPhBU}zK z&;;^#%}j+2m;PM{@160_+5K&E*@>-j(&BXHHDt3ep)VP7(DH{#s2DeJmX@hto&KIlHpYm$5oIaMsPz45IJ4(sj*UG mor2$54Uf}0o#(aB_WG~!&`nEVI%U9kl(BO=+qI1i+N=Qb|HG*O delta 45 zcmX>hyqbH0u|#lbo`P>;k%EGSf{~$>sil>XrGk-xk=evR<%tPw6H5d(FJw_?1pq9h B45|PC diff --git a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat index b710e26a7a4c7ce16589a7f853ca8f0e5448edc9..4ef2eb5bf3afd7a27d2b1e67656ba8ae02afb244 100644 GIT binary patch delta 764 zcmeyydW2_!v4n4ao`P>;k%EGyf`O%#p@Eg5rGk-xk=evR<%tPw8%vs)CN{_`_?_0# z(&N-SamJ7LR6B|V88It_e*bEX{KS@qB;1^I&eWofF%9g&2eW_T(6~o3<2?>S@Kh`jifp$P# zG3Rk|LV^ZEl9|FYh9oyOjs(|7hYjono{LWZq55puF43pYtDY^DF#5oFlxM=1%?~c} zGch?#Q=GuYaA__(E65agxG5bVQyiZ)Tv9l|_0*WjP-Tti@J_?-3R^gh$+ zuj5^QTIbXmZ+aG>i*defyp4s&LO&zUuQ z%4DhelV?ns6@2CZyGXIarSna$_Ue;%G6n2g7`mmguW??YPRN{?;+_j<7#SAS^CW>R zHijEpz%aRsiMyWBP+8;k%EGSf{~$>sil>XrGk-xk=evR<%tPw8%vs)00Z3%^Z)<= diff --git a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat index e9a75967456a431544dbecc3dca318f71431eaa6..4dc75a252bdd98f854089f1ae3c6f1634f1e1a62 100644 GIT binary patch delta 41 xcmbQvHJxjMv4n4ao`P>;k%EGyf`NsVk)f4=iGq=Vk=evR<%tPw8%q|m008WN3w8hi delta 41 xcmbQvHJxjMu|#lbo`P>;k%EGSf{~$>sil>viGq=Vk=evR<%tPw8%q|m008c_3yc5& diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index b6fb6f333b41292092ee8ed1bfb8b7311ac3ba53..4033bdd935f5b26fb73f174278aa13fee0b8982d 100644 GIT binary patch delta 41 xcmbQoHIHk8v4n4ao`P>;k%EGyf`O%#p@Eg5v4W9-k=evR<%tPw8%tKP008c-3yJ^$ delta 41 xcmbQoHIHk8u|#lbo`P>;k%EGSf{~$>sil>XnSzmlk=evR<%tPw8%tKP008jd3!nf1 diff --git a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat index d0a47e8b7bf7d19ca0bcb27e80320a5094db6fff..61ea8c14ea2a3416df675a7ddd3aced5f561acf4 100644 GIT binary patch delta 41 wcmdnXy_b7}v4n4ao`P>;k%EGyf`NsVk)f4=xq^{_k=evR<%tPw8%sP{0r)BlZU6uP delta 41 vcmdnXy_b7}u|#lbo`P>;k%EGSf{~$>sil=EkYQkCHZf3nVglR75>HkD`fv+{ diff --git a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat index b774fe31d3e54ebdabc6dfd55d977e75ac312151..7f54ac06e8ead9e0b2cccaa19ce3917fe971f712 100644 GIT binary patch delta 41 wcmey%@t0$Qv4n4ao`P>;k%EGyf`NsVk)f4=iGq=Vk=evR<%tPw8%sD@01BQ9qyPW_ delta 41 wcmey%@t0$Qu|#lbo`P>;k%EGSf{~$>sil>viGq=Vk=evR<%tPw8%sD@01VU%x&QzG diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index b8e7742c..081e45c8 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -19,6 +19,7 @@ Events, FitResult, FitResSummary, + History, SignalObj, Trial, TrialConfig, @@ -70,6 +71,17 @@ def _string_list(payload: dict[str, np.ndarray], key: str) -> list[str]: return [str(item) for item in arr.reshape(-1)] +def _object_vectors(payload: dict[str, np.ndarray], key: str) -> list[np.ndarray]: + value = np.asarray(payload[key], dtype=object) + if value.shape == (): + value = np.asarray([value.item()], dtype=object) + if value.dtype != object: + return [np.asarray(value, dtype=float).reshape(-1)] + if value.ndim == 1 and all(not isinstance(item, (list, tuple, np.ndarray)) for item in value.reshape(-1)): + return [np.asarray(value, dtype=float).reshape(-1)] + return [np.asarray(item, dtype=float).reshape(-1) for item in value.reshape(-1)] + + def test_signalobj_matches_matlab_gold_fixture() -> None: payload = _load_fixture("signalobj_exactness.mat") signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"]) @@ -138,6 +150,74 @@ def test_nspiketrain_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(float(burst_train.numBursts), _scalar(payload, "burst_numBursts"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(burst_train.numSpikesPerBurst, _vector(payload, "burst_numSpikesPerBurst"), rtol=1e-8, atol=1e-10) + fig, ax = plt.subplots() + try: + line = nst.plotISISpectrumFunction() + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float), _vector(payload, "isi_spectrum_x"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float), _vector(payload, "isi_spectrum_y"), rtol=1e-12, atol=1e-12) + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + joint_ax = nst.plotJointISIHistogram() + joint_lines = list(joint_ax.lines) + expected_styles = _string_list(payload, "joint_isi_style") + assert len(joint_lines) == len(expected_styles) + for line, expected_x, expected_y, expected_style in zip( + joint_lines, + _object_vectors(payload, "joint_isi_x"), + _object_vectors(payload, "joint_isi_y"), + expected_styles, + strict=True, + ): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), expected_x, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), expected_y, rtol=1e-12, atol=1e-12) + assert str(line.get_linestyle()).lower() == str(expected_style).lower() + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + counts = nst.plotISIHistogram(handle=ax) + np.testing.assert_allclose(np.asarray(counts, dtype=float).reshape(-1), _vector(payload, "isi_hist_counts"), rtol=1e-12, atol=1e-12) + patches = list(ax.patches) + assert patches + np.testing.assert_allclose(np.asarray(patches[0].get_facecolor()[:3], dtype=float), _vector(payload, "isi_hist_face_color"), rtol=1e-12, atol=1e-12) + expected_edge = _string(payload, "isi_hist_edge_color") + if expected_edge.lower() == "none": + edge = np.asarray(patches[0].get_edgecolor()) + assert edge.size == 0 or edge.reshape(-1, 4)[0, 3] == 0.0 + else: + np.testing.assert_allclose(np.asarray(patches[0].get_edgecolor()[:3], dtype=float), _vector(payload, "isi_hist_edge_color"), rtol=1e-12, atol=1e-12) + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + prob_ax = nst.plotProbPlot(handle=ax) + prob_lines = list(prob_ax.lines) + expected_styles = _string_list(payload, "probplot_style") + assert len(prob_lines) == len(expected_styles) + for line, expected_x, expected_y, expected_style in zip( + prob_lines, + _object_vectors(payload, "probplot_x"), + _object_vectors(payload, "probplot_y"), + expected_styles, + strict=True, + ): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), expected_x, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), expected_y, rtol=1e-8, atol=1e-10) + assert str(line.get_linestyle()).lower() == str(expected_style).lower() + finally: + plt.close("all") + + fig = nst.plotExponentialFit() + try: + assert len(fig.axes) == int(_scalar(payload, "expfit_num_axes")) + finally: + plt.close(fig) + def test_covariate_and_confidence_interval_match_matlab_gold_fixture() -> None: payload = _load_fixture("covariate_exactness.mat") @@ -147,11 +227,43 @@ def test_covariate_and_confidence_interval_match_matlab_gold_fixture() -> None: cov = Covariate(time, replicates, "Stimulus", "time", "s", "a.u.", ["r1", "r2", "r3", "r4"]) mean_cov = cov.computeMeanPlusCI(0.05) cov_single = Covariate(time, np.mean(replicates, axis=1), "StimulusSingle", "time", "s", "a.u.", ["stim"]) - cov_single.setConfInterval(ConfidenceInterval(time, np.asarray(payload["explicit_ci"], dtype=float), "b")) + cov_single.setConfInterval( + ConfidenceInterval( + time, + np.asarray(payload["explicit_ci"], dtype=float), + "CI", + "time", + "s", + "a.u.", + ) + ) np.testing.assert_allclose(mean_cov.data[:, 0], _vector(payload, "mean_data"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(mean_cov.ci[0].bounds, np.asarray(payload["mean_ci"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(cov_single.ci[0].bounds, np.asarray(payload["explicit_ci"], dtype=float), rtol=1e-8, atol=1e-10) + structure = cov_single.toStructure() + np.testing.assert_allclose(np.asarray(structure["time"], dtype=float), _vector(payload, "structure_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["data"], dtype=float).reshape(-1), _vector(payload, "structure_data"), rtol=1e-12, atol=1e-12) + assert structure["name"] == _string(payload, "structure_name") + assert list(structure["dataLabels"]) == _string_list(payload, "structure_dataLabels") + assert list(structure["plotProps"]) == _string_list(payload, "structure_plotProps") + assert isinstance(structure["ci"], dict) + np.testing.assert_allclose(np.asarray(structure["ci"]["signals"]["values"], dtype=float), np.asarray(payload["structure_ci_values"], dtype=float), rtol=1e-12, atol=1e-12) + assert structure["ci"]["name"] == _string(payload, "structure_ci_name") + + roundtrip = Covariate.fromStructure(structure) + np.testing.assert_allclose(roundtrip.data.reshape(-1), _vector(payload, "roundtrip_data"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(roundtrip.ci[0].bounds, np.asarray(payload["roundtrip_ci"], dtype=float), rtol=1e-12, atol=1e-12) + assert list(roundtrip.dataLabels) == _string_list(payload, "roundtrip_dataLabels") + + fig, ax = plt.subplots() + try: + cov_single.plot(handle=ax) + line_handles = list(ax.lines) + line_colors = np.asarray([mcolors.to_rgb(line.get_color()) for line in line_handles], dtype=float) + np.testing.assert_allclose(line_colors, np.asarray(payload["plot_line_colors"], dtype=float), rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) def test_confidence_interval_matches_matlab_gold_fixture() -> None: @@ -212,6 +324,7 @@ def test_nstcoll_matches_matlab_gold_fixture() -> None: n2 = nspikeTrain(_vector(payload, "secondSpikeTimes"), "2", 10.0, 0.0, 0.5, "time", "s", "spikes", "spk", -1) coll = nstColl([n1, n2]) collapsed = coll.toSpikeTrain() + coll.setNeighbors() np.testing.assert_equal(coll.numSpikeTrains, int(_scalar(payload, "numSpikeTrains"))) assert coll.getNST(1).name == _string(payload, "firstName") @@ -221,6 +334,25 @@ def test_nstcoll_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(float(collapsed.minTime), _scalar(payload, "collapsedMinTime"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(float(collapsed.maxTime), _scalar(payload, "collapsedMaxTime"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(float(collapsed.sampleRate), _scalar(payload, "collapsedSampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(coll.getFirstSpikeTime()), _scalar(payload, "firstSpikeTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(coll.getLastSpikeTime()), _scalar(payload, "lastSpikeTime"), rtol=1e-12, atol=1e-12) + assert coll.isSigRepBinary() == bool(_scalar(payload, "binarySigRep")) + assert coll.BinarySigRep() == bool(_scalar(payload, "binarySigRep")) + assert coll.getNSTnameFromInd(1) == _string(payload, "nstNameFromInd1") + nst_from_name = coll.getNSTFromName("1") + assert isinstance(nst_from_name, nspikeTrain) + np.testing.assert_allclose(nst_from_name.spikeTimes, _vector(payload, "nstFromName1_spikeTimes"), rtol=1e-12, atol=1e-12) + fieldVal, neuronNumbers = coll.getFieldVal("avgFiringRate") + np.testing.assert_allclose(fieldVal, _vector(payload, "fieldVal_avgFiringRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(neuronNumbers, _vector(payload, "fieldVal_neuronNumbers"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(coll.getNeighbors(1), dtype=float), _vector(payload, "neighbors1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(coll.getNeighbors(2), dtype=float), _vector(payload, "neighbors2"), rtol=1e-12, atol=1e-12) + ensembleCov = coll.getEnsembleNeuronCovariates(1, [], [0.0, 0.1]) + assert ensembleCov.getAllCovLabels() == _string_list(payload, "ensemble_labels") + np.testing.assert_allclose(ensembleCov.dataToMatrix(), np.asarray(payload["ensemble_matrix"], dtype=float).reshape(-1, 1), rtol=1e-12, atol=1e-12) + psthCov = coll.psth(0.1, [1, 2], 0.0, 0.5) + np.testing.assert_allclose(psthCov.time, _vector(payload, "psth_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(psthCov.data.reshape(-1), _vector(payload, "psth_data"), rtol=1e-12, atol=1e-12) def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: @@ -286,6 +418,58 @@ def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: np.testing.assert_allclose(np.asarray(trial_from_coll.covarColl.getCov(1).time, dtype=float), _vector(payload, "applied_coll_shifted_position_time"), rtol=1e-12, atol=1e-12) +def test_covcoll_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("covcoll_exactness.mat") + time = np.array([0.0, 0.5, 1.0], dtype=float) + position = Covariate(time, np.column_stack([[0.0, 1.0, 2.0], [10.0, 11.0, 12.0]]), "Position", "time", "s", "", ["x", "y"]) + stimulus = Covariate(time, [5.0, 6.0, 7.0], "Stimulus", "time", "s", "a.u.", ["stim"]) + coll = CovColl([position, stimulus]) + coll.setMask([["Position", "x"], ["Stimulus"]]) + + np.testing.assert_allclose(coll.getCov(1).time, _vector(payload, "masked_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose( + coll.dataToMatrix(), + np.asarray(payload["masked_matrix"], dtype=float).reshape(coll.dataToMatrix().shape), + rtol=1e-12, + atol=1e-12, + ) + assert coll.getCovLabelsFromMask() == _string_list(payload, "masked_labels") + + data_structure = coll.dataToStructure() + np.testing.assert_allclose(np.asarray(data_structure["time"], dtype=float).reshape(-1), _vector(payload, "data_structure_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose( + np.asarray(data_structure["signals"]["values"], dtype=float), + np.asarray(payload["data_structure_values"], dtype=float).reshape(np.asarray(data_structure["signals"]["values"], dtype=float).shape), + rtol=1e-12, + atol=1e-12, + ) + + structure = coll.toStructure() + assert int(structure["numCov"]) == int(_scalar(payload, "structure_numCov")) + np.testing.assert_allclose(float(structure["minTime"]), _scalar(payload, "structure_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(structure["maxTime"]), _scalar(payload, "structure_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(coll.covMask[0], _vector(payload, "post_structure_mask_1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(coll.covMask[1], _vector(payload, "post_structure_mask_2"), rtol=1e-12, atol=1e-12) + + roundtrip = CovColl.fromStructure(structure) + np.testing.assert_allclose(float(roundtrip.minTime), _scalar(payload, "roundtrip_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.maxTime), _scalar(payload, "roundtrip_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.sampleRate), _scalar(payload, "roundtrip_sampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(roundtrip.dataToMatrix(), np.asarray(payload["roundtrip_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert roundtrip.getCovLabelsFromMask() == _string_list(payload, "roundtrip_labels") + + shifted = CovColl([position, stimulus]) + shifted.setCovShift(0.25) + shifted.restrictToTimeWindow(0.25, 1.25) + np.testing.assert_allclose(float(shifted.minTime), _scalar(payload, "shifted_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(shifted.maxTime), _scalar(payload, "shifted_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(shifted.getCov(2).time, _vector(payload, "shifted_stim_time"), rtol=1e-12, atol=1e-12) + + assert coll.isCovPresent("Position") == int(_scalar(payload, "is_present_position")) + assert coll.isCovPresent(2) == int(_scalar(payload, "is_present_last_index")) + assert coll.copy().numCov == int(_scalar(payload, "copy_numCov")) + + def test_events_match_matlab_gold_fixture() -> None: payload = _load_fixture("events_exactness.mat") events = Events(_vector(payload, "eventTimes"), _string_list(payload, "eventLabels"), _string(payload, "eventColor")) @@ -314,6 +498,48 @@ def test_events_match_matlab_gold_fixture() -> None: plt.close(fig) +def test_history_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("history_exactness.mat") + history = History(_vector(payload, "windowTimes"), _scalar(payload, "minTime"), _scalar(payload, "maxTime")) + rebuilt = History.fromStructure(history.toStructure()) + assert rebuilt is not None + np.testing.assert_allclose(rebuilt.windowTimes, _vector(payload, "roundtrip_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(rebuilt.minTime), _scalar(payload, "roundtrip_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(rebuilt.maxTime), _scalar(payload, "roundtrip_maxTime"), rtol=1e-12, atol=1e-12) + + n1 = nspikeTrain([0.0, 0.5, 1.0], "n1", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + n2 = nspikeTrain([0.25, 0.75], "n2", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + coll = nstColl([n1, n2]) + + single_cov = history.computeHistory(n1, 1) + coll_cov = history.computeHistory(coll, 2) + np.testing.assert_allclose(single_cov.dataToMatrix(), np.asarray(payload["single_history_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert single_cov.getAllCovLabels() == _string_list(payload, "single_history_labels") + assert single_cov.getCov(1).name == _string(payload, "single_history_name") + np.testing.assert_allclose(coll_cov.dataToMatrix(), np.asarray(payload["coll_history_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert coll_cov.getAllCovLabels() == _string_list(payload, "coll_history_labels") + assert [coll_cov.getCov(index).name for index in range(1, coll_cov.numCov + 1)] == _string_list(payload, "coll_cov_names") + + filter_bank = history.toFilter(_scalar(payload, "filter_delta")) + expected_num = _object_vectors(payload, "filter_num") + expected_den = _object_vectors(payload, "filter_den") + assert filter_bank.numFilters == len(expected_num) + for idx, expected in enumerate(expected_num): + np.testing.assert_allclose(filter_bank[idx].numerator.reshape(-1), expected, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(filter_bank[idx].denominator.reshape(-1), expected_den[idx], rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(filter_bank[idx].delta), _scalar(payload, "filter_delta"), rtol=1e-12, atol=1e-12) + + fig, ax = plt.subplots() + try: + lines = history.plot(handle=ax) + assert len(lines) == len(_object_vectors(payload, "plot_x")) + for line, xdata, ydata in zip(lines, _object_vectors(payload, "plot_x"), _object_vectors(payload, "plot_y")): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), xdata, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), ydata, rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) + + def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("cif_exactness.mat") cif = CIF( @@ -378,6 +604,37 @@ def test_analysis_fit_surface_matches_matlab_gold_fixture() -> None: assert fit.fitType[0] == _string(payload, "distribution") +def test_analysis_multineuron_surface_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("analysis_multineuron_exactness.mat") + time = _vector(payload, "time") + stim_data = _vector(payload, "stim_data") + stim = Covariate(time, stim_data, "Stimulus", "time", "s", "", ["stim"]) + spike_train_1 = nspikeTrain(_vector(payload, "spike_times_1"), "1", 0.1, 0.0, 1.0, "time", "s", "", "", -1) + spike_train_2 = nspikeTrain(_vector(payload, "spike_times_2"), "2", 0.1, 0.0, 1.0, "time", "s", "", "", -1) + trial = Trial(nstColl([spike_train_1, spike_train_2]), CovColl([stim])) + cfg = TrialConfig([["Stimulus", "stim"]], 10, [], [], name="stim") + fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl([cfg]), makePlot=0) + assert isinstance(fits, list) + assert len(fits) == int(_scalar(payload, "num_fits")) + + np.testing.assert_allclose(fits[0].getCoeffs(1), _vector(payload, "fit1_coeffs"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(fits[1].getCoeffs(1), _vector(payload, "fit2_coeffs"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[0].AIC[0]), _scalar(payload, "fit1_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].AIC[0]), _scalar(payload, "fit2_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].BIC[0]), _scalar(payload, "fit1_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].BIC[0]), _scalar(payload, "fit2_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].logLL[0]), _scalar(payload, "fit1_logLL"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[1].logLL[0]), _scalar(payload, "fit2_logLL"), rtol=1e-6, atol=1e-8) + + summary = FitResSummary(fits) + np.testing.assert_allclose(summary.AIC, np.asarray(payload["summary_AIC"], dtype=float).reshape(summary.AIC.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.BIC, np.asarray(payload["summary_BIC"], dtype=float).reshape(summary.BIC.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.logLL, np.asarray(payload["summary_logLL"], dtype=float).reshape(summary.logLL.shape), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(summary.KSStats, np.asarray(payload["summary_KSStats"], dtype=float).reshape(summary.KSStats.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.KSPvalues, np.asarray(payload["summary_KSPvalues"], dtype=float).reshape(summary.KSPvalues.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.withinConfInt, np.asarray(payload["summary_withinConfInt"], dtype=float).reshape(summary.withinConfInt.shape), rtol=1e-8, atol=1e-10) + + def test_analysis_discrete_ks_arrays_match_matlab_gold_fixture() -> None: payload = _load_fixture("ksdiscrete_exactness.mat") diff --git a/tests/test_trial_fidelity.py b/tests/test_trial_fidelity.py index 067a7d53..a9258637 100644 --- a/tests/test_trial_fidelity.py +++ b/tests/test_trial_fidelity.py @@ -7,6 +7,7 @@ from nstat.ConfigColl import ConfigColl from nstat.CovColl import CovColl from nstat.nstColl import nstColl +from nstat.SignalObj import SignalObj def _make_covariates() -> tuple[Covariate, Covariate]: @@ -64,6 +65,22 @@ def test_nstcoll_neighbors_mask_and_data_matrix() -> None: np.testing.assert_allclose(matrix, [[1.0, 0.0], [1.0, 1.0], [1.0, 1.0]]) +def test_nstcoll_psthbars_public_contract() -> None: + train1, train2 = _make_spikes() + coll = nstColl([train1, train2]) + + bars = coll.psthBars(0.5, [1, 2], 0.0, 1.0) + + assert isinstance(bars, SignalObj) + assert bars.name == "PSTH_{bars}" + assert bars.dataLabels == ["mode", "mean", "ciLower", "ciUpper"] + np.testing.assert_allclose(bars.time, [0.0, 0.5, 1.0]) + assert bars.data.shape == (3, 4) + np.testing.assert_allclose(bars.data[:, 0], bars.data[:, 1]) + assert np.all(bars.data[:, 2] <= bars.data[:, 1]) + assert np.all(bars.data[:, 1] <= bars.data[:, 3]) + + def test_trialconfig_and_configcoll_apply_and_roundtrip() -> None: position, stimulus = _make_covariates() train1, train2 = _make_spikes() diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index 30cd9eae..26ef9afd 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -24,9 +24,12 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_nspiketrain_fixture(fixtureRoot); export_nstcoll_fixture(fixtureRoot); export_config_fixture(fixtureRoot); +export_covcoll_fixture(fixtureRoot); export_events_fixture(fixtureRoot); +export_history_fixture(fixtureRoot); export_cif_fixture(fixtureRoot); export_analysis_fixture(fixtureRoot); +export_analysis_multineuron_fixture(fixtureRoot); export_ksdiscrete_fixture(fixtureRoot); export_fit_summary_fixture(fixtureRoot); export_point_process_fixture(fixtureRoot); @@ -38,6 +41,63 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_simulated_network_fixture(fixtureRoot); end +function export_history_fixture(fixtureRoot) +histObj = History([0 0.5 1.0], 0.0, 1.0); +n1 = nspikeTrain([0.0 0.5 1.0], 'n1', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +n2 = nspikeTrain([0.25 0.75], 'n2', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +coll = nstColl({n1, n2}); + +singleCov = histObj.computeHistory(n1, 1); +collCov = histObj.computeHistory(coll, 2); +structure = histObj.toStructure; +roundtrip = History.fromStructure(structure); +filterMat = histObj.toFilter(0.5); + +fig = figure('Visible', 'off'); +histObj.plot(); +ax = gca; +lineHandles = flipud(findobj(ax, 'Type', 'line')); +lineLabels = cell(1, numel(lineHandles)); +lineX = cell(1, numel(lineHandles)); +lineY = cell(1, numel(lineHandles)); +for iLine = 1:numel(lineHandles) + lineLabels{iLine} = get(lineHandles(iLine), 'DisplayName'); + lineX{iLine} = get(lineHandles(iLine), 'XData'); + lineY{iLine} = get(lineHandles(iLine), 'YData'); +end +close(fig); + +payload = struct(); +payload.windowTimes = histObj.windowTimes; +payload.minTime = histObj.minTime; +payload.maxTime = histObj.maxTime; +payload.structure_windowTimes = structure.windowTimes; +payload.roundtrip_windowTimes = roundtrip.windowTimes; +payload.roundtrip_minTime = roundtrip.minTime; +payload.roundtrip_maxTime = roundtrip.maxTime; +payload.single_history_matrix = singleCov.dataToMatrix(); +payload.single_history_labels = singleCov.getAllCovLabels; +payload.single_history_name = singleCov.getCov(1).name; +payload.coll_history_matrix = collCov.dataToMatrix(); +payload.coll_history_labels = collCov.getAllCovLabels; +payload.coll_cov_names = cell(1, collCov.numCov); +for iCov = 1:collCov.numCov + payload.coll_cov_names{iCov} = collCov.getCov(iCov).name; +end +payload.filter_num = cell(size(filterMat)); +payload.filter_den = cell(size(filterMat)); +for idx = 1:numel(filterMat) + payload.filter_num{idx} = filterMat(idx).Numerator{1}; + payload.filter_den{idx} = filterMat(idx).Denominator{1}; +end +payload.filter_delta = filterMat.Ts; +payload.plot_labels = lineLabels; +payload.plot_x = lineX; +payload.plot_y = lineY; + +save(fullfile(fixtureRoot, 'history_exactness.mat'), '-struct', 'payload'); +end + function export_events_fixture(fixtureRoot) events = Events([0.2 0.7], {'E1','E2'}, 'g'); fig = figure('Visible', 'off'); @@ -167,6 +227,50 @@ function export_nspiketrain_fixture(fixtureRoot) burstTrain = nspikeTrain([0.0; 0.001; 0.002; 0.007; 0.507; 0.508; 0.509; 0.514], 'bursting', 0.001, 0.0, 0.6, 'time', 's', 'spikes', 'spk', 0); payload = struct(); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +h = nst.plotISISpectrumFunction(); +payload.isi_spectrum_x = get(h,'XData'); +payload.isi_spectrum_y = get(h,'YData'); +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +nst.plotJointISIHistogram(); +jointLines = flipud(findobj(ax, 'Type', 'line')); +for iLine = 1:numel(jointLines) + payload.joint_isi_x{iLine} = get(jointLines(iLine), 'XData'); + payload.joint_isi_y{iLine} = get(jointLines(iLine), 'YData'); + payload.joint_isi_style{iLine} = get(jointLines(iLine), 'LineStyle'); +end +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +counts = nst.plotISIHistogram(); +histBars = findobj(ax, 'Type', 'patch'); +payload.isi_hist_counts = counts; +if(~isempty(histBars)) + payload.isi_hist_face_color = get(histBars(1), 'FaceColor'); + payload.isi_hist_edge_color = get(histBars(1), 'EdgeColor'); +end +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +nst.plotProbPlot(); +probLines = flipud(findobj(ax, 'Type', 'line')); +for iLine = 1:numel(probLines) + payload.probplot_x{iLine} = get(probLines(iLine), 'XData'); + payload.probplot_y{iLine} = get(probLines(iLine), 'YData'); + payload.probplot_style{iLine} = get(probLines(iLine), 'LineStyle'); +end +close(fig); + +fig = nst.plotExponentialFit(); +payload.expfit_num_axes = numel(findobj(fig, 'Type', 'axes')); +close(fig); payload.spikeTimes = spikeTimes; payload.binwidth = binwidth; payload.minTime = 0.0; @@ -200,6 +304,23 @@ function export_covariate_fixture(fixtureRoot) ci = ConfidenceInterval(t, [mean(replicates,2)-0.1, mean(replicates,2)+0.1], 'CI', 'time', 's', 'a.u.'); covSingle = Covariate(t, mean(replicates,2), 'StimulusSingle', 'time', 's', 'a.u.', {'stim'}); covSingle.setConfInterval(ci); +structure = covSingle.toStructure; +roundtrip = Covariate.fromStructure(structure); +if(iscell(structure.ci)) + ciStructure = structure.ci{1}; +else + ciStructure = structure.ci; +end + +fig = figure('Visible','off'); +plot(covSingle); +drawnow; +lineHandles = findobj(gca,'Type','line'); +plotColors = zeros(length(lineHandles),3); +for i = 1:length(lineHandles) + plotColors(i,:) = get(lineHandles(i),'Color'); +end +close(fig); payload = struct(); payload.time = t; @@ -207,6 +328,17 @@ function export_covariate_fixture(fixtureRoot) payload.mean_data = meanCov.data; payload.mean_ci = meanCov.ci{1}.data; payload.explicit_ci = covSingle.ci{1}.data; +payload.structure_time = structure.time; +payload.structure_data = structure.data; +payload.structure_name = structure.name; +payload.structure_dataLabels = structure.dataLabels; +payload.structure_plotProps = structure.plotProps; +payload.structure_ci_values = ciStructure.signals.values; +payload.structure_ci_name = ciStructure.name; +payload.roundtrip_data = roundtrip.data; +payload.roundtrip_ci = roundtrip.ci{1}.data; +payload.roundtrip_dataLabels = roundtrip.dataLabels; +payload.plot_line_colors = plotColors; save(fullfile(fixtureRoot, 'covariate_exactness.mat'), '-struct', 'payload'); end @@ -217,6 +349,11 @@ function export_nstcoll_fixture(fixtureRoot) coll = nstColl({n1, n2}); dataMat = coll.dataToMatrix([1 2], 0.1, 0.0, 0.5); collapsed = coll.toSpikeTrain; +coll.setNeighbors; +neighbors1 = coll.getNeighbors(1); +neighbors2 = coll.getNeighbors(2); +ensembleCov = coll.getEnsembleNeuronCovariates(1, [], [0.0 0.1]); +psthCov = coll.psth(0.1, [1 2], 0.0, 0.5); payload = struct(); payload.numSpikeTrains = coll.numSpikeTrains; @@ -229,6 +366,20 @@ function export_nstcoll_fixture(fixtureRoot) payload.collapsedMinTime = collapsed.minTime; payload.collapsedMaxTime = collapsed.maxTime; payload.collapsedSampleRate = collapsed.sampleRate; +payload.firstSpikeTime = coll.getFirstSpikeTime; +payload.lastSpikeTime = coll.getLastSpikeTime; +payload.binarySigRep = coll.isSigRepBinary; +payload.nstNameFromInd1 = coll.getNSTnameFromInd(1); +payload.nstFromName1_spikeTimes = coll.getNSTFromName('1').spikeTimes; +[fieldVal, neuronNumbers] = coll.getFieldVal('avgFiringRate'); +payload.fieldVal_avgFiringRate = fieldVal; +payload.fieldVal_neuronNumbers = neuronNumbers; +payload.neighbors1 = neighbors1; +payload.neighbors2 = neighbors2; +payload.ensemble_labels = ensembleCov.getAllCovLabels; +payload.ensemble_matrix = ensembleCov.dataToMatrix(); +payload.psth_time = psthCov.time; +payload.psth_data = psthCov.data; save(fullfile(fixtureRoot, 'nstcoll_exactness.mat'), '-struct', 'payload'); end @@ -302,6 +453,53 @@ function export_config_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'config_exactness.mat'), '-struct', 'payload'); end +function export_covcoll_fixture(fixtureRoot) +t = (0:0.5:1.0)'; +position = Covariate(t, [0 10; 1 11; 2 12], 'Position', 'time', 's', '', {'x','y'}); +stimulus = Covariate(t, [5; 6; 7], 'Stimulus', 'time', 's', 'a.u.', {'stim'}); +coll = CovColl({position, stimulus}); +coll.setMask({{'Position','x'},{'Stimulus'}}); +maskedLabels = coll.getCovLabelsFromMask; +maskedMatrix = coll.dataToMatrix(); +maskedTime = coll.getCov(1).time; +dataStructure = coll.dataToStructure; +structure = coll.toStructure; +postMask1 = coll.covMask{1}; +postMask2 = coll.covMask{2}; +roundtrip = CovColl.fromStructure(structure); +copyColl = coll.copy; + +shifted = CovColl({position, stimulus}); +shifted.setCovShift(0.25); +shifted.restrictToTimeWindow(0.25, 1.25); +shiftedStim = shifted.getCov(2); + +payload = struct(); +payload.masked_labels = maskedLabels; +payload.masked_matrix = maskedMatrix; +payload.masked_time = maskedTime; +payload.data_structure_time = dataStructure.time; +payload.data_structure_values = dataStructure.signals.values; +payload.post_structure_mask_1 = postMask1; +payload.post_structure_mask_2 = postMask2; +payload.structure_numCov = structure.numCov; +payload.structure_minTime = structure.minTime; +payload.structure_maxTime = structure.maxTime; +payload.roundtrip_minTime = roundtrip.minTime; +payload.roundtrip_maxTime = roundtrip.maxTime; +payload.roundtrip_sampleRate = roundtrip.sampleRate; +payload.roundtrip_labels = roundtrip.getCovLabelsFromMask; +payload.roundtrip_matrix = roundtrip.dataToMatrix(); +payload.shifted_minTime = shifted.minTime; +payload.shifted_maxTime = shifted.maxTime; +payload.shifted_stim_time = shiftedStim.time; +payload.is_present_position = coll.isCovPresent('Position'); +payload.is_present_last_index = coll.isCovPresent(2); +payload.copy_numCov = copyColl.numCov; + +save(fullfile(fixtureRoot, 'covcoll_exactness.mat'), '-struct', 'payload'); +end + function export_cif_fixture(fixtureRoot) cif = CIF([0.1 0.5], {'stim1', 'stim2'}, {'stim1', 'stim2'}, 'binomial'); stimVal = [0.6; -0.2]; @@ -368,6 +566,42 @@ function export_analysis_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'analysis_exactness.mat'), '-struct', 'payload'); end +function export_analysis_multineuron_fixture(fixtureRoot) +t = (0:0.1:1.0)'; +stimData = sin(2*pi*t); +stim = Covariate(t, stimData, 'Stimulus', 'time', 's', '', {'stim'}); +spikeTrain1 = nspikeTrain([0.1 0.4 0.7], '1', 0.1, 0.0, 1.0, 'time', 's', '', '', -1); +spikeTrain2 = nspikeTrain([0.2 0.6 0.9], '2', 0.1, 0.0, 1.0, 'time', 's', '', '', -1); +trial = Trial(nstColl({spikeTrain1, spikeTrain2}), CovColl({stim})); +cfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [], []); +cfg.setName('stim'); +fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl({cfg}), 0); +summary = FitResSummary(fits); + +payload = struct(); +payload.time = t; +payload.stim_data = stimData; +payload.spike_times_1 = spikeTrain1.spikeTimes; +payload.spike_times_2 = spikeTrain2.spikeTimes; +payload.num_fits = numel(fits); +payload.fit1_coeffs = fits{1}.getCoeffs(1); +payload.fit2_coeffs = fits{2}.getCoeffs(1); +payload.fit1_AIC = fits{1}.AIC(1); +payload.fit2_AIC = fits{2}.AIC(1); +payload.fit1_BIC = fits{1}.BIC(1); +payload.fit2_BIC = fits{2}.BIC(1); +payload.fit1_logLL = fits{1}.logLL(1); +payload.fit2_logLL = fits{2}.logLL(1); +payload.summary_AIC = summary.AIC; +payload.summary_BIC = summary.BIC; +payload.summary_logLL = summary.logLL; +payload.summary_KSStats = summary.KSStats; +payload.summary_KSPvalues = summary.KSPvalues; +payload.summary_withinConfInt = summary.withinConfInt; + +save(fullfile(fixtureRoot, 'analysis_multineuron_exactness.mat'), '-struct', 'payload'); +end + function export_ksdiscrete_fixture(fixtureRoot) t = (0:0.1:1.0)'; stimData = sin(2*pi*t); From a969bfee27587d91a996b3938ee97d67df9ff2d9 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 09:26:28 -0400 Subject: [PATCH 2/2] Fix exactness branch regressions --- nstat/history.py | 4 +++- parity/notebook_fidelity.yml | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nstat/history.py b/nstat/history.py index 00d749fb..d7c5e7ba 100644 --- a/nstat/history.py +++ b/nstat/history.py @@ -223,8 +223,10 @@ def plot(self, *_, handle=None, **__): dataLabels.append(f"[{start:.3g},{stop:.3g}]") time = np.linspace(float(np.min(tmin)), float(np.max(tmax)), num_samples) signal = SignalObj(time, data, "History", "time", "s", "", dataLabels) + created_ax = handle is None ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 2.2))[1] - return signal.plot(handle=ax) + plot_handles = signal.plot(handle=ax) + return ax if created_ax else plot_handles HistoryBasis = History diff --git a/parity/notebook_fidelity.yml b/parity/notebook_fidelity.yml index 6678e39d..0a2a5ed9 100644 --- a/parity/notebook_fidelity.yml +++ b/parity/notebook_fidelity.yml @@ -1,5 +1,5 @@ version: 1 -generated_on: '2026-03-08' +generated_on: '2026-03-09' source_repositories: matlab: https://github.com/cajigaslab/nSTAT python: https://github.com/cajigaslab/nSTAT-python