diff --git a/nstat/analysis.py b/nstat/analysis.py index 55484313..f5ff546e 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -327,7 +327,7 @@ def run_analysis_for_neuron( merged_lambda = merged_lambda.merge(part) _restore_trial_partition(trial, original_partition) - return FitResult( + fit_result = FitResult( spike_train, labels, numHist, @@ -346,6 +346,10 @@ def run_analysis_for_neuron( distributions, fits=fits, ) + # MATLAB returns fits with KS diagnostics already populated, and + # downstream summary classes read those cached fields directly. + fit_result.computeKSStats() + return fit_result @staticmethod def run_analysis_for_all_neurons( diff --git a/nstat/core.py b/nstat/core.py index 443927e4..1c2381a7 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -266,6 +266,9 @@ def setPlotProps(self, plotProps: Sequence[Any] | str | None, index: int | None self.plotProps = [plotProps for _ in range(self.dimension)] else: props = list(plotProps) + if len(props) == 0: + self.plotProps = [None for _ in range(self.dimension)] + return if len(props) == 1 and self.dimension > 1: props = props * self.dimension if len(props) != self.dimension: @@ -945,6 +948,9 @@ def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = None) -> dict[str, Any]: data = self.dataToMatrix(selectorArray) + plot_props = list(self.plotProps) + if all(prop is None for prop in plot_props): + plot_props = [] return { "time": self.time.tolist(), "data": data.tolist(), @@ -953,7 +959,7 @@ def dataToStructure(self, selectorArray: Sequence[int] | np.ndarray | None = Non "xunits": self.xunits, "yunits": self.yunits, "dataLabels": list(self.dataLabels), - "plotProps": list(self.plotProps), + "plotProps": plot_props, } def toStructure(self) -> dict[str, Any]: @@ -974,6 +980,7 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": def plot(self, selectorArray=None, plotPropsIn=None, handle=None): import matplotlib.pyplot as plt + from .confidence_interval import MATLAB_COLOR_ORDER ax = plt.gca() if handle is None else handle signal = self.getSubSignal(selectorArray) if selectorArray is not None else self.getSubSignal(self.findIndFromDataMask() or list(range(1, self.dimension + 1))) @@ -989,6 +996,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): prop = props[index] if isinstance(prop, str) and prop: kwargs["fmt"] = prop + elif prop is None: + kwargs["color"] = MATLAB_COLOR_ORDER[index % MATLAB_COLOR_ORDER.shape[0]] if "fmt" in kwargs: fmt = kwargs.pop("fmt") line = ax.plot(signal.time, signal.data[:, index], fmt, **kwargs) @@ -1075,6 +1084,7 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): lines = super().plot(selectorArray, plotPropsIn, handle) if self.isConfIntervalSet(): import matplotlib.pyplot as plt + import matplotlib.colors as mcolors ax = plt.gca() if handle is None else handle selectors = self.findIndFromDataMask() if selectorArray is None else ( @@ -1086,6 +1096,8 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): selectors = [item[0] for item in selectors] for line_index, selector in enumerate(selectors): color = getattr(lines[line_index], "get_color", lambda: "b")() + if isinstance(color, (str, bytes)): + color = mcolors.to_rgb(color) self.ci[selector - 1].plot(color, ax=ax) return lines @@ -1176,17 +1188,37 @@ def toStructure(self) -> dict[str, Any]: if self.isConfIntervalSet(): ci_payload: list[dict[str, Any]] = [] for item in self.ci or []: - if hasattr(item, "time") and hasattr(item, "bounds"): - ci_payload.append( - { - "time": np.asarray(item.time, dtype=float).tolist(), - "bounds": np.asarray(item.bounds, dtype=float).tolist(), - "color": getattr(item, "color", "b"), - } - ) - structure["ci"] = ci_payload + if hasattr(item, "dataToStructure"): + ci_payload.append(item.dataToStructure()) + if ci_payload: + structure["ci"] = ci_payload[0] if len(ci_payload) == 1 else ci_payload return structure + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "Covariate": + from .confidence_interval import ConfidenceInterval + + cov = Covariate( + structure["time"], + structure["data"], + structure.get("name", ""), + structure.get("xlabelval", "time"), + structure.get("xunits", "s"), + structure.get("yunits", ""), + structure.get("dataLabels"), + structure.get("plotProps"), + ) + ci_payload = structure.get("ci") + if ci_payload is None: + return cov + if isinstance(ci_payload, list): + cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload]) + elif isinstance(ci_payload, tuple): + cov.setConfInterval([ConfidenceInterval.fromStructure(item) for item in ci_payload]) + else: + cov.setConfInterval(ConfidenceInterval.fromStructure(ci_payload)) + return cov + class nspikeTrain: """Closer MATLAB-style spike-train object with cached signal representation.""" @@ -1405,11 +1437,15 @@ def _build_sigrep(self, binwidth: float, minTime: float, maxTime: float) -> Sign return sig def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, maxTime: float | None = None) -> SignalObj: - self.sigRep = self.getSigRep(binwidth, minTime, maxTime) - self.isSigRepBin = self.isSigRepBinary() - self.sampleRate = float(self.sigRep.sampleRate) - self.minTime = float(self.sigRep.minTime) - self.maxTime = float(self.sigRep.maxTime) + sig = self.getSigRep(binwidth, minTime, maxTime) + self.sigRep = sig.copySignal() + self.sampleRate = float(sig.sampleRate) + self.isSigRepBin = bool(np.max(np.asarray(sig.data, dtype=float)) <= 1.0) + # Keep the freshly-built cached representation alive instead of + # clearing it through the public min/max setters. + self.minTime = float(sig.minTime) + self.maxTime = float(sig.maxTime) + self.computeStatistics(-1) return self.sigRep def clearSigRep(self) -> None: @@ -1420,14 +1456,12 @@ def clearSigRep(self) -> None: def setMinTime(self, minTime: float) -> None: self.minTime = float(minTime) self.clearSigRep() - if self.avgFiringRate is not None: - self.computeStatistics(-1) + self.computeStatistics(-1) def setMaxTime(self, maxTime: float) -> None: self.maxTime = float(maxTime) self.clearSigRep() - if self.avgFiringRate is not None: - self.computeStatistics(-1) + self.computeStatistics(-1) def resample(self, sampleRate: float) -> "nspikeTrain": self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime) @@ -1476,8 +1510,9 @@ def getMaxBinSizeBinary(self) -> float: return float(np.min(isi)) def isSigRepBinary(self) -> bool: - if self.isSigRepBin is None: - self.getSigRep() + default_key = self._cache_key(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) + if self._sigrep_cache_key != default_key or self.isSigRepBin is None: + self.getSigRep(1.0 / float(self.sampleRate), float(self.minTime), float(self.maxTime)) return bool(self.isSigRepBin) def computeRate(self) -> SignalObj: @@ -1553,14 +1588,19 @@ def plotJointISIHistogram(self): ax = plt.subplots(1, 1, figsize=(4.5, 4.0))[1] isi = self.getISIs() if isi.size >= 2: - ax.loglog(isi[:-1], isi[1:], ".") + xvals = np.asarray(isi[:-1], dtype=float).reshape(-1) + yvals = np.asarray(isi[1:], dtype=float).reshape(-1) + ax.loglog(xvals, yvals, ".") mean_isi = float(np.mean(isi)) ln = isi[isi < mean_isi] ml = float(np.mean(ln)) if ln.size else np.nan if np.isfinite(ml) and ml > 0: - v = ax.axis() - ax.loglog([ml, ml], [v[2], v[3]], "k--") - ax.loglog([v[0], v[1]], [ml, ml], "k--") + ymin = float(np.min(yvals)) + ymax = float(np.max(yvals)) + xmin = float(np.min(xvals)) + xmax = float(np.max(xvals)) + ax.loglog([ml, ml], [ymin, ymax], "k--") + ax.loglog([xmin, xmax], [ml, ml], "k--") ax.set_xlabel("ISI(t) [s]") ax.set_ylabel("ISI(t+1) [s]") return ax @@ -1582,14 +1622,21 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = bins = np.arange(0.0, float(np.max(isi)) + bin_width, bin_width, dtype=float) if bins.size < 2: bins = np.array([0.0, bin_width], dtype=float) - counts, edges = np.histogram(isi, bins=bins) - centers = edges[:-1] + idx = np.searchsorted(bins, isi, side="right") - 1 + idx = np.where( + np.isclose(isi, bins[-1], rtol=0.0, atol=max(1e-12, bin_width * 1e-9)), + bins.size - 1, + idx, + ) + idx = np.clip(idx, 0, bins.size - 1) + counts = np.bincount(idx, minlength=bins.size).astype(float) + centers = bins ax.bar( centers, counts, width=bin_width, align="edge", - edgecolor=(0.0, 0.0, 0.0), + edgecolor="none", linewidth=2.0, color=(0.831372559070587, 0.815686285495758, 0.7843137383461), ) @@ -1600,7 +1647,6 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = None, handle=None): import matplotlib.pyplot as plt - from scipy import stats ax = plt.gca() if handle is None else handle if maxTime is None: @@ -1610,8 +1656,11 @@ def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = Non isi = self.getISIs(minTime, maxTime) ax.clear() if isi.size: - stats.probplot(isi, dist=stats.expon, plot=ax) - ax.set_title(ax.get_title() or "Probability Plot") + sorted_isi = np.sort(np.asarray(isi, dtype=float).reshape(-1)) + n = sorted_isi.size + p = (np.arange(1, n + 1, dtype=float) - 0.5) / float(n) + exp_quantiles = -np.log(1.0 - p) + ax.plot(sorted_isi, exp_quantiles, linestyle="none", marker=".") return ax def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): diff --git a/nstat/history.py b/nstat/history.py index ea33ff9f..d7c5e7ba 100644 --- a/nstat/history.py +++ b/nstat/history.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from collections.abc import Sequence from typing import Any @@ -9,7 +10,63 @@ import matplotlib.pyplot as plt import numpy as np -from .core import Covariate, nspikeTrain +from .core import Covariate, SignalObj, nspikeTrain + + +@dataclass(frozen=True) +class HistoryFilter: + """Discrete-time MATLAB-style transfer function numerator/denominator pair.""" + + numerator: np.ndarray + denominator: np.ndarray + delta: float + variable: str = "z^-1" + + +@dataclass(frozen=True) +class HistoryFilterBank: + """Matrix-like collection of discrete history-window filters.""" + + numerators: tuple[np.ndarray, ...] + denominators: tuple[np.ndarray, ...] + delta: float + variable: str = "z^-1" + + @property + def shape(self) -> tuple[int, int]: + return (len(self.numerators), 1) + + @property + def numFilters(self) -> int: + return len(self.numerators) + + def __len__(self) -> int: + return self.numFilters + + def __getitem__(self, index: int) -> HistoryFilter: + return HistoryFilter( + numerator=np.asarray(self.numerators[index], dtype=float).copy(), + denominator=np.asarray(self.denominators[index], dtype=float).copy(), + delta=float(self.delta), + variable=self.variable, + ) + + def combine(self, coefficients) -> HistoryFilter: + coeffs = np.asarray(coefficients, dtype=float).reshape(-1) + if coeffs.size != self.numFilters: + raise ValueError("Number of coefficients must match the number of history filters.") + max_len = max(len(numerator) for numerator in self.numerators) + padded = np.zeros((self.numFilters, max_len), dtype=float) + for idx, numerator in enumerate(self.numerators): + arr = np.asarray(numerator, dtype=float).reshape(-1) + padded[idx, : arr.size] = arr + numerator = coeffs @ padded + denominator = np.zeros(max_len, dtype=float) + denominator[0] = 1.0 + return HistoryFilter(numerator=np.asarray(numerator, dtype=float), denominator=denominator, delta=float(self.delta), variable=self.variable) + + def __rmatmul__(self, coefficients) -> HistoryFilter: + return self.combine(coefficients) class History: @@ -23,8 +80,8 @@ def __init__(self, windowTimes, minTime: float | None = None, maxTime: float | N raise ValueError("windowTimes must be strictly increasing") self.windowTimes = times - self.minTime = float(times[0] if minTime is None else minTime) - self.maxTime = float(times[-1] if maxTime is None else maxTime) + self.minTime = None if minTime is None else float(minTime) + self.maxTime = None if maxTime is None else float(maxTime) self.name = str(name) @property @@ -41,40 +98,84 @@ def setWindow(self, windowTimes) -> None: self.minTime = replacement.minTime self.maxTime = replacement.maxTime + def toFilter(self, delta: float) -> HistoryFilterBank: + delta = float(delta) + if delta <= 0: + raise ValueError("delta must be positive") + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + numerators: list[np.ndarray] = [] + denominators: list[np.ndarray] = [] + for row, (window_start, window_stop) in enumerate(zip(tmin, tmax)): + num_samples = int(np.ceil(float(window_stop) / delta)) + start_sample = int(np.ceil(float(window_start) / delta)) + 1 + del row + numerator = np.zeros(num_samples + 1, dtype=float) + denominator = np.zeros(num_samples + 1, dtype=float) + denominator[0] = 1.0 + numerator[start_sample : num_samples + 1] = 1.0 + numerators.append(numerator) + denominators.append(denominator) + return HistoryFilterBank(numerators=tuple(numerators), denominators=tuple(denominators), delta=delta) + def _compute_single_history(self, train: nspikeTrain, historyIndex: int | None = None, time_grid=None) -> Covariate: - if time_grid is None: - sigrep = train.getSigRep() - time = np.asarray(sigrep.time, dtype=float).reshape(-1) - else: - time = np.asarray(time_grid, dtype=float).reshape(-1) - spikes = np.asarray(train.getSpikeTimes(), dtype=float).reshape(-1) - history = np.zeros((time.size, self.numWindows), dtype=float) - - for col, (window_start, window_stop) in enumerate(zip(self.windowTimes[:-1], self.windowTimes[1:])): - left = time - float(window_stop) - right = time - float(window_start) - history[:, col] = np.searchsorted(spikes, right, side="left") - np.searchsorted(spikes, left, side="left") - - label_prefix = train.name or f"neuron_{historyIndex or 1}" - labels = [ - f"{label_prefix}_hist_{col + 1}" - for col in range(self.numWindows) - ] - return Covariate(time, history, self.name, "time", "s", "count", labels) + sigrep = train.getSigRep() if time_grid is None else train.getSigRep(None, float(np.min(time_grid)), float(np.max(time_grid))) + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + data_columns: list[np.ndarray] = [] + data_labels: list[str] = [] + + for window_start, window_stop in zip(tmin, tmax, strict=False): + num_samples = int(np.ceil(float(window_stop) * float(train.sampleRate))) + numerator = np.zeros(max(num_samples, 0), dtype=float) + start_sample = int(np.ceil(float(window_start) * float(train.sampleRate))) + 1 + if num_samples > 0 and start_sample <= num_samples: + numerator[max(start_sample - 1, 0) : num_samples] = 1.0 + filtered = sigrep.filter(numerator if numerator.size else [0.0], [1.0]) + delayed = filtered.filter([0.0, 1.0], [1.0]) + data_columns.append(np.asarray(delayed.dataToMatrix(), dtype=float)) + if historyIndex is None: + data_labels.append(f"[{window_start:.3g},{window_stop:.3g}]") + else: + data_labels.append(f"[{window_start:.3g},{window_stop:.3g}]_{historyIndex}") + + data = np.hstack(data_columns) if data_columns else np.zeros((sigrep.time.size, 0), dtype=float) + name = "History" if not getattr(train, "name", "") else f"History {train.name}" + cov = Covariate(sigrep.time, data, name, sigrep.xlabelval, sigrep.xunits, sigrep.yunits, data_labels) + + if time_grid is not None: + return cov + + if data.size == 0: + return Covariate([], data, name, sigrep.xlabelval, sigrep.xunits, sigrep.yunits, data_labels) + + if (self.minTime is not None or self.maxTime is not None) and round(float(cov.sampleRate), 9) != round(float(train.sampleRate), 9): + cov.resampleMe(float(train.sampleRate)) + min_time = float(cov.minTime) if self.minTime is None else float(self.minTime) + max_time = float(cov.maxTime) if self.maxTime is None else float(self.maxTime) + windowed = cov.getSigInTimeWindow(min_time, max_time) + windowed.setMinTime(float(train.minTime)) + windowed.setMaxTime(float(train.maxTime)) + windowed.minTime = float(train.minTime) + windowed.maxTime = float(train.maxTime) + return windowed def compute_history(self, trains, historyIndex: int | None = None, time_grid=None): from .trial import CovariateCollection if isinstance(trains, nspikeTrain): - return CovariateCollection([self._compute_single_history(trains, historyIndex, time_grid=time_grid)]) + cov = self._compute_single_history(trains, historyIndex, time_grid=time_grid) + if historyIndex is not None: + cov.name = f"History #{historyIndex} for {trains.name}" + return CovariateCollection([cov]) if hasattr(trains, "getNST") and hasattr(trains, "numSpikeTrains"): covariates = [ - self._compute_single_history(trains.getNST(index), index, time_grid=time_grid) + self._compute_single_history(trains.getNST(index), historyIndex, time_grid=time_grid) for index in range(1, int(trains.numSpikeTrains) + 1) ] return CovariateCollection(covariates) if isinstance(trains, Sequence) and not isinstance(trains, (str, bytes, np.ndarray)): - covariates = [self._compute_single_history(train, index, time_grid=time_grid) for index, train in enumerate(trains, start=1)] + covariates = [self._compute_single_history(train, historyIndex, time_grid=time_grid) for train in trains] return CovariateCollection(covariates) raise TypeError("History can only be computed from nspikeTrain, nstColl, or sequences of nspikeTrain") @@ -108,19 +209,27 @@ def fromStructure(structure: dict[str, Any] | None) -> "History" | None: ) def plot(self, *_, handle=None, **__): + tmin = np.asarray(self.windowTimes[:-1], dtype=float) + tmax = np.asarray(self.windowTimes[1:], dtype=float) + sampleRate = 1000.0 + num_samples = max(1, int(round((float(np.max(tmax)) - float(np.min(tmin))) * sampleRate))) + data = np.zeros((num_samples, tmax.size), dtype=float) + dataLabels: list[str] = [] + for index, (start, stop) in enumerate(zip(tmin, tmax)): + indMin = max(1, int(round((float(start) - float(np.min(tmin))) * sampleRate))) + indMax = int(round((float(stop) - float(np.min(tmin))) * sampleRate)) + if indMax >= indMin: + data[indMin - 1 : indMax, index] = 1.0 + dataLabels.append(f"[{start:.3g},{stop:.3g}]") + time = np.linspace(float(np.min(tmin)), float(np.max(tmax)), num_samples) + signal = SignalObj(time, data, "History", "time", "s", "", dataLabels) + created_ax = handle is None ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 2.2))[1] - ax.clear() - for idx, (start, stop) in enumerate(zip(self.windowTimes[:-1], self.windowTimes[1:]), start=1): - ax.broken_barh([(float(start), float(stop - start))], (idx - 0.4, 0.8), facecolors="tab:blue", alpha=0.6) - ax.set_xlabel("time [s]") - ax.set_ylabel("history bin") - ax.set_yticks(range(1, self.numWindows + 1)) - ax.set_title(self.name) - ax.set_xlim(float(self.windowTimes[0]), float(self.windowTimes[-1])) - return ax + plot_handles = signal.plot(handle=ax) + return ax if created_ax else plot_handles HistoryBasis = History -__all__ = ["History", "HistoryBasis"] +__all__ = ["History", "HistoryBasis", "HistoryFilter", "HistoryFilterBank"] diff --git a/nstat/trial.py b/nstat/trial.py index 94e5af70..67c7f6c8 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -8,8 +8,9 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np +from scipy.signal import filtfilt -from .core import Covariate, nspikeTrain +from .core import Covariate, SignalObj, nspikeTrain from .events import Events @@ -49,6 +50,36 @@ def _copy_covariate(cov: Covariate) -> Covariate: return copied +def _copy_covariate_for_collection_view(cov: Covariate) -> Covariate: + """Mirror MATLAB CovColl.getCov copy semantics. + + MATLAB reconstructs a fresh Covariate from the visible time/data payload + rather than preserving internal sample-rate/original-data bookkeeping. + That matters for degenerate one-sample covariates, where the constructor + falls back to the 1 kHz default used throughout the toolbox. + """ + + copied = Covariate( + np.asarray(cov.time, dtype=float).copy(), + np.asarray(cov.data, dtype=float).copy(), + cov.name, + cov.xlabelval, + cov.xunits, + cov.yunits, + list(cov.dataLabels), + list(cov.plotProps), + ) + copied.dataMask = np.asarray(cov.dataMask, dtype=int).copy() + if cov.ci: + copied.ci = list(cov.ci) + if cov.conf_interval is not None: + copied.conf_interval = ( + np.asarray(cov.conf_interval[0], dtype=float).copy(), + np.asarray(cov.conf_interval[1], dtype=float).copy(), + ) + return copied + + class CovariateCollection: """MATLAB-style CovColl implementation with collection-level masks and timing.""" @@ -123,12 +154,12 @@ def _covariate_from_identifier(self, identifier: int | str) -> int: return index def _apply_collection_state(self, cov: Covariate, index: int) -> Covariate: - out = _copy_covariate(cov) + out = _copy_covariate_for_collection_view(cov) if self.covShift != 0: out.time = out.time + float(self.covShift) out.minTime = float(np.min(out.time)) out.maxTime = float(np.max(out.time)) - if np.isfinite(self.sampleRate) and self.sampleRate > 0 and round(out.sampleRate, 3) != round(self.sampleRate, 3): + if out.time.size > 1 and np.isfinite(self.sampleRate) and self.sampleRate > 0 and round(out.sampleRate, 3) != round(self.sampleRate, 3): out = out.resample(self.sampleRate) if np.isfinite(self.minTime) and np.isfinite(self.maxTime) and out.time.size > 0: out = out.getSigInTimeWindow(self.minTime, self.maxTime, holdVals=1) @@ -141,6 +172,9 @@ def add(self, covariate: Covariate) -> None: def addCovariate(self, covariate: Covariate) -> None: self.addToColl(covariate) + def addCovCollection(self, covariates: "CovariateCollection") -> None: + self.addToColl(covariates) + def addToColl(self, covariates: Sequence[Covariate] | Covariate | "CovariateCollection" | None) -> None: if covariates is None: return @@ -165,6 +199,10 @@ def removeCovariate(self, identifier: int | str) -> None: del self.covMask[index - 1] self._refresh_summary() + def copy(self) -> "CovariateCollection": + cov = [self.getCov(i).copySignal() for i in range(1, self.numCov + 1)] + return CovariateCollection(cov) + def get(self, name: str) -> Covariate: return self.getCov(name) @@ -191,6 +229,26 @@ def getCovIndicesFromNames(self, name: Sequence[str] | str): return self.getCovIndFromName(name) return [self.getCovIndFromName(item) for item in name] + def isCovPresent(self, cov) -> int: + if isinstance(cov, Covariate): + if not cov.name: + return 0 + try: + self.getCovIndFromName(cov.name) + except KeyError: + return 0 + return 1 + if isinstance(cov, str): + try: + self.getCovIndFromName(cov) + except KeyError: + return 0 + return 1 + if isinstance(cov, (int, np.integer, float, np.floating)): + index = int(cov) + return int(index > 0 and index < self.numCov) + raise TypeError("Need either covariate class or name of covariate or index of covariate") + def findMinTime(self) -> float: if self.numCov == 0: return float("inf") @@ -457,6 +515,68 @@ def dataToMatrix(self, repType: str | Sequence[str] | None = "standard", dataSel _, matrix, _ = self.matrixWithTime(str(repType), dataSelector) return matrix + def dataToStructure( + self, + selectorCell=None, + binwidth: float | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> dict[str, Any]: + del binwidth, minTime, maxTime + if selectorCell is None: + if self.isCovMaskSet(): + selectorCell = self.getSelectorFromMasks() + else: + selectorCell = [list(range(1, self.getCov(i).dimension + 1)) for i in range(1, self.numCov + 1)] + dataMatrix = self.dataToMatrix("standard", selectorCell) + return { + "time": self.getCov(1).time.copy() if self.numCov else np.array([], dtype=float), + "signals": {"values": dataMatrix}, + } + + def toStructure(self) -> dict[str, Any]: + self.resetMask() + structure: dict[str, Any] = { + "numCov": int(self.numCov), + "minTime": float(self.minTime), + "maxTime": float(self.maxTime), + "covDimensions": [int(value) for value in self.covDimensions], + "covMask": [np.asarray(mask, dtype=int).reshape(-1).tolist() for mask in self.covMask], + "covShift": float(self.covShift), + "sampleRate": float(self.sampleRate) if np.isfinite(self.sampleRate) else self.sampleRate, + "originalSampleRate": self.originalSampleRate, + "originalMinTime": self.originalMinTime, + "originalMaxTime": self.originalMaxTime, + "covArray": [cov.toStructure() for cov in self.covArray], + } + return structure + + @staticmethod + def fromStructure(structure) -> "CovariateCollection" | list["CovariateCollection"]: + if isinstance(structure, list): + return [CovariateCollection.fromStructure(item) for item in structure] + if not isinstance(structure, dict): + raise TypeError("CovColl.fromStructure expects a dictionary or list of dictionaries.") + cov = [ + Covariate( + row["time"], + row["data"], + row.get("name", ""), + row.get("xlabelval", "time"), + row.get("xunits", "s"), + row.get("yunits", ""), + row.get("dataLabels"), + row.get("plotProps"), + ) + for row in structure.get("covArray", []) + ] + ccObj = CovariateCollection(cov) + if "minTime" in structure: + ccObj.setMinTime(float(structure["minTime"])) + if "maxTime" in structure: + ccObj.setMaxTime(float(structure["maxTime"])) + return ccObj + class SpikeTrainCollection: """MATLAB-style nstColl implementation.""" @@ -484,6 +604,9 @@ def __iter__(self): for tr in self.nstrain: yield tr + def __len__(self) -> int: + return int(self.numSpikeTrains) + def _refresh_summary(self) -> None: self.numSpikeTrains = len(self.nstrain) if self.numSpikeTrains == 0: @@ -500,8 +623,22 @@ def _refresh_summary(self) -> None: self.neuronMask = np.ones(self.numSpikeTrains, dtype=int) def addSingleSpikeToColl(self, nst: nspikeTrain) -> None: - self.nstrain.append(nst.nstCopy()) - self._refresh_summary() + train = nst.nstCopy() + if not getattr(train, "name", ""): + train.setName(str(self.numSpikeTrains + 1)) + if self.numSpikeTrains == 0: + self.minTime = float(train.minTime) + self.maxTime = float(train.maxTime) + self.sampleRate = float(train.sampleRate) + else: + self.updateTimes(train) + self.sampleRate = float(max(float(self.sampleRate), float(train.sampleRate))) + self.enforceSampleRate() + self.nstrain.append(train) + self.numSpikeTrains = len(self.nstrain) + self.neuronMask = np.append(self.neuronMask, 1).astype(int) + if self.numSpikeTrains == 1: + self.neighbors = [] def addToColl(self, nst: Sequence[nspikeTrain] | nspikeTrain | "SpikeTrainCollection") -> None: if isinstance(nst, SpikeTrainCollection): @@ -523,6 +660,15 @@ def merge(self, nstColl2: "SpikeTrainCollection") -> "SpikeTrainCollection": self.addToColl(nstColl2) return self + def length(self) -> int: + return int(self.numSpikeTrains) + + def getFirstSpikeTime(self) -> float: + return float(self.minTime) + + def getLastSpikeTime(self) -> float: + return float(self.maxTime) + def get_nst(self, idx: int) -> nspikeTrain: if idx < 0 or idx >= self.numSpikeTrains: raise IndexError("SpikeTrainCollection index out of bounds (0-based indexing).") @@ -551,6 +697,43 @@ def getNSTIndicesFromName(self, name: Sequence[str] | str): return matches if len(matches) > 1 else matches[0] return [self.getNSTIndicesFromName(item) for item in name] + def getNSTnameFromInd(self, ind: int) -> str: + index = int(ind) + if index < 1 or index > self.numSpikeTrains: + raise IndexError("Index is out of bounds!") + return str(self.nstrain[index - 1].name) + + def getNSTFromName(self, neuronName=None): + if neuronName is None: + neuronName = self.getUniqueNSTnames() + indices = self.getNSTIndicesFromName(neuronName) + return self.getNST(indices) + + def getFieldVal(self, fieldName: str): + fieldVal: list[float] = [] + neuronNumbers: list[int] = [] + cnt = 1 + for index in range(1, self.numSpikeTrains + 1): + currVal = self.getNST(index).getFieldVal(fieldName) + if currVal is None: + continue + if isinstance(currVal, np.ndarray) and currVal.size == 0: + continue + if len(fieldVal) < cnt: + fieldVal.extend([0.0] * (cnt - len(fieldVal))) + fieldVal[cnt - 1] = float(currVal) + cnt += 1 + if len(neuronNumbers) < cnt: + neuronNumbers.extend([0] * (cnt - len(neuronNumbers))) + neuronNumbers[cnt - 1] = index + return np.asarray(fieldVal, dtype=float), np.asarray(neuronNumbers, dtype=int) + + def shiftTime(self, timeShift: float | None = None) -> "SpikeTrainCollection": + if timeShift is None: + timeShift = -float(self.minTime) + shifted = [nspikeTrain(np.asarray(train.spikeTimes, dtype=float) + float(timeShift)) for train in self.nstrain] + return SpikeTrainCollection(shifted) + def toSpikeTrain( self, selectorArray: Sequence[int] | Sequence[str] | str | None = None, @@ -631,6 +814,12 @@ def resample(self, sampleRate: float) -> None: for train in self.nstrain: train.resample(sampleRate) + def enforceSampleRate(self) -> None: + for index in range(1, self.numSpikeTrains + 1): + currSpike = self.getNST(index) + if round(float(currSpike.sampleRate), 9) != round(float(self.sampleRate), 9): + currSpike.resample(float(self.sampleRate)) + def findMaxSampleRate(self) -> float: if self.numSpikeTrains == 0: return float("-inf") @@ -713,6 +902,12 @@ def getMaxBinSizeBinary(self) -> float: values = [self.getNST(index).getMaxBinSizeBinary() for index in selectorArray] return float(np.min(values)) + def BinarySigRep(self) -> bool: + return bool(all(self.getNST(index).isSigRepBinary() for index in range(1, self.numSpikeTrains + 1))) + + def isSigRepBinary(self) -> bool: + return self.BinarySigRep() + def dataToMatrix( self, selectorArray: Sequence[int] | Sequence[str] | str | None = None, @@ -749,7 +944,9 @@ def dataToMatrix( return dataMat def getEnsembleNeuronCovariates(self, neuronNum: int = 1, neighborIndex=None, windowTimes=None): - if neighborIndex is None: + if neighborIndex is None or ( + isinstance(neighborIndex, (list, tuple, np.ndarray)) and np.asarray(neighborIndex).size == 0 + ): allNeighbors = self.getNeighbors(neuronNum) else: allNeighbors = [int(item) for item in np.asarray(neighborIndex, dtype=int).reshape(-1)] @@ -781,6 +978,21 @@ def restoreToOriginal(self, rMask: int = 0) -> None: if rMask == 1: self.resetMask() + def ensureConsistancy(self) -> None: + self.enforceSampleRate() + self.setMinTime() + self.setMaxTime() + + def updateTimes(self, nst: nspikeTrain) -> None: + if float(nst.minTime) <= float(self.minTime): + self.setMinTime(float(nst.minTime)) + else: + nst.setMinTime(float(self.minTime)) + if float(nst.maxTime) >= float(self.maxTime): + self.setMaxTime(float(nst.maxTime)) + else: + nst.setMaxTime(float(self.maxTime)) + def plot(self, *_, handle=None, **__): selected = self.getIndFromMask() if not selected: @@ -795,6 +1007,65 @@ def plot(self, *_, handle=None, **__): ax.set_title("Spike Train Raster") return ax + def getMinISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> np.ndarray: + isis = self.getISIs(selectorArray, minTime, maxTime) + return np.asarray([float(np.min(values)) if values.size else 0.0 for values in isis], dtype=float) + + def getISIs(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + return [self.getNST(int(neuron)).getISIs(minTime, maxTime) for neuron in selectorArray] + + def plotISIHistogram(self, selectorArray: Sequence[int] | None = None, minTime: float | None = None, maxTime: float | None = None, handle=None): + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + fig = handle if handle is not None else plt.figure(figsize=(7.0, max(2.5, 2.2 * len(selectorArray)))) + fig.clear() + axes = fig.subplots(len(selectorArray), 1) + if not isinstance(axes, np.ndarray): + axes = np.asarray([axes], dtype=object) + for ax, neuron in zip(axes.reshape(-1), selectorArray, strict=False): + self.getNST(int(neuron)).plotISIHistogram(minTime, maxTime, handle=ax) + fig.tight_layout() + return fig + + def plotExponentialFit( + self, + selectorArray: Sequence[int] | None = None, + minTime: float | None = None, + maxTime: float | None = None, + numBins: int | None = None, + handle=None, + ): + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + fig = handle if handle is not None else plt.figure(figsize=(7.0, max(2.5, 2.2 * len(selectorArray)))) + fig.clear() + axes = fig.subplots(len(selectorArray), 1) + if not isinstance(axes, np.ndarray): + axes = np.asarray([axes], dtype=object) + for ax, neuron in zip(axes.reshape(-1), selectorArray, strict=False): + self.getNST(int(neuron)).plotExponentialFit(minTime, maxTime, numBins, handle=ax) + fig.tight_layout() + return fig + + def getSpikeTimes(self, minTime: float | None = None, maxTime: float | None = None) -> list[np.ndarray]: + del minTime, maxTime + selector = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + return [self.getNST(int(index)).getSpikeTimes() for index in selector] + def psth( self, binwidth: float = 0.100, @@ -831,7 +1102,21 @@ def psth( valid = np.isfinite(spikes) & (spikes >= window_times[0]) & (spikes <= window_times[-1]) if not np.any(valid): continue - idx = np.searchsorted(window_times, spikes[valid], side="right") - 1 + spike_times = spikes[valid] + # Mirror MATLAB histc edge semantics on a numerically noisy uniform grid: + # interior samples land in [edge_k, edge_{k+1}), but samples that match + # an edge belong to that edge's bin, with the final edge kept in the + # extra histc bin that is discarded below. + left = np.searchsorted(window_times, spike_times, side="left") + right = np.searchsorted(window_times, spike_times, side="right") - 1 + edge_tol = max(1e-12, abs(float(binwidth)) * 1e-9) + exact_edge = (left < window_times.size) & np.isclose( + spike_times, + window_times[np.clip(left, 0, window_times.size - 1)], + rtol=0.0, + atol=edge_tol, + ) + idx = np.where(exact_edge, left, right) idx = np.clip(idx, 0, window_times.size - 1) psth_hist += np.bincount(idx, minlength=window_times.size).astype(float) @@ -843,6 +1128,201 @@ def psthGLM(self, binwidth: float): psth_signal = self.psth(binwidth) return psth_signal, None, None + def psthBars( + self, + binwidth: float = 0.100, + selectorArray: Sequence[int] | None = None, + minTime: float | None = None, + maxTime: float | None = None, + ) -> SignalObj: + """Deterministic pure-Python fallback for MATLAB nstColl.psthBars. + + MATLAB delegates this method to an external BARS package that is not + bundled with the source tree. The Python port preserves the public + surface and return structure with a smoothed PSTH approximation. + """ + if binwidth <= 0: + raise ValueError("binwidth must be > 0") + min_time = self.minTime if minTime is None else float(minTime) + max_time = self.maxTime if maxTime is None else float(maxTime) + if selectorArray is None or len(selectorArray) == 0: + selectorArray = self.getIndFromMask() if self.isNeuronMaskSet() else list(range(1, self.numSpikeTrains + 1)) + selector = [int(item) for item in selectorArray] + if not selector: + raise ValueError("selectorArray must contain at least one neuron") + + time = np.arange(min_time, max_time + float(binwidth), float(binwidth), dtype=float) + if time.size == 0: + time = np.array([min_time, max_time], dtype=float) + if not np.isclose(time[-1], max_time): + if time[-1] < max_time: + time = np.append(time, max_time) + else: + time[-1] = max_time + + psthData = np.zeros(time.size, dtype=float) + for neuron in selector: + spikeTimes = np.asarray(self.getNST(neuron).getSpikeTimes(), dtype=float).reshape(-1) + if spikeTimes.size == 0: + continue + valid = np.isfinite(spikeTimes) & (spikeTimes >= time[0]) & (spikeTimes <= time[-1]) + if not np.any(valid): + continue + spikeTimes = spikeTimes[valid] + left = np.searchsorted(time, spikeTimes, side="left") + right = np.searchsorted(time, spikeTimes, side="right") - 1 + edge_tol = max(1e-12, abs(float(binwidth)) * 1e-9) + exact_edge = (left < time.size) & np.isclose( + spikeTimes, + time[np.clip(left, 0, time.size - 1)], + rtol=0.0, + atol=edge_tol, + ) + idx = np.where(exact_edge, left, right) + idx = np.clip(idx, 0, time.size - 1) + psthData += np.bincount(idx, minlength=time.size).astype(float) + + psthData = psthData / float(binwidth) / float(len(selector)) + + # MATLAB uses an external BARS fitter here; preserve the public output + # structure with a deterministic smoothed-rate fallback. + if psthData.size >= 3: + kernel = np.array([0.25, 0.5, 0.25], dtype=float) + mean_curve = np.convolve(psthData, kernel, mode="same") + else: + mean_curve = psthData.copy() + mode_curve = mean_curve.copy() + counts_per_bin = np.maximum(mean_curve * float(binwidth) * float(len(selector)), 0.0) + stderr = np.sqrt(counts_per_bin) / max(float(binwidth) * float(len(selector)), 1e-12) + ciLower = np.maximum(mean_curve - 1.96 * stderr, 0.0) + ciUpper = mean_curve + 1.96 * stderr + data = np.column_stack([mode_curve, mean_curve, ciLower, ciUpper]) + return SignalObj( + time, + data, + "PSTH_{bars}", + "time", + "s", + "Hz", + ["mode", "mean", "ciLower", "ciUpper"], + ) + + def _psth_glm_coeffs( + self, + basisWidth: float, + windowTimes=None, + fitType: str = "poisson", + ) -> np.ndarray: + from .analysis import Analysis + + basis = self.generateUnitImpulseBasis(float(basisWidth), float(self.minTime), float(self.maxTime), float(self.sampleRate)) + trial = Trial(SpikeTrainCollection([train.nstCopy() for train in self.nstrain]), CovariateCollection([basis])) + hist = [] if windowTimes is None else np.asarray(windowTimes, dtype=float).reshape(-1) + label_select = [[basis.name, *list(basis.dataLabels)]] + cfg = TrialConfig(label_select, float(self.sampleRate), hist, []) + cfg.setName("GLM-PSTH+Hist" if np.asarray(hist).size else "GLM-PSTH") + cfgColl = ConfigCollection([cfg]) + algorithm = "GLM" if str(fitType or "poisson").lower() == "poisson" else "BNLRCG" + psth_result = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, algorithm, [], 1) + fit = psth_result[0] if isinstance(psth_result, list) else psth_result + coeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1) + numBasis = basis.dimension + if coeffs.size < numBasis: + padded = np.zeros(numBasis, dtype=float) + padded[: coeffs.size] = coeffs + return padded + return coeffs[:numBasis] + + def estimateVarianceAcrossTrials( + self, + numBasis: int | None = None, + windowTimes=None, + numIter: int | None = None, + fitType: str | None = None, + ) -> np.ndarray: + if fitType is None or fitType == "": + fitType = "poisson" + if numIter is None: + numIter = 20 + if windowTimes is None: + windowTimes = [] + if numBasis is None: + numBasis = 20 + + numBasis = int(numBasis) + numIter = int(numIter) + coeffs = np.zeros((numBasis, numIter), dtype=float) + numRealizations = int(self.numSpikeTrains) + if numRealizations == 0 or numBasis <= 0 or numIter <= 0: + return np.zeros((max(numBasis, 0), max(numBasis, 0)), dtype=float) + + basisWidth = (float(self.maxTime) - float(self.minTime)) / float(numBasis) + sumNumber = max(int(np.floor(numRealizations / 2.0 - 1.0)), 0) + delta = 1.0 / float(self.sampleRate) + minTime = float(self.minTime) + maxTime = float(self.maxTime) + halfIters = min(int(np.floor(numIter / 2.0)), sumNumber) + + for i in range(1, halfIters + 1): + subset = SpikeTrainCollection(self.getNST(list(range(i, i + sumNumber + 1)))) + subset.resample(1.0 / delta) + subset.setMaxTime(maxTime) + subset.setMinTime(minTime) + coeffs[:, i - 1] = subset._psth_glm_coeffs(basisWidth, windowTimes, fitType) + + for i in range(numRealizations, numRealizations - halfIters, -1): + subset = SpikeTrainCollection(self.getNST(list(range(i, i - sumNumber - 1, -1)))) + subset.resample(1.0 / delta) + subset.setMaxTime(maxTime) + subset.setMinTime(minTime) + coeffs[:, i - 1] = subset._psth_glm_coeffs(basisWidth, windowTimes, fitType) + + coeff_rows = [row[row != 0] for row in coeffs] + max_width = max((row.size for row in coeff_rows), default=0) + if max_width == 0: + return np.zeros((numBasis, numBasis), dtype=float) + coeffsTemp = np.full((numBasis, max_width), np.nan, dtype=float) + for idx, row in enumerate(coeff_rows): + coeffsTemp[idx, : row.size] = row + + nTerms = 4 + filt_num = np.ones(nTerms, dtype=float) / float(nTerms) + coeffsTemp[np.isnan(coeffsTemp)] = 0.0 + if coeffsTemp.T.shape[0] > 3 * nTerms: + fcoeffs = filtfilt(filt_num, [1.0], coeffsTemp.T, axis=0).T + else: + fcoeffs = coeffsTemp + + diffs = np.diff(fcoeffs, axis=1) + if diffs.shape[1] <= 1: + varEst = np.full(numBasis, np.nan, dtype=float) + else: + with np.errstate(invalid="ignore", divide="ignore"): + varEst = np.nanvar(diffs, axis=1, ddof=1) + return np.diag(varEst) + + @staticmethod + def generateUnitImpulseBasis(basisWidth: float, minTime: float, maxTime: float, sampleRate: float = 1000.0) -> Covariate: + windowTimes = np.arange(float(minTime), float(maxTime), float(basisWidth)) + if windowTimes.size == 0 or not np.isclose(windowTimes[-1], maxTime): + windowTimes = np.append(windowTimes, float(maxTime)) + else: + windowTimes[-1] = float(maxTime) + if windowTimes.size < 2: + windowTimes = np.array([float(minTime), float(maxTime)], dtype=float) + timeVec = np.arange(float(minTime), float(maxTime) + (1.0 / float(sampleRate)), 1.0 / float(sampleRate)) + dataMat = np.zeros((timeVec.size, windowTimes.size - 1), dtype=float) + dataLabels: list[str] = [] + for i in range(windowTimes.size - 1): + start = float(windowTimes[i]) + stop = float(windowTimes[i + 1]) + if i == windowTimes.size - 2: + dataMat[:, i] = ((timeVec >= start) & (timeVec <= stop)).astype(float) + else: + dataMat[:, i] = ((timeVec >= start) & (timeVec < stop)).astype(float) + dataLabels.append(f"b{i + 1:02d}" if i + 1 < 10 else f"b{i + 1}") + return Covariate(timeVec, dataMat, "UnitPulseBasis", "time", "s", "", dataLabels) + class TrialConfig: """MATLAB-style TrialConfig with configuration-application semantics.""" diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 8657b271..65d51300 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -16,7 +16,7 @@ items: matlab_path: SignalObj.m python_public_name: nstat.SignalObj python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Constructor defaults, orientation handling, labels, masks, sample-rate inference, and time-window APIs now mirror MATLAB closely. property_parity: Core observable fields exist, including time, data, name, xlabelval, @@ -56,7 +56,7 @@ items: matlab_path: Covariate.m python_public_name: nstat.Covariate python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Uses the MATLAB-aligned SignalObj constructor shape and supports the Python compatibility aliases for values and units. property_parity: mu and sigma views exist and confidence-interval storage matches @@ -71,22 +71,16 @@ items: output_type_parity: Covariate methods return Covariate or SignalObj as MATLAB expects for the implemented subset. symbol_presence_verified: yes - known_remaining_differences: - - Some CI plotting options and full structure round-tripping remain lighter than - MATLAB. - - More specialized arithmetic/reporting behaviors still need MATLAB-derived fixtures. - required_remediation: - - Extend the committed MATLAB-derived fixtures beyond `computeMeanPlusCI` and - explicit confidence-interval payloads to cover CI plotting and serialized - round-trips. - plotting_report_parity: Core signal and confidence-interval plotting works; some - MATLAB CI styling/report variations remain lighter. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: CI-aware plotting and serialized round-tripping are + fixture-backed against MATLAB for the exported Covariate surface. - matlab_name: nspikeTrain kind: class matlab_path: nspikeTrain.m python_public_name: nstat.nspikeTrain python_impl_path: nstat/core.py - status: high_fidelity + status: exact constructor_parity: Constructor argument order, defaults, and cached signal-representation setup follow MATLAB closely, including min/max/sample-rate initialization and the makePlots behavior split. @@ -107,15 +101,10 @@ items: output_type_parity: Signal representation returns SignalObj and rate conversion returns SignalObj as expected. symbol_presence_verified: yes - known_remaining_differences: - - Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers - remains lighter than MATLAB. - required_remediation: - - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, restore-bound - semantics, and burst-stat summaries to cover ISI plotting traces and the remaining - visualization details. - plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute - on the canonical class, though visual detail remains lighter than MATLAB. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: ISI spectrum, joint-ISI, histogram, probability-plot, + and exponential-fit surfaces are fixture-backed against MATLAB. - matlab_name: nstColl kind: class matlab_path: nstColl.m @@ -128,8 +117,10 @@ items: minTime, maxTime, sampleRate, neuronMask, and neighbors. method_parity: MATLAB-facing collection methods are now first-class, including addToColl, addSingleSpikeToColl, merge, getNST, name/index lookup, masking, neighborhood - management, dataToMatrix, fixture-backed toSpikeTrain collapsing, ensemble-covariate - helpers, restoreToOriginal, psth, and psthGLM. + management, getFieldVal, getSpikeTimes/getISIs wrappers, BinarySigRep/isSigRepBinary, + fixture-backed dataToMatrix, fixture-backed toSpikeTrain collapsing, fixture-backed + ensemble-covariate helpers, restoreToOriginal, fixture-backed psth, psthGLM, + deterministic-fallback psthBars, and Python-side estimateVarianceAcrossTrials. defaults_parity: Defaults for masks, sample rate, and min/max time now track MATLAB collection semantics closely. indexing_parity: MATLAB-facing one-based getNST is preserved. @@ -138,13 +129,22 @@ items: output_type_parity: PSTH returns Covariate. symbol_presence_verified: yes known_remaining_differences: - - Some plotting/statistics helpers and lower-level utility methods from MATLAB are - still absent. + - psthBars now exists, but MATLAB delegates to an external BARS fitter that is + not bundled with the source tree; the Python port currently uses a deterministic + smoothed PSTH fallback instead of exact BARS output. + - MATLAB-only public branch `ssglm` remains unported. + - "estimateVarianceAcrossTrials now exists in Python, but the nontrivial MATLAB + reference path is internally inconsistent through psthGLM / RunAnalysisForAllNeurons, + so the method is not yet fixture-backed strongly enough to promote nstColl to exact." + - Collection-level plotting/report layout still differs from MATLAB in subplot composition + and presentation details. required_remediation: - - Extend the committed MATLAB-derived fixtures beyond collection naming, - `dataToMatrix`, and `toSpikeTrain` outputs to cover neighbor masks, ensemble - covariates, and PSTH outputs. - - Port any remaining collection utilities that surface in MATLAB helpfiles. + - Port `ssglm`. + - Add or vendor a stable BARS-equivalent reference path before promoting psthBars behavior to exact. + - "Add a stable MATLAB-side reference/export path for estimateVarianceAcrossTrials, + then back the Python method with fixtures before promoting nstColl to exact." + - Add fixture-backed checks for the remaining collection plotting/report helpers before + promoting `nstColl` to exact. plotting_report_parity: Raster and PSTH plotting works for core workflows; some collection summary visuals remain unported. - matlab_name: Trial @@ -430,28 +430,23 @@ items: matlab_path: History.m python_public_name: nstat.History python_impl_path: nstat/history.py - status: high_fidelity - constructor_parity: History now uses MATLAB-style windowTimes construction with + status: exact + constructor_parity: History uses MATLAB-style windowTimes construction with optional min/max metadata. property_parity: windowTimes, minTime, maxTime, and lags-compatible access are exposed. - method_parity: setWindow, computeHistory/compute_history, structure round-trip, - and CovColl-producing history workflows are now implemented for single trains, - train collections, and trial history use. - defaults_parity: Window-boundary defaults are close to MATLAB for the implemented - history workflows. + method_parity: setWindow, computeHistory/compute_history, toFilter, plot, and + structure round-trip now match the MATLAB public surface. + defaults_parity: Window-boundary defaults and CovColl return semantics are fixture-backed + against MATLAB. indexing_parity: WindowTimes are interpreted as MATLAB-style consecutive lag boundaries. - error_warning_parity: Core validation is present, though MATLAB warning text and - some malformed-input branches remain thinner. + error_warning_parity: Constructor validation and runtime error branches match the + implemented MATLAB public surface. output_type_parity: Returns CovariateCollection outputs in the MATLAB-facing workflows - that consume History objects. + that consume History objects, including the MATLAB one-sample internal covariate quirk. symbol_presence_verified: yes - known_remaining_differences: - - Plotting and some specialized history-basis utilities remain unported. - required_remediation: - - Add MATLAB-derived fixtures for history-window outputs and multi-neuron history - collections. - plotting_report_parity: No dedicated history plotting parity beyond workflow-generated - covariates and notebook figures. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: MATLAB-style history-window plotting is fixture-backed. - matlab_name: Events kind: class matlab_path: Events.m @@ -512,29 +507,29 @@ items: matlab_path: CovColl.m python_public_name: nstat.CovColl python_impl_path: nstat/trial.py - status: high_fidelity - constructor_parity: CovColl now supports MATLAB-style direct construction, empty - initialization, and nested collection ingestion. - property_parity: Core collection state exists, including covArray, covDimensions, - numCov, minTime, maxTime, covMask, covShift, sampleRate, and original timing metadata. - method_parity: MATLAB-facing collection methods are now first-class, covering add/remove, - name/index lookup, mask selectors, time-window restriction, resampling, matrixWithTime, - dataToMatrix, shift/reset, label extraction, and restoreToOriginal. - defaults_parity: Default mask, shift, sample-rate, and timing behavior now track - MATLAB collection semantics closely. - indexing_parity: Shared-time enforcement is implemented. - error_warning_parity: Core validation is present, though some MATLAB warning text - and malformed-selector branches are still thinner. - output_type_parity: Returns Covariate and CovariateCollection-compatible outputs - across MATLAB-facing workflows. + status: exact + constructor_parity: Direct construction, empty construction, and nested collection + ingestion are now fixture-backed against MATLAB behavior. + property_parity: covArray, covDimensions, numCov, minTime, maxTime, covMask, + covShift, sampleRate, and the original timing/sample-rate metadata are now + fixture-backed on the canonical implementation. + method_parity: MATLAB-facing collection methods now include add/remove, copy, + isCovPresent, name/index lookup, mask selectors, time-window restriction, resampling, + matrix/data export, shift/reset, label extraction, dataToStructure, and structure + round-trip with MATLAB-compatible mask-reset behavior. + defaults_parity: Default mask, shift, and timing behavior are fixture-backed + against MATLAB. + indexing_parity: Shared-time enforcement and one-based selector semantics are + fixture-backed against MATLAB. + error_warning_parity: The implemented constructor/selector/state branches now + match MATLAB behavior for the fixture-backed public surface. + output_type_parity: Returns Covariate and CovColl-compatible outputs across the + MATLAB-facing workflow surface. symbol_presence_verified: yes - known_remaining_differences: - - Some structure serialization and rarely used helper methods remain unported. - required_remediation: - - Add MATLAB-derived fixtures for selector masks, time-window coercion, and serialized - collection state. - plotting_report_parity: Collection plotting is available for core workflows; some - MATLAB summary visuals remain absent. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: Core MATLAB-facing collection plotting and exported structure + views are fixture-backed on the canonical surface. - matlab_name: getPaperDataDirs kind: function matlab_path: getPaperDataDirs.m diff --git a/parity/manifest.yml b/parity/manifest.yml index 206db5a3..585779c5 100644 --- a/parity/manifest.yml +++ b/parity/manifest.yml @@ -476,8 +476,8 @@ repo_structure: or repo-root package stub. fidelity_summary: class_fidelity: - exact: 3 - high_fidelity: 15 + exact: 8 + high_fidelity: 10 not_applicable: 1 notebook_fidelity: high_fidelity: 13 diff --git a/parity/notebook_fidelity.yml b/parity/notebook_fidelity.yml index 6678e39d..0a2a5ed9 100644 --- a/parity/notebook_fidelity.yml +++ b/parity/notebook_fidelity.yml @@ -1,5 +1,5 @@ version: 1 -generated_on: '2026-03-08' +generated_on: '2026-03-09' source_repositories: matlab: https://github.com/cajigaslab/nSTAT python: https://github.com/cajigaslab/nSTAT-python diff --git a/parity/report.md b/parity/report.md index f7a6da0d..0053769d 100644 --- a/parity/report.md +++ b/parity/report.md @@ -22,8 +22,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo | Status | Count | |---|---:| -| `exact` | 3 | -| `high_fidelity` | 15 | +| `exact` | 8 | +| `high_fidelity` | 10 | | `partial` | 0 | | `wrapper_only` | 0 | | `missing` | 0 | diff --git a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat index fdb60499..b506acf3 100644 Binary files a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat and b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat new file mode 100644 index 00000000..3527e7e0 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/cif_exactness.mat b/tests/parity/fixtures/matlab_gold/cif_exactness.mat index 96de6ae3..2818502a 100644 Binary files a/tests/parity/fixtures/matlab_gold/cif_exactness.mat and b/tests/parity/fixtures/matlab_gold/cif_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat index 5a4e9d1a..3314a365 100644 Binary files a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat and b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/config_exactness.mat b/tests/parity/fixtures/matlab_gold/config_exactness.mat index 1c4f78bd..5a70a015 100644 Binary files a/tests/parity/fixtures/matlab_gold/config_exactness.mat and b/tests/parity/fixtures/matlab_gold/config_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat index b7d6f4fd..f7901f36 100644 Binary files a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat and b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat new file mode 100644 index 00000000..b7ba24ee Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat index 0bf4a894..96f4c0cf 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat index ff5e21f9..379a650e 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/events_exactness.mat b/tests/parity/fixtures/matlab_gold/events_exactness.mat index f3ac719f..6c8d8d9c 100644 Binary files a/tests/parity/fixtures/matlab_gold/events_exactness.mat and b/tests/parity/fixtures/matlab_gold/events_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat index a13b7529..445c4c6f 100644 Binary files a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat and b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/history_exactness.mat b/tests/parity/fixtures/matlab_gold/history_exactness.mat new file mode 100644 index 00000000..ca1469e4 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/history_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat index cc41a273..16463dad 100644 Binary files a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat and b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat index 3f487410..4a66e2c0 100644 Binary files a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat and b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat index f76fea7a..bcec1870 100644 Binary files a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat and b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index caec6bea..b3cc8a01 100644 Binary files a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat and b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat index b710e26a..4ef2eb5b 100644 Binary files a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat and b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat index e9a75967..4dc75a25 100644 Binary files a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat and b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index b6fb6f33..4033bdd9 100644 Binary files a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat and b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat index d0a47e8b..61ea8c14 100644 Binary files a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat and b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat index b774fe31..7f54ac06 100644 Binary files a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat and b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat differ diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index b8e7742c..081e45c8 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -19,6 +19,7 @@ Events, FitResult, FitResSummary, + History, SignalObj, Trial, TrialConfig, @@ -70,6 +71,17 @@ def _string_list(payload: dict[str, np.ndarray], key: str) -> list[str]: return [str(item) for item in arr.reshape(-1)] +def _object_vectors(payload: dict[str, np.ndarray], key: str) -> list[np.ndarray]: + value = np.asarray(payload[key], dtype=object) + if value.shape == (): + value = np.asarray([value.item()], dtype=object) + if value.dtype != object: + return [np.asarray(value, dtype=float).reshape(-1)] + if value.ndim == 1 and all(not isinstance(item, (list, tuple, np.ndarray)) for item in value.reshape(-1)): + return [np.asarray(value, dtype=float).reshape(-1)] + return [np.asarray(item, dtype=float).reshape(-1) for item in value.reshape(-1)] + + def test_signalobj_matches_matlab_gold_fixture() -> None: payload = _load_fixture("signalobj_exactness.mat") signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"]) @@ -138,6 +150,74 @@ def test_nspiketrain_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(float(burst_train.numBursts), _scalar(payload, "burst_numBursts"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(burst_train.numSpikesPerBurst, _vector(payload, "burst_numSpikesPerBurst"), rtol=1e-8, atol=1e-10) + fig, ax = plt.subplots() + try: + line = nst.plotISISpectrumFunction() + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float), _vector(payload, "isi_spectrum_x"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float), _vector(payload, "isi_spectrum_y"), rtol=1e-12, atol=1e-12) + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + joint_ax = nst.plotJointISIHistogram() + joint_lines = list(joint_ax.lines) + expected_styles = _string_list(payload, "joint_isi_style") + assert len(joint_lines) == len(expected_styles) + for line, expected_x, expected_y, expected_style in zip( + joint_lines, + _object_vectors(payload, "joint_isi_x"), + _object_vectors(payload, "joint_isi_y"), + expected_styles, + strict=True, + ): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), expected_x, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), expected_y, rtol=1e-12, atol=1e-12) + assert str(line.get_linestyle()).lower() == str(expected_style).lower() + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + counts = nst.plotISIHistogram(handle=ax) + np.testing.assert_allclose(np.asarray(counts, dtype=float).reshape(-1), _vector(payload, "isi_hist_counts"), rtol=1e-12, atol=1e-12) + patches = list(ax.patches) + assert patches + np.testing.assert_allclose(np.asarray(patches[0].get_facecolor()[:3], dtype=float), _vector(payload, "isi_hist_face_color"), rtol=1e-12, atol=1e-12) + expected_edge = _string(payload, "isi_hist_edge_color") + if expected_edge.lower() == "none": + edge = np.asarray(patches[0].get_edgecolor()) + assert edge.size == 0 or edge.reshape(-1, 4)[0, 3] == 0.0 + else: + np.testing.assert_allclose(np.asarray(patches[0].get_edgecolor()[:3], dtype=float), _vector(payload, "isi_hist_edge_color"), rtol=1e-12, atol=1e-12) + finally: + plt.close("all") + + fig, ax = plt.subplots() + try: + prob_ax = nst.plotProbPlot(handle=ax) + prob_lines = list(prob_ax.lines) + expected_styles = _string_list(payload, "probplot_style") + assert len(prob_lines) == len(expected_styles) + for line, expected_x, expected_y, expected_style in zip( + prob_lines, + _object_vectors(payload, "probplot_x"), + _object_vectors(payload, "probplot_y"), + expected_styles, + strict=True, + ): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), expected_x, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), expected_y, rtol=1e-8, atol=1e-10) + assert str(line.get_linestyle()).lower() == str(expected_style).lower() + finally: + plt.close("all") + + fig = nst.plotExponentialFit() + try: + assert len(fig.axes) == int(_scalar(payload, "expfit_num_axes")) + finally: + plt.close(fig) + def test_covariate_and_confidence_interval_match_matlab_gold_fixture() -> None: payload = _load_fixture("covariate_exactness.mat") @@ -147,11 +227,43 @@ def test_covariate_and_confidence_interval_match_matlab_gold_fixture() -> None: cov = Covariate(time, replicates, "Stimulus", "time", "s", "a.u.", ["r1", "r2", "r3", "r4"]) mean_cov = cov.computeMeanPlusCI(0.05) cov_single = Covariate(time, np.mean(replicates, axis=1), "StimulusSingle", "time", "s", "a.u.", ["stim"]) - cov_single.setConfInterval(ConfidenceInterval(time, np.asarray(payload["explicit_ci"], dtype=float), "b")) + cov_single.setConfInterval( + ConfidenceInterval( + time, + np.asarray(payload["explicit_ci"], dtype=float), + "CI", + "time", + "s", + "a.u.", + ) + ) np.testing.assert_allclose(mean_cov.data[:, 0], _vector(payload, "mean_data"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(mean_cov.ci[0].bounds, np.asarray(payload["mean_ci"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(cov_single.ci[0].bounds, np.asarray(payload["explicit_ci"], dtype=float), rtol=1e-8, atol=1e-10) + structure = cov_single.toStructure() + np.testing.assert_allclose(np.asarray(structure["time"], dtype=float), _vector(payload, "structure_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["data"], dtype=float).reshape(-1), _vector(payload, "structure_data"), rtol=1e-12, atol=1e-12) + assert structure["name"] == _string(payload, "structure_name") + assert list(structure["dataLabels"]) == _string_list(payload, "structure_dataLabels") + assert list(structure["plotProps"]) == _string_list(payload, "structure_plotProps") + assert isinstance(structure["ci"], dict) + np.testing.assert_allclose(np.asarray(structure["ci"]["signals"]["values"], dtype=float), np.asarray(payload["structure_ci_values"], dtype=float), rtol=1e-12, atol=1e-12) + assert structure["ci"]["name"] == _string(payload, "structure_ci_name") + + roundtrip = Covariate.fromStructure(structure) + np.testing.assert_allclose(roundtrip.data.reshape(-1), _vector(payload, "roundtrip_data"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(roundtrip.ci[0].bounds, np.asarray(payload["roundtrip_ci"], dtype=float), rtol=1e-12, atol=1e-12) + assert list(roundtrip.dataLabels) == _string_list(payload, "roundtrip_dataLabels") + + fig, ax = plt.subplots() + try: + cov_single.plot(handle=ax) + line_handles = list(ax.lines) + line_colors = np.asarray([mcolors.to_rgb(line.get_color()) for line in line_handles], dtype=float) + np.testing.assert_allclose(line_colors, np.asarray(payload["plot_line_colors"], dtype=float), rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) def test_confidence_interval_matches_matlab_gold_fixture() -> None: @@ -212,6 +324,7 @@ def test_nstcoll_matches_matlab_gold_fixture() -> None: n2 = nspikeTrain(_vector(payload, "secondSpikeTimes"), "2", 10.0, 0.0, 0.5, "time", "s", "spikes", "spk", -1) coll = nstColl([n1, n2]) collapsed = coll.toSpikeTrain() + coll.setNeighbors() np.testing.assert_equal(coll.numSpikeTrains, int(_scalar(payload, "numSpikeTrains"))) assert coll.getNST(1).name == _string(payload, "firstName") @@ -221,6 +334,25 @@ def test_nstcoll_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(float(collapsed.minTime), _scalar(payload, "collapsedMinTime"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(float(collapsed.maxTime), _scalar(payload, "collapsedMaxTime"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(float(collapsed.sampleRate), _scalar(payload, "collapsedSampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(coll.getFirstSpikeTime()), _scalar(payload, "firstSpikeTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(coll.getLastSpikeTime()), _scalar(payload, "lastSpikeTime"), rtol=1e-12, atol=1e-12) + assert coll.isSigRepBinary() == bool(_scalar(payload, "binarySigRep")) + assert coll.BinarySigRep() == bool(_scalar(payload, "binarySigRep")) + assert coll.getNSTnameFromInd(1) == _string(payload, "nstNameFromInd1") + nst_from_name = coll.getNSTFromName("1") + assert isinstance(nst_from_name, nspikeTrain) + np.testing.assert_allclose(nst_from_name.spikeTimes, _vector(payload, "nstFromName1_spikeTimes"), rtol=1e-12, atol=1e-12) + fieldVal, neuronNumbers = coll.getFieldVal("avgFiringRate") + np.testing.assert_allclose(fieldVal, _vector(payload, "fieldVal_avgFiringRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(neuronNumbers, _vector(payload, "fieldVal_neuronNumbers"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(coll.getNeighbors(1), dtype=float), _vector(payload, "neighbors1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(coll.getNeighbors(2), dtype=float), _vector(payload, "neighbors2"), rtol=1e-12, atol=1e-12) + ensembleCov = coll.getEnsembleNeuronCovariates(1, [], [0.0, 0.1]) + assert ensembleCov.getAllCovLabels() == _string_list(payload, "ensemble_labels") + np.testing.assert_allclose(ensembleCov.dataToMatrix(), np.asarray(payload["ensemble_matrix"], dtype=float).reshape(-1, 1), rtol=1e-12, atol=1e-12) + psthCov = coll.psth(0.1, [1, 2], 0.0, 0.5) + np.testing.assert_allclose(psthCov.time, _vector(payload, "psth_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(psthCov.data.reshape(-1), _vector(payload, "psth_data"), rtol=1e-12, atol=1e-12) def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: @@ -286,6 +418,58 @@ def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: np.testing.assert_allclose(np.asarray(trial_from_coll.covarColl.getCov(1).time, dtype=float), _vector(payload, "applied_coll_shifted_position_time"), rtol=1e-12, atol=1e-12) +def test_covcoll_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("covcoll_exactness.mat") + time = np.array([0.0, 0.5, 1.0], dtype=float) + position = Covariate(time, np.column_stack([[0.0, 1.0, 2.0], [10.0, 11.0, 12.0]]), "Position", "time", "s", "", ["x", "y"]) + stimulus = Covariate(time, [5.0, 6.0, 7.0], "Stimulus", "time", "s", "a.u.", ["stim"]) + coll = CovColl([position, stimulus]) + coll.setMask([["Position", "x"], ["Stimulus"]]) + + np.testing.assert_allclose(coll.getCov(1).time, _vector(payload, "masked_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose( + coll.dataToMatrix(), + np.asarray(payload["masked_matrix"], dtype=float).reshape(coll.dataToMatrix().shape), + rtol=1e-12, + atol=1e-12, + ) + assert coll.getCovLabelsFromMask() == _string_list(payload, "masked_labels") + + data_structure = coll.dataToStructure() + np.testing.assert_allclose(np.asarray(data_structure["time"], dtype=float).reshape(-1), _vector(payload, "data_structure_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose( + np.asarray(data_structure["signals"]["values"], dtype=float), + np.asarray(payload["data_structure_values"], dtype=float).reshape(np.asarray(data_structure["signals"]["values"], dtype=float).shape), + rtol=1e-12, + atol=1e-12, + ) + + structure = coll.toStructure() + assert int(structure["numCov"]) == int(_scalar(payload, "structure_numCov")) + np.testing.assert_allclose(float(structure["minTime"]), _scalar(payload, "structure_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(structure["maxTime"]), _scalar(payload, "structure_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(coll.covMask[0], _vector(payload, "post_structure_mask_1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(coll.covMask[1], _vector(payload, "post_structure_mask_2"), rtol=1e-12, atol=1e-12) + + roundtrip = CovColl.fromStructure(structure) + np.testing.assert_allclose(float(roundtrip.minTime), _scalar(payload, "roundtrip_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.maxTime), _scalar(payload, "roundtrip_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.sampleRate), _scalar(payload, "roundtrip_sampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(roundtrip.dataToMatrix(), np.asarray(payload["roundtrip_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert roundtrip.getCovLabelsFromMask() == _string_list(payload, "roundtrip_labels") + + shifted = CovColl([position, stimulus]) + shifted.setCovShift(0.25) + shifted.restrictToTimeWindow(0.25, 1.25) + np.testing.assert_allclose(float(shifted.minTime), _scalar(payload, "shifted_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(shifted.maxTime), _scalar(payload, "shifted_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(shifted.getCov(2).time, _vector(payload, "shifted_stim_time"), rtol=1e-12, atol=1e-12) + + assert coll.isCovPresent("Position") == int(_scalar(payload, "is_present_position")) + assert coll.isCovPresent(2) == int(_scalar(payload, "is_present_last_index")) + assert coll.copy().numCov == int(_scalar(payload, "copy_numCov")) + + def test_events_match_matlab_gold_fixture() -> None: payload = _load_fixture("events_exactness.mat") events = Events(_vector(payload, "eventTimes"), _string_list(payload, "eventLabels"), _string(payload, "eventColor")) @@ -314,6 +498,48 @@ def test_events_match_matlab_gold_fixture() -> None: plt.close(fig) +def test_history_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("history_exactness.mat") + history = History(_vector(payload, "windowTimes"), _scalar(payload, "minTime"), _scalar(payload, "maxTime")) + rebuilt = History.fromStructure(history.toStructure()) + assert rebuilt is not None + np.testing.assert_allclose(rebuilt.windowTimes, _vector(payload, "roundtrip_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(rebuilt.minTime), _scalar(payload, "roundtrip_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(rebuilt.maxTime), _scalar(payload, "roundtrip_maxTime"), rtol=1e-12, atol=1e-12) + + n1 = nspikeTrain([0.0, 0.5, 1.0], "n1", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + n2 = nspikeTrain([0.25, 0.75], "n2", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + coll = nstColl([n1, n2]) + + single_cov = history.computeHistory(n1, 1) + coll_cov = history.computeHistory(coll, 2) + np.testing.assert_allclose(single_cov.dataToMatrix(), np.asarray(payload["single_history_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert single_cov.getAllCovLabels() == _string_list(payload, "single_history_labels") + assert single_cov.getCov(1).name == _string(payload, "single_history_name") + np.testing.assert_allclose(coll_cov.dataToMatrix(), np.asarray(payload["coll_history_matrix"], dtype=float), rtol=1e-12, atol=1e-12) + assert coll_cov.getAllCovLabels() == _string_list(payload, "coll_history_labels") + assert [coll_cov.getCov(index).name for index in range(1, coll_cov.numCov + 1)] == _string_list(payload, "coll_cov_names") + + filter_bank = history.toFilter(_scalar(payload, "filter_delta")) + expected_num = _object_vectors(payload, "filter_num") + expected_den = _object_vectors(payload, "filter_den") + assert filter_bank.numFilters == len(expected_num) + for idx, expected in enumerate(expected_num): + np.testing.assert_allclose(filter_bank[idx].numerator.reshape(-1), expected, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(filter_bank[idx].denominator.reshape(-1), expected_den[idx], rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(filter_bank[idx].delta), _scalar(payload, "filter_delta"), rtol=1e-12, atol=1e-12) + + fig, ax = plt.subplots() + try: + lines = history.plot(handle=ax) + assert len(lines) == len(_object_vectors(payload, "plot_x")) + for line, xdata, ydata in zip(lines, _object_vectors(payload, "plot_x"), _object_vectors(payload, "plot_y")): + np.testing.assert_allclose(np.asarray(line.get_xdata(), dtype=float).reshape(-1), xdata, rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(line.get_ydata(), dtype=float).reshape(-1), ydata, rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) + + def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("cif_exactness.mat") cif = CIF( @@ -378,6 +604,37 @@ def test_analysis_fit_surface_matches_matlab_gold_fixture() -> None: assert fit.fitType[0] == _string(payload, "distribution") +def test_analysis_multineuron_surface_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("analysis_multineuron_exactness.mat") + time = _vector(payload, "time") + stim_data = _vector(payload, "stim_data") + stim = Covariate(time, stim_data, "Stimulus", "time", "s", "", ["stim"]) + spike_train_1 = nspikeTrain(_vector(payload, "spike_times_1"), "1", 0.1, 0.0, 1.0, "time", "s", "", "", -1) + spike_train_2 = nspikeTrain(_vector(payload, "spike_times_2"), "2", 0.1, 0.0, 1.0, "time", "s", "", "", -1) + trial = Trial(nstColl([spike_train_1, spike_train_2]), CovColl([stim])) + cfg = TrialConfig([["Stimulus", "stim"]], 10, [], [], name="stim") + fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl([cfg]), makePlot=0) + assert isinstance(fits, list) + assert len(fits) == int(_scalar(payload, "num_fits")) + + np.testing.assert_allclose(fits[0].getCoeffs(1), _vector(payload, "fit1_coeffs"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(fits[1].getCoeffs(1), _vector(payload, "fit2_coeffs"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[0].AIC[0]), _scalar(payload, "fit1_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].AIC[0]), _scalar(payload, "fit2_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].BIC[0]), _scalar(payload, "fit1_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].BIC[0]), _scalar(payload, "fit2_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].logLL[0]), _scalar(payload, "fit1_logLL"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[1].logLL[0]), _scalar(payload, "fit2_logLL"), rtol=1e-6, atol=1e-8) + + summary = FitResSummary(fits) + np.testing.assert_allclose(summary.AIC, np.asarray(payload["summary_AIC"], dtype=float).reshape(summary.AIC.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.BIC, np.asarray(payload["summary_BIC"], dtype=float).reshape(summary.BIC.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.logLL, np.asarray(payload["summary_logLL"], dtype=float).reshape(summary.logLL.shape), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(summary.KSStats, np.asarray(payload["summary_KSStats"], dtype=float).reshape(summary.KSStats.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.KSPvalues, np.asarray(payload["summary_KSPvalues"], dtype=float).reshape(summary.KSPvalues.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.withinConfInt, np.asarray(payload["summary_withinConfInt"], dtype=float).reshape(summary.withinConfInt.shape), rtol=1e-8, atol=1e-10) + + def test_analysis_discrete_ks_arrays_match_matlab_gold_fixture() -> None: payload = _load_fixture("ksdiscrete_exactness.mat") diff --git a/tests/test_trial_fidelity.py b/tests/test_trial_fidelity.py index 067a7d53..a9258637 100644 --- a/tests/test_trial_fidelity.py +++ b/tests/test_trial_fidelity.py @@ -7,6 +7,7 @@ from nstat.ConfigColl import ConfigColl from nstat.CovColl import CovColl from nstat.nstColl import nstColl +from nstat.SignalObj import SignalObj def _make_covariates() -> tuple[Covariate, Covariate]: @@ -64,6 +65,22 @@ def test_nstcoll_neighbors_mask_and_data_matrix() -> None: np.testing.assert_allclose(matrix, [[1.0, 0.0], [1.0, 1.0], [1.0, 1.0]]) +def test_nstcoll_psthbars_public_contract() -> None: + train1, train2 = _make_spikes() + coll = nstColl([train1, train2]) + + bars = coll.psthBars(0.5, [1, 2], 0.0, 1.0) + + assert isinstance(bars, SignalObj) + assert bars.name == "PSTH_{bars}" + assert bars.dataLabels == ["mode", "mean", "ciLower", "ciUpper"] + np.testing.assert_allclose(bars.time, [0.0, 0.5, 1.0]) + assert bars.data.shape == (3, 4) + np.testing.assert_allclose(bars.data[:, 0], bars.data[:, 1]) + assert np.all(bars.data[:, 2] <= bars.data[:, 1]) + assert np.all(bars.data[:, 1] <= bars.data[:, 3]) + + def test_trialconfig_and_configcoll_apply_and_roundtrip() -> None: position, stimulus = _make_covariates() train1, train2 = _make_spikes() diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index 30cd9eae..26ef9afd 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -24,9 +24,12 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_nspiketrain_fixture(fixtureRoot); export_nstcoll_fixture(fixtureRoot); export_config_fixture(fixtureRoot); +export_covcoll_fixture(fixtureRoot); export_events_fixture(fixtureRoot); +export_history_fixture(fixtureRoot); export_cif_fixture(fixtureRoot); export_analysis_fixture(fixtureRoot); +export_analysis_multineuron_fixture(fixtureRoot); export_ksdiscrete_fixture(fixtureRoot); export_fit_summary_fixture(fixtureRoot); export_point_process_fixture(fixtureRoot); @@ -38,6 +41,63 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_simulated_network_fixture(fixtureRoot); end +function export_history_fixture(fixtureRoot) +histObj = History([0 0.5 1.0], 0.0, 1.0); +n1 = nspikeTrain([0.0 0.5 1.0], 'n1', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +n2 = nspikeTrain([0.25 0.75], 'n2', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +coll = nstColl({n1, n2}); + +singleCov = histObj.computeHistory(n1, 1); +collCov = histObj.computeHistory(coll, 2); +structure = histObj.toStructure; +roundtrip = History.fromStructure(structure); +filterMat = histObj.toFilter(0.5); + +fig = figure('Visible', 'off'); +histObj.plot(); +ax = gca; +lineHandles = flipud(findobj(ax, 'Type', 'line')); +lineLabels = cell(1, numel(lineHandles)); +lineX = cell(1, numel(lineHandles)); +lineY = cell(1, numel(lineHandles)); +for iLine = 1:numel(lineHandles) + lineLabels{iLine} = get(lineHandles(iLine), 'DisplayName'); + lineX{iLine} = get(lineHandles(iLine), 'XData'); + lineY{iLine} = get(lineHandles(iLine), 'YData'); +end +close(fig); + +payload = struct(); +payload.windowTimes = histObj.windowTimes; +payload.minTime = histObj.minTime; +payload.maxTime = histObj.maxTime; +payload.structure_windowTimes = structure.windowTimes; +payload.roundtrip_windowTimes = roundtrip.windowTimes; +payload.roundtrip_minTime = roundtrip.minTime; +payload.roundtrip_maxTime = roundtrip.maxTime; +payload.single_history_matrix = singleCov.dataToMatrix(); +payload.single_history_labels = singleCov.getAllCovLabels; +payload.single_history_name = singleCov.getCov(1).name; +payload.coll_history_matrix = collCov.dataToMatrix(); +payload.coll_history_labels = collCov.getAllCovLabels; +payload.coll_cov_names = cell(1, collCov.numCov); +for iCov = 1:collCov.numCov + payload.coll_cov_names{iCov} = collCov.getCov(iCov).name; +end +payload.filter_num = cell(size(filterMat)); +payload.filter_den = cell(size(filterMat)); +for idx = 1:numel(filterMat) + payload.filter_num{idx} = filterMat(idx).Numerator{1}; + payload.filter_den{idx} = filterMat(idx).Denominator{1}; +end +payload.filter_delta = filterMat.Ts; +payload.plot_labels = lineLabels; +payload.plot_x = lineX; +payload.plot_y = lineY; + +save(fullfile(fixtureRoot, 'history_exactness.mat'), '-struct', 'payload'); +end + function export_events_fixture(fixtureRoot) events = Events([0.2 0.7], {'E1','E2'}, 'g'); fig = figure('Visible', 'off'); @@ -167,6 +227,50 @@ function export_nspiketrain_fixture(fixtureRoot) burstTrain = nspikeTrain([0.0; 0.001; 0.002; 0.007; 0.507; 0.508; 0.509; 0.514], 'bursting', 0.001, 0.0, 0.6, 'time', 's', 'spikes', 'spk', 0); payload = struct(); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +h = nst.plotISISpectrumFunction(); +payload.isi_spectrum_x = get(h,'XData'); +payload.isi_spectrum_y = get(h,'YData'); +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +nst.plotJointISIHistogram(); +jointLines = flipud(findobj(ax, 'Type', 'line')); +for iLine = 1:numel(jointLines) + payload.joint_isi_x{iLine} = get(jointLines(iLine), 'XData'); + payload.joint_isi_y{iLine} = get(jointLines(iLine), 'YData'); + payload.joint_isi_style{iLine} = get(jointLines(iLine), 'LineStyle'); +end +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +counts = nst.plotISIHistogram(); +histBars = findobj(ax, 'Type', 'patch'); +payload.isi_hist_counts = counts; +if(~isempty(histBars)) + payload.isi_hist_face_color = get(histBars(1), 'FaceColor'); + payload.isi_hist_edge_color = get(histBars(1), 'EdgeColor'); +end +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +nst.plotProbPlot(); +probLines = flipud(findobj(ax, 'Type', 'line')); +for iLine = 1:numel(probLines) + payload.probplot_x{iLine} = get(probLines(iLine), 'XData'); + payload.probplot_y{iLine} = get(probLines(iLine), 'YData'); + payload.probplot_style{iLine} = get(probLines(iLine), 'LineStyle'); +end +close(fig); + +fig = nst.plotExponentialFit(); +payload.expfit_num_axes = numel(findobj(fig, 'Type', 'axes')); +close(fig); payload.spikeTimes = spikeTimes; payload.binwidth = binwidth; payload.minTime = 0.0; @@ -200,6 +304,23 @@ function export_covariate_fixture(fixtureRoot) ci = ConfidenceInterval(t, [mean(replicates,2)-0.1, mean(replicates,2)+0.1], 'CI', 'time', 's', 'a.u.'); covSingle = Covariate(t, mean(replicates,2), 'StimulusSingle', 'time', 's', 'a.u.', {'stim'}); covSingle.setConfInterval(ci); +structure = covSingle.toStructure; +roundtrip = Covariate.fromStructure(structure); +if(iscell(structure.ci)) + ciStructure = structure.ci{1}; +else + ciStructure = structure.ci; +end + +fig = figure('Visible','off'); +plot(covSingle); +drawnow; +lineHandles = findobj(gca,'Type','line'); +plotColors = zeros(length(lineHandles),3); +for i = 1:length(lineHandles) + plotColors(i,:) = get(lineHandles(i),'Color'); +end +close(fig); payload = struct(); payload.time = t; @@ -207,6 +328,17 @@ function export_covariate_fixture(fixtureRoot) payload.mean_data = meanCov.data; payload.mean_ci = meanCov.ci{1}.data; payload.explicit_ci = covSingle.ci{1}.data; +payload.structure_time = structure.time; +payload.structure_data = structure.data; +payload.structure_name = structure.name; +payload.structure_dataLabels = structure.dataLabels; +payload.structure_plotProps = structure.plotProps; +payload.structure_ci_values = ciStructure.signals.values; +payload.structure_ci_name = ciStructure.name; +payload.roundtrip_data = roundtrip.data; +payload.roundtrip_ci = roundtrip.ci{1}.data; +payload.roundtrip_dataLabels = roundtrip.dataLabels; +payload.plot_line_colors = plotColors; save(fullfile(fixtureRoot, 'covariate_exactness.mat'), '-struct', 'payload'); end @@ -217,6 +349,11 @@ function export_nstcoll_fixture(fixtureRoot) coll = nstColl({n1, n2}); dataMat = coll.dataToMatrix([1 2], 0.1, 0.0, 0.5); collapsed = coll.toSpikeTrain; +coll.setNeighbors; +neighbors1 = coll.getNeighbors(1); +neighbors2 = coll.getNeighbors(2); +ensembleCov = coll.getEnsembleNeuronCovariates(1, [], [0.0 0.1]); +psthCov = coll.psth(0.1, [1 2], 0.0, 0.5); payload = struct(); payload.numSpikeTrains = coll.numSpikeTrains; @@ -229,6 +366,20 @@ function export_nstcoll_fixture(fixtureRoot) payload.collapsedMinTime = collapsed.minTime; payload.collapsedMaxTime = collapsed.maxTime; payload.collapsedSampleRate = collapsed.sampleRate; +payload.firstSpikeTime = coll.getFirstSpikeTime; +payload.lastSpikeTime = coll.getLastSpikeTime; +payload.binarySigRep = coll.isSigRepBinary; +payload.nstNameFromInd1 = coll.getNSTnameFromInd(1); +payload.nstFromName1_spikeTimes = coll.getNSTFromName('1').spikeTimes; +[fieldVal, neuronNumbers] = coll.getFieldVal('avgFiringRate'); +payload.fieldVal_avgFiringRate = fieldVal; +payload.fieldVal_neuronNumbers = neuronNumbers; +payload.neighbors1 = neighbors1; +payload.neighbors2 = neighbors2; +payload.ensemble_labels = ensembleCov.getAllCovLabels; +payload.ensemble_matrix = ensembleCov.dataToMatrix(); +payload.psth_time = psthCov.time; +payload.psth_data = psthCov.data; save(fullfile(fixtureRoot, 'nstcoll_exactness.mat'), '-struct', 'payload'); end @@ -302,6 +453,53 @@ function export_config_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'config_exactness.mat'), '-struct', 'payload'); end +function export_covcoll_fixture(fixtureRoot) +t = (0:0.5:1.0)'; +position = Covariate(t, [0 10; 1 11; 2 12], 'Position', 'time', 's', '', {'x','y'}); +stimulus = Covariate(t, [5; 6; 7], 'Stimulus', 'time', 's', 'a.u.', {'stim'}); +coll = CovColl({position, stimulus}); +coll.setMask({{'Position','x'},{'Stimulus'}}); +maskedLabels = coll.getCovLabelsFromMask; +maskedMatrix = coll.dataToMatrix(); +maskedTime = coll.getCov(1).time; +dataStructure = coll.dataToStructure; +structure = coll.toStructure; +postMask1 = coll.covMask{1}; +postMask2 = coll.covMask{2}; +roundtrip = CovColl.fromStructure(structure); +copyColl = coll.copy; + +shifted = CovColl({position, stimulus}); +shifted.setCovShift(0.25); +shifted.restrictToTimeWindow(0.25, 1.25); +shiftedStim = shifted.getCov(2); + +payload = struct(); +payload.masked_labels = maskedLabels; +payload.masked_matrix = maskedMatrix; +payload.masked_time = maskedTime; +payload.data_structure_time = dataStructure.time; +payload.data_structure_values = dataStructure.signals.values; +payload.post_structure_mask_1 = postMask1; +payload.post_structure_mask_2 = postMask2; +payload.structure_numCov = structure.numCov; +payload.structure_minTime = structure.minTime; +payload.structure_maxTime = structure.maxTime; +payload.roundtrip_minTime = roundtrip.minTime; +payload.roundtrip_maxTime = roundtrip.maxTime; +payload.roundtrip_sampleRate = roundtrip.sampleRate; +payload.roundtrip_labels = roundtrip.getCovLabelsFromMask; +payload.roundtrip_matrix = roundtrip.dataToMatrix(); +payload.shifted_minTime = shifted.minTime; +payload.shifted_maxTime = shifted.maxTime; +payload.shifted_stim_time = shiftedStim.time; +payload.is_present_position = coll.isCovPresent('Position'); +payload.is_present_last_index = coll.isCovPresent(2); +payload.copy_numCov = copyColl.numCov; + +save(fullfile(fixtureRoot, 'covcoll_exactness.mat'), '-struct', 'payload'); +end + function export_cif_fixture(fixtureRoot) cif = CIF([0.1 0.5], {'stim1', 'stim2'}, {'stim1', 'stim2'}, 'binomial'); stimVal = [0.6; -0.2]; @@ -368,6 +566,42 @@ function export_analysis_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'analysis_exactness.mat'), '-struct', 'payload'); end +function export_analysis_multineuron_fixture(fixtureRoot) +t = (0:0.1:1.0)'; +stimData = sin(2*pi*t); +stim = Covariate(t, stimData, 'Stimulus', 'time', 's', '', {'stim'}); +spikeTrain1 = nspikeTrain([0.1 0.4 0.7], '1', 0.1, 0.0, 1.0, 'time', 's', '', '', -1); +spikeTrain2 = nspikeTrain([0.2 0.6 0.9], '2', 0.1, 0.0, 1.0, 'time', 's', '', '', -1); +trial = Trial(nstColl({spikeTrain1, spikeTrain2}), CovColl({stim})); +cfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [], []); +cfg.setName('stim'); +fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl({cfg}), 0); +summary = FitResSummary(fits); + +payload = struct(); +payload.time = t; +payload.stim_data = stimData; +payload.spike_times_1 = spikeTrain1.spikeTimes; +payload.spike_times_2 = spikeTrain2.spikeTimes; +payload.num_fits = numel(fits); +payload.fit1_coeffs = fits{1}.getCoeffs(1); +payload.fit2_coeffs = fits{2}.getCoeffs(1); +payload.fit1_AIC = fits{1}.AIC(1); +payload.fit2_AIC = fits{2}.AIC(1); +payload.fit1_BIC = fits{1}.BIC(1); +payload.fit2_BIC = fits{2}.BIC(1); +payload.fit1_logLL = fits{1}.logLL(1); +payload.fit2_logLL = fits{2}.logLL(1); +payload.summary_AIC = summary.AIC; +payload.summary_BIC = summary.BIC; +payload.summary_logLL = summary.logLL; +payload.summary_KSStats = summary.KSStats; +payload.summary_KSPvalues = summary.KSPvalues; +payload.summary_withinConfInt = summary.withinConfInt; + +save(fullfile(fixtureRoot, 'analysis_multineuron_exactness.mat'), '-struct', 'payload'); +end + function export_ksdiscrete_fixture(fixtureRoot) t = (0:0.1:1.0)'; stimData = sin(2*pi*t);