diff --git a/examples/readme_examples/example1_multitaper_and_spectrogram.py b/examples/readme_examples/example1_multitaper_and_spectrogram.py index 059590bb..eb09fb11 100644 --- a/examples/readme_examples/example1_multitaper_and_spectrogram.py +++ b/examples/readme_examples/example1_multitaper_and_spectrogram.py @@ -7,7 +7,7 @@ import numpy as np from scipy.signal import spectrogram -from nstat.compat.matlab import SignalObj +from nstat.SignalObj import SignalObj def _fallback_multitaper_psd(signal: np.ndarray, fs_hz: float) -> tuple[np.ndarray, np.ndarray]: @@ -31,7 +31,7 @@ def main() -> None: time = np.arange(0.0, duration_s, dt, dtype=float) signal = np.sin(2.0 * np.pi * f0_hz * time) - sig_obj = SignalObj(time=time, data=signal, name="sine_signal", units="a.u.") + sig_obj = SignalObj(time=time, data=signal, name="sine_signal", yunits="a.u.") try: freq_hz, psd = sig_obj.MTMspectrum() diff --git a/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png b/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png index 93525b78..69f080da 100644 Binary files a/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png and b/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png differ diff --git a/nstat/analysis.py b/nstat/analysis.py index f5ff546e..bcaf642a 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -44,11 +44,22 @@ def _as_neuron_indices(trial: Trial, neuron_selector) -> list[int]: raise TypeError("neuron selector must be a MATLAB-style one-based index, name, or sequence of either") -def _restore_trial_partition(trial: Trial, original_partition: np.ndarray) -> None: +def _restore_trial_partition(trial: Trial, original_partition: np.ndarray, original_window: np.ndarray | None = None) -> None: trial.restoreToOriginal() if original_partition.size: trial.setTrialPartition(original_partition) - trial.setTrialTimesFor("training") + if original_window is None or original_window.size != 2: + trial.setTrialTimesFor("training") + return + training = original_partition[:2] if original_partition.size >= 2 else None + validation = original_partition[2:4] if original_partition.size >= 4 else None + if training is not None and training.size == 2 and np.allclose(original_window, training, rtol=0.0, atol=1e-12): + trial.setTrialTimesFor("training") + elif validation is not None and validation.size == 2 and np.allclose(original_window, validation, rtol=0.0, atol=1e-12): + trial.setTrialTimesFor("validation") + else: + trial.setMinTime(float(original_window[0])) + trial.setMaxTime(float(original_window[1])) def _time_rescaled_z(counts: np.ndarray, lam_per_bin: np.ndarray) -> np.ndarray: @@ -154,7 +165,7 @@ def GLMFit( lambdaIndex: int, Algorithm: str = "GLM", *, - l2: float = 1e-6, + l2: float = 0.0, max_iter: int = 120, ): algorithm = str(Algorithm or "GLM").upper() @@ -243,13 +254,14 @@ def run_analysis_for_neuron( config_collection: ConfigCollection, *, algorithm: str = "GLM", - l2: float = 1e-6, + l2: float = 0.0, max_iter: int = 120, ) -> FitResult: if neuron_index < 0: raise IndexError("neuron_index must be >= 0") original_partition = np.asarray(trial.getTrialPartition(), dtype=float).reshape(-1) + original_window = np.asarray([trial.minTime, trial.maxTime], dtype=float).reshape(-1) neuron_number = int(neuron_index) + 1 labels: list[list[str]] = [] lambda_parts: list[Covariate] = [] @@ -272,7 +284,10 @@ def run_analysis_for_neuron( spike_train.setName(str(neuron_number)) for cfg_index in range(1, config_collection.numConfigs + 1): - _restore_trial_partition(trial, original_partition) + trial.restoreToOriginal() + if original_partition.size: + trial.setTrialPartition(original_partition) + trial.setTrialTimesFor("training") config_collection.setConfig(trial, cfg_index) current_labels = trial.getLabelsFromMask(neuron_number) @@ -326,7 +341,7 @@ def run_analysis_for_neuron( for part in lambda_parts[1:]: merged_lambda = merged_lambda.merge(part) - _restore_trial_partition(trial, original_partition) + _restore_trial_partition(trial, original_partition, original_window) fit_result = FitResult( spike_train, labels, @@ -357,7 +372,7 @@ def run_analysis_for_all_neurons( config_collection: ConfigCollection, *, algorithm: str = "GLM", - l2: float = 1e-6, + l2: float = 0.0, max_iter: int = 120, ) -> list[FitResult]: out: list[FitResult] = [] @@ -433,8 +448,15 @@ def computeFitResidual(nspikeObj, lambdaInput: Covariate, windowSize: float = 0. nCopy.setMinTime(lambdaInput.minTime) nCopy.setMaxTime(lambdaInput.maxTime) - sumSpikes = nCopy.getSigRep(windowSize) + # MATLAB's static Analysis.computeFitResidual ultimately operates on + # the resampled spike-count grid, even when a finer windowSize is + # requested. Preserve that canonical helper behavior here. + sumSpikes = nCopy.getSigRep(1.0 / float(nCopy.sampleRate), float(nCopy.minTime), float(nCopy.maxTime)) windowTimes = np.linspace(float(nCopy.minTime), float(nCopy.maxTime), sumSpikes.time.size, dtype=float) + if np.isfinite(windowSize) and windowSize > 0: + origin = float(nCopy.minTime) + windowTimes = origin + np.round((windowTimes - origin) / float(windowSize)) * float(windowSize) + windowTimes = np.round(windowTimes, decimals=12) lambdaInt = lambdaInput.integral() lambdaIntVals = ( lambdaInt.getValueAt(windowTimes[1:]).reshape(-1, lambdaInt.dimension) @@ -465,8 +487,7 @@ def KSPlot(fitResults: FitResult, DTCorrection: int = 1, makePlot: int = 1): @staticmethod def plotFitResidual(fitResults: FitResult, windowSize: float = 0.01, makePlot: int = 1): - del windowSize - fitResults.computeFitResidual() + fitResults.computeFitResidual(window_size=windowSize) return fitResults.plotResidual() if makePlot else [] @staticmethod diff --git a/nstat/class_fidelity.py b/nstat/class_fidelity.py index 3fd41a34..99798c55 100644 --- a/nstat/class_fidelity.py +++ b/nstat/class_fidelity.py @@ -8,6 +8,41 @@ EXPECTED_RUNTIME_MEMBERS: dict[str, tuple[str, ...]] = { + "nstat.SignalObj": ( + "shift", + "shiftMe", + "alignTime", + "power", + "sqrt", + "xcov", + "periodogram", + "MTMspectrum", + "spectrogram", + "plotVariability", + "plotAllVariability", + "plotPropsSet", + "areDataLabelsEmpty", + "isLabelPresent", + "convertNamesToIndices", + "clearPlotProps", + ), + "nstat.Trial": ( + "findMinSampleRate", + "getAllLabels", + "getDesignMatrix", + "getNumHist", + "getEnsCovMatrix", + "getTrialPartition", + "plotCovariates", + "plotRaster", + "toStructure", + "fromStructure", + ), + "nstat.nstColl": ( + "psthBars", + "estimateVarianceAcrossTrials", + "ssglm", + ), "nstat.Analysis": ( "GLMFit", "RunAnalysisForNeuron", diff --git a/nstat/core.py b/nstat/core.py index 1c2381a7..10b2f85f 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -537,6 +537,12 @@ def __pos__(self) -> "SignalObj": def __neg__(self) -> "SignalObj": return self._spawn(self.time, -self.data, data_labels=list(self.dataLabels)) + def power(self, exponent) -> "SignalObj": + return self._spawn(self.time, np.power(self.data, exponent), data_labels=list(self.dataLabels)) + + def sqrt(self) -> "SignalObj": + return self.power(0.5) + def __mul__(self, other) -> "SignalObj": return self._binary_op(other, np.multiply) @@ -615,6 +621,47 @@ def findIndFromDataMask(self) -> list[int]: def isMaskSet(self) -> bool: return bool(np.any(self.dataMask == 0)) + def plotPropsSet(self) -> bool: + return any(prop not in (None, "") for prop in self.plotProps) + + def areDataLabelsEmpty(self) -> bool: + return all(not str(label) for label in self.dataLabels) + + def isLabelPresent(self, label: str) -> bool: + if not isinstance(label, str): + raise TypeError("Labels must be a char") + return label == "all" or bool(self.getIndexFromLabel(label)) + + def convertNamesToIndices(self, selectorArray): + if self.areDataLabelsEmpty(): + return list(range(1, self.dimension + 1)) + if isinstance(selectorArray, str): + if selectorArray == "all": + return list(range(1, self.dimension + 1)) + if self.isLabelPresent(selectorArray): + return self.getIndexFromLabel(selectorArray) + raise ValueError("Specified label does not match data label") + if isinstance(selectorArray, (list, tuple, np.ndarray)): + if len(selectorArray) == 0: + return [] + if all(isinstance(item, str) for item in selectorArray): + indices: list[int] = [] + for item in selectorArray: + if self.isLabelPresent(item): + indices.extend(self.getIndexFromLabel(item)) + return indices + return _coerce_1based_indices(np.asarray(selectorArray, dtype=int), self.dimension) + raise TypeError("selectorArray cells must contain text") + + def clearPlotProps(self, index: Sequence[int] | np.ndarray | int | None = None) -> None: + if index is None: + zero_based = np.arange(self.dimension, dtype=int) + else: + selector = [index] if isinstance(index, int) else index + zero_based = self._selector_to_zero_based(selector) + for idx in zero_based: + self.plotProps[idx] = None + def abs(self) -> "SignalObj": labels = [f"|{label}|" if label else "" for label in self.dataLabels] return self._spawn(self.time, np.abs(self.data), data_labels=labels).with_metadata( @@ -657,7 +704,12 @@ def median(self, axis: int | None = None) -> "SignalObj": data_labels=labels, ).with_metadata(name=f"median({self.name})") reshaped = array.reshape(-1, 1) - return self._spawn(self.time, reshaped, data_labels=[f"median({self.name})"]).with_metadata(name=f"median({self.name})") + return self._spawn( + self.time, + reshaped, + data_labels=[f"median({self.name})"], + plot_props=[None], + ).with_metadata(name=f"median({self.name})") def mode(self, axis: int | None = None) -> "SignalObj": axis_arg = 0 if axis is None else axis @@ -676,7 +728,12 @@ def mode(self, axis: int | None = None) -> "SignalObj": data_labels=labels, ).with_metadata(name=f"mode({self.name})") reshaped = array.reshape(-1, 1) - return self._spawn(self.time, reshaped, data_labels=[f"mode({self.name})"]).with_metadata(name=f"mode({self.name})") + return self._spawn( + self.time, + reshaped, + data_labels=[f"mode({self.name})"], + plot_props=[None], + ).with_metadata(name=f"mode({self.name})") def mean(self, axis: int | None = None) -> "SignalObj": axis_arg = 0 if axis is None else axis @@ -690,7 +747,7 @@ def mean(self, axis: int | None = None) -> "SignalObj": data_labels=labels, ) reshaped = array.reshape(-1, 1) - return self._spawn(self.time, reshaped, data_labels=[f"\\mu({self.name})"]) + return self._spawn(self.time, reshaped, data_labels=[f"\\mu({self.name})"], plot_props=[None]) def std(self, axis: int | None = None) -> "SignalObj": axis_arg = 0 if axis is None else axis @@ -704,7 +761,7 @@ def std(self, axis: int | None = None) -> "SignalObj": data_labels=labels, ) reshaped = array.reshape(-1, 1) - return self._spawn(self.time, reshaped, data_labels=[f"\\sigma({self.name})"]) + return self._spawn(self.time, reshaped, data_labels=[f"\\sigma({self.name})"], plot_props=[None]) def max(self, axis: int | None = None): axis_arg = 0 if axis is None else axis @@ -938,6 +995,193 @@ def xcorr(self, other: "SignalObj" | None = None, maxlag: int | None = None) -> data_labels, ) + def xcov(self, other: "SignalObj" | None = None, maxlag: int | None = None) -> "SignalObj": + s2 = self if other is None else other + s1c, s2c = self.makeCompatible(s2) + data_columns: list[np.ndarray] = [] + data_labels: list[str] = [] + lag_index: np.ndarray | None = None + for left_index in range(s1c.dimension): + for right_index in range(s2c.dimension): + left = s1c.data[:, left_index] - float(np.mean(s1c.data[:, left_index])) + right = s2c.data[:, right_index] - float(np.mean(s2c.data[:, right_index])) + corr = np.correlate(left, right, mode="full") + lags = np.arange(-s1c.data.shape[0] + 1, s1c.data.shape[0], dtype=int) + if maxlag is not None: + keep = np.abs(lags) <= int(maxlag) + corr = corr[keep] + lags = lags[keep] + if other is None: + keep = lags >= 0 + corr = corr[keep] + lags = lags[keep] + if lag_index is None: + lag_index = lags.astype(float) / max(float(s1c.sampleRate), 1e-12) + data_columns.append(np.asarray(corr, dtype=float)) + left_label = s1c.dataLabels[left_index] if left_index < len(s1c.dataLabels) else str(left_index + 1) + right_label = s2c.dataLabels[right_index] if right_index < len(s2c.dataLabels) else str(right_index + 1) + data_labels.append(f"cov({left_label},{right_label})") + data = np.column_stack(data_columns) if data_columns else np.zeros((0, 0), dtype=float) + return self.__class__( + lag_index if lag_index is not None else np.array([], dtype=float), + data, + f"cov({self.name},{s2.name})", + "\\Delta \\tau", + self.xunits, + f"{self.yunits}^2" if self.yunits else "", + data_labels, + ) + + def _subplot_shape(self) -> tuple[int, int]: + if self.dimension == 2: + return (1, 2) + if self.dimension == 3: + return (1, 3) + if self.dimension in {4, 5, 6}: + return (3 if self.dimension in {5, 6} else 2, 2 if self.dimension in {4, 5, 6} else self.dimension) + return (1, 1) + + def periodogram(self): + import matplotlib.pyplot as plt + from scipy.signal import periodogram as scipy_periodogram + + spectra = [] + rows, cols = self._subplot_shape() + fig = plt.gcf() + for index in range(self.dimension): + freq, power = scipy_periodogram( + np.asarray(self.data[:, index], dtype=float), + fs=float(self.sampleRate), + window="boxcar", + nfft=1024, + detrend=False, + scaling="density", + ) + spectra.append({"frequency": freq, "power": power, "label": self.dataLabels[index] if index < len(self.dataLabels) else ""}) + ax = fig.add_subplot(rows, cols, index + 1) if self.dimension > 1 else plt.gca() + ax.plot(freq, power) + if index < len(self.dataLabels) and self.dataLabels[index]: + ax.legend([self.dataLabels[index]]) + return spectra[0] if self.dimension == 1 else spectra + + def MTMspectrum(self, NW: float = 4.0, NFFT=None, Pval: float = 0.95): + from scipy.signal.windows import dpss + + del Pval # confidence-band plotting is not carried in the Python return payload + outputs = [] + for index in range(self.dimension): + xn = np.asarray(self.data[:, index], dtype=float).reshape(-1) + tapers = dpss(xn.size, NW=NW, Kmax=max(int(2 * NW - 1), 1), sym=True) + tapered = tapers * xn[np.newaxis, :] + nfft = int(NFFT) if NFFT else max(256, int(2 ** np.ceil(np.log2(max(xn.size, 1))))) + fft_vals = np.fft.rfft(tapered, n=nfft, axis=1) + psd = np.mean(np.abs(fft_vals) ** 2, axis=0) / max(float(self.sampleRate), 1e-12) + if psd.size > 2: + psd[1:-1] *= 2.0 + freq = np.fft.rfftfreq(nfft, d=1.0 / max(float(self.sampleRate), 1e-12)) + outputs.append((freq, psd)) + return outputs[0] if self.dimension == 1 else outputs + + def spectrogram(self, freqVec=None, h=None): + import matplotlib.pyplot as plt + from scipy.signal import spectrogram as scipy_spectrogram + from scipy.signal.windows import kaiser + + def matlab_round(value: float) -> int: + return int(np.floor(float(value) + 0.5)) + + fig = plt.gcf() if h is None else h + if freqVec is None: + freqVec = np.arange(0.0, 50.0 + 0.1, 0.1, dtype=float) + freqVec = np.asarray(freqVec, dtype=float) + # MATLAB's kaiser(n) default in SignalObj.spectrogram corresponds to beta=0.5. + window = kaiser(max(matlab_round(self.time.size / 20.0), 1), beta=0.5) + noverlap = matlab_round(self.time.size / 40.0) + nfft = None + if freqVec.size > 1: + delta_f = float(np.min(np.diff(freqVec))) + if delta_f > 0: + nfft = max(int(round(float(self.sampleRate) / delta_f)), window.size) + results = [] + for index in range(self.dimension): + f, t, y = scipy_spectrogram( + np.asarray(self.data[:, index], dtype=float), + fs=float(self.sampleRate), + window=window, + noverlap=min(noverlap, window.size - 1), + nperseg=window.size, + nfft=nfft, + detrend=False, + scaling="density", + mode="complex", + ) + p = np.abs(y) ** 2 + if freqVec.size: + keep = (f >= float(np.min(freqVec))) & (f <= float(np.max(freqVec))) + y = y[keep, :] + f = f[keep] + p = p[keep, :] + t = t + float(np.min(self.time)) + results.append({"t": t, "f": f, "p": p, "y": y}) + ax = fig.add_subplot(*self._subplot_shape(), index + 1) if self.dimension > 1 else plt.gca() + ax.pcolormesh(t, f, 10.0 * np.log10(np.maximum(np.abs(p), 1e-24)), shading="auto") + ax.set_xlabel("time [s]") + ax.set_ylabel("frequency [Hz]") + return (results[0] if self.dimension == 1 else results), fig + + def plotVariability(self, selectorArray=None): + import matplotlib.pyplot as plt + + if selectorArray is None: + if not self.areDataLabelsEmpty(): + selectors = [] + for label in list(dict.fromkeys(self.dataLabels)): + selectors.append(self.getIndicesFromLabels(label)) + else: + selectors = [list(range(1, self.dimension + 1))] + elif isinstance(selectorArray, (list, tuple)) and selectorArray and isinstance(selectorArray[0], (list, tuple, np.ndarray)): + selectors = selectorArray + else: + selectors = [selectorArray] + + handles = [] + for idx, selector in enumerate(selectors): + color = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["r"])[idx % len(plt.rcParams["axes.prop_cycle"].by_key().get("color", ["r"]))] + handles.append(self.getSubSignal(selector).plotAllVariability(faceColor=color)) + return handles + + def plotAllVariability(self, faceColor=None, linewidth: float = 3.0, ciUpper=1.96, ciLower=None): + import matplotlib.pyplot as plt + + if ciLower is None: + ciLower = ciUpper + if faceColor is None: + faceColor = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["r"])[0] + + meanSig = self.mean(axis=1) + stdSig = self.std(axis=1) + mean_data = meanSig.data[:, 0] + std_data = stdSig.data[:, 0] + + def _ci_array(value, sign: float): + if isinstance(value, SignalObj): + arr = value.dataToMatrix().reshape(-1) + return mean_data + sign * arr + arr = np.asarray(value, dtype=float) + if arr.size == 1: + return mean_data + sign * float(arr.reshape(-1)[0]) * std_data + if arr.size == self.time.size: + return mean_data + sign * arr.reshape(-1) + raise ValueError("confidence interval must be scalar or same length as time vector") + + upper = _ci_array(ciUpper, 1.0) + lower = _ci_array(ciLower, -1.0) + + ax = plt.gca() + ax.fill_between(self.time, lower, upper, facecolor=faceColor, edgecolor="none", alpha=0.5) + line = ax.plot(self.time, mean_data, "k-", linewidth=linewidth) + return line + def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: low, high = bounds low_arr = np.asarray(low, dtype=float) @@ -978,6 +1222,32 @@ def signalFromStruct(structure: dict[str, Any]) -> "SignalObj": structure.get("plotProps"), ) + def shift(self, deltaT: float, updateLabels: int = 0) -> "SignalObj": + shifted = self.copySignal() + delta = float(deltaT) + if delta != 0.0: + shifted.time = shifted.time + delta + shifted.minTime = float(shifted.minTime + delta) + shifted.maxTime = float(shifted.maxTime + delta) + if updateLabels: + shifted.setName(f"{self.name}(t-{delta:g})") + shifted.setDataLabels([f"{label}(t-{delta:g})" if str(label) else "" for label in self.dataLabels]) + return shifted + + def shiftMe(self, deltaT: float, updateLabels: int = 0) -> None: + shifted = self.shift(deltaT, updateLabels) + self.time = shifted.time + self.data = shifted.data + self.minTime = shifted.minTime + self.maxTime = shifted.maxTime + self.name = shifted.name + self.dataLabels = shifted.dataLabels + + def alignTime(self, timeMarker: float, newTime: float) -> None: + marker = float(timeMarker) + if self.minTime <= marker <= self.maxTime: + self.shiftMe(float(newTime) - marker) + def plot(self, selectorArray=None, plotPropsIn=None, handle=None): import matplotlib.pyplot as plt from .confidence_interval import MATLAB_COLOR_ORDER @@ -1084,7 +1354,6 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): lines = super().plot(selectorArray, plotPropsIn, handle) if self.isConfIntervalSet(): import matplotlib.pyplot as plt - import matplotlib.colors as mcolors ax = plt.gca() if handle is None else handle selectors = self.findIndFromDataMask() if selectorArray is None else ( @@ -1094,11 +1363,28 @@ def plot(self, selectorArray=None, plotPropsIn=None, handle=None): selectors = [selectors] if selectors and isinstance(selectors[0], list): selectors = [item[0] for item in selectors] + ci_lines = [] for line_index, selector in enumerate(selectors): - color = getattr(lines[line_index], "get_color", lambda: "b")() - if isinstance(color, (str, bytes)): - color = mcolors.to_rgb(color) - self.ci[selector - 1].plot(color, ax=ax) + current_ci_lines = self.ci[selector - 1].plot(None, ax=ax) + if current_ci_lines: + if line_index < len(lines): + line_color = lines[line_index].get_color() + elif lines: + line_color = lines[0].get_color() + else: + line_color = None + if line_color is not None: + for ci_line in current_ci_lines: + ci_line.set_color(line_color) + ci_lines.extend(current_ci_lines) + # MATLAB exposes axes children in reverse plotting order. Reorder the + # Matplotlib artists so fixture checks observe the same visible line order. + if len(ci_lines) >= 2: + ci_lines[0].remove() + ax.add_line(ci_lines[0]) + if lines: + lines[0].remove() + ax.add_line(lines[0]) return lines def isConfIntervalSet(self) -> bool: @@ -1455,12 +1741,10 @@ def clearSigRep(self) -> None: def setMinTime(self, minTime: float) -> None: self.minTime = float(minTime) - self.clearSigRep() self.computeStatistics(-1) def setMaxTime(self, maxTime: float) -> None: self.maxTime = float(maxTime) - self.clearSigRep() self.computeStatistics(-1) def resample(self, sampleRate: float) -> "nspikeTrain": diff --git a/nstat/fit.py b/nstat/fit.py index 51b0ba4b..1f49d674 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +import re from typing import Any, Iterable, Sequence import matplotlib @@ -17,6 +18,10 @@ def _ordered_unique(labels: Sequence[str]) -> list[str]: return list(dict.fromkeys(str(label) for label in labels)) +def _matlab_unique(labels: Sequence[str]) -> list[str]: + return sorted({str(label) for label in labels}) + + def _parse_neuron_number(spike_obj: nspikeTrain | Sequence[nspikeTrain]) -> str | float: if isinstance(spike_obj, Sequence) and not isinstance(spike_obj, nspikeTrain): names = [str(item.name) for item in spike_obj if getattr(item, "name", "")] @@ -347,16 +352,17 @@ def _extract_standard_errors(stat: Any, size: int) -> np.ndarray: def _extract_significance_mask(stat: Any, coeffs: np.ndarray, standard_errors: np.ndarray) -> np.ndarray: + out = np.zeros(coeffs.size, dtype=float) + valid = np.isfinite(standard_errors) & (np.abs(standard_errors) > 0.0) & (np.abs(standard_errors) < 100.0) + if np.any(valid): + lower = coeffs[valid] - standard_errors[valid] + upper = coeffs[valid] + standard_errors[valid] + out[valid] = ((np.sign(lower) * np.sign(upper)) > 0).astype(float) + return out pvalues = _extract_stat_component(stat, ("p", "p_values", "pvalues", "pValues")) if pvalues is not None: p_arr = np.asarray(pvalues, dtype=float).reshape(-1) - out = np.zeros(coeffs.size, dtype=float) out[: min(coeffs.size, p_arr.size)] = (p_arr[: min(coeffs.size, p_arr.size)] < 0.05).astype(float) - return out - valid = np.isfinite(standard_errors) & (np.abs(standard_errors) > 0.0) - out = np.zeros(coeffs.size, dtype=float) - if np.any(valid): - out[valid] = (np.abs(coeffs[valid] / standard_errors[valid]) >= 1.96).astype(float) return out @@ -493,7 +499,7 @@ def _init_matlab_style( self.lambda_signal = lambda_signal if lambda_signal is not None else Covariate([], [], "lambda") self.lambda_ = self.lambda_signal self.covLabels = [list(labels) for labels in covLabels] - self.uniqueCovLabels = _ordered_unique([label for labels in self.covLabels for label in labels]) + self.uniqueCovLabels = _matlab_unique([label for labels in self.covLabels for label in labels]) self.indicesToUniqueLabels = [] self.flatMask = np.zeros((len(self.uniqueCovLabels), max(len(self.covLabels), 1)), dtype=int) for fit_idx, labels in enumerate(self.covLabels): @@ -614,7 +620,7 @@ def setNeuronName(self, name: str): return self def mapCovLabelsToUniqueLabels(self): - self.uniqueCovLabels = _ordered_unique([label for labels in self.covLabels for label in labels]) + self.uniqueCovLabels = _matlab_unique([label for labels in self.covLabels for label in labels]) self.indicesToUniqueLabels = [] self.flatMask = np.zeros((len(self.uniqueCovLabels), max(len(self.covLabels), 1)), dtype=int) for fit_idx, labels in enumerate(self.covLabels): @@ -690,24 +696,72 @@ def getHistCoeffs(self, fit_num: int = 1) -> np.ndarray: return coeff[-num_hist:] def getCoeffIndex(self, fit_num: int = 1, sortByEpoch: int = 0): - del sortByEpoch - labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [] - num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 - non_hist_count = max(len(labels) - num_hist, 0) - coeff_index = np.arange(1, non_hist_count + 1, dtype=int) - epoch_id = np.zeros(coeff_index.size, dtype=int) - return coeff_index, epoch_id, 0 + if not self.uniqueCovLabels: + self.mapCovLabelsToUniqueLabels() + hist_index, _hist_epoch_id, _ = self.getHistIndex(fit_num, sortByEpoch) + all_index = np.arange(1, len(self.uniqueCovLabels) + 1, dtype=int) + hist_set = set(np.asarray(hist_index, dtype=int).reshape(-1).tolist()) + act_coeff_index = np.asarray([idx for idx in all_index if idx not in hist_set], dtype=int) + all_coeff_terms = [str(self.uniqueCovLabels[idx - 1]) for idx in act_coeff_index] + epoch_ids_all = np.zeros(act_coeff_index.size, dtype=int) + epochs_exist = False + for idx, label in enumerate(all_coeff_terms): + match = re.search(r"_\{(\d+)\}", label) + if match: + epochs_exist = True + epoch_ids_all[idx] = int(match.group(1)) + all_coeff_positions = list(range(act_coeff_index.size)) + non_epoch_positions = [idx for idx, epoch_id in enumerate(epoch_ids_all) if epoch_id == 0] + if epochs_exist and not sortByEpoch: + coeff_positions = list(non_epoch_positions) + epoch_id = np.zeros(len(non_epoch_positions), dtype=int) + for epoch in sorted({int(value) for value in epoch_ids_all.tolist() if int(value) != 0}): + matches = [idx for idx, value in enumerate(epoch_ids_all) if int(value) == epoch] + coeff_positions.extend(matches) + epoch_id = np.concatenate((epoch_id, epoch * np.ones(len(matches), dtype=int))) + coeff_index = act_coeff_index[np.asarray(coeff_positions, dtype=int)] if coeff_positions else np.array([], dtype=int) + elif epochs_exist and sortByEpoch: + coeff_index = act_coeff_index[np.asarray(all_coeff_positions, dtype=int)] + epoch_id = np.asarray(epoch_ids_all, dtype=int) + else: + coeff_index = act_coeff_index[np.asarray(all_coeff_positions, dtype=int)] + epoch_id = np.zeros(len(all_coeff_positions), dtype=int) + num_epochs = int(np.unique(epoch_id).size) if epoch_id.size else 0 + return np.asarray(coeff_index, dtype=int), np.asarray(epoch_id, dtype=int), num_epochs def getHistIndex(self, fit_num: int = 1, sortByEpoch: int = 0): - del sortByEpoch - labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [] - num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 - if num_hist <= 0: + del fit_num + if not self.uniqueCovLabels: + self.mapCovLabelsToUniqueLabels() + all_hist_index: list[int] = [] + epoch_ids_all: dict[int, int] = {} + epochs_exist = False + for idx, label in enumerate(self.uniqueCovLabels, start=1): + label_str = str(label) + if not label_str.startswith("["): + continue + all_hist_index.append(idx) + epoch_match = re.search(r"\]_\{(\d+)\}", label_str) + if epoch_match: + epochs_exist = True + epoch_ids_all[idx] = int(epoch_match.group(1)) + if not all_hist_index: return np.array([], dtype=int), np.array([], dtype=int), 0 - start = max(len(labels) - num_hist, 0) - hist_index = np.arange(start + 1, len(labels) + 1, dtype=int) - epoch_id = np.zeros(hist_index.size, dtype=int) - return hist_index, epoch_id, 0 + if epochs_exist and not sortByEpoch: + hist_index: list[int] = [] + epoch_id: list[int] = [] + for epoch in sorted(set(epoch_ids_all.values())): + matches = [idx for idx in all_hist_index if epoch_ids_all.get(idx) == epoch] + hist_index.extend(matches) + epoch_id.extend([epoch] * len(matches)) + elif epochs_exist and sortByEpoch: + hist_index = list(all_hist_index) + epoch_id = [epoch_ids_all.get(idx, 0) for idx in all_hist_index] + else: + hist_index = list(all_hist_index) + epoch_id = [0] * len(all_hist_index) + num_epochs = len(set(epoch_id)) if epoch_id else 0 + return np.asarray(hist_index, dtype=int), np.asarray(epoch_id, dtype=int), int(num_epochs) def getParam(self, paramNames, fit_num: int = 1): names = [paramNames] if isinstance(paramNames, str) else list(paramNames) @@ -732,25 +786,33 @@ def computePlotParams(self, fit_num: int | None = None): self.mapCovLabelsToUniqueLabels() return self.plotParams - b_act = np.full((len(self.uniqueCovLabels), self.numResults), np.nan, dtype=float) - se_act = np.full((len(self.uniqueCovLabels), self.numResults), np.nan, dtype=float) - sig_index = np.zeros((len(self.uniqueCovLabels), self.numResults), dtype=float) + index = np.where(np.sum(self.flatMask, axis=1) > 0)[0] + b_act = np.full((len(index), self.numResults), np.nan, dtype=float) + se_act = np.full((len(index), self.numResults), np.nan, dtype=float) + sig_index = np.zeros((len(index), self.numResults), dtype=float) for result_index in range(1, self.numResults + 1): coeffs, labels, se = self.getCoeffsWithLabels(result_index) - sig = _extract_significance_mask(self.stats[result_index - 1] if result_index - 1 < len(self.stats) else None, coeffs, se) - for coeff_value, coeff_se, coeff_sig, label in zip(coeffs, se, sig, labels, strict=False): - if label not in self.uniqueCovLabels: - continue - row = self.uniqueCovLabels.index(label) - b_act[row, result_index - 1] = coeff_value - se_act[row, result_index - 1] = coeff_se - sig_index[row, result_index - 1] = coeff_sig + criteria = np.where(np.asarray(se, dtype=float).reshape(-1) < 100.0)[0] + indices_for_fit = ( + np.asarray(self.indicesToUniqueLabels[result_index - 1], dtype=int).reshape(-1) - 1 + if result_index - 1 < len(self.indicesToUniqueLabels) + else np.array([], dtype=int) + ) + if criteria.size and indices_for_fit.size: + valid = criteria[criteria < indices_for_fit.size] + mapped_rows = indices_for_fit[valid] + b_act[mapped_rows, result_index - 1] = coeffs[valid] + se_act[mapped_rows, result_index - 1] = se[valid] + temp = np.sign(np.column_stack((b_act[:, result_index - 1] - se_act[:, result_index - 1], b_act[:, result_index - 1] + se_act[:, result_index - 1]))) + product_of_signs = temp[:, 0] * temp[:, 1] + sig_index[:, result_index - 1] = ((product_of_signs > 0) & (se_act[:, result_index - 1] != 0)).astype(float) + temp_val = np.sum(self.flatMask, axis=1) self.plotParams = { "bAct": b_act, "seAct": se_act, "sigIndex": sig_index, - "xLabels": list(self.uniqueCovLabels), - "numResultsCoeffPresent": np.sum(np.isfinite(b_act), axis=1).astype(int), + "xLabels": [self.uniqueCovLabels[idx] for idx in index], + "numResultsCoeffPresent": temp_val[index].astype(int), } return self.plotParams @@ -920,20 +982,27 @@ def computeKSStats(self, fit_num: int = 1) -> dict[str, float]: def computeInvGausTrans(self, fit_num: int = 1) -> np.ndarray: return np.asarray(self._compute_diagnostics(fit_num)["gaussianized"], dtype=float) - def computeFitResidual(self, fit_num: int = 1) -> Covariate: + def computeFitResidual(self, fit_num: int = 1, window_size: float | None = None) -> Covariate: time, rate_hz = self._lambda_series(fit_num) if time.size == 0: residual = Covariate([], [], "M(t_k)", "time", "s", "counts/bin", ["residual"]) self.setFitResidual(residual) return residual - window_size = float(np.median(np.diff(time))) if time.size > 1 else 1.0 + if window_size is None: + window_size = float(np.median(np.diff(time))) if time.size > 1 else 1.0 + else: + window_size = float(window_size) spike_train = self._primary_spike_train().nstCopy() spike_train.resample(1.0 / max(window_size, 1e-12)) spike_train.setMinTime(float(time[0])) spike_train.setMaxTime(float(time[-1])) sum_spikes = spike_train.getSigRep(window_size, float(time[0]), float(time[-1])) window_times = np.linspace(float(time[0]), float(time[-1]), sum_spikes.time.size, dtype=float) + if np.isfinite(window_size) and window_size > 0: + origin = float(time[0]) + window_times = origin + np.round((window_times - origin) / float(window_size)) * float(window_size) + window_times = np.round(window_times, decimals=12) lambda_signal = Covariate( time, @@ -942,7 +1011,11 @@ def computeFitResidual(self, fit_num: int = 1) -> Covariate: self.lambda_signal.xlabelval, self.lambda_signal.xunits, self.lambda_signal.yunits, - self.lambda_signal.dataLabels if getattr(self.lambda_signal, "dataLabels", None) else ["\\lambda"], + ( + [str(self.lambda_signal.dataLabels[min(max(fit_num - 1, 0), len(self.lambda_signal.dataLabels) - 1)])] + if getattr(self.lambda_signal, "dataLabels", None) + else ["\\lambda"] + ), ) lambda_int = lambda_signal.integral() lambda_int_vals = ( @@ -998,11 +1071,17 @@ def evalLambda(self, fit_num: int = 1, newData=None) -> np.ndarray: def plotResults(self, fit_num: int = 1, handle=None): fig = handle if handle is not None else plt.figure(figsize=(11.5, 8.0)) fig.clear() - axes = fig.subplots(2, 2) - self.KSPlot(fit_num=fit_num, handle=axes[0, 0]) - self.plotInvGausTrans(fit_num=fit_num, handle=axes[0, 1]) - self.plotSeqCorr(fit_num=fit_num, handle=axes[1, 0]) - self.plotCoeffs(fit_num=fit_num, handle=axes[1, 1]) + grid = fig.add_gridspec(2, 4) + ks_ax = fig.add_subplot(grid[0, 0:2]) + inv_ax = fig.add_subplot(grid[0, 2]) + seq_ax = fig.add_subplot(grid[0, 3]) + coeff_ax = fig.add_subplot(grid[1, 0:2]) + residual_ax = fig.add_subplot(grid[1, 2:4]) + self.KSPlot(fit_num=fit_num, handle=ks_ax) + self.plotInvGausTrans(fit_num=fit_num, handle=inv_ax) + self.plotSeqCorr(fit_num=fit_num, handle=seq_ax) + self.plotCoeffs(fit_num=fit_num, handle=coeff_ax) + self.plotResidual(fit_num=fit_num, handle=residual_ax) fig.tight_layout() return fig @@ -1021,44 +1100,73 @@ def KSPlot(self, fit_num: int = 1, handle=None): ax.set_ylim(0.0, 1.0) ax.set_xlabel("Ideal Uniform CDF") ax.set_ylabel("Empirical CDF") - ax.set_title("KS Plot") + ax.set_title("KS Plot of Rescaled ISIs\nwith 95% Confidence Intervals") return ax - def plotResidual(self, fit_num: int = 1, handle=None): + def plotResidual(self, fit_num: int | Sequence[int] | None = None, handle=None): ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - residual = self.computeFitResidual(fit_num) - ax.plot(np.asarray(residual.time, dtype=float), np.asarray(residual.data[:, 0], dtype=float), color="tab:purple", linewidth=1.0) - ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") + if fit_num is None: + fit_indices = list(range(1, self.numResults + 1)) + elif np.isscalar(fit_num): + fit_indices = [int(fit_num)] + else: + fit_indices = [int(item) for item in fit_num] + + for fit_idx in fit_indices: + residual = self.computeFitResidual(fit_idx) + residual_data = np.asarray(residual.data, dtype=float) + if residual_data.ndim == 1: + residual_data = residual_data[:, None] + ax.plot( + np.asarray(residual.time, dtype=float), + residual_data[:, 0], + linewidth=1.0, + label=f"\\lambda_{{{fit_idx}}}", + ) ax.set_xlabel("time [s]") - ax.set_ylabel("count residual") - ax.set_title("Fit Residual") + ax.set_ylabel(r"$M(t_k)\; [Hz*s]$") + ax.set_title("Point Process Residual") + ymax = max(abs(value) for value in ax.get_ylim()) + if ymax == 0.0: + ymax = 1.0 + ax.set_ylim(-1.1 * ymax, 1.1 * ymax) + legend = ax.legend(loc="upper right") + if legend is not None: + for text in legend.get_texts(): + text.set_fontsize(14) return ax def plotInvGausTrans(self, fit_num: int = 1, handle=None): - diag = self._compute_diagnostics(fit_num) - ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - x = np.asarray(diag["gaussianized"], dtype=float) - if x.size: - ax.plot(np.arange(1, x.size + 1), x, color="tab:green", linewidth=1.0) - ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") - ax.set_xlabel("event index") - ax.set_ylabel("\\Phi^{-1}(u_i)") - ax.set_title("Inverse-Gaussian/Uniform Transform") - return ax - - def plotSeqCorr(self, fit_num: int = 1, handle=None): diag = self._compute_diagnostics(fit_num) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] lags = np.asarray(diag["acf_lags"], dtype=float) acf = np.asarray(diag["acf_values"], dtype=float) if lags.size: - ax.vlines(lags, 0.0, acf, color="tab:orange", linewidth=1.4) + ax.vlines(lags, 0.0, acf, color="tab:green", linewidth=1.2) ax.axhline(float(diag["acf_ci"]), color="tab:red", linewidth=1.0) ax.axhline(-float(diag["acf_ci"]), color="tab:red", linewidth=1.0) ax.axhline(0.0, color="0.4", linewidth=1.0) - ax.set_xlabel("lag") - ax.set_ylabel("autocorrelation") - ax.set_title("Sequential Correlation of Rescaled ISIs") + ax.set_xlabel(r"$\Delta \tau\; [sec]$") + ax.set_ylabel(r"$ACF[ \Phi^{-1}(u_i) ]$") + ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs") + return ax + + def plotSeqCorr(self, fit_num: int = 1, handle=None): + diag = self._compute_diagnostics(fit_num) + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] + uniforms = np.asarray(diag["uniforms"], dtype=float) + if uniforms.size >= 2: + ax.plot( + uniforms[:-1], + uniforms[1:], + ".", + color="tab:orange", + ) + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.set_xlabel("u_j") + ax.set_ylabel("u_{j+1}") + ax.set_title("Sequential Correlation of\nRescaled ISIs") return ax def plotCoeffs(self, fit_num: int = 1, handle=None): @@ -1067,11 +1175,11 @@ def plotCoeffs(self, fit_num: int = 1, handle=None): coeffs = np.asarray(diag["coefficients"], dtype=float) labels = list(np.asarray(diag["coeff_labels"], dtype=object)) xpos = np.arange(coeffs.size, dtype=float) - ax.axhline(0.0, color="0.6", linewidth=1.0) - ax.plot(xpos, coeffs, "o-", color="tab:blue", linewidth=1.0) + ax.plot(xpos, coeffs, "o-", color="tab:blue", linewidth=1.0, label=f"\\lambda_{{{fit_num}}}") ax.set_xticks(xpos, labels, rotation=45, ha="right") - ax.set_ylabel("coefficient value") - ax.set_title("GLM Coefficients") + ax.set_ylabel("GLM Fit Coefficients") + ax.set_title("GLM Coefficients with 95% CIs (* p<0.05)") + ax.legend(loc="lower right") return ax def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): @@ -1083,11 +1191,10 @@ def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotS labels = labels[:-num_hist] ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] xpos = np.arange(coeffs.size, dtype=float) - ax.axhline(0.0, color="0.6", linewidth=1.0) ax.plot(xpos, coeffs, "o-", color="tab:blue", linewidth=1.0) ax.set_xticks(xpos, labels, rotation=45, ha="right") - ax.set_ylabel("coefficient value") - ax.set_title("GLM Coefficients Without History") + ax.set_ylabel("GLM Fit Coefficients") + ax.set_title("GLM Coefficients with 95% CIs (* p<0.05)") return ax def plotHistCoeffs(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): @@ -1096,12 +1203,11 @@ def plotHistCoeffs(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificanc labels = list(self.covLabels[fit_num - 1])[-coeffs.size :] if coeffs.size and fit_num - 1 < len(self.covLabels) else [f"hist_{idx + 1}" for idx in range(coeffs.size)] ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] xpos = np.arange(coeffs.size, dtype=float) - ax.axhline(0.0, color="0.6", linewidth=1.0) if coeffs.size: ax.plot(xpos, coeffs, "o-", color="tab:orange", linewidth=1.0) ax.set_xticks(xpos, labels, rotation=45, ha="right") - ax.set_ylabel("history coefficient") - ax.set_title("History Coefficients") + ax.set_ylabel("GLM Fit Coefficients") + ax.set_title("GLM Coefficients with 95% CIs (* p<0.05)") return ax def setKSStats(self, Z, U, xAxis, KSSorted, ks_stat): @@ -1137,9 +1243,30 @@ def setFitResidual(self, M): return self def toStructure(self) -> dict[str, Any]: + lambda_structure = ( + self.lambda_signal.toStructure() + if hasattr(self.lambda_signal, "toStructure") + else { + "time": self.lambda_signal.time.tolist(), + "data": self.lambda_signal.data.tolist(), + "name": self.lambda_signal.name, + "xlabelval": self.lambda_signal.xlabelval, + "xunits": self.lambda_signal.xunits, + "yunits": self.lambda_signal.yunits, + "dataLabels": list(getattr(self.lambda_signal, "dataLabels", [])), + "plotProps": list(getattr(self.lambda_signal, "plotProps", [])), + } + ) + neural_structure = ( + self.neuralSpikeTrain.toStructure() + if isinstance(self.neuralSpikeTrain, nspikeTrain) + else [train.toStructure() if hasattr(train, "toStructure") else train for train in self.neuralSpikeTrain] + ) + configs_structure = self.configs.toStructure() if self.configs is not None else None return { "covLabels": [list(labels) for labels in self.covLabels], "numHist": list(self.numHist), + "lambda": lambda_structure, "lambda_time": self.lambda_signal.time.tolist(), "lambda_data": self.lambda_signal.data.tolist(), "lambda_name": self.lambda_signal.name, @@ -1148,8 +1275,10 @@ def toStructure(self) -> dict[str, Any]: "AIC": self.AIC.tolist(), "BIC": self.BIC.tolist(), "logLL": self.logLL.tolist(), + "configs": configs_structure, "configNames": list(self.configNames), "fitType": list(self.fitType), + "neuralSpikeTrain": neural_structure, "neural_spike_times": ( self.neuralSpikeTrain.spikeTimes.tolist() if isinstance(self.neuralSpikeTrain, nspikeTrain) @@ -1181,25 +1310,44 @@ def toStructure(self) -> dict[str, Any]: def fromStructure(structure: dict[str, Any]) -> "FitResult": from .trial import ConfigCollection, TrialConfig - spike_times = structure["neural_spike_times"] - neural_name = structure.get("neural_name", "") - neural_min_time = structure.get("neural_min_time", None) - neural_max_time = structure.get("neural_max_time", None) - if spike_times and isinstance(spike_times[0], list): - train: nspikeTrain | list[nspikeTrain] = [] - for st, name, min_t, max_t in zip(spike_times, neural_name, neural_min_time, neural_max_time): - train.append(nspikeTrain(st, name=name, minTime=min_t, maxTime=max_t, makePlots=-1)) + neural_structure = structure.get("neuralSpikeTrain") + if isinstance(neural_structure, dict): + train: nspikeTrain | list[nspikeTrain] = nspikeTrain.fromStructure(neural_structure) + elif isinstance(neural_structure, list) and neural_structure and isinstance(neural_structure[0], dict): + train = [nspikeTrain.fromStructure(item) for item in neural_structure] else: - train = nspikeTrain(spike_times, name=neural_name, minTime=neural_min_time, maxTime=neural_max_time, makePlots=-1) - lam = Covariate( - structure["lambda_time"], - np.asarray(structure["lambda_data"], dtype=float), - structure.get("lambda_name", "lambda"), - "time", - "s", - "spikes/sec", - ) - configColl = ConfigCollection([TrialConfig(name=name) for name in structure.get("configNames", [])]) + spike_times = structure["neural_spike_times"] + neural_name = structure.get("neural_name", "") + neural_min_time = structure.get("neural_min_time", None) + neural_max_time = structure.get("neural_max_time", None) + if spike_times and isinstance(spike_times[0], list): + train = [] + for st, name, min_t, max_t in zip(spike_times, neural_name, neural_min_time, neural_max_time): + train.append(nspikeTrain(st, name=name, minTime=min_t, maxTime=max_t, makePlots=-1)) + else: + train = nspikeTrain(spike_times, name=neural_name, minTime=neural_min_time, maxTime=neural_max_time, makePlots=-1) + + lambda_structure = structure.get("lambda") + if isinstance(lambda_structure, dict): + lam = Covariate.fromStructure(lambda_structure) + else: + lam = Covariate( + structure["lambda_time"], + np.asarray(structure["lambda_data"], dtype=float), + structure.get("lambda_name", "lambda"), + "time", + "s", + "spikes/sec", + ) + + configs_structure = structure.get("configs") + if isinstance(configs_structure, dict): + configColl = ConfigCollection.fromStructure(configs_structure) + else: + configColl = ConfigCollection([TrialConfig(name=name) for name in structure.get("configNames", [])]) + config_names = list(structure.get("configNames", [])) + if config_names: + configColl.setConfigNames(config_names, list(range(1, len(config_names) + 1))) return FitResult( train, structure.get("covLabels", []), @@ -1237,6 +1385,7 @@ def __init__(self, fit_results: FitResult | Iterable[FitResult]) -> None: self.numNeurons = len(self.fitResCell) self.numResults = max(fr.numResults for fr in self.fitResCell) + self.maxNumIndex = int(max(range(self.numNeurons), key=lambda idx: self.fitResCell[idx].numResults) + 1) self.fitNames = self.fitResCell[max(range(self.numNeurons), key=lambda idx: self.fitResCell[idx].numResults)].configNames self.neuronNumbers = [fr.neuronNumber for fr in self.fitResCell] @@ -1277,77 +1426,335 @@ def getDifflogLL(self, idx: int = 1) -> np.ndarray: return self.logLL.copy() def mapCovLabelsToUniqueLabels(self): - self.uniqueCovLabels = _ordered_unique( + self.uniqueCovLabels = _matlab_unique( [label for fit in self.fitResCell for labels in fit.covLabels for label in labels] ) return self.uniqueCovLabels + def computePlotParams(self): + labels = list(self.uniqueCovLabels) + flat_mask = np.zeros((len(labels), self.numResults, self.numNeurons), dtype=int) + b_act = np.full((len(labels), self.numResults, self.numNeurons), np.nan, dtype=float) + se_act = np.full_like(b_act, np.nan) + sig_index = np.zeros_like(b_act, dtype=float) + for neuron_idx, fit in enumerate(self.fitResCell): + for fit_idx in range(1, self.numResults + 1): + if fit_idx > fit.numResults: + continue + curr_labels = fit.covLabels[fit_idx - 1] if fit_idx - 1 < len(fit.covLabels) else [] + index = [labels.index(label) for label in curr_labels if label in labels] + if index: + flat_mask[np.asarray(index, dtype=int), fit_idx - 1, neuron_idx] = 1 + fit_plot_params = fit.getPlotParams() + orig_index = ( + np.asarray(fit.indicesToUniqueLabels[fit_idx - 1], dtype=int).reshape(-1) - 1 + if fit_idx - 1 < len(fit.indicesToUniqueLabels) + else np.array([], dtype=int) + ) + if index and orig_index.size: + mapped = np.asarray(index, dtype=int) + valid = orig_index < fit_plot_params["bAct"].shape[0] + mapped = mapped[valid] + orig_index = orig_index[valid] + b_act[mapped, fit_idx - 1, neuron_idx] = np.asarray(fit_plot_params["bAct"], dtype=float)[orig_index, fit_idx - 1] + se_act[mapped, fit_idx - 1, neuron_idx] = np.asarray(fit_plot_params["seAct"], dtype=float)[orig_index, fit_idx - 1] + sig_index[mapped, fit_idx - 1, neuron_idx] = np.asarray(fit_plot_params["sigIndex"], dtype=float)[orig_index, fit_idx - 1] + self.plotParams = { + "bAct": b_act, + "seAct": se_act, + "sigIndex": sig_index, + "xLabels": labels, + "numResultsCoeffPresent": np.sum(flat_mask, axis=(1, 2)).astype(int), + } + return self.plotParams + def setCoeffRange(self, minVal, maxVal): self.coeffMin = float(minVal) self.coeffMax = float(maxVal) return self def getCoeffs(self, fitNum: int = 1): - labels = self.uniqueCovLabels - coeff_rows = [] - se_rows = [] - for fit in self.fitResCell: - coeffs, fit_labels, se = fit.getCoeffsWithLabels(fitNum) - row = np.full(len(labels), np.nan, dtype=float) - se_row = np.full(len(labels), np.nan, dtype=float) - for coeff, coeff_se, label in zip(coeffs, se, fit_labels, strict=False): - if label in labels: - idx = labels.index(label) - row[idx] = coeff - se_row[idx] = coeff_se - coeff_rows.append(row) - se_rows.append(se_row) - return np.asarray(coeff_rows, dtype=float), labels, np.asarray(se_rows, dtype=float) + fit_idx = int(fitNum) + coeff_index, epoch_id, num_epochs = self.getCoeffIndex(fit_idx) + coeff_index = np.asarray(coeff_index, dtype=int).reshape(-1) + epoch_id = np.asarray(epoch_id, dtype=int).reshape(-1) + if coeff_index.size == 0: + return np.array([], dtype=float), [], np.array([], dtype=float) + + coeff_strings = [str(self.uniqueCovLabels[idx - 1]) for idx in coeff_index] + base_strings = [re.sub(r"_\{\d+\}$", "", label) for label in coeff_strings] + unique_coeffs = _matlab_unique(base_strings) + min_epoch = int(np.min(epoch_id)) if epoch_id.size else 0 + num_epochs = int(num_epochs) if int(num_epochs) > 0 else 1 + plot_params = self.computePlotParams() + coeff_mat = np.full((len(unique_coeffs), num_epochs, self.numNeurons), np.nan, dtype=float) + se_mat = np.full_like(coeff_mat, np.nan) + labels: list[list[str]] = [["" for _ in range(num_epochs)] for _ in unique_coeffs] + + for row_idx, base_label in enumerate(unique_coeffs): + matches = [idx for idx, curr in enumerate(base_strings) if curr == base_label] + coeff_str_index = coeff_index[matches] + curr_epoch_id = epoch_id[matches] + epoch_positions = curr_epoch_id + 1 if min_epoch == 0 else curr_epoch_id + for coeff_label_index, epoch_position in zip(coeff_str_index, epoch_positions, strict=False): + label = str(self.uniqueCovLabels[int(coeff_label_index) - 1]) + labels[row_idx][int(epoch_position) - 1] = label + coeff_mat[row_idx, int(epoch_position) - 1, :] = np.asarray( + plot_params["bAct"][int(coeff_label_index) - 1, fit_idx - 1, :], + dtype=float, + ) + se_mat[row_idx, int(epoch_position) - 1, :] = np.asarray( + plot_params["seAct"][int(coeff_label_index) - 1, fit_idx - 1, :], + dtype=float, + ) + + if self.numNeurons == 1: + coeff_out = coeff_mat[:, :, 0].T + se_out = se_mat[:, :, 0].T + elif num_epochs == 1: + coeff_out = coeff_mat[:, 0, :] + se_out = se_mat[:, 0, :] + else: + coeff_out = coeff_mat + se_out = se_mat + + if num_epochs == 1: + label_out: list[str] | list[list[str]] = [row[0] for row in labels] + else: + label_out = labels + return np.asarray(coeff_out, dtype=float), label_out, np.asarray(se_out, dtype=float) def getHistCoeffs(self, fitNum: int = 1): - labels = _ordered_unique( - [label for fit in self.fitResCell for label in fit.covLabels[fitNum - 1][-int(fit.numHist[fitNum - 1]) :] if fitNum - 1 < len(fit.covLabels) and int(fit.numHist[fitNum - 1]) > 0] - ) - if not labels: - return np.zeros((self.numNeurons, 0), dtype=float), [], np.zeros((self.numNeurons, 0), dtype=float) - coeff_rows = [] - se_rows = [] - for fit in self.fitResCell: - coeffs = fit.getHistCoeffs(fitNum) - fit_labels = list(fit.covLabels[fitNum - 1])[-coeffs.size :] if coeffs.size and fitNum - 1 < len(fit.covLabels) else [] - se = _extract_standard_errors(fit.stats[fitNum - 1] if fitNum - 1 < len(fit.stats) else None, fit.getCoeffs(fitNum).size) - se_hist = se[-coeffs.size :] if coeffs.size else np.array([], dtype=float) - row = np.full(len(labels), np.nan, dtype=float) - se_row = np.full(len(labels), np.nan, dtype=float) - for coeff, coeff_se, label in zip(coeffs, se_hist, fit_labels, strict=False): - if label in labels: - idx = labels.index(label) - row[idx] = coeff - se_row[idx] = coeff_se - coeff_rows.append(row) - se_rows.append(se_row) - return np.asarray(coeff_rows, dtype=float), labels, np.asarray(se_rows, dtype=float) + fit_idx = int(fitNum) + hist_index, epoch_id, num_epochs = self.getHistIndex(fit_idx) + hist_index = np.asarray(hist_index, dtype=int).reshape(-1) + epoch_id = np.asarray(epoch_id, dtype=int).reshape(-1) + if hist_index.size == 0: + return np.array([], dtype=float), [], np.array([], dtype=float) + + hist_strings = [str(self.uniqueCovLabels[idx - 1]) for idx in hist_index] + base_strings = [re.sub(r"_\{\d+\}$", "", label) for label in hist_strings] + unique_coeffs = _matlab_unique(base_strings) + min_epoch = int(np.min(epoch_id)) if epoch_id.size else 0 + num_epochs = int(num_epochs) if int(num_epochs) > 0 else 1 + plot_params = self.computePlotParams() + hist_mat = np.full((len(unique_coeffs), num_epochs, self.numNeurons), np.nan, dtype=float) + labels: list[list[str]] = [["" for _ in range(num_epochs)] for _ in unique_coeffs] + + for row_idx, base_label in enumerate(unique_coeffs): + matches = [idx for idx, curr in enumerate(base_strings) if curr == base_label] + hist_str_index = hist_index[matches] + curr_epoch_id = epoch_id[matches] + epoch_positions = curr_epoch_id + 1 if min_epoch == 0 else curr_epoch_id + for coeff_label_index, epoch_position in zip(hist_str_index, epoch_positions, strict=False): + label = str(self.uniqueCovLabels[int(coeff_label_index) - 1]) + labels[row_idx][int(epoch_position) - 1] = label + hist_mat[row_idx, int(epoch_position) - 1, :] = np.asarray( + plot_params["bAct"][int(coeff_label_index) - 1, fit_idx - 1, :], + dtype=float, + ) + + if self.numNeurons == 1: + hist_out = hist_mat[:, :, 0].T + se_out = np.full_like(hist_out, np.nan, dtype=float) + elif num_epochs == 1: + hist_out = hist_mat[:, 0, :] + se_out = np.full_like(hist_out, np.nan, dtype=float) + else: + hist_out = hist_mat + se_out = np.full_like(hist_out, np.nan, dtype=float) + + if num_epochs == 1: + label_out: list[str] | list[list[str]] = [row[0] for row in labels] + else: + label_out = labels + return np.asarray(hist_out, dtype=float), label_out, np.asarray(se_out, dtype=float) def getSigCoeffs(self, fitNum: int = 1): - coeff_mat, labels, se_mat = self.getCoeffs(fitNum) - sig = np.zeros_like(coeff_mat, dtype=float) + labels = list(self.computePlotParams().get("xLabels", [])) + sig = np.full((len(labels), self.numNeurons), np.nan, dtype=float) for row_idx, fit in enumerate(self.fitResCell): coeffs, fit_labels, se = fit.getCoeffsWithLabels(fitNum) - mask = _extract_significance_mask(fit.stats[fitNum - 1] if fitNum - 1 < len(fit.stats) else None, coeffs, se) - for label, value in zip(fit_labels, mask, strict=False): + mask = _extract_significance_mask( + fit.stats[fitNum - 1] if fitNum - 1 < len(fit.stats) else None, + coeffs, + se, + ) + for coeff, label, value in zip(coeffs, fit_labels, mask, strict=False): if label in labels: - sig[row_idx, labels.index(label)] = value + sig[labels.index(label), row_idx] = float(coeff) * float(value) return sig def binCoeffs(self, minVal, maxVal, binSize): - coeff_mat, _, _ = self.getCoeffs(1) - values = coeff_mat[np.isfinite(coeff_mat)] + plot_params = self.computePlotParams() edges = np.arange(float(minVal), float(maxVal) + float(binSize), float(binSize), dtype=float) if edges.size < 2: edges = np.array([float(minVal), float(maxVal)], dtype=float) - N, edges = np.histogram(values, bins=edges) - percentSig = float(np.mean(self.getSigCoeffs(1))) if coeff_mat.size else 0.0 - return N, edges, percentSig + num_labels = len(plot_params["xLabels"]) + N = np.zeros((edges.size, num_labels), dtype=float) + percent_sig = np.zeros(num_labels, dtype=float) + for idx in range(num_labels): + sig_vals = np.asarray(plot_params["bAct"][idx, :, :], dtype=float) + sig_mask = np.asarray(plot_params["sigIndex"][idx, :, :], dtype=float) == 1 + vals = sig_vals[sig_mask] + vals = vals[np.isfinite(vals)] + counts = np.zeros(edges.size, dtype=float) + if vals.size: + bin_index = np.searchsorted(edges, vals, side="right") - 1 + exact_last = np.isclose(vals, edges[-1]) + bin_index[exact_last] = edges.size - 1 + valid = (vals >= edges[0]) & ((vals < edges[-1]) | exact_last) & (bin_index >= 0) & (bin_index < edges.size) + if np.any(valid): + counts = np.bincount(bin_index[valid], minlength=edges.size).astype(float) + total = counts.sum() + if total > 0: + N[:, idx] = counts / total + denom = float(plot_params["numResultsCoeffPresent"][idx]) if idx < len(plot_params["numResultsCoeffPresent"]) else 0.0 + if denom > 0: + percent_sig[idx] = counts.sum() / denom + return N, edges, percent_sig + + def plot2dCoeffSummary(self, h=None): + if not np.isfinite(self.coeffMin) or not np.isfinite(self.coeffMax): + self.setCoeffRange(-12.0, 12.0) + N, edges, percent_sig = self.binCoeffs(self.coeffMin, self.coeffMax, 0.1) + ax = h if h is not None else plt.subplots(1, 1, figsize=(8.0, 4.0))[1] + handles = [] + for idx, label in enumerate(self.plotParams.get("xLabels", []), start=1): + (line,) = ax.plot(edges, N[:, idx - 1] + idx, linewidth=1.0) + handles.append(line) + ax.text( + float(self.coeffMax), + float(idx), + f"{percent_sig[idx - 1] * 100:.0f}%_{{sig}}", + fontsize=6, + ha="right", + va="center", + ) + ax.set_yticks(np.arange(1, len(self.plotParams.get("xLabels", [])) + 1)) + ax.set_yticklabels(self.plotParams.get("xLabels", []), fontsize=6) + ax.tick_params(axis="x", labelsize=8) + ax.set_ylabel("") + ax.set_xlabel("") + return ax + + def plot3dCoeffSummary(self, h=None): + if not np.isfinite(self.coeffMin) or not np.isfinite(self.coeffMax): + self.setCoeffRange(-12.0, 12.0) + N, edges, _ = self.binCoeffs(self.coeffMin, self.coeffMax, 0.1) + if h is None: + fig = plt.figure(figsize=(8.0, 5.0)) + ax = fig.add_subplot(111, projection="3d") + else: + ax = h + x = np.asarray(edges, dtype=float) + y = np.arange(1, N.shape[1] + 1, dtype=float) + X, Y = np.meshgrid(x, y, indexing="ij") + ax.plot_surface(X, Y, N, edgecolor="none", alpha=0.6) + ax.view_init(elev=28, azim=71.5) + ax.grid(True) + ax.set_yticks(y) + ax.set_yticklabels(self.plotParams.get("xLabels", [])) + return ax + + def getHistIndex(self, fitNum: int | Sequence[int] | None = None, sortByEpoch: int = 0): + del sortByEpoch + if fitNum is None: + fit_indices = list(range(1, self.numResults + 1)) + elif np.isscalar(fitNum): + fit_indices = [int(fitNum)] + else: + fit_indices = [int(item) for item in fitNum] + + hist_index: list[int] = [] + epoch_id: list[int] = [] + for idx, label in enumerate(self.uniqueCovLabels, start=1): + if not isinstance(label, str): + continue + label_lower = label.lower() + if ( + label.startswith("[") + or "*hist" in label_lower + or "history" in label_lower + ): + present = False + for fit_idx in fit_indices: + if fit_idx - 1 >= len(self.fitResCell[0].covLabels): + continue + fit_labels = [ + str(item) + for fit in self.fitResCell + if fit_idx - 1 < len(fit.covLabels) + for item in fit.covLabels[fit_idx - 1] + ] + if label in fit_labels: + present = True + break + if present: + hist_index.append(idx) + epoch_id.append(0) + hist_array = np.asarray(hist_index, dtype=int) + epoch_array = np.asarray(epoch_id, dtype=int) + if hist_array.size: + plot_params = self.computePlotParams() + fit_zero = [fit_idx - 1 for fit_idx in fit_indices if 0 < fit_idx <= self.numResults] + if fit_zero: + b_act = np.asarray(plot_params["bAct"][:, fit_zero, :], dtype=float).reshape(len(self.uniqueCovLabels), -1) + non_nan_index = np.where(np.sum(~np.isnan(b_act), axis=1) >= 1)[0] + 1 + if non_nan_index.size == 0: + fallback = [] + for idx, label in enumerate(self.uniqueCovLabels, start=1): + present = False + for fit_idx in fit_indices: + for fit in self.fitResCell: + if fit_idx - 1 < len(fit.covLabels) and label in fit.covLabels[fit_idx - 1]: + present = True + break + if present: + break + if present: + fallback.append(idx) + non_nan_index = np.asarray(fallback, dtype=int) + valid = np.isin(hist_array, non_nan_index) + hist_array = hist_array[valid] + epoch_array = epoch_array[valid] + num_epochs = int(np.unique(epoch_array).size) if epoch_array.size else 0 + return hist_array, epoch_array, num_epochs + + def getCoeffIndex(self, fitNum: int | Sequence[int] | None = None, sortByEpoch: int = 0): + hist_index, _, _ = self.getHistIndex(fitNum) + hist_set = set(hist_index.tolist()) + if fitNum is None: + fit_indices = list(range(1, self.numResults + 1)) + elif np.isscalar(fitNum): + fit_indices = [int(fitNum)] + else: + fit_indices = [int(item) for item in fitNum] + plot_params = self.computePlotParams() + fit_zero = [fit_idx - 1 for fit_idx in fit_indices if 0 < fit_idx <= self.numResults] + if fit_zero: + b_act = np.asarray(plot_params["bAct"][:, fit_zero, :], dtype=float).reshape(len(self.uniqueCovLabels), -1) + non_nan_index = np.where(np.sum(~np.isnan(b_act), axis=1) >= 1)[0] + 1 + else: + non_nan_index = np.array([], dtype=int) + if non_nan_index.size == 0: + fallback = [] + for idx, label in enumerate(self.uniqueCovLabels, start=1): + present = False + for fit_idx in fit_indices: + for fit in self.fitResCell: + if fit_idx - 1 < len(fit.covLabels) and label in fit.covLabels[fit_idx - 1]: + present = True + break + if present: + break + if present: + fallback.append(idx) + non_nan_index = np.asarray(fallback, dtype=int) + coeff_index = [idx for idx in range(1, len(self.uniqueCovLabels) + 1) if idx not in hist_set and idx in set(non_nan_index.tolist())] + epoch_id = np.zeros(len(coeff_index), dtype=int) + num_epochs = int(np.unique(epoch_id).size) if epoch_id.size else 0 + return np.asarray(coeff_index, dtype=int), epoch_id, num_epochs def plotIC(self, handle=None): fig = handle if handle is not None else plt.figure(figsize=(9.0, 3.5)) @@ -1383,32 +1790,177 @@ def plotlogLL(self, handle=None): def plotResidualSummary(self, handle=None): fig = handle if handle is not None else plt.figure(figsize=(8.0, 3.5)) fig.clear() - ax = fig.subplots(1, 1) - for fit in self.fitResCell: - residual = fit.computeFitResidual().dataToMatrix().reshape(-1) - ax.plot(residual, alpha=0.6) - ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") - ax.set_title("Residual Summary") - ax.set_ylabel("count residual") + num_neurons = max(int(self.numNeurons), 1) + if num_neurons <= 4: + nrows, ncols = 2, 2 + elif num_neurons <= 8: + nrows, ncols = 2, 4 + elif num_neurons <= 12: + nrows, ncols = 3, 4 + elif num_neurons <= 16: + nrows, ncols = 4, 4 + elif num_neurons <= 20: + nrows, ncols = 5, 4 + elif num_neurons <= 24: + nrows, ncols = 6, 4 + elif num_neurons <= 40: + nrows, ncols = 10, 4 + else: + nrows, ncols = 10, 10 + + axes = [fig.add_subplot(nrows, ncols, idx + 1) for idx in range(num_neurons)] + for idx, fit in enumerate(self.fitResCell[:num_neurons]): + ax = axes[idx] + fit.plotResidual(handle=ax) + legend = ax.get_legend() + if idx != num_neurons - 1: + if legend is not None: + legend.remove() + elif legend is not None: + legend.set_loc("center left") + legend.set_bbox_to_anchor((1.02, 0.5)) + ax.set_ylabel("") + ax.set_xlabel("") + ax.set_title("") fig.tight_layout() return fig + def plotAllCoeffs( + self, + h=None, + fitNum: int | Sequence[int] | None = None, + plotProps=None, + plotSignificance: int = 1, + subIndex: Sequence[int] | None = None, + legendLabels: Sequence[str] | None = None, + ): + del plotProps, plotSignificance + ax = h if h is not None else plt.subplots(1, 1, figsize=(9.0, 4.0))[1] + if fitNum is None: + fit_indices = list(range(1, self.numResults + 1)) + elif np.isscalar(fitNum): + fit_indices = [int(fitNum)] + else: + fit_indices = [int(item) for item in fitNum] + + coeff_labels = list(self.uniqueCovLabels) + if subIndex is None: + sub_labels = coeff_labels + else: + sub_zero = [int(idx) - 1 if int(idx) >= 1 else int(idx) for idx in subIndex] + sub_labels = [coeff_labels[idx] for idx in sub_zero if 0 <= idx < len(coeff_labels)] + x = np.arange(1, len(sub_labels) + 1, dtype=float) + + legend_handles: list[Any] = [] + legend_labels: list[str] = [] + for fit_idx in fit_indices: + coeffs, labels, se = self.getCoeffs(fit_idx) + label_map = {label: idx for idx, label in enumerate(labels)} + coeffs = np.asarray(coeffs, dtype=float) + se = np.asarray(se, dtype=float) + if coeffs.ndim == 1: + if coeffs.size == self.numNeurons and len(labels) == 1: + coeffs = coeffs.reshape(self.numNeurons, 1) + se = se.reshape(self.numNeurons, 1) + else: + coeffs = coeffs.reshape(1, -1) + se = se.reshape(1, -1) + elif coeffs.ndim == 2 and coeffs.shape == (len(labels), self.numNeurons): + coeffs = coeffs.T + se = se.T + coeff_view = np.full((self.numNeurons, len(sub_labels)), np.nan, dtype=float) + se_view = np.full_like(coeff_view, np.nan) + for col, label in enumerate(sub_labels): + src = label_map.get(label) + if src is not None: + coeff_view[:, col] = coeffs[:, src] + se_view[:, col] = se[:, src] + handle = None + for neuron_idx in range(self.numNeurons): + eb = ax.errorbar( + x, + coeff_view[neuron_idx, :], + yerr=se_view[neuron_idx, :], + fmt=".", + linewidth=1.0, + markersize=6.0, + alpha=0.9, + ) + if handle is None: + handle = eb.lines[0] + if handle is not None: + legend_handles.append(handle) + if legendLabels is not None and fit_idx - 1 < len(legendLabels): + legend_labels.append(str(legendLabels[fit_idx - 1])) + else: + legend_labels.append(f"\\lambda_{{{fit_idx}}}") + + ax.set_ylabel("Fit Coefficients") + ax.set_xticks(x, sub_labels, rotation=90 if len(sub_labels) > 1 else 0) + ax.grid(True, alpha=0.25) + ax.margins(x=0.02) + if legend_handles: + ax.legend(legend_handles, legend_labels, loc="lower right", fontsize=10) + ymin, ymax = ax.get_ylim() + self.setCoeffRange(ymin, ymax) + return ax + + def plotCoeffsWithoutHistory( + self, + fitNum: int | Sequence[int] | None = None, + sortByEpoch: int = 0, + plotSignificance: int = 1, + handle=None, + ): + coeff_index, _, _ = self.getCoeffIndex(fitNum, sortByEpoch) + return self.plotAllCoeffs( + h=handle, + fitNum=fitNum, + plotSignificance=plotSignificance, + subIndex=coeff_index.tolist() if coeff_index.size else [], + ) + + def plotHistCoeffs( + self, + fitNum: int | Sequence[int] | None = None, + sortByEpoch: int = 0, + plotSignificance: int = 1, + handle=None, + ): + hist_index, _, _ = self.getHistIndex(fitNum, sortByEpoch) + return self.plotAllCoeffs( + h=handle, + fitNum=fitNum, + plotSignificance=plotSignificance, + subIndex=hist_index.tolist() if hist_index.size else [], + ) + def plotSummary(self, handle=None): - fig = handle if handle is not None else plt.figure(figsize=(10.0, 4.5)) + fig = handle if handle is not None else plt.figure(figsize=(12.0, 7.0)) fig.clear() - axes = fig.subplots(1, 3) - x = np.arange(self.numResults, dtype=float) - labels = list(self.fitNames) - for ax, values, title in zip( - axes, - (self.meanAIC, self.meanBIC, self.meanlogLL), - ("AIC", "BIC", "log likelihood"), - strict=False, - ): - ax.bar(x, np.asarray(values, dtype=float), color="tab:blue", alpha=0.8) - ax.set_xticks(x, labels, rotation=30, ha="right") - ax.set_title(title) - ax.grid(axis="y", alpha=0.25) + gs = fig.add_gridspec(2, 4) + coeff_ax = fig.add_subplot(gs[:, :2]) + self.plotAllCoeffs(h=coeff_ax, legendLabels=self.fitNames) + coeff_ax.grid(False) + coeff_ax.set_title("GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)") + + ks_ax = fig.add_subplot(gs[0, 2:]) + ks_ax.boxplot(self.KSStats, labels=self.fitNames) + ks_ax.set_ylabel("KS Statistics") + ks_ax.set_title("KS Statistics Across Neurons") + + aic_ax = fig.add_subplot(gs[1, 2]) + self.boxPlot(self.getDiffAIC(1), diffIndex=1, h=aic_ax) + aic_ax.set_ylabel("\\Delta AIC") + aic_ax.set_title("Change in AIC Across Neurons") + aic_ax.tick_params(axis="x", rotation=90) + + bic_ax = fig.add_subplot(gs[1, 3]) + self.boxPlot(self.getDiffBIC(1), diffIndex=1, h=bic_ax) + bic_ax.set_ylabel("\\Delta BIC") + bic_ax.set_title("Change in BIC Across Neurons") + bic_ax.tick_params(axis="x", rotation=90) + fig.tight_layout() return fig @@ -1423,7 +1975,11 @@ def boxPlot(self, X, diffIndex: int = 1, h=None, dataLabels=None, **kwargs): elif values.shape[1] == len(self.fitNames): labels = list(self.fitNames) elif values.shape[1] == max(len(self.fitNames) - 1, 1): - labels = [name for idx, name in enumerate(self.fitNames, start=1) if idx != diffIndex] + labels = [ + f"{name} - {self.fitNames[diffIndex - 1]}" + for idx, name in enumerate(self.fitNames, start=1) + if idx != diffIndex + ] else: labels = list(self.fitNames[: values.shape[1]]) ax.boxplot(values, labels=labels) @@ -1434,6 +1990,8 @@ def toStructure(self) -> dict[str, Any]: "fitResCell": FitResult.CellArrayToStructure(self.fitResCell), "numNeurons": self.numNeurons, "numResults": self.numResults, + "maxNumIndex": self.maxNumIndex, + "neuronNumbers": list(self.neuronNumbers), "fitNames": list(self.fitNames), "dev": self.dev.tolist(), "AIC": self.AIC.tolist(), @@ -1442,12 +2000,32 @@ def toStructure(self) -> dict[str, Any]: "KSStats": self.KSStats.tolist(), "KSPvalues": self.KSPvalues.tolist(), "withinConfInt": self.withinConfInt.tolist(), + "covLabels": [list(labels) for labels in getattr(self, "covLabels", [])], + "uniqueCovLabels": list(getattr(self, "uniqueCovLabels", [])), + "indicesToUniqueLabels": [ + [np.asarray(item, dtype=float).reshape(-1).tolist() for item in row] + for row in getattr(self, "indicesToUniqueLabels", []) + ] + if getattr(self, "indicesToUniqueLabels", None) + else [], + "flatMask": np.asarray(getattr(self, "flatMask", np.zeros((0, 0, 0), dtype=float)), dtype=float).tolist(), + "bAct": np.asarray(getattr(self, "bAct", np.zeros((0, 0, 0), dtype=float)), dtype=float).tolist(), + "seAct": np.asarray(getattr(self, "seAct", np.zeros((0, 0, 0), dtype=float)), dtype=float).tolist(), + "sigIndex": np.asarray(getattr(self, "sigIndex", np.zeros((0, 0, 0), dtype=float)), dtype=float).tolist(), + "numCoeffs": int(getattr(self, "numCoeffs", 0)), + "numResultsCoeffPresent": np.asarray( + getattr(self, "numResultsCoeffPresent", np.zeros(0, dtype=float)), + dtype=float, + ).reshape(-1).tolist(), + "coeffRange": [] if getattr(self, "coeffRange", None) in (None, []) else np.asarray(self.coeffRange, dtype=float).reshape(-1).tolist(), } @staticmethod def fromStructure(structure: dict[str, Any]) -> "FitSummary": fits = [FitResult.fromStructure(item) for item in structure.get("fitResCell", [])] - return FitSummary(fits) + summary = FitSummary(fits) + summary.fitNames = [f"Fit {idx + 1}" for idx in range(summary.numResults)] + return summary class FitResSummary(FitSummary): diff --git a/nstat/glm.py b/nstat/glm.py index b078d821..d2fb7fe9 100644 --- a/nstat/glm.py +++ b/nstat/glm.py @@ -58,7 +58,7 @@ def fit_poisson_glm( *, offset: Sequence[float] | np.ndarray | None = None, include_intercept: bool = True, - l2: float = 1e-6, + l2: float = 0.0, max_iter: int = 120, tol: float = 1e-8, ) -> PoissonGLMResult: @@ -125,7 +125,7 @@ def fit_binomial_glm( y: Sequence[float] | np.ndarray, *, include_intercept: bool = True, - l2: float = 1e-6, + l2: float = 0.0, max_iter: int = 120, tol: float = 1e-8, ) -> BinomialGLMResult: diff --git a/nstat/trial.py b/nstat/trial.py index 67c7f6c8..f82355c3 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -1301,6 +1301,181 @@ def estimateVarianceAcrossTrials( varEst = np.nanvar(diffs, axis=1, ddof=1) return np.diag(varEst) + def ssglm( + self, + windowTimes=None, + numBasis: int | None = None, + numVarEstIter: int | None = None, + fitType: str | None = None, + ): + """MATLAB-facing state-space GLM entry point. + + The original MATLAB method delegates the latent-state recursion to + `DecodingAlgortihms.PPSS_EM`, which is not ported as a named Python + surface today. The Python port still exposes the same public method + and return signature by fitting the per-trial GLM basis/history + models, estimating across-trial state variance from those fits, and + returning deterministic summary arrays with MATLAB-compatible shapes. + """ + + if fitType is None or fitType == "": + fitType = "poisson" + if numVarEstIter is None: + numVarEstIter = 10 + if numBasis is None: + basisWidth = 0.02 + duration = float(self.maxTime) - float(self.minTime) + numBasis = max(int(round(duration / basisWidth)), 1) + if windowTimes is None: + windowTimes = [] + + numBasis = int(numBasis) + fitType = str(fitType) + history_times = np.asarray(windowTimes, dtype=float).reshape(-1) + numHist = max(history_times.size - 1, 0) + delta = 1.0 / float(self.sampleRate) + basisWidth = (float(self.maxTime) - float(self.minTime)) / float(max(numBasis, 1)) + + from .analysis import Analysis, _fit_lambda_matrix_to_covariate, _glm_deviance + from .fit import FitResSummary, FitResult, _SingleFit + from .glm import fit_binomial_glm + + basis = self.generateUnitImpulseBasis(basisWidth, float(self.minTime), float(self.maxTime), float(self.sampleRate)) + label_select = [[basis.name, *list(basis.dataLabels)]] + + xK = np.zeros((numBasis, self.numSpikeTrains), dtype=float) + WK = np.zeros((numBasis, numBasis, self.numSpikeTrains), dtype=float) + gamma_bank = np.full((numHist, self.numSpikeTrains), np.nan, dtype=float) if numHist else np.zeros((0, self.numSpikeTrains), dtype=float) + logll_bank = np.zeros(self.numSpikeTrains, dtype=float) + fit_results = [] + + algorithm = "GLM" if fitType.lower() == "poisson" else "BNLRCG" + for idx, train in enumerate(self.nstrain, start=1): + train_copy = train.nstCopy() + if not str(getattr(train_copy, "name", "")): + train_copy.setName(str(idx)) + trial = Trial(SpikeTrainCollection([train_copy]), CovariateCollection([basis])) + cfg = TrialConfig( + covMask=label_select, + sampleRate=float(self.sampleRate), + history=history_times.tolist() if history_times.size else [], + ensCovHist=[], + name="SSGLM", + ) + if fitType.lower() == "binomial": + cfg.setConfig(trial) + x = np.asarray(trial.getDesignMatrix(1), dtype=float) + lambda_time = np.asarray(trial.getCov(1).time, dtype=float).reshape(-1) + sample_rate = float(trial.sampleRate) + dt = 1.0 / max(sample_rate, 1e-12) + bin_edges = np.concatenate([lambda_time, [lambda_time[-1] + dt]]) if lambda_time.size else np.array([0.0, dt], dtype=float) + y = np.asarray(trial.nspikeColl.getNST(1).to_binned_counts(bin_edges), dtype=float).reshape(-1) + n_obs = min(x.shape[0], y.shape[0], lambda_time.shape[0]) + x = x[:n_obs, :] + y = np.clip(y[:n_obs], 0.0, 1.0) + lambda_time = lambda_time[:n_obs] + + glm_res = fit_binomial_glm(x, y, include_intercept=False, l2=0.0, max_iter=120) + lambda_delta = np.clip(glm_res.predict_probability(x), 1e-12, 1.0 - 1e-9) + rate_hz = lambda_delta * sample_rate + deviance = _glm_deviance(y, lambda_delta, "binomial") + coeffs_full = np.asarray(glm_res.coefficients, dtype=float).reshape(-1) + n_params = int(coeffs_full.size) + matlab_bin_mass = np.maximum(rate_hz / max(sample_rate, 1e-12), 1e-12) + log_likelihood = float(np.sum(y * np.log(matlab_bin_mass) + (1.0 - y) * np.log(1.0 - matlab_bin_mass))) + aic = float(2.0 * n_params + deviance) + bic = float(np.log(max(y.shape[0], 1)) * n_params + deviance) + lambda_signal = _fit_lambda_matrix_to_covariate(lambda_time, [rate_hz], 1) + lambda_signal.setDataLabels([cfg.name or "SSGLM"]) + single_fit = _SingleFit( + name=cfg.name or "SSGLM", + coefficients=coeffs_full, + intercept=float(glm_res.intercept), + log_likelihood=log_likelihood, + aic=aic, + bic=bic, + stats={ + "intercept": float(glm_res.intercept), + "n_iter": int(glm_res.n_iter), + "converged": bool(glm_res.converged), + }, + ) + fit = FitResult(train_copy, lambda_signal, [single_fit]) + fit.dev[0] = deviance + fit.AIC[0] = aic + fit.BIC[0] = bic + fit.logLL[0] = log_likelihood + fit.fitType = ["binomial"] + coeffs = coeffs_full + hist_coeffs = coeffs_full[-numHist:] if numHist else np.asarray([], dtype=float) + else: + fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigCollection([cfg]), makePlot=0, Algorithm=algorithm) + coeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1) + hist_coeffs = np.asarray(fit.getHistCoeffs(1), dtype=float).reshape(-1) if numHist else np.asarray([], dtype=float) + fit_results.append(fit) + + stim_coeffs = coeffs[:-numHist] if numHist and coeffs.size >= numHist else coeffs + if stim_coeffs.size < numBasis: + padded = np.zeros(numBasis, dtype=float) + padded[: stim_coeffs.size] = stim_coeffs + stim_coeffs = padded + else: + stim_coeffs = stim_coeffs[:numBasis] + xK[:, idx - 1] = stim_coeffs + logll_bank[idx - 1] = float(np.asarray(fit.logLL, dtype=float).reshape(-1)[0]) + if numHist: + gamma_row = np.full(numHist, np.nan, dtype=float) + take = min(hist_coeffs.size, numHist) + if take: + gamma_row[:take] = hist_coeffs[:take] + gamma_bank[:, idx - 1] = gamma_row + + Qhat = np.asarray(self.estimateVarianceAcrossTrials(numBasis, history_times, numVarEstIter, fitType), dtype=float) + if Qhat.shape != (numBasis, numBasis): + Qhat = np.zeros((numBasis, numBasis), dtype=float) + if np.any(np.diag(Qhat) == 0.0): + Qhat = Qhat + 0.001 * np.eye(numBasis, dtype=float) + + for idx in range(self.numSpikeTrains): + WK[:, :, idx] = Qhat + + if numHist: + gammahat = np.nanmean(gamma_bank, axis=1) + gammahat[np.isnan(gammahat)] = -5.0 + else: + gammahat = np.zeros(0, dtype=float) + + fit_summary = FitResSummary(fit_results) + logll = np.asarray([float(np.mean(logll_bank))], dtype=float) + return xK, WK, Qhat, gammahat, logll, fit_summary + + def toStructure(self) -> dict[str, Any]: + original_mask = np.asarray(self.neuronMask, dtype=int).copy() + self.resetMask() + structure = { + "nstrain": [train.toStructure() for train in self.nstrain], + "numSpikeTrains": int(self.numSpikeTrains), + "minTime": float(self.minTime), + "maxTime": float(self.maxTime), + "sampleRate": float(self.sampleRate), + "neuronMask": np.asarray(self.neuronMask, dtype=int).copy(), + "neuronNames": self.getNSTnames(), + "neighbors": [] if not self.areNeighborsSet() else np.asarray(self.neighbors, dtype=int).copy(), + } + self.neuronMask = original_mask + return structure + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "SpikeTrainCollection": + trains = [nspikeTrain.fromStructure(item) for item in structure["nstrain"]] + coll = SpikeTrainCollection(trains) + coll.setMinTime(float(structure["minTime"])) + coll.setMaxTime(float(structure["maxTime"])) + neighbors = structure.get("neighbors", []) + if not _is_empty_config_value(neighbors): + coll.setNeighbors(np.asarray(neighbors, dtype=int)) + return coll + @staticmethod def generateUnitImpulseBasis(basisWidth: float, minTime: float, maxTime: float, sampleRate: float = 1000.0) -> Covariate: windowTimes = np.arange(float(minTime), float(maxTime), float(basisWidth)) @@ -1701,6 +1876,94 @@ def plot(self, *_, handle=None, **__): fig.tight_layout() return fig + def plotRaster(self, handle=None): + fig = handle if hasattr(handle, "subplots") else plt.figure(handle) if handle is not None else plt.figure() + fig.clear() + ax = fig.subplots(1, 1) + self.nspikeColl.plot(handle=ax) + return fig + + def plotCovariates(self, handle=None): + fig = handle if hasattr(handle, "subplots") else plt.figure(handle) if handle is not None else plt.figure() + fig.clear() + active_cov = [idx for idx, selector in enumerate(self.covarColl.getSelectorFromMasks(), start=1) if selector] + if not active_cov: + active_cov = list(range(1, self.covarColl.numCov + 1)) + numCovars = len(active_cov) + + if numCovars == 1: + axes = [fig.subplots(1, 1)] + elif numCovars == 2: + axes = list(np.asarray(fig.subplots(1, 2), dtype=object).reshape(-1)) + elif numCovars == 3: + raster_ax = fig.add_subplot(3, 2, (1, 3, 5)) + self.nspikeColl.plot(handle=raster_ax) + axes = [fig.add_subplot(3, 2, 2), fig.add_subplot(3, 2, 4), fig.add_subplot(3, 2, 6)] + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=raster_ax) + else: + raster_fig = plt.figure() + raster_ax = raster_fig.subplots(1, 1) + self.nspikeColl.plot(handle=raster_ax) + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=raster_ax) + axes = list(np.asarray(fig.subplots(numCovars, 1), dtype=object).reshape(-1)) + + for ax, cov_index in zip(axes, active_cov, strict=False): + cov = self.covarColl.getCov(cov_index) + cov.plot(handle=ax) + ax.set_title(cov.name) + if self.ev is not None and self.ev.eventTimes.size: + self.ev.plot(handle=ax) + fig.tight_layout() + return fig + + def toStructure(self) -> dict[str, Any]: + structure: dict[str, Any] = { + "nspikeColl": self.nspikeColl.toStructure(), + "covarColl": self.covarColl.toStructure(), + "ev": [] if self.ev is None else self.ev.toStructure(), + "history": [] if self.history in (None, []) else self.history.toStructure(), + "ensCovHist": [] if self.ensCovHist in (None, []) else self.ensCovHist.toStructure(), + "sampleRate": float(self.sampleRate), + "minTime": float(self.minTime), + "maxTime": float(self.maxTime), + "covMask": [np.asarray(mask, dtype=int).copy() for mask in self.covMask], + "ensCovMask": np.asarray(self.ensCovMask, dtype=int).copy(), + "neuronMask": np.asarray(self.neuronMask, dtype=int).copy(), + "trainingWindow": [] if self.trainingWindow is None else np.asarray(self.trainingWindow, dtype=float).copy(), + "validationWindow": [] if self.validationWindow is None else np.asarray(self.validationWindow, dtype=float).copy(), + } + return structure + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "Trial": + from .history import History + + nspikeColl = SpikeTrainCollection.fromStructure(structure["nspikeColl"]) + covarColl = CovariateCollection.fromStructure(structure["covarColl"]) + ev = Events.fromStructure(structure["ev"]) + history = [] if _is_empty_config_value(structure.get("history", [])) else History.fromStructure(structure["history"]) + ensCovHist = [] if _is_empty_config_value(structure.get("ensCovHist", [])) else History.fromStructure(structure["ensCovHist"]) + + trial = Trial(nspikeColl, covarColl, ev, history, ensCovHist, structure.get("ensCovMask", [])) + trial.setMinTime(float(structure["minTime"])) + trial.setMaxTime(float(structure["maxTime"])) + training = np.asarray(structure.get("trainingWindow", []), dtype=float).reshape(-1) + validation = np.asarray(structure.get("validationWindow", []), dtype=float).reshape(-1) + if training.size and validation.size: + trial.setTrialPartition(np.concatenate([training, validation])) + trial.covMask = [np.asarray(mask, dtype=int).copy() for mask in structure.get("covMask", trial.covMask)] + trial.covarColl.covMask = [mask.copy() for mask in trial.covMask] + if "ensCovMask" in structure: + trial.ensCovMask = np.asarray(structure["ensCovMask"], dtype=int).copy() + if "neuronMask" in structure: + trial.neuronMask = np.asarray(structure["neuronMask"], dtype=int).copy() + trial.nspikeColl.neuronMask = trial.neuronMask.copy() + if "sampleRate" in structure: + trial.sampleRate = float(structure["sampleRate"]) + return trial + def setSampleRate(self, sampleRate: float) -> None: self.sampleRate = float(sampleRate) self.nspikeColl.resample(sampleRate) @@ -1890,7 +2153,29 @@ def getEnsCovMatrix(self, neuronNum: int, includedNeurons=None) -> np.ndarray: ensCovCollTemp = CovariateCollection(self.ensCovColl.covArray) ensCovCollTemp.covMask = [mask.copy() for mask in self.ensCovColl.covMask] ensCovCollTemp.maskAwayAllExcept(includedNeurons) - return ensCovCollTemp.dataToMatrix("standard") + if self.covarColl.numCov: + target_time = self.covarColl.matrixWithTime("standard")[0] + else: + target_time = self.nspikeColl.getNST(neuronNum).getSigRep().time + target_time = np.asarray(target_time, dtype=float).reshape(-1) + selector_cell = ensCovCollTemp.getSelectorFromMasks() + active_cov = [index + 1 for index, selector in enumerate(selector_cell) if selector] + if not active_cov: + return np.zeros((target_time.size, 0), dtype=float) + + parts: list[np.ndarray] = [] + for covIndex in active_cov: + selector = selector_cell[covIndex - 1] + cov = _copy_covariate(ensCovCollTemp.getCov(covIndex)) + cov.setMinTime(float(self.minTime)) + cov.setMaxTime(float(self.maxTime)) + sig = cov.getSigRep("standard") + data = sig.dataToMatrix(selector) + block = np.zeros((target_time.size, data.shape[1]), dtype=float) + endInd = min(target_time.size, data.shape[0]) + block[:endInd, :] = data[:endInd, :] + parts.append(block) + return np.hstack(parts) if parts else np.zeros((target_time.size, 0), dtype=float) def getNeuronIndFromMask(self) -> list[int]: return self.nspikeColl.getIndFromMask() @@ -1931,6 +2216,14 @@ def getNeuron(self, identifier): def getAllCovLabels(self) -> list[str]: return self.covarColl.getAllCovLabels() + def getAllLabels(self) -> list[str]: + labels = list(self.getAllCovLabels()) + if self.isHistSet(): + labels.extend(self.getHistLabels()) + if self.isEnsCovHistSet(): + labels.extend(self.getEnsCovLabels()) + return labels + def getCovLabelsFromMask(self) -> list[str]: return self.covarColl.getCovLabelsFromMask() @@ -1939,6 +2232,17 @@ def getHistLabels(self) -> list[str]: return [] return self.getHistForNeurons(1).getAllCovLabels() + def getNumHist(self): + if not self.isHistSet(): + return 0 + from .history import History + + if isinstance(self.history, History): + return max(len(self.history.windowTimes) - 1, 0) + if isinstance(self.history, list): + return [max(len(item.windowTimes) - 1, 0) for item in self.history] + return 0 + def getEnsCovLabels(self) -> list[str]: if not self.isEnsCovHistSet() or self.ensCovColl is None: return [] @@ -2009,6 +2313,10 @@ def findMaxSampleRate(self) -> float: values = [value for value in [self.nspikeColl.findMaxSampleRate(), self.covarColl.findMaxSampleRate()] if np.isfinite(value)] return float(max(values)) if values else float("nan") + def findMinSampleRate(self) -> float: + values = [value for value in [self.sampleRate, self.nspikeColl.sampleRate, self.covarColl.sampleRate] if np.isfinite(value)] + return float(min(values)) if values else float("nan") + def findMinTime(self) -> float: return float(min(self.nspikeColl.minTime, self.covarColl.minTime)) diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 65d51300..fe566cad 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -1,5 +1,5 @@ version: 1 -generated_on: 2026-03-08 +generated_on: 2026-03-09 source_repositories: matlab: https://github.com/cajigaslab/nSTAT python: https://github.com/cajigaslab/nSTAT-python @@ -16,17 +16,18 @@ items: matlab_path: SignalObj.m python_public_name: nstat.SignalObj python_impl_path: nstat/core.py - status: exact + status: high_fidelity constructor_parity: Constructor defaults, orientation handling, labels, masks, sample-rate inference, and time-window APIs now mirror MATLAB closely. property_parity: Core observable fields exist, including time, data, name, xlabelval, xunits, yunits, sampleRate, originalTime, originalData, dataMask, plotProps, and confidence-interval storage. - method_parity: MATLAB-facing methods now cover labels, masking, sub-signals, nearest-time - lookup, time-window extraction, merge, arithmetic operators, derivative/derivativeAt, - integral, filtering, compatibility alignment, autocorrelation/crosscorrelation/xcorr, - abs/log, mean/median/mode/std, min/max summaries, plotting, restore/reset, resampling, - and structure export. + method_parity: MATLAB-facing methods now cover labels, masking, plot-property helpers, + sub-signals, nearest-time lookup, time-window extraction, merge, arithmetic operators, + derivative/derivativeAt, integral, filtering, compatibility alignment, shifting/alignment, + autocorrelation/crosscorrelation/xcorr/xcov, abs/log/power/sqrt, mean/median/mode/std, + variability plotting, min/max summaries, plotting, restore/reset, resampling, and + structure export. defaults_parity: Defaults for labels, units, and sample-rate fallback now match MATLAB closely, including the 1 kHz fallback when sample spacing is ill-conditioned. indexing_parity: Signals use time-by-dimension storage and one-based selector behavior @@ -37,13 +38,10 @@ items: expected. symbol_presence_verified: yes known_remaining_differences: - - Some specialized MATLAB spectral utilities and report-style plotting options remain - unported. + - "`xcov`, the canonical single-channel `periodogram`, the canonical single-channel `spectrogram` path, and the canonical single-channel `MTMspectrum` frequency/power payload are now compared against MATLAB fixtures, but `MTMspectrum` still shows small numerical drift versus MATLAB's `pmtm` output and the remaining MATLAB-specific spectral/report helpers are still carried by scipy/numpy-native implementations rather than exact MATLAB toolbox objects." - Structure serialization is close but not exhaustive for every MATLAB-only field. required_remediation: - - Extend the committed MATLAB-derived fixtures beyond derivative, integral, spline - resampling, filtering, `makeCompatible`, and `xcorr` to cover the remaining - spectral utility methods. + - "Tighten the `MTMspectrum` implementation from MATLAB-compared high-fidelity to exact parity, and extend the committed MATLAB-derived fixtures beyond derivative, integral, spline resampling, filtering, `makeCompatible`, `xcorr`, `xcov`, `periodogram`, `spectrogram`, and `MTMspectrum` to cover the remaining spectral utility methods." - MATLAB's legacy `autocorrelation`/`crosscorrelation` code path depends on a `crosscorr` call that is not directly executable in the current MATLAB runtime; keep those methods source-audited until a portable reference fixture path is @@ -120,7 +118,8 @@ items: management, getFieldVal, getSpikeTimes/getISIs wrappers, BinarySigRep/isSigRepBinary, fixture-backed dataToMatrix, fixture-backed toSpikeTrain collapsing, fixture-backed ensemble-covariate helpers, restoreToOriginal, fixture-backed psth, psthGLM, - deterministic-fallback psthBars, and Python-side estimateVarianceAcrossTrials. + deterministic-fallback psthBars, Python-side estimateVarianceAcrossTrials, and + a MATLAB-facing ssglm entry point. defaults_parity: Defaults for masks, sample rate, and min/max time now track MATLAB collection semantics closely. indexing_parity: MATLAB-facing one-based getNST is preserved. @@ -132,14 +131,19 @@ items: - psthBars now exists, but MATLAB delegates to an external BARS fitter that is not bundled with the source tree; the Python port currently uses a deterministic smoothed PSTH fallback instead of exact BARS output. - - MATLAB-only public branch `ssglm` remains unported. + - MATLAB-side `ssglm` is now fixture-backed and Python `ssglm` now runs the canonical + binomial fixture case, but the Python path still reconstructs the workflow from + the existing GLM/smoother stack rather than a PPSS_EM-equivalent recursion. The + latent-state outputs and summary metrics remain materially different from the + MATLAB fixture, so the method is not exact. - "estimateVarianceAcrossTrials now exists in Python, but the nontrivial MATLAB reference path is internally inconsistent through psthGLM / RunAnalysisForAllNeurons, so the method is not yet fixture-backed strongly enough to promote nstColl to exact." - Collection-level plotting/report layout still differs from MATLAB in subplot composition and presentation details. required_remediation: - - Port `ssglm`. + - Add MATLAB-backed fixtures for the new ssglm public surface, or port a PPSS_EM-equivalent + recursion closely enough to promote the method from high_fidelity to exact. - Add or vendor a stable BARS-equivalent reference path before promoting psthBars behavior to exact. - "Add a stable MATLAB-side reference/export path for estimateVarianceAcrossTrials, then back the Python method with fixtures before promoting nstColl to exact." @@ -161,8 +165,8 @@ items: covMask, ensCovMask, neuronMask, trainingWindow, and validationWindow. method_parity: The MATLAB trial workflow is much richer now, covering event/history setup, partitioning, sample-rate and time consistency, neuron/covariate masking, - design-matrix generation, history/ensemble covariates, label extraction, and restore/reset - helpers. + design-matrix generation, history/ensemble covariates, label extraction, structure + round-tripping, and restore/reset helpers. defaults_parity: Default object state and partition behavior are much closer to MATLAB than the earlier thin implementation. indexing_parity: Core one-based neuron selection is preserved via getSpikeVector. @@ -173,14 +177,16 @@ items: objects where expected. symbol_presence_verified: yes known_remaining_differences: - - Some MATLAB plotting, partition-serialization, and specialized workflow helpers - remain unported. + - Core partitioning, ensemble-history construction, and validation-window design-matrix + behavior are now fixture-backed, and Trial.toStructure/fromStructure now round-trip + against MATLAB fixtures, but several MATLAB plotting views and specialized + helpfile-only Trial helpers still remain lighter in Python. required_remediation: - - Add dataset-backed fixtures for trial partitioning, ensemble-history construction, - and design-matrix parity. - Port the remaining specialized Trial helpers used only in MATLAB helpfiles. + - Add fixture-backed coverage for the remaining Trial plotting/report branches + before promoting the class to exact. plotting_report_parity: Notebook-facing trial plots work, but several MATLAB display, - partition-summary, and serialization views remain lighter. + and partition-summary views remain lighter. - matlab_name: TrialConfig kind: class matlab_path: TrialConfig.m @@ -261,15 +267,29 @@ items: - Advanced MATLAB algorithm-selection, cross-validation, and some report-layout branches are still lighter than MATLAB. - The canonical single-neuron GLM path is now fixture-backed for coefficients, - lambda traces, AIC, BIC, stored logLL, KS statistic, residuals, and the discrete-time - KS arrays under injected within-bin draws. Remaining gaps are now concentrated - in broader algorithm-selection, validation-window, and multi-neuron branches - rather than the canonical baseline diagnostics, and the helper surface now also - accepts MATLAB-style multi-trial spike inputs by collapsing them through fixture-backed - `nstColl.toSpikeTrain` semantics. + lambda traces, AIC, BIC, stored logLL, direct `GLMFit`, direct + `computeKSStats`, direct `computeFitResidual`, KS statistic, residuals, the canonical + single-neuron binomial `BNLRCG` branch (at tight numerical tolerance on the + coefficient/lambda solver outputs), the validation-window + GLM branch, the validation-window `plotResults` dashboard, the discrete-time KS arrays under injected within-bin draws, and + the canonical two-neuron `RunAnalysisForAllNeurons` summary payload/structure, + direct `plotAllCoeffs`, `plotSummary`, `plotIC`, direct `plotAIC`/`plotBIC`/`plotlogLL`, + `plotResidualSummary`, `plotCoeffsWithoutHistory`, and `plotHistCoeffs`. Remaining gaps are now concentrated in + broader algorithm-selection, + cross-validation, richer multi-neuron report branches beyond the canonical + two-neuron fixture, and richer alternative report-layout branches rather than + the canonical baseline, validation diagnostics, or the basic multi-neuron summary/report + surface, and the helper surface now also accepts MATLAB-style multi-trial spike + inputs by collapsing them through fixture-backed `nstColl.toSpikeTrain` semantics. + On the tiny canonical two-neuron history-fit branch, MATLAB exports `NaN` + history coefficients while Python currently returns finite fallback estimates, + and the associated KS summary diagnostics also diverge, so that sub-branch + remains non-exact even though the associated AIC/BIC/logLL summary payload + is fixture-backed. required_remediation: - - Extend the committed MATLAB-derived fixture coverage beyond the canonical single-neuron - GLM workflow to multi-neuron, validation-window, and alternative algorithm branches. + - Extend the committed MATLAB-derived fixture coverage beyond the canonical, validation-window, + and basic two-neuron summary workflows to richer multi-neuron and alternative + algorithm branches. - Port remaining algorithm-selection and validation-option branches from MATLAB. plotting_report_parity: KS, inverse-Gaussian, coefficient, residual, and summary plots now execute on canonical Analysis output; advanced algorithm-selection, @@ -300,16 +320,25 @@ items: known_remaining_differences: - Plotting/report methods now execute, Z/U/X semantics now follow MATLAB more closely, and the canonical baseline fit is fixture-backed for AIC/BIC/logLL, KS statistic, - residual traces, deterministic discrete-time KS arrays, and the stored MATLAB-style - KS p-value. Remaining differences are concentrated in richer report layouts, - validation payloads, and multi-fit branches. + residual traces, deterministic discrete-time KS arrays, the stored MATLAB-style + KS p-value, the canonical `plotResults` dashboard surface, and the underlying + single-fit `KSPlot`, `plotInvGausTrans`, `plotSeqCorr`, `plotResidual`, + `plotCoeffs`, `plotCoeffsWithoutHistory`, and `plotHistCoeffs` branches. + The direct `getCoeffIndex`, `getHistIndex`, and `getParam('stim', ...)` helper + surface, the validation-window `plotResults` dashboard, the canonical `toStructure` + payload plus `fromStructure` round-trip, and the history-only subset structure + payload plus round-trip are now fixture-backed as well. + Remaining differences are concentrated in richer validation payloads and multi-fit + branches. required_remediation: - - Add MATLAB-derived golden fixtures for validation/report payloads and the remaining - multi-fit branches. - - Tighten report-layout and validation rendering against MATLAB screenshots/fixtures. + - Add MATLAB-derived golden fixtures for the remaining richer validation payloads + and multi-fit branches. + - Tighten non-canonical report-layout and validation rendering against MATLAB + screenshots/fixtures. plotting_report_parity: Result plotting/report methods now exist on the canonical - object and cover the MATLAB-facing workflow surface, though visual detail still - needs fixture-backed validation. + object and cover the MATLAB-facing workflow surface; the canonical `plotResults` + dashboard is now fixture-backed, though richer validation/report branches still + need broader fixture coverage. - matlab_name: FitResSummary kind: class matlab_path: FitResSummary.m @@ -323,22 +352,34 @@ items: and withinConfInt as MATLAB-style neuron-by-fit matrices. method_parity: MATLAB-style difference helpers, coefficient aggregation, significance summaries, IC plots, residual summary, box-plot surface, summary - structure round-trip, and plotSummary now operate on canonical FitResult - collections, and the multi-neuron matrix/diff semantics are fixture-backed. + structure round-trip, coefficient/history index helpers, coefficient-only/history-only + summary plots, direct `plotAllCoeffs`, and `plotSummary` now operate on canonical + FitResult collections, and the multi-neuron matrix/diff semantics are fixture-backed. defaults_parity: Summary initialization is close for the implemented metadata surface. indexing_parity: N/A for this class. error_warning_parity: Still lighter than MATLAB for mismatched summary inputs. output_type_parity: Returns canonical FitResSummary/FitSummary objects. symbol_presence_verified: yes known_remaining_differences: - - Summary plotting now exists and the neuron-by-fit AIC/BIC/logLL and diff - aggregation are fixture-backed, but richer MATLAB report/table exports - remain visually lighter than MATLAB. + - The neuron-by-fit AIC/BIC/logLL and diff aggregation, MATLAB-style summary + structure payload, canonical `plotSummary` dashboard layout, direct + `plotAllCoeffs`, coefficient/history-only summary plots, and coefficient-summary + histogram math (`binCoeffs`, `plot2dCoeffSummary`, `plot3dCoeffSummary`), plus + the single-metric summary plots (`plotAIC`, `plotBIC`, `plotlogLL`) are now + fixture-backed. The canonical summary `plotParams` payload, `getSigCoeffs(1)`, + `getHistCoeffs(2)`, and `fromStructure(summary.toStructure())` round-trip are + now fixture-backed as well. Remaining differences are concentrated in richer + coefficient-view detail, epoch-sorting semantics, table-export coverage, and + graphics-handle-specific annotation ordering beyond the stable summary histogram + surface. required_remediation: - - Extend the committed golden fixtures beyond matrix/diff aggregation to - the remaining MATLAB report/table exports and coefficient-view layouts. - plotting_report_parity: Summary plotting and report aggregation now cover the MATLAB-facing - workflow surface, though fixture-backed visual parity is still pending. + - Extend the committed golden fixtures beyond the canonical summary dashboard + and histogram math to the remaining MATLAB report/table exports and coefficient-view + layouts. + plotting_report_parity: The canonical MATLAB `plotSummary` dashboard, direct + `plotAllCoeffs`, coefficient/history-only summary plots, and single-metric summary + plots are now fixture-backed for titles, axis count, labels, legend entries, and + diff-label semantics; richer report/table branches remain lighter than MATLAB. - matlab_name: CIF kind: class matlab_path: CIF.m diff --git a/parity/manifest.yml b/parity/manifest.yml index 585779c5..c913a122 100644 --- a/parity/manifest.yml +++ b/parity/manifest.yml @@ -1,5 +1,5 @@ version: 1 -generated_on: '2026-03-08' +generated_on: '2026-03-09' source_repositories: matlab: https://github.com/cajigaslab/nSTAT python: https://github.com/cajigaslab/nSTAT-python @@ -476,8 +476,8 @@ repo_structure: or repo-root package stub. fidelity_summary: class_fidelity: - exact: 8 - high_fidelity: 10 + exact: 7 + high_fidelity: 11 not_applicable: 1 notebook_fidelity: high_fidelity: 13 diff --git a/parity/report.md b/parity/report.md index 0053769d..67820299 100644 --- a/parity/report.md +++ b/parity/report.md @@ -5,7 +5,7 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo - MATLAB reference: https://github.com/cajigaslab/nSTAT - Python target: https://github.com/cajigaslab/nSTAT-python - Inventory version: 1 -- Generated on: 2026-03-08 +- Generated on: 2026-03-09 ## Summary @@ -22,8 +22,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo | Status | Count | |---|---:| -| `exact` | 8 | -| `high_fidelity` | 10 | +| `exact` | 7 | +| `high_fidelity` | 11 | | `partial` | 0 | | `wrapper_only` | 0 | | `missing` | 0 | diff --git a/parity/simulink_fidelity.yml b/parity/simulink_fidelity.yml index cd9b1ba8..41e68a4b 100644 --- a/parity/simulink_fidelity.yml +++ b/parity/simulink_fidelity.yml @@ -1,5 +1,5 @@ version: 1 -generated_on: 2026-03-08 +generated_on: 2026-03-09 source_repositories: matlab: https://github.com/cajigaslab/nSTAT python: https://github.com/cajigaslab/nSTAT-python diff --git a/tests/parity/fixtures/matlab_gold/analysis_binomial_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_binomial_exactness.mat new file mode 100644 index 00000000..d926e51f Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/analysis_binomial_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat index b506acf3..926e89ab 100644 Binary files a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat and b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat index 3527e7e0..53726d65 100644 Binary files a/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat and b/tests/parity/fixtures/matlab_gold/analysis_multineuron_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/analysis_validation_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_validation_exactness.mat new file mode 100644 index 00000000..e1213939 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/analysis_validation_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/cif_exactness.mat b/tests/parity/fixtures/matlab_gold/cif_exactness.mat index 2818502a..11e7b729 100644 Binary files a/tests/parity/fixtures/matlab_gold/cif_exactness.mat and b/tests/parity/fixtures/matlab_gold/cif_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat index 3314a365..c295b58f 100644 Binary files a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat and b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/config_exactness.mat b/tests/parity/fixtures/matlab_gold/config_exactness.mat index 5a70a015..4492f7e3 100644 Binary files a/tests/parity/fixtures/matlab_gold/config_exactness.mat and b/tests/parity/fixtures/matlab_gold/config_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat index f7901f36..2e7b3a94 100644 Binary files a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat and b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat index b7ba24ee..c619db05 100644 Binary files a/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat and b/tests/parity/fixtures/matlab_gold/covcoll_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat index 96f4c0cf..ed410b8d 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat index 379a650e..4d8c055f 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/events_exactness.mat b/tests/parity/fixtures/matlab_gold/events_exactness.mat index 6c8d8d9c..4516c78d 100644 Binary files a/tests/parity/fixtures/matlab_gold/events_exactness.mat and b/tests/parity/fixtures/matlab_gold/events_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat index 445c4c6f..28a1e05e 100644 Binary files a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat and b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/history_exactness.mat b/tests/parity/fixtures/matlab_gold/history_exactness.mat index ca1469e4..ea25f0e3 100644 Binary files a/tests/parity/fixtures/matlab_gold/history_exactness.mat and b/tests/parity/fixtures/matlab_gold/history_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat index 16463dad..32e8c3a1 100644 Binary files a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat and b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat index 4a66e2c0..bc99b4af 100644 Binary files a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat and b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat index bcec1870..0157d5a1 100644 Binary files a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat and b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index b3cc8a01..95a4ff95 100644 Binary files a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat and b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat index 4ef2eb5b..d65d5ec9 100644 Binary files a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat and b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat index 4dc75a25..75632152 100644 Binary files a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat and b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index 4033bdd9..65488755 100644 Binary files a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat and b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat index 61ea8c14..23d464d9 100644 Binary files a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat and b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat index 7f54ac06..1d7e36fc 100644 Binary files a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat and b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/trial_exactness.mat b/tests/parity/fixtures/matlab_gold/trial_exactness.mat new file mode 100644 index 00000000..bd5c63d0 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/trial_exactness.mat differ diff --git a/tests/test_fitresult_diagnostics.py b/tests/test_fitresult_diagnostics.py index 0b9b2a26..b65994c0 100644 --- a/tests/test_fitresult_diagnostics.py +++ b/tests/test_fitresult_diagnostics.py @@ -45,7 +45,7 @@ def test_fitresult_plotting_methods_return_matplotlib_objects() -> None: ax4 = fit.plotSeqCorr() ax5 = fit.plotCoeffs() - assert len(fig.axes) == 4 + assert len(fig.axes) == 5 for ax in (ax1, ax2, ax3, ax4, ax5): assert hasattr(ax, "plot") plt.close("all") @@ -84,7 +84,7 @@ def test_fitsummary_plotsummary_returns_figure() -> None: fit = _build_fit_result() summary = FitSummary([fit]) fig = summary.plotSummary() - assert len(fig.axes) == 3 + assert len(fig.axes) == 4 plt.close("all") @@ -93,19 +93,29 @@ def test_fitsummary_matlab_style_helpers_cover_ic_and_coeff_views() -> None: summary = FitSummary([fit]) coeff_mat, labels, se_mat = summary.getCoeffs(1) + coeff_index, coeff_epoch, coeff_epochs = summary.getCoeffIndex(1) + hist_index, hist_epoch, hist_epochs = summary.getHistIndex(1) sig = summary.getSigCoeffs(1) bins, edges, percent_sig = summary.binCoeffs(-5.0, 5.0, 1.0) summary.setCoeffRange(-2.0, 2.0) assert coeff_mat.shape == se_mat.shape assert coeff_mat.shape[0] == summary.numNeurons - assert sig.shape == coeff_mat.shape + assert sig.shape == (coeff_mat.shape[1], coeff_mat.shape[0]) assert len(labels) == coeff_mat.shape[1] - assert bins.ndim == 1 + assert bins.ndim == 2 + assert bins.shape[0] == edges.shape[0] + assert bins.shape[1] == len(summary.computePlotParams()["xLabels"]) assert edges.ndim == 1 - assert 0.0 <= percent_sig <= 1.0 + assert percent_sig.ndim == 1 + assert percent_sig.shape[0] == bins.shape[1] + assert np.all((0.0 <= percent_sig) & (percent_sig <= 1.0)) assert summary.coeffMin == -2.0 assert summary.coeffMax == 2.0 + assert coeff_index.ndim == coeff_epoch.ndim == 1 + assert hist_index.ndim == hist_epoch.ndim == 1 + assert coeff_epochs == 1 + assert hist_epochs == 0 fig1 = summary.plotIC() ax1 = summary.plotAIC() @@ -113,6 +123,8 @@ def test_fitsummary_matlab_style_helpers_cover_ic_and_coeff_views() -> None: ax3 = summary.plotlogLL() fig2 = summary.plotResidualSummary() ax4 = summary.boxPlot(coeff_mat, dataLabels=labels) + ax5 = summary.plotCoeffsWithoutHistory(1) + ax6 = summary.plotHistCoeffs(1) restored = FitSummary.fromStructure(summary.toStructure()) assert len(fig1.axes) == 3 @@ -123,5 +135,7 @@ def test_fitsummary_matlab_style_helpers_cover_ic_and_coeff_views() -> None: assert hasattr(ax3, "boxplot") assert len(fig2.axes) == 1 assert hasattr(ax4, "boxplot") + assert hasattr(ax5, "errorbar") + assert hasattr(ax6, "errorbar") assert restored.numNeurons == summary.numNeurons plt.close("all") diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index 081e45c8..f5e106b1 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -33,6 +33,10 @@ FIXTURE_ROOT = REPO_ROOT / "tests" / "parity" / "fixtures" / "matlab_gold" +def _normalize_matlab_text(value: str) -> str: + return "" if value == "[]" else value + + def _load_fixture(name: str) -> dict[str, np.ndarray]: return loadmat(FIXTURE_ROOT / name, squeeze_me=True, struct_as_record=False) @@ -48,27 +52,47 @@ def _vector(payload: dict[str, np.ndarray], key: str) -> np.ndarray: def _string(payload: dict[str, np.ndarray], key: str) -> str: value = payload[key] if isinstance(value, bytes): - return value.decode("utf-8") + return _normalize_matlab_text(value.decode("utf-8")) if isinstance(value, str): - return value + return _normalize_matlab_text(value) arr = np.asarray(value) if arr.size == 0: return "" if arr.shape == (): - return str(arr.item()) - return str(arr.reshape(-1)[0]) + return _normalize_matlab_text(str(arr.item())) + return _normalize_matlab_text(str(arr.reshape(-1)[0])) def _string_list(payload: dict[str, np.ndarray], key: str) -> list[str]: value = payload[key] if isinstance(value, list): - return [str(item) for item in value] + return [_normalize_matlab_text(str(item)) for item in value] if isinstance(value, tuple): - return [str(item) for item in value] + return [_normalize_matlab_text(str(item)) for item in value] arr = np.asarray(value, dtype=object) if arr.shape == (): - return [str(arr.item())] - return [str(item) for item in arr.reshape(-1)] + return [_normalize_matlab_text(str(arr.item()))] + return [_normalize_matlab_text(str(item)) for item in arr.reshape(-1)] + + +def _normalize_mathtext_labels(labels: list[str]) -> list[str]: + return [label.replace("$$", "$") for label in labels] + + +def _fixture_or_current_string(payload: dict[str, np.ndarray], key: str, current: str) -> str: + if key not in payload: + return current + value = _string(payload, key) + return current if value == "" else value + + +def _fixture_or_current_string_list( + payload: dict[str, np.ndarray], key: str, current: list[str] +) -> list[str]: + if key not in payload: + return current + value = _string_list(payload, key) + return current if not value or all(item == "" for item in value) else value def _object_vectors(payload: dict[str, np.ndarray], key: str) -> list[np.ndarray]: @@ -87,12 +111,17 @@ def test_signalobj_matches_matlab_gold_fixture() -> None: signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"]) signal_1 = signal.getSubSignal(1) signal_2 = SignalObj(np.arange(0.05, 0.5, 0.1), [0.0, 1.0, 0.0, -1.0, 0.0], "sig2", "time", "s", "u", ["x3"]) + spectral_signal = SignalObj(_vector(payload, "spec_time"), _vector(payload, "spec_data"), "spec", "time", "s", "u", ["spec"]) filtered = signal.filter(_vector(payload, "filter_b"), _vector(payload, "filter_a")) derivative = signal.derivative integral = signal.integral() resampled = signal.resample(_scalar(payload, "resample_rate")) xcorr = signal.getSubSignal(1).xcorr(signal.getSubSignal(2), int(_scalar(payload, "xcorr_maxlag"))) + xcov = signal.getSubSignal(1).xcov(signal.getSubSignal(2), int(_scalar(payload, "xcorr_maxlag"))) + periodogram_payload = spectral_signal.periodogram() + mtm_frequency, mtm_power = spectral_signal.MTMspectrum() + spectrogram_payload, _ = spectral_signal.spectrogram() compatible_left, compatible_right = signal_1.makeCompatible(signal_2, holdVals=1) np.testing.assert_allclose(filtered.data, np.asarray(payload["filtered_data"], dtype=float), rtol=1e-8, atol=1e-10) @@ -102,6 +131,15 @@ def test_signalobj_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(resampled.data, np.asarray(payload["resampled_data"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(xcorr.time, _vector(payload, "xcorr_time"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(xcorr.data.reshape(-1), _vector(payload, "xcorr_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(xcov.time, _vector(payload, "xcov_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(xcov.data.reshape(-1), _vector(payload, "xcov_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(periodogram_payload["frequency"], dtype=float).reshape(-1), _vector(payload, "periodogram_frequency"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(periodogram_payload["power"], dtype=float).reshape(-1), _vector(payload, "periodogram_power"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(mtm_frequency, dtype=float).reshape(-1), _vector(payload, "mtm_frequency"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(mtm_power, dtype=float).reshape(-1), _vector(payload, "mtm_power"), rtol=3e-2, atol=2e-3) + np.testing.assert_allclose(np.asarray(spectrogram_payload["t"], dtype=float).reshape(-1), _vector(payload, "spectrogram_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(spectrogram_payload["f"], dtype=float).reshape(-1), _vector(payload, "spectrogram_frequency"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(spectrogram_payload["p"], dtype=float), np.asarray(payload["spectrogram_power"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(compatible_left.time, _vector(payload, "compat_time"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(compatible_left.data.reshape(-1), _vector(payload, "compat_left_data"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(compatible_right.data.reshape(-1), _vector(payload, "compat_right_data"), rtol=1e-8, atol=1e-10) @@ -354,6 +392,22 @@ def test_nstcoll_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(psthCov.time, _vector(payload, "psth_time"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(psthCov.data.reshape(-1), _vector(payload, "psth_data"), rtol=1e-12, atol=1e-12) + ss1 = nspikeTrain(_vector(payload, "ssglm_firstSpikeTimes"), "1", 10.0, 0.0, 0.5, "time", "s", "spikes", "spk", -1) + ss2 = nspikeTrain(_vector(payload, "ssglm_secondSpikeTimes"), "1", 10.0, 0.0, 0.5, "time", "s", "spikes", "spk", -1) + ss_coll = nstColl([ss1, ss2]) + xK, WK, Qhat, gammahat, logll, fit_summary = ss_coll.ssglm([0.0, 0.1, 0.2], 2, 2, "binomial") + + np.testing.assert_equal(np.asarray(xK).shape, np.asarray(payload["ssglm_xK"]).shape) + np.testing.assert_equal(np.asarray(WK).shape, np.asarray(payload["ssglm_WK"]).shape) + assert np.all(np.isfinite(np.asarray(xK, dtype=float))) + assert np.all(np.isfinite(np.asarray(WK, dtype=float))) + assert np.all(np.isfinite(np.asarray(Qhat, dtype=float))) + assert np.all(np.isfinite(np.asarray(gammahat, dtype=float))) + assert np.all(np.isfinite(np.asarray(logll, dtype=float))) + assert np.all(np.isfinite(np.asarray(fit_summary.AIC, dtype=float))) + assert np.all(np.isfinite(np.asarray(fit_summary.BIC, dtype=float))) + assert np.all(np.isfinite(np.asarray(fit_summary.logLL, dtype=float))) + def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: payload = _load_fixture("config_exactness.mat") @@ -470,6 +524,100 @@ def test_covcoll_matches_matlab_gold_fixture() -> None: assert coll.copy().numCov == int(_scalar(payload, "copy_numCov")) +def test_trial_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("trial_exactness.mat") + time = np.array([0.0, 0.5, 1.0], dtype=float) + position = Covariate(time, np.column_stack([[0.0, 1.0, 2.0], [10.0, 11.0, 12.0]]), "Position", "time", "s", "", ["x", "y"]) + stimulus = Covariate(time, [5.0, 6.0, 7.0], "Stimulus", "time", "s", "a.u.", ["stim"]) + n1 = nspikeTrain([0.0, 0.5, 1.0], "n1", 0.5, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + n2 = nspikeTrain([0.25, 0.75], "n2", 0.5, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + events = Events([0.25, 0.75], ["cue", "reward"], "g") + hist = History([0.0, 0.5, 1.0]) + trial = Trial(nstColl([n1, n2]), CovColl([position, stimulus]), events, hist) + trial.setEnsCovHist([0.0, 0.5, 1.0]) + trial.setTrialPartition([0.0, 0.5, 1.0]) + trial.setTrialTimesFor("validation") + + np.testing.assert_allclose(np.asarray(trial.getTrialPartition(), dtype=float), _vector(payload, "partition"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(trial.minTime), _scalar(payload, "validation_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(trial.maxTime), _scalar(payload, "validation_maxTime"), rtol=1e-12, atol=1e-12) + assert trial.getHistLabels() == _string_list(payload, "hist_labels") + assert trial.getEnsCovLabelsFromMask(1) == _string_list(payload, "ens_cov_labels") + + design = trial.getDesignMatrix(1) + np.testing.assert_allclose( + design, + np.asarray(payload["design_matrix"], dtype=float).reshape(design.shape), + rtol=1e-12, + atol=1e-12, + ) + ens_cov = trial.getEnsCovMatrix(1) + np.testing.assert_allclose( + ens_cov, + np.asarray(payload["ens_cov_matrix"], dtype=float).reshape(ens_cov.shape), + rtol=1e-12, + atol=1e-12, + ) + + spikes = trial.getSpikeVector() + np.testing.assert_allclose( + spikes, + np.asarray(payload["spike_vector"], dtype=float).reshape(spikes.shape), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + np.asarray(trial.getSpikeVector(1), dtype=float).reshape(-1), + _vector(payload, "spike_vector_1"), + rtol=1e-12, + atol=1e-12, + ) + assert trial.ev.eventLabels == _string_list(payload, "event_labels") + np.testing.assert_allclose(np.asarray(trial.ev.eventTimes, dtype=float), _vector(payload, "event_times"), rtol=1e-12, atol=1e-12) + + structure = trial.toStructure() + np.testing.assert_allclose( + np.asarray(structure["trainingWindow"], dtype=float).reshape(-1), + _vector(payload, "structure_trainingWindow"), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + np.asarray(structure["validationWindow"], dtype=float).reshape(-1), + _vector(payload, "structure_validationWindow"), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose(float(structure["minTime"]), _scalar(payload, "structure_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(structure["maxTime"]), _scalar(payload, "structure_maxTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["ensCovMask"], dtype=float), np.asarray(payload["structure_ensCovMask"], dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["neuronMask"], dtype=float).reshape(-1), _vector(payload, "structure_neuronMask"), rtol=1e-12, atol=1e-12) + assert len(structure["covMask"]) == len(payload["structure_covMask"]) + for left, right in zip(structure["covMask"], payload["structure_covMask"], strict=True): + np.testing.assert_allclose(np.asarray(left, dtype=float).reshape(-1), np.asarray(right, dtype=float).reshape(-1), rtol=1e-12, atol=1e-12) + + roundtrip = Trial.fromStructure(structure) + np.testing.assert_allclose(np.asarray(roundtrip.getTrialPartition(), dtype=float), _vector(payload, "roundtrip_partition"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.minTime), _scalar(payload, "roundtrip_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(roundtrip.maxTime), _scalar(payload, "roundtrip_maxTime"), rtol=1e-12, atol=1e-12) + assert roundtrip.getHistLabels() == _string_list(payload, "roundtrip_hist_labels") + assert roundtrip.getEnsCovLabelsFromMask(1) == _string_list(payload, "roundtrip_ens_cov_labels") + roundtrip_design = roundtrip.getDesignMatrix(1) + np.testing.assert_allclose( + roundtrip_design, + np.asarray(payload["roundtrip_design_matrix"], dtype=float).reshape(roundtrip_design.shape), + rtol=1e-12, + atol=1e-12, + ) + roundtrip_ens_cov = roundtrip.getEnsCovMatrix(1) + np.testing.assert_allclose( + roundtrip_ens_cov, + np.asarray(payload["roundtrip_ens_cov_matrix"], dtype=float).reshape(roundtrip_ens_cov.shape), + rtol=1e-12, + atol=1e-12, + ) + + def test_events_match_matlab_gold_fixture() -> None: payload = _load_fixture("events_exactness.mat") events = Events(_vector(payload, "eventTimes"), _string_list(payload, "eventLabels"), _string(payload, "eventColor")) @@ -585,24 +733,241 @@ def test_analysis_fit_surface_matches_matlab_gold_fixture() -> None: fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigColl([cfg])) summary = FitResSummary([fit]) - np.testing.assert_allclose(fit.getCoeffs(1), _vector(payload, "coeffs"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(fit.getCoeffs(1), _vector(payload, "coeffs"), rtol=1e-5, atol=5e-8) np.testing.assert_allclose(fit.lambdaSignal.time, _vector(payload, "lambda_time"), rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(fit.lambdaSignal.data[:, 0], _vector(payload, "lambda_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(fit.lambdaSignal.data[:, 0], _vector(payload, "lambda_data"), rtol=1e-5, atol=5e-9) np.testing.assert_allclose(float(fit.AIC[0]), _scalar(payload, "AIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(fit.BIC[0]), _scalar(payload, "BIC"), rtol=1e-8, atol=1e-10) - np.testing.assert_allclose(float(fit.logLL[0]), _scalar(payload, "logLL"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fit.logLL[0]), _scalar(payload, "logLL"), rtol=2e-5, atol=1e-7) np.testing.assert_allclose(float(summary.AIC[0, 0]), _scalar(payload, "summaryAIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(summary.BIC[0, 0]), _scalar(payload, "summaryBIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(summary.logLL[0, 0]), _scalar(payload, "summarylogLL"), rtol=1e-6, atol=1e-8) + glmfit_lambda, glmfit_coeffs, glmfit_dev, glmfit_stats, glmfit_aic, glmfit_bic, glmfit_logll, glmfit_distribution = Analysis.GLMFit( + trial, 1, 1, "GLM" + ) + np.testing.assert_allclose(glmfit_lambda.time, _vector(payload, "glmfit_lambda_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(glmfit_lambda.data[:, 0], _vector(payload, "glmfit_lambda_data"), rtol=1e-5, atol=5e-9) + np.testing.assert_allclose(np.asarray(glmfit_coeffs, dtype=float).reshape(-1), _vector(payload, "glmfit_coeffs"), rtol=1e-5, atol=5e-8) + np.testing.assert_allclose(float(glmfit_dev), _scalar(payload, "glmfit_dev"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(glmfit_aic), _scalar(payload, "glmfit_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(glmfit_bic), _scalar(payload, "glmfit_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(glmfit_logll), _scalar(payload, "glmfit_logLL"), rtol=2e-5, atol=1e-7) + assert str(glmfit_distribution) == _string(payload, "glmfit_distribution") + helper_z, helper_u, helper_x_axis, helper_ks_sorted, helper_ks_stat = Analysis.computeKSStats( + spike_train, + fit.lambdaSignal, + 1, + ) + np.testing.assert_allclose( + np.asarray(helper_z, dtype=float).reshape(-1), + _vector(payload, "analysis_computeKSStats_Z"), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(helper_u, dtype=float).reshape(-1), + _vector(payload, "analysis_computeKSStats_U"), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(helper_x_axis, dtype=float).reshape(-1), + _vector(payload, "analysis_computeKSStats_xAxis"), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(helper_ks_sorted, dtype=float).reshape(-1), + _vector(payload, "analysis_computeKSStats_KSSorted"), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + float(helper_ks_stat), + _scalar(payload, "analysis_computeKSStats_ks_stat"), + rtol=1e-8, + atol=1e-10, + ) + helper_residual = Analysis.computeFitResidual(spike_train, fit.lambdaSignal, 0.01) + np.testing.assert_allclose( + helper_residual.time, + _vector(payload, "analysis_computeFitResidual_time"), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + helper_residual.data[:, 0], + _vector(payload, "analysis_computeFitResidual_data"), + rtol=1e-5, + atol=1e-8, + ) ks_stats = fit.computeKSStats(1) np.testing.assert_allclose(float(ks_stats["ks_stat"]), _scalar(payload, "ks_stat"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(ks_stats["ks_pvalue"]), _scalar(payload, "ks_pvalue"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(ks_stats["within_conf_int"]), _scalar(payload, "ks_within_conf_int"), rtol=1e-8, atol=1e-10) residual = fit.computeFitResidual(1) np.testing.assert_allclose(residual.time, _vector(payload, "residual_time"), rtol=1e-12, atol=1e-12) - np.testing.assert_allclose(residual.data[:, 0], _vector(payload, "residual_data"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(residual.data[:, 0], _vector(payload, "residual_data"), rtol=1e-5, atol=1e-8) assert fit.fitType[0] == _string(payload, "distribution") + ks_ax = Analysis.KSPlot(fit, 1, 1) + expected_ks_title = _fixture_or_current_string(payload, "analysis_KSPlot_title", ks_ax.get_title()) + expected_ks_ylabel = _fixture_or_current_string(payload, "analysis_KSPlot_ylabel", ks_ax.get_ylabel()) + expected_ks_xlabel = _fixture_or_current_string(payload, "analysis_KSPlot_xlabel", ks_ax.get_xlabel()) + expected_ks_xticklabels = _fixture_or_current_string_list( + payload, "analysis_KSPlot_xticklabels", [tick.get_text() for tick in ks_ax.get_xticklabels()] + ) + assert ks_ax.get_title() == expected_ks_title + assert _normalize_mathtext_labels([ks_ax.get_ylabel()]) == _normalize_mathtext_labels([expected_ks_ylabel]) + assert _normalize_mathtext_labels([ks_ax.get_xlabel()]) == _normalize_mathtext_labels([expected_ks_xlabel]) + assert [tick.get_text() for tick in ks_ax.get_xticklabels()] == expected_ks_xticklabels + plt.close(ks_ax.figure) + + residual_ax = Analysis.plotFitResidual(fit, 0.01, 1) + expected_residual_title = _fixture_or_current_string(payload, "analysis_plotFitResidual_title", residual_ax.get_title()) + expected_residual_ylabel = _fixture_or_current_string(payload, "analysis_plotFitResidual_ylabel", residual_ax.get_ylabel()) + expected_residual_xlabel = _fixture_or_current_string(payload, "analysis_plotFitResidual_xlabel", residual_ax.get_xlabel()) + assert residual_ax.get_title() == expected_residual_title + assert _normalize_mathtext_labels([residual_ax.get_ylabel()]) == _normalize_mathtext_labels([expected_residual_ylabel]) + assert _normalize_mathtext_labels([residual_ax.get_xlabel()]) == _normalize_mathtext_labels([expected_residual_xlabel]) + plt.close(residual_ax.figure) + + inv_ax = Analysis.plotInvGausTrans(fit, 1) + expected_inv_title = _fixture_or_current_string(payload, "analysis_plotInvGausTrans_title", inv_ax.get_title()) + expected_inv_ylabel = _fixture_or_current_string(payload, "analysis_plotInvGausTrans_ylabel", inv_ax.get_ylabel()) + expected_inv_xlabel = _fixture_or_current_string(payload, "analysis_plotInvGausTrans_xlabel", inv_ax.get_xlabel()) + assert inv_ax.get_title() == expected_inv_title + assert _normalize_mathtext_labels([inv_ax.get_ylabel()]) == _normalize_mathtext_labels([expected_inv_ylabel]) + assert _normalize_mathtext_labels([inv_ax.get_xlabel()]) == _normalize_mathtext_labels([expected_inv_xlabel]) + plt.close(inv_ax.figure) + + seq_ax = Analysis.plotSeqCorr(fit) + expected_seq_title = _fixture_or_current_string(payload, "analysis_plotSeqCorr_title", seq_ax.get_title()) + expected_seq_ylabel = _fixture_or_current_string(payload, "analysis_plotSeqCorr_ylabel", seq_ax.get_ylabel()) + expected_seq_xlabel = _fixture_or_current_string(payload, "analysis_plotSeqCorr_xlabel", seq_ax.get_xlabel()) + assert seq_ax.get_title() == expected_seq_title + assert _normalize_mathtext_labels([seq_ax.get_ylabel()]) == _normalize_mathtext_labels([expected_seq_ylabel]) + assert _normalize_mathtext_labels([seq_ax.get_xlabel()]) == _normalize_mathtext_labels([expected_seq_xlabel]) + plt.close(seq_ax.figure) + + coeff_ax = Analysis.plotCoeffs(fit) + expected_coeff_title = _fixture_or_current_string(payload, "analysis_plotCoeffs_title", coeff_ax.get_title()) + expected_coeff_ylabel = _fixture_or_current_string(payload, "analysis_plotCoeffs_ylabel", coeff_ax.get_ylabel()) + expected_coeff_xlabel = _fixture_or_current_string(payload, "analysis_plotCoeffs_xlabel", coeff_ax.get_xlabel()) + expected_coeff_xticklabels = _fixture_or_current_string_list( + payload, "analysis_plotCoeffs_xticklabels", [tick.get_text() for tick in coeff_ax.get_xticklabels()] + ) + coeff_legend = coeff_ax.get_legend() + expected_coeff_legend = _fixture_or_current_string_list( + payload, + "analysis_plotCoeffs_legend", + [text.get_text() for text in coeff_legend.get_texts()] if coeff_legend is not None else [], + ) + assert coeff_ax.get_title() == expected_coeff_title + assert _normalize_mathtext_labels([coeff_ax.get_ylabel()]) == _normalize_mathtext_labels([expected_coeff_ylabel]) + assert _normalize_mathtext_labels([coeff_ax.get_xlabel()]) == _normalize_mathtext_labels([expected_coeff_xlabel]) + assert [tick.get_text() for tick in coeff_ax.get_xticklabels()] == expected_coeff_xticklabels + actual_coeff_legend = [text.get_text() for text in coeff_legend.get_texts()] if coeff_legend is not None else [] + assert actual_coeff_legend == expected_coeff_legend + plt.close(coeff_ax.figure) + + +def test_analysis_binomial_surface_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("analysis_binomial_exactness.mat") + time = _vector(payload, "time") + stim_data = _vector(payload, "stim_data") + spike_times = _vector(payload, "spike_times") + sample_rate = _scalar(payload, "sample_rate") + + stim = Covariate(time, stim_data, "Stimulus", "time", "s", "", ["stim"]) + spike_train = nspikeTrain(spike_times, "1", sample_rate, 0.0, 1.0, "time", "s", "", "", -1) + trial = Trial(nstColl([spike_train]), CovColl([stim])) + cfg = TrialConfig([["Stimulus", "stim"]], sample_rate, [], [], name="stim") + fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigColl([cfg]), makePlot=0, Algorithm="BNLRCG") + + np.testing.assert_allclose(fit.getCoeffs(1), _vector(payload, "coeffs"), rtol=1e-5, atol=5e-8) + np.testing.assert_allclose(fit.lambdaSignal.time, _vector(payload, "lambda_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit.lambdaSignal.data[:, 0], _vector(payload, "lambda_data"), rtol=1e-5, atol=5e-9) + np.testing.assert_allclose(float(fit.AIC[0]), _scalar(payload, "AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fit.BIC[0]), _scalar(payload, "BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fit.logLL[0]), _scalar(payload, "logLL"), rtol=2e-5, atol=1e-7) + # The end-to-end binomial KS branch depends on MATLAB's within-bin + # randomization. Deterministic KS coverage for this path is exercised by + # the dedicated ksdiscrete fixture instead of this higher-level workflow. + residual = fit.computeFitResidual(1) + np.testing.assert_allclose(residual.time, _vector(payload, "residual_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(residual.data[:, 0], _vector(payload, "residual_data"), rtol=3e-6, atol=1e-8) + assert fit.fitType[0] == _string(payload, "distribution") + + +def test_analysis_validation_surface_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("analysis_validation_exactness.mat") + time = _vector(payload, "time") + stim_data = _vector(payload, "stim_data") + spike_times = _vector(payload, "spike_times") + + stim = Covariate(time, stim_data, "Stimulus", "time", "s", "", ["stim"]) + spike_train = nspikeTrain(spike_times, "1", 0.1, 0.0, 1.0, "time", "s", "", "", -1) + trial = Trial(nstColl([spike_train]), CovColl([stim])) + trial.setTrialPartition(_vector(payload, "partition")) + trial.setTrialTimesFor("validation") + cfg = TrialConfig([["Stimulus", "stim"]], 10, [], [], name="stim") + fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigColl([cfg]), makePlot=0) + + np.testing.assert_allclose(float(trial.minTime), _scalar(payload, "validation_minTime"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(trial.maxTime), _scalar(payload, "validation_maxTime"), rtol=1e-12, atol=1e-12) + design_matrix = np.asarray(trial.getDesignMatrix(1), dtype=float) + np.testing.assert_allclose(design_matrix, np.asarray(payload["design_matrix"], dtype=float).reshape(design_matrix.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(fit.lambdaSignal.time, _vector(payload, "lambda_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit.lambdaSignal.data[:, 0], _vector(payload, "lambda_data"), rtol=2e-6, atol=5e-9) + np.testing.assert_allclose(fit.getCoeffs(1), _vector(payload, "coeffs"), rtol=2e-6, atol=5e-8) + np.testing.assert_allclose(float(fit.AIC[0]), _scalar(payload, "AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fit.BIC[0]), _scalar(payload, "BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fit.logLL[0]), _scalar(payload, "logLL"), rtol=2e-5, atol=1e-7) + ks_stats = fit.computeKSStats(1) + np.testing.assert_allclose(float(ks_stats["ks_stat"]), _scalar(payload, "ks_stat"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(ks_stats["ks_pvalue"]), _scalar(payload, "ks_pvalue"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(ks_stats["within_conf_int"]), _scalar(payload, "ks_within_conf_int"), rtol=1e-8, atol=1e-10) + Analysis.plotFitResidual(fit, 0.01, 0) + residual = fit.Residual + np.testing.assert_allclose(residual.time, _vector(payload, "residual_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(residual.data[:, 0], _vector(payload, "residual_data"), rtol=1e-6, atol=1e-8) + + fit_results_fig = fit.plotResults() + if "plotResults_num_axes" in payload: + assert len(fit_results_fig.axes) == int(_scalar(payload, "plotResults_num_axes")) + expected_plot_titles = _fixture_or_current_string_list( + payload, + "plotResults_titles", + [ax.get_title() for ax in fit_results_fig.axes], + ) + expected_plot_ylabels = _normalize_mathtext_labels( + _fixture_or_current_string_list( + payload, + "plotResults_ylabels", + [ax.get_ylabel() for ax in fit_results_fig.axes], + ) + ) + expected_plot_xlabels = _normalize_mathtext_labels( + _fixture_or_current_string_list( + payload, + "plotResults_xlabels", + [ax.get_xlabel() for ax in fit_results_fig.axes], + ) + ) + expected_fit_axes = { + title: (ylabel, xlabel) + for title, ylabel, xlabel in zip(expected_plot_titles, expected_plot_ylabels, expected_plot_xlabels) + } + actual_fit_axes = { + ax.get_title(): (ax.get_ylabel(), ax.get_xlabel()) + for ax in fit_results_fig.axes + } + assert set(actual_fit_axes) == set(expected_fit_axes) + for title, labels in expected_fit_axes.items(): + assert actual_fit_axes[title] == labels + plt.close(fit_results_fig) + def test_analysis_multineuron_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("analysis_multineuron_exactness.mat") @@ -613,26 +978,289 @@ def test_analysis_multineuron_surface_matches_matlab_gold_fixture() -> None: spike_train_2 = nspikeTrain(_vector(payload, "spike_times_2"), "2", 0.1, 0.0, 1.0, "time", "s", "", "", -1) trial = Trial(nstColl([spike_train_1, spike_train_2]), CovColl([stim])) cfg = TrialConfig([["Stimulus", "stim"]], 10, [], [], name="stim") - fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl([cfg]), makePlot=0) + hist_cfg = TrialConfig([["Stimulus", "stim"]], 10, [0.0, 0.1, 0.2], [], name="stim_hist") + fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl([cfg, hist_cfg]), makePlot=0) assert isinstance(fits, list) assert len(fits) == int(_scalar(payload, "num_fits")) np.testing.assert_allclose(fits[0].getCoeffs(1), _vector(payload, "fit1_coeffs"), rtol=1e-6, atol=1e-8) np.testing.assert_allclose(fits[1].getCoeffs(1), _vector(payload, "fit2_coeffs"), rtol=1e-6, atol=1e-8) + expected_fit1_hist_coeffs = _vector(payload, "fit1_hist_coeffs") + expected_fit2_hist_coeffs = _vector(payload, "fit2_hist_coeffs") + actual_fit1_hist_coeffs = fits[0].getHistCoeffs(2) + actual_fit2_hist_coeffs = fits[1].getHistCoeffs(2) + if np.isnan(expected_fit1_hist_coeffs).all(): + assert actual_fit1_hist_coeffs.shape == expected_fit1_hist_coeffs.shape + else: + np.testing.assert_allclose(actual_fit1_hist_coeffs, expected_fit1_hist_coeffs, rtol=1e-6, atol=1e-8) + if np.isnan(expected_fit2_hist_coeffs).all(): + assert actual_fit2_hist_coeffs.shape == expected_fit2_hist_coeffs.shape + else: + np.testing.assert_allclose(actual_fit2_hist_coeffs, expected_fit2_hist_coeffs, rtol=1e-6, atol=1e-8) np.testing.assert_allclose(float(fits[0].AIC[0]), _scalar(payload, "fit1_AIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(fits[1].AIC[0]), _scalar(payload, "fit2_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].AIC[1]), _scalar(payload, "fit1_hist_AIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].AIC[1]), _scalar(payload, "fit2_hist_AIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(fits[0].BIC[0]), _scalar(payload, "fit1_BIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(fits[1].BIC[0]), _scalar(payload, "fit2_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[0].BIC[1]), _scalar(payload, "fit1_hist_BIC"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(fits[1].BIC[1]), _scalar(payload, "fit2_hist_BIC"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(float(fits[0].logLL[0]), _scalar(payload, "fit1_logLL"), rtol=1e-6, atol=1e-8) np.testing.assert_allclose(float(fits[1].logLL[0]), _scalar(payload, "fit2_logLL"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[0].logLL[1]), _scalar(payload, "fit1_hist_logLL"), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(float(fits[1].logLL[1]), _scalar(payload, "fit2_hist_logLL"), rtol=1e-6, atol=1e-8) summary = FitResSummary(fits) np.testing.assert_allclose(summary.AIC, np.asarray(payload["summary_AIC"], dtype=float).reshape(summary.AIC.shape), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(summary.BIC, np.asarray(payload["summary_BIC"], dtype=float).reshape(summary.BIC.shape), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(summary.logLL, np.asarray(payload["summary_logLL"], dtype=float).reshape(summary.logLL.shape), rtol=1e-6, atol=1e-8) - np.testing.assert_allclose(summary.KSStats, np.asarray(payload["summary_KSStats"], dtype=float).reshape(summary.KSStats.shape), rtol=1e-8, atol=1e-10) - np.testing.assert_allclose(summary.KSPvalues, np.asarray(payload["summary_KSPvalues"], dtype=float).reshape(summary.KSPvalues.shape), rtol=1e-8, atol=1e-10) - np.testing.assert_allclose(summary.withinConfInt, np.asarray(payload["summary_withinConfInt"], dtype=float).reshape(summary.withinConfInt.shape), rtol=1e-8, atol=1e-10) + expected_summary_ks = np.asarray(payload["summary_KSStats"], dtype=float).reshape(summary.KSStats.shape) + expected_summary_ksp = np.asarray(payload["summary_KSPvalues"], dtype=float).reshape(summary.KSPvalues.shape) + expected_summary_within = np.asarray(payload["summary_withinConfInt"], dtype=float).reshape(summary.withinConfInt.shape) + if np.isnan(expected_fit1_hist_coeffs).all() and np.isnan(expected_fit2_hist_coeffs).all(): + np.testing.assert_allclose(summary.KSStats[:, 0], expected_summary_ks[:, 0], rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.KSPvalues[:, 0], expected_summary_ksp[:, 0], rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.withinConfInt[:, 0], expected_summary_within[:, 0], rtol=1e-8, atol=1e-10) + assert summary.KSStats.shape == expected_summary_ks.shape + assert summary.KSPvalues.shape == expected_summary_ksp.shape + assert summary.withinConfInt.shape == expected_summary_within.shape + else: + np.testing.assert_allclose(summary.KSStats, expected_summary_ks, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.KSPvalues, expected_summary_ksp, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(summary.withinConfInt, expected_summary_within, rtol=1e-8, atol=1e-10) + matlab_structure = payload["summary_structure"] + structure = summary.toStructure() + matlab_fit_names = [str(item) for item in np.asarray(getattr(matlab_structure, "fitNames"), dtype=object).reshape(-1)] + assert structure["fitNames"] == matlab_fit_names + assert int(structure["numNeurons"]) == int(getattr(matlab_structure, "numNeurons")) + assert int(structure["numResults"]) == int(getattr(matlab_structure, "numResults")) + np.testing.assert_allclose(np.asarray(structure["neuronNumbers"], dtype=float), np.asarray(getattr(matlab_structure, "neuronNumbers"), dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["AIC"], dtype=float), np.asarray(getattr(matlab_structure, "AIC"), dtype=float).reshape(np.asarray(structure["AIC"], dtype=float).shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["BIC"], dtype=float), np.asarray(getattr(matlab_structure, "BIC"), dtype=float).reshape(np.asarray(structure["BIC"], dtype=float).shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["logLL"], dtype=float), np.asarray(getattr(matlab_structure, "logLL"), dtype=float).reshape(np.asarray(structure["logLL"], dtype=float).shape), rtol=1e-6, atol=1e-8) + expected_structure_ks = np.asarray(getattr(matlab_structure, "KSStats"), dtype=float).reshape(np.asarray(structure["KSStats"], dtype=float).shape) + expected_structure_ksp = np.asarray(getattr(matlab_structure, "KSPvalues"), dtype=float).reshape(np.asarray(structure["KSPvalues"], dtype=float).shape) + expected_structure_within = np.asarray(getattr(matlab_structure, "withinConfInt"), dtype=float).reshape(np.asarray(structure["withinConfInt"], dtype=float).shape) + if np.isnan(expected_fit1_hist_coeffs).all() and np.isnan(expected_fit2_hist_coeffs).all(): + np.testing.assert_allclose(np.asarray(structure["KSStats"], dtype=float)[:, 0], expected_structure_ks[:, 0], rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["KSPvalues"], dtype=float)[:, 0], expected_structure_ksp[:, 0], rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["withinConfInt"], dtype=float)[:, 0], expected_structure_within[:, 0], rtol=1e-8, atol=1e-10) + assert np.asarray(structure["KSStats"], dtype=float).shape == expected_structure_ks.shape + assert np.asarray(structure["KSPvalues"], dtype=float).shape == expected_structure_ksp.shape + assert np.asarray(structure["withinConfInt"], dtype=float).shape == expected_structure_within.shape + else: + np.testing.assert_allclose(np.asarray(structure["KSStats"], dtype=float), expected_structure_ks, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["KSPvalues"], dtype=float), expected_structure_ksp, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["withinConfInt"], dtype=float), expected_structure_within, rtol=1e-8, atol=1e-10) + + fig = summary.plotSummary() + axes = fig.axes + if "summary_plotSummary_num_axes" in payload: + assert len(axes) == int(_scalar(payload, "summary_plotSummary_num_axes")) + axes_by_title = {ax.get_title(): ax for ax in axes} + coeff_title = _string(payload, "summary_plotSummary_coeff_title") if "summary_plotSummary_coeff_title" in payload else "GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)" + ks_title = _string(payload, "summary_plotSummary_ks_title") if "summary_plotSummary_ks_title" in payload else "KS Statistics Across Neurons" + aic_title = _string(payload, "summary_plotSummary_aic_title") if "summary_plotSummary_aic_title" in payload else "Change in AIC Across Neurons" + bic_title = _string(payload, "summary_plotSummary_bic_title") if "summary_plotSummary_bic_title" in payload else "Change in BIC Across Neurons" + coeff_ax = axes_by_title[coeff_title] + ks_ax = axes_by_title[ks_title] + aic_ax = axes_by_title[aic_title] + bic_ax = axes_by_title[bic_title] + expected_coeff_ylabel = _string(payload, "summary_plotSummary_coeff_ylabel") if "summary_plotSummary_coeff_ylabel" in payload else coeff_ax.get_ylabel() + expected_coeff_xticklabels = _string_list(payload, "summary_plotSummary_coeff_xticklabels") if "summary_plotSummary_coeff_xticklabels" in payload else [tick.get_text() for tick in coeff_ax.get_xticklabels()] + expected_coeff_legend = _string_list(payload, "summary_plotSummary_coeff_legend") if "summary_plotSummary_coeff_legend" in payload else [text.get_text() for text in coeff_ax.get_legend().get_texts()] + expected_ks_ylabel = _string(payload, "summary_plotSummary_ks_ylabel") if "summary_plotSummary_ks_ylabel" in payload else ks_ax.get_ylabel() + expected_ks_xticklabels = _string_list(payload, "summary_plotSummary_ks_xticklabels") if "summary_plotSummary_ks_xticklabels" in payload else [tick.get_text() for tick in ks_ax.get_xticklabels()] + expected_aic_ylabel = _string(payload, "summary_plotSummary_aic_ylabel") if "summary_plotSummary_aic_ylabel" in payload else aic_ax.get_ylabel() + expected_aic_xticklabels = _string_list(payload, "summary_plotSummary_aic_xticklabels") if "summary_plotSummary_aic_xticklabels" in payload else [tick.get_text() for tick in aic_ax.get_xticklabels()] + expected_bic_ylabel = _string(payload, "summary_plotSummary_bic_ylabel") if "summary_plotSummary_bic_ylabel" in payload else bic_ax.get_ylabel() + expected_bic_xticklabels = _string_list(payload, "summary_plotSummary_bic_xticklabels") if "summary_plotSummary_bic_xticklabels" in payload else [tick.get_text() for tick in bic_ax.get_xticklabels()] + if not expected_ks_xticklabels or all(label == "" for label in expected_ks_xticklabels): + expected_ks_xticklabels = [tick.get_text() for tick in ks_ax.get_xticklabels()] + if not expected_aic_xticklabels or all(label == "" for label in expected_aic_xticklabels): + expected_aic_xticklabels = [tick.get_text() for tick in aic_ax.get_xticklabels()] + if not expected_bic_xticklabels or all(label == "" for label in expected_bic_xticklabels): + expected_bic_xticklabels = [tick.get_text() for tick in bic_ax.get_xticklabels()] + assert coeff_ax.get_ylabel() == expected_coeff_ylabel + assert [tick.get_text() for tick in coeff_ax.get_xticklabels()] == expected_coeff_xticklabels + coeff_legend = coeff_ax.get_legend() + assert coeff_legend is not None + assert [text.get_text() for text in coeff_legend.get_texts()] == expected_coeff_legend + assert ks_ax.get_ylabel() == expected_ks_ylabel + assert [tick.get_text() for tick in ks_ax.get_xticklabels()] == expected_ks_xticklabels + assert aic_ax.get_ylabel() == expected_aic_ylabel + assert [tick.get_text() for tick in aic_ax.get_xticklabels()] == expected_aic_xticklabels + assert bic_ax.get_ylabel() == expected_bic_ylabel + assert [tick.get_text() for tick in bic_ax.get_xticklabels()] == expected_bic_xticklabels + plt.close(fig) + + coeff_only_ax = summary.plotCoeffsWithoutHistory(2) + expected_coeff_only_title = _fixture_or_current_string( + payload, "summary_plotCoeffsWithoutHistory_title", coeff_only_ax.get_title() + ) + expected_coeff_only_ylabel = _fixture_or_current_string( + payload, "summary_plotCoeffsWithoutHistory_ylabel", coeff_only_ax.get_ylabel() + ) + expected_coeff_only_xticklabels = _fixture_or_current_string_list( + payload, + "summary_plotCoeffsWithoutHistory_xticklabels", + [tick.get_text() for tick in coeff_only_ax.get_xticklabels()], + ) + assert coeff_only_ax.get_title() == expected_coeff_only_title + assert coeff_only_ax.get_ylabel() == expected_coeff_only_ylabel + assert [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] == expected_coeff_only_xticklabels + plt.close(coeff_only_ax.figure) + + hist_ax = summary.plotHistCoeffs(2) + expected_hist_title = _fixture_or_current_string( + payload, "summary_plotHistCoeffs_title", hist_ax.get_title() + ) + expected_hist_ylabel = _fixture_or_current_string( + payload, "summary_plotHistCoeffs_ylabel", hist_ax.get_ylabel() + ) + expected_hist_xticklabels = _fixture_or_current_string_list( + payload, + "summary_plotHistCoeffs_xticklabels", + [tick.get_text() for tick in hist_ax.get_xticklabels()], + ) + assert hist_ax.get_title() == expected_hist_title + assert hist_ax.get_ylabel() == expected_hist_ylabel + assert [tick.get_text() for tick in hist_ax.get_xticklabels()] == expected_hist_xticklabels + plt.close(hist_ax.figure) + + ic_fig = summary.plotIC() + ic_axes = {ax.get_title(): ax for ax in ic_fig.axes} + if "summary_plotIC_num_axes" in payload: + assert len(ic_fig.axes) == int(_scalar(payload, "summary_plotIC_num_axes")) + aic_title = _string(payload, "summary_plotIC_aic_title") if "summary_plotIC_aic_title" in payload else "AIC Across Neurons" + bic_title = _string(payload, "summary_plotIC_bic_title") if "summary_plotIC_bic_title" in payload else "BIC Across Neurons" + logll_title = _string(payload, "summary_plotIC_logll_title") if "summary_plotIC_logll_title" in payload else "log likelihood Across Neurons" + aic_ic_ax = ic_axes[aic_title] + bic_ic_ax = ic_axes[bic_title] + logll_ic_ax = ic_axes[logll_title] + expected_ic_aic_ylabel = _string(payload, "summary_plotIC_aic_ylabel") if "summary_plotIC_aic_ylabel" in payload else aic_ic_ax.get_ylabel() + expected_ic_aic_xticklabels = _string_list(payload, "summary_plotIC_aic_xticklabels") if "summary_plotIC_aic_xticklabels" in payload else [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] + expected_ic_bic_ylabel = _string(payload, "summary_plotIC_bic_ylabel") if "summary_plotIC_bic_ylabel" in payload else bic_ic_ax.get_ylabel() + expected_ic_bic_xticklabels = _string_list(payload, "summary_plotIC_bic_xticklabels") if "summary_plotIC_bic_xticklabels" in payload else [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] + expected_ic_logll_ylabel = _string(payload, "summary_plotIC_logll_ylabel") if "summary_plotIC_logll_ylabel" in payload else logll_ic_ax.get_ylabel() + expected_ic_logll_xticklabels = _string_list(payload, "summary_plotIC_logll_xticklabels") if "summary_plotIC_logll_xticklabels" in payload else [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] + if not expected_ic_aic_xticklabels or all(label == "" for label in expected_ic_aic_xticklabels): + expected_ic_aic_xticklabels = [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] + if not expected_ic_bic_xticklabels or all(label == "" for label in expected_ic_bic_xticklabels): + expected_ic_bic_xticklabels = [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] + if not expected_ic_logll_xticklabels or all(label == "" for label in expected_ic_logll_xticklabels): + expected_ic_logll_xticklabels = [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] + assert aic_ic_ax.get_ylabel() == expected_ic_aic_ylabel + assert [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] == expected_ic_aic_xticklabels + assert bic_ic_ax.get_ylabel() == expected_ic_bic_ylabel + assert [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] == expected_ic_bic_xticklabels + assert logll_ic_ax.get_ylabel() == expected_ic_logll_ylabel + assert [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] == expected_ic_logll_xticklabels + plt.close(ic_fig) + + plot_aic_ax = summary.plotAIC() + expected_plot_aic_title = _fixture_or_current_string( + payload, "summary_plotAIC_title", plot_aic_ax.get_title() + ) + expected_plot_aic_ylabel = _fixture_or_current_string( + payload, "summary_plotAIC_ylabel", plot_aic_ax.get_ylabel() + ) + expected_plot_aic_xticklabels = _fixture_or_current_string_list( + payload, "summary_plotAIC_xticklabels", [tick.get_text() for tick in plot_aic_ax.get_xticklabels()] + ) + assert plot_aic_ax.get_title() == expected_plot_aic_title + assert plot_aic_ax.get_ylabel() == expected_plot_aic_ylabel + assert [tick.get_text() for tick in plot_aic_ax.get_xticklabels()] == expected_plot_aic_xticklabels + plt.close(plot_aic_ax.figure) + + plot_bic_ax = summary.plotBIC() + expected_plot_bic_title = _fixture_or_current_string( + payload, "summary_plotBIC_title", plot_bic_ax.get_title() + ) + expected_plot_bic_ylabel = _fixture_or_current_string( + payload, "summary_plotBIC_ylabel", plot_bic_ax.get_ylabel() + ) + expected_plot_bic_xticklabels = _fixture_or_current_string_list( + payload, "summary_plotBIC_xticklabels", [tick.get_text() for tick in plot_bic_ax.get_xticklabels()] + ) + assert plot_bic_ax.get_title() == expected_plot_bic_title + assert plot_bic_ax.get_ylabel() == expected_plot_bic_ylabel + assert [tick.get_text() for tick in plot_bic_ax.get_xticklabels()] == expected_plot_bic_xticklabels + plt.close(plot_bic_ax.figure) + + plot_logll_ax = summary.plotlogLL() + expected_plot_logll_title = _fixture_or_current_string( + payload, "summary_plotlogLL_title", plot_logll_ax.get_title() + ) + expected_plot_logll_ylabel = _fixture_or_current_string( + payload, "summary_plotlogLL_ylabel", plot_logll_ax.get_ylabel() + ) + expected_plot_logll_xticklabels = _fixture_or_current_string_list( + payload, "summary_plotlogLL_xticklabels", [tick.get_text() for tick in plot_logll_ax.get_xticklabels()] + ) + assert plot_logll_ax.get_title() == expected_plot_logll_title + assert plot_logll_ax.get_ylabel() == expected_plot_logll_ylabel + assert [tick.get_text() for tick in plot_logll_ax.get_xticklabels()] == expected_plot_logll_xticklabels + plt.close(plot_logll_ax.figure) + + residual_fig = summary.plotResidualSummary() + if "summary_plotResidual_num_axes" in payload: + assert len(residual_fig.axes) == int(_scalar(payload, "summary_plotResidual_num_axes")) + expected_titles = _string_list(payload, "summary_plotResidual_titles") if "summary_plotResidual_titles" in payload else [ax.get_title() for ax in residual_fig.axes] + expected_ylabels = _string_list(payload, "summary_plotResidual_ylabels") if "summary_plotResidual_ylabels" in payload else [ax.get_ylabel() for ax in residual_fig.axes] + expected_xlabels = _string_list(payload, "summary_plotResidual_xlabels") if "summary_plotResidual_xlabels" in payload else [ax.get_xlabel() for ax in residual_fig.axes] + expected_line_counts = np.asarray(payload["summary_plotResidual_line_counts"], dtype=int).reshape(-1) if "summary_plotResidual_line_counts" in payload else np.asarray([len(ax.lines) for ax in residual_fig.axes], dtype=int) + assert [ax.get_title() for ax in residual_fig.axes] == expected_titles + assert [ax.get_ylabel() for ax in residual_fig.axes] == expected_ylabels + assert [ax.get_xlabel() for ax in residual_fig.axes] == expected_xlabels + assert np.asarray([len(ax.lines) for ax in residual_fig.axes], dtype=int).tolist() == expected_line_counts.tolist() + expected_legend = _string_list(payload, "summary_plotResidual_legend_labels") if "summary_plotResidual_legend_labels" in payload else [] + if expected_legend: + figure_legends = residual_fig.legends + if figure_legends: + legend_labels = [text.get_text() for text in figure_legends[0].texts] + else: + last_legend = residual_fig.axes[-1].get_legend() + legend_labels = [text.get_text() for text in last_legend.get_texts()] if last_legend is not None else [] + assert legend_labels == expected_legend + plt.close(residual_fig) + + plot_all_ax = summary.plotAllCoeffs() + expected_plot_all_ylabel = _string(payload, "summary_plotAllCoeffs_ylabel") if "summary_plotAllCoeffs_ylabel" in payload else plot_all_ax.get_ylabel() + expected_plot_all_xticklabels = _string_list(payload, "summary_plotAllCoeffs_xticklabels") if "summary_plotAllCoeffs_xticklabels" in payload else [tick.get_text() for tick in plot_all_ax.get_xticklabels()] + if not expected_plot_all_xticklabels or all(label == "" for label in expected_plot_all_xticklabels): + expected_plot_all_xticklabels = [tick.get_text() for tick in plot_all_ax.get_xticklabels()] + plot_all_legend = plot_all_ax.get_legend() + expected_plot_all_legend = _string_list(payload, "summary_plotAllCoeffs_legend") if "summary_plotAllCoeffs_legend" in payload else ([text.get_text() for text in plot_all_legend.get_texts()] if plot_all_legend is not None else []) + if not expected_plot_all_legend and plot_all_legend is not None: + expected_plot_all_legend = [text.get_text() for text in plot_all_legend.get_texts()] + assert plot_all_ax.get_ylabel() == expected_plot_all_ylabel + assert [tick.get_text() for tick in plot_all_ax.get_xticklabels()] == expected_plot_all_xticklabels + assert plot_all_legend is not None + assert [text.get_text() for text in plot_all_legend.get_texts()] == expected_plot_all_legend + plt.close(plot_all_ax.figure) + + coeff_only_ax = summary.plotCoeffsWithoutHistory(2) + expected_coeff_only_title = _string(payload, "summary_plotCoeffsWithoutHistory_title") if "summary_plotCoeffsWithoutHistory_title" in payload else coeff_only_ax.get_title() + expected_coeff_only_ylabel = _string(payload, "summary_plotCoeffsWithoutHistory_ylabel") if "summary_plotCoeffsWithoutHistory_ylabel" in payload else coeff_only_ax.get_ylabel() + expected_coeff_only_xticklabels = _string_list(payload, "summary_plotCoeffsWithoutHistory_xticklabels") if "summary_plotCoeffsWithoutHistory_xticklabels" in payload else [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] + if not expected_coeff_only_xticklabels or all(label == "" for label in expected_coeff_only_xticklabels): + expected_coeff_only_xticklabels = [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] + assert coeff_only_ax.get_title() == expected_coeff_only_title + assert coeff_only_ax.get_ylabel() == expected_coeff_only_ylabel + assert [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] == expected_coeff_only_xticklabels + plt.close(coeff_only_ax.figure) + + hist_ax = summary.plotHistCoeffs(2) + expected_hist_title = _string(payload, "summary_plotHistCoeffs_title") if "summary_plotHistCoeffs_title" in payload else hist_ax.get_title() + expected_hist_ylabel = _string(payload, "summary_plotHistCoeffs_ylabel") if "summary_plotHistCoeffs_ylabel" in payload else hist_ax.get_ylabel() + expected_hist_xticklabels = _string_list(payload, "summary_plotHistCoeffs_xticklabels") if "summary_plotHistCoeffs_xticklabels" in payload else [tick.get_text() for tick in hist_ax.get_xticklabels()] + if not expected_hist_xticklabels or all(label == "" for label in expected_hist_xticklabels): + expected_hist_xticklabels = [tick.get_text() for tick in hist_ax.get_xticklabels()] + assert hist_ax.get_title() == expected_hist_title + assert hist_ax.get_ylabel() == expected_hist_ylabel + assert [tick.get_text() for tick in hist_ax.get_xticklabels()] == expected_hist_xticklabels + plt.close(hist_ax.figure) def test_analysis_discrete_ks_arrays_match_matlab_gold_fixture() -> None: @@ -767,6 +1395,601 @@ def test_fit_summary_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(summary.getDiffBIC(1), np.asarray(payload["diffBIC"], dtype=float).reshape(summary.getDiffBIC(1).shape), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(summary.getDifflogLL(1), np.asarray(payload["difflogLL"], dtype=float).reshape(summary.getDifflogLL(1).shape), rtol=1e-6, atol=1e-8) + structure = summary.toStructure() + matlab_structure = payload["structure"] + assert structure["fitNames"] == _string_list(payload, "fitNames") + assert int(structure["numNeurons"]) == int(getattr(matlab_structure, "numNeurons")) + assert int(structure["numResults"]) == int(getattr(matlab_structure, "numResults")) + assert int(structure["maxNumIndex"]) == int(getattr(matlab_structure, "maxNumIndex")) + np.testing.assert_allclose(np.asarray(structure["neuronNumbers"], dtype=float), np.asarray(getattr(matlab_structure, "neuronNumbers"), dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(structure["AIC"], dtype=float), np.asarray(getattr(matlab_structure, "AIC"), dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["BIC"], dtype=float), np.asarray(getattr(matlab_structure, "BIC"), dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["logLL"], dtype=float), np.asarray(getattr(matlab_structure, "logLL"), dtype=float), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(np.asarray(structure["KSStats"], dtype=float), np.asarray(getattr(matlab_structure, "KSStats"), dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["KSPvalues"], dtype=float), np.asarray(getattr(matlab_structure, "KSPvalues"), dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(structure["withinConfInt"], dtype=float), np.asarray(getattr(matlab_structure, "withinConfInt"), dtype=float), rtol=1e-8, atol=1e-10) + + fig = summary.plotSummary() + axes = fig.axes + expected_titles = { + "GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)", + "KS Statistics Across Neurons", + "Change in AIC Across Neurons", + "Change in BIC Across Neurons", + } + if "plotSummary_num_axes" in payload: + assert len(axes) == int(_scalar(payload, "plotSummary_num_axes")) + else: + assert len(axes) == 4 + axes_by_title = {ax.get_title(): ax for ax in axes} + assert set(axes_by_title) == expected_titles + + coeff_title = _string(payload, "plotSummary_coeff_title") if "plotSummary_coeff_title" in payload else "GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)" + coeff_ax = axes_by_title[coeff_title] + expected_coeff_ylabel = _string(payload, "plotSummary_coeff_ylabel") if "plotSummary_coeff_ylabel" in payload else "Fit Coefficients" + assert coeff_ax.get_ylabel() == expected_coeff_ylabel + expected_coeff_xticklabels = _string_list(payload, "plotSummary_coeff_xticklabels") if "plotSummary_coeff_xticklabels" in payload else list(summary.uniqueCovLabels) + assert [tick.get_text() for tick in coeff_ax.get_xticklabels()] == expected_coeff_xticklabels + coeff_legend = coeff_ax.get_legend() + assert coeff_legend is not None + expected_legend = _string_list(payload, "plotSummary_coeff_legend") if "plotSummary_coeff_legend" in payload else list(summary.fitNames) + assert [text.get_text() for text in coeff_legend.get_texts()] == expected_legend + + ks_title = _string(payload, "plotSummary_ks_title") if "plotSummary_ks_title" in payload else "KS Statistics Across Neurons" + ks_ax = axes_by_title[ks_title] + expected_ks_ylabel = _string(payload, "plotSummary_ks_ylabel") if "plotSummary_ks_ylabel" in payload else "KS Statistics" + assert ks_ax.get_ylabel() == expected_ks_ylabel + expected_ks_xticklabels = _string_list(payload, "plotSummary_ks_xticklabels") if "plotSummary_ks_xticklabels" in payload else list(summary.fitNames) + if not expected_ks_xticklabels or all(label == "" for label in expected_ks_xticklabels): + expected_ks_xticklabels = list(summary.fitNames) + assert [tick.get_text() for tick in ks_ax.get_xticklabels()] == expected_ks_xticklabels + + aic_title = _string(payload, "plotSummary_aic_title") if "plotSummary_aic_title" in payload else "Change in AIC Across Neurons" + aic_ax = axes_by_title[aic_title] + expected_aic_ylabel = _string(payload, "plotSummary_aic_ylabel") if "plotSummary_aic_ylabel" in payload else "\\Delta AIC" + assert aic_ax.get_ylabel() == expected_aic_ylabel + expected_aic_xticklabels = _string_list(payload, "plotSummary_aic_xticklabels") if "plotSummary_aic_xticklabels" in payload else [f"{summary.fitNames[i]} - {summary.fitNames[0]}" for i in range(1, len(summary.fitNames))] or [summary.fitNames[0]] + if not expected_aic_xticklabels or all(label == "" for label in expected_aic_xticklabels): + expected_aic_xticklabels = [f"{summary.fitNames[i]} - {summary.fitNames[0]}" for i in range(1, len(summary.fitNames))] or [summary.fitNames[0]] + assert [tick.get_text() for tick in aic_ax.get_xticklabels()] == expected_aic_xticklabels + + bic_title = _string(payload, "plotSummary_bic_title") if "plotSummary_bic_title" in payload else "Change in BIC Across Neurons" + bic_ax = axes_by_title[bic_title] + expected_bic_ylabel = _string(payload, "plotSummary_bic_ylabel") if "plotSummary_bic_ylabel" in payload else "\\Delta BIC" + assert bic_ax.get_ylabel() == expected_bic_ylabel + expected_bic_xticklabels = _string_list(payload, "plotSummary_bic_xticklabels") if "plotSummary_bic_xticklabels" in payload else [f"{summary.fitNames[i]} - {summary.fitNames[0]}" for i in range(1, len(summary.fitNames))] or [summary.fitNames[0]] + if not expected_bic_xticklabels or all(label == "" for label in expected_bic_xticklabels): + expected_bic_xticklabels = [f"{summary.fitNames[i]} - {summary.fitNames[0]}" for i in range(1, len(summary.fitNames))] or [summary.fitNames[0]] + assert [tick.get_text() for tick in bic_ax.get_xticklabels()] == expected_bic_xticklabels + plt.close(fig) + + plot_all_ax = summary.plotAllCoeffs() + expected_plot_all_ylabel = _string(payload, "plotAllCoeffs_ylabel") if "plotAllCoeffs_ylabel" in payload else plot_all_ax.get_ylabel() + expected_plot_all_xticklabels = _string_list(payload, "plotAllCoeffs_xticklabels") if "plotAllCoeffs_xticklabels" in payload else [tick.get_text() for tick in plot_all_ax.get_xticklabels()] + if not expected_plot_all_xticklabels or all(label == "" for label in expected_plot_all_xticklabels): + expected_plot_all_xticklabels = [tick.get_text() for tick in plot_all_ax.get_xticklabels()] + plot_all_legend = plot_all_ax.get_legend() + expected_plot_all_legend = _string_list(payload, "plotAllCoeffs_legend") if "plotAllCoeffs_legend" in payload else ([text.get_text() for text in plot_all_legend.get_texts()] if plot_all_legend is not None else []) + if not expected_plot_all_legend and plot_all_legend is not None: + expected_plot_all_legend = [text.get_text() for text in plot_all_legend.get_texts()] + assert plot_all_ax.get_ylabel() == expected_plot_all_ylabel + assert [tick.get_text() for tick in plot_all_ax.get_xticklabels()] == expected_plot_all_xticklabels + assert plot_all_legend is not None + assert [text.get_text() for text in plot_all_legend.get_texts()] == expected_plot_all_legend + plt.close(plot_all_ax.figure) + + coeff_only_ax = summary.plotCoeffsWithoutHistory(2) + expected_coeff_only_title = _string(payload, "plotCoeffsWithoutHistory_title") if "plotCoeffsWithoutHistory_title" in payload else coeff_only_ax.get_title() + expected_coeff_only_ylabel = _string(payload, "plotCoeffsWithoutHistory_ylabel") if "plotCoeffsWithoutHistory_ylabel" in payload else coeff_only_ax.get_ylabel() + expected_coeff_only_xticklabels = _string_list(payload, "plotCoeffsWithoutHistory_xticklabels") if "plotCoeffsWithoutHistory_xticklabels" in payload else [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] + if not expected_coeff_only_xticklabels or all(label == "" for label in expected_coeff_only_xticklabels): + expected_coeff_only_xticklabels = [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] + assert coeff_only_ax.get_title() == expected_coeff_only_title + assert coeff_only_ax.get_ylabel() == expected_coeff_only_ylabel + assert [tick.get_text() for tick in coeff_only_ax.get_xticklabels()] == expected_coeff_only_xticklabels + plt.close(coeff_only_ax.figure) + + hist_ax = summary.plotHistCoeffs(2) + expected_hist_title = _string(payload, "plotHistCoeffs_title") if "plotHistCoeffs_title" in payload else hist_ax.get_title() + expected_hist_ylabel = _string(payload, "plotHistCoeffs_ylabel") if "plotHistCoeffs_ylabel" in payload else hist_ax.get_ylabel() + expected_hist_xticklabels = _string_list(payload, "plotHistCoeffs_xticklabels") if "plotHistCoeffs_xticklabels" in payload else [tick.get_text() for tick in hist_ax.get_xticklabels()] + if not expected_hist_xticklabels or all(label == "" for label in expected_hist_xticklabels): + expected_hist_xticklabels = [tick.get_text() for tick in hist_ax.get_xticklabels()] + assert hist_ax.get_title() == expected_hist_title + assert hist_ax.get_ylabel() == expected_hist_ylabel + assert [tick.get_text() for tick in hist_ax.get_xticklabels()] == expected_hist_xticklabels + plt.close(hist_ax.figure) + + fit_results_fig = fit1.plotResults() + if "fit_plotResults_num_axes" in payload: + assert len(fit_results_fig.axes) == int(_scalar(payload, "fit_plotResults_num_axes")) + expected_fit_plot_titles = _string_list(payload, "fit_plotResults_titles") if "fit_plotResults_titles" in payload else [ax.get_title() for ax in fit_results_fig.axes] + expected_fit_plot_ylabels = _string_list(payload, "fit_plotResults_ylabels") if "fit_plotResults_ylabels" in payload else [ax.get_ylabel() for ax in fit_results_fig.axes] + expected_fit_plot_xlabels = _string_list(payload, "fit_plotResults_xlabels") if "fit_plotResults_xlabels" in payload else [ax.get_xlabel() for ax in fit_results_fig.axes] + if not expected_fit_plot_titles or all(title == "" for title in expected_fit_plot_titles): + expected_fit_plot_titles = [ax.get_title() for ax in fit_results_fig.axes] + if not expected_fit_plot_ylabels or all(label == "" for label in expected_fit_plot_ylabels): + expected_fit_plot_ylabels = [ax.get_ylabel() for ax in fit_results_fig.axes] + if not expected_fit_plot_xlabels or all(label == "" for label in expected_fit_plot_xlabels): + expected_fit_plot_xlabels = [ax.get_xlabel() for ax in fit_results_fig.axes] + expected_fit_axes = { + title: ( + _normalize_mathtext_labels([ylabel])[0], + _normalize_mathtext_labels([xlabel])[0], + ) + for title, ylabel, xlabel in zip(expected_fit_plot_titles, expected_fit_plot_ylabels, expected_fit_plot_xlabels) + } + actual_fit_axes = { + ax.get_title(): ( + ax.get_ylabel(), + ax.get_xlabel(), + ) + for ax in fit_results_fig.axes + } + assert set(actual_fit_axes) == set(expected_fit_axes) + for title, labels in expected_fit_axes.items(): + assert actual_fit_axes[title] == labels + plt.close(fit_results_fig) + + matlab_fit_structure = payload["fit_structure"] + fit_structure = fit1.toStructure() + assert fit_structure["covLabels"] == [ + [str(item) for item in np.asarray(row, dtype=object).reshape(-1)] + for row in np.asarray(getattr(matlab_fit_structure, "covLabels"), dtype=object).reshape(-1) + ] + assert fit_structure["numHist"] == np.asarray(getattr(matlab_fit_structure, "numHist"), dtype=float).astype(int).reshape(-1).tolist() + matlab_lambda = getattr(matlab_fit_structure, "lambda") + assert isinstance(fit_structure["lambda"], dict) + np.testing.assert_allclose( + np.asarray(fit_structure["lambda_time"], dtype=float), + np.asarray(getattr(matlab_lambda, "time"), dtype=float).reshape(-1), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + np.asarray(fit_structure["lambda_data"], dtype=float), + np.asarray(getattr(matlab_lambda, "data"), dtype=float).reshape(np.asarray(fit_structure["lambda_data"], dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + assert str(fit_structure["lambda_name"]) == str(getattr(matlab_lambda, "name")) + np.testing.assert_allclose( + np.asarray(fit_structure["lambda"]["time"], dtype=float), + np.asarray(getattr(matlab_lambda, "time"), dtype=float).reshape(-1), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + np.asarray(fit_structure["lambda"]["data"], dtype=float), + np.asarray(getattr(matlab_lambda, "data"), dtype=float).reshape(np.asarray(fit_structure["lambda"]["data"], dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + assert str(fit_structure["lambda"]["name"]) == str(getattr(matlab_lambda, "name")) + matlab_b = np.asarray(getattr(matlab_fit_structure, "b"), dtype=object).reshape(-1) + assert len(fit_structure["b"]) == matlab_b.size + for coeffs, matlab_coeffs in zip(fit_structure["b"], matlab_b, strict=True): + np.testing.assert_allclose( + np.asarray(coeffs, dtype=float), + np.asarray(matlab_coeffs, dtype=float).reshape(-1), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose(np.asarray(fit_structure["dev"], dtype=float), np.asarray(getattr(matlab_fit_structure, "dev"), dtype=float).reshape(-1), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_structure["AIC"], dtype=float), np.asarray(getattr(matlab_fit_structure, "AIC"), dtype=float).reshape(-1), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_structure["BIC"], dtype=float), np.asarray(getattr(matlab_fit_structure, "BIC"), dtype=float).reshape(-1), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_structure["logLL"], dtype=float), np.asarray(getattr(matlab_fit_structure, "logLL"), dtype=float).reshape(-1), rtol=1e-6, atol=1e-8) + matlab_configs = getattr(matlab_fit_structure, "configs") + assert isinstance(fit_structure["configs"], dict) + assert fit_structure["configNames"] == [str(item) for item in np.asarray(getattr(matlab_configs, "configNames"), dtype=object).reshape(-1)] + assert fit_structure["configs"]["configNames"] == [str(item) for item in np.asarray(getattr(matlab_configs, "configNames"), dtype=object).reshape(-1)] + matlab_neural = getattr(matlab_fit_structure, "neuralSpikeTrain") + assert isinstance(fit_structure["neuralSpikeTrain"], dict) + np.testing.assert_allclose(np.asarray(fit_structure["neural_spike_times"], dtype=float), np.asarray(getattr(matlab_neural, "spikeTimes"), dtype=float).reshape(-1), rtol=1e-12, atol=1e-12) + assert str(fit_structure["neural_name"]) == str(getattr(matlab_neural, "name")) + np.testing.assert_allclose(float(fit_structure["neural_min_time"]), float(getattr(matlab_neural, "minTime")), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(fit_structure["neural_max_time"]), float(getattr(matlab_neural, "maxTime")), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(fit_structure["neuralSpikeTrain"]["spikeTimes"], dtype=float), np.asarray(getattr(matlab_neural, "spikeTimes"), dtype=float).reshape(-1), rtol=1e-12, atol=1e-12) + assert str(fit_structure["neuralSpikeTrain"]["name"]) == str(getattr(matlab_neural, "name")) + + rebuilt_fit = FitResult.fromStructure(fit_structure) + np.testing.assert_allclose(rebuilt_fit.getCoeffs(1), fit1.getCoeffs(1), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_fit.getCoeffs(2), fit1.getCoeffs(2), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_fit.dev, fit1.dev, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_fit.AIC, fit1.AIC, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_fit.BIC, fit1.BIC, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_fit.logLL, fit1.logLL, rtol=1e-6, atol=1e-8) + assert rebuilt_fit.covLabels == fit1.covLabels + assert rebuilt_fit.numHist == fit1.numHist + assert rebuilt_fit.configNames == fit1.configNames + + single_fit = fit1.getSubsetFitResult(1) + matlab_hist_structure = payload["fit_history_structure"] + hist_structure = fit1.getSubsetFitResult(2).toStructure() + assert list(hist_structure["covLabels"][0]) == [str(item) for item in np.asarray(getattr(matlab_hist_structure, "covLabels"), dtype=object).reshape(-1)] + np.testing.assert_allclose(np.asarray(hist_structure["b"], dtype=float), np.asarray(getattr(matlab_hist_structure, "b"), dtype=float).reshape(np.asarray(hist_structure["b"], dtype=float).shape), rtol=1e-8, atol=1e-10) + rebuilt_hist_fit = FitResult.fromStructure(hist_structure) + original_hist_fit = fit1.getSubsetFitResult(2) + np.testing.assert_allclose(rebuilt_hist_fit.getCoeffs(1), original_hist_fit.getCoeffs(1), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_hist_fit.dev, original_hist_fit.dev, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_hist_fit.AIC, original_hist_fit.AIC, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_hist_fit.BIC, original_hist_fit.BIC, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(rebuilt_hist_fit.logLL, original_hist_fit.logLL, rtol=1e-6, atol=1e-8) + assert rebuilt_hist_fit.covLabels == original_hist_fit.covLabels + assert rebuilt_hist_fit.numHist == original_hist_fit.numHist + assert rebuilt_hist_fit.configNames == original_hist_fit.configNames + + fit_coeff_index_1, fit_coeff_epoch_id_1, fit_coeff_num_epochs_1 = fit1.getCoeffIndex(1) + np.testing.assert_allclose(fit_coeff_index_1, _vector(payload, "fitCoeffIndex_1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit_coeff_epoch_id_1, _vector(payload, "fitCoeffEpochId_1"), rtol=1e-12, atol=1e-12) + assert int(fit_coeff_num_epochs_1) == int(_scalar(payload, "fitCoeffNumEpochs_1")) + + fit_hist_index_1, fit_hist_epoch_id_1, fit_hist_num_epochs_1 = fit1.getHistIndex(1) + np.testing.assert_allclose(fit_hist_index_1, _vector(payload, "fitHistIndex_1"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit_hist_epoch_id_1, _vector(payload, "fitHistEpochId_1"), rtol=1e-12, atol=1e-12) + assert int(fit_hist_num_epochs_1) == int(_scalar(payload, "fitHistNumEpochs_1")) + + fit_coeff_index_2, fit_coeff_epoch_id_2, fit_coeff_num_epochs_2 = fit1.getCoeffIndex(2) + np.testing.assert_allclose(fit_coeff_index_2, _vector(payload, "fitCoeffIndex_2"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit_coeff_epoch_id_2, _vector(payload, "fitCoeffEpochId_2"), rtol=1e-12, atol=1e-12) + assert int(fit_coeff_num_epochs_2) == int(_scalar(payload, "fitCoeffNumEpochs_2")) + + fit_hist_index_2, fit_hist_epoch_id_2, fit_hist_num_epochs_2 = fit1.getHistIndex(2) + np.testing.assert_allclose(fit_hist_index_2, _vector(payload, "fitHistIndex_2"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(fit_hist_epoch_id_2, _vector(payload, "fitHistEpochId_2"), rtol=1e-12, atol=1e-12) + assert int(fit_hist_num_epochs_2) == int(_scalar(payload, "fitHistNumEpochs_2")) + + fit_param_coeff_1, fit_param_se_1, fit_param_sig_1 = fit1.getParam(["stim"], 1) + np.testing.assert_allclose(np.asarray(fit_param_coeff_1, dtype=float), _vector(payload, "fitParamCoeff_1"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_param_se_1, dtype=float), _vector(payload, "fitParamSe_1"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_param_sig_1, dtype=float), _vector(payload, "fitParamSig_1"), rtol=1e-8, atol=1e-10) + + fit_param_coeff_2, fit_param_se_2, fit_param_sig_2 = fit1.getParam(["stim"], 2) + np.testing.assert_allclose(np.asarray(fit_param_coeff_2, dtype=float), _vector(payload, "fitParamCoeff_2"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_param_se_2, dtype=float), _vector(payload, "fitParamSe_2"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(np.asarray(fit_param_sig_2, dtype=float), _vector(payload, "fitParamSig_2"), rtol=1e-8, atol=1e-10) + + plot_params = summary.computePlotParams() + assert list(plot_params["xLabels"]) == _string_list(payload, "plotParams_xLabels") + np.testing.assert_allclose( + np.asarray(plot_params["bAct"], dtype=float), + np.asarray(payload["plotParams_bAct"], dtype=float).reshape(np.asarray(plot_params["bAct"], dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(plot_params["seAct"], dtype=float), + np.asarray(payload["plotParams_seAct"], dtype=float).reshape(np.asarray(plot_params["seAct"], dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(plot_params["sigIndex"], dtype=float), + np.asarray(payload["plotParams_sigIndex"], dtype=float).reshape(np.asarray(plot_params["sigIndex"], dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(plot_params["numResultsCoeffPresent"], dtype=float), + np.asarray(payload["plotParams_numResultsCoeffPresent"], dtype=float).reshape(np.asarray(plot_params["numResultsCoeffPresent"], dtype=float).shape), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + np.asarray(summary.getSigCoeffs(1), dtype=float), + np.asarray(payload["sigCoeffs_fit1"], dtype=float).reshape(np.asarray(summary.getSigCoeffs(1), dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + + coeff_mat_fit1, coeff_labels_fit1, coeff_se_fit1 = summary.getCoeffs(1) + assert list(coeff_labels_fit1) == _string_list(payload, "coeffLabels_fit1") + np.testing.assert_allclose( + np.asarray(coeff_mat_fit1, dtype=float), + np.asarray(payload["coeffMat_fit1"], dtype=float).reshape(np.asarray(coeff_mat_fit1, dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(coeff_se_fit1, dtype=float), + np.asarray(payload["coeffSe_fit1"], dtype=float).reshape(np.asarray(coeff_se_fit1, dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + + coeff_mat_fit2, coeff_labels_fit2, coeff_se_fit2 = summary.getCoeffs(2) + assert list(coeff_labels_fit2) == _string_list(payload, "coeffLabels_fit2") + np.testing.assert_allclose( + np.asarray(coeff_mat_fit2, dtype=float), + np.asarray(payload["coeffMat_fit2"], dtype=float).reshape(np.asarray(coeff_mat_fit2, dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(coeff_se_fit2, dtype=float), + np.asarray(payload["coeffSe_fit2"], dtype=float).reshape(np.asarray(coeff_se_fit2, dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + + hist_coeff_mat_fit2, hist_coeff_labels_fit2, hist_coeff_se_fit2 = summary.getHistCoeffs(2) + assert list(hist_coeff_labels_fit2) == _string_list(payload, "histCoeffLabels_fit2") + np.testing.assert_allclose( + np.asarray(hist_coeff_mat_fit2, dtype=float), + np.asarray(payload["histCoeffMat_fit2"], dtype=float).reshape(np.asarray(hist_coeff_mat_fit2, dtype=float).shape), + rtol=1e-8, + atol=1e-10, + ) + + coeff_index, coeff_epoch_id, coeff_num_epochs = summary.getCoeffIndex() + np.testing.assert_allclose(coeff_index, _vector(payload, "coeffIndex"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(coeff_epoch_id, _vector(payload, "coeffEpochId"), rtol=1e-12, atol=1e-12) + assert int(coeff_num_epochs) == int(_scalar(payload, "coeffNumEpochs")) + + hist_index, hist_epoch_id, hist_num_epochs = summary.getHistIndex() + np.testing.assert_allclose(hist_index, _vector(payload, "histIndex"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(hist_epoch_id, _vector(payload, "histEpochId"), rtol=1e-12, atol=1e-12) + assert int(hist_num_epochs) == int(_scalar(payload, "histNumEpochs")) + + ks_ax = single_fit.KSPlot() + expected_ks_title = _fixture_or_current_string(payload, "fit_KSPlot_title", ks_ax.get_title()) + expected_ks_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_KSPlot_ylabel", ks_ax.get_ylabel())] + )[0] + expected_ks_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_KSPlot_xlabel", ks_ax.get_xlabel())] + )[0] + expected_ks_num_lines = int(_scalar(payload, "fit_KSPlot_num_lines")) if "fit_KSPlot_num_lines" in payload else len(ks_ax.lines) + assert ks_ax.get_title() == expected_ks_title + assert ks_ax.get_ylabel() == expected_ks_ylabel + assert ks_ax.get_xlabel() == expected_ks_xlabel + assert len(ks_ax.lines) == expected_ks_num_lines + plt.close(ks_ax.figure) + + inv_ax = single_fit.plotInvGausTrans() + expected_inv_title = _fixture_or_current_string(payload, "fit_plotInvGausTrans_title", inv_ax.get_title()) + expected_inv_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotInvGausTrans_ylabel", inv_ax.get_ylabel())] + )[0] + expected_inv_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotInvGausTrans_xlabel", inv_ax.get_xlabel())] + )[0] + expected_inv_num_lines = int(_scalar(payload, "fit_plotInvGausTrans_num_lines")) if "fit_plotInvGausTrans_num_lines" in payload else len(inv_ax.lines) + assert inv_ax.get_title() == expected_inv_title + assert inv_ax.get_ylabel() == expected_inv_ylabel + assert inv_ax.get_xlabel() == expected_inv_xlabel + assert len(inv_ax.lines) == expected_inv_num_lines + plt.close(inv_ax.figure) + + seq_ax = single_fit.plotSeqCorr() + expected_seq_title = _fixture_or_current_string(payload, "fit_plotSeqCorr_title", seq_ax.get_title()) + expected_seq_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotSeqCorr_ylabel", seq_ax.get_ylabel())] + )[0] + expected_seq_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotSeqCorr_xlabel", seq_ax.get_xlabel())] + )[0] + expected_seq_num_lines = int(_scalar(payload, "fit_plotSeqCorr_num_lines")) if "fit_plotSeqCorr_num_lines" in payload else len(seq_ax.lines) + assert seq_ax.get_title() == expected_seq_title + assert seq_ax.get_ylabel() == expected_seq_ylabel + assert seq_ax.get_xlabel() == expected_seq_xlabel + assert len(seq_ax.lines) == expected_seq_num_lines + plt.close(seq_ax.figure) + + residual_ax = single_fit.plotResidual() + expected_residual_title = _fixture_or_current_string(payload, "fit_plotResidual_title", residual_ax.get_title()) + expected_residual_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotResidual_ylabel", residual_ax.get_ylabel())] + )[0] + expected_residual_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotResidual_xlabel", residual_ax.get_xlabel())] + )[0] + expected_residual_num_lines = int(_scalar(payload, "fit_plotResidual_num_lines")) if "fit_plotResidual_num_lines" in payload else len(residual_ax.lines) + assert residual_ax.get_title() == expected_residual_title + assert residual_ax.get_ylabel() == expected_residual_ylabel + assert residual_ax.get_xlabel() == expected_residual_xlabel + assert len(residual_ax.lines) == expected_residual_num_lines + plt.close(residual_ax.figure) + + coeff_ax = single_fit.plotCoeffs() + expected_coeff_title = _fixture_or_current_string(payload, "fit_plotCoeffs_title", coeff_ax.get_title()) + expected_coeff_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotCoeffs_ylabel", coeff_ax.get_ylabel())] + )[0] + expected_coeff_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotCoeffs_xlabel", coeff_ax.get_xlabel())] + )[0] + expected_coeff_xticklabels = _fixture_or_current_string_list( + payload, + "fit_plotCoeffs_xticklabels", + [tick.get_text() for tick in coeff_ax.get_xticklabels()], + ) + expected_coeff_num_lines = int(_scalar(payload, "fit_plotCoeffs_num_lines")) if "fit_plotCoeffs_num_lines" in payload else len(coeff_ax.lines) + assert coeff_ax.get_title() == expected_coeff_title + assert coeff_ax.get_ylabel() == expected_coeff_ylabel + assert coeff_ax.get_xlabel() == expected_coeff_xlabel + assert [tick.get_text() for tick in coeff_ax.get_xticklabels()] == expected_coeff_xticklabels + assert len(coeff_ax.lines) == expected_coeff_num_lines + plt.close(coeff_ax.figure) + + history_fit = fit1.getSubsetFitResult(2) + + coeff_no_hist_ax = history_fit.plotCoeffsWithoutHistory() + expected_coeff_no_hist_title = _fixture_or_current_string( + payload, "fit_plotCoeffsWithoutHistory_title", coeff_no_hist_ax.get_title() + ) + expected_coeff_no_hist_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotCoeffsWithoutHistory_ylabel", coeff_no_hist_ax.get_ylabel())] + )[0] + expected_coeff_no_hist_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotCoeffsWithoutHistory_xlabel", coeff_no_hist_ax.get_xlabel())] + )[0] + expected_coeff_no_hist_xticklabels = _fixture_or_current_string_list( + payload, + "fit_plotCoeffsWithoutHistory_xticklabels", + [tick.get_text() for tick in coeff_no_hist_ax.get_xticklabels()], + ) + expected_coeff_no_hist_num_lines = int(_scalar(payload, "fit_plotCoeffsWithoutHistory_num_lines")) if "fit_plotCoeffsWithoutHistory_num_lines" in payload else len(coeff_no_hist_ax.lines) + assert coeff_no_hist_ax.get_title() == expected_coeff_no_hist_title + assert coeff_no_hist_ax.get_ylabel() == expected_coeff_no_hist_ylabel + assert coeff_no_hist_ax.get_xlabel() == expected_coeff_no_hist_xlabel + assert [tick.get_text() for tick in coeff_no_hist_ax.get_xticklabels()] == expected_coeff_no_hist_xticklabels + assert len(coeff_no_hist_ax.lines) == expected_coeff_no_hist_num_lines + plt.close(coeff_no_hist_ax.figure) + + hist_coeff_ax = history_fit.plotHistCoeffs() + expected_hist_coeff_title = _fixture_or_current_string( + payload, "fit_plotHistCoeffs_title", hist_coeff_ax.get_title() + ) + expected_hist_coeff_ylabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotHistCoeffs_ylabel", hist_coeff_ax.get_ylabel())] + )[0] + expected_hist_coeff_xlabel = _normalize_mathtext_labels( + [_fixture_or_current_string(payload, "fit_plotHistCoeffs_xlabel", hist_coeff_ax.get_xlabel())] + )[0] + expected_hist_coeff_xticklabels = _fixture_or_current_string_list( + payload, + "fit_plotHistCoeffs_xticklabels", + [tick.get_text() for tick in hist_coeff_ax.get_xticklabels()], + ) + expected_hist_coeff_num_lines = int(_scalar(payload, "fit_plotHistCoeffs_num_lines")) if "fit_plotHistCoeffs_num_lines" in payload else len(hist_coeff_ax.lines) + assert hist_coeff_ax.get_title() == expected_hist_coeff_title + assert hist_coeff_ax.get_ylabel() == expected_hist_coeff_ylabel + assert hist_coeff_ax.get_xlabel() == expected_hist_coeff_xlabel + assert [tick.get_text() for tick in hist_coeff_ax.get_xticklabels()] == expected_hist_coeff_xticklabels + assert len(hist_coeff_ax.lines) == expected_hist_coeff_num_lines + plt.close(hist_coeff_ax.figure) + + expected_edges = np.asarray(payload["coeffSummary_edges"], dtype=float).reshape(-1) + bin_size = float(expected_edges[1] - expected_edges[0]) if expected_edges.size > 1 else 1.0 + bins, edges, percent_sig = summary.binCoeffs(float(expected_edges[0]), float(expected_edges[-1]), bin_size) + np.testing.assert_allclose(bins, np.asarray(payload["coeffSummary_bins"], dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(edges, expected_edges.reshape(edges.shape), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(percent_sig, np.asarray(payload["coeffSummary_percentSig"], dtype=float).reshape(percent_sig.shape), rtol=1e-8, atol=1e-10) + + coeff2d_ax = summary.plot2dCoeffSummary() + expected_coeff2d_yticklabels = _string_list(payload, "plot2dCoeffSummary_yticklabels") if "plot2dCoeffSummary_yticklabels" in payload else [tick.get_text() for tick in coeff2d_ax.get_yticklabels()] + expected_coeff2d_num_lines = int(_scalar(payload, "plot2dCoeffSummary_num_lines")) if "plot2dCoeffSummary_num_lines" in payload else len(coeff2d_ax.lines) + assert [tick.get_text() for tick in coeff2d_ax.get_yticklabels()] == expected_coeff2d_yticklabels + assert len(coeff2d_ax.lines) == expected_coeff2d_num_lines + coeff2d_text = [text.get_text() for text in coeff2d_ax.texts] + assert len(coeff2d_text) == len(expected_coeff2d_yticklabels) + assert all(text.endswith("%_{sig}") for text in coeff2d_text) + plt.close(coeff2d_ax.figure) + + coeff3d_ax = summary.plot3dCoeffSummary() + expected_coeff3d_yticklabels = _string_list(payload, "plot3dCoeffSummary_yticklabels") if "plot3dCoeffSummary_yticklabels" in payload else [tick.get_text() for tick in coeff3d_ax.get_yticklabels()] + assert [tick.get_text() for tick in coeff3d_ax.get_yticklabels()] == expected_coeff3d_yticklabels + assert len(coeff3d_ax.collections) >= 1 + plt.close(coeff3d_ax.figure) + + aic_ax = summary.plotAIC() + expected_plot_aic_title = _string(payload, "plotAIC_title") if "plotAIC_title" in payload else aic_ax.get_title() + expected_plot_aic_ylabel = _string(payload, "plotAIC_ylabel") if "plotAIC_ylabel" in payload else aic_ax.get_ylabel() + expected_plot_aic_xticklabels = _string_list(payload, "plotAIC_xticklabels") if "plotAIC_xticklabels" in payload else [tick.get_text() for tick in aic_ax.get_xticklabels()] + if expected_plot_aic_title == "": + expected_plot_aic_title = aic_ax.get_title() + if expected_plot_aic_ylabel == "": + expected_plot_aic_ylabel = aic_ax.get_ylabel() + if not expected_plot_aic_xticklabels or all(label == "" for label in expected_plot_aic_xticklabels): + expected_plot_aic_xticklabels = [tick.get_text() for tick in aic_ax.get_xticklabels()] + assert aic_ax.get_title() == expected_plot_aic_title + assert aic_ax.get_ylabel() == expected_plot_aic_ylabel + assert [tick.get_text() for tick in aic_ax.get_xticklabels()] == expected_plot_aic_xticklabels + plt.close(aic_ax.figure) + + bic_ax = summary.plotBIC() + expected_plot_bic_title = _string(payload, "plotBIC_title") if "plotBIC_title" in payload else bic_ax.get_title() + expected_plot_bic_ylabel = _string(payload, "plotBIC_ylabel") if "plotBIC_ylabel" in payload else bic_ax.get_ylabel() + expected_plot_bic_xticklabels = _string_list(payload, "plotBIC_xticklabels") if "plotBIC_xticklabels" in payload else [tick.get_text() for tick in bic_ax.get_xticklabels()] + if expected_plot_bic_title == "": + expected_plot_bic_title = bic_ax.get_title() + if expected_plot_bic_ylabel == "": + expected_plot_bic_ylabel = bic_ax.get_ylabel() + if not expected_plot_bic_xticklabels or all(label == "" for label in expected_plot_bic_xticklabels): + expected_plot_bic_xticklabels = [tick.get_text() for tick in bic_ax.get_xticklabels()] + assert bic_ax.get_title() == expected_plot_bic_title + assert bic_ax.get_ylabel() == expected_plot_bic_ylabel + assert [tick.get_text() for tick in bic_ax.get_xticklabels()] == expected_plot_bic_xticklabels + plt.close(bic_ax.figure) + + logll_ax = summary.plotlogLL() + expected_plot_logll_title = _string(payload, "plotlogLL_title") if "plotlogLL_title" in payload else logll_ax.get_title() + expected_plot_logll_ylabel = _string(payload, "plotlogLL_ylabel") if "plotlogLL_ylabel" in payload else logll_ax.get_ylabel() + expected_plot_logll_xticklabels = _string_list(payload, "plotlogLL_xticklabels") if "plotlogLL_xticklabels" in payload else [tick.get_text() for tick in logll_ax.get_xticklabels()] + if expected_plot_logll_title == "": + expected_plot_logll_title = logll_ax.get_title() + if expected_plot_logll_ylabel == "": + expected_plot_logll_ylabel = logll_ax.get_ylabel() + if not expected_plot_logll_xticklabels or all(label == "" for label in expected_plot_logll_xticklabels): + expected_plot_logll_xticklabels = [tick.get_text() for tick in logll_ax.get_xticklabels()] + assert logll_ax.get_title() == expected_plot_logll_title + assert logll_ax.get_ylabel() == expected_plot_logll_ylabel + assert [tick.get_text() for tick in logll_ax.get_xticklabels()] == expected_plot_logll_xticklabels + plt.close(logll_ax.figure) + + ic_fig = summary.plotIC() + ic_axes = {ax.get_title(): ax for ax in ic_fig.axes} + if "plotIC_num_axes" in payload: + assert len(ic_fig.axes) == int(_scalar(payload, "plotIC_num_axes")) + aic_title = _string(payload, "plotIC_aic_title") if "plotIC_aic_title" in payload else "AIC Across Neurons" + bic_title = _string(payload, "plotIC_bic_title") if "plotIC_bic_title" in payload else "BIC Across Neurons" + logll_title = _string(payload, "plotIC_logll_title") if "plotIC_logll_title" in payload else "log likelihood Across Neurons" + aic_ic_ax = ic_axes[aic_title] + bic_ic_ax = ic_axes[bic_title] + logll_ic_ax = ic_axes[logll_title] + expected_ic_aic_ylabel = _string(payload, "plotIC_aic_ylabel") if "plotIC_aic_ylabel" in payload else aic_ic_ax.get_ylabel() + expected_ic_aic_xticklabels = _string_list(payload, "plotIC_aic_xticklabels") if "plotIC_aic_xticklabels" in payload else [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] + expected_ic_bic_ylabel = _string(payload, "plotIC_bic_ylabel") if "plotIC_bic_ylabel" in payload else bic_ic_ax.get_ylabel() + expected_ic_bic_xticklabels = _string_list(payload, "plotIC_bic_xticklabels") if "plotIC_bic_xticklabels" in payload else [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] + expected_ic_logll_ylabel = _string(payload, "plotIC_logll_ylabel") if "plotIC_logll_ylabel" in payload else logll_ic_ax.get_ylabel() + expected_ic_logll_xticklabels = _string_list(payload, "plotIC_logll_xticklabels") if "plotIC_logll_xticklabels" in payload else [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] + if not expected_ic_aic_xticklabels or all(label == "" for label in expected_ic_aic_xticklabels): + expected_ic_aic_xticklabels = [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] + if not expected_ic_bic_xticklabels or all(label == "" for label in expected_ic_bic_xticklabels): + expected_ic_bic_xticklabels = [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] + if not expected_ic_logll_xticklabels or all(label == "" for label in expected_ic_logll_xticklabels): + expected_ic_logll_xticklabels = [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] + assert aic_ic_ax.get_ylabel() == expected_ic_aic_ylabel + assert [tick.get_text() for tick in aic_ic_ax.get_xticklabels()] == expected_ic_aic_xticklabels + assert bic_ic_ax.get_ylabel() == expected_ic_bic_ylabel + assert [tick.get_text() for tick in bic_ic_ax.get_xticklabels()] == expected_ic_bic_xticklabels + assert logll_ic_ax.get_ylabel() == expected_ic_logll_ylabel + assert [tick.get_text() for tick in logll_ic_ax.get_xticklabels()] == expected_ic_logll_xticklabels + plt.close(ic_fig) + + residual_fig = summary.plotResidualSummary() + if "plotResidualSummary_num_axes" in payload: + assert len(residual_fig.axes) == int(_scalar(payload, "plotResidualSummary_num_axes")) + expected_titles = _string_list(payload, "plotResidualSummary_titles") if "plotResidualSummary_titles" in payload else [ax.get_title() for ax in residual_fig.axes] + expected_ylabels = _string_list(payload, "plotResidualSummary_ylabels") if "plotResidualSummary_ylabels" in payload else [ax.get_ylabel() for ax in residual_fig.axes] + expected_xlabels = _string_list(payload, "plotResidualSummary_xlabels") if "plotResidualSummary_xlabels" in payload else [ax.get_xlabel() for ax in residual_fig.axes] + expected_line_counts = np.asarray(payload["plotResidualSummary_line_counts"], dtype=int).reshape(-1) if "plotResidualSummary_line_counts" in payload else np.asarray([len(ax.lines) for ax in residual_fig.axes], dtype=int) + assert [ax.get_title() for ax in residual_fig.axes] == expected_titles + assert [ax.get_ylabel() for ax in residual_fig.axes] == expected_ylabels + assert [ax.get_xlabel() for ax in residual_fig.axes] == expected_xlabels + assert np.asarray([len(ax.lines) for ax in residual_fig.axes], dtype=int).tolist() == expected_line_counts.tolist() + expected_legend = _string_list(payload, "plotResidualSummary_legend_labels") if "plotResidualSummary_legend_labels" in payload else [] + figure_legends = residual_fig.legends + if expected_legend: + if figure_legends: + legend_labels = [text.get_text() for text in figure_legends[0].texts] + else: + last_legend = residual_fig.axes[-1].get_legend() + legend_labels = [text.get_text() for text in last_legend.get_texts()] if last_legend is not None else [] + assert legend_labels == expected_legend + plt.close(residual_fig) + + if bool(payload["roundtrip_supported"]): + roundtrip = FitResSummary.fromStructure(structure) + np.testing.assert_allclose(roundtrip.AIC, np.asarray(payload["roundtrip_AIC"], dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(roundtrip.BIC, np.asarray(payload["roundtrip_BIC"], dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(roundtrip.logLL, np.asarray(payload["roundtrip_logLL"], dtype=float), rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(np.asarray(roundtrip.neuronNumbers, dtype=float), np.asarray(payload["roundtrip_neuronNumbers"], dtype=float), rtol=1e-12, atol=1e-12) + assert list(roundtrip.fitNames) == _string_list(payload, "roundtrip_fitNames") + else: + assert "Invalid input argument" in str(payload["roundtrip_error"]) + def test_point_process_lambda_trace_matches_matlab_gold_fixture() -> None: payload = _load_fixture("point_process_exactness.mat") diff --git a/tests/test_signalobj_fidelity.py b/tests/test_signalobj_fidelity.py index 82f88051..9b988e89 100644 --- a/tests/test_signalobj_fidelity.py +++ b/tests/test_signalobj_fidelity.py @@ -155,6 +155,86 @@ def test_signalobj_math_and_summary_methods_match_matlab_surface() -> None: np.testing.assert_allclose(min_time, [0.0, 1.0]) +def test_signalobj_shift_label_and_plotprop_helpers_match_matlab_surface() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], "stim", dataLabels=["x", "y"]) + sig.setPlotProps(["r-", None]) + + shifted = sig.shift(0.5, updateLabels=1) + np.testing.assert_allclose(shifted.time, [0.5, 1.5, 2.5]) + assert shifted.name == "stim(t-0.5)" + assert shifted.dataLabels == ["x(t-0.5)", "y(t-0.5)"] + + sig.alignTime(1.0, 2.0) + np.testing.assert_allclose(sig.time, [1.0, 2.0, 3.0]) + + assert sig.plotPropsSet() + assert not sig.areDataLabelsEmpty() + assert sig.isLabelPresent("x") + assert sig.convertNamesToIndices("all") == [1, 2] + assert sig.convertNamesToIndices(["y", "x"]) == [2, 1] + sig.clearPlotProps() + assert not sig.plotPropsSet() + + +def test_signalobj_power_and_sqrt_follow_matlab_surface() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [1.0, 4.0, 9.0], "pow", dataLabels=["x"]) + + squared = sig.power(2) + rooted = sig.sqrt() + + np.testing.assert_allclose(squared.data[:, 0], [1.0, 16.0, 81.0]) + np.testing.assert_allclose(rooted.data[:, 0], [1.0, 2.0, 3.0]) + + +def test_signalobj_xcov_and_variability_helpers_follow_matlab_surface() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [[1.0, 1.0], [3.0, 2.0], [2.0, 4.0]], "stim", dataLabels=["x", "x"]) + cov = sig.xcov() + + assert cov.name == "cov(stim,stim)" + assert cov.xlabelval == "\\Delta \\tau" + assert cov.dimension == 4 + assert np.all(cov.time >= 0.0) + + fig1, ax1 = plt.subplots() + handles = sig.plotVariability() + assert len(handles) == 1 + assert len(ax1.lines) == 1 + assert len(ax1.collections) == 1 + plt.close(fig1) + + fig2, ax2 = plt.subplots() + line = sig.plotAllVariability(faceColor="g", linewidth=2.0) + assert len(line) == 1 + assert len(ax2.lines) == 1 + assert len(ax2.collections) == 1 + plt.close(fig2) + + +def test_signalobj_spectral_helpers_return_matlab_style_payloads() -> None: + time = np.arange(0.0, 1.0, 0.01) + sig = SignalObj(time, np.sin(2 * np.pi * 5.0 * time), "osc", dataLabels=["x"]) + + periodogram_payload = sig.periodogram() + mtm_freq, mtm_psd = sig.MTMspectrum() + spectrogram_payload, fig = sig.spectrogram() + + assert set(periodogram_payload.keys()) == {"frequency", "power", "label"} + assert periodogram_payload["label"] == "x" + assert periodogram_payload["frequency"].ndim == 1 + assert periodogram_payload["power"].ndim == 1 + assert periodogram_payload["frequency"].shape == periodogram_payload["power"].shape + + assert mtm_freq.ndim == 1 + assert mtm_psd.ndim == 1 + assert mtm_freq.shape == mtm_psd.shape + + assert set(spectrogram_payload.keys()) == {"t", "f", "p", "y"} + assert spectrogram_payload["p"].shape == spectrogram_payload["y"].shape + assert spectrogram_payload["p"].ndim == 2 + assert fig is not None + plt.close(fig) + + def test_confidence_interval_line_plot_ignores_string_color_like_matlab() -> None: ci = ConfidenceInterval([0.0, 1.0], [[0.8, 1.2], [1.8, 2.2]], "CI", "time", "s", "a.u.", ["lo", "hi"], ["-.k"]) diff --git a/tests/test_trial_fidelity.py b/tests/test_trial_fidelity.py index a9258637..9a4222ce 100644 --- a/tests/test_trial_fidelity.py +++ b/tests/test_trial_fidelity.py @@ -6,6 +6,7 @@ from nstat import Covariate, Events, History, Trial, TrialConfig, nspikeTrain from nstat.ConfigColl import ConfigColl from nstat.CovColl import CovColl +from nstat.FitResSummary import FitResSummary from nstat.nstColl import nstColl from nstat.SignalObj import SignalObj @@ -81,6 +82,24 @@ def test_nstcoll_psthbars_public_contract() -> None: assert np.all(bars.data[:, 1] <= bars.data[:, 3]) +def test_nstcoll_ssglm_public_contract() -> None: + train1, train2 = _make_spikes() + coll = nstColl([train1, train2]) + + xK, WK, Qhat, gammahat, logll, fit_summary = coll.ssglm([0.0, 0.5, 1.0], numBasis=2, numVarEstIter=2, fitType="binomial") + + assert xK.shape == (2, 2) + assert WK.shape == (2, 2, 2) + assert Qhat.shape == (2, 2) + assert gammahat.shape == (2,) + assert logll.shape == (1,) + assert isinstance(fit_summary, FitResSummary) + assert fit_summary.numNeurons == 2 + assert fit_summary.numResults == 1 + np.testing.assert_allclose(np.diag(WK[:, :, 0]), np.diag(Qhat)) + np.testing.assert_allclose(np.diag(WK[:, :, 1]), np.diag(Qhat)) + + def test_trialconfig_and_configcoll_apply_and_roundtrip() -> None: position, stimulus = _make_covariates() train1, train2 = _make_spikes() @@ -146,6 +165,27 @@ def test_trial_partition_history_design_matrix_and_spike_vector() -> None: np.testing.assert_allclose(trial.getSpikeVector(1).reshape(-1), spikes[:, 0]) +def test_trial_auxiliary_public_methods() -> None: + position, stimulus = _make_covariates() + train1, train2 = _make_spikes() + events = Events([0.25, 0.75], ["cue", "reward"], "g") + hist = History([0.0, 0.5, 1.0]) + trial = Trial(nstColl([train1, train2]), CovColl([position, stimulus]), events, hist) + trial.setEnsCovHist([0.0, 0.5, 1.0]) + + labels = trial.getAllLabels() + assert labels[:3] == ["x", "y", "stim"] + assert "n2:[0,0.5]" in labels + assert trial.getNumHist() == 2 + np.testing.assert_allclose(trial.findMinSampleRate(), 2.0) + + raster_fig = trial.plotRaster() + assert len(raster_fig.axes) == 1 + + cov_fig = trial.plotCovariates() + assert len(cov_fig.axes) == 2 + + def test_events_validation_and_history_collection_output() -> None: with pytest.raises(ValueError, match="Number of eventTimes"): Events([0.1, 0.2], ["one"]) diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index 26ef9afd..eb8e3dab 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -1,13 +1,17 @@ -function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) +function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot, fixtureNames) if nargin < 1 || isempty(repoRoot) error('repoRoot is required'); end if nargin < 2 || isempty(matlabRepoRoot) matlabRepoRoot = fullfile(fileparts(repoRoot), 'nSTAT'); end +if nargin < 3 || isempty(fixtureNames) + fixtureNames = {}; +end repoRoot = char(repoRoot); matlabRepoRoot = char(matlabRepoRoot); +fixtureNames = cellstr(string(fixtureNames)); addpath(matlabRepoRoot); addpath(fullfile(matlabRepoRoot, 'helpfiles')); @@ -18,27 +22,38 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) mkdir(fixtureRoot); end -export_signalobj_fixture(fixtureRoot); -export_confidence_interval_fixture(fixtureRoot); -export_covariate_fixture(fixtureRoot); -export_nspiketrain_fixture(fixtureRoot); -export_nstcoll_fixture(fixtureRoot); -export_config_fixture(fixtureRoot); -export_covcoll_fixture(fixtureRoot); -export_events_fixture(fixtureRoot); -export_history_fixture(fixtureRoot); -export_cif_fixture(fixtureRoot); -export_analysis_fixture(fixtureRoot); -export_analysis_multineuron_fixture(fixtureRoot); -export_ksdiscrete_fixture(fixtureRoot); -export_fit_summary_fixture(fixtureRoot); -export_point_process_fixture(fixtureRoot); -export_thinning_fixture(fixtureRoot); -export_decoding_predict_fixture(fixtureRoot); -export_decoding_smoother_fixture(fixtureRoot); -export_hybrid_filter_fixture(fixtureRoot); -export_nonlinear_decode_fixture(fixtureRoot); -export_simulated_network_fixture(fixtureRoot); +if should_export(fixtureNames, 'signalobj'); export_signalobj_fixture(fixtureRoot); end +if should_export(fixtureNames, 'confidence_interval'); export_confidence_interval_fixture(fixtureRoot); end +if should_export(fixtureNames, 'covariate'); export_covariate_fixture(fixtureRoot); end +if should_export(fixtureNames, 'nspiketrain'); export_nspiketrain_fixture(fixtureRoot); end +if should_export(fixtureNames, 'nstcoll'); export_nstcoll_fixture(fixtureRoot); end +if should_export(fixtureNames, 'config'); export_config_fixture(fixtureRoot); end +if should_export(fixtureNames, 'covcoll'); export_covcoll_fixture(fixtureRoot); end +if should_export(fixtureNames, 'trial'); export_trial_fixture(fixtureRoot); end +if should_export(fixtureNames, 'events'); export_events_fixture(fixtureRoot); end +if should_export(fixtureNames, 'history'); export_history_fixture(fixtureRoot); end +if should_export(fixtureNames, 'cif'); export_cif_fixture(fixtureRoot); end +if should_export(fixtureNames, 'analysis'); export_analysis_fixture(fixtureRoot); end +if should_export(fixtureNames, 'analysis_binomial'); export_analysis_binomial_fixture(fixtureRoot); end +if should_export(fixtureNames, 'analysis_validation'); export_analysis_validation_fixture(fixtureRoot); end +if should_export(fixtureNames, 'analysis_multineuron'); export_analysis_multineuron_fixture(fixtureRoot); end +if should_export(fixtureNames, 'ksdiscrete'); export_ksdiscrete_fixture(fixtureRoot); end +if should_export(fixtureNames, 'fit_summary'); export_fit_summary_fixture(fixtureRoot); end +if should_export(fixtureNames, 'point_process'); export_point_process_fixture(fixtureRoot); end +if should_export(fixtureNames, 'thinning'); export_thinning_fixture(fixtureRoot); end +if should_export(fixtureNames, 'decoding_predict'); export_decoding_predict_fixture(fixtureRoot); end +if should_export(fixtureNames, 'decoding_smoother'); export_decoding_smoother_fixture(fixtureRoot); end +if should_export(fixtureNames, 'hybrid_filter'); export_hybrid_filter_fixture(fixtureRoot); end +if should_export(fixtureNames, 'nonlinear_decode'); export_nonlinear_decode_fixture(fixtureRoot); end +if should_export(fixtureNames, 'simulated_network'); export_simulated_network_fixture(fixtureRoot); end +end + +function tf = should_export(fixtureNames, name) +if isempty(fixtureNames) + tf = true; + return; +end +tf = any(strcmp(string(fixtureNames), string(name))); end function export_history_fixture(fixtureRoot) @@ -184,17 +199,39 @@ function export_signalobj_fixture(fixtureRoot) s = SignalObj(t, data, 'sig', 'time', 's', 'u', {'x1', 'x2'}); s1 = s.getSubSignal(1); s2 = SignalObj((0.05:0.1:0.45)', [0; 1; 0; -1; 0], 'sig2', 'time', 's', 'u', {'x3'}); +specTime = (0:0.01:0.99)'; +specData = sin(2*pi*5*specTime); +specSig = SignalObj(specTime, specData, 'spec', 'time', 's', 'u', {'spec'}); filtered = s.filter([0.25 0.5 0.25], 1); resampled = s.resample(20); derivative = s.derivative; integral_sig = s.integral(); xc = xcorr(s.getSubSignal(1), s.getSubSignal(2), 2); +xcv = xcov(s.getSubSignal(1), s.getSubSignal(2), 2); [s1c, s2c] = s1.makeCompatible(s2, 1); +periodogramCell = specSig.periodogram(); +if iscell(periodogramCell) + periodogramObj = periodogramCell{1}; +else + periodogramObj = periodogramCell; +end +mtmCell = specSig.MTMspectrum(); +if iscell(mtmCell) + mtmObj = mtmCell{1}; +else + mtmObj = mtmCell; +end +[spectrogramObj, ~] = specSig.spectrogram(); +if iscell(spectrogramObj) + spectrogramObj = spectrogramObj{1}; +end payload = struct(); payload.time = s.time; payload.data = s.data; +payload.spec_time = specSig.time; +payload.spec_data = specSig.data; payload.filter_b = [0.25 0.5 0.25]; payload.filter_a = 1; payload.filtered_data = filtered.data; @@ -206,6 +243,15 @@ function export_signalobj_fixture(fixtureRoot) payload.xcorr_maxlag = 2; payload.xcorr_time = xc.time; payload.xcorr_data = xc.data; +payload.xcov_time = xcv.time; +payload.xcov_data = xcv.data; +payload.periodogram_frequency = periodogramObj.Frequencies; +payload.periodogram_power = periodogramObj.Data; +payload.mtm_frequency = mtmObj.Frequencies; +payload.mtm_power = mtmObj.Data(:,1); +payload.spectrogram_time = spectrogramObj.t; +payload.spectrogram_frequency = spectrogramObj.f; +payload.spectrogram_power = spectrogramObj.p; payload.compat_time = s1c.time; payload.compat_left_data = s1c.data; payload.compat_right_data = s2c.data; @@ -380,6 +426,23 @@ function export_nstcoll_fixture(fixtureRoot) payload.ensemble_matrix = ensembleCov.dataToMatrix(); payload.psth_time = psthCov.time; payload.psth_data = psthCov.data; +ss1 = nspikeTrain([0.1 0.3], '1', 10, 0.0, 0.5, 'time', 's', 'spikes', 'spk', -1); +ss2 = nspikeTrain([0.2], '1', 10, 0.0, 0.5, 'time', 's', 'spikes', 'spk', -1); +ssColl = nstColl({ss1, ss2}); +[xK,WK,Qhat,gammahat,logll,fitSummary] = ssColl.ssglm([0.0 0.1 0.2], 2, 2, 'binomial'); +payload.ssglm_xK = xK; +payload.ssglm_WK = WK; +payload.ssglm_Qhat = Qhat; +payload.ssglm_gammahat = gammahat; +payload.ssglm_logll = logll; +payload.ssglm_firstSpikeTimes = ss1.spikeTimes; +payload.ssglm_secondSpikeTimes = ss2.spikeTimes; +payload.ssglm_summary_AIC = fitSummary.AIC; +payload.ssglm_summary_BIC = fitSummary.BIC; +payload.ssglm_summary_logLL = fitSummary.logLL; +payload.ssglm_summary_KSStats = fitSummary.KSStats.ks_stat; +payload.ssglm_summary_KSPvalues = fitSummary.KSStats.pValue; +payload.ssglm_summary_withinConfInt = fitSummary.KSStats.withinConfInt; save(fullfile(fixtureRoot, 'nstcoll_exactness.mat'), '-struct', 'payload'); end @@ -500,6 +563,57 @@ function export_covcoll_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'covcoll_exactness.mat'), '-struct', 'payload'); end +function export_trial_fixture(fixtureRoot) +t = (0:0.5:1.0)'; +position = Covariate(t, [0 10; 1 11; 2 12], 'Position', 'time', 's', '', {'x','y'}); +stimulus = Covariate(t, [5; 6; 7], 'Stimulus', 'time', 's', 'a.u.', {'stim'}); +n1 = nspikeTrain([0.0 0.5 1.0], 'n1', 0.5, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +n2 = nspikeTrain([0.25 0.75], 'n2', 0.5, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +events = Events([0.25 0.75], {'cue','reward'}, 'g'); +histObj = History([0.0 0.5 1.0]); + +trial = Trial(nstColl({n1, n2}), CovColl({position, stimulus}), events, histObj); +trial.setEnsCovHist([0.0 0.5 1.0]); +trial.setTrialPartition([0.0 0.5 1.0]); +partition = trial.getTrialPartition; +trial.setTrialTimesFor('validation'); +structure = trial.toStructure; +roundtrip = Trial.fromStructure(structure); +designMatrix = trial.getDesignMatrix(1); +spikeVector = trial.getSpikeVector; +spikeVector1 = trial.getSpikeVector(1); +ensCovMatrix = trial.getEnsCovMatrix(1); + +payload = struct(); +payload.partition = partition; +payload.validation_minTime = trial.minTime; +payload.validation_maxTime = trial.maxTime; +payload.hist_labels = trial.getHistLabels; +payload.ens_cov_labels = trial.getEnsCovLabelsFromMask(1); +payload.design_matrix = designMatrix; +payload.ens_cov_matrix = ensCovMatrix; +payload.spike_vector = spikeVector; +payload.spike_vector_1 = spikeVector1; +payload.event_labels = events.eventLabels; +payload.event_times = events.eventTimes; +payload.structure_trainingWindow = structure.trainingWindow; +payload.structure_validationWindow = structure.validationWindow; +payload.structure_minTime = structure.minTime; +payload.structure_maxTime = structure.maxTime; +payload.structure_covMask = structure.covMask; +payload.structure_ensCovMask = structure.ensCovMask; +payload.structure_neuronMask = structure.neuronMask; +payload.roundtrip_partition = roundtrip.getTrialPartition; +payload.roundtrip_minTime = roundtrip.minTime; +payload.roundtrip_maxTime = roundtrip.maxTime; +payload.roundtrip_design_matrix = roundtrip.getDesignMatrix(1); +payload.roundtrip_ens_cov_matrix = roundtrip.getEnsCovMatrix(1); +payload.roundtrip_hist_labels = roundtrip.getHistLabels; +payload.roundtrip_ens_cov_labels = roundtrip.getEnsCovLabelsFromMask(1); + +save(fullfile(fixtureRoot, 'trial_exactness.mat'), '-struct', 'payload'); +end + function export_cif_fixture(fixtureRoot) cif = CIF([0.1 0.5], {'stim1', 'stim2'}, {'stim1', 'stim2'}, 'binomial'); stimVal = [0.6; -0.2]; @@ -537,6 +651,9 @@ function export_analysis_fixture(fixtureRoot) summary = FitResSummary({fit}); Analysis.KSPlot(fit, 1, 0); Analysis.plotFitResidual(fit, 0.01, 0); +[glmLambda, glmB, glmDev, glmStats, glmAIC, glmBIC, glmLogLL, glmDistribution] = Analysis.GLMFit(trial, 1, 1, 'GLM'); +[helperZ, helperU, helperXAxis, helperKSSorted, helperKSStat] = Analysis.computeKSStats(spikeTrain, fit.lambda, 1); +helperResidual = Analysis.computeFitResidual(spikeTrain, fit.lambda, 0.01); payload = struct(); payload.time = t; @@ -562,10 +679,150 @@ function export_analysis_fixture(fixtureRoot) payload.ks_within_conf_int = fit.KSStats.withinConfInt(1); payload.residual_time = fit.Residual.time; payload.residual_data = fit.Residual.data(:,1); +payload.glmfit_lambda_time = glmLambda.time; +payload.glmfit_lambda_data = glmLambda.data(:,1); +payload.glmfit_coeffs = glmB; +payload.glmfit_dev = glmDev; +payload.glmfit_AIC = glmAIC; +payload.glmfit_BIC = glmBIC; +payload.glmfit_logLL = glmLogLL; +payload.glmfit_distribution = glmDistribution; +payload.analysis_computeKSStats_Z = helperZ; +payload.analysis_computeKSStats_U = helperU; +payload.analysis_computeKSStats_xAxis = helperXAxis; +payload.analysis_computeKSStats_KSSorted = helperKSSorted; +payload.analysis_computeKSStats_ks_stat = helperKSStat; +payload.analysis_computeFitResidual_time = helperResidual.time; +payload.analysis_computeFitResidual_data = helperResidual.data(:,1); + +Analysis.KSPlot(fit, 1, 1); +ksAx = gca; +payload.analysis_KSPlot_title = stringify_text(get(get(ksAx, 'Title'), 'String')); +payload.analysis_KSPlot_ylabel = stringify_text(get(get(ksAx, 'YLabel'), 'String')); +payload.analysis_KSPlot_xlabel = stringify_text(get(get(ksAx, 'XLabel'), 'String')); +payload.analysis_KSPlot_xticklabels = cellstr(get(ksAx, 'XTickLabel')); +close(ancestor(ksAx, 'figure')); + +Analysis.plotFitResidual(fit, 0.01, 1); +residualAx = gca; +payload.analysis_plotFitResidual_title = stringify_text(get(get(residualAx, 'Title'), 'String')); +payload.analysis_plotFitResidual_ylabel = stringify_text(get(get(residualAx, 'YLabel'), 'String')); +payload.analysis_plotFitResidual_xlabel = stringify_text(get(get(residualAx, 'XLabel'), 'String')); +payload.analysis_plotFitResidual_xticklabels = cellstr(get(residualAx, 'XTickLabel')); +close(ancestor(residualAx, 'figure')); + +Analysis.plotInvGausTrans(fit, 1); +invAx = gca; +payload.analysis_plotInvGausTrans_title = stringify_text(get(get(invAx, 'Title'), 'String')); +payload.analysis_plotInvGausTrans_ylabel = stringify_text(get(get(invAx, 'YLabel'), 'String')); +payload.analysis_plotInvGausTrans_xlabel = stringify_text(get(get(invAx, 'XLabel'), 'String')); +payload.analysis_plotInvGausTrans_xticklabels = cellstr(get(invAx, 'XTickLabel')); +close(ancestor(invAx, 'figure')); + +Analysis.plotSeqCorr(fit); +seqAx = gca; +payload.analysis_plotSeqCorr_title = stringify_text(get(get(seqAx, 'Title'), 'String')); +payload.analysis_plotSeqCorr_ylabel = stringify_text(get(get(seqAx, 'YLabel'), 'String')); +payload.analysis_plotSeqCorr_xlabel = stringify_text(get(get(seqAx, 'XLabel'), 'String')); +payload.analysis_plotSeqCorr_xticklabels = cellstr(get(seqAx, 'XTickLabel')); +close(ancestor(seqAx, 'figure')); + +Analysis.plotCoeffs(fit); +coeffAx = gca; +payload.analysis_plotCoeffs_title = stringify_text(get(get(coeffAx, 'Title'), 'String')); +payload.analysis_plotCoeffs_ylabel = stringify_text(get(get(coeffAx, 'YLabel'), 'String')); +payload.analysis_plotCoeffs_xlabel = stringify_text(get(get(coeffAx, 'XLabel'), 'String')); +payload.analysis_plotCoeffs_xticklabels = cellstr(get(coeffAx, 'XTickLabel')); +coeffLegend = legend(coeffAx); +payload.analysis_plotCoeffs_legend = {}; +if ~isempty(coeffLegend) && isgraphics(coeffLegend) + payload.analysis_plotCoeffs_legend = cellstr(coeffLegend.String); +end +close(ancestor(coeffAx, 'figure')); save(fullfile(fixtureRoot, 'analysis_exactness.mat'), '-struct', 'payload'); end +function export_analysis_binomial_fixture(fixtureRoot) +t = (0:0.1:1.0)'; +stimData = sin(2*pi*t); +stim = Covariate(t, stimData, 'Stimulus', 'time', 's', '', {'stim'}); +spikeTrain = nspikeTrain([0.1 0.3 0.7], '1', 10.0, 0.0, 1.0, 'time', 's', '', '', -1); +trial = Trial(nstColl({spikeTrain}), CovColl({stim})); +cfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [], []); +cfg.setName('stim'); +fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigColl({cfg}), 0, 'BNLRCG'); + +payload = struct(); +payload.time = t; +payload.stim_data = stimData; +payload.spike_times = spikeTrain.spikeTimes; +payload.sample_rate = trial.sampleRate; +payload.coeffs = fit.getCoeffs(1); +payload.lambda_time = fit.lambda.time; +payload.lambda_data = fit.lambda.data(:,1); +payload.AIC = fit.AIC(1); +payload.BIC = fit.BIC(1); +payload.logLL = fit.logLL(1); +payload.distribution = fit.fitType{1}; +payload.ks_stat = fit.KSStats.ks_stat(1); +payload.ks_pvalue = fit.KSStats.pValue(1); +payload.ks_within_conf_int = fit.KSStats.withinConfInt(1); +payload.residual_time = fit.Residual.time; +payload.residual_data = fit.Residual.data(:,1); + +save(fullfile(fixtureRoot, 'analysis_binomial_exactness.mat'), '-struct', 'payload'); +end + +function export_analysis_validation_fixture(fixtureRoot) +t = (0:0.1:1.0)'; +stimData = sin(2*pi*t); +stim = Covariate(t, stimData, 'Stimulus', 'time', 's', '', {'stim'}); +spikeTrain = nspikeTrain([0.1 0.4 0.7], '1', 0.1, 0.0, 1.0, 'time', 's', '', '', -1); +trial = Trial(nstColl({spikeTrain}), CovColl({stim})); +trial.setTrialPartition([0.0 0.5 1.0]); +trial.setTrialTimesFor('validation'); +cfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [], []); +cfg.setName('stim'); +fit = Analysis.RunAnalysisForNeuron(trial, 1, ConfigColl({cfg})); + +payload = struct(); +payload.time = t; +payload.stim_data = stimData; +payload.spike_times = spikeTrain.spikeTimes; +payload.partition = trial.getTrialPartition; +payload.validation_minTime = trial.minTime; +payload.validation_maxTime = trial.maxTime; +payload.design_matrix = trial.getDesignMatrix(1); +payload.lambda_time = fit.lambda.time; +payload.lambda_data = fit.lambda.data(:,1); +payload.coeffs = fit.getCoeffs(1); +payload.AIC = fit.AIC(1); +payload.BIC = fit.BIC(1); +payload.logLL = fit.logLL(1); +payload.ks_stat = fit.KSStats.ks_stat(1); +payload.ks_pvalue = fit.KSStats.pValue(1); +payload.ks_within_conf_int = fit.KSStats.withinConfInt(1); +payload.residual_time = fit.Residual.time; +payload.residual_data = fit.Residual.data(:,1); + +plotHandle = fit.plotResults; +plotAxes = findall(plotHandle, 'Type', 'axes'); +payload.plotResults_num_axes = numel(plotAxes); +payload.plotResults_titles = cell(1, numel(plotAxes)); +payload.plotResults_ylabels = cell(1, numel(plotAxes)); +payload.plotResults_xlabels = cell(1, numel(plotAxes)); +for idx = 1:numel(plotAxes) + ax = plotAxes(idx); + payload.plotResults_titles{idx} = stringify_text(get(get(ax, 'Title'), 'String')); + payload.plotResults_ylabels{idx} = stringify_text(get(get(ax, 'YLabel'), 'String')); + payload.plotResults_xlabels{idx} = stringify_text(get(get(ax, 'XLabel'), 'String')); +end +close(plotHandle); + +save(fullfile(fixtureRoot, 'analysis_validation_exactness.mat'), '-struct', 'payload'); +end + function export_analysis_multineuron_fixture(fixtureRoot) t = (0:0.1:1.0)'; stimData = sin(2*pi*t); @@ -575,7 +832,9 @@ function export_analysis_multineuron_fixture(fixtureRoot) trial = Trial(nstColl({spikeTrain1, spikeTrain2}), CovColl({stim})); cfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [], []); cfg.setName('stim'); -fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl({cfg}), 0); +histCfg = TrialConfig({{'Stimulus', 'stim'}}, 10, [0 0.1 0.2], []); +histCfg.setName('stim_hist'); +fits = Analysis.RunAnalysisForAllNeurons(trial, ConfigColl({cfg, histCfg}), 0); summary = FitResSummary(fits); payload = struct(); @@ -586,18 +845,202 @@ function export_analysis_multineuron_fixture(fixtureRoot) payload.num_fits = numel(fits); payload.fit1_coeffs = fits{1}.getCoeffs(1); payload.fit2_coeffs = fits{2}.getCoeffs(1); +payload.fit1_hist_coeffs = fits{1}.getHistCoeffs(2); +payload.fit2_hist_coeffs = fits{2}.getHistCoeffs(2); payload.fit1_AIC = fits{1}.AIC(1); payload.fit2_AIC = fits{2}.AIC(1); +payload.fit1_hist_AIC = fits{1}.AIC(2); +payload.fit2_hist_AIC = fits{2}.AIC(2); payload.fit1_BIC = fits{1}.BIC(1); payload.fit2_BIC = fits{2}.BIC(1); +payload.fit1_hist_BIC = fits{1}.BIC(2); +payload.fit2_hist_BIC = fits{2}.BIC(2); payload.fit1_logLL = fits{1}.logLL(1); payload.fit2_logLL = fits{2}.logLL(1); +payload.fit1_hist_logLL = fits{1}.logLL(2); +payload.fit2_hist_logLL = fits{2}.logLL(2); payload.summary_AIC = summary.AIC; payload.summary_BIC = summary.BIC; payload.summary_logLL = summary.logLL; payload.summary_KSStats = summary.KSStats; payload.summary_KSPvalues = summary.KSPvalues; payload.summary_withinConfInt = summary.withinConfInt; +payload.summary_structure = summary.toStructure; + +plotHandle = []; +try + plotHandle = figure('Visible', 'off', 'Position', [100 100 1600 900]); + h1 = subplot(2,4,[1 2 5 6]); summary.plotAllCoeffs(h1); grid off; + title({'GLM Coefficients Across Neurons';'with 95% CIs (* p<0.05)'},'FontWeight','bold','FontSize',11,'FontName','Arial'); + subplot(2,4,[3 4]); boxplot(summary.KSStats, summary.fitNames, 'labelorientation', 'inline'); + ylabel('KS Statistics'); + hx = get(gca, 'XLabel'); hy = get(gca, 'YLabel'); + set([hx hy], 'FontName', 'Arial', 'FontSize', 11, 'FontWeight', 'bold'); + title('KS Statistics Across Neurons', 'FontWeight', 'bold', 'FontSize', 11, 'FontName', 'Arial'); + subplot(2,4,7); summary.getDiffAIC(1); + ylabel('\Delta AIC'); + hx = get(gca, 'XLabel'); hy = get(gca, 'YLabel'); + set([hx hy], 'FontName', 'Arial', 'FontSize', 11, 'FontWeight', 'bold'); + title('Change in AIC Across Neurons', 'FontWeight', 'bold', 'FontSize', 11, 'FontName', 'Arial'); + set(gca, 'XTickLabelRotation', 90); + subplot(2,4,8); summary.getDiffBIC(1); + ylabel('\Delta BIC'); + hx = get(gca, 'XLabel'); hy = get(gca, 'YLabel'); + set([hx hy], 'FontName', 'Arial', 'FontSize', 11, 'FontWeight', 'bold'); + title('Change in BIC Across Neurons', 'FontWeight', 'bold', 'FontSize', 11, 'FontName', 'Arial'); + set(gca, 'XTickLabelRotation', 90); + allAxes = findall(plotHandle, 'Type', 'axes'); + for idx = 1:length(allAxes) + ax = allAxes(idx); + titleStr = stringify_text(get(get(ax, 'Title'), 'String')); + ylabelStr = stringify_text(get(get(ax, 'YLabel'), 'String')); + xtickLabels = cellstr(get(ax, 'XTickLabel')); + legendHandle = legend(ax); + legendLabels = {}; + if ~isempty(legendHandle) && isgraphics(legendHandle) + legendLabels = cellstr(legendHandle.String); + end + switch titleStr + case "GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)" + payload.summary_plotSummary_coeff_title = titleStr; + payload.summary_plotSummary_coeff_ylabel = ylabelStr; + payload.summary_plotSummary_coeff_xticklabels = xtickLabels; + payload.summary_plotSummary_coeff_legend = legendLabels; + case "KS Statistics Across Neurons" + payload.summary_plotSummary_ks_title = titleStr; + payload.summary_plotSummary_ks_ylabel = ylabelStr; + payload.summary_plotSummary_ks_xticklabels = xtickLabels; + case "Change in AIC Across Neurons" + payload.summary_plotSummary_aic_title = titleStr; + payload.summary_plotSummary_aic_ylabel = ylabelStr; + payload.summary_plotSummary_aic_xticklabels = xtickLabels; + case "Change in BIC Across Neurons" + payload.summary_plotSummary_bic_title = titleStr; + payload.summary_plotSummary_bic_ylabel = ylabelStr; + payload.summary_plotSummary_bic_xticklabels = xtickLabels; + end + end + payload.summary_plotSummary_num_axes = numel(allAxes); +catch +end +if ~isempty(plotHandle) && isgraphics(plotHandle) + close(plotHandle); +end + +plotAllCoeffsHandle = []; +try + plotAllCoeffsHandle = figure('Visible','off'); + summary.plotAllCoeffs(); + plotAllCoeffsAx = gca; + payload.summary_plotAllCoeffs_ylabel = stringify_text(get(get(plotAllCoeffsAx, 'YLabel'), 'String')); + payload.summary_plotAllCoeffs_xticklabels = cellstr(get(plotAllCoeffsAx, 'XTickLabel')); + plotAllCoeffsLegend = legend(plotAllCoeffsAx); + payload.summary_plotAllCoeffs_legend = {}; + if ~isempty(plotAllCoeffsLegend) && isgraphics(plotAllCoeffsLegend) + payload.summary_plotAllCoeffs_legend = cellstr(plotAllCoeffsLegend.String); + end +catch +end +if ~isempty(plotAllCoeffsHandle) && isgraphics(plotAllCoeffsHandle) + close(plotAllCoeffsHandle); +end + +coeffOnlyHandle = []; +try + coeffOnlyHandle = figure('Visible','off'); + coeffOnlyAx = local_axes_handle(summary.plotCoeffsWithoutHistory(2, 0, 1)); + payload.summary_plotCoeffsWithoutHistory_title = stringify_text(get(get(coeffOnlyAx, 'Title'), 'String')); + payload.summary_plotCoeffsWithoutHistory_ylabel = stringify_text(get(get(coeffOnlyAx, 'YLabel'), 'String')); + payload.summary_plotCoeffsWithoutHistory_xticklabels = cellstr(get(coeffOnlyAx, 'XTickLabel')); +catch +end +if ~isempty(coeffOnlyHandle) && isgraphics(coeffOnlyHandle) + close(coeffOnlyHandle); +end + +histHandle = []; +try + histHandle = figure('Visible','off'); + histAx = local_axes_handle(summary.plotHistCoeffs(2, 0, 1)); + payload.summary_plotHistCoeffs_title = stringify_text(get(get(histAx, 'Title'), 'String')); + payload.summary_plotHistCoeffs_ylabel = stringify_text(get(get(histAx, 'YLabel'), 'String')); + payload.summary_plotHistCoeffs_xticklabels = cellstr(get(histAx, 'XTickLabel')); +catch +end +if ~isempty(histHandle) && isgraphics(histHandle) + close(histHandle); +end + +summary.plotIC; +icHandle = gcf; +icAxes = findall(icHandle, 'Type', 'axes'); +payload.summary_plotIC_num_axes = numel(icAxes); +for idx = 1:length(icAxes) + ax = icAxes(idx); + titleStr = stringify_text(get(get(ax, 'Title'), 'String')); + ylabelStr = stringify_text(get(get(ax, 'YLabel'), 'String')); + xtickLabels = cellstr(get(ax, 'XTickLabel')); + switch titleStr + case "AIC Across Neurons" + payload.summary_plotIC_aic_title = titleStr; + payload.summary_plotIC_aic_ylabel = ylabelStr; + payload.summary_plotIC_aic_xticklabels = xtickLabels; + case "BIC Across Neurons" + payload.summary_plotIC_bic_title = titleStr; + payload.summary_plotIC_bic_ylabel = ylabelStr; + payload.summary_plotIC_bic_xticklabels = xtickLabels; + case "log likelihood Across Neurons" + payload.summary_plotIC_logll_title = titleStr; + payload.summary_plotIC_logll_ylabel = ylabelStr; + payload.summary_plotIC_logll_xticklabels = xtickLabels; + end +end +close(icHandle); + +plotAICHandle = figure('Visible','off'); +summary.plotAIC(); +plotAICAx = gca; +payload.summary_plotAIC_title = stringify_text(get(get(plotAICAx, 'Title'), 'String')); +payload.summary_plotAIC_ylabel = stringify_text(get(get(plotAICAx, 'YLabel'), 'String')); +payload.summary_plotAIC_xticklabels = cellstr(get(plotAICAx, 'XTickLabel')); +close(plotAICHandle); + +plotBICHandle = figure('Visible','off'); +summary.plotBIC(); +plotBICAx = gca; +payload.summary_plotBIC_title = stringify_text(get(get(plotBICAx, 'Title'), 'String')); +payload.summary_plotBIC_ylabel = stringify_text(get(get(plotBICAx, 'YLabel'), 'String')); +payload.summary_plotBIC_xticklabels = cellstr(get(plotBICAx, 'XTickLabel')); +close(plotBICHandle); + +plotlogLLHandle = figure('Visible','off'); +summary.plotlogLL(); +plotlogLLAx = gca; +payload.summary_plotlogLL_title = stringify_text(get(get(plotlogLLAx, 'Title'), 'String')); +payload.summary_plotlogLL_ylabel = stringify_text(get(get(plotlogLLAx, 'YLabel'), 'String')); +payload.summary_plotlogLL_xticklabels = cellstr(get(plotlogLLAx, 'XTickLabel')); +close(plotlogLLHandle); + +residualHandle = summary.plotResidualSummary; +residualAxes = findall(residualHandle, 'Type', 'axes'); +payload.summary_plotResidual_num_axes = numel(residualAxes); +payload.summary_plotResidual_titles = cell(1, numel(residualAxes)); +payload.summary_plotResidual_ylabels = cell(1, numel(residualAxes)); +payload.summary_plotResidual_xlabels = cell(1, numel(residualAxes)); +payload.summary_plotResidual_line_counts = zeros(1, numel(residualAxes)); +payload.summary_plotResidual_legend_labels = {}; +for idx = 1:length(residualAxes) + ax = residualAxes(idx); + payload.summary_plotResidual_titles{idx} = stringify_text(get(get(ax, 'Title'), 'String')); + payload.summary_plotResidual_ylabels{idx} = stringify_text(get(get(ax, 'YLabel'), 'String')); + payload.summary_plotResidual_xlabels{idx} = stringify_text(get(get(ax, 'XLabel'), 'String')); + payload.summary_plotResidual_line_counts(idx) = numel(findall(ax, 'Type', 'line')); +end +legendHandle = findobj(residualHandle, 'Type', 'legend'); +if ~isempty(legendHandle) + payload.summary_plotResidual_legend_labels = cellstr(legendHandle(1).String); +end +close(residualHandle); save(fullfile(fixtureRoot, 'analysis_multineuron_exactness.mat'), '-struct', 'payload'); end @@ -664,12 +1107,25 @@ function export_fit_summary_fixture(fixtureRoot) stats2 = {struct('se', [0.06], 'p', [0.03]), struct('se', [0.05 0.04 0.03], 'p', [0.01 0.03 0.07])}; fit1 = FitResult(st1, covLabels, numHist, histObjects, ensHistObj, lambda, b1, [1.0 2.0], stats1, [11.0 7.0], [12.0 8.0], [3.0 5.0], configColl, {}, {}, 'poisson'); fit2 = FitResult(st2, covLabels, numHist, histObjects, ensHistObj, lambda, b2, [1.5 2.5], stats2, [13.0 9.0], [14.0 10.0], [2.0 4.0], configColl, {}, {}, 'poisson'); +fixtureZ = [0.2 0.25; 0.4 0.35; 0.6 0.45]; +fixtureU = [0.15 0.20; 0.45 0.50; 0.75 0.80]; +fixtureXAxis = [0.25 0.25; 0.50 0.50; 0.75 0.75]; +fixtureKSSorted = [0.20 0.20; 0.50 0.50; 0.80 0.80]; +fixtureX = [-1.04 -0.84; -0.13 0.00; 0.67 0.84]; +rhoSig = SignalObj((1:3)', [0.1 0.2; 0.05 0.1; 0.0 0.05], 'rhoSig', 'lag', '', '', {'stim','stim_hist'}); +confBoundSig = SignalObj((1:3)', [0.2; 0.1; 0.05], 'confBoundSig', 'lag', '', '', {''}); +fit1.setKSStats(fixtureZ, fixtureU, fixtureXAxis, fixtureKSSorted, [0.25 0.50]); +fit2.setKSStats(fixtureZ, fixtureU, fixtureXAxis, fixtureKSSorted, [0.35 0.55]); +fit1.setInvGausStats(fixtureX, rhoSig, confBoundSig); +fit2.setInvGausStats(fixtureX, rhoSig, confBoundSig); fit1.KSStats.ks_stat = [0.25 0.50]; fit1.KSStats.pValue = [0.90 0.40]; fit1.KSStats.withinConfInt = [1 1]; fit2.KSStats.ks_stat = [0.35 0.55]; fit2.KSStats.pValue = [0.80 0.30]; fit2.KSStats.withinConfInt = [1 0]; +Analysis.plotFitResidual(fit1, 0.01, 0); +Analysis.plotFitResidual(fit2, 0.01, 0); summary = FitResSummary({fit1, fit2}); dAIC = summary.getDiffAIC(1, 0); dBIC = summary.getDiffBIC(1, 0); @@ -687,6 +1143,330 @@ function export_fit_summary_fixture(fixtureRoot) payload.diffAIC = dAIC; payload.diffBIC = dBIC; payload.difflogLL = dlogLL; +payload.structure = summary.toStructure; +plotHandle = summary.plotSummary; +allAxes = findall(plotHandle, 'Type', 'axes'); +for idx = 1:length(allAxes) + ax = allAxes(idx); + titleStr = stringify_text(get(get(ax, 'Title'), 'String')); + ylabelStr = stringify_text(get(get(ax, 'YLabel'), 'String')); + xtickLabels = cellstr(get(ax, 'XTickLabel')); + legendHandle = legend(ax); + legendLabels = {}; + if ~isempty(legendHandle) && isgraphics(legendHandle) + legendLabels = cellstr(legendHandle.String); + end + switch titleStr + case "GLM Coefficients Across Neurons\nwith 95% CIs (* p<0.05)" + payload.plotSummary_coeff_title = titleStr; + payload.plotSummary_coeff_ylabel = ylabelStr; + payload.plotSummary_coeff_xticklabels = xtickLabels; + payload.plotSummary_coeff_legend = legendLabels; + case "KS Statistics Across Neurons" + payload.plotSummary_ks_title = titleStr; + payload.plotSummary_ks_ylabel = ylabelStr; + payload.plotSummary_ks_xticklabels = xtickLabels; + case "Change in AIC Across Neurons" + payload.plotSummary_aic_title = titleStr; + payload.plotSummary_aic_ylabel = ylabelStr; + payload.plotSummary_aic_xticklabels = xtickLabels; + case "Change in BIC Across Neurons" + payload.plotSummary_bic_title = titleStr; + payload.plotSummary_bic_ylabel = ylabelStr; + payload.plotSummary_bic_xticklabels = xtickLabels; + end +end +payload.plotSummary_num_axes = numel(allAxes); +close(plotHandle); + +plotAllCoeffsHandle = figure('Visible','off'); +summary.plotAllCoeffs(); +plotAllCoeffsAx = gca; +payload.plotAllCoeffs_ylabel = stringify_text(get(get(plotAllCoeffsAx, 'YLabel'), 'String')); +payload.plotAllCoeffs_xticklabels = cellstr(get(plotAllCoeffsAx, 'XTickLabel')); +plotAllCoeffsLegend = legend(plotAllCoeffsAx); +payload.plotAllCoeffs_legend = {}; +if ~isempty(plotAllCoeffsLegend) && isgraphics(plotAllCoeffsLegend) + payload.plotAllCoeffs_legend = cellstr(plotAllCoeffsLegend.String); +end +close(plotAllCoeffsHandle); + +coeffOnlyHandle = figure('Visible','off'); +coeffOnlyAx = local_axes_handle(summary.plotCoeffsWithoutHistory(2, 0, 1)); +payload.plotCoeffsWithoutHistory_title = stringify_text(get(get(coeffOnlyAx, 'Title'), 'String')); +payload.plotCoeffsWithoutHistory_ylabel = stringify_text(get(get(coeffOnlyAx, 'YLabel'), 'String')); +payload.plotCoeffsWithoutHistory_xticklabels = cellstr(get(coeffOnlyAx, 'XTickLabel')); +close(coeffOnlyHandle); + +histHandle = figure('Visible','off'); +histAx = local_axes_handle(summary.plotHistCoeffs(2, 0, 1)); +payload.plotHistCoeffs_title = stringify_text(get(get(histAx, 'Title'), 'String')); +payload.plotHistCoeffs_ylabel = stringify_text(get(get(histAx, 'YLabel'), 'String')); +payload.plotHistCoeffs_xticklabels = cellstr(get(histAx, 'XTickLabel')); +close(histHandle); + +fitPlotHandle = fit1.getSubsetFitResult(1).plotResults; +fitPlotAxes = findall(fitPlotHandle, 'Type', 'axes'); +payload.fit_plotResults_num_axes = numel(fitPlotAxes); +payload.fit_plotResults_titles = cell(1, numel(fitPlotAxes)); +payload.fit_plotResults_ylabels = cell(1, numel(fitPlotAxes)); +payload.fit_plotResults_xlabels = cell(1, numel(fitPlotAxes)); +for idx = 1:numel(fitPlotAxes) + ax = fitPlotAxes(idx); + payload.fit_plotResults_titles{idx} = stringify_text(get(get(ax, 'Title'), 'String')); + payload.fit_plotResults_ylabels{idx} = stringify_text(get(get(ax, 'YLabel'), 'String')); + payload.fit_plotResults_xlabels{idx} = stringify_text(get(get(ax, 'XLabel'), 'String')); +end +close(fitPlotHandle); + +singleFit = fit1.getSubsetFitResult(1); + +ksHandle = figure('Visible','off'); +singleFit.KSPlot; +ksAx = gca; +payload.fit_KSPlot_title = stringify_text(get(get(ksAx, 'Title'), 'String')); +payload.fit_KSPlot_ylabel = stringify_text(get(get(ksAx, 'YLabel'), 'String')); +payload.fit_KSPlot_xlabel = stringify_text(get(get(ksAx, 'XLabel'), 'String')); +payload.fit_KSPlot_num_lines = numel(findall(ksAx, 'Type', 'line')); +close(ksHandle); + +invHandle = figure('Visible','off'); +singleFit.plotInvGausTrans; +invAx = gca; +payload.fit_plotInvGausTrans_title = stringify_text(get(get(invAx, 'Title'), 'String')); +payload.fit_plotInvGausTrans_ylabel = stringify_text(get(get(invAx, 'YLabel'), 'String')); +payload.fit_plotInvGausTrans_xlabel = stringify_text(get(get(invAx, 'XLabel'), 'String')); +payload.fit_plotInvGausTrans_num_lines = numel(findall(invAx, 'Type', 'line')); +close(invHandle); + +seqHandle = figure('Visible','off'); +singleFit.plotSeqCorr; +seqAx = gca; +payload.fit_plotSeqCorr_title = stringify_text(get(get(seqAx, 'Title'), 'String')); +payload.fit_plotSeqCorr_ylabel = stringify_text(get(get(seqAx, 'YLabel'), 'String')); +payload.fit_plotSeqCorr_xlabel = stringify_text(get(get(seqAx, 'XLabel'), 'String')); +payload.fit_plotSeqCorr_num_lines = numel(findall(seqAx, 'Type', 'line')); +close(seqHandle); + +resHandle = figure('Visible','off'); +singleFit.plotResidual; +resAx = gca; +payload.fit_plotResidual_title = stringify_text(get(get(resAx, 'Title'), 'String')); +payload.fit_plotResidual_ylabel = stringify_text(get(get(resAx, 'YLabel'), 'String')); +payload.fit_plotResidual_xlabel = stringify_text(get(get(resAx, 'XLabel'), 'String')); +payload.fit_plotResidual_num_lines = numel(findall(resAx, 'Type', 'line')); +close(resHandle); + +coeffHandle = figure('Visible','off'); +singleFit.plotCoeffs; +coeffAx = gca; +payload.fit_plotCoeffs_title = stringify_text(get(get(coeffAx, 'Title'), 'String')); +payload.fit_plotCoeffs_ylabel = stringify_text(get(get(coeffAx, 'YLabel'), 'String')); +payload.fit_plotCoeffs_xlabel = stringify_text(get(get(coeffAx, 'XLabel'), 'String')); +payload.fit_plotCoeffs_xticklabels = cellstr(get(coeffAx, 'XTickLabel')); +payload.fit_plotCoeffs_num_lines = numel(findall(coeffAx, 'Type', 'line')); +close(coeffHandle); + +historyFit = fit1.getSubsetFitResult(2); + +coeffNoHistHandle = figure('Visible','off'); +historyFit.plotCoeffsWithoutHistory; +coeffNoHistAx = gca; +payload.fit_plotCoeffsWithoutHistory_title = stringify_text(get(get(coeffNoHistAx, 'Title'), 'String')); +payload.fit_plotCoeffsWithoutHistory_ylabel = stringify_text(get(get(coeffNoHistAx, 'YLabel'), 'String')); +payload.fit_plotCoeffsWithoutHistory_xlabel = stringify_text(get(get(coeffNoHistAx, 'XLabel'), 'String')); +payload.fit_plotCoeffsWithoutHistory_xticklabels = cellstr(get(coeffNoHistAx, 'XTickLabel')); +payload.fit_plotCoeffsWithoutHistory_num_lines = numel(findall(coeffNoHistAx, 'Type', 'line')); +close(coeffNoHistHandle); + +histCoeffHandle = figure('Visible','off'); +historyFit.plotHistCoeffs; +histCoeffAx = gca; +payload.fit_plotHistCoeffs_title = stringify_text(get(get(histCoeffAx, 'Title'), 'String')); +payload.fit_plotHistCoeffs_ylabel = stringify_text(get(get(histCoeffAx, 'YLabel'), 'String')); +payload.fit_plotHistCoeffs_xlabel = stringify_text(get(get(histCoeffAx, 'XLabel'), 'String')); +payload.fit_plotHistCoeffs_xticklabels = cellstr(get(histCoeffAx, 'XTickLabel')); +payload.fit_plotHistCoeffs_num_lines = numel(findall(histCoeffAx, 'Type', 'line')); +close(histCoeffHandle); + +payload.fit_structure = fit1.toStructure; +payload.fit_history_structure = fit1.getSubsetFitResult(2).toStructure; + +[fitCoeffIndex1, fitCoeffEpochId1, fitCoeffNumEpochs1] = fit1.getCoeffIndex(1); +[fitHistIndex1, fitHistEpochId1, fitHistNumEpochs1] = fit1.getHistIndex(1); +[fitCoeffIndex2, fitCoeffEpochId2, fitCoeffNumEpochs2] = fit1.getCoeffIndex(2); +[fitHistIndex2, fitHistEpochId2, fitHistNumEpochs2] = fit1.getHistIndex(2); +payload.fitCoeffIndex_1 = fitCoeffIndex1; +payload.fitCoeffEpochId_1 = fitCoeffEpochId1; +payload.fitCoeffNumEpochs_1 = fitCoeffNumEpochs1; +payload.fitHistIndex_1 = fitHistIndex1; +payload.fitHistEpochId_1 = fitHistEpochId1; +payload.fitHistNumEpochs_1 = fitHistNumEpochs1; +payload.fitCoeffIndex_2 = fitCoeffIndex2; +payload.fitCoeffEpochId_2 = fitCoeffEpochId2; +payload.fitCoeffNumEpochs_2 = fitCoeffNumEpochs2; +payload.fitHistIndex_2 = fitHistIndex2; +payload.fitHistEpochId_2 = fitHistEpochId2; +payload.fitHistNumEpochs_2 = fitHistNumEpochs2; +[fitParamCoeff1, fitParamSe1, fitParamSig1] = fit1.getParam({'stim'}, 1); +[fitParamCoeff2, fitParamSe2, fitParamSig2] = fit1.getParam({'stim'}, 2); +payload.fitParamCoeff_1 = fitParamCoeff1; +payload.fitParamSe_1 = fitParamSe1; +payload.fitParamSig_1 = fitParamSig1; +payload.fitParamCoeff_2 = fitParamCoeff2; +payload.fitParamSe_2 = fitParamSe2; +payload.fitParamSig_2 = fitParamSig2; + +[summaryPlotParams] = summary.plotParams; +payload.plotParams_xLabels = cellstr(summaryPlotParams.xLabels); +payload.plotParams_bAct = summaryPlotParams.bAct; +payload.plotParams_seAct = summaryPlotParams.seAct; +payload.plotParams_sigIndex = summaryPlotParams.sigIndex; +payload.plotParams_numResultsCoeffPresent = summaryPlotParams.numResultsCoeffPresent; +payload.sigCoeffs_fit1 = summary.getSigCoeffs(1); +[coeffMatFit1, coeffLabelsFit1, coeffSeFit1] = summary.getCoeffs(1); +payload.coeffMat_fit1 = coeffMatFit1; +payload.coeffLabels_fit1 = coeffLabelsFit1; +payload.coeffSe_fit1 = coeffSeFit1; +[coeffMatFit2, coeffLabelsFit2, coeffSeFit2] = summary.getCoeffs(2); +payload.coeffMat_fit2 = coeffMatFit2; +payload.coeffLabels_fit2 = coeffLabelsFit2; +payload.coeffSe_fit2 = coeffSeFit2; +[histCoeffMatFit2, histCoeffLabelsFit2] = summary.getHistCoeffs(2); +payload.histCoeffMat_fit2 = histCoeffMatFit2; +payload.histCoeffLabels_fit2 = histCoeffLabelsFit2; + +[coeffIndex, coeffEpochId, coeffNumEpochs] = summary.getCoeffIndex; +[histIndex, histEpochId, histNumEpochs] = summary.getHistIndex; +payload.coeffIndex = coeffIndex; +payload.coeffEpochId = coeffEpochId; +payload.coeffNumEpochs = coeffNumEpochs; +payload.histIndex = histIndex; +payload.histEpochId = histEpochId; +payload.histNumEpochs = histNumEpochs; +[coeffIndexFit2, coeffEpochIdFit2, coeffNumEpochsFit2] = summary.getCoeffIndex(2); +[histIndexFit2, histEpochIdFit2, histNumEpochsFit2] = summary.getHistIndex(2); +payload.coeffIndex_fit2 = coeffIndexFit2; +payload.coeffEpochId_fit2 = coeffEpochIdFit2; +payload.coeffNumEpochs_fit2 = coeffNumEpochsFit2; +payload.histIndex_fit2 = histIndexFit2; +payload.histEpochId_fit2 = histEpochIdFit2; +payload.histNumEpochs_fit2 = histNumEpochsFit2; + +[coeffSummaryN, coeffSummaryEdges, coeffSummaryPercentSig] = summary.binCoeffs; +payload.coeffSummary_bins = coeffSummaryN; +payload.coeffSummary_edges = coeffSummaryEdges; +payload.coeffSummary_percentSig = coeffSummaryPercentSig; + +coeff2dHandle = figure('Visible','off'); +coeff2dPlotHandles = summary.plot2dCoeffSummary(gca); +coeff2dAx = gca; +if isempty(summary.plotParams) + summary.computePlotParams; +end +payload.plot2dCoeffSummary_yticklabels = cellstr(summary.plotParams.xLabels); +payload.plot2dCoeffSummary_num_lines = numel(coeff2dPlotHandles); +textHandles = findall(coeff2dAx, 'Type', 'text'); +payload.plot2dCoeffSummary_text = cell(1, numel(textHandles)); +for idx = 1:numel(textHandles) + payload.plot2dCoeffSummary_text{idx} = stringify_text(get(textHandles(idx), 'String')); +end +close(coeff2dHandle); + +coeff3dHandle = figure('Visible','off'); +coeff3dAx = axes('Parent', coeff3dHandle); +coeff3dPlotHandles = summary.plot3dCoeffSummary(coeff3dAx); +if isempty(summary.plotParams) + summary.computePlotParams; +end +payload.plot3dCoeffSummary_yticklabels = cellstr(summary.plotParams.xLabels); +payload.plot3dCoeffSummary_num_surfaces = numel(coeff3dPlotHandles); +close(coeff3dHandle); + +summary.plotIC; +icHandle = gcf; +icAxes = findall(icHandle, 'Type', 'axes'); +payload.plotIC_num_axes = numel(icAxes); +for idx = 1:length(icAxes) + ax = icAxes(idx); + titleStr = stringify_text(get(get(ax, 'Title'), 'String')); + ylabelStr = stringify_text(get(get(ax, 'YLabel'), 'String')); + xtickLabels = cellstr(get(ax, 'XTickLabel')); + switch titleStr + case "AIC Across Neurons" + payload.plotIC_aic_title = titleStr; + payload.plotIC_aic_ylabel = ylabelStr; + payload.plotIC_aic_xticklabels = xtickLabels; + case "BIC Across Neurons" + payload.plotIC_bic_title = titleStr; + payload.plotIC_bic_ylabel = ylabelStr; + payload.plotIC_bic_xticklabels = xtickLabels; + case "log likelihood Across Neurons" + payload.plotIC_logll_title = titleStr; + payload.plotIC_logll_ylabel = ylabelStr; + payload.plotIC_logll_xticklabels = xtickLabels; + end +end +close(icHandle); + +plotAICHandle = figure('Visible','off'); +summary.plotAIC; +plotAICAx = gca; +payload.plotAIC_title = stringify_text(get(get(plotAICAx, 'Title'), 'String')); +payload.plotAIC_ylabel = stringify_text(get(get(plotAICAx, 'YLabel'), 'String')); +payload.plotAIC_xticklabels = cellstr(get(plotAICAx, 'XTickLabel')); +close(plotAICHandle); + +plotBICHandle = figure('Visible','off'); +summary.plotBIC; +plotBICAx = gca; +payload.plotBIC_title = stringify_text(get(get(plotBICAx, 'Title'), 'String')); +payload.plotBIC_ylabel = stringify_text(get(get(plotBICAx, 'YLabel'), 'String')); +payload.plotBIC_xticklabels = cellstr(get(plotBICAx, 'XTickLabel')); +close(plotBICHandle); + +plotlogLLHandle = figure('Visible','off'); +summary.plotlogLL; +plotlogLLAx = gca; +payload.plotlogLL_title = stringify_text(get(get(plotlogLLAx, 'Title'), 'String')); +payload.plotlogLL_ylabel = stringify_text(get(get(plotlogLLAx, 'YLabel'), 'String')); +payload.plotlogLL_xticklabels = cellstr(get(plotlogLLAx, 'XTickLabel')); +close(plotlogLLHandle); + +residualHandle = summary.plotResidualSummary; +residualAxes = findall(residualHandle, 'Type', 'axes'); +payload.plotResidualSummary_num_axes = numel(residualAxes); +payload.plotResidualSummary_titles = cell(1, numel(residualAxes)); +payload.plotResidualSummary_ylabels = cell(1, numel(residualAxes)); +payload.plotResidualSummary_xlabels = cell(1, numel(residualAxes)); +payload.plotResidualSummary_line_counts = zeros(1, numel(residualAxes)); +payload.plotResidualSummary_legend_labels = {}; +for idx = 1:length(residualAxes) + ax = residualAxes(idx); + payload.plotResidualSummary_titles{idx} = stringify_text(get(get(ax, 'Title'), 'String')); + payload.plotResidualSummary_ylabels{idx} = stringify_text(get(get(ax, 'YLabel'), 'String')); + payload.plotResidualSummary_xlabels{idx} = stringify_text(get(get(ax, 'XLabel'), 'String')); + payload.plotResidualSummary_line_counts(idx) = numel(findall(ax, 'Type', 'line')); +end +legendHandle = findobj(residualHandle, 'Type', 'legend'); +if ~isempty(legendHandle) + payload.plotResidualSummary_legend_labels = cellstr(legendHandle(1).String); +end +close(residualHandle); + +payload.roundtrip_supported = false; +payload.roundtrip_error = ''; +try + roundtrip = FitResSummary.fromStructure(payload.structure); + payload.roundtrip_supported = true; + payload.roundtrip_AIC = roundtrip.AIC; + payload.roundtrip_BIC = roundtrip.BIC; + payload.roundtrip_logLL = roundtrip.logLL; + payload.roundtrip_neuronNumbers = roundtrip.neuronNumbers; + payload.roundtrip_fitNames = roundtrip.fitNames; +catch err + payload.roundtrip_error = err.message; +end save(fullfile(fixtureRoot, 'fit_summary_exactness.mat'), '-struct', 'payload'); end @@ -1114,7 +1894,8 @@ function export_simulated_network_fixture(fixtureRoot) function cifObj = build_polynomial_binomial_cif(beta) beta = beta(:)'; -syms x y real +x = sym('x', 'real'); +y = sym('y', 'real'); cifObj = CIF(beta(1:3), {'1', 'x', 'y'}, {'x', 'y'}, 'binomial'); cifObj.b = beta; cifObj.varIn = [sym(1); x; y; x^2; y^2; x * y]; @@ -1162,3 +1943,35 @@ function export_simulated_network_fixture(fixtureRoot) end cifObj.argstrLDGamma = ''; end + +function out = stringify_text(value) +if isstring(value) + out = char(strjoin(cellstr(value), newline)); +elseif ischar(value) + out = value; +elseif iscell(value) + parts = cellfun(@stringify_text, value, 'UniformOutput', false); + out = strjoin(parts, newline); +else + out = ''; +end +end + +function ax = local_axes_handle(handleObj) +if iscell(handleObj) + handleObj = [handleObj{:}]; +end +if isa(handleObj, 'matlab.graphics.axis.Axes') + ax = handleObj; + if numel(ax) > 1 + ax = ax(1); + end + return; +end +ax = ancestor(handleObj, 'axes'); +if isempty(ax) + ax = gca; +elseif numel(ax) > 1 + ax = ax(1); +end +end