diff --git a/README.md b/README.md index ba1ddbd7..ab29d068 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,9 @@ One of nSTAT's key strengths is point-process generalized linear models for spik train signals that provide a formal statistical framework for processing signals recorded from ensembles of single neurons. It also has extensive support for model fitting, model-order analysis, and adaptive decoding — including state-space GLM -(SSGLM) estimation via EM, unscented Kalman filtering (UKF), and hybrid -discrete/continuous point-process filters. +(SSGLM) estimation via EM, unscented Kalman filtering (UKF), goal-directed +point-process adaptive filters (PPAF), and hybrid discrete/continuous +point-process filters (PPHF). Although created with neural signal processing in mind, nSTAT can be used as a generic tool for analyzing any types of discrete and continuous signals, and thus diff --git a/docs/ClassDefinitions.md b/docs/ClassDefinitions.md index 7da9db12..14ab03ff 100644 --- a/docs/ClassDefinitions.md +++ b/docs/ClassDefinitions.md @@ -102,12 +102,21 @@ Primary notebook: `notebooks/nstCollExamples.ipynb` **Time operations**: `shiftTime`, `setMinTime`, `setMaxTime` +**Data export**: +`dataToMatrix`, `resample` + +**PSTH**: +`psth`, `psthGLM`, `estimateVarianceAcrossTrials`, `psthBars` + **SSGLM (state-space GLM)**: `ssglm`, `ssglmFB` **Basis generation**: `generateUnitImpulseBasis` +**Plotting**: +`plot` + ### `History` (`nstat.History`) Primary notebook: `notebooks/HistoryExamples.ipynb` @@ -149,7 +158,8 @@ Primary notebook: `notebooks/TrialExamples.ipynb` `flattenMask` **Utilities**: -`shiftCovariates`, `resampleEnsColl`, `restoreToOriginal`, `plot` +`shiftCovariates`, `resampleEnsColl`, `restoreToOriginal`, `getAllLabels`, +`plot` ### `TrialConfig` (`nstat.TrialConfig`) @@ -230,7 +240,8 @@ Alias of `FitSummary`. Aggregates multiple `FitResult` objects. **Plotting**: `plotIC`, `plotAIC`, `plotBIC`, `plotlogLL`, `plotResidualSummary`, -`plotSummary`, `boxPlot` +`plotSummary`, `boxPlot`, `plotAllCoeffs`, `plot3dCoeffSummary`, +`plot2dCoeffSummary`, `plotKSSummary` ### `DecodingAlgorithms` (`nstat.DecodingAlgorithms`) @@ -238,11 +249,14 @@ Primary notebook: `notebooks/DecodingExample.ipynb` **Point-process decode filters**: `PPDecodeFilterLinear`, `PPDecodeFilter`, `PPHybridFilterLinear`, -`ComputeStimulusCIs` +`ComputeStimulusCIs`, `computeSpikeRateCIs` **Kalman and unscented Kalman filters**: `kalman_filter`, `PP_fixedIntervalSmoother`, `ukf` +**Helper methods**: +`PPDecode_predict`, `PPDecode_updateLinear` + **State-space GLM (SSGLM) — EM algorithm**: `PPSS_EStep`, `PPSS_MStep`, `PPSS_EM`, `PPSS_EMFB` diff --git a/docs/PaperOverview.md b/docs/PaperOverview.md index 92423dc9..9ecde011 100644 --- a/docs/PaperOverview.md +++ b/docs/PaperOverview.md @@ -88,10 +88,16 @@ map to: **Point-process adaptive filters (Section 2.5)**: - `DecodingAlgorithms.PPDecodeFilterLinear` — linear-CIF point-process - adaptive filter for continuous stimulus decoding. + adaptive filter for continuous stimulus decoding. Supports goal-directed + decoding via backward information filter (Srinivasan et al. 2006) when + `yT` and `PiT` target parameters are provided. - `DecodingAlgorithms.PPDecodeFilter` — general CIF version using symbolic gradients/Jacobians. -- `DecodingAlgorithms.ComputeStimulusCIs` — stimulus confidence intervals. +- `DecodingAlgorithms.ComputeStimulusCIs` — stimulus confidence intervals + via Monte Carlo sampling (dual-path: 4-D SSGLM cross-trial + 3-D smoother + z-score). +- `DecodingAlgorithms.computeSpikeRateCIs` — spike rate confidence intervals + and pairwise significance testing across trials. - `DecodingAlgorithms.PP_fixedIntervalSmoother` — fixed-interval smoother for off-line smoothing of decode estimates. @@ -100,6 +106,8 @@ map to: - `DecodingAlgorithms.PPHybridFilterLinear` — joint discrete/continuous state estimation combining point-process observations with a hidden Markov model over discrete states and Kalman filtering over continuous kinematics. + Supports goal-directed decoding via per-model backward information filters + when `yT` and `PiT` are provided. **Kalman and UKF filters**: diff --git a/examples/README.md b/examples/README.md index cfffe48b..648eb9ce 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,17 +1,44 @@ -# Python nSTAT Examples +# nSTAT Python Examples -## Basic examples +## Paper Examples (Self-Contained) + +Five self-contained scripts mirroring the MATLAB paper examples. Each +generates publication-quality figures and supports `--export-figures`. ```bash -python3 examples/basic_data_workflow.py -python3 examples/fit_poisson_glm.py -python3 examples/simulate_population_psth.py +python examples/paper/example01_mepsc_poisson.py --export-figures +python examples/paper/example02_whisker_stimulus_thalamus.py --export-figures +python examples/paper/example03_psth_and_ssglm.py --export-figures +python examples/paper/example04_place_cells_continuous_stimulus.py --export-figures +python examples/paper/example05_decoding_ppaf_pphf.py --export-figures ``` -## Paper-style example workflow +| Example | Focus | Paper Section | +|---|---|---| +| 01 | mEPSC Poisson models (constant vs piecewise baseline) | 2.3.1 | +| 02 | Whisker stimulus GLM with lag and history selection | 2.3.2 | +| 03 | PSTH and SSGLM across-trial dynamics | 2.3.3-2.3.4 | +| 04 | Place-cell receptive fields (Gaussian vs Zernike) | 2.3.5 | +| 05 | PPAF and hybrid filter decoding | 2.5-2.6 | + +## Basic Examples ```bash -python3 examples/nstat_paper_examples.py --repo-root .. +python examples/basic_data_workflow.py +python examples/fit_poisson_glm.py +python examples/simulate_population_psth.py ``` -This mirrors key analyses described in the nSTAT paper using the bundled Python APIs. +## README Examples (Quick Checks) + +```bash +python examples/readme_examples/example1_multitaper_and_spectrogram.py +python examples/readme_examples/example2_simulate_cif_spiketrain_10s.py +python examples/readme_examples/example3_nstcoll_raster_from_example2.py +``` + +## Jupyter Notebooks + +All 29 class-tutorial and data-analysis notebooks are in `notebooks/`. +They mirror the MATLAB helpfile examples one-to-one. See +[docs/Examples.md](../docs/Examples.md) for the full index. diff --git a/examples/paper/example01_mepsc_poisson.py b/examples/paper/example01_mepsc_poisson.py index e93c7263..3e104660 100644 --- a/examples/paper/example01_mepsc_poisson.py +++ b/examples/paper/example01_mepsc_poisson.py @@ -128,33 +128,14 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non print(f" AIC: {resultConst.AIC}") print(f" BIC: {resultConst.BIC}") - # --- Figure 1: Constant Mg2+ raster + diagnostics + lambda --- - fig1, axes1 = plt.subplots(2, 2, figsize=(14, 9)) - - # Subplot 1: Neural raster - ax = axes1[0, 0] - spikeCollConst.plot(handle=ax) - ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", - fontweight="bold", fontsize=12) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") - ax.set_yticks([0, 1]) - - # Subplot 2: Inverse Gaussian Transform (autocorrelation of rescaled residuals) - ax = axes1[0, 1] - resultConst.plotInvGausTrans(handle=ax) - - # Subplot 3: KS plot - ax = axes1[1, 0] - resultConst.KSPlot(handle=ax) - - # Subplot 4: Lambda estimate - ax = axes1[1, 1] - resultConst.lambda_signal.plot(handle=ax) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.legend(["$\\lambda_{const}$"], loc="upper right") - - fig1.suptitle("Example 01 — Figure 1: Constant Mg$^{2+}$ Summary", fontsize=14, fontweight="bold") + # --- Figure 1: Constant Mg2+ diagnostics (Matlab-matching plotResults) --- + # Matlab calls resultConst.plotResults which creates a 2x4 grid: + # [KSPlot (2-wide)] [InvGausTrans] [SeqCorr] + # [plotCoeffs (2-wide)] [plotResidual (2-wide)] + fig1 = plt.figure(figsize=(14, 9)) + resultConst.plotResults(handle=fig1) + fig1.suptitle("Example 01 — Figure 1: Constant Mg$^{2+}$ Summary", + fontsize=14, fontweight="bold") fig1.tight_layout() figure_files.extend(_maybe_export(fig1, export_dir, "fig01_constant_mg_summary")) @@ -198,8 +179,12 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non print("\n=== Part 3: Piecewise Baseline Model Comparison ===") # Build piecewise indicator covariates - timeInd1 = np.searchsorted(timeWashout, 495.0) - timeInd2 = np.searchsorted(timeWashout, 765.0) + # Matlab: find(time<495,1,'last') — last index strictly before 495 + # np.searchsorted gives first index >= 495, so subtract 1 isn't needed + # because Python slice [:idx] is exclusive. But Matlab's 1:timeInd1 is + # inclusive, so we need searchsorted(..., side='right') to include 495. + timeInd1 = np.searchsorted(timeWashout, 495.0, side="right") + timeInd2 = np.searchsorted(timeWashout, 765.0, side="right") N = len(timeWashout) constantRate = np.ones((N, 1)) @@ -233,34 +218,12 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non print(f" AIC: {resultWashout.AIC}") print(f" BIC: {resultWashout.BIC}") - # --- Figure 3: Piecewise model diagnostics + lambda comparison --- - fig3, axes3 = plt.subplots(2, 2, figsize=(14, 9)) - - # Subplot 1: Raster with epoch boundaries - ax = axes3[0, 0] - spikeCollWashout.plot(handle=ax) - ax.set_title("Neural Raster with decreasing Mg$^{2+}$ Concentration", - fontweight="bold", fontsize=12) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.axvline(495.0, color="r", linewidth=4) - ax.axvline(765.0, color="r", linewidth=4) - - # Subplot 2: Inverse Gaussian Transform - ax = axes3[0, 1] - resultWashout.plotInvGausTrans(handle=ax) - - # Subplot 3: KS plot - ax = axes3[1, 0] - resultWashout.KSPlot(handle=ax) - - # Subplot 4: Lambda comparison - ax = axes3[1, 1] - resultWashout.lambda_signal.plot(handle=ax) - ax.set_ylim(0, 5) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.legend(["$\\lambda_{const}$", "$\\lambda_{const-epoch}$"], loc="upper right") - - fig3.suptitle("Example 01 — Figure 3: Piecewise Baseline Comparison", fontsize=14, fontweight="bold") + # --- Figure 3: Piecewise model diagnostics (Matlab-matching plotResults) --- + # Matlab calls resultWashout.plotResults which creates the same 2x4 grid. + fig3 = plt.figure(figsize=(14, 9)) + resultWashout.plotResults(handle=fig3) + fig3.suptitle("Example 01 — Figure 3: Piecewise Baseline Comparison", + fontsize=14, fontweight="bold") fig3.tight_layout() figure_files.extend(_maybe_export(fig3, export_dir, "fig03_piecewise_baseline_comparison")) diff --git a/examples/paper/example02_whisker_stimulus_thalamus.py b/examples/paper/example02_whisker_stimulus_thalamus.py index 8b4f1abd..7747037a 100644 --- a/examples/paper/example02_whisker_stimulus_thalamus.py +++ b/examples/paper/example02_whisker_stimulus_thalamus.py @@ -12,74 +12,406 @@ (whisker displacement ``t``, binary spike indicator ``y``, 1000 Hz). Expected outputs: - - Figure 1: Data overview (raster, stimulus, velocity). + - Figure 1: Data overview (raster, stimulus displacement, velocity). - Figure 2: Lag selection (CCF), history diagnostics, KS plot, coefficients. Paper mapping: - Section 2.3.2 (thalamic whisker-stimulus analysis). + Section 2.3.2 (thalamic whisker-stimulus analysis); Figs. 4 and 11. """ from __future__ import annotations import argparse -import json import sys from pathlib import Path +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from scipy.io import loadmat + +# --------------------------------------------------------------------------- +# Ensure nstat is importable when running from the examples/paper directory. +# --------------------------------------------------------------------------- THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +import nstat # noqa: E402 +from nstat import ( # noqa: E402 + Analysis, + ConfigColl, + CovColl, + nspikeTrain, + nstColl, + Trial, + TrialConfig, +) +from nstat.signal import Covariate # noqa: E402 from nstat.data_manager import ensure_example_data # noqa: E402 -from nstat.paper_examples_full import run_experiment2 # noqa: E402 -from nstat.paper_figures import export_named_paper_figures # noqa: E402 - - -def run_example02(*, export_figures: bool = False, export_dir: Path | None = None): - """Run Example 02: Whisker stimulus GLM. - - Analysis workflow (mirrors Matlab example02_whisker_stimulus_thalamus.m): - 1. Load trngdataBis.mat — stimulus displacement and spike indicator. - 2. Compute cross-covariance between residual spikes and stimulus. - 3. Identify peak lag; shift stimulus by optimal lag. - 4. Fit 3 nested GLMs: - (a) baseline only, - (b) baseline + stimulus + velocity, - (c) baseline + stimulus + velocity + spike history. - 5. Sweep history orders 1..28 via AIC/BIC to select optimal lag. - 6. Generate figures comparing models. + + +# ========================================================================= +# Helper: export figure +# ========================================================================= +def _maybe_export(fig, export_dir: Path | None, name: str, dpi: int = 250): + """Save figure to disk if export_dir is set.""" + saved = [] + if export_dir is not None: + export_dir.mkdir(parents=True, exist_ok=True) + png_path = export_dir / f"{name}.png" + fig.savefig(png_path, dpi=dpi, bbox_inches="tight") + saved.append(png_path) + print(f" Saved {png_path}") + return saved + + +# ========================================================================= +# Main example function +# ========================================================================= +def run_example02(*, export_figures: bool = False, export_dir: Path | None = None, + visible: bool = True): + """Run Example 02: Whisker stimulus GLM with lag and history selection. + + Mirrors Matlab example02_whisker_stimulus_thalamus.m exactly: + 1. Load trngdataBis.mat (struct with fields t=stimulus, y=spike indicator). + 2. Construct nSTAT objects (nspikeTrain, Covariate, Trial). + 3. Fit baseline-only GLM; compute residual cross-covariance with stimulus. + 4. Identify optimal lag from peak xcov; shift stimulus by that lag. + 5. Sweep history windows via Analysis.computeHistLagForAll with logspace grid. + 6. Select optimal history order from min(AIC_idx, BIC_idx). + 7. Fit 3 nested models: baseline, baseline+stim, baseline+stim+hist. + 8. Generate 2 figures with Matlab-matching subplot layouts. """ + if not visible: + matplotlib.use("Agg") + data_dir = ensure_example_data(download=True) + figure_files: list[Path] = [] + + sampleRate = 1000 # Hz + + # ================================================================== + # Load data from trngdataBis.mat + # ================================================================== + print("=== Example 02: Whisker Stimulus GLM ===") + + mat_path = (data_dir / "Explicit Stimulus" / "Dir3" / "Neuron1" + / "Stim2" / "trngdataBis.mat") + d = loadmat(mat_path, squeeze_me=True, struct_as_record=False) + + # Extract stimulus signal and spike indicator from struct + # Matlab: data.t is stimulus, data.y is binary spike indicator + if hasattr(d.get("data", None), "t"): + stimData = np.asarray(d["data"].t, dtype=float).reshape(-1) + yData = np.asarray(d["data"].y, dtype=float).reshape(-1) + else: + # Fallback: try direct keys + stimData = np.asarray(d["t"], dtype=float).reshape(-1) + yData = np.asarray(d["y"], dtype=float).reshape(-1) + + # Construct time vector at 1 ms resolution + time = np.arange(0, len(stimData)) * (1.0 / sampleRate) + + # Extract spike times from binary indicator + spikeTimes = time[yData == 1] + print(f" Data length: {len(stimData)} samples ({time[-1]:.1f} s)") + print(f" Total spikes: {len(spikeTimes)}") + + # ================================================================== + # Create nSTAT objects + # ================================================================== + # Stimulus covariate (divided by 10, matching Matlab: stimData ./ 10) + stim = Covariate( + time, stimData / 10.0, + "Stimulus", "time", "s", "mm", + dataLabels=["stim"], + ) + # Constant baseline covariate + baseline = Covariate( + time, np.ones((len(time), 1)), + "Baseline", "time", "s", "", + dataLabels=["constant"], + ) + + nst = nspikeTrain(spikeTimes) + spikeColl = nstColl(nst) + trial = Trial(spikeColl, CovColl([stim, baseline])) + + # ================================================================== + # Figure 1: Data overview — raster, stimulus, velocity (3x1 layout) + # ================================================================== + fig1, axes1 = plt.subplots(3, 1, figsize=(14, 9)) + viewWindow = 21.0 # First 21 seconds, matching Matlab + + # Subplot 1: Neural raster (first 21 s) + ax = axes1[0] + nstView = nspikeTrain(spikeTimes) + nstView.setMaxTime(viewWindow) + nstView.plot(handle=ax) + ax.set_yticks([0, 1]) + ax.set_title("Neural Raster", fontweight="bold", fontsize=12) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel("Spikes", fontsize=12, fontweight="bold") + + # Subplot 2: Stimulus displacement (first 21 s) + ax = axes1[1] + stimView = stim.getSigInTimeWindow(0, viewWindow) + stimView.plot(handle=ax) + ax.set_ylabel("Displacement [mm]", fontsize=12, fontweight="bold") + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + + # Subplot 3: Stimulus velocity (derivative, first 21 s) + ax = axes1[2] + stimDeriv = stim.derivative + stimDerivView = stimDeriv.getSigInTimeWindow(0, viewWindow) + stimDerivView.plot(handle=ax) + ax.set_ylim(-80, 80) + ax.set_ylabel("Velocity", fontsize=12, fontweight="bold") + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + + fig1.suptitle("Example 02 — Figure 1: Data Overview", + fontsize=14, fontweight="bold") + fig1.tight_layout() + figure_files.extend(_maybe_export( + fig1, export_dir, "fig01_data_overview")) + + # ================================================================== + # Fit baseline-only model + # ================================================================== + print("\n--- Fitting baseline-only model ---") + cfgBase = TrialConfig([("Baseline", "constant")], sampleRate, [], []) + cfgBase.setName("Baseline") + baselineResults = Analysis.RunAnalysisForAllNeurons( + trial, ConfigColl([cfgBase]), 0) - # Run analysis (returns summary statistics and figure payload) - summary, payload = run_experiment2(data_dir, return_payload=True) + # ================================================================== + # Compute residual cross-covariance with stimulus to find optimal lag + # ================================================================== + print("--- Computing residual cross-covariance ---") + residual = baselineResults.computeFitResidual() + xcovSig = residual.xcov(stim) - print(json.dumps(summary, indent=2)) + # Window to positive lags [0, 1] s (matching Matlab) + xcovWindowed = xcovSig.windowedSignal([0, 1]) + + # Find peak lag — findGlobalPeak returns (times, values) + peakTimes, peakVals = xcovWindowed.findGlobalPeak("maxima") + shiftTime = float(peakTimes[0]) + peakVal = float(peakVals[0]) + print(f" Peak xcov at lag = {shiftTime:.4f} s (value = {peakVal:.4f})") + + # ================================================================== + # Shift stimulus by optimal lag and build new Trial + # ================================================================== + # Matlab: stimShifted = Covariate(time, stimData, ...).shift(shiftTime) + # Note: Matlab uses raw stimData (not /10) with units 'V' for the shifted version + stimShifted = Covariate( + time, stimData, + "Stimulus", "time", "s", "V", + dataLabels=["stim"], + ) + stimShifted = stimShifted.shift(shiftTime) + + baselineMu = Covariate( + time, np.ones((len(time), 1)), + "Baseline", "time", "s", "", + dataLabels=["\\mu"], + ) + + trialShifted = Trial( + nstColl(nspikeTrain(spikeTimes)), + CovColl([stimShifted, baselineMu]), + ) - if export_figures: - if export_dir is None: - export_dir = THIS_DIR / "figures" / "example02" - saved = export_named_paper_figures( - "example02", summary=summary, payload=payload, export_dir=export_dir - ) - print(f"\nGenerated {len(saved)} figure(s):") - for p in saved: - print(f" {p}") + # ================================================================== + # History model-order search via computeHistLagForAll + # ================================================================== + print("\n--- Sweeping history windows ---") + delta = 1.0 / sampleRate + maxWindow = 1.0 + numWindows = 32 - return summary + # Construct log-spaced history window boundaries (matching Matlab) + logVals = np.logspace(np.log10(delta), np.log10(maxWindow), numWindows) + windowTimes = np.concatenate([[0.0], logVals]) + # Round to nearest ms and remove duplicates + windowTimes = np.unique(np.round(windowTimes * sampleRate) / sampleRate) + print(f" Window boundaries: {len(windowTimes)} unique values") + print(f" Range: [{windowTimes[0]:.4f}, {windowTimes[-1]:.4f}] s") + historySweep = Analysis.computeHistLagForAll( + trialShifted, windowTimes, + CovLabels=[("Baseline", "\\mu"), ("Stimulus", "stim")], + Algorithm="GLM", + batchMode=0, + sampleRate=sampleRate, + makePlot=0, + ) + + # ================================================================== + # Select optimal history order + # ================================================================== + # historySweep is a list of FitResult objects (one per neuron) + sweep = historySweep[0] + aicArr = np.asarray(sweep.AIC, dtype=float) + bicArr = np.asarray(sweep.BIC, dtype=float) + ksArr = np.asarray(sweep.KSStats, dtype=float).ravel() + + # Delta AIC/BIC relative to no-history model (index 0) + dAIC = aicArr[1:] - aicArr[0] + dBIC = bicArr[1:] - bicArr[0] + + # Find index of minimum delta (offset by +1 since we skipped index 0) + aicIdx = int(np.argmin(dAIC)) + 1 if dAIC.size > 0 else None + bicIdx = int(np.argmin(dBIC)) + 1 if dBIC.size > 0 else None + ksIdx = int(np.argmin(ksArr)) if ksArr.size > 0 else 0 + + # Take minimum of AIC and BIC optimal indices + candidates = [] + if aicIdx is not None and aicIdx > 0: + candidates.append(aicIdx) + if bicIdx is not None and bicIdx > 0: + candidates.append(bicIdx) + windowIndex = min(candidates) if candidates else ksIdx + + if windowIndex > len(windowTimes): + windowIndex = ksIdx + + # Extract selected history windows + if windowIndex > 1: + selectedHistory = list(windowTimes[:windowIndex]) + else: + selectedHistory = [] + + print(f" AIC optimal index: {aicIdx}") + print(f" BIC optimal index: {bicIdx}") + print(f" KS optimal index: {ksIdx}") + print(f" Selected window index: {windowIndex}") + print(f" Selected history: {len(selectedHistory)} windows") + + # ================================================================== + # Final 3-model comparison + # ================================================================== + print("\n--- Fitting 3 nested models ---") + + cfg1 = TrialConfig([("Baseline", "\\mu")], sampleRate, [], []) + cfg1.setName("Baseline") + + cfg2 = TrialConfig( + [("Baseline", "\\mu"), ("Stimulus", "stim")], + sampleRate, [], [], + ) + cfg2.setName("Baseline+Stimulus") + + cfg3 = TrialConfig( + [("Baseline", "\\mu"), ("Stimulus", "stim")], + sampleRate, selectedHistory, [], + ) + cfg3.setName("Baseline+Stimulus+Hist") + + modelCompare = Analysis.RunAnalysisForAllNeurons( + trialShifted, ConfigColl([cfg1, cfg2, cfg3]), 0) + modelCompare.lambda_signal.setDataLabels([ + "\\lambda_{const}", + "\\lambda_{const+stim}", + "\\lambda_{const+stim+hist}", + ]) + + print(f" AIC: {modelCompare.AIC}") + print(f" BIC: {modelCompare.BIC}") + + # ================================================================== + # Figure 2: Lag selection, history diagnostics, KS, coefficients + # (Matlab uses subplot(7,2,...) layout) + # ================================================================== + fig2 = plt.figure(figsize=(14, 12)) + import matplotlib.gridspec as gridspec + gs = gridspec.GridSpec(7, 2, figure=fig2, hspace=0.5, wspace=0.3) + + # --- Left column, rows 1-3: Cross-correlation function --- + ax_xcov = fig2.add_subplot(gs[0:3, 0]) + xcovWindowed.plot(handle=ax_xcov) + ax_xcov.plot(shiftTime, peakVal, "ro", markersize=8, + markerfacecolor="r", markeredgecolor="r", linewidth=3) + ax_xcov.set_title("Residual Cross-Covariance", fontweight="bold") + ax_xcov.set_xlabel("Lag [s]") + ax_xcov.set_ylabel("Cross-covariance") + + # --- Right column, row 1: KS statistic vs Q --- + ax_ks_sweep = fig2.add_subplot(gs[0, 1]) + xvals = np.arange(len(ksArr)) + ax_ks_sweep.plot(xvals, ksArr, ".-") + if windowIndex < len(ksArr): + ax_ks_sweep.plot(xvals[windowIndex], ksArr[windowIndex], "r*", + markersize=10) + ax_ks_sweep.set_title("KS Statistic vs Q", fontweight="bold") + ax_ks_sweep.set_xlabel("Number of History Windows") + ax_ks_sweep.set_ylabel("KS Stat") + + # --- Right column, row 2: Delta AIC vs Q --- + ax_daic = fig2.add_subplot(gs[1, 1]) + dAIC_full = aicArr - aicArr[0] + ax_daic.plot(np.arange(len(dAIC_full)), dAIC_full, ".-") + if windowIndex < len(dAIC_full): + ax_daic.plot(windowIndex, dAIC_full[windowIndex], "r*", markersize=10) + ax_daic.set_title("$\\Delta$AIC vs Q", fontweight="bold") + ax_daic.set_xlabel("Number of History Windows") + ax_daic.set_ylabel("$\\Delta$AIC") + + # --- Right column, row 3: Delta BIC vs Q --- + ax_dbic = fig2.add_subplot(gs[2, 1]) + dBIC_full = bicArr - bicArr[0] + ax_dbic.plot(np.arange(len(dBIC_full)), dBIC_full, ".-") + if windowIndex < len(dBIC_full): + ax_dbic.plot(windowIndex, dBIC_full[windowIndex], "r*", markersize=10) + ax_dbic.set_title("$\\Delta$BIC vs Q", fontweight="bold") + ax_dbic.set_xlabel("Number of History Windows") + ax_dbic.set_ylabel("$\\Delta$BIC") + + # --- Left column, rows 5-7: KS plot (3 models) --- + ax_ks = fig2.add_subplot(gs[4:7, 0]) + modelCompare.KSPlot(handle=ax_ks) + + # --- Right column, rows 5-7: Coefficient comparison --- + ax_coeff = fig2.add_subplot(gs[4:7, 1]) + modelCompare.plotCoeffs(handle=ax_coeff) + + fig2.suptitle("Example 02 — Figure 2: Lag & History Selection", + fontsize=14, fontweight="bold") + figure_files.extend(_maybe_export( + fig2, export_dir, "fig02_lag_and_model_comparison")) + + if visible: + plt.show() + + print(f"\nExample 02 complete. Generated {len(figure_files)} figure(s).") + return figure_files + + +# ========================================================================= +# CLI entry point +# ========================================================================= if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Example 02: Whisker Stimulus GLM") - parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) - parser.add_argument("--export-figures", action="store_true") - parser.add_argument("--export-dir", type=Path, default=None) - parser.add_argument("--output-json", type=Path, default=None) + parser = argparse.ArgumentParser( + description="Example 02: Whisker Stimulus GLM") + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT, + help="Repository root (default: auto-detected).") + parser.add_argument("--export-figures", action="store_true", + help="Export figures to disk.") + parser.add_argument("--export-dir", type=Path, default=None, + help="Directory for exported figures.") + parser.add_argument("--no-display", action="store_true", + help="Run without displaying figures (headless).") args = parser.parse_args() - result = run_example02( + export_dir = args.export_dir + if args.export_figures and export_dir is None: + export_dir = THIS_DIR / "figures" / "example02" + + run_example02( export_figures=args.export_figures, - export_dir=args.export_dir, + export_dir=export_dir if args.export_figures else None, + visible=not args.no_display, ) - if args.output_json: - args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example03_psth_and_ssglm.py b/examples/paper/example03_psth_and_ssglm.py index a5f66c72..377d6e4b 100644 --- a/examples/paper/example03_psth_and_ssglm.py +++ b/examples/paper/example03_psth_and_ssglm.py @@ -8,20 +8,21 @@ 4) Across-trial learning dynamics and stimulus-effect surfaces. The example has two parts: - Part A (experiment3): PSTH analysis — simulate 20 trials from sinusoidal - CIF, load real data from ``data/PSTH/Results.mat``, compare histogram + Part A (PSTH): Simulate 20 trials from a sinusoidal CIF, + load real data from ``data/PSTH/Results.mat``, compare histogram PSTH vs GLM-PSTH. - Part B (experiment3b): SSGLM analysis — simulate 50-trial dataset with - across-trial gain modulation, fit SSGLM via EM, visualise learning - dynamics and 3-D stimulus-effect surfaces. + Part B (SSGLM): Simulate 50-trial dataset with across-trial gain + modulation, load precomputed SSGLM fit from + ``data/SSGLMExampleData.mat``, visualise learning dynamics and + 3-D stimulus-effect surfaces. Expected outputs: - - Figure 1: Simulated and real rasters. + - Figure 1: Simulated CIF + simulated/real raster examples. - Figure 2: PSTH comparison (histogram vs GLM). - Figure 3: SSGLM simulation summary. - - Figure 4: SSGLM fit diagnostics. + - Figure 4: SSGLM vs PSTH model diagnostics. - Figure 5: Stimulus-effect surfaces (3-D). - - Figure 6: Learning-trial comparison. + - Figure 6: Learning-trial comparison and significance matrix. Paper mapping: Section 2.3.3 (PSTH) and Section 2.4 (SSGLM). @@ -29,71 +30,608 @@ from __future__ import annotations import argparse -import json import sys from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from scipy.io import loadmat + THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from nstat import ( # noqa: E402 + Analysis, + Covariate, + CovariateCollection, + FitResult, + Trial, + TrialConfig, + ConfigCollection, +) +from nstat.cif import CIF # noqa: E402 +from nstat.confidence_interval import ConfidenceInterval # noqa: E402 +from nstat.core import nspikeTrain # noqa: E402 from nstat.data_manager import ensure_example_data # noqa: E402 -from nstat.paper_examples_full import run_experiment3, run_experiment3b # noqa: E402 -from nstat.paper_figures import export_named_paper_figures # noqa: E402 +from nstat.decoding_algorithms import DecodingAlgorithms # noqa: E402 +from nstat.trial import SpikeTrainCollection # noqa: E402 -def run_example03(*, export_figures: bool = False, export_dir: Path | None = None): - """Run Example 03: PSTH and SSGLM dynamics. +# ===================================================================== +# Helper: load Matlab FitResult struct into Python FitResult +# ===================================================================== +def _load_matlab_fitresult(mat_struct, spike_trains): + """Convert a Matlab FitResult structured array to a Python FitResult. - Analysis workflow (mirrors Matlab example03_psth_and_ssglm.m): + Parameters + ---------- + mat_struct : numpy structured array + The `fR` or `psthR` field from the .mat file. + spike_trains : list[nspikeTrain] + The spike trains corresponding to this FitResult (since Matlab + MCOS objects cannot be deserialized by scipy). + """ + # Extract lambda signal + lam = mat_struct["lambda"].item() + lam_time = np.asarray(lam["time"].item(), dtype=float).ravel() + lam_data = np.asarray(lam["data"].item(), dtype=float) + lam_name = str(lam["name"].item()) if lam["name"].size else "\\lambda" - Part A — PSTH: - 1. Define sinusoidal CIF: lambda(t) = exp(b0 + b1*cos(2*pi*f*t)). - 2. Simulate 20 spike trains via CIF thinning. - 3. Load real multi-trial data from PSTH/Results.mat. - 4. Compute histogram PSTH and GLM-PSTH; compare. + lambda_cov = Covariate( + lam_time, lam_data, lam_name, "time", "s", "spikes/sec", + ) - Part B — SSGLM: - 5. Simulate 50-trial population with across-trial stimulus gain. - 6. Fit SSGLM via EM (forward-backward Kalman + Newton M-step). - 7. Plot per-trial coefficient trajectories and confidence bands. - 8. Generate 3-D stimulus-effect surface and learning-trial figure. - """ - data_dir = ensure_example_data(download=True) + # Extract scalar statistics + b_raw = np.asarray(mat_struct["b"].item(), dtype=float).reshape(-1) + AIC_val = float(np.asarray(mat_struct["AIC"].item(), dtype=float).ravel()[0]) + BIC_val = float(np.asarray(mat_struct["BIC"].item(), dtype=float).ravel()[0]) + logLL_val = float(np.asarray(mat_struct["logLL"].item(), dtype=float).ravel()[0]) + config_name = str(mat_struct["configNames"].item()) + + # Extract covariate labels + cov_labels_raw = mat_struct["covLabels"].item() + if isinstance(cov_labels_raw, np.ndarray): + cov_labels = [str(x) for x in cov_labels_raw.ravel()] + elif isinstance(cov_labels_raw, str): + cov_labels = [cov_labels_raw] + else: + cov_labels = list(cov_labels_raw) if cov_labels_raw is not None else [] + + num_hist_raw = mat_struct["numHist"].item() + num_hist = [int(num_hist_raw)] if np.isscalar(num_hist_raw) else [int(x) for x in np.asarray(num_hist_raw).ravel()] + + cfgs = ConfigCollection([TrialConfig(name=config_name)]) + + return FitResult( + spike_trains, + [cov_labels], # covLabels (list of lists) + num_hist, # numHist + [], # histObjects + [], # ensHistObjects + lambda_cov, # lambda_signal + [b_raw], # b + [0.0], # dev + [None], # stats + [AIC_val], # AIC + [BIC_val], # BIC + [logLL_val], # logLL + cfgs, # configColl + [], # XvalData + [], # XvalTime + "poisson", # distribution + ) + + +# ===================================================================== +# Part A: PSTH Analysis +# ===================================================================== +def run_part_a(data_dir, export_dir=None): + """Simulate and real PSTH + GLM-PSTH analysis.""" + print("=== Part A: PSTH Analysis ===") + + # ------------------------------------------------------------------ + # 1. Define sinusoidal CIF: lambda(t) = sigmoid(sin(2*pi*f*t) + mu) / dt + # ------------------------------------------------------------------ + delta = 0.001 + tmax = 1.0 + time = np.arange(0.0, tmax + delta, delta) + f = 2 + mu = -3 + + lambdaRaw = np.sin(2 * np.pi * f * time) + mu + lambdaData = np.exp(lambdaRaw) / (1 + np.exp(lambdaRaw)) * (1 / delta) + lambdaCov = Covariate( + time, lambdaData, "\\lambda(t)", "time", "s", "spikes/sec", + ["\\lambda_{1}"], + ) + + # ------------------------------------------------------------------ + # 2. Simulate 20 spike trains via CIF thinning + # ------------------------------------------------------------------ + numRealizations = 20 + spikeCollSim = CIF.simulateCIFByThinningFromLambda( + lambdaCov, numRealizations, seed=0, + ) + spikeCollSim.setMinTime(0.0) + spikeCollSim.setMaxTime(tmax) + print(f" Simulated {numRealizations} spike trains") + + # ------------------------------------------------------------------ + # 3. Load real PSTH data from Results.mat + # ------------------------------------------------------------------ + psth_path = data_dir / "PSTH" / "Results.mat" + psthData = loadmat(str(psth_path), squeeze_me=False) + Results = psthData["Results"][0, 0] + Data = Results["Data"][0, 0] + STC = Data["Spike_times_STC"][0, 0] + SUA = STC["balanced_SUA"][0, 0] + numTrials = int(SUA["Nr_trials"][0, 0]) + spikeTimesArr = SUA["spike_times"] # shape (16, numTrials, 8) + + # Cell 6 (Matlab 1-indexed) + trains6 = [] + for iTrial in range(numTrials): + st = spikeTimesArr[0, iTrial, 5].ravel() # cell index 5 = cell 6 + nst = nspikeTrain(st, name="6", minTime=0.0, maxTime=2.0, makePlots=-1) + trains6.append(nst) + spikeCollReal1 = SpikeTrainCollection(trains6) + spikeCollReal1.setMinTime(0.0) + spikeCollReal1.setMaxTime(2.0) + + # Cell 1 (Matlab 1-indexed) + trains1 = [] + for iTrial in range(numTrials): + st = spikeTimesArr[0, iTrial, 0].ravel() # cell index 0 = cell 1 + nst = nspikeTrain(st, name="1", minTime=0.0, maxTime=2.0, makePlots=-1) + trains1.append(nst) + spikeCollReal2 = SpikeTrainCollection(trains1) + spikeCollReal2.setMinTime(0.0) + spikeCollReal2.setMaxTime(2.0) + print(f" Loaded real data: {numTrials} trials, cells 6 and 1") + + # ------------------------------------------------------------------ + # Figure 1: Simulated CIF + simulated/real rasters (2x2) + # ------------------------------------------------------------------ + fig1, axes1 = plt.subplots(2, 2, figsize=(14, 9)) + + # Top-left: CIF + ax = axes1[0, 0] + ax.plot(time, lambdaData, "b", linewidth=2) + ax.set_title("Simulated CIF", fontweight="bold", fontsize=14) + ax.set_xlabel("time [s]") + ax.set_ylabel("spikes/sec") + + # Bottom-left: simulated raster + ax = axes1[1, 0] + spikeCollSim.plot(handle=ax) + ax.set_yticks(range(0, numRealizations + 1, 5)) + ax.set_title(f"{numRealizations} Simulated Sample Paths", + fontweight="bold", fontsize=14) + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") + + # Top-right: real cell 6 raster + ax = axes1[0, 1] + spikeCollReal1.plot(handle=ax) + ax.set_yticks(range(0, numTrials + 1, 2)) + ax.set_title("Response to Moving Visual Stimulus (Neuron 6)", + fontweight="bold", fontsize=14) + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") + + # Bottom-right: real cell 1 raster + ax = axes1[1, 1] + spikeCollReal2.plot(handle=ax) + ax.set_yticks(range(0, numTrials + 1, 2)) + ax.set_title("Response to Moving Visual Stimulus (Neuron 1)", + fontweight="bold", fontsize=14) + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") + + fig1.tight_layout() + + # ------------------------------------------------------------------ + # 4. Compute PSTH and GLM-PSTH + # ------------------------------------------------------------------ + binsize = 0.05 + psthSim = spikeCollSim.psth(binsize) + psthGLMSim, _, _ = spikeCollSim.psthGLM(binsize) + + psthReal1 = spikeCollReal1.psth(binsize) + psthGLMReal1, _, _ = spikeCollReal1.psthGLM(binsize) + + psthReal2 = spikeCollReal2.psth(binsize) + psthGLMReal2, _, _ = spikeCollReal2.psthGLM(binsize) + print(" PSTH and GLM-PSTH computed for all 3 collections") + + # ------------------------------------------------------------------ + # Figure 2: PSTH comparison (2x3) + # ------------------------------------------------------------------ + fig2, axes2 = plt.subplots(2, 3, figsize=(14, 9)) + + # Top row: rasters + spikeCollSim.plot(handle=axes2[0, 0]) + axes2[0, 0].set_yticks(range(0, numRealizations + 1, 2)) + axes2[0, 0].set_xlabel("time [s]") + axes2[0, 0].set_ylabel("Trial [k]") + + spikeCollReal1.plot(handle=axes2[0, 1]) + axes2[0, 1].set_yticks(range(0, numTrials + 1, 2)) + axes2[0, 1].set_xlabel("time [s]") + axes2[0, 1].set_ylabel("Trial [k]") + + spikeCollReal2.plot(handle=axes2[0, 2]) + axes2[0, 2].set_yticks(range(0, numTrials + 1, 2)) + axes2[0, 2].set_xlabel("time [s]") + axes2[0, 2].set_ylabel("Trial [k]") + + # Bottom row: PSTH comparisons + ax = axes2[1, 0] + h_true, = ax.plot(time, lambdaData, "b", linewidth=4, label="true") + psth_time = np.asarray(psthSim.time, dtype=float).ravel() + psth_data = np.asarray(psthSim.data, dtype=float).ravel() + h_psth, = ax.plot(psth_time, psth_data, "rx", linewidth=4, label="PSTH") + glm_time = np.asarray(psthGLMSim.time, dtype=float).ravel() + glm_data = np.asarray(psthGLMSim.data, dtype=float).ravel() + h_glm, = ax.plot(glm_time, glm_data, "k", linewidth=4, label="PSTH_{glm}") + ax.legend(handles=[h_true, h_psth, h_glm]) + ax.set_xlabel("time [s]") + ax.set_ylabel("[spikes/sec]") + + ax = axes2[1, 1] + psth_t1 = np.asarray(psthReal1.time, dtype=float).ravel() + psth_d1 = np.asarray(psthReal1.data, dtype=float).ravel() + glm_t1 = np.asarray(psthGLMReal1.time, dtype=float).ravel() + glm_d1 = np.asarray(psthGLMReal1.data, dtype=float).ravel() + h2, = ax.plot(psth_t1, psth_d1, "rx", linewidth=4, label="PSTH") + h3, = ax.plot(glm_t1, glm_d1, "k", linewidth=4, label="PSTH_{glm}") + ax.legend(handles=[h2, h3]) + ax.set_xlabel("time [s]") + ax.set_ylabel("[spikes/sec]") + + ax = axes2[1, 2] + psth_t2 = np.asarray(psthReal2.time, dtype=float).ravel() + psth_d2 = np.asarray(psthReal2.data, dtype=float).ravel() + glm_t2 = np.asarray(psthGLMReal2.time, dtype=float).ravel() + glm_d2 = np.asarray(psthGLMReal2.data, dtype=float).ravel() + h2, = ax.plot(psth_t2, psth_d2, "rx", linewidth=4, label="PSTH") + h3, = ax.plot(glm_t2, glm_d2, "k", linewidth=4, label="PSTH_{glm}") + ax.legend(handles=[h2, h3]) + ax.set_xlabel("time [s]") + ax.set_ylabel("[spikes/sec]") + + fig2.tight_layout() + + figures = {"fig01_simulated_and_real_rasters": fig1, "fig02_psth_comparison": fig2} + return figures, spikeCollSim, lambdaCov + + +# ===================================================================== +# Part B: SSGLM Analysis +# ===================================================================== +def run_part_b(data_dir, export_dir=None): + """SSGLM simulation, diagnostics, stimulus surfaces, learning trial.""" + print("\n=== Part B: SSGLM Analysis ===") + + # ------------------------------------------------------------------ + # 1. Simulate 50-trial CIF with across-trial stimulus gain + # ------------------------------------------------------------------ + delta = 0.001 + tmax = 1.0 + time = np.arange(0.0, tmax + delta, delta) + f = 2 + numRealizations = 50 + b0 = -3 + + # Linearly increasing stimulus gain across trials + b1 = 3 * np.arange(1, numRealizations + 1) / numRealizations + + # Simulate each trial using CIF.simulateCIF + trains = [] + for iTrial in range(numRealizations): + u = np.sin(2 * np.pi * f * time) + stim = Covariate(time, u, "Stimulus", "time", "s", "V", ["sin"]) + ens = Covariate(time, np.zeros_like(time), "Ensemble", "time", "s", + "Spikes", ["n1"]) + + histCoeffs = [-4, -1, -0.5] + + sC, lambdaTemp = CIF.simulateCIF( + b0, histCoeffs, [b1[iTrial]], [0], + stim, ens, 1, "binomial", + seed=iTrial, return_lambda=True, + ) + nst = sC.getNST(1) + nst = nst.nstCopy() + nst.resample(1 / delta) + trains.append(nst) + + spikeColl = SpikeTrainCollection(trains) + spikeColl.setMinTime(0.0) + spikeColl.setMaxTime(tmax) + print(f" Simulated {numRealizations} spike trains with CIF.simulateCIF") - # --- Part A: PSTH analysis --- - summary3, payload3 = run_experiment3(return_payload=True) + # Compute true CIF surface: sigma(b0 + b1[k]*u(t)) / delta + u = np.sin(2 * np.pi * f * time) + stimDataEta = np.outer(u, b1) # (T, K) + stimData = np.exp(stimDataEta + b0) + stimData = stimData / (1 + stimData) / delta # binomial link - # --- Part B: SSGLM analysis --- - summary3b, payload3b = run_experiment3b(data_dir, return_payload=True) + # ------------------------------------------------------------------ + # Figure 3: SSGLM simulation summary (3x2) + # ------------------------------------------------------------------ + fig3, axes3 = plt.subplots(3, 2, figsize=(14, 9)) - # Merge summaries for JSON output - combined_summary = { - "experiment3": summary3, - "experiment3b": summary3b, + # (1,1): Within-trial stimulus + ax = axes3[0, 0] + ax.plot(time, u, "k", linewidth=3) + ax.set_xlabel("time [s]") + ax.set_ylabel("Stimulus") + ax.set_title("Within Trial Stimulus", fontweight="bold", fontsize=14) + + # (1,2): Across-trial gain + ax = axes3[0, 1] + ax.plot(np.arange(1, numRealizations + 1), b1, "k", linewidth=3) + ax.set_xlabel("Trial [k]") + ax.set_ylabel("Stimulus Gain") + ax.set_title("Across Trial Stimulus Gain", fontweight="bold", fontsize=14) + + # (2,1)+(2,2): Raster spanning both columns + axes3[1, 1].remove() + ax = axes3[1, 0] + ax.set_position( + [axes3[1, 0].get_position().x0, + axes3[1, 0].get_position().y0, + axes3[1, 1].get_position().x1 - axes3[1, 0].get_position().x0, + axes3[1, 0].get_position().height] + ) + spikeColl.plot(handle=ax) + ax.set_yticks(range(0, numRealizations + 1, 10)) + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") + ax.set_title("Simulated Neural Raster", fontweight="bold", fontsize=14) + + # (3,1)+(3,2): True CIF heatmap spanning both columns + axes3[2, 1].remove() + ax = axes3[2, 0] + ax.set_position( + [axes3[2, 0].get_position().x0, + axes3[2, 0].get_position().y0, + ax.get_position().width * 2.1, + axes3[2, 0].get_position().height] + ) + ax.imshow(stimData.T, aspect="auto", origin="lower", + extent=[time[0], time[-1], 1, numRealizations]) + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") + ax.set_title("True Conditional Intensity Function", fontweight="bold", + fontsize=14) + ax.set_yticks(range(0, numRealizations + 1, 10)) + + fig3.tight_layout() + + # ------------------------------------------------------------------ + # 2. Compute PSTH-GLM and prepare data matrices + # (Matlab: psthGLM + dN before loading precomputed SSGLM) + # ------------------------------------------------------------------ + numBasis = 25 + basisWidth = (tmax - 0.0) / numBasis + windowTimes = np.arange(0.0, 0.004, delta) + fitType = "poisson" + + spikeColl.resample(1 / delta) + spikeColl.setMaxTime(tmax) + + dN = spikeColl.dataToMatrix() + if dN.ndim == 1: + dN = dN.reshape(1, -1) + dN = np.asarray(dN, dtype=float) + dN[dN > 1] = 1 + + psthSig, _, _ = spikeColl.psthGLM(basisWidth, windowTimes, fitType) + print(" Computed psthGLM on 50-trial collection") + + # ------------------------------------------------------------------ + # 3. Load precomputed SSGLM data + # ------------------------------------------------------------------ + ssglm_path = data_dir / "SSGLMExampleData.mat" + ssglm = loadmat(str(ssglm_path), squeeze_me=True) + + xK = np.asarray(ssglm["xK"], dtype=float) # (25, 50) + WkuFinal = np.asarray(ssglm["WkuFinal"], dtype=float) # (25, 25, 50, 50) + stimulus_true = np.asarray(ssglm["stimulus"], dtype=float) # (25, 50) + stimCIs = np.asarray(ssglm["stimCIs"], dtype=float) # (25, 50, 2) + gammahat = np.asarray(ssglm["gammahat"], dtype=float) # (3,) + K = xK.shape[1] + print(f" Loaded precomputed SSGLM: {numBasis} basis x {K} trials") + + # ------------------------------------------------------------------ + # 4. Reconstruct FitResult objects from loaded data + # ------------------------------------------------------------------ + ssglm_fit = _load_matlab_fitresult(ssglm["fR"], trains) + psth_fit = _load_matlab_fitresult(ssglm["psthR"], trains) + + tCompare = psth_fit.mergeResults(ssglm_fit) + tCompare.lambda_signal.setDataLabels( + ["\\lambda_{PSTH}", "\\lambda_{SSGLM}"] + ) + + # ------------------------------------------------------------------ + # Figure 4: SSGLM vs PSTH diagnostics (2x2) + # ------------------------------------------------------------------ + fig4, axes4 = plt.subplots(2, 2, figsize=(14, 9)) + tCompare.KSPlot(handle=axes4[0, 0]) + tCompare.plotResidual(handle=axes4[0, 1]) + tCompare.plotInvGausTrans(handle=axes4[1, 0]) + tCompare.plotSeqCorr(handle=axes4[1, 1]) + fig4.tight_layout() + print(" Figure 4: SSGLM vs PSTH diagnostics") + + # ------------------------------------------------------------------ + # 5. Compute stimulus effect surfaces + # ------------------------------------------------------------------ + sampleRate = 1 / delta + + unitPulseBasis = SpikeTrainCollection.generateUnitImpulseBasis( + basisWidth, 0.0, tmax, sampleRate, + ) + basisMat = np.asarray(unitPulseBasis.data, dtype=float) # (T, numBasis) + basis_time = np.asarray(unitPulseBasis.time, dtype=float).ravel() + + # True stimulus effect (Poisson link, matching fitType for analysis) + u_basis = np.sin(2 * np.pi * f * basis_time) + actStimEffect = np.exp(np.outer(u_basis, b1) + b0) / delta # (T, K) + + # PSTH surface (constant across trials — replicate fresh psthGLM output) + psthSig_data = np.asarray(psthSig.data, dtype=float).ravel() + psthSurface2D = np.tile(psthSig_data[:, None], (1, numRealizations)) + + # SSGLM estimated CIF from basis coefficients + estStimEffect = np.exp(basisMat @ xK) / delta # (T, K) + + # ------------------------------------------------------------------ + # Figure 5: True/PSTH/SSGLM stimulus effect surfaces (3D mesh) + # ------------------------------------------------------------------ + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 + + fig5 = plt.figure(figsize=(10, 12)) + trial_axis = np.arange(1, numRealizations + 1) + T_act = min(actStimEffect.shape[0], len(basis_time)) + T_mesh, K_mesh = np.meshgrid( + basis_time[:T_act], trial_axis, indexing="ij" + ) + + ax = fig5.add_subplot(3, 1, 1, projection="3d") + ax.plot_surface(K_mesh, T_mesh, actStimEffect[:T_act, :], + cmap="viridis", edgecolor="none") + ax.view_init(elev=-90, azim=90) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("True Stimulus Effect", fontweight="bold", fontsize=14) + + ax = fig5.add_subplot(3, 1, 2, projection="3d") + ax.plot_surface(K_mesh, T_mesh, psthSurface2D[:T_act, :], + cmap="viridis", edgecolor="none") + ax.view_init(elev=-90, azim=90) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("PSTH Estimated Stimulus Effect", fontweight="bold", + fontsize=14) + + ax = fig5.add_subplot(3, 1, 3, projection="3d") + ax.plot_surface(K_mesh, T_mesh, estStimEffect[:T_act, :], + cmap="viridis", edgecolor="none") + ax.view_init(elev=-90, azim=90) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("SSGLM Estimated Stimulus Effect", fontweight="bold", + fontsize=14) + + fig5.tight_layout() + print(" Figure 5: Stimulus effect surfaces (3D mesh)") + + # ------------------------------------------------------------------ + # 6. Learning-trial analysis: spike rate CIs + # ------------------------------------------------------------------ + tRate, probMat, sigMat = DecodingAlgorithms.computeSpikeRateCIs( + xK, WkuFinal, dN, 0, tmax, fitType, delta, gammahat, windowTimes, + ) + + # Find first learning trial (first column where significance appears) + sig_cols = np.where(sigMat[0, :] == 1)[0] + lt = int(sig_cols[0]) if sig_cols.size > 0 else 2 + if lt < 2: + lt = 2 + + # ------------------------------------------------------------------ + # Figure 6: Learning trial comparison + significance matrix (2x3) + # ------------------------------------------------------------------ + fig6 = plt.figure(figsize=(14, 9)) + + # (1,1): average spike rate with learning trial marker + ax1 = fig6.add_subplot(2, 3, 1) + rate_time = np.asarray(tRate.time, dtype=float).ravel() + rate_data = np.asarray(tRate.data, dtype=float).ravel() + ax1.plot(rate_time, rate_data, "k", linewidth=4) + ylims = ax1.get_ylim() + ax1.plot([lt, lt], ylims, "r", linewidth=2) + ax1.set_xlabel("Trial [k]") + ax1.set_ylabel("Average Firing Rate [spikes/sec]") + ax1.set_title(f"Learning Trial: {lt}", fontweight="bold", fontsize=12) + + # (1,2)+(1,3)+(2,2)+(2,3): significance matrix + ax2 = fig6.add_subplot(2, 3, (2, 6)) + ax2.imshow(probMat, cmap="gray_r", aspect="auto") + kTrials = sigMat.shape[0] + for k in range(kTrials): + for m in range(k + 1, kTrials): + if sigMat[k, m] == 1: + ax2.plot(m, k, "r*", markersize=3) + ax2.xaxis.set_ticks_position("top") + ax2.yaxis.set_ticks_position("right") + ax2.set_xlabel("Trial Number") + ax2.set_ylabel("Trial Number") + + # (2,1): CIF comparison for trial 1 vs learning trial + ax3 = fig6.add_subplot(2, 3, 4) + stim1_data = basisMat @ stimulus_true[:, 0] + stimlt_data = basisMat @ stimulus_true[:, lt - 1] + ci1_lo = basisMat @ stimCIs[:, 0, 0] + ci1_hi = basisMat @ stimCIs[:, 0, 1] + cilt_lo = basisMat @ stimCIs[:, lt - 1, 0] + cilt_hi = basisMat @ stimCIs[:, lt - 1, 1] + + ax3.fill_between(basis_time, ci1_lo, ci1_hi, alpha=0.3, color="gray") + ax3.fill_between(basis_time, cilt_lo, cilt_hi, alpha=0.3, color="red") + h1, = ax3.plot(basis_time, stim1_data, "k", linewidth=4, + label="\\lambda_1(t)") + h2, = ax3.plot(basis_time, stimlt_data, "r", linewidth=4, + label=f"\\lambda_{{{lt}}}(t)") + ax3.legend(handles=[h1, h2]) + ax3.set_xlabel("time [s]") + ax3.set_ylabel("Firing Rate [spikes/sec]") + ax3.set_title("Learning Trial Vs. Baseline Trial\nwith 95% CIs", + fontweight="bold", fontsize=12) + + fig6.tight_layout() + print(f" Figure 6: Learning trial = {lt}") + + figures = { + "fig03_ssglm_simulation_summary": fig3, + "fig04_ssglm_fit_diagnostics": fig4, + "fig05_stimulus_effect_surfaces": fig5, + "fig06_learning_trial_comparison": fig6, } - print(json.dumps(combined_summary, indent=2)) + return figures + + +# ===================================================================== +# Main +# ===================================================================== +def run_example03(*, export_figures: bool = False, export_dir: Path | None = None): + """Run Example 03: PSTH and SSGLM dynamics.""" + data_dir = ensure_example_data(download=True) + + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example03" + + figs_a, _, _ = run_part_a(data_dir, export_dir) + figs_b = run_part_b(data_dir, export_dir) + + all_figs = {**figs_a, **figs_b} if export_figures: - if export_dir is None: - export_dir = THIS_DIR / "figures" / "example03" - # Figure generation needs the combined dicts (multi-section example) - combined_payload = { - "experiment3": payload3, - "experiment3b": payload3b, - } - saved = export_named_paper_figures( - "example03", - summary=combined_summary, - payload=combined_payload, - export_dir=export_dir, - ) - print(f"\nGenerated {len(saved)} figure(s):") - for p in saved: - print(f" {p}") + export_dir.mkdir(parents=True, exist_ok=True) + for name, fig in all_figs.items(): + path = export_dir / f"{name}.png" + fig.savefig(str(path), dpi=150, bbox_inches="tight") + print(f" Saved {path}") - return combined_summary + plt.show() + print(f"\nExample 03 complete. Generated {len(all_figs)} figure(s).") + return all_figs if __name__ == "__main__": @@ -103,12 +641,9 @@ def run_example03(*, export_figures: bool = False, export_dir: Path | None = Non parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) parser.add_argument("--export-figures", action="store_true") parser.add_argument("--export-dir", type=Path, default=None) - parser.add_argument("--output-json", type=Path, default=None) args = parser.parse_args() - result = run_example03( + run_example03( export_figures=args.export_figures, export_dir=args.export_dir, ) - if args.output_json: - args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example04_place_cells_continuous_stimulus.py b/examples/paper/example04_place_cells_continuous_stimulus.py index f5a3cae0..96013b6d 100644 --- a/examples/paper/example04_place_cells_continuous_stimulus.py +++ b/examples/paper/example04_place_cells_continuous_stimulus.py @@ -4,17 +4,18 @@ This example demonstrates: 1) Loading hippocampal place-cell data from two animals. 2) Visualising spike locations overlaid on the animal's path. - 3) Fitting Gaussian and Zernike polynomial receptive-field models. - 4) Comparing model families via KS, AIC, and BIC statistics. - 5) Generating 2-D heatmaps and 3-D mesh plots of place fields. + 3) Loading precomputed Gaussian and Zernike polynomial receptive-field fits. + 4) Comparing model families via KS, AIC, and BIC statistics using FitSummary. + 5) Generating 2-D heatmaps of place fields for all neurons. + 6) Generating 3-D mesh comparison for selected example cells. Data provenance: - Uses ``data/PlaceCellDataAnimal1.mat`` and ``data/PlaceCellDataAnimal2.mat`` - (position trajectories + multi-neuron spike times). + Uses ``data/PlaceCellDataAnimal{1,2}.mat`` (trajectories + spike times) + and ``PlaceCellAnimal{1,2}Results.mat`` (precomputed FitResult structures). Expected outputs: - Figure 1: Example cells — spike locations over path (4 cells per animal). - - Figure 2: Population model-comparison statistics (Delta-KS, Delta-AIC, Delta-BIC). + - Figure 2: Population model-comparison statistics (dKS, dAIC, dBIC). - Figure 3: Gaussian receptive-field heatmaps (Animal 1). - Figure 4: Zernike receptive-field heatmaps (Animal 1). - Figure 5: Gaussian receptive-field heatmaps (Animal 2). @@ -22,55 +23,408 @@ - Figure 7: 3-D mesh comparison for selected example cells. Paper mapping: - Section 2.3.4 (place-cell continuous-stimulus analysis). + Section 2.3.5 (place-cell continuous-stimulus analysis). """ from __future__ import annotations import argparse -import json +import math import sys from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from scipy.io import loadmat + THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from nstat import ( # noqa: E402 + Covariate, + FitResult, + FitSummary, + TrialConfig, + ConfigCollection, +) +from nstat.core import nspikeTrain # noqa: E402 from nstat.data_manager import ensure_example_data # noqa: E402 -from nstat.paper_examples_full import run_experiment4 # noqa: E402 -from nstat.paper_figures import export_named_paper_figures # noqa: E402 +from nstat.zernike import zernike_basis_from_cartesian # noqa: E402 + + +# ===================================================================== +# Helpers +# ===================================================================== +def _load_animal_data(path): + """Load place-cell trajectory and spike data from a .mat file.""" + d = loadmat(str(path), squeeze_me=True) + x = np.asarray(d["x"], dtype=float).ravel() + y = np.asarray(d["y"], dtype=float).ravel() + time = np.asarray(d["time"], dtype=float).ravel() + neurons = np.asarray(d["neuron"], dtype=object).ravel() + return x, y, time, neurons + + +def _load_animal_results(path, x, y, time, neurons): + """Load precomputed FitResult structures and reconstruct Python FitResults.""" + d = loadmat(str(path), squeeze_me=True) + res_structs = np.asarray(d["resStruct"], dtype=object).ravel() + fit_results = [] + + for i, rs in enumerate(res_structs): + # Extract lambda signal + lam = rs["lambda"].item() + lam_time = np.asarray(lam["time"].item(), dtype=float).ravel() + lam_data = np.asarray(lam["data"].item(), dtype=float) + lam_name = str(lam["name"].item()) if lam["name"].size else "\\lambda" + + lambda_cov = Covariate( + lam_time, lam_data, lam_name, "time", "s", "spikes/sec", + ) + + # Extract coefficients + b_raw = rs["b"].item() + if b_raw.dtype == object: + b_list = [np.asarray(b_raw[j], dtype=float).ravel() for j in range(b_raw.size)] + else: + b_list = [np.asarray(b_raw, dtype=float).ravel()] + + numResults = int(np.asarray(rs["numResults"].item()).ravel()[0]) + + # Extract AIC/BIC/logLL + AIC = np.asarray(rs["AIC"].item(), dtype=float).ravel() + BIC = np.asarray(rs["BIC"].item(), dtype=float).ravel() + logLL = np.asarray(rs["logLL"].item(), dtype=float).ravel() + + # Config names + cn_raw = rs["configNames"].item() + if isinstance(cn_raw, np.ndarray): + config_names = [str(c) for c in cn_raw.ravel()] + else: + config_names = [str(cn_raw)] + + # Covariate labels + if "covLabels" in rs.dtype.names: + cl_raw = rs["covLabels"].item() + cl = cl_raw + else: + cl = [] + if isinstance(cl, np.ndarray) and cl.dtype == object: + cov_labels = [] + for j in range(cl.size): + item = cl[j] + if isinstance(item, np.ndarray): + cov_labels.append([str(x) for x in item.ravel()]) + else: + cov_labels.append([str(item)]) + elif isinstance(cl, str): + cov_labels = [[cl]] * numResults + else: + cov_labels = [config_names] * numResults + + # Create spike train for this neuron + st = np.asarray(neurons[i]["spikeTimes"].item(), dtype=float).ravel() + nst = nspikeTrain(st, name=str(i + 1), + minTime=float(time[0]), maxTime=float(time[-1]), + makePlots=-1) + # numHist + if "numHist" in rs.dtype.names: + nh = rs["numHist"].item() + num_hist = list(np.asarray(nh, dtype=int).ravel()) + else: + num_hist = [0] * numResults + cfgs = ConfigCollection([TrialConfig(name=n) for n in config_names]) + + fr = FitResult( + nst, + cov_labels, + num_hist, + [], # histObjects + [], # ensHistObjects + lambda_cov, + b_list, + [0.0] * numResults, # dev + [None] * numResults, # stats + AIC, + BIC, + logLL, + cfgs, + [], # XvalData + [], # XvalTime + "poisson", + ) + + # Load KS statistics if available + if "KSStats" in rs.dtype.names: + ks_struct = rs["KSStats"].item() + if hasattr(ks_struct, "dtype") and ks_struct.dtype.names: + ks_stat = np.asarray(ks_struct["ks_stat"].item(), dtype=float).ravel() + pval = np.asarray(ks_struct["pValue"].item(), dtype=float).ravel() + within = np.asarray(ks_struct["withinConfInt"].item(), dtype=float).ravel() + if ks_stat.size >= numResults: + fr.KSStats = ks_stat[:numResults].reshape(numResults, 1) + fr.KSPvalues = pval[:numResults] + fr.withinConfInt = within[:numResults] + + fit_results.append(fr) + + return fit_results + + +def _compute_place_field(coeffs, grid_design, grid_shape): + """Compute predicted firing rate on a spatial grid.""" + eta = grid_design @ coeffs + rate = np.exp(eta) + return rate.reshape(grid_shape) + + +# ===================================================================== +# Main example +# ===================================================================== def run_example04(*, export_figures: bool = False, export_dir: Path | None = None): - """Run Example 04: Place-cell receptive fields. - - Analysis workflow (mirrors Matlab example04_place_cells_continuous_stimulus.m): - 1. Load PlaceCellDataAnimal1.mat and PlaceCellDataAnimal2.mat. - 2. For each animal, visualise 4 example neurons (spike locations on path). - 3. Load or compute precomputed fit results for all neurons. - 4. Compute per-neuron Delta-KS, Delta-AIC, Delta-BIC (Gaussian vs Zernike). - 5. Generate Gaussian receptive-field heatmaps for all neurons (both animals). - 6. Generate Zernike polynomial receptive-field heatmaps. - 7. Generate 3-D mesh comparison for selected example cells. - """ + """Run Example 04: Place-cell receptive fields.""" + print("=== Example 04: Place-Cell Receptive Fields ===") + data_dir = ensure_example_data(download=True) + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example04" + + # ================================================================== + # 1. Load data for both animals + # ================================================================== + x1, y1, t1, neurons1 = _load_animal_data( + data_dir / "Place Cells" / "PlaceCellDataAnimal1.mat") + x2, y2, t2, neurons2 = _load_animal_data( + data_dir / "Place Cells" / "PlaceCellDataAnimal2.mat") + nCells1 = len(neurons1) + nCells2 = len(neurons2) + print(f" Animal 1: {nCells1} cells, {len(t1)} time points") + print(f" Animal 2: {nCells2} cells, {len(t2)} time points") + + # ================================================================== + # 2. Load precomputed FitResults + # ================================================================== + fitResults1 = _load_animal_results( + data_dir / "PlaceCellAnimal1Results.mat", x1, y1, t1, neurons1) + fitResults2 = _load_animal_results( + data_dir / "PlaceCellAnimal2Results.mat", x2, y2, t2, neurons2) + print(f" Loaded {len(fitResults1)} + {len(fitResults2)} FitResult objects") + + # ================================================================== + # 3. Build FitSummary for each animal + # ================================================================== + summary1 = FitSummary(fitResults1) + summary2 = FitSummary(fitResults2) - # Run analysis (returns summary statistics and figure payload) - summary, payload = run_experiment4(data_dir, return_payload=True) + # Delta statistics: Gaussian (index 0) minus Zernike (index 1) + dAIC1 = summary1.AIC[:, 0] - summary1.AIC[:, 1] + dBIC1 = summary1.BIC[:, 0] - summary1.BIC[:, 1] + dKS1 = summary1.KSStats[:, 0] - summary1.KSStats[:, 1] - print(json.dumps(summary, indent=2)) + dAIC2 = summary2.AIC[:, 0] - summary2.AIC[:, 1] + dBIC2 = summary2.BIC[:, 0] - summary2.BIC[:, 1] + dKS2 = summary2.KSStats[:, 0] - summary2.KSStats[:, 1] + + dAIC_all = np.concatenate([dAIC1, dAIC2]) + dBIC_all = np.concatenate([dBIC1, dBIC2]) + dKS_all = np.concatenate([dKS1, dKS2]) + + print(f" Mean dAIC (Gauss-Zern): {np.nanmean(dAIC_all):.2f}") + print(f" Mean dBIC (Gauss-Zern): {np.nanmean(dBIC_all):.2f}") + print(f" Mean dKS (Gauss-Zern): {np.nanmean(dKS_all):.4f}") + + # ================================================================== + # Figure 1: Example cells — spike locations over path (2x2) + # ================================================================== + exampleCells = [1, 20, 24, 48] # 0-indexed + fig1, axes1 = plt.subplots(2, 2, figsize=(12, 10)) + for i, cidx in enumerate(exampleCells): + ax = axes1.flat[i] + ax.plot(x1, y1, "b-", linewidth=0.5, alpha=0.5) + n = neurons1[min(cidx, nCells1 - 1)] + xn = np.asarray(n["xN"].item(), dtype=float).ravel() + yn = np.asarray(n["yN"].item(), dtype=float).ravel() + ax.plot(xn, yn, "r.", markersize=7) + ax.set_title(f"Cell {cidx + 1}", fontweight="bold", fontsize=12) + ax.set_aspect("equal") + fig1.suptitle("Animal 1 — Example Place Cells", fontweight="bold", + fontsize=14) + fig1.tight_layout() + + # ================================================================== + # Figure 2: Population statistics (1x3 box plots) + # ================================================================== + fig2, axes2 = plt.subplots(1, 3, figsize=(14, 5)) + + axes2[0].boxplot([dKS1[np.isfinite(dKS1)], dKS2[np.isfinite(dKS2)]], + tick_labels=["Animal 1", "Animal 2"]) + axes2[0].set_ylabel(r"$\Delta$KS (Gaussian - Zernike)") + axes2[0].set_title("KS Statistic Difference") + axes2[0].axhline(0, color="gray", linestyle="--", linewidth=0.5) + + axes2[1].boxplot([dAIC1[np.isfinite(dAIC1)], dAIC2[np.isfinite(dAIC2)]], + tick_labels=["Animal 1", "Animal 2"]) + axes2[1].set_ylabel(r"$\Delta$AIC (Gaussian - Zernike)") + axes2[1].set_title("AIC Difference") + axes2[1].axhline(0, color="gray", linestyle="--", linewidth=0.5) + + axes2[2].boxplot([dBIC1[np.isfinite(dBIC1)], dBIC2[np.isfinite(dBIC2)]], + tick_labels=["Animal 1", "Animal 2"]) + axes2[2].set_ylabel(r"$\Delta$BIC (Gaussian - Zernike)") + axes2[2].set_title("BIC Difference") + axes2[2].axhline(0, color="gray", linestyle="--", linewidth=0.5) + + fig2.tight_layout() + + # ================================================================== + # 4. Build spatial grids and design matrices for heatmaps + # ================================================================== + grid_res = 201 # Matlab: meshgrid(-1:0.01:1) → 201 points + xGrid = np.linspace(-1, 1, grid_res) + yGrid = np.linspace(-1, 1, grid_res) + xx, yy = np.meshgrid(xGrid, yGrid) + yy = np.flipud(yy) # Matlab: y increases bottom-to-top + xx = np.fliplr(xx) # Matlab: x increases right-to-left + xf, yf = xx.ravel(), yy.ravel() + + # Gaussian design: [1, x, y, x^2, y^2, xy] (intercept prepended) + gridDesignGauss = np.column_stack([ + np.ones(xf.size), xf, yf, xf**2, yf**2, xf * yf + ]) + + # Zernike design: [1, z1, z2, ..., z9] (intercept prepended) + zBasis = zernike_basis_from_cartesian(xf, yf, fill_value=0.0) + gridDesignZern = np.column_stack([np.ones(xf.size), zBasis]) + + # ================================================================== + # Figures 3-6: Place field heatmaps + # ================================================================== + def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, + design_zern, grid_shape): + nRows = math.ceil(nCells / 7) + nCols = 7 + + figG, axesG = plt.subplots(nRows, nCols, figsize=(14, 2 * nRows)) + figZ, axesZ = plt.subplots(nRows, nCols, figsize=(14, 2 * nRows)) + if nRows == 1: + axesG = axesG[np.newaxis, :] + axesZ = axesZ[np.newaxis, :] + + for i in range(nCells): + row, col = divmod(i, nCols) + fr = fit_results[i] + coeffs_g = np.asarray(fr.b[0], dtype=float).ravel() + coeffs_z = np.asarray(fr.b[1], dtype=float).ravel() if fr.numResults > 1 else coeffs_g + + # Gaussian field + ax = axesG[row, col] + try: + field_g = _compute_place_field(coeffs_g, design_gauss[:, :coeffs_g.size], grid_shape) + ax.pcolormesh(xx, yy, field_g, shading="gouraud", cmap="jet") + except Exception: + pass + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f"{i + 1}", fontsize=8) + + # Zernike field + ax = axesZ[row, col] + try: + field_z = _compute_place_field(coeffs_z, design_zern[:, :coeffs_z.size], grid_shape) + ax.pcolormesh(xx, yy, field_z, shading="gouraud", cmap="jet") + except Exception: + pass + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f"{i + 1}", fontsize=8) + + # Hide unused subplots + for i in range(nCells, nRows * nCols): + row, col = divmod(i, nCols) + axesG[row, col].set_visible(False) + axesZ[row, col].set_visible(False) + + figG.suptitle(f"{title_prefix} — Gaussian Place Fields", + fontweight="bold", fontsize=14) + figZ.suptitle(f"{title_prefix} — Zernike Place Fields", + fontweight="bold", fontsize=14) + figG.tight_layout() + figZ.tight_layout() + return figG, figZ + + figG1, figZ1 = _plot_heatmaps( + fitResults1, nCells1, "Animal 1", + gridDesignGauss, gridDesignZern, xx.shape, + ) + figG2, figZ2 = _plot_heatmaps( + fitResults2, nCells2, "Animal 2", + gridDesignGauss, gridDesignZern, xx.shape, + ) + print(" Figures 3-6: Place field heatmaps") + + # ================================================================== + # Figure 7: 3-D mesh comparison for an example cell + # ================================================================== + exampleCell = min(24, nCells1 - 1) # 0-indexed → cell 25 in Matlab + fr_ex = fitResults1[exampleCell] + coeffs_g = np.asarray(fr_ex.b[0], dtype=float).ravel() + coeffs_z = np.asarray(fr_ex.b[1], dtype=float).ravel() + + field_g = _compute_place_field( + coeffs_g, gridDesignGauss[:, :coeffs_g.size], xx.shape) + field_z = _compute_place_field( + coeffs_z, gridDesignZern[:, :coeffs_z.size], xx.shape) + + fig7 = plt.figure(figsize=(12, 8)) + ax3d = fig7.add_subplot(111, projection="3d") + ax3d.plot_surface(xx, yy, field_g, alpha=0.3, color="blue", + label="Gaussian") + ax3d.plot_surface(xx, yy, field_z, alpha=0.3, color="green", + label="Zernike") + # Overlay animal path at z=0 + ax3d.plot(x1, y1, np.zeros_like(x1), "b-", linewidth=0.3, alpha=0.3) + # Overlay spike locations + n_ex = neurons1[exampleCell] + xn_ex = np.asarray(n_ex["xN"].item(), dtype=float).ravel() + yn_ex = np.asarray(n_ex["yN"].item(), dtype=float).ravel() + ax3d.scatter(xn_ex, yn_ex, np.zeros_like(xn_ex), c="r", s=5, + alpha=0.5) + ax3d.set_xlabel("x") + ax3d.set_ylabel("y") + ax3d.set_zlabel("Firing Rate") + ax3d.set_title(f"Cell {exampleCell + 1}: Gaussian (blue) vs Zernike (green)", + fontweight="bold", fontsize=14) + + print(f" Figure 7: 3D mesh for cell {exampleCell + 1}") + + # ================================================================== + # Save figures + # ================================================================== + all_figs = { + "fig01_example_cells_path_overlay": fig1, + "fig02_model_summary_statistics": fig2, + "fig03_gaussian_place_fields_animal1": figG1, + "fig04_zernike_place_fields_animal1": figZ1, + "fig05_gaussian_place_fields_animal2": figG2, + "fig06_zernike_place_fields_animal2": figZ2, + "fig07_example_cell_mesh_comparison": fig7, + } if export_figures: - if export_dir is None: - export_dir = THIS_DIR / "figures" / "example04" - saved = export_named_paper_figures( - "example04", summary=summary, payload=payload, export_dir=export_dir - ) - print(f"\nGenerated {len(saved)} figure(s):") - for p in saved: - print(f" {p}") + export_dir.mkdir(parents=True, exist_ok=True) + for name, fig in all_figs.items(): + path = export_dir / f"{name}.png" + fig.savefig(str(path), dpi=150, bbox_inches="tight") + print(f" Saved {path}") - return summary + plt.show() + print(f"\nExample 04 complete. Generated {len(all_figs)} figure(s).") + return all_figs if __name__ == "__main__": @@ -80,12 +434,9 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) parser.add_argument("--export-figures", action="store_true") parser.add_argument("--export-dir", type=Path, default=None) - parser.add_argument("--output-json", type=Path, default=None) args = parser.parse_args() - result = run_example04( + run_example04( export_figures=args.export_figures, export_dir=args.export_dir, ) - if args.output_json: - args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 60a37fad..89e41a67 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -1,31 +1,37 @@ #!/usr/bin/env python3 """Example 05 — Stimulus Decoding With PPAF and PPHF. -This example demonstrates: - 1) Univariate sinusoidal stimulus encoding and decoding via PPDecodeFilterLinear. - 2) 4-state arm-reach simulation with 20-cell population encoding. - 3) PPAF (Point-Process Adaptive Filter) decoding: free vs goal-informed. - 4) Hybrid filter (PPHybridFilterLinear) for joint discrete/continuous states. +This example demonstrates neural decoding using point-process adaptive filters +(PPAF) and point-process hybrid filters (PPHF) from the nSTAT toolbox. The example has three parts: - Part A (experiment5): Univariate sinusoidal stimulus — encode with 20 - neurons, decode with PPDecodeFilterLinear. - Part B (experiment5b): 4-state arm reaching — simulate 20-cell population, - compare PPAF vs PPAF+Goal across 20 simulations. - Part C (experiment6): Hybrid filter — simulate 40-cell population with - discrete reach states and continuous kinematics, decode with - PPHybridFilterLinear. -Expected outputs: - - Figure 1: Univariate stimulus setup (CIF tuning curves, simulated spikes). - - Figure 2: Univariate decoding results (decoded stimulus vs true). - - Figure 3: Reach setup and population encoding. - - Figure 4: PPAF comparison (free vs goal-informed). - - Figure 5: Hybrid filter setup. - - Figure 6: Hybrid decoding summary. +Part A — Univariate Sinusoidal Stimulus (Figures 1–2): + 1. Define 20-cell population with logistic (binomial) tuning to a 1-D + sinusoidal stimulus. + 2. Simulate spike observations from the binomial CIF. + 3. Decode the stimulus using ``PPDecodeFilterLinear`` (PPAF). + +Part B — 4-State Arm Reach with PPAF (Figures 3–4): + 4. Simulate reaching trajectories (position + velocity, 4-D state). + 5. Encode with 20-cell cosine-tuning population (binomial CIF). + 6. Decode with PPAF (free) and PPAF + Goal; compare across 20 simulations. + +Part C — Hybrid Filter (Figures 5–6): + 7. Simulate 40-cell population with 2 discrete reach-states (rest / reach) + that modulate baseline firing rate, plus velocity-tuned continuous state. + 8. Decode joint discrete + continuous state via ``PPHybridFilterLinear``. Paper mapping: Section 2.5 (point-process adaptive filter) and Section 2.6 (hybrid filter). + +Expected outputs: + - Figure 1: CIF tuning curves and simulated spike raster. + - Figure 2: Decoded stimulus vs true (with ±2σ confidence band). + - Figure 3: Reach trajectory and population spike raster. + - Figure 4: PPAF comparison (free vs goal-informed, 20 runs box plot). + - Figure 5: Hybrid filter setup (state sequence, spike raster). + - Figure 6: Hybrid decoding results (state probabilities, decoded kinematics). """ from __future__ import annotations @@ -34,27 +40,490 @@ import sys from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np + THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_examples_full import ( # noqa: E402 - run_experiment5, - run_experiment5b, - run_experiment6, -) -from nstat.paper_figures import export_named_paper_figures # noqa: E402 +from nstat import DecodingAlgorithms # noqa: E402 + + +# ────────────────────────────────────────────────────────────────────────────── +# Helper: simulate binomial spikes from linear-logistic CIF +# ────────────────────────────────────────────────────────────────────────────── + + +def _simulate_binomial_spikes(x, mu, beta, rng): + """Simulate spikes from binomial CIF: p_c = sigmoid(mu_c + beta_c @ x). + + Parameters + ---------- + x : (ns, T) array — stimulus/state trajectory + mu : (C,) array — baseline log-odds per cell + beta : (ns, C) array — tuning coefficients + rng : numpy Generator + + Returns + ------- + dN : (C, T) array — binary spike indicators + """ + ns, T = x.shape + C = mu.size + dN = np.zeros((C, T), dtype=float) + for t in range(T): + eta = mu + beta.T @ x[:, t] # (C,) + p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0))) + dN[:, t] = (rng.random(C) < p).astype(float) + return dN + + +# ────────────────────────────────────────────────────────────────────────────── +# Part A — Univariate sinusoidal stimulus +# ────────────────────────────────────────────────────────────────────────────── + + +def _run_part_a(seed=11, n_cells=20): + """Encode/decode a 1-D sinusoidal stimulus with 20-cell binomial CIF.""" + rng = np.random.default_rng(seed) + delta = 0.001 # 1 ms bins + time = np.arange(0.0, 1.0 + delta, delta) + T = len(time) + + # True stimulus: sinusoidal + x_true = np.sin(2.0 * np.pi * 2.0 * time) # (T,) + + # ── Encoding model: logistic CIF ── + # mu_c ~ log(10*delta) + N(0, 0.3) — baseline firing probability + # beta_c ~ N(1.0, 0.5) — stimulus gain + b0 = np.log(10.0 * delta) * np.ones(n_cells) + rng.normal(0.0, 0.3, n_cells) + b1 = rng.normal(1.0, 0.5, n_cells) + + # Simulate spikes + x_2d = x_true.reshape(1, -1) # (1, T) — scalar state + beta = b1.reshape(1, -1) # (1, C) — stimulus coefficients + dN = _simulate_binomial_spikes(x_2d, b0, beta, rng) + + # ── State-space model ── + # x(t+1) = A * x(t) + w, w ~ N(0, Q) + A = np.array([[1.0]]) + Q = np.array([[0.001]]) + x0 = np.array([0.0]) + Pi0 = 0.5 * np.eye(1) + + # ── Decode with PPDecodeFilterLinear ── + # dN is (C, T) — the API expects (num_cells, num_steps) + x_p, W_p, x_u, W_u, _, _, _, _ = DecodingAlgorithms.PPDecodeFilterLinear( + A, Q, dN, b0, beta, "binomial", delta, None, None, x0, Pi0 + ) + + # Extract decoded signal and ±2σ confidence band + x_decoded = x_u[0, :] # (T,) + sigma = np.sqrt(np.maximum(W_u[0, 0, :], 0.0)) + ci_low = x_decoded - 2.0 * sigma + ci_high = x_decoded + 2.0 * sigma + rmse = float(np.sqrt(np.mean((x_decoded - x_true) ** 2))) + + return { + "time": time, + "x_true": x_true, + "x_decoded": x_decoded, + "ci_low": ci_low, + "ci_high": ci_high, + "dN": dN, + "b0": b0, + "b1": b1, + "rmse": rmse, + "n_cells": n_cells, + } + + +# ────────────────────────────────────────────────────────────────────────────── +# Part B — 4-state arm reach with PPAF +# ────────────────────────────────────────────────────────────────────────────── + + +def _simulate_reach(delta, T_total, rng): + """Simulate a 2-D reaching trajectory with 4-D state [x, y, vx, vy]. + + Uses a simple sinusoidal trajectory to mimic a reaching task. + """ + time = np.arange(0.0, T_total + delta, delta) + T = len(time) + + # Smooth trajectory + x_pos = 0.25 * np.sin(2.0 * np.pi * 0.15 * time) + y_pos = 0.20 * np.cos(2.0 * np.pi * 0.10 * time) + vx = np.gradient(x_pos, delta) + vy = np.gradient(y_pos, delta) + + state = np.vstack([x_pos, y_pos, vx, vy]) # (4, T) + return time, state + + +def _run_part_b(seed=19, n_cells=20, n_sims=20): + """Compare PPAF free vs goal-directed decoding for arm reach.""" + rng = np.random.default_rng(seed) + delta = 0.01 # 10 ms bins + ns = 4 # state dimension + + # State-space model (constant-velocity kinematic model) + A = np.array([ + [1, 0, delta, 0], + [0, 1, 0, delta], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], dtype=float) + Q = 0.001 * np.eye(ns, dtype=float) + + # Encoding model: cosine tuning to velocity + # mu_c ~ N(-3.0, 0.2) + # beta_c = [0, 0, w_vx, w_vy] — velocity tuned + b0 = rng.normal(-3.0, 0.2, n_cells) + beta = np.zeros((ns, n_cells), dtype=float) + for c in range(n_cells): + beta[2, c] = 3.0 * rng.normal(0.0, 1.0) # vx weight + beta[3, c] = 3.0 * rng.normal(0.0, 1.0) # vy weight + + # Run multiple simulations to compare free vs goal-directed + rmse_free = np.zeros((n_sims, ns), dtype=float) + rmse_goal = np.zeros((n_sims, ns), dtype=float) + + # Store one example run for plotting + example_run = None + + for sim_idx in range(n_sims): + sim_rng = np.random.default_rng(seed + sim_idx + 100) + time, state = _simulate_reach(delta, 10.0, sim_rng) + T = state.shape[1] + + # Simulate spikes + dN = _simulate_binomial_spikes(state, b0, beta, sim_rng) + + # Initial conditions + x0 = state[:, 0] + Pi0 = 0.1 * np.eye(ns) + + # --- Free decode (no goal) --- + x_p_free, _, x_u_free, W_u_free, _, _, _, _ = ( + DecodingAlgorithms.PPDecodeFilterLinear( + A, Q, dN, b0, beta, "binomial", delta, + None, None, x0, Pi0 + ) + ) + + # --- Goal-directed decode --- + yT = state[:, -1] # target = final state + PiT = 0.01 * np.eye(ns) # tight target uncertainty + x_p_goal, _, x_u_goal, W_u_goal, _, _, _, _ = ( + DecodingAlgorithms.PPDecodeFilterLinear( + A, Q, dN, b0, beta, "binomial", delta, + None, None, x0, Pi0, yT, PiT, 0 + ) + ) + + # Compute RMSE per state dimension + for d in range(ns): + rmse_free[sim_idx, d] = np.sqrt(np.mean((x_u_free[d, :] - state[d, :]) ** 2)) + rmse_goal[sim_idx, d] = np.sqrt(np.mean((x_u_goal[d, :] - state[d, :]) ** 2)) + + if sim_idx == 0: + example_run = { + "time": time, + "state": state, + "dN": dN, + "x_u_free": x_u_free, + "x_u_goal": x_u_goal, + "W_u_free": W_u_free, + "W_u_goal": W_u_goal, + } + + return { + "rmse_free": rmse_free, + "rmse_goal": rmse_goal, + "example": example_run, + "n_cells": n_cells, + "n_sims": n_sims, + "state_labels": ["x", "y", "vx", "vy"], + } + + +# ────────────────────────────────────────────────────────────────────────────── +# Part C — Hybrid filter +# ────────────────────────────────────────────────────────────────────────────── + +def _run_part_c(seed=37, n_cells=40): + """PPHybridFilterLinear: joint discrete/continuous state decoding.""" + rng = np.random.default_rng(seed) + delta = 0.01 # 10 ms bins + ns = 4 # continuous state dimension (x, y, vx, vy) -def run_example05(*, export_figures: bool = False, export_dir: Path | None = None): + # ── Simulate trajectory ── + time = np.arange(0.0, 10.0, delta, dtype=float) + T = len(time) + x_pos = 0.3 * np.sin(2.0 * np.pi * 0.15 * time) + y_pos = 0.25 * np.cos(2.0 * np.pi * 0.10 * time) + vx = np.gradient(x_pos, delta) + vy = np.gradient(y_pos, delta) + state = np.vstack([x_pos, y_pos, vx, vy]) # (4, T) + + # Discrete state: alternating reach / hold (period ~6s) + true_mode = np.where(np.sin(2.0 * np.pi * time / 6.0) > 0.0, 1, 2).astype(int) + # Add stochastic flips + flip = rng.random(T) < 0.01 + true_mode[flip] = 3 - true_mode[flip] + + # ── State-space models (one per mode) ── + A_reach = np.array([ + [1, 0, delta, 0], + [0, 1, 0, delta], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], dtype=float) + Q_reach = 0.001 * np.eye(ns) + + # Hold state: damped velocity + A_hold = np.array([ + [1, 0, delta, 0], + [0, 1, 0, delta], + [0, 0, 0.95, 0], + [0, 0, 0, 0.95], + ], dtype=float) + Q_hold = 0.0005 * np.eye(ns) + + # ── Encoding model ── + # Neurons tuned to ALL state dimensions (position + velocity). + # Mode-dependent baseline: mode 1 (reach) has different rate than mode 2 (hold). + b0_mode1 = rng.normal(-3.5, 0.2, n_cells) # reach baseline + b0_mode2 = rng.normal(-2.5, 0.2, n_cells) # hold baseline + + # Full state tuning: position + velocity + beta_mat = np.zeros((ns, n_cells), dtype=float) + beta_mat[0, :] = rng.normal(0.0, 2.0, n_cells) # x position + beta_mat[1, :] = rng.normal(0.0, 2.0, n_cells) # y position + beta_mat[2, :] = rng.normal(0.0, 3.0, n_cells) # vx + beta_mat[3, :] = rng.normal(0.0, 3.0, n_cells) # vy + + # Simulate spikes with mode-dependent baseline (binomial) + dN = np.zeros((n_cells, T), dtype=float) + for t in range(T): + b0 = b0_mode1 if true_mode[t] == 1 else b0_mode2 + eta = b0 + beta_mat.T @ state[:, t] + p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0))) + dN[:, t] = (rng.random(n_cells) < p).astype(float) + + # ── Transition matrix ── + p_ij = np.array([[0.985, 0.015], [0.02, 0.98]], dtype=float) + + # ── Decode with PPHybridFilterLinear ── + Mu0 = np.array([0.5, 0.5]) + x0 = [state[:, 0], state[:, 0]] + Pi0 = [0.5 * np.eye(ns), 0.5 * np.eye(ns)] + + S_est, X_est, W_est, MU_u, _, _, _ = DecodingAlgorithms.PPHybridFilterLinear( + [A_reach, A_hold], + [Q_reach, Q_hold], + p_ij, + Mu0, + dN, + [b0_mode1, b0_mode2], + [beta_mat, beta_mat], + "binomial", + delta, + None, # gamma + None, # windowTimes + x0, + Pi0, + ) + + # Classification accuracy + state_acc = float(np.mean(S_est == true_mode)) + + # Position RMSE + rmse_x = float(np.sqrt(np.mean((X_est[0, :] - x_pos) ** 2))) + rmse_y = float(np.sqrt(np.mean((X_est[1, :] - y_pos) ** 2))) + + return { + "time": time, + "state": state, + "true_mode": true_mode, + "S_est": S_est, + "X_est": X_est, + "MU_u": MU_u, + "dN": dN, + "state_acc": state_acc, + "rmse_x": rmse_x, + "rmse_y": rmse_y, + "n_cells": n_cells, + } + + +# ────────────────────────────────────────────────────────────────────────────── +# Plotting +# ────────────────────────────────────────────────────────────────────────────── + + +def _plot_part_a(result): + """Figure 1: CIF setup & raster. Figure 2: Decoded vs true stimulus.""" + time = result["time"] + x_true = result["x_true"] + dN = result["dN"] + + # ── Figure 1: CIF tuning and spike raster ── + fig1, axes1 = plt.subplots(2, 1, figsize=(10, 6), sharex=True) + + # Top: true stimulus + axes1[0].plot(time, x_true, "k-", linewidth=1.5) + axes1[0].set_ylabel("Stimulus x(t)") + axes1[0].set_title("Part A: Sinusoidal Stimulus Encoding") + + # Bottom: spike raster + n_cells = dN.shape[0] + for c in range(n_cells): + spike_times = time[dN[c, :] > 0] + axes1[1].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2) + axes1[1].set_ylabel("Neuron") + axes1[1].set_xlabel("Time (s)") + axes1[1].set_ylim(0.5, n_cells + 0.5) + fig1.tight_layout() + + # ── Figure 2: Decoding results ── + fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4)) + ax2.plot(time, x_true, "k-", linewidth=1.5, label="True stimulus") + ax2.plot(time, result["x_decoded"], "r-", linewidth=1.0, label="PPAF decoded") + ax2.fill_between( + time, result["ci_low"], result["ci_high"], + color="red", alpha=0.15, label="±2σ CI" + ) + ax2.set_xlabel("Time (s)") + ax2.set_ylabel("x(t)") + ax2.set_title(f"PPDecodeFilterLinear — Decoded Stimulus (RMSE = {result['rmse']:.4f})") + ax2.legend(loc="upper right") + fig2.tight_layout() + + return fig1, fig2 + + +def _plot_part_b(result): + """Figure 3: Reach trajectory & encoding. Figure 4: RMSE comparison.""" + ex = result["example"] + time = ex["time"] + state = ex["state"] + + # ── Figure 3: Example reach with decoded trajectories ── + fig3, axes3 = plt.subplots(2, 2, figsize=(12, 8)) + labels = result["state_labels"] + ylabels = ["x (m)", "y (m)", "vx (m/s)", "vy (m/s)"] + + for d, (ax, lab, ylab) in enumerate(zip(axes3.ravel(), labels, ylabels)): + ax.plot(time, state[d, :], "k-", linewidth=1.0, label="True") + ax.plot(time, ex["x_u_free"][d, :], "b-", linewidth=0.7, alpha=0.8, label="PPAF free") + ax.plot(time, ex["x_u_goal"][d, :], "r-", linewidth=0.7, alpha=0.8, label="PPAF+Goal") + ax.set_ylabel(ylab) + ax.set_title(f"State: {lab}") + if d >= 2: + ax.set_xlabel("Time (s)") + if d == 0: + ax.legend(loc="upper right", fontsize=8) + + fig3.suptitle("Part B: Arm Reach — PPAF Decoding (Example Run)", fontsize=12) + fig3.tight_layout() + + # ── Figure 4: RMSE box plot (free vs goal) ── + fig4, axes4 = plt.subplots(1, 4, figsize=(14, 4)) + for d, (ax, lab) in enumerate(zip(axes4, labels)): + data = [result["rmse_free"][:, d], result["rmse_goal"][:, d]] + bp = ax.boxplot(data, tick_labels=["Free", "Goal"]) + ax.set_title(f"RMSE: {lab}") + ax.set_ylabel("RMSE") + + fig4.suptitle( + f"Part B: PPAF Free vs Goal ({result['n_sims']} simulations, {result['n_cells']} cells)", + fontsize=12, + ) + fig4.tight_layout() + + return fig3, fig4 + + +def _plot_part_c(result): + """Figure 5: Hybrid setup. Figure 6: Hybrid decoding results.""" + time = result["time"] + + # ── Figure 5: Setup — state sequence + raster ── + fig5, axes5 = plt.subplots(2, 1, figsize=(12, 5), sharex=True) + + # Top: discrete state + axes5[0].plot(time, result["true_mode"], "k-", linewidth=1.0, label="True mode") + axes5[0].set_ylabel("Discrete State") + axes5[0].set_yticks([1, 2]) + axes5[0].set_yticklabels(["Reach", "Hold"]) + axes5[0].set_title("Part C: Hybrid Filter Setup") + axes5[0].legend() + + # Bottom: spike raster (first 20 cells) + dN = result["dN"] + n_show = min(20, dN.shape[0]) + for c in range(n_show): + idx = np.where(dN[c, :] > 0)[0] + spike_t = time[idx] + axes5[1].plot(spike_t, np.full_like(spike_t, c + 1), "|", color="k", markersize=2) + axes5[1].set_ylabel("Neuron") + axes5[1].set_xlabel("Time (s)") + axes5[1].set_ylim(0.5, n_show + 0.5) + fig5.tight_layout() + + # ── Figure 6: Decoding results ── + fig6, axes6 = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + + # Top: model probabilities + axes6[0].plot(time, result["MU_u"][0, :], "b-", linewidth=0.5, label="P(Reach)") + axes6[0].plot(time, result["MU_u"][1, :], "r-", linewidth=0.5, label="P(Hold)") + axes6[0].axhline(0.5, color="gray", linestyle="--", linewidth=0.5) + axes6[0].set_ylabel("Model Prob") + axes6[0].set_title( + f"PPHybridFilterLinear — State Accuracy: {result['state_acc']:.1%}" + ) + axes6[0].legend(loc="upper right", fontsize=8) + + # Middle: decoded x-position + axes6[1].plot(time, result["state"][0, :], "k-", linewidth=1.0, label="True") + axes6[1].plot(time, result["X_est"][0, :], "b-", linewidth=0.7, alpha=0.8, label="Decoded") + axes6[1].set_ylabel("x (m)") + axes6[1].legend(loc="upper right", fontsize=8) + + # Bottom: decoded y-position + axes6[2].plot(time, result["state"][1, :], "k-", linewidth=1.0, label="True") + axes6[2].plot(time, result["X_est"][1, :], "r-", linewidth=0.7, alpha=0.8, label="Decoded") + axes6[2].set_ylabel("y (m)") + axes6[2].set_xlabel("Time (s)") + axes6[2].legend(loc="upper right", fontsize=8) + + fig6.suptitle( + f"Hybrid Decoding (RMSE: x={result['rmse_x']:.4f}, y={result['rmse_y']:.4f})", + fontsize=12, + ) + fig6.tight_layout() + + return fig5, fig6 + + +# ────────────────────────────────────────────────────────────────────────────── +# Main entry point +# ────────────────────────────────────────────────────────────────────────────── + + +def run_example05(*, export_figures=False, export_dir=None, show=False): """Run Example 05: PPAF and PPHF decoding. - Analysis workflow (mirrors Matlab example05_decoding_ppaf_pphf.m): + Analysis workflow (mirrors Matlab ``example05_decoding_ppaf_pphf.m``): Part A — Univariate stimulus decoding: 1. Define 20-cell population with sinusoidal tuning. - 2. Simulate spikes from sinusoidal stimulus CIF. + 2. Simulate spikes from binomial CIF. 3. Decode stimulus via PPDecodeFilterLinear. Part B — Arm-reach PPAF: @@ -63,45 +532,82 @@ def run_example05(*, export_figures: bool = False, export_dir: Path | None = Non 6. Decode with PPAF (free) and PPAF+Goal; compare across 20 runs. Part C — Hybrid filter: - 7. Simulate 40-cell population with discrete reach-state modulation. + 7. Simulate 40-cell population with discrete state modulation. 8. Decode joint discrete/continuous state via PPHybridFilterLinear. """ + print("=" * 70) + print("Example 05: Stimulus Decoding with PPAF and PPHF") + print("=" * 70) + # --- Part A: Univariate sinusoidal stimulus --- - summary5, payload5 = run_experiment5(return_payload=True) + print("\n--- Part A: Univariate Sinusoidal Stimulus ---") + result_a = _run_part_a() + print(f" {result_a['n_cells']} cells, decode RMSE = {result_a['rmse']:.4f}") # --- Part B: Arm-reach PPAF --- - summary5b, payload5b = run_experiment5b(return_payload=True) + print("\n--- Part B: Arm Reach PPAF (20 simulations) ---") + result_b = _run_part_b() + mean_free = result_b["rmse_free"].mean(axis=0) + mean_goal = result_b["rmse_goal"].mean(axis=0) + print(f" Mean RMSE (free): x={mean_free[0]:.4f}, y={mean_free[1]:.4f}, " + f"vx={mean_free[2]:.4f}, vy={mean_free[3]:.4f}") + print(f" Mean RMSE (goal): x={mean_goal[0]:.4f}, y={mean_goal[1]:.4f}, " + f"vx={mean_goal[2]:.4f}, vy={mean_goal[3]:.4f}") # --- Part C: Hybrid filter --- - summary6, payload6 = run_experiment6(REPO_ROOT, return_payload=True) + print("\n--- Part C: Hybrid Filter ---") + result_c = _run_part_c() + print(f" {result_c['n_cells']} cells, state accuracy = {result_c['state_acc']:.1%}") + print(f" Position RMSE: x={result_c['rmse_x']:.4f}, y={result_c['rmse_y']:.4f}") - # Merge summaries for JSON output - combined_summary = { - "experiment5": summary5, - "experiment5b": summary5b, - "experiment6": summary6, + # Summary + summary = { + "experiment5": { + "num_cells": float(result_a["n_cells"]), + "decode_rmse": result_a["rmse"], + }, + "experiment5b": { + "num_cells": float(result_b["n_cells"]), + "n_sims": float(result_b["n_sims"]), + "mean_rmse_free_x": float(mean_free[0]), + "mean_rmse_goal_x": float(mean_goal[0]), + }, + "experiment6": { + "num_cells": float(result_c["n_cells"]), + "state_accuracy": result_c["state_acc"], + "decode_rmse_x": result_c["rmse_x"], + "decode_rmse_y": result_c["rmse_y"], + }, } - print(json.dumps(combined_summary, indent=2)) + print("\n" + json.dumps(summary, indent=2)) + + # --- Figures --- + fig1, fig2 = _plot_part_a(result_a) + fig3, fig4 = _plot_part_b(result_b) + fig5, fig6 = _plot_part_c(result_c) + figures = [fig1, fig2, fig3, fig4, fig5, fig6] if export_figures: if export_dir is None: export_dir = THIS_DIR / "figures" / "example05" - combined_payload = { - "experiment5": payload5, - "experiment5b": payload5b, - "experiment6": payload6, - } - saved = export_named_paper_figures( - "example05", - summary=combined_summary, - payload=combined_payload, - export_dir=export_dir, - ) - print(f"\nGenerated {len(saved)} figure(s):") - for p in saved: - print(f" {p}") + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + for i, fig in enumerate(figures, 1): + fig_names = [ + "fig01_univariate_setup", "fig02_univariate_decoding", + "fig03_reach_and_population_setup", "fig04_ppaf_goal_vs_free", + "fig05_hybrid_setup", "fig06_hybrid_decoding_summary", + ] + path = export_dir / f"{fig_names[i - 1]}.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + print(f" Saved: {path}") + + if show: + plt.show() + else: + plt.close("all") - return combined_summary + return summary if __name__ == "__main__": @@ -112,11 +618,13 @@ def run_example05(*, export_figures: bool = False, export_dir: Path | None = Non parser.add_argument("--export-figures", action="store_true") parser.add_argument("--export-dir", type=Path, default=None) parser.add_argument("--output-json", type=Path, default=None) + parser.add_argument("--show", action="store_true", help="Display figures interactively") args = parser.parse_args() result = run_example05( export_figures=args.export_figures, export_dir=args.export_dir, + show=args.show, ) if args.output_json: args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/readme_examples/example1_multitaper_and_spectrogram.py b/examples/readme_examples/example1_multitaper_and_spectrogram.py index 059590bb..89652c8d 100644 --- a/examples/readme_examples/example1_multitaper_and_spectrogram.py +++ b/examples/readme_examples/example1_multitaper_and_spectrogram.py @@ -31,10 +31,10 @@ def main() -> None: time = np.arange(0.0, duration_s, dt, dtype=float) signal = np.sin(2.0 * np.pi * f0_hz * time) - sig_obj = SignalObj(time=time, data=signal, name="sine_signal", units="a.u.") + sig_obj = SignalObj(time=time, data=signal, name="sine_signal", yunits="a.u.") try: - freq_hz, psd = sig_obj.MTMspectrum() + freq_hz, psd, _ci = sig_obj.MTMspectrum() except Exception: freq_hz, psd = _fallback_multitaper_psd(signal, fs_hz) diff --git a/examples/readme_examples/example2_simulate_cif_spiketrain_10s.py b/examples/readme_examples/example2_simulate_cif_spiketrain_10s.py index 575022a0..e7d30990 100644 --- a/examples/readme_examples/example2_simulate_cif_spiketrain_10s.py +++ b/examples/readme_examples/example2_simulate_cif_spiketrain_10s.py @@ -44,7 +44,7 @@ def main() -> None: amp_hz = 10.0 lam = np.clip(baseline_hz + amp_hz * np.sin(2.0 * np.pi * f_hz * t), 0.2, None) - lambda_cov = Covariate(time=t, data=lam, name="Lambda", units="spikes/s", labels=["lambda"]) + lambda_cov = Covariate(time=t, data=lam, name="Lambda", yunits="spikes/s", dataLabels=["lambda"]) spikes = CIF.simulateCIFByThinningFromLambda(lambda_cov, 1, dt) spike_times = _extract_first_spike_times(spikes) diff --git a/examples/readme_examples/example3_nstcoll_raster_from_example2.py b/examples/readme_examples/example3_nstcoll_raster_from_example2.py index 1c818fc8..2d93a68e 100644 --- a/examples/readme_examples/example3_nstcoll_raster_from_example2.py +++ b/examples/readme_examples/example3_nstcoll_raster_from_example2.py @@ -49,7 +49,7 @@ def main() -> None: np.random.seed(0) t, lam, dt = _build_lambda() - lambda_cov = Covariate(time=t, data=lam, name="Lambda", units="spikes/s", labels=["lambda"]) + lambda_cov = Covariate(time=t, data=lam, name="Lambda", yunits="spikes/s", dataLabels=["lambda"]) n_units = 20 spikes_coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, n_units, dt) diff --git a/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png b/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png index 93525b78..3901dd98 100644 Binary files a/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png and b/examples/readme_examples/images/readme_example1_multitaper_and_spectrogram.png differ diff --git a/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain_10s.png b/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain_10s.png index 035ca493..be25b2e9 100644 Binary files a/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain_10s.png and b/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain_10s.png differ diff --git a/examples/readme_examples/images/readme_example3_nstcoll_raster.png b/examples/readme_examples/images/readme_example3_nstcoll_raster.png index 6a32ae88..b76f86d9 100644 Binary files a/examples/readme_examples/images/readme_example3_nstcoll_raster.png and b/examples/readme_examples/images/readme_example3_nstcoll_raster.png differ diff --git a/notebooks/AnalysisExamples.ipynb b/notebooks/AnalysisExamples.ipynb index 89257604..1452de8e 100644 --- a/notebooks/AnalysisExamples.ipynb +++ b/notebooks/AnalysisExamples.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "fc7be361", + "id": "98fc6fc8", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `AnalysisExamples.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now follows the MATLAB standard-GLM workflow with the canonical `glm_data.mat` dataset and real KS/model-visualization figures; coefficient values and styling still vary modestly because the Python GLM backend and plotting defaults differ from MATLAB.\n" + "- Remaining justified differences: The notebook now follows the MATLAB standard-GLM workflow with the canonical `glm_data.mat` dataset and real KS/model-visualization figures; coefficient values and styling still vary modestly because the Python GLM backend and plotting defaults differ from MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "45d93add", + "id": "7807842d", "metadata": {}, "outputs": [], "source": [ @@ -44,14 +44,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"AnalysisExamples\", output_root=OUTPUT_ROOT, expected_count=4)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _poisson_standard_errors(design_matrix, result):\n", " x = np.asarray(design_matrix, dtype=float)\n", " if x.ndim == 1:\n", @@ -62,7 +60,6 @@ " cov = np.linalg.pinv(x_aug.T @ (lam[:, None] * x_aug))\n", " return np.sqrt(np.clip(np.diag(cov), 0.0, None))\n", "\n", - "\n", "T = np.asarray(GLM_DATA[\"T\"], dtype=float).reshape(-1)\n", "xN = np.asarray(GLM_DATA[\"xN\"], dtype=float).reshape(-1)\n", "yN = np.asarray(GLM_DATA[\"yN\"], dtype=float).reshape(-1)\n", @@ -71,25 +68,25 @@ "x_at_spiketimes = np.asarray(GLM_DATA[\"x_at_spiketimes\"], dtype=float).reshape(-1)\n", "y_at_spiketimes = np.asarray(GLM_DATA[\"y_at_spiketimes\"], dtype=float).reshape(-1)\n", "sample_rate = 1.0 / float(np.median(np.diff(T)))\n", - "nst = nspikeTrain(spiketimes, name=\"1\", minTime=float(T[0]), maxTime=float(T[-1]), makePlots=-1)\n" + "nst = nspikeTrain(spiketimes, name=\"1\", minTime=float(T[0]), maxTime=float(T[-1]), makePlots=-1)" ] }, { "cell_type": "code", "execution_count": null, - "id": "3c621348", + "id": "37ac20c9", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Analysis Examples\n", "plt.close(\"all\")\n", - "print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"sample_rate_hz\": round(sample_rate, 3)})\n" + "print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"sample_rate_hz\": round(sample_rate, 3)})" ] }, { "cell_type": "code", "execution_count": null, - "id": "c1d9b5e4", + "id": "dbdc74f9", "metadata": {}, "outputs": [], "source": [ @@ -107,13 +104,13 @@ "x_quadratic = np.column_stack([xN, yN, xN**2, yN**2, xN * yN])\n", "linear_fit = fit_poisson_glm(x_linear, spikes_binned)\n", "quadratic_fit = fit_poisson_glm(x_quadratic, spikes_binned)\n", - "centered_fit = fit_poisson_glm(x_quadratic_centered, spikes_binned)\n" + "centered_fit = fit_poisson_glm(x_quadratic_centered, spikes_binned)" ] }, { "cell_type": "code", "execution_count": null, - "id": "b5f3a818", + "id": "5c38cff1", "metadata": {}, "outputs": [], "source": [ @@ -125,13 +122,13 @@ "ax.set_aspect(\"equal\", adjustable=\"box\")\n", "ax.set_xlabel(\"x position (m)\")\n", "ax.set_ylabel(\"y position (m)\")\n", - "ax.set_title(\"Rat trajectory with spike locations\")\n" + "ax.set_title(\"Rat trajectory with spike locations\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "396cb183", + "id": "5af52914", "metadata": {}, "outputs": [], "source": [ @@ -144,13 +141,13 @@ "ax.errorbar(xpos, centered_beta, yerr=centered_se, fmt=\".\", color=\"tab:blue\", capsize=3)\n", "ax.set_xticks(xpos, [\"baseline\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"])\n", "ax.set_ylabel(\"coefficient value\")\n", - "ax.set_title(\"Quadratic GLM coefficients\")\n" + "ax.set_title(\"Quadratic GLM coefficients\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "49d54a88", + "id": "5cac7309", "metadata": {}, "outputs": [], "source": [ @@ -169,13 +166,13 @@ "ax.set_xlabel(\"x position (m)\")\n", "ax.set_ylabel(\"y position (m)\")\n", "ax.set_zlabel(\"lambda\")\n", - "ax.set_title(\"Quadratic GLM spatial intensity\")\n" + "ax.set_title(\"Quadratic GLM spatial intensity\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "8b700118", + "id": "36dd8e70", "metadata": {}, "outputs": [], "source": [ @@ -189,13 +186,13 @@ " \"linear_mean_rate_hz\": round(float(np.mean(lambda_linear_hz)), 4),\n", " \"quadratic_mean_rate_hz\": round(float(np.mean(lambda_quadratic_hz)), 4),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "9bd202c8", + "id": "2d8b81fd", "metadata": {}, "outputs": [], "source": [ @@ -220,7 +217,7 @@ "ax.set_ylabel(\"Empirical CDF of Rescaled ISIs\")\n", "ax.set_title(\"KS Plot with 95% Confidence Intervals\")\n", "ax.legend(loc=\"lower right\", frameon=False)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -237,4 +234,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/AnalysisExamples2.ipynb b/notebooks/AnalysisExamples2.ipynb index e0e11f4e..daa82ada 100644 --- a/notebooks/AnalysisExamples2.ipynb +++ b/notebooks/AnalysisExamples2.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "2468a9e7", + "id": "66d56086", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `AnalysisExamples2.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now follows the MATLAB toolbox workflow on the canonical `glm_data.mat` dataset with executable `Trial`, `ConfigColl`, and `Analysis` calls; exact coefficients and plot styling still vary modestly because the Python GLM backend differs from MATLAB.\n" + "- Remaining justified differences: The notebook now follows the MATLAB toolbox workflow on the canonical `glm_data.mat` dataset with executable `Trial`, `ConfigColl`, and `Analysis` calls; exact coefficients and plot styling still vary modestly because the Python GLM backend differs from MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "5e1d1998", + "id": "62e21501", "metadata": {}, "outputs": [], "source": [ @@ -44,14 +44,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=5)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "T = np.asarray(GLM_DATA[\"T\"], dtype=float).reshape(-1)\n", "xN = np.asarray(GLM_DATA[\"xN\"], dtype=float).reshape(-1)\n", "yN = np.asarray(GLM_DATA[\"yN\"], dtype=float).reshape(-1)\n", @@ -67,36 +65,36 @@ "velocity = Covariate(T, np.column_stack([vxN, vyN]), \"Velocity\", \"time\", \"s\", \"m/s\", [\"v_x\", \"v_y\"])\n", "radial = Covariate(T, np.column_stack([xN, yN, xN**2, yN**2, xN * yN]), \"Radial\", \"time\", \"s\", \"m\", [\"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"])\n", "values_at_spiketimes = position.getValueAt(spiketimes)\n", - "values_at_spiketimes_upsampled = position.resample(1.0 / np.min(np.diff(spiketimes))).getValueAt(spiketimes)\n" + "values_at_spiketimes_upsampled = position.resample(1.0 / np.min(np.diff(spiketimes))).getValueAt(spiketimes)" ] }, { "cell_type": "code", "execution_count": null, - "id": "45dc365a", + "id": "1836e297", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Analysis Examples 2\n", "plt.close(\"all\")\n", - "print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"analysis_sample_rate_hz\": sample_rate})\n" + "print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"analysis_sample_rate_hz\": sample_rate})" ] }, { "cell_type": "code", "execution_count": null, - "id": "2a9182fe", + "id": "bf657cd0", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: load the rat trajectory and spiking data\n", - "print({\"position_shape\": list(position.data.shape), \"velocity_shape\": list(velocity.data.shape), \"radial_shape\": list(radial.data.shape)})\n" + "print({\"position_shape\": list(position.data.shape), \"velocity_shape\": list(velocity.data.shape), \"radial_shape\": list(radial.data.shape)})" ] }, { "cell_type": "code", "execution_count": null, - "id": "126391f1", + "id": "fe47aacc", "metadata": {}, "outputs": [], "source": [ @@ -106,13 +104,13 @@ " \"direct_spike_position_head\": np.asarray(values_at_spiketimes[:3], dtype=float).round(4).tolist(),\n", " \"upsampled_spike_position_head\": np.asarray(values_at_spiketimes_upsampled[:3], dtype=float).round(4).tolist(),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "8aaea418", + "id": "2f28e3be", "metadata": {}, "outputs": [], "source": [ @@ -124,13 +122,13 @@ "ax.set_aspect(\"equal\", adjustable=\"box\")\n", "ax.set_xlabel(\"x position (m)\")\n", "ax.set_ylabel(\"y position (m)\")\n", - "ax.set_title(\"Trajectory and interpolated spike locations\")\n" + "ax.set_title(\"Trajectory and interpolated spike locations\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "d17e023e", + "id": "d40c40c9", "metadata": {}, "outputs": [], "source": [ @@ -143,13 +141,13 @@ " TrialConfig([[\"Baseline\", \"mu\"], [\"Radial\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]], sampleRate=sample_rate, history=[], name=\"Quadratic\"),\n", " TrialConfig([[\"Baseline\", \"mu\"], [\"Radial\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]], sampleRate=sample_rate, history=[0.0, 1.0 / sample_rate], name=\"Quadratic+Hist\"),\n", "]\n", - "tcc = ConfigColl(tc)\n" + "tcc = ConfigColl(tc)" ] }, { "cell_type": "code", "execution_count": null, - "id": "4ab39635", + "id": "f9160bdb", "metadata": {}, "outputs": [], "source": [ @@ -157,13 +155,13 @@ "fitResults = Analysis.RunAnalysisForAllNeurons(trial, tcc, 0)\n", "fig = _prepare_figure(\"fitResults.plotResults\", figsize=(11.0, 8.0))\n", "fitResults.plotResults(handle=fig)\n", - "print({\"config_names\": fitResults.configNames, \"aic\": np.asarray(fitResults.AIC, dtype=float).round(3).tolist()})\n" + "print({\"config_names\": fitResults.configNames, \"aic\": np.asarray(fitResults.AIC, dtype=float).round(3).tolist()})" ] }, { "cell_type": "code", "execution_count": null, - "id": "db6c7107", + "id": "a2fafcc8", "metadata": {}, "outputs": [], "source": [ @@ -180,13 +178,13 @@ "ax.set_xlabel(\"x position (m)\")\n", "ax.set_ylabel(\"y position (m)\")\n", "ax.set_zlabel(\"lambda\")\n", - "ax.set_title(\"Toolbox-model spatial intensity comparison\")\n" + "ax.set_title(\"Toolbox-model spatial intensity comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "5a1dbe4c", + "id": "64cdac30", "metadata": {}, "outputs": [], "source": [ @@ -201,13 +199,13 @@ "ax.set_xticks(np.arange(coeff_diff.size), labels, rotation=20)\n", "ax.set_ylabel(\"standard minus toolbox\")\n", "ax.set_title(\"Coefficient agreement between workflows\")\n", - "print({\"quadratic_coeff_diff_max_abs\": round(float(np.max(np.abs(coeff_diff))), 6)})\n" + "print({\"quadratic_coeff_diff_max_abs\": round(float(np.max(np.abs(coeff_diff))), 6)})" ] }, { "cell_type": "code", "execution_count": null, - "id": "85a9d741", + "id": "8782d383", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +221,7 @@ "ax.set_ylabel(\"AIC\")\n", "ax.set_title(\"History-lag model comparison\")\n", "print({\"history_config_names\": histConfigs.getConfigNames(), \"summary_fit_names\": histSummary.fitNames})\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -233,11 +231,11 @@ }, "nstat": { "expected_figures": 5, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "AnalysisExamples2" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/ConfidenceIntervalOverview.ipynb b/notebooks/ConfidenceIntervalOverview.ipynb index bad3632b..ba0ac853 100644 --- a/notebooks/ConfidenceIntervalOverview.ipynb +++ b/notebooks/ConfidenceIntervalOverview.ipynb @@ -114,7 +114,12 @@ "name": "python", "version": "3.12" }, - "nstat": {} + "nstat": { + "expected_figures": 0, + "run_group": "smoke", + "style": "python-example", + "topic": "ConfidenceIntervalOverview" + } }, "nbformat": 4, "nbformat_minor": 5 diff --git a/notebooks/ConfigCollExamples.ipynb b/notebooks/ConfigCollExamples.ipynb index 56dffa58..00ba6d08 100644 --- a/notebooks/ConfigCollExamples.ipynb +++ b/notebooks/ConfigCollExamples.ipynb @@ -32,7 +32,7 @@ "# SECTION 0: Section 0\n", "# ConfigColl Examples\n", "# tcObj=TrialConfig(covMask,sampleRate, history,minTime,maxTime)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], diff --git a/notebooks/CovariateExamples.ipynb b/notebooks/CovariateExamples.ipynb index edc5cc69..9bdcd952 100644 --- a/notebooks/CovariateExamples.ipynb +++ b/notebooks/CovariateExamples.ipynb @@ -31,7 +31,7 @@ "\n", "# SECTION 0: Section 0\n", "# Test the Cov class\n", - "# Covariates are just like signals with a mean and a standard deviation They have two representations, the default (original representation) and a zero-mean representation\n" + "# Covariates are just like signals with a mean and a standard deviation They have two representations, the default (original representation) and a zero-mean representation" ] }, { diff --git a/notebooks/DecodingExample.ipynb b/notebooks/DecodingExample.ipynb index 781daba8..0e367250 100644 --- a/notebooks/DecodingExample.ipynb +++ b/notebooks/DecodingExample.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e78ea1c1", + "id": "0813e1e0", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b558e18d", + "id": "44d88ec9", "metadata": {}, "outputs": [], "source": [ @@ -42,14 +42,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"DecodingExample\", output_root=OUTPUT_ROOT, expected_count=5)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _plot_raster(ax, spike_coll):\n", " for row in range(1, spike_coll.numSpikeTrains + 1):\n", " train = spike_coll.getNST(row)\n", @@ -59,7 +57,6 @@ " ax.set_ylabel(\"Neuron\")\n", " ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5)\n", "\n", - "\n", "def _plot_decoded_ci(ax, time, decoded, cov, stim, title):\n", " center = np.asarray(decoded, dtype=float).reshape(-1)\n", " variance = np.asarray(cov, dtype=float).reshape(-1)\n", @@ -75,15 +72,14 @@ " ax.set_xlabel(\"time (s)\")\n", " ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", "\n", - "\n", "# SECTION 0: STIMULUS DECODING\n", - "# In this example we decode a univariate stimulus from simulated point-process observations by following the MATLAB DecodingExample workflow.\n" + "# In this example we decode a univariate stimulus from simulated point-process observations by following the MATLAB DecodingExample workflow." ] }, { "cell_type": "code", "execution_count": null, - "id": "e37eea70", + "id": "2f1ec431", "metadata": {}, "outputs": [], "source": [ @@ -110,13 +106,13 @@ "axs[1].plot(time, lambda_cov.data[:, 0], color=\"b\", linewidth=2.0)\n", "axs[1].set_title(\"Conditional intensity λ(t)\")\n", "axs[1].set_xlabel(\"time (s)\")\n", - "axs[1].set_ylabel(\"Hz\")\n" + "axs[1].set_ylabel(\"Hz\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "b8dd2913", + "id": "a7048402", "metadata": {}, "outputs": [], "source": [ @@ -174,13 +170,13 @@ "axs[0].set_title(\"Mean AIC across neurons\")\n", "axs[1].bar(xloc, np.mean(logll_matrix, axis=0), color=[\"0.6\", \"0.3\"])\n", "axs[1].set_xticks(xloc, config_names, rotation=15)\n", - "axs[1].set_title(\"Mean log-likelihood across neurons\")\n" + "axs[1].set_title(\"Mean log-likelihood across neurons\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "d7529413", + "id": "df079e59", "metadata": {}, "outputs": [], "source": [ @@ -198,7 +194,7 @@ "fig = _prepare_figure(\"figure\", figsize=(8.0, 4.5))\n", "ax = fig.subplots(1, 1)\n", "_plot_decoded_ci(ax, time, x_u, W_u, stim.data[:, 0], f\"Decoded stimulus using {numRealizations} cells\")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], diff --git a/notebooks/DecodingExampleWithHist.ipynb b/notebooks/DecodingExampleWithHist.ipynb index 2fd02734..9cdea7ac 100644 --- a/notebooks/DecodingExampleWithHist.ipynb +++ b/notebooks/DecodingExampleWithHist.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "04553c3e", + "id": "b9af2115", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc21ece5", + "id": "a847f096", "metadata": {}, "outputs": [], "source": [ @@ -42,14 +42,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"DecodingExampleWithHist\", output_root=OUTPUT_ROOT, expected_count=2)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _plot_raster(ax, spike_coll):\n", " for row in range(1, spike_coll.numSpikeTrains + 1):\n", " train = spike_coll.getNST(row)\n", @@ -59,7 +57,6 @@ " ax.set_ylabel(\"Neuron\")\n", " ax.set_ylim(0.5, spike_coll.numSpikeTrains + 0.5)\n", "\n", - "\n", "def _plot_decoded_ci(ax, time, decoded, cov, stim, title):\n", " center = np.asarray(decoded, dtype=float).reshape(-1)\n", " spread = np.asarray(cov, dtype=float).reshape(-1)\n", @@ -74,7 +71,6 @@ " ax.set_xlabel(\"time (s)\")\n", " ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", "\n", - "\n", "def _simulate_history_spike_train(time, stim_data, baseline, hist_coeffs, window_times):\n", " spikes = []\n", " for idx in range(1, len(time)):\n", @@ -93,15 +89,14 @@ " spikes.append(t)\n", " return np.asarray(spikes, dtype=float)\n", "\n", - "\n", "# SECTION 0: 1-D Stimulus Decode with History Effect\n", - "# We simulate neurons with refractory-history effects and compare point-process decoding with and without the correct history terms.\n" + "# We simulate neurons with refractory-history effects and compare point-process decoding with and without the correct history terms." ] }, { "cell_type": "code", "execution_count": null, - "id": "44a6c7e4", + "id": "b9bfa418", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +152,7 @@ "axs = fig.subplots(2, 1, sharex=True)\n", "_plot_decoded_ci(axs[0], time, x_u, W_u, stim.data[:, 0], f\"Decoded stimulus with history using {numRealizations} cells\")\n", "_plot_decoded_ci(axs[1], time, x_u_no_hist, W_u_no_hist, stim.data[:, 0], f\"Decoded stimulus without history using {numRealizations} cells\")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], diff --git a/notebooks/ExplicitStimulusWhiskerData.ipynb b/notebooks/ExplicitStimulusWhiskerData.ipynb index 436fab61..cd298616 100644 --- a/notebooks/ExplicitStimulusWhiskerData.ipynb +++ b/notebooks/ExplicitStimulusWhiskerData.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "1039ac25", + "id": "199165d0", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `ExplicitStimulusWhiskerData.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the dataset-backed lag search, stimulus-effect, and history-effect workflow with real figures; exact KS traces and coefficient values still vary modestly from MATLAB because the Python GLM backend and plotting defaults are different.\n" + "- Remaining justified differences: The notebook now reproduces the dataset-backed lag search, stimulus-effect, and history-effect workflow with real figures; exact KS traces and coefficient values still vary modestly from MATLAB because the Python GLM backend and plotting defaults are different." ] }, { "cell_type": "code", "execution_count": null, - "id": "ea575816", + "id": "e102e8bb", "metadata": {}, "outputs": [], "source": [ @@ -44,14 +44,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='ExplicitStimulusWhiskerData', output_root=OUTPUT_ROOT, expected_count=9)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _plot_spike_indicator(ax, time_s, spike_indicator):\n", " spike_times = np.asarray(time_s, dtype=float)[np.asarray(spike_indicator, dtype=float) > 0.5]\n", " if spike_times.size:\n", @@ -59,7 +57,6 @@ " ax.set_ylim(0.0, 1.0)\n", " ax.set_ylabel(\"spikes\")\n", "\n", - "\n", "def _plot_ks(ax, ideal, empirical, ci, *, label, color):\n", " ideal_arr = np.asarray(ideal, dtype=float)\n", " empirical_arr = np.asarray(empirical, dtype=float)\n", @@ -77,13 +74,13 @@ " ax.set_xlabel(\"Theoretical quantiles\")\n", " ax.set_ylabel(\"Empirical quantiles\")\n", " ax.set_xlim(0.0, 1.0)\n", - " ax.set_ylim(0.0, 1.0)\n" + " ax.set_ylim(0.0, 1.0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "c32ee0d2", + "id": "45734023", "metadata": {}, "outputs": [], "source": [ @@ -100,13 +97,13 @@ " \"peak_lag_ms\": round(float(summary[\"peak_lag_seconds\"]) * 1000.0, 1),\n", " \"best_history_window_bins\": best_history_window,\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "3e7f4cb5", + "id": "0e41ab95", "metadata": {}, "outputs": [], "source": [ @@ -128,13 +125,13 @@ "axs[1].plot(payload[\"time_s\"], payload[\"velocity\"], color=\"tab:orange\", linewidth=1.2)\n", "axs[1].set_title(\"Stimulus derivative\")\n", "axs[1].set_ylabel(\"d(stimulus)/dt\")\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "79a0c40f", + "id": "6c2191fc", "metadata": {}, "outputs": [], "source": [ @@ -143,13 +140,13 @@ "ax = fig.subplots(1, 1)\n", "_plot_ks(ax, payload[\"ks_ideal\"], payload[\"ks_const_empirical\"], payload[\"ks_ci\"], label=\"Baseline model\", color=\"tab:blue\")\n", "ax.set_title(\"Baseline model KS plot\")\n", - "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)\n" + "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "88ef749f", + "id": "da4b0be4", "metadata": {}, "outputs": [], "source": [ @@ -164,13 +161,13 @@ "ax.scatter([lags_ms[peak_idx]], [xcorr_vals[peak_idx]], color=\"tab:red\", zorder=3)\n", "ax.set_title(\"Cross-covariance used to identify the stimulus lag\")\n", "ax.set_xlabel(\"lag (ms)\")\n", - "ax.set_ylabel(\"cross-covariance\")\n" + "ax.set_ylabel(\"cross-covariance\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "8fb09e81", + "id": "334952df", "metadata": {}, "outputs": [], "source": [ @@ -192,13 +189,13 @@ "_plot_ks(ax, payload[\"ks_ideal\"], payload[\"ks_const_empirical\"], payload[\"ks_ci\"], label=\"Baseline\", color=\"tab:blue\")\n", "ax.plot(np.asarray(payload[\"ks_ideal\"], dtype=float), np.asarray(payload[\"ks_stim_empirical\"], dtype=float), color=\"tab:orange\", linewidth=1.5, label=\"Baseline+Stimulus\")\n", "ax.set_title(\"Baseline vs stimulus-augmented model\")\n", - "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)\n" + "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "1a7627c6", + "id": "eb6dc162", "metadata": {}, "outputs": [], "source": [ @@ -224,13 +221,13 @@ "ax.axvline(history_windows[best_history_idx], color=\"tab:red\", linestyle=\"--\", linewidth=1.0)\n", "ax.set_title(\"BIC improvement across history-window choices\")\n", "ax.set_xlabel(\"history window count\")\n", - "ax.set_ylabel(\"ΔBIC relative to first history model\")\n" + "ax.set_ylabel(\"ΔBIC relative to first history model\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "3097f64f", + "id": "09de73d5", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +270,7 @@ "ax.plot(np.asarray(payload[\"ks_ideal\"], dtype=float), np.asarray(payload[\"ks_hist_empirical\"], dtype=float), color=\"tab:green\", linewidth=1.5, label=\"Baseline+Stimulus+History\")\n", "ax.set_title(\"Final KS comparison across the three models\")\n", "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -283,11 +280,11 @@ }, "nstat": { "expected_figures": 9, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "ExplicitStimulusWhiskerData" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/HippocampalPlaceCellExample.ipynb b/notebooks/HippocampalPlaceCellExample.ipynb index 5b7a7816..d5be823a 100644 --- a/notebooks/HippocampalPlaceCellExample.ipynb +++ b/notebooks/HippocampalPlaceCellExample.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "9cc6c119", + "id": "74d6cdfa", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `HippocampalPlaceCellExample.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the dataset-backed place-cell model-comparison and field-visualization workflow with the same normalized 10-term Zernike basis used by MATLAB; exact AIC/BIC values and surface styling still vary modestly because the Python GLM solver and plotting backend are not byte-identical to MATLAB.\n" + "- Remaining justified differences: The notebook now reproduces the dataset-backed place-cell model-comparison and field-visualization workflow with the same normalized 10-term Zernike basis used by MATLAB; exact AIC/BIC values and surface styling still vary modestly because the Python GLM solver and plotting backend are not byte-identical to MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "3743d52f", + "id": "15ef53db", "metadata": {}, "outputs": [], "source": [ @@ -44,14 +44,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='HippocampalPlaceCellExample', output_root=OUTPUT_ROOT, expected_count=11)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _interp_spike_positions(time_s, x_pos, y_pos, spike_times):\n", " spike_times = np.asarray(spike_times, dtype=float)\n", " return (\n", @@ -59,7 +57,6 @@ " np.interp(spike_times, np.asarray(time_s, dtype=float), np.asarray(y_pos, dtype=float)),\n", " )\n", "\n", - "\n", "def _plot_field_grid(fig, animal_key, field_key, title):\n", " animal = payload[animal_key]\n", " grid_x = np.asarray(animal[\"grid_x\"], dtype=float)\n", @@ -79,13 +76,13 @@ " ax.set_xticks([])\n", " ax.set_yticks([])\n", " fig.suptitle(title)\n", - " fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)\n" + " fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)" ] }, { "cell_type": "code", "execution_count": null, - "id": "8804c316", + "id": "30094bfc", "metadata": {}, "outputs": [], "source": [ @@ -99,13 +96,13 @@ " \"mean_delta_aic\": round(float(summary[\"mean_delta_aic_gaussian_minus_zernike\"]), 3),\n", " \"mean_delta_bic\": round(float(summary[\"mean_delta_bic_gaussian_minus_zernike\"]), 3),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "f088de8e", + "id": "33671840", "metadata": {}, "outputs": [], "source": [ @@ -119,13 +116,13 @@ "ax.set_title(f\"Animal 1, Cell {int(mesh['cell_index']) + 1}\")\n", "ax.set_xlabel(\"x\")\n", "ax.set_ylabel(\"y\")\n", - "ax.set_aspect(\"equal\", adjustable=\"box\")\n" + "ax.set_aspect(\"equal\", adjustable=\"box\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "f42a4e6c", + "id": "081e6179", "metadata": {}, "outputs": [], "source": [ @@ -146,13 +143,13 @@ "ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n", "ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n", "ax.set_ylabel(\"Gaussian - Zernike BIC\")\n", - "ax.set_title(\"Animal 1 model comparison\")\n" + "ax.set_title(\"Animal 1 model comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "c75cfbaa", + "id": "aa269c6b", "metadata": {}, "outputs": [], "source": [ @@ -173,13 +170,13 @@ "ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n", "ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n", "ax.set_ylabel(\"Gaussian - Zernike BIC\")\n", - "ax.set_title(\"Animal 2 model comparison\")\n" + "ax.set_title(\"Animal 2 model comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "366f0e8c", + "id": "26aafec5", "metadata": {}, "outputs": [], "source": [ @@ -253,7 +250,7 @@ " family=\"monospace\",\n", " fontsize=10,\n", ")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -263,11 +260,11 @@ }, "nstat": { "expected_figures": 11, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "HippocampalPlaceCellExample" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/HistoryExamples.ipynb b/notebooks/HistoryExamples.ipynb index e0bad1e1..173a3ff4 100644 --- a/notebooks/HistoryExamples.ipynb +++ b/notebooks/HistoryExamples.ipynb @@ -49,7 +49,7 @@ "\n", "# SECTION 0: Section 0\n", "# Test History\n", - "# Generate a nspikeTrain and define a set of history windows of interest. We desire windows from 1-2ms, 2-3ms, 3-5ms, and 5-10ms, then compute the corresponding history covariates.\n" + "# Generate a nspikeTrain and define a set of history windows of interest. We desire windows from 1-2ms, 2-3ms, 3-5ms, and 5-10ms, then compute the corresponding history covariates." ] }, { @@ -76,7 +76,7 @@ "ax1.set_title(\"History windows\")\n", "ax2.set_title(\"History covariate for Neuron1\")\n", "ax3.set_title(\"Neuron1 spike raster\")\n", - "fig.tight_layout()\n" + "fig.tight_layout()" ] }, { @@ -108,7 +108,7 @@ "ax = fig.subplots(1, 1)\n", "coll.plot(handle=ax)\n", "fig.tight_layout()\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], diff --git a/notebooks/HybridFilterExample.ipynb b/notebooks/HybridFilterExample.ipynb index 72346edf..e0a9614c 100644 --- a/notebooks/HybridFilterExample.ipynb +++ b/notebooks/HybridFilterExample.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "51ab8762", + "id": "d60b95a7", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `HybridFilterExample.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the hybrid-filter simulation, single-run decoding, and averaged summary figures with real outputs; the Python port still uses the current hybrid-filter implementation instead of every MATLAB-specific reporting branch.\n" + "- Remaining justified differences: The notebook now reproduces the hybrid-filter simulation, single-run decoding, and averaged summary figures with real outputs; the Python port still uses the current hybrid-filter implementation instead of every MATLAB-specific reporting branch." ] }, { "cell_type": "code", "execution_count": null, - "id": "b9c53e65", + "id": "d22dd216", "metadata": {}, "outputs": [], "source": [ @@ -42,14 +42,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='HybridFilterExample', output_root=OUTPUT_ROOT, expected_count=3)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _plot_raster(ax, time_s, spikes, *, max_cells=18):\n", " n_cells = min(int(spikes.shape[1]), max_cells)\n", " for row in range(n_cells):\n", @@ -57,13 +55,13 @@ " if spike_times.size:\n", " ax.vlines(spike_times, row + 0.6, row + 1.4, color=\"k\", linewidth=0.35)\n", " ax.set_ylim(0.5, n_cells + 0.5)\n", - " ax.set_ylabel(\"cell\")\n" + " ax.set_ylabel(\"cell\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2aeada10", + "id": "cacac4c3", "metadata": {}, "outputs": [], "source": [ @@ -81,35 +79,35 @@ " \"num_cells\": int(summary[\"num_cells\"]),\n", " \"state_accuracy\": round(float(summary[\"state_accuracy\"]), 3),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "1ca4154b", + "id": "b031dd85", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Problem Statement\n", - "# We infer both a discrete movement state and a continuous reach trajectory from point-process observations.\n" + "# We infer both a discrete movement state and a continuous reach trajectory from point-process observations." ] }, { "cell_type": "code", "execution_count": null, - "id": "bc311d6b", + "id": "fd11f1d4", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: Hybrid state-space setup\n", - "# The Python port keeps the same two-state problem structure as MATLAB: a low-motion state and a movement state.\n" + "# The Python port keeps the same two-state problem structure as MATLAB: a low-motion state and a movement state." ] }, { "cell_type": "code", "execution_count": null, - "id": "9a25e97a", + "id": "fd690f04", "metadata": {}, "outputs": [], "source": [ @@ -155,24 +153,24 @@ " va=\"top\",\n", " family=\"monospace\",\n", " fontsize=9,\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "78a4e6b5", + "id": "f9e4eb9d", "metadata": {}, "outputs": [], "source": [ "# SECTION 4: Simulate Neural Firing\n", - "# The simulated spike population depends on the latent state and the movement dynamics.\n" + "# The simulated spike population depends on the latent state and the movement dynamics." ] }, { "cell_type": "code", "execution_count": null, - "id": "90fd8d80", + "id": "57220fb5", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +229,7 @@ " color=[\"tab:blue\", \"tab:orange\"],\n", ")\n", "axs[1, 1].set_title(\"Single-run decoding RMSE\")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -241,11 +239,11 @@ }, "nstat": { "expected_figures": 3, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "HybridFilterExample" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/NetworkTutorial.ipynb b/notebooks/NetworkTutorial.ipynb index 0f6a4622..cc460728 100644 --- a/notebooks/NetworkTutorial.ipynb +++ b/notebooks/NetworkTutorial.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "81a6687d", + "id": "5f17a36a", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a559c47", + "id": "5f466245", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bb3afdf", + "id": "4a15a6be", "metadata": {}, "outputs": [], "source": [ @@ -172,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60c517c7", + "id": "2c6a36bd", "metadata": {}, "outputs": [], "source": [ @@ -185,7 +185,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e75a150", + "id": "8a5a2fc7", "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c1a5a339", + "id": "f7bdd433", "metadata": {}, "outputs": [], "source": [ @@ -217,7 +217,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8518832", + "id": "65b83184", "metadata": {}, "outputs": [], "source": [ @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b58aea0f", + "id": "079aaae7", "metadata": {}, "outputs": [], "source": [ @@ -239,7 +239,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db9ba65f", + "id": "f2fab951", "metadata": {}, "outputs": [], "source": [ @@ -263,7 +263,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6c7098b3", + "id": "e50a0d22", "metadata": {}, "outputs": [], "source": [ @@ -274,7 +274,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a805e8c4", + "id": "b9331965", "metadata": {}, "outputs": [], "source": [ @@ -285,7 +285,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a1559df4", + "id": "e086cad1", "metadata": {}, "outputs": [], "source": [ @@ -299,7 +299,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b472fd0e", + "id": "ed5a15eb", "metadata": {}, "outputs": [], "source": [ @@ -310,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d71c27d", + "id": "41f7ea4d", "metadata": {}, "outputs": [], "source": [ @@ -323,7 +323,7 @@ { "cell_type": "code", "execution_count": null, - "id": "406603d4", + "id": "de315fea", "metadata": {}, "outputs": [], "source": [ @@ -336,7 +336,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f9878dce", + "id": "d5901b54", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ddb2550", + "id": "96ebc713", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1554e895", + "id": "2a0ee92c", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +373,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fcb5f672", + "id": "76c17655", "metadata": {}, "outputs": [], "source": [ @@ -394,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37804e6b", + "id": "50c4e444", "metadata": {}, "outputs": [], "source": [ @@ -428,7 +428,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3455475c", + "id": "37dc5453", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84eb9f17", + "id": "18abcf5f", "metadata": {}, "outputs": [], "source": [ @@ -479,7 +479,7 @@ { "cell_type": "code", "execution_count": null, - "id": "165afc2b", + "id": "acffc016", "metadata": {}, "outputs": [], "source": [ @@ -509,4 +509,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/PPSimExample.ipynb b/notebooks/PPSimExample.ipynb index 820a2183..cd5e02cf 100644 --- a/notebooks/PPSimExample.ipynb +++ b/notebooks/PPSimExample.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "d32212af", + "id": "2eba515a", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `PPSimExample.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now follows the MATLAB recursive-CIF workflow with the native Python `CIF.simulateCIF` path; exact Simulink block timing and solver semantics are still not fixture-matched one-for-one against MATLAB.\n" + "- Remaining justified differences: The notebook now follows the MATLAB recursive-CIF workflow with the native Python `CIF.simulateCIF` path; exact Simulink block timing and solver semantics are still not fixture-matched one-for-one against MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "10c55dde", + "id": "a6759a0c", "metadata": {}, "outputs": [], "source": [ @@ -42,14 +42,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='PPSimExample', output_root=OUTPUT_ROOT, expected_count=8)\n", "\n", - "\n", "def _figure(label: str, *, figsize=(8.5, 4.5)):\n", " fig = __tracker.new_figure(label)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "Ts = 0.001\n", "tMin = 0.0\n", "tMax = 50.0\n", @@ -65,69 +63,69 @@ "sC, lambda_cov = CIF.simulateCIF(mu, H, S, E, stim, ens, 5, \"binomial\", seed=5, return_lambda=True)\n", "cc = CovColl([stim, baseline])\n", "trial = Trial(sC, cc)\n", - "print({\"duration_s\": tMax, \"num_realizations\": sC.numSpikeTrains, \"mean_rate_hz\": round(float(np.mean(lambda_cov.data[:, 0])), 3)})\n" + "print({\"duration_s\": tMax, \"num_realizations\": sC.numSpikeTrains, \"mean_rate_hz\": round(float(np.mean(lambda_cov.data[:, 0])), 3)})" ] }, { "cell_type": "code", "execution_count": null, - "id": "82e2e6d2", + "id": "fb22ed56", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: General Point Process Simulation\n", - "plt.close(\"all\")\n" + "plt.close(\"all\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "bb8a8dca", + "id": "aefbf353", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: Point Process Sample Path Generation\n", - "print(\"Using native Python CIF.simulateCIF to mirror the MATLAB recursive-CIF workflow.\")\n" + "print(\"Using native Python CIF.simulateCIF to mirror the MATLAB recursive-CIF workflow.\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "7aa83848", + "id": "ae867985", "metadata": {}, "outputs": [], "source": [ "# SECTION 3: History Effect\n", "selfHist = [0.0, 0.001, 0.002, 0.003]\n", - "print({\"history_windows_s\": selfHist})\n" + "print({\"history_windows_s\": selfHist})" ] }, { "cell_type": "code", "execution_count": null, - "id": "e8660e03", + "id": "9076f004", "metadata": {}, "outputs": [], "source": [ "# SECTION 4: Stimulus Effect\n", - "print({\"stimulus_frequency_hz\": 1.0, \"stimulus_amplitude\": 1.0})\n" + "print({\"stimulus_frequency_hz\": 1.0, \"stimulus_amplitude\": 1.0})" ] }, { "cell_type": "code", "execution_count": null, - "id": "7b3aa59d", + "id": "fa8120b8", "metadata": {}, "outputs": [], "source": [ "# SECTION 5: Ensemble Effect\n", - "print({\"ensemble_effect\": 0.0})\n" + "print({\"ensemble_effect\": 0.0})" ] }, { "cell_type": "code", "execution_count": null, - "id": "a2f76513", + "id": "c58f3108", "metadata": {}, "outputs": [], "source": [ @@ -137,13 +135,13 @@ "sC.plot(handle=axs[0])\n", "axs[0].set_xlim(0.0, tMax / 5.0)\n", "stim.plot(handle=axs[1])\n", - "axs[1].set_xlim(0.0, tMax / 5.0)\n" + "axs[1].set_xlim(0.0, tMax / 5.0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "5dd18704", + "id": "a7b37585", "metadata": {}, "outputs": [], "source": [ @@ -151,13 +149,13 @@ "fig = _figure(\"figure; lambda.plot\", figsize=(10.0, 4.0))\n", "ax = fig.subplots(1, 1)\n", "lambda_cov.getSubSignal(1).plot(handle=ax)\n", - "ax.set_xlim(0.0, tMax / 5.0)\n" + "ax.set_xlim(0.0, tMax / 5.0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "d1867d92", + "id": "bac3e6f1", "metadata": {}, "outputs": [], "source": [ @@ -167,84 +165,84 @@ " TrialConfig([[\"Baseline\", \"mu\"], [\"Stimulus\", \"sin\"]], sampleRate=1.0 / Ts, name=\"Stim\"),\n", " TrialConfig([[\"Baseline\", \"mu\"], [\"Stimulus\", \"sin\"]], sampleRate=1.0 / Ts, history=selfHist, name=\"Stim+Hist\"),\n", "]\n", - "cfgColl = ConfigColl(cfg)\n" + "cfgColl = ConfigColl(cfg)" ] }, { "cell_type": "code", "execution_count": null, - "id": "393f2a17", + "id": "278b16e6", "metadata": {}, "outputs": [], "source": [ "# SECTION 9: Choose the MATLAB-style fitting algorithm\n", "Algorithm = \"BNLRCG\"\n", - "print({\"algorithm\": Algorithm, \"binary_representation\": bool(sC.getNST(1).isSigRepBinary())})\n" + "print({\"algorithm\": Algorithm, \"binary_representation\": bool(sC.getNST(1).isSigRepBinary())})" ] }, { "cell_type": "code", "execution_count": null, - "id": "59d8878a", + "id": "1c9f83d1", "metadata": {}, "outputs": [], "source": [ "# SECTION 10: GLM Model Fitting and Results\n", - "results = Analysis.RunAnalysisForAllNeurons(trial, cfgColl)\n" + "results = Analysis.RunAnalysisForAllNeurons(trial, cfgColl)" ] }, { "cell_type": "code", "execution_count": null, - "id": "e13231cb", + "id": "c939f67c", "metadata": {}, "outputs": [], "source": [ "# SECTION 11: Results for sample neuron\n", "fig = _figure(\"results{1}.plotResults\", figsize=(11.0, 8.0))\n", - "results[0].plotResults(handle=fig)\n" + "results[0].plotResults(handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "9298a78b", + "id": "21f39091", "metadata": {}, "outputs": [], "source": [ "# SECTION 12: Baseline-only diagnostic view\n", "fig = _figure(\"results{1}.plotResults baseline\", figsize=(11.0, 8.0))\n", - "results[0].plotResults(fit_num=1, handle=fig)\n" + "results[0].plotResults(fit_num=1, handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "1bebdc33", + "id": "6faa0468", "metadata": {}, "outputs": [], "source": [ "# SECTION 13: Stimulus model diagnostic view\n", "fig = _figure(\"results{2}.plotResults stim\", figsize=(11.0, 8.0))\n", - "results[0].plotResults(fit_num=2, handle=fig)\n" + "results[0].plotResults(fit_num=2, handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "37097000", + "id": "27d1e5af", "metadata": {}, "outputs": [], "source": [ "# SECTION 14: Stimulus-plus-history diagnostic view\n", "fig = _figure(\"results{3}.plotResults hist\", figsize=(11.0, 8.0))\n", - "results[0].plotResults(fit_num=3, handle=fig)\n" + "results[0].plotResults(fit_num=3, handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "a6cdb847", + "id": "bd5e5ca2", "metadata": {}, "outputs": [], "source": [ @@ -252,13 +250,13 @@ "fig = _figure(\"results.lambda.plot\", figsize=(9.5, 4.5))\n", "ax = fig.subplots(1, 1)\n", "results[0].lambdaSignal.getSubSignal(3).plot(handle=ax)\n", - "ax.set_xlim(0.0, tMax / 5.0)\n" + "ax.set_xlim(0.0, tMax / 5.0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "78d2200b", + "id": "9d323187", "metadata": {}, "outputs": [], "source": [ @@ -266,13 +264,13 @@ "summary = FitResSummary(results)\n", "fig = _figure(\"Summary.plotSummary\", figsize=(10.0, 4.5))\n", "summary.plotSummary(handle=fig)\n", - "print({\"fit_names\": summary.fitNames, \"mean_AIC\": np.asarray(summary.meanAIC, dtype=float).round(3).tolist()})\n" + "print({\"fit_names\": summary.fitNames, \"mean_AIC\": np.asarray(summary.meanAIC, dtype=float).round(3).tolist()})" ] }, { "cell_type": "code", "execution_count": null, - "id": "b920b8cb", + "id": "0c7da4fb", "metadata": {}, "outputs": [], "source": [ @@ -283,7 +281,7 @@ "ax.set_xticks(np.arange(len(summary.fitNames)), summary.fitNames, rotation=20)\n", "ax.set_ylabel(\"mean AIC\")\n", "ax.set_title(\"Model comparison across realizations\")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -300,4 +298,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/SignalObjExamples.ipynb b/notebooks/SignalObjExamples.ipynb index 76e0a497..5666e428 100644 --- a/notebooks/SignalObjExamples.ipynb +++ b/notebooks/SignalObjExamples.ipynb @@ -31,7 +31,7 @@ "\n", "# SECTION 0: Section 0\n", "# Using the SignalObj Class\n", - "# In this file we will give several examples of how the SignalObj can be used. A description of all of the properties of SignalObj can be found at: SignalObj Class Definition\n" + "# In this file we will give several examples of how the SignalObj can be used. A description of all of the properties of SignalObj can be found at: SignalObj Class Definition" ] }, { diff --git a/notebooks/StimulusDecode2D.ipynb b/notebooks/StimulusDecode2D.ipynb index e90ff61a..e4c5aa29 100644 --- a/notebooks/StimulusDecode2D.ipynb +++ b/notebooks/StimulusDecode2D.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "4f37c1b6", + "id": "4599bf89", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `StimulusDecode2D.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB.\n" + "- Remaining justified differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "f94aa3da", + "id": "6f7a27a5", "metadata": {}, "outputs": [], "source": [ @@ -42,20 +42,17 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='StimulusDecode2D', output_root=OUTPUT_ROOT, expected_count=6)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _subplot_grid(count):\n", " rows = max(int(np.floor(np.sqrt(count))), 1)\n", " cols = int(np.ceil(count / rows))\n", " return rows, cols\n", "\n", - "\n", "def _simulate_decode(seed=0, *, num_realizations=80, delta=0.001, tmax=1.0):\n", " rng = np.random.default_rng(seed)\n", " time = np.arange(0.0, tmax + delta, delta)\n", @@ -149,7 +146,6 @@ " \"num_cells\": num_realizations,\n", " }\n", "\n", - "\n", "def _plot_raster(ax, time_s, spikes, *, max_cells=20):\n", " n_cells = min(int(spikes.shape[1]), max_cells)\n", " for row in range(n_cells):\n", @@ -157,13 +153,13 @@ " if spike_times.size:\n", " ax.vlines(spike_times, row + 0.6, row + 1.4, color=\"k\", linewidth=0.35)\n", " ax.set_ylim(0.5, n_cells + 0.5)\n", - " ax.set_ylabel(\"cell\")\n" + " ax.set_ylabel(\"cell\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "4c32de3a", + "id": "8809d377", "metadata": {}, "outputs": [], "source": [ @@ -178,13 +174,13 @@ " \"decode_rmse\": round(float(payload[\"decode_rmse\"]), 4),\n", " \"fallback_error\": payload[\"decode_error\"] or \"\",\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "25960ea2", + "id": "d87abf0b", "metadata": {}, "outputs": [], "source": [ @@ -224,13 +220,13 @@ " ax.set_xticks([])\n", " ax.set_yticks([])\n", "if image is not None:\n", - " fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)\n" + " fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)" ] }, { "cell_type": "code", "execution_count": null, - "id": "b0784e40", + "id": "ab3bc9c7", "metadata": {}, "outputs": [], "source": [ @@ -242,13 +238,13 @@ "axs[1].plot(payload[\"time_s\"], np.mean(payload[\"spikes\"], axis=1), color=\"tab:green\", linewidth=1.2)\n", "axs[1].set_title(\"Population firing fraction\")\n", "axs[1].set_xlabel(\"time (s)\")\n", - "axs[1].set_ylabel(\"mean spike/bin\")\n" + "axs[1].set_ylabel(\"mean spike/bin\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "714372bb", + "id": "5eddb87b", "metadata": {}, "outputs": [], "source": [ @@ -286,7 +282,7 @@ "ax.set_xlabel(\"time (s)\")\n", "ax.set_ylabel(\"Euclidean error\")\n", "ax.legend(loc=\"best\", frameon=False, fontsize=8)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -296,11 +292,11 @@ }, "nstat": { "expected_figures": 6, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "StimulusDecode2D" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/TrialConfigExamples.ipynb b/notebooks/TrialConfigExamples.ipynb index 525a79cc..41cae891 100644 --- a/notebooks/TrialConfigExamples.ipynb +++ b/notebooks/TrialConfigExamples.ipynb @@ -32,7 +32,7 @@ "# SECTION 0: Section 0\n", "# TrialConfig Examples\n", "# tcObj=TrialConfig(covMask,sampleRate, history,minTime,maxTime)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], diff --git a/notebooks/TrialExamples.ipynb b/notebooks/TrialExamples.ipynb index ffe918b2..c56a6d48 100644 --- a/notebooks/TrialExamples.ipynb +++ b/notebooks/TrialExamples.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "914238fa", + "id": "d48a2677", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `TrialExamples.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now mirrors the MATLAB Trial workflow with executable object construction, masking, history extraction, and plotting; only minor Python plotting defaults differ from the published MATLAB help output.\n" + "- Remaining justified differences: The notebook now mirrors the MATLAB Trial workflow with executable object construction, masking, history extraction, and plotting; only minor Python plotting defaults differ from the published MATLAB help output." ] }, { "cell_type": "code", "execution_count": null, - "id": "8ffbef9e", + "id": "f23e6389", "metadata": {}, "outputs": [], "source": [ @@ -42,14 +42,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='TrialExamples', output_root=OUTPUT_ROOT, expected_count=6)\n", "\n", - "\n", "def _figure(label: str, *, figsize=(8.5, 3.5)):\n", " fig = __tracker.new_figure(label)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _build_trial():\n", " length_trial = 1.0\n", " sample_rate = 1000.0\n", @@ -109,7 +107,6 @@ " \"trial\": trial,\n", " }\n", "\n", - "\n", "ctx = _build_trial()\n", "print(\n", " {\n", @@ -118,13 +115,13 @@ " \"covariates\": ctx[\"cov_coll\"].names,\n", " \"history_windows\": ctx[\"history\"].windowTimes.tolist(),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "3ada46f2", + "id": "5f01a6dd", "metadata": {}, "outputs": [], "source": [ @@ -134,76 +131,76 @@ "spikeColl = ctx[\"spike_coll\"]\n", "cc = ctx[\"cov_coll\"]\n", "e = ctx[\"events\"]\n", - "h = ctx[\"history\"]\n" + "h = ctx[\"history\"]" ] }, { "cell_type": "code", "execution_count": null, - "id": "c5ad1d59", + "id": "947fce2d", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: Create History windows of interest\n", "fig = _figure(\"figure; h.plot\", figsize=(8.0, 2.5))\n", "ax = fig.subplots(1, 1)\n", - "h.plot(handle=ax)\n" + "h.plot(handle=ax)" ] }, { "cell_type": "code", "execution_count": null, - "id": "87bae7f2", + "id": "948c1c5e", "metadata": {}, "outputs": [], "source": [ "# SECTION 3: Load Covariates\n", "fig = _figure(\"figure; cc.plot\", figsize=(8.5, 5.0))\n", - "cc.plot(handle=fig)\n" + "cc.plot(handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "2bb06d98", + "id": "addc115e", "metadata": {}, "outputs": [], "source": [ "# SECTION 4: Create trial events\n", "fig = _figure(\"figure; e.plot\", figsize=(8.0, 2.3))\n", "ax = fig.subplots(1, 1)\n", - "e.plot(handle=ax)\n" + "e.plot(handle=ax)" ] }, { "cell_type": "code", "execution_count": null, - "id": "9ecba539", + "id": "8160b030", "metadata": {}, "outputs": [], "source": [ "# SECTION 5: Create neural Spike Train Data\n", "fig = _figure(\"figure; spikeColl.plot\", figsize=(8.5, 3.5))\n", "ax = fig.subplots(1, 1)\n", - "spikeColl.plot(handle=ax)\n" + "spikeColl.plot(handle=ax)" ] }, { "cell_type": "code", "execution_count": null, - "id": "828b38f9", + "id": "f15709a2", "metadata": {}, "outputs": [], "source": [ "# SECTION 6: Finally we have everything we need to create a Trial object.\n", "fig = _figure(\"figure; trial1.plot\", figsize=(9.0, 8.0))\n", - "trial1.plot(handle=fig)\n" + "trial1.plot(handle=fig)" ] }, { "cell_type": "code", "execution_count": null, - "id": "24ada846", + "id": "d25ae0c3", "metadata": {}, "outputs": [], "source": [ @@ -213,30 +210,30 @@ "trial1.plot(handle=fig)\n", "hist_cov = trial1.getHistForNeurons([1, 2])\n", "print({\"masked_labels\": trial1.getLabelsFromMask(1), \"history_covariates\": hist_cov.getAllCovLabels()[:4]})\n", - "trial1.resetCovMask()\n" + "trial1.resetCovMask()" ] }, { "cell_type": "code", "execution_count": null, - "id": "406f0e5e", + "id": "48633477", "metadata": {}, "outputs": [], "source": [ "# SECTION 8: Example 2: Analyzing Trial Data\n", - "print(\"Examples of neural spike analysis using AnalysisExamples2 (Neural Spike Analysis Toolbox) or AnalysisExamples (standard methods).\")\n" + "print(\"Examples of neural spike analysis using AnalysisExamples2 (Neural Spike Analysis Toolbox) or AnalysisExamples (standard methods).\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2a99d085", + "id": "ebaa9bdc", "metadata": {}, "outputs": [], "source": [ "# SECTION 9: Related analysis workflows\n", "print({\"recommended_next_notebooks\": [\"AnalysisExamples2\", \"AnalysisExamples\"]})\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -253,4 +250,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/ValidationDataSet.ipynb b/notebooks/ValidationDataSet.ipynb index 120e63a0..5512e73c 100644 --- a/notebooks/ValidationDataSet.ipynb +++ b/notebooks/ValidationDataSet.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "d15a297a", + "id": "ff1426a4", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `ValidationDataSet.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the constant-rate and piecewise-rate validation workflows with real `Trial`/`Analysis` objects and figure outputs; local execution uses the MATLAB-scale simulation sizes, while CI switches to a documented shorter deterministic fast path for stability.\n" + "- Remaining justified differences: The notebook now reproduces the constant-rate and piecewise-rate validation workflows with real `Trial`/`Analysis` objects and figure outputs; local execution uses the MATLAB-scale simulation sizes, while CI switches to a documented shorter deterministic fast path for stability." ] }, { "cell_type": "code", "execution_count": null, - "id": "f2698163", + "id": "5a19211e", "metadata": {}, "outputs": [], "source": [ @@ -44,14 +44,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='ValidationDataSet', output_root=OUTPUT_ROOT, expected_count=10)\n", "\n", - "\n", "def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n", " fig = __tracker.new_figure(matlab_line)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "def _lambda_columns(fit_result):\n", " time = np.asarray(fit_result.lambda_signal.time, dtype=float)\n", " data = np.asarray(fit_result.lambda_signal.data, dtype=float)\n", @@ -59,10 +57,8 @@ " data = data[:, None]\n", " return time, data\n", "\n", - "\n", "CI_FAST_PATH = os.environ.get(\"CI\", \"\").strip().lower() in {\"1\", \"true\", \"yes\"}\n", "\n", - "\n", "def _simulate_constant_case(seed=0, *, p=0.01, n_samples=None, delta=0.001):\n", " if n_samples is None:\n", " n_samples = 20001 if CI_FAST_PATH else 100001\n", @@ -91,7 +87,6 @@ " \"trains\": trains,\n", " }\n", "\n", - "\n", "def _simulate_piecewise_case(seed=1, *, p1=0.001, p2=0.01, n1=None, n2=None, delta=0.001):\n", " if n1 is None:\n", " n1 = 20000 if CI_FAST_PATH else 100000\n", @@ -139,7 +134,6 @@ " \"trains\": trains,\n", " }\n", "\n", - "\n", "def _plot_isi_hist(ax, train, lambda_hz, *, title):\n", " isi = np.asarray(train.getISIs(), dtype=float)\n", " if isi.size:\n", @@ -148,13 +142,13 @@ " ax.plot(x, lambda_hz * np.exp(-lambda_hz * x), color=\"tab:red\", linewidth=1.5)\n", " ax.set_title(title)\n", " ax.set_xlabel(\"ISI (s)\")\n", - " ax.set_ylabel(\"density\")\n" + " ax.set_ylabel(\"density\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "702ec99c", + "id": "de07a751", "metadata": {}, "outputs": [], "source": [ @@ -170,36 +164,36 @@ " \"piecewise_lambda1_hz\": round(float(piecewise_case[\"lambda1_hz\"]), 4),\n", " \"piecewise_lambda2_hz\": round(float(piecewise_case[\"lambda2_hz\"]), 4),\n", " }\n", - ")\n" + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "a65040f3", + "id": "4c326ba4", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Case #1: Constant Rate Poisson Process\n", - "# First we verify that the analysis recovers a constant Poisson rate from simulated spike trains.\n" + "# First we verify that the analysis recovers a constant Poisson rate from simulated spike trains." ] }, { "cell_type": "code", "execution_count": null, - "id": "1f04cbfd", + "id": "0348ff8b", "metadata": {}, "outputs": [], "source": [ "# SECTION 2: Generate constant-rate neural firing activity\n", "constant_time = np.asarray(constant_case[\"time_s\"], dtype=float)\n", - "constant_trains = list(constant_case[\"trains\"])\n" + "constant_trains = list(constant_case[\"trains\"])" ] }, { "cell_type": "code", "execution_count": null, - "id": "1a464c1b", + "id": "127d7e94", "metadata": {}, "outputs": [], "source": [ @@ -207,13 +201,13 @@ "fig = _prepare_figure(\"nst{1}.plotISIHistogram\", figsize=(10.0, 4.0))\n", "axs = fig.subplots(1, 2)\n", "_plot_isi_hist(axs[0], constant_trains[0], constant_case[\"lambda_hz\"], title=\"Neuron 1 ISI histogram\")\n", - "_plot_isi_hist(axs[1], constant_trains[1], constant_case[\"lambda_hz\"], title=\"Neuron 2 ISI histogram\")\n" + "_plot_isi_hist(axs[1], constant_trains[1], constant_case[\"lambda_hz\"], title=\"Neuron 2 ISI histogram\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "fb4f5ef8", + "id": "ce2f5297", "metadata": {}, "outputs": [], "source": [ @@ -229,13 +223,13 @@ "ax.set_xticks(xloc, [f\"Neuron {idx}\" for idx in xloc])\n", "ax.set_ylabel(\"μ coefficient\")\n", "ax.set_title(\"Estimated constant-rate coefficient\")\n", - "ax.legend(loc=\"best\", frameon=False)\n" + "ax.legend(loc=\"best\", frameon=False)" ] }, { "cell_type": "code", "execution_count": null, - "id": "c286b8c0", + "id": "e4a199c4", "metadata": {}, "outputs": [], "source": [ @@ -251,26 +245,26 @@ " ax.set_xlabel(\"time (s)\")\n", " ax.grid(alpha=0.25)\n", "axs[0].set_ylabel(\"rate (Hz)\")\n", - "axs[1].legend(loc=\"best\", frameon=False, fontsize=8)\n" + "axs[1].legend(loc=\"best\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "879b1951", + "id": "26380744", "metadata": {}, "outputs": [], "source": [ "# SECTION 6: Case #2: Piece-wise Constant Rate Poisson Process\n", "# Next we compare a single-rate model against a two-epoch rate model.\n", "piecewise_time = np.asarray(piecewise_case[\"time_s\"], dtype=float)\n", - "piecewise_trains = list(piecewise_case[\"trains\"])\n" + "piecewise_trains = list(piecewise_case[\"trains\"])" ] }, { "cell_type": "code", "execution_count": null, - "id": "58f09d75", + "id": "cf2c6402", "metadata": {}, "outputs": [], "source": [ @@ -294,24 +288,24 @@ "ax.set_title(\"Ground-truth rates for the two-epoch simulation\")\n", "ax.set_xlabel(\"time (s)\")\n", "ax.set_ylabel(\"rate (Hz)\")\n", - "ax.legend(loc=\"best\", frameon=False, fontsize=8)\n" + "ax.legend(loc=\"best\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "6928e1f4", + "id": "25dfb2e5", "metadata": {}, "outputs": [], "source": [ "# SECTION 8: Setup the piecewise-rate analysis\n", - "piecewise_results = Analysis.RunAnalysisForAllNeurons(piecewise_case[\"trial\"], piecewise_case[\"cfg\"], 0)\n" + "piecewise_results = Analysis.RunAnalysisForAllNeurons(piecewise_case[\"trial\"], piecewise_case[\"cfg\"], 0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "3522b3b9", + "id": "113d3272", "metadata": {}, "outputs": [], "source": [ @@ -340,13 +334,13 @@ " ax.set_xlabel(\"time (s)\")\n", " ax.grid(alpha=0.25)\n", "axs[0].set_ylabel(\"rate (Hz)\")\n", - "axs[1].legend(loc=\"best\", frameon=False, fontsize=8)\n" + "axs[1].legend(loc=\"best\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "f576be75", + "id": "1d9d4cce", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +367,7 @@ "ax.set_ylabel(\"log-likelihood\")\n", "ax.set_title(\"Per-neuron log-likelihood comparison\")\n", "ax.legend(loc=\"best\", frameon=False, fontsize=8)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -383,11 +377,11 @@ }, "nstat": { "expected_figures": 10, - "run_group": "smoke", + "run_group": "full", "style": "python-example", "topic": "ValidationDataSet" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/mEPSCAnalysis.ipynb b/notebooks/mEPSCAnalysis.ipynb index 4f1a9241..9dc5dec0 100644 --- a/notebooks/mEPSCAnalysis.ipynb +++ b/notebooks/mEPSCAnalysis.ipynb @@ -59,7 +59,7 @@ "\n", "# SECTION 0: Section 0\n", "# MINIATURE EXCITATORY POST-SYNAPTIC CURRENTS (mEPSCs)\n", - "# Data from Marnie Phillips; this notebook keeps the original analysis narrative but replaces the old placeholder cells with executable Python workflows.\n" + "# Data from Marnie Phillips; this notebook keeps the original analysis narrative but replaces the old placeholder cells with executable Python workflows." ] }, { @@ -128,7 +128,7 @@ "\n", "fig = __tracker.new_figure(\"constant-magnesium-results\")\n", "const_results.plotResults(handle=fig)\n", - "print({\"constant_events\": int(const_spike_times.size), \"AIC\": const_results.AIC.tolist()})\n" + "print({\"constant_events\": int(const_spike_times.size), \"AIC\": const_results.AIC.tolist()})" ] }, { @@ -196,7 +196,7 @@ "ax.set_title(\"Washout event raster with selected segments\")\n", "for marker in (260.0, 400.0, 745.0):\n", " ax.axvline(marker, color=\"tab:red\", linestyle=\"--\", linewidth=1.0)\n", - "fig.tight_layout()\n" + "fig.tight_layout()" ] }, { @@ -253,7 +253,7 @@ "])\n", "results = Analysis.RunAnalysisForNeuron(washout_trial, 1, configs, 0)\n", "summary = FitResSummary([results])\n", - "print({\"washout_events\": int(washout_spikes.size), \"config_names\": results.configNames})\n" + "print({\"washout_events\": int(washout_spikes.size), \"config_names\": results.configNames})" ] }, { @@ -276,7 +276,7 @@ "\n", "fig = __tracker.new_figure(\"washout-summary\")\n", "summary.plotSummary(handle=fig)\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] }, { diff --git a/notebooks/nSTATPaperExamples.ipynb b/notebooks/nSTATPaperExamples.ipynb index 093de5b2..4ee710a6 100644 --- a/notebooks/nSTATPaperExamples.ipynb +++ b/notebooks/nSTATPaperExamples.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "989e2794", + "id": "e7957a0e", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `nSTATPaperExamples.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now executes the canonical paper-example workflows through the standalone Python implementations and real figshare-backed datasets; exact numerical traces and figure styling still vary modestly because the Python GLM/decoder stack and plotting defaults are not byte-identical to MATLAB.\n" + "- Remaining justified differences: The notebook now executes the canonical paper-example workflows through the standalone Python implementations and real figshare-backed datasets; exact numerical traces and figure styling still vary modestly because the Python GLM/decoder stack and plotting defaults are not byte-identical to MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "e3ae0b3f", + "id": "64f99156", "metadata": {}, "outputs": [], "source": [ @@ -52,14 +52,12 @@ "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"nSTATPaperExamples\", output_root=OUTPUT_ROOT, expected_count=26)\n", "\n", - "\n", "def _fig(label: str, *, figsize=(8.5, 4.5)):\n", " fig = __tracker.new_figure(label)\n", " fig.clear()\n", " fig.set_size_inches(*figsize)\n", " return fig\n", "\n", - "\n", "plt.close(\"all\")\n", "exp1_summary, exp1 = run_experiment1(DATA_DIR, return_payload=True)\n", "exp2_summary, exp2 = run_experiment2(DATA_DIR, return_payload=True)\n", @@ -69,24 +67,24 @@ "exp5_summary, exp5 = run_experiment5(return_payload=True)\n", "exp5b_summary, exp5b = run_experiment5b(return_payload=True)\n", "exp6_summary, exp6 = run_experiment6(REPO_ROOT, return_payload=True)\n", - "print({\"dataset_root\": str(DATA_DIR), \"paper_examples_loaded\": 8})\n" + "print({\"dataset_root\": str(DATA_DIR), \"paper_examples_loaded\": 8})" ] }, { "cell_type": "code", "execution_count": null, - "id": "ad04d480", + "id": "cf03a39b", "metadata": {}, "outputs": [], "source": [ "# SECTION 1: Experiment 1\n", - "print(exp1_summary)\n" + "print(exp1_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "f5fd7ff0", + "id": "facb48a1", "metadata": {}, "outputs": [], "source": [ @@ -96,24 +94,24 @@ "ax.plot(exp1[\"constant_time_s\"], exp1[\"constant_rate_hz\"], color=\"tab:blue\", linewidth=1.4)\n", "ax.set_xlabel(\"time (s)\")\n", "ax.set_ylabel(\"rate (Hz)\")\n", - "ax.set_title(\"Constant Mg condition: homogeneous Poisson fit\")\n" + "ax.set_title(\"Constant Mg condition: homogeneous Poisson fit\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2b8b5cc6", + "id": "b3453ddd", "metadata": {}, "outputs": [], "source": [ "# SECTION 3: Varying Magnesium Concentration - Piecewise Constant rate poisson\n", - "print({\"decreasing_condition_spikes\": exp1_summary[\"decreasing_condition_spikes\"], \"piecewise_model_aic\": round(float(exp1_summary[\"piecewise_model_aic\"]), 3)})\n" + "print({\"decreasing_condition_spikes\": exp1_summary[\"decreasing_condition_spikes\"], \"piecewise_model_aic\": round(float(exp1_summary[\"piecewise_model_aic\"]), 3)})" ] }, { "cell_type": "code", "execution_count": null, - "id": "36d369c6", + "id": "6a1a4315", "metadata": {}, "outputs": [], "source": [ @@ -132,13 +130,13 @@ " axs[1].axvline(edge, color=\"tab:red\", linestyle=\"--\", linewidth=0.9)\n", "axs[1].set_xlabel(\"time (s)\")\n", "axs[1].set_ylabel(\"rate (Hz)\")\n", - "axs[1].legend(loc=\"upper left\", frameon=False, fontsize=8)\n" + "axs[1].legend(loc=\"upper left\", frameon=False, fontsize=8)" ] }, { "cell_type": "code", "execution_count": null, - "id": "780da944", + "id": "9bf5cb26", "metadata": {}, "outputs": [], "source": [ @@ -152,13 +150,13 @@ "ax.set_ylim(0.0, 1.0)\n", "ax.set_xlabel(\"theoretical CDF\")\n", "ax.set_ylabel(\"empirical CDF\")\n", - "ax.set_title(\"Constant-condition KS plot\")\n" + "ax.set_title(\"Constant-condition KS plot\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "3c220adb", + "id": "197eb8cd", "metadata": {}, "outputs": [], "source": [ @@ -170,13 +168,13 @@ "ax.axhline(-exp1_summary[\"constant_acf_ci\"], color=\"tab:red\", linewidth=1.0)\n", "ax.set_xlabel(\"lag\")\n", "ax.set_ylabel(\"autocorrelation\")\n", - "ax.set_title(\"Sequential correlation under constant Mg\")\n" + "ax.set_title(\"Sequential correlation under constant Mg\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "78a302f9", + "id": "53c12b4c", "metadata": {}, "outputs": [], "source": [ @@ -188,24 +186,24 @@ "ax.bar(np.arange(3), aics, color=[\"0.6\", \"tab:green\", \"tab:red\"])\n", "ax.set_xticks(np.arange(3), names)\n", "ax.set_ylabel(\"AIC\")\n", - "ax.set_title(\"Experiment 1 model comparison\")\n" + "ax.set_title(\"Experiment 1 model comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2b09c5e3", + "id": "b6751400", "metadata": {}, "outputs": [], "source": [ "# SECTION 8: Experiment 2\n", - "print(exp2_summary)\n" + "print(exp2_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "03f25c28", + "id": "be05f22d", "metadata": {}, "outputs": [], "source": [ @@ -218,13 +216,13 @@ "axs[0].set_ylabel(\"spikes\")\n", "axs[1].plot(exp2[\"time_s\"], exp2[\"stimulus\"], color=\"tab:blue\", linewidth=1.2)\n", "axs[1].set_ylabel(\"stimulus\")\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "03ee9c55", + "id": "f704736c", "metadata": {}, "outputs": [], "source": [ @@ -234,13 +232,13 @@ "ax.plot(1000.0 * np.asarray(exp2[\"xcorr_lags_s\"], dtype=float), exp2[\"xcorr_values\"], color=\"tab:purple\", linewidth=1.3)\n", "ax.set_xlabel(\"lag (ms)\")\n", "ax.set_ylabel(\"cross-covariance\")\n", - "ax.set_title(\"Stimulus lag search\")\n" + "ax.set_title(\"Stimulus lag search\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "09f0d4f1", + "id": "620d4f28", "metadata": {}, "outputs": [], "source": [ @@ -253,13 +251,13 @@ "axs[0].set_title(\"AIC\")\n", "axs[1].bar(np.arange(3), [exp2_summary[\"model1_bic\"], exp2_summary[\"model2_bic\"], exp2_summary[\"model3_bic\"]], color=[\"0.65\", \"tab:blue\", \"tab:green\"])\n", "axs[1].set_xticks(np.arange(3), model_names, rotation=15)\n", - "axs[1].set_title(\"BIC\")\n" + "axs[1].set_title(\"BIC\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "b32115d9", + "id": "0803cc01", "metadata": {}, "outputs": [], "source": [ @@ -275,13 +273,13 @@ "ax.set_xlim(0.0, 1.0)\n", "ax.set_ylim(0.0, 1.0)\n", "ax.legend(loc=\"lower right\", frameon=False, fontsize=8)\n", - "ax.set_title(\"Experiment 2 KS diagnostics\")\n" + "ax.set_title(\"Experiment 2 KS diagnostics\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "d209b1e8", + "id": "e8e2cecc", "metadata": {}, "outputs": [], "source": [ @@ -295,13 +293,13 @@ "axs[1].set_ylabel(\"Delta AIC\")\n", "axs[2].plot(windows, exp2[\"delta_bic\"], marker=\"o\", color=\"tab:brown\", linewidth=1.2)\n", "axs[2].set_ylabel(\"Delta BIC\")\n", - "axs[2].set_xlabel(\"history windows\")\n" + "axs[2].set_xlabel(\"history windows\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "570fa694", + "id": "86253034", "metadata": {}, "outputs": [], "source": [ @@ -315,24 +313,24 @@ "ax.errorbar(xpos, coef_values, yerr=np.vstack([coef_values - lower, upper - coef_values]), fmt=\"o\", color=\"tab:blue\", capsize=3)\n", "ax.set_xticks(xpos, exp2[\"coef_names\"], rotation=30)\n", "ax.set_ylabel(\"coefficient value\")\n", - "ax.set_title(\"Experiment 2 coefficient intervals\")\n" + "ax.set_title(\"Experiment 2 coefficient intervals\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2e5c472a", + "id": "214e5be7", "metadata": {}, "outputs": [], "source": [ "# SECTION 15: Experiment 3\n", - "print(exp3_summary)\n" + "print(exp3_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "ad2e93f8", + "id": "f39efc64", "metadata": {}, "outputs": [], "source": [ @@ -342,13 +340,13 @@ "ax.plot(exp3[\"time_s\"], exp3[\"true_rate_hz\"], color=\"tab:blue\", linewidth=1.3)\n", "ax.set_xlabel(\"time (s)\")\n", "ax.set_ylabel(\"rate (Hz)\")\n", - "ax.set_title(\"Experiment 3 true conditional intensity\")\n" + "ax.set_title(\"Experiment 3 true conditional intensity\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "03fe133c", + "id": "39340a3b", "metadata": {}, "outputs": [], "source": [ @@ -360,24 +358,24 @@ "axs[0].set_ylabel(\"trial\")\n", "axs[1].plot(exp3[\"psth_bin_centers_s\"], exp3[\"psth_rate_hz\"], color=\"tab:red\", linewidth=1.4)\n", "axs[1].set_ylabel(\"PSTH (Hz)\")\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2be08d3f", + "id": "e80bdf17", "metadata": {}, "outputs": [], "source": [ "# SECTION 18: Experiment 3b\n", - "print(exp3b_summary)\n" + "print(exp3b_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "4470cf1f", + "id": "d3f429ac", "metadata": {}, "outputs": [], "source": [ @@ -388,13 +386,13 @@ "axs[0].set_title(\"True stimulus\")\n", "axs[1].imshow(exp3b[\"xk\"], aspect=\"auto\", cmap=\"viridis\")\n", "axs[1].set_title(\"Decoded state\")\n", - "axs[1].set_xlabel(\"time bin\")\n" + "axs[1].set_xlabel(\"time bin\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2c42e63b", + "id": "9e327384", "metadata": {}, "outputs": [], "source": [ @@ -404,13 +402,13 @@ "axs[0].plot(np.mean(exp3b[\"ci_width\"], axis=0), color=\"tab:orange\", linewidth=1.3)\n", "axs[0].set_title(\"Mean CI width over time\")\n", "axs[1].plot(np.mean(exp3b[\"qhat_all\"], axis=0), marker=\"o\", color=\"tab:blue\", linewidth=1.2)\n", - "axs[1].set_title(\"Mean Qhat across models\")\n" + "axs[1].set_title(\"Mean Qhat across models\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "e193a719", + "id": "2fd1d209", "metadata": {}, "outputs": [], "source": [ @@ -420,24 +418,24 @@ "axs[0].bar(np.arange(len(exp3b[\"gammahat\"])), exp3b[\"gammahat\"], color=\"tab:green\")\n", "axs[0].set_title(\"gammahat\")\n", "axs[1].plot(np.asarray(exp3b[\"gammahat_all\"], dtype=float), marker=\"o\", color=\"tab:red\", linewidth=1.2)\n", - "axs[1].set_title(\"gammahatAll\")\n" + "axs[1].set_title(\"gammahatAll\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "3718c5d4", + "id": "8150177d", "metadata": {}, "outputs": [], "source": [ "# SECTION 22: Experiment 4\n", - "print(exp4_summary)\n" + "print(exp4_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "96e2af7c", + "id": "19dabcf1", "metadata": {}, "outputs": [], "source": [ @@ -447,13 +445,13 @@ "ax.bar(np.arange(len(exp4[\"animal1\"][\"selected_indices\"])), exp4[\"animal1\"][\"delta_aic\"], color=\"tab:blue\")\n", "ax.set_xticks(np.arange(len(exp4[\"animal1\"][\"selected_indices\"])), [str(int(v) + 1) for v in exp4[\"animal1\"][\"selected_indices\"]])\n", "ax.set_ylabel(\"Gaussian - Zernike AIC\")\n", - "ax.set_title(\"Animal 1 place-cell comparison\")\n" + "ax.set_title(\"Animal 1 place-cell comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "ccd871af", + "id": "21ef2693", "metadata": {}, "outputs": [], "source": [ @@ -463,13 +461,13 @@ "ax.bar(np.arange(len(exp4[\"animal2\"][\"selected_indices\"])), exp4[\"animal2\"][\"delta_bic\"], color=\"tab:green\")\n", "ax.set_xticks(np.arange(len(exp4[\"animal2\"][\"selected_indices\"])), [str(int(v) + 1) for v in exp4[\"animal2\"][\"selected_indices\"]])\n", "ax.set_ylabel(\"Gaussian - Zernike BIC\")\n", - "ax.set_title(\"Animal 2 place-cell comparison\")\n" + "ax.set_title(\"Animal 2 place-cell comparison\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "113c1903", + "id": "549f0f6b", "metadata": {}, "outputs": [], "source": [ @@ -477,13 +475,13 @@ "fig = _fig(\"experiment4 gaussian mesh\", figsize=(9.0, 6.5))\n", "ax = fig.add_subplot(111, projection=\"3d\")\n", "ax.plot_surface(exp4[\"mesh\"][\"grid_x\"], exp4[\"mesh\"][\"grid_y\"], exp4[\"mesh\"][\"gaussian_field\"], cmap=\"Blues\", linewidth=0.0, antialiased=True)\n", - "ax.set_title(\"Gaussian place-field estimate\")\n" + "ax.set_title(\"Gaussian place-field estimate\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "64654717", + "id": "3e3bb9b7", "metadata": {}, "outputs": [], "source": [ @@ -491,24 +489,24 @@ "fig = _fig(\"experiment4 zernike mesh\", figsize=(9.0, 6.5))\n", "ax = fig.add_subplot(111, projection=\"3d\")\n", "ax.plot_surface(exp4[\"mesh\"][\"grid_x\"], exp4[\"mesh\"][\"grid_y\"], exp4[\"mesh\"][\"zernike_field\"], cmap=\"Greens\", linewidth=0.0, antialiased=True)\n", - "ax.set_title(\"Zernike place-field estimate\")\n" + "ax.set_title(\"Zernike place-field estimate\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "115929b8", + "id": "96c64e62", "metadata": {}, "outputs": [], "source": [ "# SECTION 27: Experiment 5\n", - "print(exp5_summary)\n" + "print(exp5_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "26af3ac2", + "id": "b88f5c9e", "metadata": {}, "outputs": [], "source": [ @@ -520,24 +518,24 @@ "ax.fill_between(exp5[\"time_s\"], exp5[\"ci_low\"], exp5[\"ci_high\"], color=\"0.85\")\n", "ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", "ax.set_xlabel(\"time (s)\")\n", - "ax.set_title(\"Experiment 5 adaptive decoding\")\n" + "ax.set_title(\"Experiment 5 adaptive decoding\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "22ee7cab", + "id": "b751e28e", "metadata": {}, "outputs": [], "source": [ "# SECTION 29: Experiment 5b\n", - "print(exp5b_summary)\n" + "print(exp5b_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "f0312bf9", + "id": "48493368", "metadata": {}, "outputs": [], "source": [ @@ -550,13 +548,13 @@ "axs[1].plot(exp5b[\"time_s\"], exp5b[\"y_true\"], color=\"0.3\", linewidth=1.0, label=\"True y\")\n", "axs[1].plot(exp5b[\"time_s\"], exp5b[\"dy_goal\"], color=\"tab:orange\", linewidth=1.2, label=\"Decoded y\")\n", "axs[1].legend(loc=\"upper right\", frameon=False, fontsize=8)\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "216a748d", + "id": "fdeed99c", "metadata": {}, "outputs": [], "source": [ @@ -569,24 +567,24 @@ "axs[1].plot(exp5b[\"time_s\"], exp5b[\"y_true\"], color=\"0.3\", linewidth=1.0, label=\"True y\")\n", "axs[1].plot(exp5b[\"time_s\"], exp5b[\"dy_free\"], color=\"tab:red\", linewidth=1.2, label=\"Decoded y\")\n", "axs[1].legend(loc=\"upper right\", frameon=False, fontsize=8)\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "358d0d12", + "id": "98b8c650", "metadata": {}, "outputs": [], "source": [ "# SECTION 32: Experiment 6\n", - "print(exp6_summary)\n" + "print(exp6_summary)" ] }, { "cell_type": "code", "execution_count": null, - "id": "be288a10", + "id": "01efa943", "metadata": {}, "outputs": [], "source": [ @@ -598,13 +596,13 @@ "axs[1].plot(exp6[\"time_s\"], exp6[\"state_prob_1\"], color=\"tab:blue\", linewidth=1.2, label=\"P(state=1)\")\n", "axs[1].plot(exp6[\"time_s\"], exp6[\"state_prob_2\"], color=\"tab:orange\", linewidth=1.2, label=\"P(state=2)\")\n", "axs[1].legend(loc=\"upper right\", frameon=False, fontsize=8)\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "7e158b30", + "id": "fbcd95b6", "metadata": {}, "outputs": [], "source": [ @@ -617,13 +615,13 @@ "axs[1].plot(exp6[\"time_s\"], exp6[\"y_pos\"], color=\"0.3\", linewidth=1.0, label=\"True y\")\n", "axs[1].plot(exp6[\"time_s\"], exp6[\"decoded_y\"], color=\"tab:orange\", linewidth=1.2, label=\"Decoded y\")\n", "axs[1].legend(loc=\"upper right\", frameon=False, fontsize=8)\n", - "axs[1].set_xlabel(\"time (s)\")\n" + "axs[1].set_xlabel(\"time (s)\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2b0b149c", + "id": "2c1cf92e", "metadata": {}, "outputs": [], "source": [ @@ -635,13 +633,13 @@ "ax.bar(np.arange(len(labels)), rmses, color=[\"tab:blue\", \"tab:green\", \"tab:red\", \"tab:purple\", \"tab:orange\"])\n", "ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n", "ax.set_ylabel(\"RMSE\")\n", - "ax.set_title(\"Decoding summary across paper examples\")\n" + "ax.set_title(\"Decoding summary across paper examples\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "74cb3002", + "id": "bb79988e", "metadata": {}, "outputs": [], "source": [ @@ -658,13 +656,13 @@ "labels = [\"Exp1 spikes\", \"Exp2 samples\", \"Exp3 trials\", \"Exp4 cells\", \"Exp6 cells\"]\n", "ax.bar(np.arange(len(labels)), counts, color=\"0.65\")\n", "ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n", - "ax.set_title(\"Paper-example dataset scale\")\n" + "ax.set_title(\"Paper-example dataset scale\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "a52f8e5c", + "id": "9aa50b57", "metadata": {}, "outputs": [], "source": [ @@ -677,7 +675,7 @@ " \"experiment6_state_accuracy\": round(float(exp6_summary[\"state_accuracy\"]), 3),\n", " }\n", ")\n", - "__tracker.finalize()\n" + "__tracker.finalize()" ] } ], @@ -694,4 +692,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/nSpikeTrainExamples.ipynb b/notebooks/nSpikeTrainExamples.ipynb index 71075a0e..f6b44cae 100644 --- a/notebooks/nSpikeTrainExamples.ipynb +++ b/notebooks/nSpikeTrainExamples.ipynb @@ -30,7 +30,7 @@ "__tracker = FigureTracker(topic='nSpikeTrainExamples', output_root=OUTPUT_ROOT, expected_count=4)\n", "\n", "# SECTION 0: Section 0\n", - "# Test the nspikeTrain Class\n" + "# Test the nspikeTrain Class" ] }, { diff --git a/nstat/analysis.py b/nstat/analysis.py index f5ff546e..36269a30 100644 --- a/nstat/analysis.py +++ b/nstat/analysis.py @@ -229,10 +229,28 @@ def GLMFit( start = stop lambda_sig = _fit_lambda_matrix_to_covariate(lambda_time_full, lambda_segments, int(lambdaIndex)) + + # Compute standard errors from Fisher information (Hessian inverse) + # Poisson: W = diag(mu); Binomial: W = diag(mu*(1-mu)) + try: + if distribution == "binomial": + W = lambda_delta * (1.0 - lambda_delta) + else: + W = lambda_delta.copy() + W = np.maximum(W, 1e-12) + XtWX = X.T @ (X * W[:, None]) + l2 * np.eye(X.shape[1]) + covb = np.linalg.inv(XtWX) + se = np.sqrt(np.maximum(np.diag(covb), 0.0)) + except np.linalg.LinAlgError: + se = np.full(b.size, np.nan, dtype=float) + covb = None + stats = { "intercept": float(glm_res.intercept), "n_iter": int(glm_res.n_iter), "converged": bool(glm_res.converged), + "se": se, + "covb": covb, } return lambda_sig, b, dev, stats, AIC, BIC, logLL, distribution @@ -271,6 +289,9 @@ def run_analysis_for_neuron( if not spike_train.name: spike_train.setName(str(neuron_number)) + spike_validation = None + has_validation = False + for cfg_index in range(1, config_collection.numConfigs + 1): _restore_trial_partition(trial, original_partition) config_collection.setConfig(trial, cfg_index) @@ -314,9 +335,12 @@ def run_analysis_for_neuron( partition = np.asarray(trial.getTrialPartition(), dtype=float).reshape(-1) if partition.size >= 4 and partition[2] < partition[3]: + has_validation = True trial.setTrialTimesFor("validation") xvalData.append(np.asarray(trial.getDesignMatrix(neuron_number), dtype=float)) xvalTime.append(np.asarray(trial.covarColl.getCov(1).time, dtype=float).copy()) + spike_validation = trial.nspikeColl.getNST(neuron_number).nstCopy() + spike_validation.setName(str(neuron_number)) trial.setTrialTimesFor("training") else: xvalData.append(np.zeros((0, len(current_labels)), dtype=float)) @@ -349,6 +373,15 @@ def run_analysis_for_neuron( # MATLAB returns fits with KS diagnostics already populated, and # downstream summary classes read those cached fields directly. fit_result.computeKSStats() + + # Compute the conditional intensity on validation data when a + # validation partition is present (mirrors Matlab behaviour). + if has_validation: + try: + fit_result.computeValLambda() + except Exception: + pass # validation lambda is optional; don't fail the fit + return fit_result @staticmethod @@ -459,14 +492,12 @@ def computeFitResidual(nspikeObj, lambdaInput: Covariate, windowSize: float = 0. @staticmethod def KSPlot(fitResults: FitResult, DTCorrection: int = 1, makePlot: int = 1): - del DTCorrection - fitResults.computeKSStats() + fitResults.computeKSStats(dt_correction=DTCorrection) return fitResults.KSPlot() if makePlot else [] @staticmethod def plotFitResidual(fitResults: FitResult, windowSize: float = 0.01, makePlot: int = 1): - del windowSize - fitResults.computeFitResidual() + fitResults.computeFitResidual(windowSize=windowSize) return fitResults.plotResidual() if makePlot else [] @staticmethod @@ -485,7 +516,7 @@ def plotCoeffs(fitResults: FitResult): @staticmethod def computeHistLag(tObj: Trial, neuronNum=None, windowTimes=None, CovLabels=None, Algorithm="GLM", batchMode=0, sampleRate=None, makePlot=1, histMinTimes=None, histMaxTimes=None): - del batchMode, histMinTimes, histMaxTimes + del batchMode if windowTimes is None: raise ValueError("Must specify a vector of windowTimes") if neuronNum is None: @@ -497,12 +528,19 @@ def computeHistLag(tObj: Trial, neuronNum=None, windowTimes=None, CovLabels=None if windows.size < 2: raise ValueError("windowTimes must contain at least two entries") + use_history_obj = (histMinTimes is not None or histMaxTimes is not None) + configs = [] from .trial import TrialConfig configs.append(TrialConfig(cov_labels, sampleRate, [], [], name="Baseline")) for i in range(2, windows.size + 1): - cfg = TrialConfig(cov_labels, sampleRate, windows[:i], [], name=f"Window{i - 1}") + if use_history_obj: + from .history import History as _Hist + h_temp = _Hist(windows[:i], minTime=histMinTimes, maxTime=histMaxTimes) + cfg = TrialConfig(cov_labels, sampleRate, h_temp, [], name=f"Window{i - 1}") + else: + cfg = TrialConfig(cov_labels, sampleRate, windows[:i], [], name=f"Window{i - 1}") configs.append(cfg) tcc = ConfigCollection(configs) fitResults = Analysis.RunAnalysisForNeuron(tObj, neuronNum, tcc, makePlot, Algorithm) @@ -619,15 +657,26 @@ def computeGrangerCausalityMatrix(tObj: Trial, Algorithm="GLM", confidenceInterv p_val = float(chi2.sf(deviance, dim_diff)) p_vals.append(p_val) p_coords.append((neighbor - 1, neuron_index - 1)) - coeffs = fit.getHistCoeffs(2) if np.any(np.asarray(fit.numHist, dtype=int) > 0) else np.array([], dtype=float) - if coeffs.size: - phiMat[neighbor - 1, neuron_index - 1] = -float(np.sign(np.sum(coeffs))) * gamma + # Matlab: extract only the specific neighbor's ensemble + # coefficients from the BASELINE model (fit 1) for the sign. + if np.any(np.asarray(fit.numHist, dtype=int) > 0): + coeffs_all, labels_all, _ = fit.getCoeffsWithLabels(1) + neighbor_prefix = f"{neighbor}:[" + neighbor_mask = np.array([str(lbl).startswith(neighbor_prefix) for lbl in labels_all], dtype=bool) + neighbor_coeffs = coeffs_all[neighbor_mask] if np.any(neighbor_mask) else np.array([], dtype=float) + else: + neighbor_coeffs = np.array([], dtype=float) + if neighbor_coeffs.size: + phiMat[neighbor - 1, neuron_index - 1] = -float(np.sign(np.sum(neighbor_coeffs))) * gamma if p_vals: keep = _benjamini_hochberg(np.asarray(p_vals, dtype=float), alpha=max(alpha, 1e-6)) for include, (row, col) in zip(keep, p_coords, strict=False): sigMat[row, col] = int(include) + # Restore the ensemble covariate mask to its default state (Matlab parity). + tObj.resetEnsCovMask() + return fitResults, gammaMat, phiMat, devianceMat, sigMat @staticmethod diff --git a/nstat/core.py b/nstat/core.py index 2035151f..dc9bb895 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -368,6 +368,52 @@ def getIndicesFromLabels(self, label: Sequence[str] | str): return [item[0] for item in out] return out + def areDataLabelsEmpty(self) -> bool: + """Return ``True`` if all data labels are empty strings. + + Matches Matlab ``SignalObj.areDataLabelsEmpty()``. + """ + return all(not str(label) for label in self.dataLabels) + + def isLabelPresent(self, label: str) -> bool: + """Return ``True`` if *label* matches any data label or equals ``'all'``. + + Matches Matlab ``SignalObj.isLabelPresent()``. + """ + if str(label).lower() == "all": + return True + try: + self.getIndexFromLabel(label) + return True + except ValueError: + return False + + def convertNamesToIndices(self, selectorArray) -> list[int] | np.ndarray: + """Convert label names (or mixed) to 1-based indices. + + Matches Matlab ``SignalObj.convertNamesToIndices()``. + """ + if isinstance(selectorArray, str): + if selectorArray == "all": + return list(range(1, self.dimension + 1)) + if self.isLabelPresent(selectorArray): + return self.getIndexFromLabel(selectorArray) + raise ValueError(f"Specified label '{selectorArray}' does not match data label") + if isinstance(selectorArray, (int, float, np.integer)): + return [int(selectorArray)] + if isinstance(selectorArray, np.ndarray): + return selectorArray.astype(int).ravel().tolist() + if isinstance(selectorArray, (list, tuple)): + result: list[int] = [] + for item in selectorArray: + if isinstance(item, str): + if self.isLabelPresent(item): + result.extend(self.getIndexFromLabel(item)) + else: + result.append(int(item)) + return result + return list(range(1, self.dimension + 1)) + def getValueAt(self, x: Sequence[float] | float) -> np.ndarray: query = np.asarray(x, dtype=float).reshape(-1) out = np.zeros((query.size, self.dimension), dtype=float) @@ -500,14 +546,28 @@ def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None: self.data = self.data[:endIndex, :] self.maxTime = float(np.max(self.time)) - def merge(self, other: "SignalObj") -> "SignalObj": - if self.time.shape != other.time.shape or np.max(np.abs(self.time - other.time)) > 1e-9: - raise ValueError("Signals must share an identical time grid to merge.") - merged = self._spawn( - self.time, - np.column_stack([self.data, other.data]), - data_labels=[*self.dataLabels, *list(other.dataLabels)], - plot_props=[*self.plotProps, *getattr(other, "plotProps", [None for _ in range(other.dimension)])], + def merge(self, other: "SignalObj", holdVals: int = 0) -> "SignalObj": + """Merge *other* signal columns into *self*. + + Matlab calls ``makeCompatible`` first so that signals with + different time grids are reconciled automatically. The Python + version now does the same. + + Parameters + ---------- + other : SignalObj + Signal whose data columns will be appended. + holdVals : int, optional + Passed to ``makeCompatible`` – ``1`` holds endpoint values + when the time range is extended; ``0`` (default) pads with + zeros. + """ + s1c, s2c = self.makeCompatible(other, holdVals) + merged = s1c._spawn( + s1c.time, + np.column_stack([s1c.data, s2c.data]), + data_labels=[*s1c.dataLabels, *list(s2c.dataLabels)], + plot_props=[*s1c.plotProps, *getattr(s2c, "plotProps", [None for _ in range(s2c.dimension)])], ) return merged @@ -557,6 +617,48 @@ def __rtruediv__(self, other) -> "SignalObj": right = np.repeat(right, left.shape[1], axis=1) return self._spawn(self.time, np.divide(left, right), data_labels=labels) + def __matmul__(self, other) -> "SignalObj": + """Matrix multiply (``@`` operator). Matches Matlab ``mtimes``.""" + if isinstance(other, SignalObj): + return self._spawn(self.time, self.data * other.data, data_labels=list(self.dataLabels)) + other_arr = np.asarray(other, dtype=float) + result = (self.data.T @ other_arr).T if other_arr.ndim <= 1 else self.data @ other_arr + return self._spawn(self.time[:result.shape[0]] if result.ndim == 2 else self.time, result) + + def ldivide(self, other) -> "SignalObj": + r"""Element-wise left division (Matlab ``.\``): ``other ./ self``. + + Matches Matlab ``SignalObj.ldivide()``. + """ + return self._binary_op(other, lambda a, b: np.divide(b, a)) + + @property + def T(self) -> "SignalObj": + """Transpose the data matrix. Matches Matlab ``ctranspose`` / ``transpose``.""" + new_data = self.data.T + new_time = self.time[:new_data.shape[0]] if new_data.shape[0] != self.time.size else self.time + return self._spawn(new_time, new_data) + + def clearPlotProps(self, index=None) -> None: + """Clear plot properties. Matches Matlab ``clearPlotProps``.""" + if index is None: + index = list(range(self.dimension)) + else: + index = [i - 1 for i in np.atleast_1d(index)] + for i in index: + if i < len(self.plotProps): + self.plotProps[i] = None + + def plotPropsSet(self) -> bool: + """Return ``True`` if any plot property is non-empty. + + Matches Matlab ``SignalObj.plotPropsSet()``. + """ + for prop in self.plotProps: + if prop is not None and str(prop) != "": + return True + return False + def getSigInTimeWindow( self, wMin: Sequence[float] | float | None = None, @@ -1072,14 +1174,193 @@ def findGlobalPeak( values = data[idx, np.arange(data.shape[1])] return np.atleast_1d(times), np.atleast_1d(values) + # ------------------------------------------------------------------ + # Alignment / windowing (match Matlab SignalObj) + # ------------------------------------------------------------------ + def alignToMax(self) -> tuple["SignalObj", float]: + """Align all dimensions so their peaks coincide at the mean peak time. + + Returns ``(aligned_signal, mean_peak_time)``. + Matches Matlab ``SignalObj.alignToMax()``. + """ + peak_times, _ = self.findGlobalPeak("maxima") + mean_time = float(np.mean(peak_times)) + delta_t = -(peak_times - mean_time) + aligned = self.getSubSignal(1).shift(float(delta_t[0])) + for i in range(1, self.dimension): + aligned = aligned.merge(self.getSubSignal(i + 1).shift(float(delta_t[i]))) + return aligned, mean_time + + def windowedSignal(self, windowTimes) -> "SignalObj": + """Extract and concatenate windowed segments. + + Matches Matlab ``SignalObj.windowedSignal()``. + """ + windowTimes = np.asarray(windowTimes, dtype=float).ravel() + result = None + for i in range(len(windowTimes) - 1): + seg = self.getSigInTimeWindow(float(windowTimes[i]), float(windowTimes[i + 1])) + if i == 0: + result = seg + else: + seg = seg.shift(-float(windowTimes[i])) + result = result.merge(seg) + return result if result is not None else self.copySignal() + + def normWindowedSignal( + self, + windowTimes, + numPoints: int = 100, + lbound: float | None = None, + ubound: float | None = None, + ) -> "SignalObj": + """Normalize windowed signal segments to a common time axis. + + Matches Matlab ``SignalObj.normWindowedSignal()``. + """ + windowTimes = np.asarray(windowTimes, dtype=float).ravel() + columns: list[np.ndarray] = [] + for i in range(len(windowTimes) - 1): + minT = float(windowTimes[i]) + maxT = float(windowTimes[i + 1]) + dur = abs(maxT - minT) + if lbound is not None and ubound is not None: + if dur > ubound or dur < lbound: + continue + seg = self.getSigInTimeWindow(minT, maxT) + norm_time = np.linspace(minT, maxT, numPoints) + # Matlab uses interp1(..., 'nearest', 0) — nearest-neighbor with 0-fill + from scipy.interpolate import interp1d as _interp1d + _ifn = _interp1d(seg.time, seg.data[:, 0], kind="nearest", + bounds_error=False, fill_value=0.0) + interp_data = _ifn(norm_time) + columns.append(interp_data) + + if not columns: + return self.copySignal() + data = np.column_stack(columns) + act_time = np.arange(numPoints, dtype=float) / float(numPoints) + labels = list(self.dataLabels[:1]) * data.shape[1] + return self.__class__(act_time, data, self.name, self.xlabelval, "%", self.yunits, labels) + + def getSubSignalsWithinNStd(self, nStd: float = 2.0) -> tuple["SignalObj", np.ndarray]: + """Return sub-signals within *nStd* standard deviations of the mean. + + Returns ``(filtered_signal, selected_indices)``. + Matches Matlab ``SignalObj.getSubSignalsWithinNStd()``. + """ + mean_sig = np.mean(self.data, axis=1) + std_sig = np.std(self.data, axis=1, ddof=1) + min_val = mean_sig - nStd * std_sig + max_val = mean_sig + nStd * std_sig + # A column passes if ALL rows are within [minVal, maxVal] + above_min = np.all(self.data >= min_val[:, None], axis=0) + below_max = np.all(self.data <= max_val[:, None], axis=0) + sig_index = np.flatnonzero(above_min & below_max) + if sig_index.size == 0: + return self.copySignal(), sig_index + # 1-based indices for getSubSignal + return self.getSubSignal((sig_index + 1).tolist()), sig_index + + # ------------------------------------------------------------------ + # Variability plots (match Matlab SignalObj) + # ------------------------------------------------------------------ + def plotAllVariability( + self, + faceColor=None, + linewidth: float = 3.0, + ciUpper: float | np.ndarray = 1.96, + ciLower: float | np.ndarray | None = None, + ax=None, + ): + """Plot mean ± CI shaded area. Matches Matlab ``plotAllVariability``. + + Parameters + ---------- + faceColor : color, optional + Fill colour (default: tab:blue). + linewidth : float + Width of mean line. + ciUpper, ciLower : float or array + Number of std-devs (scalar) or explicit bounds (array). + ax : matplotlib Axes, optional + """ + import matplotlib.pyplot as plt + + if faceColor is None: + faceColor = "tab:blue" + if ciLower is None: + ciLower = ciUpper + if ax is None: + ax = plt.gca() + + mean_sig = np.mean(self.data, axis=1) + std_sig = np.std(self.data, axis=1, ddof=1) + + ciUpper_arr = np.atleast_1d(ciUpper).ravel() + ciLower_arr = np.atleast_1d(ciLower).ravel() + if ciUpper_arr.size == 1: + ci_top = mean_sig + float(ciUpper_arr[0]) * std_sig + else: + ci_top = mean_sig + ciUpper_arr[:len(mean_sig)] + if ciLower_arr.size == 1: + ci_bottom = mean_sig - float(ciLower_arr[0]) * std_sig + else: + ci_bottom = mean_sig - ciLower_arr[:len(mean_sig)] + + ax.fill_between(self.time, ci_bottom, ci_top, color=faceColor, alpha=0.5, edgecolor="none") + (h,) = ax.plot(self.time, mean_sig, "k-", linewidth=linewidth) + return h + + def plotVariability(self, selectorArray=None, ax=None): + """Plot mean ± CI for each label group. Matches Matlab ``plotVariability``. + + Parameters + ---------- + selectorArray : list of list[int] or list[int], optional + ax : matplotlib Axes, optional + """ + import matplotlib.pyplot as plt + + if ax is None: + ax = plt.gca() + if selectorArray is None: + if not self.areDataLabelsEmpty(): + unique_labels = list(dict.fromkeys(self.dataLabels)) + selectorArray = [self.getIndexFromLabel(lbl) for lbl in unique_labels] + else: + selectorArray = list(range(1, self.dimension + 1)) + + _TAB_COLORS = [ + "tab:blue", "tab:orange", "tab:green", "tab:red", + "tab:purple", "tab:brown", "tab:pink", "tab:gray", + ] + handles = [] + if isinstance(selectorArray, list) and selectorArray and isinstance(selectorArray[0], (list, tuple, np.ndarray)): + for i, sel in enumerate(selectorArray): + h = self.getSubSignal(sel).plotAllVariability( + faceColor=_TAB_COLORS[i % len(_TAB_COLORS)], ax=ax + ) + handles.append(h) + else: + h = self.getSubSignal(selectorArray).plotAllVariability(ax=ax) + handles.append(h) + return handles + # ------------------------------------------------------------------ # Cross-covariance (match Matlab SignalObj.xcov) # ------------------------------------------------------------------ def xcov(self, other: "SignalObj | None" = None, maxlag: int | None = None, scaleOpt: str = "biased") -> "SignalObj": - """Cross-covariance (mean-removed xcorr). Matches Matlab ``xcov``.""" + """Cross-covariance (mean-removed xcorr). Matches Matlab ``xcov``. + + When called with no *other* argument (auto-covariance), only + non-negative lags are returned — matching Matlab behaviour where + ``data=tempC(M-1:end,index)`` and ``lags=tempLags(M-1:end)``. + """ + auto = other is None s1 = self - s2 = self if other is None else other + s2 = self if auto else other s1c, s2c = s1.makeCompatible(s2) data_columns: list[np.ndarray] = [] @@ -1109,6 +1390,12 @@ def xcov(self, other: "SignalObj | None" = None, maxlag: int | None = None, corr = corr[keep] lags = lags[keep] + # Matlab returns only non-negative lags for auto-covariance + if auto: + nonneg = lags >= 0 + corr = corr[nonneg] + lags = lags[nonneg] + if lag_index is None: lag_index = lags.astype(float) / max(float(s1c.sampleRate), 1e-12) data_columns.append(np.asarray(corr, dtype=float)) @@ -1133,40 +1420,63 @@ def xcov(self, other: "SignalObj | None" = None, maxlag: int | None = None, def periodogram(self, NFFT: int | None = None) -> tuple[np.ndarray, np.ndarray]: """Power spectral density via periodogram (Matlab ``periodogram``). - Returns ``(frequencies, psd)`` arrays. + Loops over all signal dimensions like the Matlab implementation. + + Returns ``(frequencies, psd)`` where *psd* has shape + ``(nfreqs,)`` for 1-D signals or ``(nfreqs, dimension)`` for + multi-dimensional signals. """ from scipy.signal import periodogram as _periodogram fs = float(self.sampleRate) - x = self.data[:, 0] if self.data.ndim == 2 else self.data - f, Pxx = _periodogram(x, fs=fs, nfft=NFFT, window="boxcar", - scaling="density") - return f, Pxx - - def MTMspectrum(self, NW: float = 4.0, Kmax: int | None = None, - NFFT: int | None = None) -> tuple[np.ndarray, np.ndarray]: + psd_cols: list[np.ndarray] = [] + f_out: np.ndarray | None = None + ndim = self.dimension + for i in range(ndim): + x = self.data[:, i] if self.data.ndim == 2 else self.data + f, Pxx = _periodogram(x, fs=fs, nfft=NFFT, window="boxcar", + scaling="density") + if f_out is None: + f_out = f + psd_cols.append(Pxx) + if ndim == 1: + return f_out, psd_cols[0] + return f_out, np.column_stack(psd_cols) + + def MTMspectrum(self, NW: float = 4.0, NFFT: int | None = None, + Pval: float = 0.95, + Kmax: int | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: """Multi-taper spectral estimate (Matlab ``MTMspectrum``). Uses discrete prolate spheroidal sequences (DPSS / Slepian tapers). + Loops over all signal dimensions like the Matlab implementation. Parameters ---------- NW : float Time-bandwidth product (default 4). - Kmax : int, optional - Number of tapers (default ``2*NW - 1``). NFFT : int, optional FFT length (default next power of 2 >= N). + Pval : float, optional + Confidence level for the chi-squared confidence interval + (default 0.95). Set to ``None`` to skip CI computation. + Kmax : int, optional + Number of tapers (default ``2*NW - 1``). Returns ------- frequencies : ndarray psd : ndarray + Shape ``(nfreqs,)`` for 1-D or ``(nfreqs, dimension)``. + psd_ci : ndarray or None + Shape ``(nfreqs, 2)`` for 1-D or ``(nfreqs, 2*dimension)`` + containing ``[lower, upper]`` columns per dimension. + ``None`` when *Pval* is ``None``. """ from scipy.signal.windows import dpss - x = self.data[:, 0] if self.data.ndim == 2 else self.data - N = len(x) + N = self.data.shape[0] fs = float(self.sampleRate) if Kmax is None: Kmax = int(2 * NW - 1) @@ -1174,24 +1484,49 @@ def MTMspectrum(self, NW: float = 4.0, Kmax: int | None = None, NFFT = int(2 ** np.ceil(np.log2(N))) tapers, eigenvalues = dpss(N, NW, Kmax, return_ratios=True) - # tapers shape: (Kmax, N) - # Compute tapered FFTs - Sk = np.zeros((Kmax, NFFT // 2 + 1)) - for k in range(Kmax): - xw = x * tapers[k] - Xf = np.fft.rfft(xw, n=NFFT) - Sk[k] = np.abs(Xf) ** 2 - - # Weighted average by eigenvalues - weights = eigenvalues / eigenvalues.sum() - psd = np.dot(weights, Sk) * (2.0 / fs) - # DC and Nyquist don't get doubled - psd[0] /= 2.0 - if NFFT % 2 == 0: - psd[-1] /= 2.0 - frequencies = np.fft.rfftfreq(NFFT, d=1.0 / fs) - return frequencies, psd + nfreqs = len(frequencies) + + # chi-squared CI bounds (degrees of freedom = 2*Kmax) + ci_lo_factor = ci_hi_factor = None + if Pval is not None: + from scipy.stats import chi2 + dof = 2 * Kmax + alpha = 1.0 - Pval + ci_lo_factor = dof / chi2.ppf(1.0 - alpha / 2.0, dof) + ci_hi_factor = dof / chi2.ppf(alpha / 2.0, dof) + + ndim = self.dimension + psd_cols: list[np.ndarray] = [] + ci_cols: list[np.ndarray] = [] + + for di in range(ndim): + x = self.data[:, di] if self.data.ndim == 2 else self.data + Sk = np.zeros((Kmax, nfreqs)) + for k in range(Kmax): + xw = x * tapers[k] + Xf = np.fft.rfft(xw, n=NFFT) + Sk[k] = np.abs(Xf) ** 2 + + weights = eigenvalues / eigenvalues.sum() + psd = np.dot(weights, Sk) * (2.0 / fs) + psd[0] /= 2.0 + if NFFT % 2 == 0: + psd[-1] /= 2.0 + psd_cols.append(psd) + + if Pval is not None: + ci_cols.append(psd * ci_lo_factor) + ci_cols.append(psd * ci_hi_factor) + + if ndim == 1: + psd_out = psd_cols[0] + ci_out = np.column_stack(ci_cols) if ci_cols else None + else: + psd_out = np.column_stack(psd_cols) + ci_out = np.column_stack(ci_cols) if ci_cols else None + + return frequencies, psd_out, ci_out def spectrogram(self, nperseg: int = 256, noverlap: int | None = None, NFFT: int | None = None, @@ -1339,15 +1674,28 @@ def computeMeanPlusCI(self, alphaVal: float = 0.05) -> "Covariate": newCov.setConfInterval(confInt) return newCov - def getSigRep(self, repType: str = "standard") -> SignalObj: + def getSigRep(self, repType: str = "standard") -> "Covariate": + """Return a signal representation of this covariate. + + Parameters + ---------- + repType : str + ``'standard'`` returns ``self`` unchanged. + ``'zero-mean'`` returns ``self - mean(self)`` with confidence + intervals propagated (Matlab parity: uses operator overload so + CIs shift by the same constant). + """ rep = str(repType).strip().lower() if rep == "standard": return self if rep == "zero-mean": - centered = self.data - np.mean(self.data, axis=0, keepdims=True) - return Covariate( - self.time, - centered, + # Build a constant Covariate holding the per-column mean so that + # the CI-propagating __sub__ is invoked (Matlab: ``self - self.mu``). + mu_vals = np.mean(self.data, axis=0, keepdims=True) + mu_broadcast = np.repeat(mu_vals, len(self.time), axis=0) + mu_cov = Covariate( + self.time.copy(), + mu_broadcast, self.name, self.xlabelval, self.xunits, @@ -1355,6 +1703,7 @@ def getSigRep(self, repType: str = "standard") -> SignalObj: list(self.dataLabels), list(self.plotProps), ) + return self - mu_cov raise ValueError("repType must be either 'zero-mean' or 'standard'") def plot(self, selectorArray=None, plotPropsIn=None, handle=None): @@ -1582,7 +1931,9 @@ def setName(self, name: str) -> None: def computeStatistics(self, makePlots: int = 0) -> None: self.avgFiringRate = self.firing_rate_hz isi = self.getISIs() - spike_times = self.spikeTimes + # Filter spike times to [minTime, maxTime] so burst statistics + # remain valid after setMinTime / setMaxTime (Matlab parity). + spike_times = self.getSpikeTimes(self.minTime, self.maxTime) mode_isi = _matlab_mode_1d(isi) self.burstIndex = float(1.0 / mode_isi / self.avgFiringRate) if np.isfinite(mode_isi) and self.avgFiringRate > 0 else np.nan self.B = np.nan @@ -1722,7 +2073,7 @@ def setSigRep(self, binwidth: float | None = None, minTime: float | None = None, # clearing it through the public min/max setters. self.minTime = float(sig.minTime) self.maxTime = float(sig.maxTime) - self.computeStatistics(-1) + self.computeStatistics(0) return self.sigRep def clearSigRep(self) -> None: @@ -1733,12 +2084,12 @@ def clearSigRep(self) -> None: def setMinTime(self, minTime: float) -> None: self.minTime = float(minTime) self.clearSigRep() - self.computeStatistics(-1) + self.computeStatistics(0) def setMaxTime(self, maxTime: float) -> None: self.maxTime = float(maxTime) self.clearSigRep() - self.computeStatistics(-1) + self.computeStatistics(0) def resample(self, sampleRate: float) -> "nspikeTrain": self.setSigRep(1.0 / float(sampleRate), self.minTime, self.maxTime) @@ -1883,9 +2234,20 @@ def plotJointISIHistogram(self): return ax def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): + """Plot ISI histogram (Matlab ``plotISIHistogram``). + + Parameters + ---------- + minTime, maxTime : float, optional + Time window for ISIs. Defaults to the spike train bounds. + numBins : int, optional + Number of histogram bins. When *None* the bin width defaults to + 1 ms (Matlab default behaviour). + handle : matplotlib Axes, optional + Axes to plot into. + """ import matplotlib.pyplot as plt - del numBins ax = plt.gca() if handle is None else handle if maxTime is None: maxTime = self.maxTime @@ -1895,8 +2257,16 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = counts = np.array([], dtype=float) bins = np.array([], dtype=float) if isi.size: - bin_width = 0.001 - bins = np.arange(0.0, float(np.max(isi)) + bin_width, bin_width, dtype=float) + isi_max = float(np.max(isi)) + if numBins is not None and int(numBins) > 0: + # Linearly-spaced bins when numBins is specified (Matlab parity). + n = int(numBins) + bin_width = max(isi_max / n, 1e-12) + bins = np.linspace(0.0, isi_max, n + 1, dtype=float) + else: + # Default: 1 ms bin width. + bin_width = 0.001 + bins = np.arange(0.0, isi_max + bin_width, bin_width, dtype=float) if bins.size < 2: bins = np.array([0.0, bin_width], dtype=float) idx = np.searchsorted(bins, isi, side="right") - 1 @@ -1907,10 +2277,10 @@ def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = ) idx = np.clip(idx, 0, bins.size - 1) counts = np.bincount(idx, minlength=bins.size).astype(float) - centers = bins + centers = bins[:counts.size] if bins.size > counts.size else bins ax.bar( centers, - counts, + counts[:centers.size], width=bin_width, align="edge", edgecolor="none", @@ -1973,6 +2343,11 @@ def plot(self, dHeight: float = 1.0, yOffset: float = 0.5, currentHandle=None, h return lines def nstCopy(self) -> "nspikeTrain": + """Return a deep copy (Matlab ``nstCopy``). + + Matlab's ``nstCopy`` builds the copy's sigRep and calls + ``computeStatistics(0)`` so the copy has valid burst parameters. + """ return nspikeTrain( self.spikeTimes.copy(), self.name, @@ -1983,7 +2358,7 @@ def nstCopy(self) -> "nspikeTrain": self.xunits, self.yunits, self.dataLabels, - -1, + 0, ) def to_binned_counts(self, bin_edges: Sequence[float]) -> np.ndarray: diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index 021a6137..c8544f26 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -312,6 +312,39 @@ def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: i return np.asarray(mu_terms, dtype=float), beta, fit_types.pop(), gamma, history_windows + + +def _nearestSPD(A: np.ndarray) -> np.ndarray: + """Find the nearest symmetric positive-definite matrix to *A*. + + Uses the algorithm of Higham (1988) via polar decomposition plus + eigenvalue clamping, matching Matlab ``nearestSPD``. + """ + B = 0.5 * (A + A.T) + _, S, Vt = np.linalg.svd(B) + H = Vt.T @ np.diag(S) @ Vt + Ahat = 0.5 * (B + H) + Ahat = 0.5 * (Ahat + Ahat.T) + # Test for positive-definiteness; clamp eigenvalues if needed + try: + np.linalg.cholesky(Ahat) + return Ahat + except np.linalg.LinAlgError: + pass + eigvals, eigvecs = np.linalg.eigh(Ahat) + eigvals = np.maximum(eigvals, np.finfo(float).eps) + Ahat = eigvecs @ np.diag(eigvals) @ eigvecs.T + Ahat = 0.5 * (Ahat + Ahat.T) + return Ahat + + +def _ztest_pvalue(param: float, se: float) -> float: + """Two-tailed z-test p-value: H0 param == 0, matching Matlab ``ztest``.""" + if se <= 0 or not np.isfinite(se): + return 1.0 + z = param / se + return float(2.0 * norm.sf(np.abs(z))) + class DecodingAlgorithms: @staticmethod def linear_decode(spike_counts: np.ndarray, stimulus: np.ndarray) -> dict[str, np.ndarray]: @@ -340,6 +373,27 @@ def kalman_filter( x0: np.ndarray, p0: np.ndarray, ) -> dict[str, np.ndarray]: + """Discrete-time Kalman filter — public Python API. + + Runs a Kalman filter on time-major observations and returns a + dict with updated state estimates and covariances. + + Parameters + ---------- + observations : (N, Dy) — observation time-series, one row per step. + transition : (Dx, Dx) — state-transition matrix A. + observation_matrix : (Dy, Dx) — observation matrix C (H). + q_cov : (Dx, Dx) — process-noise covariance. + r_cov : (Dy, Dy) — observation-noise covariance. + x0 : (Dx,) — initial state estimate. + p0 : (Dx, Dx) — initial error covariance. + + Returns + ------- + dict with keys: + ``state`` : (N, Dx) — updated (posterior) state estimates. + ``cov`` : (N, Dx, Dx) — updated covariances. + """ y = np.asarray(observations, dtype=float) a = np.asarray(transition, dtype=float) h = np.asarray(observation_matrix, dtype=float) @@ -362,7 +416,7 @@ def kalman_filter( k_gain = p_pred @ h.T @ np.linalg.pinv(s_cov) x_post = x_pred + k_gain @ innovation - p_post = (np.eye(n_x) - k_gain @ h) @ p_pred + p_post = _symmetrize((np.eye(n_x) - k_gain @ h) @ p_pred) xs[t] = x_post ps[t] = p_post @@ -371,6 +425,116 @@ def kalman_filter( return {"state": xs, "cov": ps} + @staticmethod + def _kalman_filter_matlab(A, C, Pv, Pw, Px0, x0, y, GnConv=None): + """Discrete-time Kalman filter matching the Matlab API (internal). + + Implements the DT Kalman filter for the system:: + + x(:, n+1) = A(:,:,n) x(:, n) + v(:, n) + y(:, n) = C(:,:,n) x(:, n) + w(:, n) + + where ``Pv(:,:,n)``, ``Pw(:,:,n)`` are the covariances of v(n) and + w(n), and ``Px0`` is the initial state covariance. + + Supports **time-varying** system matrices when supplied as 3-D arrays + (e.g. ``A.shape == (Dx, Dx, N)``). Time-invariant (2-D) matrices + are broadcast automatically. + + Parameters + ---------- + A : (Dx, Dx) or (Dx, Dx, N) — state transition. + C : (Dy, Dx) or (Dy, Dx, N) — observation matrix. + Pv : (Dx, Dx) or (Dx, Dx, N) — process noise covariance. + Pw : (Dy, Dy) or (Dy, Dy, N) — observation noise covariance. + Px0 : (Dx, Dx) — initial error covariance. + x0 : (Dx,) — initial state estimate. + y : (Dy, N) or (N, Dy) — observations (auto-detected layout). + GnConv : array or None, optional + Pre-converged Kalman gain. When ``None``, gain convergence + is auto-detected during filtering. + + Returns + ------- + x_p : (Dx, N+1) — predicted states (``x_p[:, 0] == x0``). + Pe_p : (Dx, Dx, N+1) — predicted covariances. + x_u : (Dx, N) — updated states. + Pe_u : (Dx, Dx, N) — updated covariances. + Gn : (Dx, Dy, N) — Kalman gain history. + GnConvIter : int or None — iteration at which gain converged. + """ + A = np.asarray(A, dtype=float) + C = np.asarray(C, dtype=float) + Pv = np.asarray(Pv, dtype=float) + Pw = np.asarray(Pw, dtype=float) + Px0 = np.asarray(Px0, dtype=float) + x0_vec = np.asarray(x0, dtype=float).reshape(-1) + y = np.asarray(y, dtype=float) + if y.ndim == 1: + y = y[None, :] + + Dx = A.shape[0] + Dy = C.shape[0] + + # Auto-detect layout: Matlab expects (Dy, N) state-major. + # If y is (N, Dy) time-major, transpose. + if y.shape[0] != Dy and y.shape[1] == Dy: + y = y.T + N = y.shape[1] + + def _sel(M, n): + """Select time-varying slice M[:,:,n] or broadcast M[:,:] if 2-D.""" + if M.ndim == 3: + return M[:, :, min(n, M.shape[2] - 1)] + return M + + x_p = np.zeros((Dx, N + 1), dtype=float) + Pe_p = np.zeros((Dx, Dx, N + 1), dtype=float) + x_u = np.zeros((Dx, N), dtype=float) + Pe_u = np.zeros((Dx, Dx, N), dtype=float) + Gn = np.zeros((Dx, Dy, N), dtype=float) + + x_p[:, 0] = x0_vec + Pe_p[:, :, 0] = Px0 + + GnConvIter = None + _GnConv = None + if GnConv is not None and not _is_empty_value(GnConv): + _GnConv = np.asarray(GnConv, dtype=float) + + for n in range(N): + An = _sel(A, n) + Cn = _sel(C, n) + Pvn = _sel(Pv, n) + Pwn = _sel(Pw, n) + + # --- Update --- + if _GnConv is not None: + G = _GnConv + else: + S = Cn @ Pe_p[:, :, n] @ Cn.T + Pwn + G = Pe_p[:, :, n] @ Cn.T @ np.linalg.solve(S, np.eye(Dy)) + x_u[:, n] = x_p[:, n] + G @ (y[:, n] - Cn @ x_p[:, n]) + Pe_u[:, :, n] = Pe_p[:, :, n] - G @ Cn @ Pe_p[:, :, n] + Pe_u[:, :, n] = _symmetrize(Pe_u[:, :, n]) + Gn[:, :, n] = G + + # --- Predict --- + if _GnConv is not None: + Pe_p[:, :, n + 1] = _symmetrize(Pe_u[:, :, n]) + else: + Pe_p[:, :, n + 1] = _symmetrize(An @ Pe_u[:, :, n] @ An.T + Pvn) + x_p[:, n + 1] = An @ x_u[:, n] + + # --- Gain convergence detection --- + if n > 0 and _GnConv is None: + diffGn = np.abs(Gn[:, :, n] - Gn[:, :, n - 1]) + if np.max(diffGn) < 1e-6: + _GnConv = Gn[:, :, n] + GnConvIter = n + + return x_p, Pe_p, x_u, Pe_u, Gn, GnConvIter + @staticmethod def kalman_predict(x_u, Pe_u, A, Pv, GnConv=None): x_vec = np.asarray(x_u, dtype=float).reshape(-1) @@ -474,28 +638,158 @@ def kalman_smoother(A, C, Pv, Pw, Px0, x0, y): @staticmethod def kalman_fixedIntervalSmoother(A, C, Pv, Pw, Px0, x0, y, lags): - x_N, P_N, _, x_p, Pe_p, x_u, Pe_u = DecodingAlgorithms.kalman_smoother(A, C, Pv, Pw, Px0, x0, y) - x_p_tm, Pe_p_tm, _ = DecodingAlgorithms._state_history_time_major(x_p, Pe_p) - x_u_tm, Pe_u_tm, _ = DecodingAlgorithms._state_history_time_major(x_u, Pe_u) - x_N_tm, P_N_tm, _ = DecodingAlgorithms._state_history_time_major(x_N, P_N) - lag = max(int(lags), 1) - x_pLag = np.zeros_like(x_p_tm) - Pe_pLag = np.zeros_like(Pe_p_tm) - x_uLag = np.zeros_like(x_u_tm) - Pe_uLag = np.zeros_like(Pe_u_tm) - - for t in range(x_u_tm.shape[0]): - idx = max(t - lag + 1, 0) - x_uLag[t] = x_N_tm[idx] - Pe_uLag[t] = P_N_tm[idx] - x_pLag[t] = x_p_tm[idx] - Pe_pLag[t] = Pe_p_tm[idx] + """Kalman fixed-interval (fixed-lag) smoother via augmented state. + + Matches the Matlab implementation: builds an augmented state + of dimension ``(1 + lags) * n_x`` and runs ``kalman_filter`` + on the augmented system. The lagged portion of the augmented + state gives the exact smoothed estimate at lag *lags*. + + Parameters + ---------- + A, C, Pv, Pw : arrays + System matrices (may be time-varying 3-D arrays). + Px0 : (Dx, Dx) — initial covariance. + x0 : (Dx,) — initial state. + y : (N, Dy) or (Dy, N) — observations (auto-detected layout). + lags : int — number of smoothing lags. + + Returns + ------- + x_pLag : (N+1, Dx) — predicted states at the lagged component (time-major). + Pe_pLag : (N+1, Dx, Dx) — predicted covariances at the lagged component. + x_uLag : (N, Dx) — updated states at the lagged component. + Pe_uLag : (N, Dx, Dx) — updated covariances at the lagged component. + """ + A = np.asarray(A, dtype=float) + C = np.asarray(C, dtype=float) + Pv = np.asarray(Pv, dtype=float) + Pw = np.asarray(Pw, dtype=float) + Px0 = np.asarray(Px0, dtype=float) + x0_vec = np.asarray(x0, dtype=float).reshape(-1) + y = np.asarray(y, dtype=float) + if y.ndim == 1: + y = y[None, :] + + nStates = A.shape[0] + nObs = C.shape[0] + + # Auto-detect layout: convert to state-major (Dy, N) for kalman_filter + if y.shape[0] != nObs and y.shape[1] == nObs: + y = y.T + N = y.shape[1] + lags = max(int(lags), 1) + aug_dim = (lags + 1) * nStates + + def _sel(M, n): + if M.ndim == 3: + return M[:, :, min(n, M.shape[2] - 1)] + return M + + # Build augmented time-varying matrices + Alag = np.zeros((aug_dim, aug_dim, N), dtype=float) + Pvlag = np.zeros((aug_dim, aug_dim, N), dtype=float) + Clag = np.zeros((nObs, aug_dim, N), dtype=float) + Pwlag = np.zeros((nObs, nObs, N), dtype=float) + + for n in range(N): + offset = 0 + for i in range(lags + 1): + sl = slice(offset, offset + nStates) + if i == 0: + Alag[sl, sl, n] = _sel(A, n) + Pvlag[sl, sl, n] = _sel(Pv, n) + Clag[:, sl, n] = _sel(C, n) + Pwlag[:, :, n] = _sel(Pw, n) + else: + prev_sl = slice(offset - nStates, offset) + Alag[sl, prev_sl, n] = np.eye(nStates) + offset += nStates + + # Augmented initial state and covariance + x0lag = np.zeros(aug_dim, dtype=float) + x0lag[:nStates] = x0_vec + Px0lag = np.zeros((aug_dim, aug_dim), dtype=float) + Px0lag[:nStates, :nStates] = Px0 + + # Run Kalman filter on augmented system (internal Matlab API) + x_p, Pe_p, x_u, Pe_u, _, _ = DecodingAlgorithms._kalman_filter_matlab( + Alag, Clag, Pvlag, Pwlag, Px0lag, x0lag, y + ) + + # Extract the lagged portion — state-major + lag_sl = slice(lags * nStates, (lags + 1) * nStates) + + # x_p is (aug_dim, N+1), x_u is (aug_dim, N) + # Return N time steps (drop initial prediction at column 0) for + # backward compatibility with the Python API where both predicted + # and updated arrays have the same N rows. + x_pLag_sm = x_p[lag_sl, 1:] # (Dx, N) + Pe_pLag_sm = Pe_p[lag_sl, lag_sl, 1:] # (Dx, Dx, N) -- uses numpy advanced slicing on dim 3 so extract via loop + x_uLag_sm = x_u[lag_sl, :] # (Dx, N) + Pe_uLag_sm = Pe_u[lag_sl, lag_sl, :] # (Dx, Dx, N) + + # Correct the Pe slicing (nested slice on 3-D doesn't work as expected) + Pe_pLag_sm = Pe_p[lags * nStates:(lags + 1) * nStates, + lags * nStates:(lags + 1) * nStates, 1:] + Pe_uLag_sm = Pe_u[lags * nStates:(lags + 1) * nStates, + lags * nStates:(lags + 1) * nStates, :] + + # Return time-major to match kalman_smoother convention: (N, Dx) + x_pLag = x_pLag_sm.T + Pe_pLag = np.transpose(Pe_pLag_sm, (2, 0, 1)) + x_uLag = x_uLag_sm.T + Pe_uLag = np.transpose(Pe_uLag_sm, (2, 0, 1)) return x_pLag, Pe_pLag, x_uLag, Pe_uLag @staticmethod def ComputeStimulusCIs(fitType, xK, Wku, delta, Mc=None, alphaVal=0.05): - del Mc, delta + """Confidence intervals for stimulus estimate. + + When ``Wku`` is a 4-D array ``(numBasis, numBasis, K, K)`` (SSGLM + cross-trial covariance), uses Monte Carlo sampling matching Matlab's + ``DecodingAlgorithms.ComputeStimulusCIs``. + + When ``Wku`` is a 3-D array ``(N, Dx, Dx)`` (e.g. smoother output), + falls back to z-score Gaussian CIs with inverse-link transform. + + Parameters + ---------- + fitType : str + ``'poisson'`` or ``'binomial'``. + xK : array + Smoothed state estimates. Shape ``(numBasis, K)`` for SSGLM + or ``(N, Dx)`` for smoother output. + Wku : array + Covariance. Shape ``(numBasis, numBasis, K, K)`` for SSGLM + or ``(N, Dx, Dx)`` for smoother output. + delta : float + Time-step size in seconds. + Mc : int, optional + Number of Monte Carlo draws (default 3000). Ignored when + using z-score fallback. + alphaVal : float, optional + Significance level for two-sided CIs (default 0.05). + + Returns + ------- + CIs : array, shape ``(..., 2)`` + Lower and upper confidence bounds. + stimulus : array + Point estimate (inverse-link of xK, divided by delta for MC mode). + """ + Wku_arr = np.asarray(Wku, dtype=float) + + # SSGLM cross-trial path: 4-D covariance (numBasis, numBasis, K, K) + if Wku_arr.ndim == 4: + if Mc is None: + Mc = 3000 + return DecodingAlgorithms._ComputeStimulusCIs_MC( + fitType, xK, Wku, delta, Mc=Mc, alphaVal=alphaVal + ) + + # Fallback: 3-D covariance (N, Dx, Dx) from smoother — z-score CIs x_tm, W_tm, transposed = DecodingAlgorithms._state_history_time_major(xK, Wku) variances = np.clip(np.diagonal(W_tm, axis1=1, axis2=2), 0.0, None) z = float(norm.ppf(1.0 - float(alphaVal) / 2.0)) @@ -520,6 +814,381 @@ def ComputeStimulusCIs(fitType, xK, Wku, delta, Mc=None, alphaVal=0.05): return np.transpose(ci, (1, 0, 2)), stimulus.T return ci, stimulus + @staticmethod + def computeSpikeRateCIs( + xK, + Wku, + dN, + t0: float, + tf: float, + fitType: str = "poisson", + delta: float = 0.001, + gamma=None, + windowTimes=None, + Mc: int = 500, + alphaVal: float = 0.05, + ): + """Monte Carlo spike-rate confidence intervals. + + Computes the average firing rate over ``[t0, tf]`` for each trial + by drawing ``Mc`` samples from the smoothing distribution and + evaluating the conditional intensity. + + Parameters + ---------- + xK : array, shape (numBasis, K) + Smoothed state estimates (basis coefficients × trials). + Wku : array, shape (numBasis, numBasis, K, K) or compatible + Smoothed state covariance. + dN : array, shape (C, N) + Observation (spike indicator) matrix. + t0, tf : float + Time window over which to compute the average rate. + fitType : str + ``'poisson'`` or ``'binomial'``. + delta : float + Time-step size in seconds. + gamma : array or None + History-effect coefficients. + windowTimes : array or None + History window boundaries. + Mc : int + Number of Monte Carlo draws. + alphaVal : float + Significance level for CIs (one-sided). + + Returns + ------- + spikeRateSig : Covariate + Mean spike rate per trial with attached ConfidenceInterval. + ProbMat : ndarray, shape (K, K) + ``ProbMat[k, m]`` = P(rate_m > rate_k) estimated from MC draws. + sigMat : ndarray, shape (K, K) + Binary significance matrix at level ``1 - alphaVal``. + """ + from .confidence_interval import ConfidenceInterval + from .core import Covariate + from .history import History + from .nspikeTrain import nspikeTrain + from .trial import SpikeTrainCollection + + xK = np.asarray(xK, dtype=float) + dN = np.asarray(dN, dtype=float) + if dN.ndim == 1: + dN = dN.reshape(1, -1) + numBasis, K = xK.shape + minTime = 0.0 + maxTime = (dN.shape[1] - 1) * delta + + # Build unit-impulse basis matrix + basisWidth = (maxTime - minTime) / numBasis + sampleRate = 1.0 / delta + unitPulseBasis = SpikeTrainCollection.generateUnitImpulseBasis( + basisWidth, minTime, maxTime, sampleRate + ) + basisMat = unitPulseBasis.data # shape (T, numBasis) + + # Build history matrices if windowTimes provided + Hk = {} + if windowTimes is not None and len(windowTimes) > 0: + histObj = History(windowTimes, minTime, maxTime) + for k in range(K): + spike_idx = np.flatnonzero(dN[k, :] == 1) + spike_times = (spike_idx) * delta + nst_k = nspikeTrain(spike_times) + nst_k.setMinTime(minTime) + nst_k.setMaxTime(maxTime) + hist_cov = histObj.computeHistory(nst_k) + Hk[k] = hist_cov.dataToMatrix() + else: + for k in range(K): + Hk[k] = 0.0 + gamma = 0.0 + + if gamma is None: + gamma = 0.0 + gamma = np.asarray(gamma, dtype=float) + + # Monte Carlo draws from smoothing distribution + Wku = np.asarray(Wku, dtype=float) + xKDraw = np.zeros((numBasis, K, Mc), dtype=float) + for r in range(numBasis): + WkuTemp = Wku[r, r, :, :].squeeze() if Wku.ndim == 4 else Wku[r, r] + WkuTemp = np.atleast_2d(WkuTemp) + if WkuTemp.shape[0] != K: + WkuTemp = np.diag(np.full(K, float(WkuTemp.flat[0]))) + try: + chol_m = np.linalg.cholesky(WkuTemp) + except np.linalg.LinAlgError: + eigvals = np.linalg.eigvalsh(WkuTemp) + WkuTemp += np.eye(K) * (abs(min(eigvals.min(), 0.0)) + 1e-10) + chol_m = np.linalg.cholesky(WkuTemp) + for c in range(Mc): + z = np.random.randn(K) + xKDraw[r, :, c] = xK[r, :] + chol_m.T @ z + + # Compute lambda for each MC draw and each trial + time_vec = np.arange(minTime, maxTime + delta, delta) + T = basisMat.shape[0] + fit_type = str(fitType).lower() + spikeRate = np.zeros((Mc, K), dtype=float) + + for c in range(Mc): + for k in range(K): + stimK = basisMat @ xKDraw[:, k, c] + if fit_type == "poisson": + histEffect = np.exp(gamma @ Hk[k].T).ravel() if not np.isscalar(Hk[k]) else np.ones(T) + stimEffect = np.exp(np.clip(stimK, -20.0, 20.0)) + lambdaDelta_kc = stimEffect * histEffect[:T] + elif fit_type == "binomial": + if np.isscalar(Hk[k]): + eta = stimK + else: + eta = stimK + (gamma @ Hk[k].T).ravel()[:T] + eta = np.clip(eta, -20.0, 20.0) + lambdaDelta_kc = np.exp(eta) / (1.0 + np.exp(eta)) + else: + lambdaDelta_kc = np.exp(np.clip(stimK, -20.0, 20.0)) + + # Integrate via cumulative trapezoid + rate_per_sec = lambdaDelta_kc / delta + time_k = time_vec[:len(rate_per_sec)] + cum_integral = np.zeros(len(rate_per_sec)) + cum_integral[1:] = np.cumsum(rate_per_sec[:-1] * delta + 0.5 * np.diff(rate_per_sec) * delta) + + # Interpolate integral at t0 and tf + val_t0 = np.interp(t0, time_k, cum_integral) + val_tf = np.interp(tf, time_k, cum_integral) + spikeRate[c, k] = (1.0 / (tf - t0)) * (val_tf - val_t0) + + # Compute CIs from ECDF (one-sided) + CIs = np.zeros((K, 2), dtype=float) + for k in range(K): + sorted_rates = np.sort(spikeRate[:, k]) + ecdf = np.arange(1, Mc + 1, dtype=float) / float(Mc) + lower_idx = np.flatnonzero(ecdf < alphaVal) + upper_idx = np.flatnonzero(ecdf > (1.0 - alphaVal)) + CIs[k, 0] = sorted_rates[lower_idx[-1]] if lower_idx.size else sorted_rates[0] + CIs[k, 1] = sorted_rates[upper_idx[0]] if upper_idx.size else sorted_rates[-1] + + trial_axis = np.arange(1, K + 1, dtype=float) + mean_rate = np.mean(spikeRate, axis=0) + spikeRateSig = Covariate( + trial_axis, + mean_rate, + f"({tf:g}-{t0:g})^{{-1}} * \\Lambda({tf:g}-{t0:g})", + "Trial", + "k", + "Hz", + ) + ciSpikeRate = ConfidenceInterval( + trial_axis, CIs, "CI_{spikeRate}", "Trial", "k", "Hz" + ) + spikeRateSig.setConfInterval(ciSpikeRate) + + # Pairwise probability matrix + ProbMat = np.zeros((K, K), dtype=float) + for k in range(K): + for m in range(k + 1, K): + ProbMat[k, m] = np.sum(spikeRate[:, m] > spikeRate[:, k]) / float(Mc) + + sigMat = (ProbMat > (1.0 - alphaVal)).astype(float) + + return spikeRateSig, ProbMat, sigMat + + @staticmethod + def computeSpikeRateDiffCIs( + xK, + Wku, + dN, + time1, + time2, + fitType: str = "poisson", + delta: float = 0.001, + gamma=None, + windowTimes=None, + Mc: int = 500, + alphaVal: float = 0.05, + ): + """Monte Carlo CIs for the difference in spike rates between two time windows. + + Computes the difference of average firing rates + ``rate(time1) - rate(time2)`` for each trial by drawing ``Mc`` + samples from the smoothing distribution. + + Parameters + ---------- + xK : array, shape (numBasis, K) + Smoothed state estimates (basis coefficients × trials). + Wku : array, shape (numBasis, numBasis, K, K) or compatible + Smoothed state covariance. + dN : array, shape (C, N) + Observation (spike indicator) matrix. + time1 : array-like, length 2 + ``[t0_1, tf_1]`` — first time window. + time2 : array-like, length 2 + ``[t0_2, tf_2]`` — second time window. + fitType : str + ``'poisson'`` or ``'binomial'``. + delta : float + Time-step size in seconds. + gamma : array or None + History-effect coefficients. + windowTimes : array or None + History window boundaries. + Mc : int + Number of Monte Carlo draws. + alphaVal : float + Significance level for CIs (one-sided). + + Returns + ------- + spikeRateSig : Covariate + Mean spike-rate difference per trial with attached CI. + ProbMat : ndarray, shape (K, K) + ``ProbMat[k, m]`` = P(diff_m > diff_k) from MC draws. + sigMat : ndarray, shape (K, K) + Binary significance matrix at level ``1 - alphaVal``. + """ + from .confidence_interval import ConfidenceInterval + from .core import Covariate + from .history import History + from .nspikeTrain import nspikeTrain + from .trial import SpikeTrainCollection + + xK = np.asarray(xK, dtype=float) + dN = np.asarray(dN, dtype=float) + if dN.ndim == 1: + dN = dN.reshape(1, -1) + numBasis, K = xK.shape + minTime = 0.0 + maxTime = (dN.shape[1] - 1) * delta + + time1 = np.asarray(time1, dtype=float).ravel() + time2 = np.asarray(time2, dtype=float).ravel() + + # Build unit-impulse basis matrix + basisWidth = (maxTime - minTime) / numBasis + sampleRate = 1.0 / delta + unitPulseBasis = SpikeTrainCollection.generateUnitImpulseBasis( + basisWidth, minTime, maxTime, sampleRate + ) + basisMat = unitPulseBasis.data + + # Build history matrices if windowTimes provided + Hk = {} + if windowTimes is not None and len(windowTimes) > 0: + histObj = History(windowTimes, minTime, maxTime) + for k in range(K): + spike_idx = np.flatnonzero(dN[k, :] == 1) + spike_times = (spike_idx) * delta + nst_k = nspikeTrain(spike_times) + nst_k.setMinTime(minTime) + nst_k.setMaxTime(maxTime) + hist_cov = histObj.computeHistory(nst_k) + Hk[k] = hist_cov.dataToMatrix() + else: + for k in range(K): + Hk[k] = 0.0 + gamma = 0.0 + + if gamma is None: + gamma = 0.0 + gamma = np.asarray(gamma, dtype=float) + + # Monte Carlo draws from smoothing distribution + Wku = np.asarray(Wku, dtype=float) + xKDraw = np.zeros((numBasis, K, Mc), dtype=float) + for r in range(numBasis): + WkuTemp = Wku[r, r, :, :].squeeze() if Wku.ndim == 4 else Wku[r, r] + WkuTemp = np.atleast_2d(WkuTemp) + if WkuTemp.shape[0] != K: + WkuTemp = np.diag(np.full(K, float(WkuTemp.flat[0]))) + try: + chol_m = np.linalg.cholesky(WkuTemp) + except np.linalg.LinAlgError: + eigvals = np.linalg.eigvalsh(WkuTemp) + WkuTemp += np.eye(K) * (abs(min(eigvals.min(), 0.0)) + 1e-10) + chol_m = np.linalg.cholesky(WkuTemp) + for c in range(Mc): + z = np.random.randn(K) + xKDraw[r, :, c] = xK[r, :] + chol_m.T @ z + + # Compute lambda and spike-rate difference for each MC draw + time_vec = np.arange(minTime, maxTime + delta, delta) + T = basisMat.shape[0] + fit_type = str(fitType).lower() + spikeRate = np.zeros((Mc, K), dtype=float) + + for c in range(Mc): + for k in range(K): + stimK = basisMat @ xKDraw[:, k, c] + if fit_type == "poisson": + histEffect = np.exp(gamma @ Hk[k].T).ravel() if not np.isscalar(Hk[k]) else np.ones(T) + stimEffect = np.exp(np.clip(stimK, -20.0, 20.0)) + lambdaDelta_kc = stimEffect * histEffect[:T] + elif fit_type == "binomial": + if np.isscalar(Hk[k]): + eta = stimK + else: + eta = stimK + (gamma @ Hk[k].T).ravel()[:T] + eta = np.clip(eta, -20.0, 20.0) + lambdaDelta_kc = np.exp(eta) / (1.0 + np.exp(eta)) + else: + lambdaDelta_kc = np.exp(np.clip(stimK, -20.0, 20.0)) + + # Integrate via cumulative sum + rate_per_sec = lambdaDelta_kc / delta + time_k = time_vec[:len(rate_per_sec)] + cum_integral = np.zeros(len(rate_per_sec)) + cum_integral[1:] = np.cumsum(rate_per_sec[:-1] * delta + 0.5 * np.diff(rate_per_sec) * delta) + + # Rate for time window 1 + t0_1, tf_1 = float(min(time1)), float(max(time1)) + val_t0_1 = np.interp(t0_1, time_k, cum_integral) + val_tf_1 = np.interp(tf_1, time_k, cum_integral) + rate1 = (1.0 / (tf_1 - t0_1)) * (val_tf_1 - val_t0_1) + + # Rate for time window 2 + t0_2, tf_2 = float(min(time2)), float(max(time2)) + val_t0_2 = np.interp(t0_2, time_k, cum_integral) + val_tf_2 = np.interp(tf_2, time_k, cum_integral) + rate2 = (1.0 / (tf_2 - t0_2)) * (val_tf_2 - val_t0_2) + + spikeRate[c, k] = rate1 - rate2 + + # Compute CIs from ECDF (one-sided) + CIs = np.zeros((K, 2), dtype=float) + for k in range(K): + sorted_rates = np.sort(spikeRate[:, k]) + ecdf = np.arange(1, Mc + 1, dtype=float) / float(Mc) + lower_idx = np.flatnonzero(ecdf < alphaVal) + upper_idx = np.flatnonzero(ecdf > (1.0 - alphaVal)) + CIs[k, 0] = sorted_rates[lower_idx[-1]] if lower_idx.size else sorted_rates[0] + CIs[k, 1] = sorted_rates[upper_idx[0]] if upper_idx.size else sorted_rates[-1] + + trial_axis = np.arange(1, K + 1, dtype=float) + mean_rate = np.mean(spikeRate, axis=0) + label = ( + r"(t_{1f}-t_{1o})^{-1} \Lambda(t_{1f}-t_{1o})" + r" - (t_{2f}-t_{2o})^{-1} \Lambda(t_{2f}-t_{2o})" + ) + spikeRateSig = Covariate(trial_axis, mean_rate, label, "Trial", "k", "Hz") + ciSpikeRate = ConfidenceInterval( + trial_axis, CIs, "CI_{spikeRate}", "Trial", "k", "Hz" + ) + spikeRateSig.setConfInterval(ciSpikeRate) + + # Pairwise probability matrix + ProbMat = np.zeros((K, K), dtype=float) + for k in range(K): + for m in range(k + 1, K): + ProbMat[k, m] = np.sum(spikeRate[:, m] > spikeRate[:, k]) / float(Mc) + + sigMat = (ProbMat > (1.0 - alphaVal)).astype(float) + + return spikeRateSig, ProbMat, sigMat + @staticmethod def PPDecode_predict(x_u, W_u, A, Q, Wconv=None): x_vec = np.asarray(x_u, dtype=float).reshape(-1) @@ -528,7 +1197,12 @@ def PPDecode_predict(x_u, W_u, A, Q, Wconv=None): A_mat = _as_state_matrix(A, dim) if Wconv is None or Wconv == []: Q_mat = _as_state_matrix(Q, dim) - W_p = _symmetrize(A_mat @ W_mat @ A_mat.T + Q_mat) + W_p = A_mat @ W_mat @ A_mat.T + Q_mat + # Matlab: if rcond(W_p) < eps or NaN, fall back to W_u + cond_num = np.linalg.cond(W_p) + if not np.isfinite(cond_num) or (1.0 / cond_num) < np.finfo(float).eps: + W_p = W_mat.copy() + W_p = _symmetrize(W_p) else: W_p = _symmetrize(_as_state_matrix(Wconv, dim)) x_p = A_mat @ x_vec @@ -642,35 +1316,204 @@ def _ppdecode_filter_linear( estimateTarget=0, Wconv=None, ): - del yT, PiT, estimateTarget obs = _as_observation_matrix(dN) num_cells, num_steps = obs.shape - num_states = _infer_state_dim(A, beta, num_cells) + N = num_steps + ns = _infer_state_dim(A, beta, num_cells) mu_vec = _normalize_mu(mu, num_cells) - beta_mat = _normalize_beta(beta, num_states, num_cells) + beta_mat = _normalize_beta(beta, ns, num_cells) - x0_vec = np.zeros(num_states, dtype=float) if _is_empty_value(x0) else np.asarray(x0, dtype=float).reshape(-1) - if x0_vec.size != num_states: + x0_vec = np.zeros(ns, dtype=float) if _is_empty_value(x0) else np.asarray(x0, dtype=float).reshape(-1) + if x0_vec.size != ns: raise ValueError("x0 must match the decoding state dimension") - Pi0_mat = np.zeros((num_states, num_states), dtype=float) if _is_empty_value(Pi0) else _as_state_matrix(Pi0, num_states) + Pi0_mat = np.zeros((ns, ns), dtype=float) if _is_empty_value(Pi0) else _as_state_matrix(Pi0, ns) if _is_empty_value(windowTimes): - H_tensor = np.zeros((num_steps, 0, num_cells), dtype=float) + H_tensor = np.zeros((N, 0, num_cells), dtype=float) gamma_mat = np.zeros((0, num_cells), dtype=float) else: H_tensor = _compute_history_terms(obs, float(delta), windowTimes) gamma_mat = _normalize_gamma(gamma, H_tensor.shape[1], num_cells) - x_p = np.zeros((num_states, num_steps + 1), dtype=float) - x_u = np.zeros((num_states, num_steps), dtype=float) - W_p = np.zeros((num_states, num_states, num_steps + 1), dtype=float) - W_u = np.zeros((num_states, num_states, num_steps), dtype=float) + # ------------------------------------------------------------------ + # Target estimation branch (Srinivasan et al. 2006) + # ------------------------------------------------------------------ + has_target = not _is_empty_value(yT) + yT_vec = np.asarray(yT, dtype=float).reshape(-1) if has_target else np.array([], dtype=float) + estimateTarget = int(estimateTarget) + + if has_target: + PiT_mat = _as_state_matrix(PiT, ns) if not _is_empty_value(PiT) else np.zeros((ns, ns), dtype=float) + + # Backward information matrices (Srinivasan Eq. 2.16) + PitT = np.zeros((ns, ns, N), dtype=float) + QT = np.zeros((ns, ns, N), dtype=float) + QN = _select_time_matrix(Q, N - 1, ns) + if estimateTarget == 1: + PitT[:, :, N - 1] = QN # Pi(T,T) = Q_T when PiT = 0 + else: + PitT[:, :, N - 1] = PiT_mat + QN + + # Backward transition matrices + PhitT = np.zeros((ns, ns, N), dtype=float) + PhitT[:, :, N - 1] = np.eye(ns, dtype=float) # phi(T,T) = I + B = np.zeros((ns, ns, N), dtype=float) + + for n in range(N - 1, 0, -1): + An = _select_time_matrix(A, n, ns) + Qn = _select_time_matrix(Q, n, ns) + invA = np.linalg.pinv(An) + PhitT[:, :, n - 1] = invA @ PhitT[:, :, n] + PitT[:, :, n - 1] = invA @ PitT[:, :, n] @ invA.T + Qn # Eq. 2.16 + invPitT_n = np.linalg.pinv(PitT[:, :, n]) + B[:, :, n] = An - (Qn @ invPitT_n) @ An # Eq. 2.21 + QT[:, :, n] = Qn - (Qn @ invPitT_n) @ Qn.T + + A1 = _select_time_matrix(A, 0, ns) + Q1 = _select_time_matrix(Q, 0, ns) + invPitT_0 = np.linalg.pinv(PitT[:, :, 0]) + B[:, :, 0] = A1 - (Q1 @ invPitT_0) @ A1 + QT[:, :, 0] = Q1 - (Q1 @ invPitT_0) @ Q1.T + + if estimateTarget == 1: + # Augmented state space [x_t; y_T] + beta_aug = np.vstack([beta_mat, np.zeros((ns, num_cells), dtype=float)]) + na = 2 * ns + Amat = np.zeros((na, na, N), dtype=float) + Qmat = np.zeros((na, na, N), dtype=float) + + for n in range(N): + An = _select_time_matrix(A, n, ns) + Qn = _select_time_matrix(Q, n, ns) + psi = B[:, :, n] + if n == N - 1: + gammaMat = np.eye(ns, dtype=float) + else: + invPitT_n = np.linalg.pinv(PitT[:, :, n]) + gammaMat = (Qn @ invPitT_n) @ PhitT[:, :, n] + Amat[:ns, :ns, n] = psi + Amat[:ns, ns:, n] = gammaMat + Amat[ns:, ns:, n] = np.eye(ns, dtype=float) + Qmat[:ns, :ns, n] = QT[:, :, n] + + # Augmented initial state + x0_aug = np.concatenate([x0_vec, yT_vec]) + x_p = np.zeros((na, N + 1), dtype=float) + x_u = np.zeros((na, N), dtype=float) + W_p = np.zeros((na, na, N + 1), dtype=float) + W_u = np.zeros((na, na, N), dtype=float) + + x_p[:, 0] = Amat[:, :, 0] @ x0_aug + W_p[:, :, 0] = Amat[:, :, 0] @ np.zeros((na, na), dtype=float) @ Amat[:, :, 0].T + Qmat[:, :, 0] + + for time_index in range(1, N + 1): + x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_updateLinear( + x_p[:, time_index - 1], + W_p[:, :, time_index - 1], + obs, + mu_vec, + beta_aug, + fitType, + gamma_mat, + H_tensor, + time_index, + None, + ) + A_t = Amat[:, :, min(time_index - 1, N - 1)] + Q_t = Qmat[:, :, min(time_index - 1, N - 1)] + x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict( + x_u[:, time_index - 1], + W_u[:, :, time_index - 1], + A_t, + Q_t, + Wconv, + ) - A0 = _select_time_matrix(A, 0, num_states) - Q0 = _select_time_matrix(Q, 0, num_states) + # Decompose augmented state into state + target + x_uT = x_u[ns:, :] + W_uT = W_u[ns:, ns:, :] + x_pT = x_p[ns:, :] + W_pT = W_p[ns:, ns:, :] + x_u = x_u[:ns, :] + W_u = W_u[:ns, :ns, :] + x_p = x_p[:ns, :] + W_p = W_p[:ns, :ns, :] + return x_p, W_p, x_u, W_u, x_pT, W_pT, x_uT, W_uT + + else: + # Non-augmented target branch: use B and ft feedforward term + Amat = B + Qmat_arr = QT + ft = np.zeros((ns, N), dtype=float) + ut = np.zeros((ns, N), dtype=float) + for n in range(N): + An = _select_time_matrix(A, n, ns) + Qn = _select_time_matrix(Q, n, ns) + invPitT_n = np.linalg.pinv(PitT[:, :, n]) + ft[:, n] = (Qn @ invPitT_n) @ PhitT[:, :, n] @ yT_vec + + x_p = np.zeros((ns, N + 1), dtype=float) + x_u = np.zeros((ns, N), dtype=float) + W_p = np.zeros((ns, ns, N + 1), dtype=float) + W_u = np.zeros((ns, ns, N), dtype=float) + + # Initial predict with target correction + invPitT_0 = np.linalg.pinv(PitT[:, :, 0]) + invA1 = np.linalg.pinv(A1) + invPhi0T = np.linalg.pinv(invA1 @ PhitT[:, :, 0]) + ut[:, 0] = (Q1 @ invPitT_0) @ PhitT[:, :, 0] @ (yT_vec - invPhi0T @ x0_vec) + x_p[:, 0] = Amat[:, :, 0] @ x0_vec + ut[:, 0] + W_p[:, :, 0] = Amat[:, :, 0] @ Pi0_mat @ Amat[:, :, 0].T + Qmat_arr[:, :, 0] + + for time_index in range(1, N + 1): + x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_updateLinear( + x_p[:, time_index - 1], + W_p[:, :, time_index - 1], + obs, + mu_vec, + beta_mat, + fitType, + gamma_mat, + H_tensor, + time_index, + None, + ) + if time_index < N: + An = _select_time_matrix(A, time_index - 1, ns) + Qn = _select_time_matrix(Q, time_index - 1, ns) + invPitT_n1 = np.linalg.pinv(PitT[:, :, time_index]) + invPhitm1T = np.linalg.pinv(PhitT[:, :, time_index - 1]) + ut[:, time_index] = (Qn @ invPitT_n1) @ PhitT[:, :, time_index] @ ( + yT_vec - invPhitm1T @ x_u[:, time_index - 1] + ) + A_t = Amat[:, :, min(time_index - 1, N - 1)] + Q_t = Qmat_arr[:, :, min(time_index - 1, N - 1)] + x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict( + x_u[:, time_index - 1], + W_u[:, :, time_index - 1], + A_t, + Q_t, + ) + x_p[:, time_index] += ut[:, time_index] + W_p[:, :, time_index] += (Qn @ invPitT_n1) @ An @ W_u[:, :, time_index - 1] @ An.T @ (Qn @ invPitT_n1).T + + empty_vec = np.array([], dtype=float) + empty_cov = np.zeros((0, 0, 0), dtype=float) + return x_p, W_p, x_u, W_u, empty_vec, empty_cov, empty_vec, empty_cov + + # ------------------------------------------------------------------ + # Standard filter (no target) + # ------------------------------------------------------------------ + x_p = np.zeros((ns, N + 1), dtype=float) + x_u = np.zeros((ns, N), dtype=float) + W_p = np.zeros((ns, ns, N + 1), dtype=float) + W_u = np.zeros((ns, ns, N), dtype=float) + + A0 = _select_time_matrix(A, 0, ns) + Q0 = _select_time_matrix(Q, 0, ns) x_p[:, 0], W_p[:, :, 0] = DecodingAlgorithms.PPDecode_predict(x0_vec, Pi0_mat, A0, Q0, Wconv) - for time_index in range(1, num_steps + 1): + for time_index in range(1, N + 1): x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_updateLinear( x_p[:, time_index - 1], W_p[:, :, time_index - 1], @@ -683,8 +1526,8 @@ def _ppdecode_filter_linear( time_index, None, ) - A_t = _select_time_matrix(A, time_index - 1, num_states) - Q_t = _select_time_matrix(Q, time_index - 1, num_states) + A_t = _select_time_matrix(A, time_index - 1, ns) + Q_t = _select_time_matrix(Q, time_index - 1, ns) x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict( x_u[:, time_index - 1], W_u[:, :, time_index - 1], @@ -879,7 +1722,6 @@ def PPHybridFilterLinear( estimateTarget=0, MinClassificationError=0, ): - del yT, PiT, estimateTarget obs = _as_observation_matrix(dN) A_models = list(A) if isinstance(A, Sequence) and not isinstance(A, np.ndarray) else [A] Q_models = list(Q) if isinstance(Q, Sequence) and not isinstance(Q, np.ndarray) else [Q] @@ -918,6 +1760,85 @@ def PPHybridFilterLinear( H_tensor = _compute_history_terms(obs, float(binwidth), windowTimes) gamma_mat = _normalize_gamma(gamma, H_tensor.shape[1], num_cells) + # ------------------------------------------------------------------ + # Goal-directed target branch (Srinivasan et al. 2006) + # ------------------------------------------------------------------ + estimateTarget = int(estimateTarget) + + # Normalize yT, PiT as per-model lists + if _is_empty_value(yT): + yT_models = [None] * n_models + elif isinstance(yT, (list, tuple)) and not isinstance(yT, np.ndarray): + yT_models = [ + np.asarray(y, dtype=float).reshape(-1) if not _is_empty_value(y) else None + for y in yT + ] + else: + yT_vec = np.asarray(yT, dtype=float).reshape(-1) + yT_models = [yT_vec] * n_models + + if _is_empty_value(PiT): + PiT_models = [None] * n_models + elif isinstance(PiT, (list, tuple)) and not isinstance(PiT, np.ndarray): + PiT_models = [ + _as_state_matrix(p, state_dims[i]) if not _is_empty_value(p) else None + for i, p in enumerate(PiT) + ] + else: + PiT_models = [_as_state_matrix(PiT, state_dims[i]) for i in range(n_models)] + + _has_target = [yT_models[s] is not None for s in range(n_models)] + _any_target = any(_has_target) + + if estimateTarget == 1 and _any_target: + raise NotImplementedError( + "Augmented state-space target estimation (estimateTarget=1) is not yet " + "supported for PPHybridFilterLinear. Use estimateTarget=0 with fixed target, " + "or use PPDecodeFilterLinear which supports both modes." + ) + + # Backward information filter for each target model + PhitT_m = [None] * n_models + PitT_m = [None] * n_models + B_m = [None] * n_models + QT_m = [None] * n_models + + for s in range(n_models): + if not _has_target[s]: + continue + dim = state_dims[s] + PiT_s = PiT_models[s] if PiT_models[s] is not None else np.zeros((dim, dim), dtype=float) + + PitT = np.zeros((dim, dim, num_steps), dtype=float) + PhitT = np.zeros((dim, dim, num_steps), dtype=float) + B_arr = np.zeros((dim, dim, num_steps), dtype=float) + QT_arr = np.zeros((dim, dim, num_steps), dtype=float) + + QN = _select_time_matrix(Q_models[s], num_steps - 1, dim) + PitT[:, :, num_steps - 1] = PiT_s + QN + PhitT[:, :, num_steps - 1] = np.eye(dim, dtype=float) + + for n in range(num_steps - 1, 0, -1): + An = _select_time_matrix(A_models[s], n, dim) + Qn = _select_time_matrix(Q_models[s], n, dim) + invA = np.linalg.pinv(An) + PhitT[:, :, n - 1] = invA @ PhitT[:, :, n] + PitT[:, :, n - 1] = invA @ PitT[:, :, n] @ invA.T + Qn + invPitT_n = np.linalg.pinv(PitT[:, :, n]) + B_arr[:, :, n] = An - (Qn @ invPitT_n) @ An + QT_arr[:, :, n] = Qn - (Qn @ invPitT_n) @ Qn.T + + A1 = _select_time_matrix(A_models[s], 0, dim) + Q1 = _select_time_matrix(Q_models[s], 0, dim) + invPitT_0 = np.linalg.pinv(PitT[:, :, 0]) + B_arr[:, :, 0] = A1 - (Q1 @ invPitT_0) @ A1 + QT_arr[:, :, 0] = Q1 - (Q1 @ invPitT_0) @ Q1.T + + PhitT_m[s] = PhitT + PitT_m[s] = PitT + B_m[s] = B_arr + QT_m[s] = QT_arr + X = np.zeros((max_dim, num_steps), dtype=float) W = np.zeros((max_dim, max_dim, num_steps), dtype=float) X_s = [np.zeros((max_dim, num_steps), dtype=float) for _ in range(n_models)] @@ -967,14 +1888,41 @@ def PPHybridFilterLinear( likelihoods = np.zeros(n_models, dtype=float) for model_index in range(n_models): dim = state_dims[model_index] - A_t = _select_time_matrix(A_models[model_index], time_index, dim) - Q_t = _select_time_matrix(Q_models[model_index], time_index, dim) + if _has_target[model_index]: + A_t = B_m[model_index][:, :, time_index] + Q_t = QT_m[model_index][:, :, time_index] + else: + A_t = _select_time_matrix(A_models[model_index], time_index, dim) + Q_t = _select_time_matrix(Q_models[model_index], time_index, dim) pred_x, pred_W = DecodingAlgorithms.PPDecode_predict( X_s[model_index][:dim, time_index], W_s[model_index][:dim, :dim, time_index], A_t, Q_t, ) + # Goal-directed offset for fixed target (Srinivasan Eq. 2.21) + if _has_target[model_index] and estimateTarget == 0: + Qn_orig = _select_time_matrix(Q_models[model_index], time_index, dim) + invPitT_n = np.linalg.pinv(PitT_m[model_index][:, :, time_index]) + if time_index > 0: + invPhitm1 = np.linalg.pinv(PhitT_m[model_index][:, :, time_index - 1]) + ut = (Qn_orig @ invPitT_n) @ PhitT_m[model_index][:, :, time_index] @ ( + yT_models[model_index] - invPhitm1 @ X_s[model_index][:dim, time_index] + ) + else: + A1_orig = _select_time_matrix(A_models[model_index], 0, dim) + invA1 = np.linalg.pinv(A1_orig) + invPhi0 = np.linalg.pinv(invA1 @ PhitT_m[model_index][:, :, 0]) + ut = (Qn_orig @ invPitT_n) @ PhitT_m[model_index][:, :, time_index] @ ( + yT_models[model_index] - invPhi0 @ X_s[model_index][:dim, time_index] + ) + pred_x = pred_x + ut + An_orig = _select_time_matrix(A_models[model_index], time_index, dim) + pred_W = pred_W + ( + (Qn_orig @ invPitT_n) @ An_orig + @ W_s[model_index][:dim, :dim, time_index] + @ An_orig.T @ (Qn_orig @ invPitT_n).T + ) upd_x, upd_W, lambda_delta = DecodingAlgorithms.PPDecode_updateLinear( pred_x, pred_W, @@ -1045,60 +1993,203 @@ def PPHybridFilterLinear( @staticmethod def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, MinClassificationError=0): + """Hybrid point-process filter with CIF-object evaluation. + + Unlike :meth:`PPHybridFilterLinear` which takes pre-extracted linear + parameters (mu, beta, gamma), this method evaluates CIF objects + directly via their ``evalLambdaDelta`` / ``evalGradient*`` / + ``evalJacobian*`` methods. This supports nonlinear conditional + intensity specifications. + + Falls back to the linear path when the target-estimation branch is + active (``yT`` / ``PiT`` / ``estimateTarget`` supplied), matching + Matlab behaviour. + """ + del yT, PiT, estimateTarget # reserved for future target-estimation branch obs = _as_observation_matrix(dN) + lambda_items = _normalize_cif_collection(lambdaCIFColl) A_models = list(A) if isinstance(A, Sequence) and not isinstance(A, np.ndarray) else [A] - num_states = _infer_state_dim(A_models[0], np.array([0.0]), obs.shape[0]) - mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambdaCIFColl, num_states, obs.shape[0]) - return DecodingAlgorithms.PPHybridFilterLinear( - A, - Q, - p_ij, - Mu0, - obs, - mu, - beta, - fitType, - binwidth, - gamma, - windowTimes, - x0, - Pi0, - yT, - PiT, - estimateTarget, - MinClassificationError, - ) - - # ------------------------------------------------------------------ - # Unscented Kalman Filter (UKF) - # Ported from Matlab DecodingAlgorithms.m - # ------------------------------------------------------------------ + Q_models = list(Q) if isinstance(Q, Sequence) and not isinstance(Q, np.ndarray) else [Q] + n_models = len(A_models) + if len(Q_models) != n_models: + raise ValueError("A and Q must define the same number of hybrid models") - @staticmethod - def ukf_sigmas(x: np.ndarray, P: np.ndarray, c: float) -> np.ndarray: - """Generate sigma points around reference point *x*. + num_cells, num_steps = obs.shape + if len(lambda_items) != num_cells: + raise ValueError("Number of CIF objects must match the number of observed cells") + state_dims = [_infer_state_dim(A_models[index], np.array([0.0]), num_cells) for index in range(n_models)] + max_dim = max(state_dims) - Parameters - ---------- - x : (L,) state vector - P : (L, L) covariance - c : scaling coefficient + x0_models_raw = _normalize_model_sequence(x0, n_models, lambda index: np.zeros(state_dims[index], dtype=float)) + Pi0_models_raw = _normalize_model_sequence(Pi0, n_models, lambda index: np.zeros((state_dims[index], state_dims[index]), dtype=float)) + x0_models = [np.asarray(x0_models_raw[index], dtype=float).reshape(-1) for index in range(n_models)] + Pi0_models = [_as_state_matrix(Pi0_models_raw[index], state_dims[index]) for index in range(n_models)] - Returns - ------- - X : (L, 2L+1) sigma-point matrix - """ - x = np.asarray(x, dtype=float).reshape(-1) - P = np.asarray(P, dtype=float) - A = c * np.linalg.cholesky(P) # (L, L) - L = len(x) - Y = np.tile(x[:, None], (1, L)) - X = np.column_stack([x[:, None], Y + A, Y - A]) - return X + transition = np.asarray(p_ij, dtype=float) + if transition.shape != (n_models, n_models): + raise ValueError("p_ij must be an nModels x nModels transition matrix") + row_sums = np.sum(transition, axis=1) + if not np.allclose(row_sums, np.ones(n_models), atol=1e-8): + raise ValueError("State Transition probability matrix must sum to 1 along each row") - @staticmethod - def ukf_ut(f, X: np.ndarray, Wm: np.ndarray, Wc: np.ndarray, - n: int, R: np.ndarray): + if _is_empty_value(Mu0): + model_probs0 = np.full(n_models, 1.0 / float(n_models), dtype=float) + else: + model_probs0 = _normalize_probabilities(Mu0) + if model_probs0.size != n_models: + raise ValueError("Mu0 must contain one probability per hybrid model") + + X = np.zeros((max_dim, num_steps), dtype=float) + W = np.zeros((max_dim, max_dim, num_steps), dtype=float) + X_s = [np.zeros((max_dim, num_steps), dtype=float) for _ in range(n_models)] + W_s = [np.zeros((max_dim, max_dim, num_steps), dtype=float) for _ in range(n_models)] + X_u = [np.zeros((state_dims[index], num_steps), dtype=float) for index in range(n_models)] + W_u = [np.zeros((state_dims[index], state_dims[index], num_steps), dtype=float) for index in range(n_models)] + X_p = [np.zeros((state_dims[index], num_steps), dtype=float) for index in range(n_models)] + W_p = [np.zeros((state_dims[index], state_dims[index], num_steps), dtype=float) for index in range(n_models)] + MU_u = np.zeros((n_models, num_steps), dtype=float) + pNGivenS = np.zeros((n_models, num_steps), dtype=float) + S_est = np.zeros(num_steps, dtype=int) + + for time_index in range(num_steps): + if time_index == 0: + MU_p = transition.T @ model_probs0 + prev_probs = model_probs0 + else: + MU_p = transition.T @ MU_u[:, time_index - 1] + prev_probs = MU_u[:, time_index - 1] + + p_ij_s = transition * prev_probs[:, None] + column_norm = np.sum(p_ij_s, axis=0, keepdims=True) + column_norm[column_norm == 0.0] = 1.0 + p_ij_s = p_ij_s / column_norm + + for target_model in range(n_models): + mixed_state = np.zeros(max_dim, dtype=float) + for source_model in range(n_models): + dim_i = state_dims[source_model] + source_state = x0_models[source_model] if time_index == 0 else X_u[source_model][:, time_index - 1] + mixed_state[:dim_i] += source_state * p_ij_s[source_model, target_model] + X_s[target_model][:, time_index] = mixed_state + + mixed_cov = np.zeros((max_dim, max_dim), dtype=float) + for source_model in range(n_models): + dim_i = state_dims[source_model] + source_state = x0_models[source_model] if time_index == 0 else X_u[source_model][:, time_index - 1] + source_cov = Pi0_models[source_model] if time_index == 0 else W_u[source_model][:, :, time_index - 1] + diff = source_state - mixed_state[:dim_i] + mixed_cov[:dim_i, :dim_i] += ( + source_cov + np.outer(diff, diff) + ) * p_ij_s[source_model, target_model] + W_s[target_model][:, :, time_index] = _symmetrize(mixed_cov) + + likelihoods = np.zeros(n_models, dtype=float) + for model_index in range(n_models): + dim = state_dims[model_index] + A_t = _select_time_matrix(A_models[model_index], time_index, dim) + Q_t = _select_time_matrix(Q_models[model_index], time_index, dim) + pred_x, pred_W = DecodingAlgorithms.PPDecode_predict( + X_s[model_index][:dim, time_index], + W_s[model_index][:dim, :dim, time_index], + A_t, + Q_t, + ) + # Use CIF-based (nonlinear) update instead of linear + upd_x, upd_W, lambda_delta = DecodingAlgorithms.PPDecode_update( + pred_x, + pred_W, + obs, + lambda_items, + binwidth, + time_index + 1, + None, + ) + X_p[model_index][:, time_index] = pred_x + W_p[model_index][:, :, time_index] = pred_W + X_u[model_index][:, time_index] = upd_x + W_u[model_index][:, :, time_index] = upd_W + + det_ratio = np.sqrt(max(np.linalg.det(upd_W), 0.0)) / max(np.sqrt(max(np.linalg.det(pred_W), 0.0)), 1e-15) + log_term = np.sum(obs[:, time_index] * np.log(np.clip(lambda_delta.reshape(-1), 1e-12, np.inf)) - lambda_delta.reshape(-1)) + likelihoods[model_index] = float(det_ratio * np.exp(np.clip(log_term, -200.0, 50.0))) + + finite_likelihoods = likelihoods.copy() + finite_likelihoods[~np.isfinite(finite_likelihoods)] = 0.0 + pNGivenS[:, time_index] = finite_likelihoods + norm = np.sum(pNGivenS[:, time_index]) + if norm != 0.0 and np.isfinite(norm): + pNGivenS[:, time_index] /= norm + elif time_index > 0: + pNGivenS[:, time_index] = pNGivenS[:, time_index - 1] + else: + pNGivenS[:, time_index] = np.full(n_models, 0.5 if n_models == 2 else 1.0 / float(n_models), dtype=float) + + posterior = MU_p * pNGivenS[:, time_index] + posterior_norm = np.sum(posterior) + if posterior_norm != 0.0 and np.isfinite(posterior_norm): + MU_u[:, time_index] = posterior / posterior_norm + elif time_index > 0: + MU_u[:, time_index] = MU_u[:, time_index - 1] + else: + MU_u[:, time_index] = model_probs0 + + best_model = int(np.argmax(MU_u[:, time_index])) + S_est[time_index] = best_model + 1 + + if MinClassificationError: + chosen = best_model + dim = state_dims[chosen] + X[:dim, time_index] = X_u[chosen][:, time_index] + W[:dim, :dim, time_index] = W_u[chosen][:, :, time_index] + continue + + mixed_global_state = np.zeros(max_dim, dtype=float) + for model_index in range(n_models): + dim = state_dims[model_index] + mixed_global_state[:dim] += MU_u[model_index, time_index] * X_u[model_index][:, time_index] + X[:, time_index] = mixed_global_state + + mixed_global_cov = np.zeros((max_dim, max_dim), dtype=float) + for model_index in range(n_models): + dim = state_dims[model_index] + diff = X_u[model_index][:, time_index] - mixed_global_state[:dim] + mixed_global_cov[:dim, :dim] += MU_u[model_index, time_index] * ( + W_u[model_index][:, :, time_index] + np.outer(diff, diff) + ) + W[:, :, time_index] = _symmetrize(mixed_global_cov) + + return S_est, X, W, MU_u, X_s, W_s, pNGivenS + + # ------------------------------------------------------------------ + # Unscented Kalman Filter (UKF) + # Ported from Matlab DecodingAlgorithms.m + # ------------------------------------------------------------------ + + @staticmethod + def ukf_sigmas(x: np.ndarray, P: np.ndarray, c: float) -> np.ndarray: + """Generate sigma points around reference point *x*. + + Parameters + ---------- + x : (L,) state vector + P : (L, L) covariance + c : scaling coefficient + + Returns + ------- + X : (L, 2L+1) sigma-point matrix + """ + x = np.asarray(x, dtype=float).reshape(-1) + P = np.asarray(P, dtype=float) + A = c * np.linalg.cholesky(P) # (L, L) + L = len(x) + Y = np.tile(x[:, None], (1, L)) + X = np.column_stack([x[:, None], Y + A, Y - A]) + return X + + @staticmethod + def ukf_ut(f, X: np.ndarray, Wm: np.ndarray, Wc: np.ndarray, + n: int, R: np.ndarray): """Unscented transformation. Parameters @@ -1512,8 +2603,9 @@ def PPSS_MStep(dN, HkAll, fitType, x_K, W_K, gamma, delta, sumXkTerms, windowTim converged = True break - # Clamp gamma - gamma_new = np.clip(gamma_new, -1e2, 1e2) + # Clamp gamma — Matlab: gamma_new(gamma_new>1e2)=1e1 + # Only reduce excessively large positive values to 10 + gamma_new[gamma_new > 1e2] = 1e1 return Qhat, gamma_new @@ -1718,16 +2810,17 @@ def PPSS_EMFB(A, Q0, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, neur ) if not negLL: - # Backward EM - _, _, _, QnewR, gnewR, _, _, _, _, negLLR = DecodingAlgorithms.PPSS_EM( + # Backward EM (reversed trial order) + xKR, _, _, QnewR, gnewR, _, _, _, _, negLLR = DecodingAlgorithms.PPSS_EM( A, Qnew, xK[:, -1], np.flipud(dN), fitType, delta, gnew, windowTimes, numBasis, HkAllR ) if not negLLR: # Forward EM again with backward-updated parameters # Matlab: PPSS_EM(A, QhatR(:,cnt+1), xKR(:,end), dN, ...) + # Use backward EM's final state as initial state for forward pass xK2, WK2, Wku2, Qnew2, gnew2, ll2, _, _, _, negLL2 = DecodingAlgorithms.PPSS_EM( - A, QnewR, xK[:, -1], dN, fitType, delta, gnewR, + A, QnewR, xKR[:, -1], dN, fitType, delta, gnewR, windowTimes, numBasis, HkAll ) @@ -2206,6 +3299,4322 @@ def prepareEMResults(fitType, neuronNumber, dN, HkAll, xK, WK, Q, gamma, return fitResults + # ------------------------------------------------------------------ + # Kalman Filter EM (KF_EM) family + # Ported from Matlab DecodingAlgorithms.m lines 3295-4586 + # ------------------------------------------------------------------ + + @staticmethod + def KF_EMCreateConstraints( + EstimateA=1, + AhatDiag=0, + QhatDiag=1, + QhatIsotropic=0, + RhatDiag=1, + RhatIsotropic=0, + Estimatex0=1, + EstimatePx0=1, + Px0Isotropic=0, + mcIter=1000, + EnableIkeda=0, + ): + """Return a dict of EM constraint flags for :meth:`KF_EM`. + + Parameters + ---------- + EstimateA : int + Whether to estimate the state transition matrix *A*. + AhatDiag : int + Constrain *A* to be diagonal. + QhatDiag : int + Constrain *Q* to be diagonal. + QhatIsotropic : int + Constrain *Q* to be isotropic (scalar times identity). + Only active when *QhatDiag* is also true. + RhatDiag : int + Constrain *R* to be diagonal. + RhatIsotropic : int + Constrain *R* to be isotropic. Only active when *RhatDiag* + is also true. + Estimatex0 : int + Whether to estimate the initial state *x0*. + EstimatePx0 : int + Whether to estimate the initial covariance *Px0*. + Px0Isotropic : int + Constrain *Px0* to be isotropic. Only active when + *EstimatePx0* is true. + mcIter : int + Number of Monte Carlo iterations for standard-error + estimation via the observed information matrix. + EnableIkeda : int + Enable Ikeda acceleration in the EM loop. + + Returns + ------- + dict + Constraint dictionary consumed by :meth:`KF_EM`, + :meth:`KF_MStep`, and :meth:`KF_ComputeParamStandardErrors`. + """ + C = {} + C["EstimateA"] = int(EstimateA) + C["AhatDiag"] = int(AhatDiag) + C["QhatDiag"] = int(QhatDiag) + # QhatIsotropic only valid if QhatDiag is true + C["QhatIsotropic"] = 1 if (QhatDiag and QhatIsotropic) else 0 + C["RhatDiag"] = int(RhatDiag) + # RhatIsotropic only valid if RhatDiag is true + C["RhatIsotropic"] = 1 if (RhatDiag and RhatIsotropic) else 0 + C["Estimatex0"] = int(Estimatex0) + C["EstimatePx0"] = int(EstimatePx0) + # Px0Isotropic only valid if EstimatePx0 is true + C["Px0Isotropic"] = 1 if (EstimatePx0 and Px0Isotropic) else 0 + C["mcIter"] = int(mcIter) + C["EnableIkeda"] = int(EnableIkeda) + return C + + # ---- internal Kalman filter matching Matlab (A, C, Q, R, Px0, x0, y) ---- + + @staticmethod + def _kf_filter_stateMajor(A, C, Q, R, Px0, x0, y): + """Run a Kalman filter with Matlab-compatible state-major layout. + + Parameters + ---------- + A : (Dx, Dx) state transition + C : (Dy, Dx) observation matrix + Q : (Dx, Dx) process noise covariance + R : (Dy, Dy) observation noise covariance + Px0 : (Dx, Dx) initial state covariance + x0 : (Dx,) or (Dx, 1) initial state + y : (Dy, K) observations (state-major, each column is one time step) + + Returns + ------- + x_p : (Dx, K+1) predicted states (x_p[:, 0] == x0) + Pe_p : (Dx, Dx, K+1) predicted covariances + x_u : (Dx, K) updated states + Pe_u : (Dx, Dx, K) updated covariances + """ + A = np.asarray(A, dtype=float) + C = np.asarray(C, dtype=float) + Q = np.asarray(Q, dtype=float) + R = np.asarray(R, dtype=float) + Px0 = np.asarray(Px0, dtype=float) + x0 = np.asarray(x0, dtype=float).reshape(-1) + y = np.asarray(y, dtype=float) + + Dx = A.shape[0] + K = y.shape[1] + + x_p = np.zeros((Dx, K + 1), dtype=float) + Pe_p = np.zeros((Dx, Dx, K + 1), dtype=float) + x_u = np.zeros((Dx, K), dtype=float) + Pe_u = np.zeros((Dx, Dx, K), dtype=float) + + x_p[:, 0] = x0 + Pe_p[:, :, 0] = Px0 + + for n in range(K): + # Update + S = C @ Pe_p[:, :, n] @ C.T + R + Gn = Pe_p[:, :, n] @ C.T @ np.linalg.pinv(S) + x_u[:, n] = x_p[:, n] + Gn @ (y[:, n] - C @ x_p[:, n]) + Pe_u[:, :, n] = Pe_p[:, :, n] - Gn @ C @ Pe_p[:, :, n] + # Predict + x_p[:, n + 1] = A @ x_u[:, n] + Pe_p[:, :, n + 1] = A @ Pe_u[:, :, n] @ A.T + Q + + return x_p, Pe_p, x_u, Pe_u + + @staticmethod + def _kf_smootherFromFiltered_stateMajor(A, x_p, Pe_p, x_u, Pe_u): + """RTS smoother with Matlab-compatible state-major layout. + + Parameters + ---------- + A : (Dx, Dx) transition matrix + x_p : (Dx, K+1) predicted states + Pe_p : (Dx, Dx, K+1) predicted covariances + x_u : (Dx, K) updated states + Pe_u : (Dx, Dx, K) updated covariances + + Returns + ------- + x_K : (Dx, K) smoothed states + W_K : (Dx, Dx, K) smoothed covariances + Lk : (Dx, Dx, K-1) smoother gains + """ + K = x_u.shape[1] + Dx = x_u.shape[0] + x_K = np.copy(x_u) + W_K = np.copy(Pe_u) + Lk = np.zeros((Dx, Dx, max(K - 1, 0)), dtype=float) + + for t in range(K - 2, -1, -1): + gain = Pe_u[:, :, t] @ A.T @ np.linalg.pinv(Pe_p[:, :, t + 1]) + Lk[:, :, t] = gain + x_K[:, t] = x_u[:, t] + gain @ (x_K[:, t + 1] - x_p[:, t + 1]) + W_K[:, :, t] = _symmetrize( + Pe_u[:, :, t] + gain @ (W_K[:, :, t + 1] - Pe_p[:, :, t + 1]) @ gain.T + ) + + return x_K, W_K, Lk + + @staticmethod + def KF_EStep(A, Q, C, R, y, alpha, x0, Px0): + """E-step for the Kalman Filter EM algorithm. + + Runs the forward Kalman filter followed by the backward RTS smoother + and computes sufficient statistics (expectation sums) for the M-step. + + Parameters + ---------- + A : (Dx, Dx) state transition matrix + Q : (Dx, Dx) process noise covariance + C : (Dy, Dx) observation matrix + R : (Dy, Dy) observation noise covariance + y : (Dy, K) observations (state-major) + alpha : (Dy, 1) or (Dy,) observation offset + x0 : (Dx,) initial state + Px0 : (Dx, Dx) initial state covariance + + Returns + ------- + x_K : (Dx, K) smoothed states + W_K : (Dx, Dx, K) smoothed covariances + logll : float — complete-data log-likelihood + ExpectationSums : dict of sufficient statistics + """ + A = np.asarray(A, dtype=float) + Q = np.asarray(Q, dtype=float) + C = np.asarray(C, dtype=float) + R = np.asarray(R, dtype=float) + y = np.asarray(y, dtype=float) + alpha = np.asarray(alpha, dtype=float).reshape(-1, 1) + x0 = np.asarray(x0, dtype=float).reshape(-1) + Px0 = np.asarray(Px0, dtype=float) + + Dx = A.shape[1] + Dy = C.shape[0] + K = y.shape[1] + + # Forward filter with offset subtracted: y - alpha*ones(1,K) + y_centered = y - alpha @ np.ones((1, K)) + x_p, Pe_p, x_u, Pe_u = DecodingAlgorithms._kf_filter_stateMajor( + A, C, Q, R, Px0, x0, y_centered + ) + + # Backward RTS smoother + x_K, W_K, Lk = DecodingAlgorithms._kf_smootherFromFiltered_stateMajor( + A, x_p, Pe_p, x_u, Pe_u + ) + + # Best estimates of initial states given the data + # Matlab: W1G0 = A*Px0*A' + Q + W1G0 = A @ Px0 @ A.T + Q + L0 = Px0 @ A.T @ np.linalg.pinv(W1G0) + + # Ex0Gy = x0 + L0*(x_K(:,1) - x_p(:,1)) + Ex0Gy = x0 + L0 @ (x_K[:, 0] - x_p[:, 0]) + # Px0Gy = Px0 + L0*(inv(W_K(:,:,1)) - inv(W1G0))*L0' + Px0Gy = Px0 + L0 @ ( + np.linalg.pinv(W_K[:, :, 0]) - np.linalg.pinv(W1G0) + ) @ L0.T + Px0Gy = _symmetrize(Px0Gy) + + # Cross-covariance terms Wku(:,:,k,u) from de Jong and MacKinnon 1988 + # Only compute the elements actually needed for the sums: + # Wku(:,:,k,k) = W_K(:,:,k) and off-diagonal lags (k, k+1) + # Matlab: Dk(:,:,k) = W_u(:,:,k)*A'/(W_p(:,:,k+1)) + # Wku(:,:,k,u) = Dk(:,:,k)*Wku(:,:,k+1,u) + # We only need Wku(:,:,k-1,k) for the expectation sums. + Wku_lag1 = np.zeros((Dx, Dx, K), dtype=float) # Wku_lag1[:,:,k] = Wku(:,:,k-1,k) + for k in range(K - 1, 0, -1): + # Dk = Pe_u[:,:,k-1] * A' / Pe_p[:,:,k] + Dk = Pe_u[:, :, k - 1] @ A.T @ np.linalg.pinv(Pe_p[:, :, k]) + if k == K - 1: + # Wku(:,:,k-1,k) = Dk * W_K(:,:,k) + Wku_lag1[:, :, k] = Dk @ W_K[:, :, k] + else: + # Wku(:,:,k-1,k) = Dk * Wku(:,:,k,k) = Dk * W_K(:,:,k) + Wku_lag1[:, :, k] = Dk @ W_K[:, :, k] + + # Also need Wku(:,:,0,0) = W_K(:,:,0) and Px0*A'/W_p(:,:,0) for k==0 + # Matlab: Sxkm1xk at k==1: Px0*A'/W_p(:,:,1)*Wku(:,:,1,1) + # Note: Matlab 1-indexed, W_p(:,:,1) is our Pe_p[:,:,0] + # But the Matlab filter stores x_p(:,1)=x0, Pe_p(:,:,1)=Px0 + # and W_p(:,:,1) after the first predict is actually Pe_p(:,:,1) in Matlab = Pe_p[:,:,0] here + # Actually let me re-read: Matlab's filter has Pe_p(:,:,1)=Px0 and the first + # iteration does update then predict, so Pe_p(:,:,2) = A*Pe_u(:,:,1)*A'+Q. + # In our _kf_filter_stateMajor, Pe_p[:,:,0]=Px0 and Pe_p[:,:,1]=A*Pe_u[:,:,0]*A'+Q + # So Matlab's W_p(:,:,1) = Pe_p(:,:,1) in Matlab = our Pe_p[:,:,0] = Px0 + + # Sufficient statistics (expectation sums) + Sxkm1xk = np.zeros((Dx, Dx), dtype=float) + Sxkm1xkm1 = np.zeros((Dx, Dx), dtype=float) + Sxkxk = np.zeros((Dx, Dx), dtype=float) + Sykyk = np.zeros((Dy, Dy), dtype=float) + Sxkyk = np.zeros((Dx, Dy), dtype=float) + + for k in range(K): + if k == 0: + # Matlab: Sxkm1xk = Sxkm1xk + Px0*A'/W_p(:,:,1)*Wku(:,:,1,1) + # W_p(:,:,1) in Matlab is Pe_p[:,:,0] = Px0 here + # Wku(:,:,1,1) = W_K(:,:,1) in Matlab = W_K[:,:,0] here + Sxkm1xk += Px0 @ A.T @ np.linalg.pinv(Pe_p[:, :, 0]) @ W_K[:, :, 0] + Sxkm1xkm1 += Px0 + np.outer(x0, x0) + else: + # Wku(:,:,k-1,k) is Wku_lag1[:,:,k] + Sxkm1xk += Wku_lag1[:, :, k] + np.outer(x_K[:, k - 1], x_K[:, k]) + Sxkm1xkm1 += W_K[:, :, k - 1] + np.outer(x_K[:, k - 1], x_K[:, k - 1]) + Sxkxk += W_K[:, :, k] + np.outer(x_K[:, k], x_K[:, k]) + Sykyk += np.outer(y[:, k] - alpha.ravel(), y[:, k] - alpha.ravel()) + Sxkyk += np.outer(x_K[:, k], y[:, k] - alpha.ravel()) + + Sxkxk = _symmetrize(Sxkxk) + Sykyk = _symmetrize(Sykyk) + + sumXkTerms = Sxkxk - A @ Sxkm1xk - Sxkm1xk.T @ A.T + A @ Sxkm1xkm1 @ A.T + sumYkTerms = Sykyk - C @ Sxkyk - Sxkyk.T @ C.T + C @ Sxkxk @ C.T + Sxkxkm1 = Sxkm1xk.T + + sumXkTerms = _symmetrize(sumXkTerms) + sumYkTerms = _symmetrize(sumYkTerms) + + # Complete-data log-likelihood + # Matlab: logll = -Dx*K/2*log(2*pi) - K/2*log(det(Q)) + # - Dy*K/2*log(2*pi) - K/2*log(det(R)) + # - Dx/2*log(2*pi) - 1/2*log(det(Px0)) + # - 1/2*trace(inv(Q)*sumXkTerms) - 1/2*trace(inv(R)*sumYkTerms) + # - Dx/2 + sign_Q, logdet_Q = np.linalg.slogdet(Q) + sign_R, logdet_R = np.linalg.slogdet(R) + sign_P, logdet_P = np.linalg.slogdet(Px0) + logll = ( + -Dx * K / 2.0 * np.log(2.0 * np.pi) + - K / 2.0 * logdet_Q + - Dy * K / 2.0 * np.log(2.0 * np.pi) + - K / 2.0 * logdet_R + - Dx / 2.0 * np.log(2.0 * np.pi) + - 0.5 * logdet_P + - 0.5 * np.trace(np.linalg.solve(Q, sumXkTerms)) + - 0.5 * np.trace(np.linalg.solve(R, sumYkTerms)) + - Dx / 2.0 + ) + logll = float(logll) + print(f"logll: {logll}") + + ExpectationSums = { + "Sxkm1xkm1": Sxkm1xkm1, + "Sxkm1xk": Sxkm1xk, + "Sxkxkm1": Sxkxkm1, + "Sxkxk": Sxkxk, + "Sxkyk": Sxkyk, + "Sykyk": Sykyk, + "sumXkTerms": sumXkTerms, + "sumYkTerms": sumYkTerms, + "Sx0": Ex0Gy, + "Sx0x0": Px0Gy + np.outer(Ex0Gy, Ex0Gy), + } + + return x_K, W_K, logll, ExpectationSums + + @staticmethod + def KF_MStep(y, x_K, x0, Px0, ExpectationSums, KFEM_Constraints=None): + """M-step for the Kalman Filter EM algorithm. + + Updates all state-space model parameters given the sufficient + statistics from :meth:`KF_EStep`. + + Parameters + ---------- + y : (Dy, K) observations + x_K : (Dx, K) smoothed states from E-step + x0 : (Dx,) current initial state estimate + Px0 : (Dx, Dx) current initial covariance estimate + ExpectationSums : dict from :meth:`KF_EStep` + KFEM_Constraints : dict from :meth:`KF_EMCreateConstraints`, or *None* + + Returns + ------- + Ahat, Qhat, Chat, Rhat, alphahat, x0hat, Px0hat + """ + if KFEM_Constraints is None: + KFEM_Constraints = DecodingAlgorithms.KF_EMCreateConstraints() + + Sxkm1xkm1 = ExpectationSums["Sxkm1xkm1"] + Sxkxkm1 = ExpectationSums["Sxkxkm1"] + Sxkxk = ExpectationSums["Sxkxk"] + Sxkyk = ExpectationSums["Sxkyk"] + sumXkTerms = ExpectationSums["sumXkTerms"] + sumYkTerms = ExpectationSums["sumYkTerms"] + Sx0 = ExpectationSums["Sx0"] + Sx0x0 = ExpectationSums["Sx0x0"] + + y = np.asarray(y, dtype=float) + x_K = np.asarray(x_K, dtype=float) + x0 = np.asarray(x0, dtype=float).reshape(-1) + Px0 = np.asarray(Px0, dtype=float) + + N, K = x_K.shape # N = Dx (num states), K = num time steps + + # Ahat + if KFEM_Constraints["AhatDiag"]: + I_N = np.eye(N) + Ahat = (Sxkxkm1 * I_N) @ np.linalg.pinv(Sxkm1xkm1 * I_N) + else: + Ahat = Sxkxkm1 @ np.linalg.pinv(Sxkm1xkm1) + + # Chat = Sxkyk' / Sxkxk (Matlab: Chat = Sxkyk'/Sxkxk) + Chat = Sxkyk.T @ np.linalg.pinv(Sxkxk) + + # alphahat = sum(y - Chat*x_K, 2) / K + alphahat = np.sum(y - Chat @ x_K, axis=1, keepdims=True) / K + + # Qhat + if KFEM_Constraints["QhatDiag"]: + if KFEM_Constraints["QhatIsotropic"]: + Qhat = (1.0 / (N * K)) * np.trace(sumXkTerms) * np.eye(N) + else: + I_N = np.eye(N) + Qhat = (1.0 / K) * (sumXkTerms * I_N) + Qhat = _symmetrize(Qhat) + else: + Qhat = (1.0 / K) * sumXkTerms + Qhat = _symmetrize(Qhat) + + # Rhat + dy = sumYkTerms.shape[0] + if KFEM_Constraints["RhatDiag"]: + if KFEM_Constraints["RhatIsotropic"]: + I_dy = np.eye(dy) + Rhat = (1.0 / (dy * K)) * np.trace(sumYkTerms) * I_dy + else: + I_dy = np.eye(dy) + Rhat = (1.0 / K) * (sumYkTerms * I_dy) + Rhat = _symmetrize(Rhat) + else: + Rhat = (1.0 / K) * sumYkTerms + Rhat = _symmetrize(Rhat) + + # x0hat — uses the newly computed Ahat and Qhat + if KFEM_Constraints["Estimatex0"]: + # Matlab: x0hat = (inv(Px0)+Ahat'/Qhat*Ahat)\(Ahat'/Qhat*x_K(:,1)+Px0\x0) + Px0_inv = np.linalg.pinv(Px0) + AQ = np.linalg.solve(Qhat, Ahat) # Qhat\Ahat + lhs = Px0_inv + Ahat.T @ AQ + rhs = Ahat.T @ np.linalg.solve(Qhat, x_K[:, 0]) + np.linalg.solve(Px0, x0) + x0hat = np.linalg.solve(lhs, rhs) + else: + x0hat = x0.copy() + + # Px0hat + if KFEM_Constraints["EstimatePx0"]: + if KFEM_Constraints["Px0Isotropic"]: + diff = x0hat - x0 + Px0hat = (np.trace(np.outer(diff, diff)) / (N * K)) * np.eye(N) + else: + I_N = np.eye(N) + diff = x0hat - x0 + Px0hat = ( + np.outer(x0hat, x0hat) + - np.outer(x0, x0hat) + - np.outer(x0hat, x0) + + np.outer(x0, x0) + ) * I_N + Px0hat = _symmetrize(Px0hat) + eigvals, eigvecs = np.linalg.eigh(Px0hat) + if np.min(eigvals) < np.finfo(float).eps: + eigvals[eigvals == np.min(eigvals)] = np.finfo(float).eps + Px0hat = eigvecs @ np.diag(eigvals) @ eigvecs.T + else: + Px0hat = Px0.copy() + + return Ahat, Qhat, Chat, Rhat, alphahat, x0hat, Px0hat + + @staticmethod + def KF_ComputeParamStandardErrors( + y, xKFinal, WKFinal, Ahat, Qhat, Chat, Rhat, alphahat, + x0hat, Px0hat, ExpectationSumsFinal, KFEM_Constraints=None, + ): + """Compute standard errors via the observed information matrix. + + Uses the complete information matrix and a Monte Carlo estimate of + the missing information matrix, following McLachlan and Krishnan + Eq. 4.7: ``Io(theta; y) = Ic(theta; y) - Im(theta; y)``. + + Parameters + ---------- + y : (Dy, K) observations + xKFinal : (Dx, K) smoothed states + WKFinal : (Dx, Dx, K) smoothed covariances + Ahat, Qhat, Chat, Rhat : estimated model matrices + alphahat : (Dy, 1) observation offset + x0hat : (Dx,) initial state + Px0hat : (Dx, Dx) initial covariance + ExpectationSumsFinal : dict from :meth:`KF_EStep` + KFEM_Constraints : dict from :meth:`KF_EMCreateConstraints` + + Returns + ------- + SE : dict of standard-error matrices/vectors for each parameter + Pvals : dict of p-value matrices/vectors for each parameter + """ + if KFEM_Constraints is None: + KFEM_Constraints = DecodingAlgorithms.KF_EMCreateConstraints() + + Ahat = np.asarray(Ahat, dtype=float) + Qhat = np.asarray(Qhat, dtype=float) + Chat = np.asarray(Chat, dtype=float) + Rhat = np.asarray(Rhat, dtype=float) + alphahat = np.asarray(alphahat, dtype=float).reshape(-1, 1) + x0hat = np.asarray(x0hat, dtype=float).reshape(-1) + Px0hat = np.asarray(Px0hat, dtype=float) + y = np.asarray(y, dtype=float) + xKFinal = np.asarray(xKFinal, dtype=float) + WKFinal = np.asarray(WKFinal, dtype=float) + + dy, N = y.shape + dx = xKFinal.shape[0] + K = N + + # ---------------------------------------------------------------- + # Complete Information Matrices + # ---------------------------------------------------------------- + + # --- IAComp: information for A --- + n1_A, n2_A = Ahat.shape + el_A = np.eye(n1_A) + em_A = np.eye(n2_A) + if KFEM_Constraints["AhatDiag"]: + nA = n1_A + IAComp = np.zeros((nA, nA), dtype=float) + cnt = 0 + for l in range(n1_A): + m = l # diagonal only + # termMat = inv(Q) * el(:,l)*em(:,m)' * Sxkm1xkm1 .* I + termMat = np.linalg.solve(Qhat, np.outer(el_A[:, l], em_A[:, m])) @ ( + ExpectationSumsFinal["Sxkm1xkm1"] * np.eye(n1_A, n2_A) + ) + IAComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + nA = Ahat.size + IAComp = np.zeros((nA, nA), dtype=float) + cnt = 0 + Qinv = np.linalg.inv(Qhat) + for l in range(n1_A): + for m in range(n2_A): + termMat = Qinv @ np.outer(el_A[:, l], em_A[:, m]) @ ExpectationSumsFinal["Sxkm1xkm1"] + termvec = termMat.T.ravel() + IAComp[:, cnt] = termvec + cnt += 1 + + # --- ICComp: information for C --- + n1_C, n2_C = Chat.shape + el_C = np.eye(n1_C) + em_C = np.eye(n2_C) + nC = Chat.size + ICComp = np.zeros((nC, nC), dtype=float) + cnt = 0 + Rinv = np.linalg.inv(Rhat) + for l in range(n1_C): + for m in range(n2_C): + termMat = Rinv @ np.outer(el_C[:, l], em_C[:, m]) @ ExpectationSumsFinal["Sxkxk"] + termvec = termMat.T.ravel() + ICComp[:, cnt] = termvec + cnt += 1 + + # --- IRComp: information for R --- + n1_R, n2_R = Rhat.shape + el_R = np.eye(n1_R) + em_R = np.eye(n2_R) + if KFEM_Constraints["RhatDiag"]: + if KFEM_Constraints["RhatIsotropic"]: + IRComp = np.array([[0.5 * N * dy * Rhat[0, 0] ** (-2)]]) + nR = 1 + else: + nR = n1_R + IRComp = np.zeros((nR, nR), dtype=float) + cnt = 0 + for l in range(n1_R): + m = l + termMat = (N / 2.0) * np.linalg.solve(Rhat, np.outer(em_R[:, m], el_R[:, l])) @ np.linalg.inv(Rhat) + IRComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + nR = Rhat.size + IRComp = np.zeros((nR, nR), dtype=float) + cnt = 0 + for l in range(n1_R): + for m in range(n2_R): + termMat = (N / 2.0) * np.linalg.solve(Rhat, np.outer(em_R[:, m], el_R[:, l])) @ np.linalg.inv(Rhat) + termvec = termMat.T.ravel() + IRComp[:, cnt] = termvec + cnt += 1 + + # --- IQComp: information for Q --- + n1_Q, n2_Q = Qhat.shape + el_Q = np.eye(n1_Q) + em_Q = np.eye(n2_Q) + if KFEM_Constraints["QhatDiag"]: + if KFEM_Constraints["QhatIsotropic"]: + IQComp = np.array([[0.5 * N * dx * Qhat[0, 0] ** (-2)]]) + nQ = 1 + else: + nQ = n1_Q + IQComp = np.zeros((nQ, nQ), dtype=float) + cnt = 0 + for l in range(n1_Q): + m = l + termMat = (N / 2.0) * np.linalg.solve(Qhat, np.outer(em_Q[:, m], el_Q[:, l])) @ np.linalg.inv(Qhat) + IQComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + nQ = Qhat.size + IQComp = np.zeros((nQ, nQ), dtype=float) + cnt = 0 + for l in range(n1_Q): + for m in range(n2_Q): + termMat = (N / 2.0) * np.linalg.solve(Qhat, np.outer(em_Q[:, m], el_Q[:, l])) @ np.linalg.inv(Qhat) + termvec = termMat.T.ravel() + IQComp[:, cnt] = termvec + cnt += 1 + + # --- ISComp: information for Px0 --- + if KFEM_Constraints["EstimatePx0"]: + if KFEM_Constraints["Px0Isotropic"]: + ISComp = np.array([[0.5 * dx * Px0hat[0, 0] ** (-2)]]) + nS = 1 + else: + nS = Px0hat.shape[0] + ISComp = np.zeros((nS, nS), dtype=float) + el_S = np.eye(nS) + em_S = np.eye(nS) + cnt = 0 + for l in range(nS): + m = l + termMat = 0.5 * np.linalg.solve(Px0hat, np.outer(em_S[:, m], el_S[:, l])) @ np.linalg.inv(Px0hat) + ISComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + nS = 0 + + # --- Ix0Comp: information for x0 --- + if KFEM_Constraints["Estimatex0"]: + Ix0Comp = np.linalg.inv(Px0hat) + Ahat.T @ np.linalg.solve(Qhat, Ahat) + nx0 = Ix0Comp.shape[0] + else: + nx0 = 0 + + # --- IAlphaComp --- + IAlphaComp = N * np.linalg.inv(Rhat) + nAlpha = IAlphaComp.shape[0] + + # Block sizes + # n1=A, n2=Q, n3=C, n4=R, n5=Px0, n6=x0, n7=alpha + if KFEM_Constraints["EstimateA"]: + n1 = IAComp.shape[0] + else: + n1 = 0 + n2 = IQComp.shape[0] + n3 = ICComp.shape[0] + n4 = IRComp.shape[0] + n5 = nS + n6 = nx0 + n7 = nAlpha + nTerms = n1 + n2 + n3 + n4 + n5 + n6 + n7 + + # Assemble block-diagonal complete information matrix + IComp = np.zeros((nTerms, nTerms), dtype=float) + if KFEM_Constraints["EstimateA"]: + IComp[:n1, :n1] = IAComp + off = n1 + IComp[off:off + n2, off:off + n2] = IQComp + off = n1 + n2 + IComp[off:off + n3, off:off + n3] = ICComp + off = n1 + n2 + n3 + IComp[off:off + n4, off:off + n4] = IRComp + off = n1 + n2 + n3 + n4 + if KFEM_Constraints["EstimatePx0"]: + IComp[off:off + n5, off:off + n5] = ISComp + off = n1 + n2 + n3 + n4 + n5 + if KFEM_Constraints["Estimatex0"]: + IComp[off:off + n6, off:off + n6] = Ix0Comp + off = n1 + n2 + n3 + n4 + n5 + n6 + IComp[off:off + n7, off:off + n7] = IAlphaComp + + # ---------------------------------------------------------------- + # Missing Information Matrix (Monte Carlo) + # ---------------------------------------------------------------- + Mc = KFEM_Constraints["mcIter"] + xKDraw = np.zeros((dx, N, Mc), dtype=float) + + for n in range(N): + WuTemp = WKFinal[:, :, n] + try: + chol_m = np.linalg.cholesky(WuTemp).T # upper Cholesky (Matlab chol returns upper) + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(_nearestSPD(WuTemp)).T + z = np.random.randn(dx, Mc) + xKDraw[:, n, :] = x0hat[:, None] * 0 + xKFinal[:, n:n + 1] + chol_m @ z + + if KFEM_Constraints["EstimatePx0"] or KFEM_Constraints["Estimatex0"]: + try: + chol_m = np.linalg.cholesky(Px0hat).T + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(_nearestSPD(Px0hat)).T + z = np.random.randn(dx, Mc) + x0Draw = x0hat[:, None] + chol_m @ z + else: + x0Draw = np.tile(x0hat[:, None], (1, Mc)) + + IMc = np.zeros((nTerms, nTerms, Mc), dtype=float) + alpha_flat = alphahat.ravel() + + for c in range(Mc): + x_K_c = xKDraw[:, :, c] + x_0_c = x0Draw[:, c] + + Dx_c = x_K_c.shape[0] + Dy_c = y.shape[0] + Sxkm1xk_c = np.zeros((Dx_c, Dx_c)) + Sxkm1xkm1_c = np.zeros((Dx_c, Dx_c)) + Sxkxk_c = np.zeros((Dx_c, Dx_c)) + Sykyk_c = np.zeros((Dy_c, Dy_c)) + Sxkyk_c = np.zeros((Dx_c, Dy_c)) + + for k in range(K): + if k == 0: + Sxkm1xk_c += np.outer(x_0_c, x_K_c[:, k]) + Sxkm1xkm1_c += np.outer(x_0_c, x_0_c) + else: + Sxkm1xk_c += np.outer(x_K_c[:, k - 1], x_K_c[:, k]) + Sxkm1xkm1_c += np.outer(x_K_c[:, k - 1], x_K_c[:, k - 1]) + Sxkxk_c += np.outer(x_K_c[:, k], x_K_c[:, k]) + yk_centered = y[:, k] - alpha_flat + Sykyk_c += np.outer(yk_centered, yk_centered) + Sxkyk_c += np.outer(x_K_c[:, k], yk_centered) + + Sxkxk_c = _symmetrize(Sxkxk_c) + Sykyk_c = _symmetrize(Sykyk_c) + sumXkTerms_c = Sxkxk_c - Ahat @ Sxkm1xk_c - Sxkm1xk_c.T @ Ahat.T + Ahat @ Sxkm1xkm1_c @ Ahat.T + sumYkTerms_c = Sykyk_c - Chat @ Sxkyk_c - Sxkyk_c.T @ Chat.T + Chat @ Sxkxk_c @ Chat.T + Sxkxkm1_c = Sxkm1xk_c.T + Sykxk_c = Sxkyk_c.T + + sumXkTerms_c = _symmetrize(sumXkTerms_c) + sumYkTerms_c = _symmetrize(sumYkTerms_c) + + # Score for A + if KFEM_Constraints["EstimateA"]: + ScorA = np.linalg.solve(Qhat, Sxkxkm1_c - Ahat @ Sxkm1xkm1_c) + if KFEM_Constraints["AhatDiag"]: + ScoreAMc = np.diag(ScorA) + else: + ScoreAMc = ScorA.T.ravel() + else: + ScoreAMc = np.array([], dtype=float) + + # Score for C + ScorC = np.linalg.solve(Rhat, Sykxk_c - Chat @ Sxkxk_c) + ScoreCMc = ScorC.T.ravel() + + # Score for Q + Qinv_c = np.linalg.inv(Qhat) + I_Q = np.eye(Qhat.shape[0]) + if KFEM_Constraints["QhatDiag"]: + if KFEM_Constraints["QhatIsotropic"]: + ScoreQ = -0.5 * (K * Dx_c * Qhat[0, 0] ** (-1) - Qhat[0, 0] ** (-2) * np.trace(sumXkTerms_c)) + ScoreQMc = np.atleast_1d(ScoreQ) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * I_Q - np.linalg.solve(Qhat, sumXkTerms_c).T) + ScoreQMc = np.diag(ScoreQ) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * I_Q - np.linalg.solve(Qhat, sumXkTerms_c).T) + ScoreQMc = ScoreQ.T.ravel() + + # Score for alpha + ScoreAlphaMc = np.sum( + np.linalg.solve(Rhat, y - Chat @ x_K_c - alpha_flat[:, None] @ np.ones((1, N))), + axis=1, + ) + + # Score for R + I_R = np.eye(Rhat.shape[0]) + if KFEM_Constraints["RhatDiag"]: + if KFEM_Constraints["RhatIsotropic"]: + ScoreR = -0.5 * (K * Dy_c * Rhat[0, 0] ** (-1) - Rhat[0, 0] ** (-2) * np.trace(sumYkTerms_c)) + ScoreRMc = np.atleast_1d(ScoreR) + else: + ScoreR = -0.5 * np.linalg.solve(Rhat, K * I_R - np.linalg.solve(Rhat, sumYkTerms_c).T) + ScoreRMc = np.diag(ScoreR) + else: + ScoreR = -0.5 * np.linalg.solve(Rhat, K * I_R - np.linalg.solve(Rhat, sumYkTerms_c).T) + ScoreRMc = ScoreR.T.ravel() + + # Score for Px0 + diff_x0 = x_0_c - x0hat + if KFEM_Constraints["Px0Isotropic"]: + ScoreSMc = np.atleast_1d( + -0.5 * (Dx_c * Px0hat[0, 0] ** (-1) - Px0hat[0, 0] ** (-2) * np.trace(np.outer(diff_x0, diff_x0))) + ) + else: + ScorS = -0.5 * np.linalg.solve( + Px0hat, + np.eye(Px0hat.shape[0]) - np.linalg.solve(Px0hat, np.outer(diff_x0, diff_x0)).T, + ) + ScoreSMc = np.diag(ScorS) + + # Score for x0 + Scorx0 = -np.linalg.solve(Px0hat, diff_x0) + Ahat.T @ np.linalg.solve(Qhat, x_K_c[:, 0] - Ahat @ x_0_c) + Scorex0Mc = Scorx0.ravel() + + # Assemble score vector + ScoreVec = ScoreAMc if KFEM_Constraints["EstimateA"] else np.array([], dtype=float) + ScoreVec = np.concatenate([ScoreVec, ScoreQMc, ScoreCMc, ScoreRMc]) + if KFEM_Constraints["EstimatePx0"]: + ScoreVec = np.concatenate([ScoreVec, ScoreSMc]) + if KFEM_Constraints["Estimatex0"]: + ScoreVec = np.concatenate([ScoreVec, Scorex0Mc]) + ScoreVec = np.concatenate([ScoreVec, ScoreAlphaMc]) + + IMc[:, :, c] = np.outer(ScoreVec, ScoreVec) + + # Observed information = Complete - Missing + IMissing = np.mean(IMc, axis=2) + IObs = IComp - IMissing + invIObs = np.linalg.pinv(IObs) + invIObs = _nearestSPD(invIObs) + + VarVec = np.diag(invIObs) + SEVec = np.sqrt(np.maximum(VarVec, 0.0)) + + # Unpack SE vector + off = 0 + SEAterms = SEVec[off:off + n1]; off += n1 + SEQterms = SEVec[off:off + n2]; off += n2 + SECterms = SEVec[off:off + n3]; off += n3 + SERterms = SEVec[off:off + n4]; off += n4 + SEPx0terms = SEVec[off:off + n5]; off += n5 + SEx0terms = SEVec[off:off + n6]; off += n6 + SEAlphaterms = SEVec[off:off + n7] + + # Reshape SEs into matrices matching parameter shapes + SE = {} + if KFEM_Constraints["EstimateA"]: + if KFEM_Constraints["AhatDiag"]: + SE["A"] = np.diag(SEAterms) + else: + SE["A"] = SEAterms.reshape(Ahat.shape[1], Ahat.shape[0]).T + SE["Q"] = np.diag(SEQterms) if KFEM_Constraints["QhatDiag"] else SEQterms.reshape(Qhat.shape[1], Qhat.shape[0]).T + SE["C"] = SECterms.reshape(Chat.shape[1], Chat.shape[0]).T + SE["R"] = np.diag(SERterms) if KFEM_Constraints["RhatDiag"] else SERterms.reshape(Rhat.shape[1], Rhat.shape[0]).T + SE["alpha"] = SEAlphaterms.reshape(alphahat.shape) + if KFEM_Constraints["EstimatePx0"]: + SE["Px0"] = np.diag(SEPx0terms) + if KFEM_Constraints["Estimatex0"]: + SE["x0"] = SEx0terms + + # Compute p-values via z-tests + Pvals = {} + if KFEM_Constraints["EstimateA"]: + if KFEM_Constraints["AhatDiag"]: + pA = np.diag([_ztest_pvalue(Ahat[i, i], SE["A"][i, i]) for i in range(Ahat.shape[0])]) + else: + pA_flat = [_ztest_pvalue(Ahat.ravel()[i], SE["A"].ravel()[i]) for i in range(Ahat.size)] + pA = np.array(pA_flat).reshape(Ahat.shape) + Pvals["A"] = pA + + # C p-values + pC_flat = [_ztest_pvalue(Chat.ravel()[i], SE["C"].ravel()[i]) for i in range(Chat.size)] + Pvals["C"] = np.array(pC_flat).reshape(Chat.shape) + + # R p-values + if KFEM_Constraints["RhatDiag"]: + if KFEM_Constraints["RhatIsotropic"]: + pR = np.diag([_ztest_pvalue(Rhat[0, 0], SE["R"][0, 0])]) + else: + pR = np.diag([_ztest_pvalue(Rhat[i, i], SE["R"][i, i]) for i in range(Rhat.shape[0])]) + else: + pR_flat = [_ztest_pvalue(Rhat.ravel()[i], SE["R"].ravel()[i]) for i in range(Rhat.size)] + pR = np.array(pR_flat).reshape(Rhat.shape) + Pvals["R"] = pR + + # Q p-values + if KFEM_Constraints["QhatDiag"]: + if KFEM_Constraints["QhatIsotropic"]: + pQ = np.diag([_ztest_pvalue(Qhat[0, 0], SE["Q"][0, 0])]) + else: + pQ = np.diag([_ztest_pvalue(Qhat[i, i], SE["Q"][i, i]) for i in range(Qhat.shape[0])]) + else: + pQ_flat = [_ztest_pvalue(Qhat.ravel()[i], SE["Q"].ravel()[i]) for i in range(Qhat.size)] + pQ = np.array(pQ_flat).reshape(Qhat.shape) + Pvals["Q"] = pQ + + # Px0 p-values + if KFEM_Constraints["EstimatePx0"]: + if KFEM_Constraints["Px0Isotropic"]: + pPx0 = np.diag([_ztest_pvalue(Px0hat[0, 0], SE["Px0"][0, 0])]) + else: + pPx0 = np.diag([_ztest_pvalue(Px0hat[i, i], SE["Px0"][i, i]) for i in range(Px0hat.shape[0])]) + Pvals["Px0"] = pPx0 + + # alpha p-values + alpha_flat_se = SE["alpha"].ravel() + pAlpha = np.array([_ztest_pvalue(alphahat.ravel()[i], alpha_flat_se[i]) for i in range(alphahat.size)]) + Pvals["alpha"] = pAlpha + + # x0 p-values + if KFEM_Constraints["Estimatex0"]: + pX0 = np.array([_ztest_pvalue(x0hat[i], SE["x0"][i]) for i in range(x0hat.size)]) + Pvals["x0"] = pX0 + + return SE, Pvals + + @staticmethod + def KF_EM( + y, + Ahat0, + Qhat0, + Chat0, + Rhat0, + alphahat0, + x0=None, + Px0=None, + KFEM_Constraints=None, + ): + """Kalman Filter EM algorithm with Cholesky-scaled system. + + Estimates the parameters of a linear-Gaussian state-space model:: + + x_{k+1} = A x_k + v_k, v ~ N(0, Q) + y_k = C x_k + alpha + w_k, w ~ N(0, R) + + using the Expectation-Maximisation algorithm (E-step: KF + RTS + smoother, M-step: closed-form updates). Optionally applies Ikeda + acceleration. + + Parameters + ---------- + y : (Dy, K) observation matrix (each column is one time step) + Ahat0 : (Dx, Dx) initial state transition + Qhat0 : (Dx, Dx) initial process noise covariance + Chat0 : (Dy, Dx) initial observation matrix + Rhat0 : (Dy, Dy) initial observation noise covariance + alphahat0 : (Dy, 1) initial observation offset + x0 : (Dx,) initial state (default zeros) + Px0 : (Dx, Dx) initial state covariance (default 1e-10 * I) + KFEM_Constraints : dict from :meth:`KF_EMCreateConstraints` + + Returns + ------- + xKFinal : (Dx, K) smoothed state estimates + WKFinal : (Dx, Dx, K) smoothed covariances + Ahat, Qhat, Chat, Rhat : estimated model matrices + alphahat : (Dy, 1) estimated observation offset + x0hat : (Dx,) estimated initial state + Px0hat : (Dx, Dx) estimated initial covariance + IC : dict of information criteria (AIC, AICc, BIC, llobs, llcomp) + SE : dict of standard errors (or empty dict if not computed) + Pvals : dict of p-values (or empty dict if not computed) + nIter : int — number of EM iterations + """ + Ahat0 = np.asarray(Ahat0, dtype=float) + Qhat0 = np.asarray(Qhat0, dtype=float) + Chat0 = np.asarray(Chat0, dtype=float) + Rhat0 = np.asarray(Rhat0, dtype=float) + alphahat0 = np.asarray(alphahat0, dtype=float).reshape(-1, 1) + y = np.asarray(y, dtype=float) + numStates = Ahat0.shape[0] + + if KFEM_Constraints is None: + KFEM_Constraints = DecodingAlgorithms.KF_EMCreateConstraints() + if Px0 is None: + Px0 = 1e-10 * np.eye(numStates) + else: + Px0 = np.asarray(Px0, dtype=float) + if x0 is None: + x0 = np.zeros(numStates) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + + tolAbs = 1e-3 + llTol = 1e-3 + maxIter = 100 + numToKeep = 10 + + # Save originals for un-scaling later + A0 = Ahat0.copy() + Q0 = Qhat0.copy() + C0 = Chat0.copy() + R0 = Rhat0.copy() + alpha0 = alphahat0.copy() + yOrig = y.copy() + + # Circular buffers (indexed by storeInd) + Ahat_buf = [None] * numToKeep + Qhat_buf = [None] * numToKeep + Chat_buf = [None] * numToKeep + Rhat_buf = [None] * numToKeep + x0hat_buf = [None] * numToKeep + Px0hat_buf = [None] * numToKeep + alphahat_buf = [None] * numToKeep + x_K_buf = [None] * numToKeep + W_K_buf = [None] * numToKeep + ExpSums_buf = [None] * numToKeep + + # Initialize slot 0 + Ahat_buf[0] = A0.copy() + Qhat_buf[0] = Q0.copy() + Chat_buf[0] = C0.copy() + Rhat_buf[0] = R0.copy() + x0hat_buf[0] = x0.copy() + Px0hat_buf[0] = Px0.copy() + alphahat_buf[0] = alpha0.copy() + + # Scale the system via Cholesky transforms + # Matlab: Tq = eye(size(Q))/(chol(Q)); Tr = eye(size(R))/(chol(R)) + scaledSystem = True + if scaledSystem: + try: + cholQ = np.linalg.cholesky(Qhat_buf[0]).T # upper Cholesky + except np.linalg.LinAlgError: + cholQ = np.linalg.cholesky(_nearestSPD(Qhat_buf[0])).T + try: + cholR = np.linalg.cholesky(Rhat_buf[0]).T # upper Cholesky + except np.linalg.LinAlgError: + cholR = np.linalg.cholesky(_nearestSPD(Rhat_buf[0])).T + Tq = np.linalg.solve(cholQ, np.eye(numStates)) + Tr = np.linalg.solve(cholR, np.eye(y.shape[0])) + + Ahat_buf[0] = Tq @ Ahat_buf[0] @ np.linalg.solve(Tq, np.eye(numStates)) + Chat_buf[0] = Tr @ Chat_buf[0] @ np.linalg.solve(Tq, np.eye(numStates)) + Qhat_buf[0] = Tq @ Qhat_buf[0] @ Tq.T + Rhat_buf[0] = Tr @ Rhat_buf[0] @ Tr.T + y = Tr @ y + x0hat_buf[0] = Tq @ x0 + Px0hat_buf[0] = Tq @ Px0 @ Tq.T + alphahat_buf[0] = Tr @ alphahat_buf[0] + + cnt = 0 # 0-based iteration counter + ll_list = [] + dLikelihood = [np.inf, np.inf] + IkedaAcc = KFEM_Constraints["EnableIkeda"] + stoppingCriteria = False + + print(" Kalman Filter/Gaussian Observation EM Algorithm ") + + while not stoppingCriteria and cnt < maxIter: + storeInd = cnt % numToKeep + storeIndP1 = (cnt + 1) % numToKeep + storeIndM1 = (cnt - 1) % numToKeep + + print("-" * 100) + print(f"Iteration #{cnt + 1}") + print("-" * 100) + + # E-step + x_K_buf[storeInd], W_K_buf[storeInd], ll_val, ExpSums_buf[storeInd] = ( + DecodingAlgorithms.KF_EStep( + Ahat_buf[storeInd], Qhat_buf[storeInd], + Chat_buf[storeInd], Rhat_buf[storeInd], + y, alphahat_buf[storeInd], + x0hat_buf[storeInd], Px0hat_buf[storeInd], + ) + ) + ll_list.append(ll_val) + + # M-step + ( + Ahat_buf[storeIndP1], Qhat_buf[storeIndP1], + Chat_buf[storeIndP1], Rhat_buf[storeIndP1], + alphahat_buf[storeIndP1], x0hat_buf[storeIndP1], + Px0hat_buf[storeIndP1], + ) = DecodingAlgorithms.KF_MStep( + y, x_K_buf[storeInd], + x0hat_buf[storeInd], Px0hat_buf[storeInd], + ExpSums_buf[storeInd], KFEM_Constraints, + ) + + # Ikeda acceleration + if IkedaAcc: + print("****Ikeda Acceleration Step****") + K_obs = x_K_buf[storeInd].shape[1] + mean_y = ( + Chat_buf[storeIndP1] @ x_K_buf[storeInd] + + alphahat_buf[storeIndP1] @ np.ones((1, K_obs)) + ) + ykNew = np.random.multivariate_normal( + np.zeros(Rhat_buf[storeIndP1].shape[0]), + Rhat_buf[storeIndP1], + size=K_obs, + ).T + mean_y + + x_KNew, W_KNew, llNew, ExpSumsNew = DecodingAlgorithms.KF_EStep( + Ahat_buf[storeInd], Qhat_buf[storeInd], + Chat_buf[storeInd], Rhat_buf[storeInd], + ykNew, alphahat_buf[storeInd], x0, Px0, + ) + ( + AhatNew, QhatNew, ChatNew, RhatNew, + alphahatNew, x0new, Px0new, + ) = DecodingAlgorithms.KF_MStep( + ykNew, x_KNew, x0hat_buf[storeInd], + Px0hat_buf[storeInd], ExpSumsNew, KFEM_Constraints, + ) + + Ahat_buf[storeIndP1] = 2 * Ahat_buf[storeIndP1] - AhatNew + Qhat_buf[storeIndP1] = 2 * Qhat_buf[storeIndP1] - QhatNew + Qhat_buf[storeIndP1] = _symmetrize(Qhat_buf[storeIndP1]) + Chat_buf[storeIndP1] = 2 * Chat_buf[storeIndP1] - ChatNew + Rhat_buf[storeIndP1] = 2 * Rhat_buf[storeIndP1] - RhatNew + Rhat_buf[storeIndP1] = _symmetrize(Rhat_buf[storeIndP1]) + alphahat_buf[storeIndP1] = 2 * alphahat_buf[storeIndP1] - alphahatNew + + # Override A if not estimating + if not KFEM_Constraints["EstimateA"]: + Ahat_buf[storeIndP1] = Ahat_buf[storeInd] + + # Likelihood change + if cnt == 0: + dLikelihood_val = np.inf + else: + dLikelihood_val = ll_list[cnt] - ll_list[cnt - 1] + + # Convergence check: max parameter change + if cnt == 0: + dMax = np.inf + else: + prev = storeIndM1 + dQvals = np.max(np.abs( + np.sqrt(np.abs(np.diag(Qhat_buf[storeInd]))) + - np.sqrt(np.abs(np.diag(Qhat_buf[prev]))) + )) + dRvals = np.max(np.abs( + np.sqrt(np.abs(np.diag(Rhat_buf[storeInd]))) + - np.sqrt(np.abs(np.diag(Rhat_buf[prev]))) + )) + dAvals = np.max(np.abs(Ahat_buf[storeInd] - Ahat_buf[prev])) + dCvals = np.max(np.abs(Chat_buf[storeInd] - Chat_buf[prev])) + dAlphavals = np.max(np.abs(alphahat_buf[storeInd] - alphahat_buf[prev])) + dMax = max(dQvals, dRvals, dAvals, dCvals, dAlphavals) + + if cnt == 0: + print("Max Parameter Change: N/A") + else: + print(f"Max Parameter Change: {dMax}") + + cnt += 1 + + if dMax < tolAbs: + stoppingCriteria = True + print(f" EM converged at iteration# {cnt} b/c change in params was within criteria") + + if abs(dLikelihood_val) < llTol or dLikelihood_val < 0: + stoppingCriteria = True + print(f" EM stopped at iteration# {cnt} b/c change in likelihood was negative") + + print("-" * 100) + + # Select best iteration by max log-likelihood + ll_arr = np.array(ll_list) + maxLLIndex = int(np.argmax(ll_arr)) + maxLLIndMod = maxLLIndex % numToKeep + nIter = cnt + + xKFinal = x_K_buf[maxLLIndMod] + WKFinal = W_K_buf[maxLLIndMod] + Ahat_final = Ahat_buf[maxLLIndMod] + Qhat_final = Qhat_buf[maxLLIndMod] + Chat_final = Chat_buf[maxLLIndMod] + Rhat_final = Rhat_buf[maxLLIndMod] + alphahat_final = alphahat_buf[maxLLIndMod] + x0hat_final = x0hat_buf[maxLLIndMod] + Px0hat_final = Px0hat_buf[maxLLIndMod] + + # Un-scale the system + if scaledSystem: + # Reconstruct Tq, Tr from original Q0, R0 + try: + cholQ0 = np.linalg.cholesky(Q0).T + except np.linalg.LinAlgError: + cholQ0 = np.linalg.cholesky(_nearestSPD(Q0)).T + try: + cholR0 = np.linalg.cholesky(R0).T + except np.linalg.LinAlgError: + cholR0 = np.linalg.cholesky(_nearestSPD(R0)).T + Tq = np.linalg.solve(cholQ0, np.eye(numStates)) + Tr = np.linalg.solve(cholR0, np.eye(y.shape[0])) + + # Matlab: Ahat = Tq\Ahat*Tq + Tq_inv = np.linalg.inv(Tq) + Tr_inv = np.linalg.inv(Tr) + Ahat_final = Tq_inv @ Ahat_final @ Tq + Qhat_final = Tq_inv @ Qhat_final @ Tq_inv.T + Chat_final = Tr_inv @ Chat_final @ Tq + Rhat_final = Tr_inv @ Rhat_final @ Tr_inv.T + alphahat_final = Tr_inv @ alphahat_final + xKFinal = Tq_inv @ xKFinal + x0hat_final = Tq_inv @ x0hat_final + Px0hat_final = Tq_inv @ Px0hat_final @ Tq_inv.T + K_steps = WKFinal.shape[2] + tempWK = np.zeros_like(WKFinal) + for kk in range(K_steps): + tempWK[:, :, kk] = Tq_inv @ WKFinal[:, :, kk] @ Tq_inv.T + WKFinal = tempWK + + ll_best = ll_list[maxLLIndex] + ExpectationSumsFinal = ExpSums_buf[maxLLIndMod] + + # Compute standard errors + SE, Pvals = DecodingAlgorithms.KF_ComputeParamStandardErrors( + yOrig, xKFinal, WKFinal, Ahat_final, Qhat_final, + Chat_final, Rhat_final, alphahat_final, x0hat_final, + Px0hat_final, ExpectationSumsFinal, KFEM_Constraints, + ) + + # Compute information criteria + # Count number of estimated parameters (matches Matlab lines 3600-3640) + if KFEM_Constraints["EstimateA"] and KFEM_Constraints["AhatDiag"]: + np1 = Ahat_final.shape[0] + elif KFEM_Constraints["EstimateA"] and not KFEM_Constraints["AhatDiag"]: + np1 = Ahat_final.size + else: + np1 = 0 + + if KFEM_Constraints["QhatDiag"] and KFEM_Constraints["QhatIsotropic"]: + np2 = 1 + elif KFEM_Constraints["QhatDiag"] and not KFEM_Constraints["QhatIsotropic"]: + np2 = Qhat_final.shape[0] + else: + np2 = Qhat_final.size + + np3 = Chat_final.size + + if KFEM_Constraints["RhatDiag"] and KFEM_Constraints["RhatIsotropic"]: + np4 = 1 + elif KFEM_Constraints["QhatDiag"] and not KFEM_Constraints["QhatIsotropic"]: + # Note: Matlab line 3618 checks QhatDiag here (likely a bug, but we match it) + np4 = Rhat_final.shape[0] + else: + np4 = Rhat_final.size + + if KFEM_Constraints["EstimatePx0"] and KFEM_Constraints["Px0Isotropic"]: + np5 = 1 + elif KFEM_Constraints["EstimatePx0"] and not KFEM_Constraints["Px0Isotropic"]: + np5 = Px0hat_final.shape[0] + else: + np5 = 0 + + np6 = x0hat_final.shape[0] if KFEM_Constraints["Estimatex0"] else 0 + np7 = alphahat_final.shape[0] + nTerms_ic = np1 + np2 + np3 + np4 + np5 + np6 + np7 + + K_steps = yOrig.shape[1] + Dx = Ahat_final.shape[1] + sumXkTerms_final = ExpectationSumsFinal["sumXkTerms"] + + # Matlab: llobs = ll + Dx*K/2*log(2*pi) + K/2*log(det(Qhat)) + # + 1/2*trace(Qhat\sumXkTerms) + Dx/2*log(2*pi) + # + 1/2*log(det(Px0hat)) + 1/2*Dx + _, logdet_Q_final = np.linalg.slogdet(Qhat_final) + _, logdet_Px0_final = np.linalg.slogdet(Px0hat_final) + llobs = ( + ll_best + + Dx * K_steps / 2.0 * np.log(2.0 * np.pi) + + K_steps / 2.0 * logdet_Q_final + + 0.5 * np.trace(np.linalg.solve(Qhat_final, sumXkTerms_final)) + + Dx / 2.0 * np.log(2.0 * np.pi) + + 0.5 * logdet_Px0_final + + 0.5 * Dx + ) + + AIC = 2.0 * nTerms_ic - 2.0 * llobs + AICc = AIC + 2.0 * nTerms_ic * (nTerms_ic + 1) / max(K_steps - nTerms_ic - 1, 1) + BIC = -2.0 * llobs + nTerms_ic * np.log(K_steps) + + IC = { + "AIC": float(AIC), + "AICc": float(AICc), + "BIC": float(BIC), + "llobs": float(llobs), + "llcomp": float(ll_best), + } + + return ( + xKFinal, WKFinal, Ahat_final, Qhat_final, + Chat_final, Rhat_final, alphahat_final, + x0hat_final, Px0hat_final, IC, SE, Pvals, nIter, + ) + + @staticmethod + # PP_EM family: Point-Process state-space EM (without basis functions) + # ------------------------------------------------------------------ + + @staticmethod + def PP_EMCreateConstraints( + EstimateA=1, + AhatDiag=0, + QhatDiag=1, + QhatIsotropic=0, + Estimatex0=1, + EstimatePx0=1, + Px0Isotropic=0, + mcIter=1000, + EnableIkeda=0, + ): + """Build a constraints dict for PP_EM. + + Parameters + ---------- + EstimateA : int + Whether to estimate the state transition matrix A. + AhatDiag : int + Constrain A to be diagonal. + QhatDiag : int + Constrain Q to be diagonal. + QhatIsotropic : int + Constrain Q to be isotropic (scalar * I). + Estimatex0 : int + Whether to estimate the initial state x0. + EstimatePx0 : int + Whether to estimate the initial state covariance Px0. + Px0Isotropic : int + Constrain Px0 to be isotropic. + mcIter : int + Number of Monte Carlo iterations for standard error estimation. + EnableIkeda : int + Enable Ikeda acceleration. + + Returns + ------- + dict + Constraints dictionary with all fields. + """ + C = {} + C["EstimateA"] = int(EstimateA) + C["AhatDiag"] = int(AhatDiag) + C["QhatDiag"] = int(QhatDiag) + C["QhatIsotropic"] = 1 if (QhatDiag and QhatIsotropic) else 0 + C["Estimatex0"] = int(Estimatex0) + C["EstimatePx0"] = int(EstimatePx0) + C["Px0Isotropic"] = 1 if (EstimatePx0 and Px0Isotropic) else 0 + C["mcIter"] = int(mcIter) + C["EnableIkeda"] = int(EnableIkeda) + return C + + @staticmethod + def _nearestSPD(A): + """Compute the nearest symmetric positive semi-definite matrix. + + Uses the algorithm of Higham (1988). + """ + B = 0.5 * (A + A.T) + _, S, Vt = np.linalg.svd(B) + H = Vt.T @ np.diag(S) @ Vt + Ahat = 0.5 * (B + H) + Ahat = 0.5 * (Ahat + Ahat.T) + # Test positive definiteness and fix if needed + try: + np.linalg.cholesky(Ahat) + return Ahat + except np.linalg.LinAlgError: + pass + spacing = np.spacing(np.linalg.norm(A)) + I = np.eye(A.shape[0]) + k = 1 + while True: + try: + np.linalg.cholesky(Ahat) + return Ahat + except np.linalg.LinAlgError: + mineig = np.min(np.real(np.linalg.eigvalsh(Ahat))) + Ahat += I * (-mineig * k ** 2 + spacing) + k += 1 + if k > 100: + return Ahat + + @staticmethod + def _ztest_pvalue(param, se): + """Two-sided z-test p-value for H0: param == 0.""" + se_safe = np.where(se > 0, se, 1.0) + z = np.abs(param / se_safe) + p = 2.0 * (1.0 - norm.cdf(z)) + # Where se was 0, return 1.0 + p = np.where(se > 0, p, 1.0) + return p + + @staticmethod + def PP_ComputeParamStandardErrors( + dN, + xKFinal, + WKFinal, + Ahat, + Qhat, + x0hat, + Px0hat, + ExpectationSumsFinal, + fitType, + muhat, + betahat, + gammahat, + windowTimes, + HkAll, + PPEM_Constraints=None, + ): + """Compute standard errors via the observed information matrix. + + Uses a Monte-Carlo approximation of the missing information matrix + (McLachlan & Krishnan, Eq. 4.7). + + Parameters + ---------- + dN : (C, N) spike observations + xKFinal : (dx, N) smoothed states + WKFinal : (dx, dx, N) smoothed state covariances + Ahat : (dx, dx) estimated state transition + Qhat : (dx, dx) estimated state noise covariance + x0hat : (dx,) estimated initial state + Px0hat : (dx, dx) estimated initial state covariance + ExpectationSumsFinal : dict with sufficient statistics + fitType : 'poisson' or 'binomial' + muhat : (C,) estimated baseline rates + betahat : (dx, C) estimated stimulus coefficients + gammahat : (nW, C) or scalar estimated history coefficients + windowTimes : history window boundaries or None + HkAll : (N, nW, C) history design tensor + PPEM_Constraints : dict from PP_EMCreateConstraints + + Returns + ------- + SE : dict of standard errors for each parameter group + Pvals : dict of p-values for each parameter group + nTerms : int, total number of estimated parameters + """ + if PPEM_Constraints is None: + PPEM_Constraints = DecodingAlgorithms.PP_EMCreateConstraints() + + Ahat = np.atleast_2d(Ahat) + Qhat = np.atleast_2d(Qhat) + Px0hat = np.atleast_2d(Px0hat) + x0hat = np.asarray(x0hat, dtype=float).reshape(-1) + muhat = np.asarray(muhat, dtype=float).reshape(-1) + betahat = np.atleast_2d(betahat) + gammahat = np.asarray(gammahat, dtype=float) + dN = np.atleast_2d(dN) + + dx = Ahat.shape[0] + N = xKFinal.shape[1] + K = N + numCells = betahat.shape[1] + fitType = str(fitType).lower() + + # ---- Complete Information Matrices ---- + + # A information + if PPEM_Constraints["EstimateA"]: + n1_A, n2_A = Ahat.shape + Qinv = np.linalg.inv(Qhat) + if PPEM_Constraints["AhatDiag"]: + IAComp = np.zeros((n1_A, n1_A)) + for l in range(n1_A): + el = np.zeros(n1_A) + el[l] = 1.0 + em = np.zeros(n2_A) + em[l] = 1.0 + termMat = Qinv @ np.outer(el, em) @ (ExpectationSumsFinal["Sxkm1xkm1"] * np.eye(n1_A)) + IAComp[:, l] = np.diag(termMat) + else: + nA = Ahat.size + IAComp = np.zeros((nA, nA)) + cnt = 0 + for l in range(n1_A): + el = np.zeros(n1_A) + el[l] = 1.0 + for m in range(n2_A): + em = np.zeros(n2_A) + em[m] = 1.0 + termMat = Qinv @ np.outer(el, em) @ ExpectationSumsFinal["Sxkm1xkm1"] + IAComp[:, cnt] = termMat.T.ravel() + cnt += 1 + else: + IAComp = np.zeros((0, 0)) + + # Q information + n1_Q, n2_Q = Qhat.shape + Qinv = np.linalg.inv(Qhat) + if PPEM_Constraints["QhatDiag"]: + if PPEM_Constraints["QhatIsotropic"]: + IQComp = np.array([[0.5 * N * dx * Qhat[0, 0] ** (-2)]]) + else: + IQComp = np.zeros((n1_Q, n1_Q)) + cnt = 0 + for l in range(n1_Q): + el = np.zeros(n1_Q) + el[l] = 1.0 + termMat = N / 2.0 * Qinv @ np.outer(el, el) @ Qinv + IQComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + nQ = Qhat.size + IQComp = np.zeros((nQ, nQ)) + cnt = 0 + for l in range(n1_Q): + el = np.zeros(n1_Q) + el[l] = 1.0 + for m in range(n2_Q): + em = np.zeros(n2_Q) + em[m] = 1.0 + termMat = N / 2.0 * Qinv @ np.outer(em, el) @ Qinv + IQComp[:, cnt] = termMat.T.ravel() + cnt += 1 + + # Px0 information + if PPEM_Constraints["EstimatePx0"]: + Px0inv = np.linalg.inv(Px0hat) + if PPEM_Constraints["Px0Isotropic"]: + ISComp = np.array([[0.5 * dx * Px0hat[0, 0] ** (-2)]]) + else: + n1_S, n2_S = Px0hat.shape + ISComp = np.zeros((n1_S, n1_S)) + cnt = 0 + for l in range(n1_S): + el = np.zeros(n1_S) + el[l] = 1.0 + termMat = 0.5 * Px0inv @ np.outer(el, el) @ Px0inv + ISComp[:, cnt] = np.diag(termMat) + cnt += 1 + else: + ISComp = np.zeros((0, 0)) + + # x0 information + if PPEM_Constraints["Estimatex0"]: + Qinv = np.linalg.inv(Qhat) + Px0inv = np.linalg.inv(Px0hat) + Ix0Comp = Px0inv + Ahat.T @ Qinv @ Ahat + else: + Ix0Comp = np.zeros((0, 0)) + + # Monte Carlo draws for expectation approximation + McExp = PPEM_Constraints["mcIter"] + xKDrawExp = np.zeros((dx, K, McExp)) + for k in range(K): + WuTemp = WKFinal[:, :, k] + try: + chol_m = np.linalg.cholesky(WuTemp).T # upper triangular + except np.linalg.LinAlgError: + eigv, eigvec = np.linalg.eigh(WuTemp) + eigv = np.maximum(eigv, 1e-12) + chol_m = np.linalg.cholesky(eigvec @ np.diag(eigv) @ eigvec.T).T + z = np.random.randn(dx, McExp) + xKDrawExp[:, k, :] = xKFinal[:, k:k + 1] + chol_m @ z + + # Beta information (Hessian approximation via MC) + IBetaComp = np.zeros((dx * numCells, dx * numCells)) + # xkPerm: (dx, McExp, K) + xkPerm = np.transpose(xKDrawExp, (0, 2, 1)) + + for c in range(numCells): + HessianTerm = np.zeros((dx, dx, K)) + for k in range(K): + Hk = HkAll[k, :, c] if HkAll.ndim == 3 else np.zeros(0) + xk = xkPerm[:, :, k] # (dx, McExp) + + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, c] + gammaC = np.atleast_1d(gammaC) + + Hk_vec = np.atleast_1d(Hk) + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size and gammaC.size > 0 else 0.0 + + terms = muhat[c] + betahat[:, c] @ xk + hist_term + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + HessianTerm[:, :, k] = -1.0 / McExp * (ld[None, :] * xk) @ xk.T + else: # binomial + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExplambdaDeltaXkXk = 1.0 / McExp * (ld[None, :] * xk) @ xk.T + ExplambdaDeltaSqXkXkT = 1.0 / McExp * (ld[None, :] ** 2 * xk) @ xk.T + ExplambdaDeltaCubeXkXkT = 1.0 / McExp * (ld[None, :] ** 3 * xk) @ xk.T + HessianTerm[:, :, k] = ExplambdaDeltaXkXk + ExplambdaDeltaSqXkXkT - 2 * ExplambdaDeltaCubeXkXkT + + startInd = dx * c + endInd = dx * (c + 1) + IBetaComp[startInd:endInd, startInd:endInd] = -np.sum(HessianTerm, axis=2) + + # Mu information + IMuComp = np.zeros((numCells, numCells)) + for c in range(numCells): + HessianTerm = 0.0 + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 0)) + Hk_vec = Hk_full[k, :] if Hk_full.ndim == 2 and Hk_full.shape[0] > k else np.zeros(0) + xk = xkPerm[:, :, k] + + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, c] + gammaC = np.atleast_1d(gammaC) + Hk_vec = np.atleast_1d(Hk_vec) + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size and gammaC.size > 0 else 0.0 + + terms = muhat[c] + betahat[:, c] @ xk + hist_term + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + HessianTerm -= 1.0 / McExp * np.sum(ld) + else: + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExplambdaDelta = 1.0 / McExp * np.sum(ld) + ExplambdaDeltaSq = 1.0 / McExp * np.sum(ld ** 2) + ExplambdaDeltaCubed = 1.0 / McExp * np.sum(ld ** 3) + HessianTerm += -(dN[c, k] + 1) * ExplambdaDelta + (dN[c, k] + 3) * ExplambdaDeltaSq - 3 * ExplambdaDeltaCubed + IMuComp[c, c] = -HessianTerm + + # Gamma information + gammahat_flat = gammahat.ravel() + has_gamma = gammahat_flat.size > 1 or (gammahat_flat.size == 1 and gammahat_flat[0] != 0) + if windowTimes is not None and len(windowTimes) > 0 and has_gamma: + nHist = HkAll.shape[1] if HkAll.ndim == 3 else 0 + IGammaComp = np.zeros((nHist * numCells, nHist * numCells)) + for c in range(numCells): + HessianTerm = np.zeros((nHist, nHist)) + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, nHist)) + for k in range(K): + Hk_vec = Hk_full[k, :] + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, c] + gammaC = np.atleast_1d(gammaC) + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size else 0.0 + terms = muhat[c] + betahat[:, c] @ xk + hist_term + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + ExplambdaDelta = 1.0 / McExp * np.sum(ld) + HessianTerm -= np.outer(Hk_vec, Hk_vec) * ExplambdaDelta + else: + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExplambdaDelta = 1.0 / McExp * np.sum(ld) + ExplambdaDeltaSq = 1.0 / McExp * np.sum(ld ** 2) + ExplambdaDeltaCubed = 1.0 / McExp * np.sum(ld ** 2) # Matlab uses ld.^2 here + HessianTerm += (-ExplambdaDelta * (dN[c, k] + 1) + + ExplambdaDeltaSq * (dN[c, k] + 3) + - 2 * ExplambdaDeltaCubed) * np.outer(Hk_vec, Hk_vec) + startInd = nHist * c + endInd = nHist * (c + 1) + IGammaComp[startInd:endInd, startInd:endInd] = -HessianTerm + else: + IGammaComp = np.zeros((0, 0)) + + # Assemble complete information matrix + n1 = IAComp.shape[0] if PPEM_Constraints["EstimateA"] else 0 + n2 = IQComp.shape[0] + n3 = ISComp.shape[0] if PPEM_Constraints["EstimatePx0"] else 0 + n4 = Ix0Comp.shape[0] if PPEM_Constraints["Estimatex0"] else 0 + n5 = IMuComp.shape[0] + n6 = IBetaComp.shape[0] + n7 = IGammaComp.shape[0] if has_gamma else 0 + nTerms = n1 + n2 + n3 + n4 + n5 + n6 + n7 + + IComp = np.zeros((nTerms, nTerms)) + off = 0 + if PPEM_Constraints["EstimateA"] and n1 > 0: + IComp[off:off + n1, off:off + n1] = IAComp + off = n1 + IComp[off:off + n2, off:off + n2] = IQComp + off = n1 + n2 + if PPEM_Constraints["EstimatePx0"] and n3 > 0: + IComp[off:off + n3, off:off + n3] = ISComp + off = n1 + n2 + n3 + if PPEM_Constraints["Estimatex0"] and n4 > 0: + IComp[off:off + n4, off:off + n4] = Ix0Comp + off = n1 + n2 + n3 + n4 + IComp[off:off + n5, off:off + n5] = IMuComp + off = n1 + n2 + n3 + n4 + n5 + IComp[off:off + n6, off:off + n6] = IBetaComp + off = n1 + n2 + n3 + n4 + n5 + n6 + if n7 > 0: + IComp[off:off + n7, off:off + n7] = IGammaComp + + # ---- Missing Information Matrix (Monte Carlo) ---- + Mc = PPEM_Constraints["mcIter"] + xKDraw = np.zeros((dx, N, Mc)) + for n_idx in range(N): + WuTemp = WKFinal[:, :, n_idx] + try: + chol_m = np.linalg.cholesky(WuTemp).T + except np.linalg.LinAlgError: + eigv, eigvec = np.linalg.eigh(WuTemp) + eigv = np.maximum(eigv, 1e-12) + chol_m = np.linalg.cholesky(eigvec @ np.diag(eigv) @ eigvec.T).T + z = np.random.randn(dx, Mc) + xKDraw[:, n_idx, :] = xKFinal[:, n_idx:n_idx + 1] + chol_m @ z + + if PPEM_Constraints["EstimatePx0"] or PPEM_Constraints["Estimatex0"]: + try: + chol_m = np.linalg.cholesky(Px0hat).T + except np.linalg.LinAlgError: + eigv, eigvec = np.linalg.eigh(Px0hat) + eigv = np.maximum(eigv, 1e-12) + chol_m = np.linalg.cholesky(eigvec @ np.diag(eigv) @ eigvec.T).T + z = np.random.randn(dx, Mc) + x0Draw = x0hat[:, None] + chol_m @ z + else: + x0Draw = np.tile(x0hat[:, None], (1, Mc)) + + Qinv = np.linalg.inv(Qhat) + Px0inv = np.linalg.inv(Px0hat) + IMc = np.zeros((nTerms, nTerms, Mc)) + + for c_mc in range(Mc): + x_K = xKDraw[:, :, c_mc] + x_0 = x0Draw[:, c_mc] + Dx = x_K.shape[0] + + Sxkm1xk = np.zeros((Dx, Dx)) + Sxkm1xkm1 = np.zeros((Dx, Dx)) + Sxkxk = np.zeros((Dx, Dx)) + + for k in range(K): + if k == 0: + Sxkm1xk += np.outer(x_0, x_K[:, k]) + Sxkm1xkm1 += np.outer(x_0, x_0) + else: + Sxkm1xk += np.outer(x_K[:, k - 1], x_K[:, k]) + Sxkm1xkm1 += np.outer(x_K[:, k - 1], x_K[:, k - 1]) + Sxkxk += np.outer(x_K[:, k], x_K[:, k]) + + Sxkxk = 0.5 * (Sxkxk + Sxkxk.T) + sumXkTerms_mc = Sxkxk - Ahat @ Sxkm1xk - Sxkm1xk.T @ Ahat.T + Ahat @ Sxkm1xkm1 @ Ahat.T + Sxkxkm1 = Sxkm1xk.T + sumXkTerms_mc = 0.5 * (sumXkTerms_mc + sumXkTerms_mc.T) + + # Score for A + if PPEM_Constraints["EstimateA"]: + ScorA = np.linalg.solve(Qhat, Sxkxkm1 - Ahat @ Sxkm1xkm1) + if PPEM_Constraints["AhatDiag"]: + ScoreAMc = np.diag(ScorA) + else: + ScoreAMc = ScorA.T.ravel() + else: + ScoreAMc = np.array([]) + + # Score for Q + if PPEM_Constraints["QhatDiag"]: + if PPEM_Constraints["QhatIsotropic"]: + ScoreQ = -0.5 * (K * Dx * Qhat[0, 0] ** (-1) - Qhat[0, 0] ** (-2) * np.trace(sumXkTerms_mc)) + ScoreQMc = np.atleast_1d(ScoreQ) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * np.eye(dx) - np.linalg.solve(Qhat, sumXkTerms_mc).T) + ScoreQMc = np.diag(ScoreQ) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * np.eye(dx) - np.linalg.solve(Qhat, sumXkTerms_mc).T) + ScoreQMc = ScoreQ.T.ravel() + + # Score for Px0 + if PPEM_Constraints["Px0Isotropic"]: + diff = x_0 - x0hat + ScoreSMc = np.atleast_1d(-0.5 * (Dx * Px0hat[0, 0] ** (-1) + - Px0hat[0, 0] ** (-2) * np.dot(diff, diff))) + else: + diff = x_0 - x0hat + ScorS = -0.5 * np.linalg.solve(Px0hat, np.eye(dx) - np.linalg.solve(Px0hat, np.outer(diff, diff)).T) + ScoreSMc = np.diag(ScorS) + + # Score for x0 + Scorx0 = -np.linalg.solve(Px0hat, x_0 - x0hat) + Ahat.T @ Qinv @ (x_K[:, 0] - Ahat @ x_0) + Scorex0Mc = Scorx0.ravel() + + # Cell scores + ScoreMuMc = np.zeros(numCells) + ScoreBetaMc = np.array([], dtype=float) + ScoreGammaMc = np.array([], dtype=float) + + for nc in range(numCells): + Hk_full = HkAll[:, :, nc] if HkAll.ndim == 3 else np.zeros((K, 0)) + nHist_c = Hk_full.shape[1] + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, nc] + gammaC = np.atleast_1d(gammaC) + + hist_terms = Hk_full @ gammaC if gammaC.size == nHist_c and nHist_c > 0 else np.zeros(K) + terms = muhat[nc] + betahat[:, nc] @ x_K + hist_terms + + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + ScoreMuMc[nc] = np.sum(dN[nc, :] - ld) + ScoreBetaMc = np.concatenate([ScoreBetaMc, + np.sum((dN[nc, :] - ld)[None, :] * x_K, axis=1)]) + if nHist_c > 0: + ScoreGammaMc = np.concatenate([ScoreGammaMc, + np.sum((dN[nc, :] - ld)[None, :] * Hk_full.T, axis=1)]) + else: # binomial + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ScoreMuMc[nc] = np.sum(dN[nc, :] - (dN[nc, :] + 1) * ld + ld ** 2) + ScoreBetaMc = np.concatenate([ScoreBetaMc, + np.sum((dN[nc, :] * (1 - ld) - ld * (1 - ld))[None, :] * x_K, axis=1)]) + if nHist_c > 0: + ScoreGammaMc = np.concatenate([ScoreGammaMc, + np.sum((dN[nc, :] - (dN[nc, :] + 1) * ld + ld ** 2)[None, :] * Hk_full.T, axis=1)]) + + # Assemble score vector + ScoreVec = np.concatenate([ScoreAMc, ScoreQMc]) + if PPEM_Constraints["EstimatePx0"]: + ScoreVec = np.concatenate([ScoreVec, ScoreSMc]) + if PPEM_Constraints["Estimatex0"]: + ScoreVec = np.concatenate([ScoreVec, Scorex0Mc]) + ScoreVec = np.concatenate([ScoreVec, ScoreMuMc, ScoreBetaMc]) + if has_gamma and ScoreGammaMc.size > 0: + ScoreVec = np.concatenate([ScoreVec, ScoreGammaMc]) + + IMc[:, :, c_mc] = np.outer(ScoreVec, ScoreVec) + + IMissing = np.mean(IMc, axis=2) + IObs = IComp - IMissing + try: + invIObs = np.linalg.inv(IObs) + except np.linalg.LinAlgError: + invIObs = np.linalg.pinv(IObs) + invIObs = DecodingAlgorithms._nearestSPD(invIObs) + + VarVec = np.diag(invIObs) + SEVec = np.sqrt(np.maximum(VarVec, 0.0)) + + # Unpack SE vector + off = 0 + SEAterms = SEVec[off:off + n1]; off += n1 + SEQterms = SEVec[off:off + n2]; off += n2 + SEPx0terms = SEVec[off:off + n3]; off += n3 + SEx0terms = SEVec[off:off + n4]; off += n4 + SEMuTerms = SEVec[off:off + n5]; off += n5 + SEBetaTerms = SEVec[off:off + n6]; off += n6 + SEGammaTerms = SEVec[off:off + n7] + + SE = {} + Pvals = {} + + # A + if PPEM_Constraints["EstimateA"]: + if PPEM_Constraints["AhatDiag"]: + SEA = np.diag(SEAterms) + pA = np.diag(DecodingAlgorithms._ztest_pvalue(np.diag(Ahat), np.diag(SEA))) + else: + SEA = SEAterms.reshape(Ahat.shape[1], Ahat.shape[0]).T + pA = DecodingAlgorithms._ztest_pvalue(Ahat.ravel(), SEA.ravel()).reshape(Ahat.shape) + SE["A"] = SEA + Pvals["A"] = pA + + # Q + if PPEM_Constraints["QhatDiag"]: + SEQ = np.diag(SEQterms) + if PPEM_Constraints["QhatIsotropic"]: + pQ = np.diag(DecodingAlgorithms._ztest_pvalue(np.atleast_1d(Qhat[0, 0]), np.atleast_1d(SEQ[0, 0]))) + else: + pQ = np.diag(DecodingAlgorithms._ztest_pvalue(np.diag(Qhat), np.diag(SEQ))) + else: + SEQ = SEQterms.reshape(Qhat.shape[1], Qhat.shape[0]).T + pQ = DecodingAlgorithms._ztest_pvalue(Qhat.ravel(), SEQ.ravel()).reshape(Qhat.shape) + SE["Q"] = SEQ + Pvals["Q"] = pQ + + # Px0 + if PPEM_Constraints["EstimatePx0"]: + SES = np.diag(SEPx0terms) + if PPEM_Constraints["Px0Isotropic"]: + pPx0 = np.diag(DecodingAlgorithms._ztest_pvalue(np.atleast_1d(Px0hat[0, 0]), np.atleast_1d(SES[0, 0]))) + else: + pPx0 = np.diag(DecodingAlgorithms._ztest_pvalue(np.diag(Px0hat), np.diag(SES))) + SE["Px0"] = SES + Pvals["Px0"] = pPx0 + + # x0 + if PPEM_Constraints["Estimatex0"]: + SEx0 = SEx0terms + pX0 = DecodingAlgorithms._ztest_pvalue(x0hat, SEx0) + SE["x0"] = SEx0 + Pvals["x0"] = pX0 + + # Mu + SEMu = SEMuTerms + pMu = DecodingAlgorithms._ztest_pvalue(muhat, SEMu) + SE["mu"] = SEMu + Pvals["mu"] = pMu + + # Beta + SEBeta = SEBetaTerms.reshape(betahat.shape[1], betahat.shape[0]).T + pBeta = DecodingAlgorithms._ztest_pvalue(betahat.ravel(), SEBeta.ravel()).reshape(betahat.shape) + SE["beta"] = SEBeta + Pvals["beta"] = pBeta + + # Gamma + if has_gamma and n7 > 0: + SEGamma = SEGammaTerms.reshape(gammahat.shape[1], gammahat.shape[0]).T if gammahat.ndim == 2 else SEGammaTerms + pGamma = DecodingAlgorithms._ztest_pvalue(gammahat.ravel(), SEGammaTerms).reshape(gammahat.shape) if gammahat.ndim == 2 else DecodingAlgorithms._ztest_pvalue(gammahat.ravel(), SEGammaTerms) + SE["gamma"] = SEGamma + Pvals["gamma"] = pGamma + + return SE, Pvals, nTerms + + @staticmethod + def PP_EStep(A, Q, dN, mu, beta, fitType, gamma, HkAll, x0, Px0): + """E-step for PP EM: forward filter + RTS smoother + cross-covariance. + + Parameters + ---------- + A : (dx, dx) state transition matrix + Q : (dx, dx) state noise covariance + dN : (C, N) binary spike observations + mu : (C,) baseline log-rate + beta : (dx, C) stimulus coefficients + fitType : 'poisson' or 'binomial' + gamma : (nW, C) or scalar history coefficients + HkAll : (N, nW, C) history design tensor + x0 : (dx,) initial state + Px0 : (dx, dx) initial state covariance + + Returns + ------- + x_K : (dx, N) smoothed states + W_K : (dx, dx, N) smoothed covariances + logll : float, log-likelihood + ExpectationSums : dict of sufficient statistics + """ + A = np.atleast_2d(A).astype(float) + Q = np.atleast_2d(Q).astype(float) + dN = np.atleast_2d(dN).astype(float) + mu = np.asarray(mu, dtype=float).reshape(-1) + beta = np.atleast_2d(beta).astype(float) + gamma = np.asarray(gamma, dtype=float) + x0 = np.asarray(x0, dtype=float).reshape(-1) + Px0 = np.atleast_2d(Px0).astype(float) + fitType = str(fitType).lower() + + numCells, K = dN.shape + Dx = A.shape[1] + + # Forward filter + x_p = np.zeros((Dx, K + 1)) + x_u = np.zeros((Dx, K)) + W_p = np.zeros((Dx, Dx, K + 1)) + W_u = np.zeros((Dx, Dx, K)) + x_p[:, 0] = A @ x0 + W_p[:, :, 0] = A @ Px0 @ A.T + Q + + # Permute HkAll for PPDecode_updateLinear: (nW, C, N) + HkPerm = np.transpose(HkAll, (1, 2, 0)) if HkAll.ndim == 3 else HkAll + + for k in range(K): + x_u[:, k], W_u[:, :, k], _ = DecodingAlgorithms.PPDecode_updateLinear( + x_p[:, k], W_p[:, :, k], dN, mu, beta, fitType, gamma, HkPerm, k + 1, None + ) + A_k = A[:, :, min(k, A.shape[2] - 1)] if A.ndim == 3 else A + Q_k = Q[:, :, min(k, Q.shape[2] - 1)] if Q.ndim == 3 else Q + x_p[:, k + 1], W_p[:, :, k + 1] = DecodingAlgorithms.PPDecode_predict( + x_u[:, k], W_u[:, :, k], A_k, Q_k + ) + + # RTS smoother using kalman_smootherFromFiltered + # Convert state-major (dx, K+1/K) to time-major (K+1/K, dx) for + # the smoother, which uses _state_history_time_major internally. + x_p_tm = x_p.T # (K+1, dx) + W_p_tm = np.transpose(W_p, (2, 0, 1)) # (K+1, dx, dx) + x_u_tm = x_u.T # (K, dx) + W_u_tm = np.transpose(W_u, (2, 0, 1)) # (K, dx, dx) + + x_K_tm, W_K_tm, Lk = DecodingAlgorithms.kalman_smootherFromFiltered( + A, x_p_tm, W_p_tm, x_u_tm, W_u_tm + ) + + # Convert back to state-major: x_K (dx, K), W_K (dx, dx, K) + x_K = x_K_tm.T if x_K_tm.ndim == 2 else x_K_tm + W_K = np.transpose(W_K_tm, (1, 2, 0)) if W_K_tm.ndim == 3 else W_K_tm + + numStates = x_K.shape[0] + + # Cross-covariance Wku + Wku = np.zeros((numStates, numStates, K, K)) + for k in range(K): + Wku[:, :, k, k] = W_K[:, :, k] + + # W_u and W_p remain in state-major (dx, dx, K) format + W_u_sm = W_u + W_p_sm = W_p + + Dk = np.zeros((numStates, numStates, K)) + for u in range(K - 1, 0, -1): + k = u - 1 + Dk[:, :, k] = W_u_sm[:, :, k] @ A.T @ np.linalg.inv(W_p_sm[:, :, k + 1]) + Wku[:, :, k, u] = Dk[:, :, k] @ Wku[:, :, k + 1, u] + Wku[:, :, u, k] = Wku[:, :, k, u].T + + # Sufficient statistics + Sxkm1xk = np.zeros((Dx, Dx)) + Sxkm1xkm1 = np.zeros((Dx, Dx)) + Sxkxk = np.zeros((Dx, Dx)) + + for k in range(K): + if k == 0: + Sxkm1xk += Px0 @ A.T @ np.linalg.inv(W_p_sm[:, :, 0]) @ Wku[:, :, 0, 0] + Sxkm1xkm1 += Px0 + np.outer(x0, x0) + else: + Sxkm1xk += Wku[:, :, k - 1, k] + np.outer(x_K[:, k - 1], x_K[:, k]) + Sxkm1xkm1 += Wku[:, :, k - 1, k - 1] + np.outer(x_K[:, k - 1], x_K[:, k - 1]) + Sxkxk += Wku[:, :, k, k] + np.outer(x_K[:, k], x_K[:, k]) + + Sxkxk = 0.5 * (Sxkxk + Sxkxk.T) + sumXkTerms = Sxkxk - A @ Sxkm1xk - Sxkm1xk.T @ A.T + A @ Sxkm1xkm1 @ A.T + Sxkxkm1 = Sxkm1xk.T + + # Point process log-likelihood + sumPPll = 0.0 + + if fitType == "poisson": + for k in range(K): + if HkAll.ndim == 3: + Hk = HkAll[k, :, :] # (nW, C) — need to handle orientation + if Hk.shape[0] == numCells: + Hk = Hk.T + else: + Hk = np.zeros((0, numCells)) + + xk = x_K[:, k] + gammaC = np.tile(gamma, numCells) if gamma.ndim == 0 or gamma.size == 1 else gamma + gammaC = np.atleast_2d(gammaC) + if gammaC.shape[0] == 1 and gammaC.shape[1] == 1: + gammaC = np.full((max(Hk.shape[0], 1), numCells), float(gamma.ravel()[0]) if gamma.size > 0 else 0.0) + + if Hk.ndim == 2 and Hk.shape[0] > 0 and gammaC.shape[0] == Hk.shape[0]: + hist_diag = np.diag(gammaC.T @ Hk) if Hk.shape[0] > 0 else np.zeros(numCells) + else: + hist_diag = np.zeros(numCells) + + terms = mu + beta.T @ xk + hist_diag + Wk = W_K[:, :, k] + ld = np.exp(np.clip(terms, -30, 30)) + bt = beta + ExplambdaDelta = ld + 0.5 * (ld * np.diag(bt.T @ Wk @ bt)) + ExplogLD = terms + sumPPll += float(np.sum(dN[:, k] * ExplogLD - ExplambdaDelta)) + + elif fitType == "binomial": + for k in range(K): + if HkAll.ndim == 3: + Hk = HkAll[k, :, :] + if Hk.shape[0] == numCells: + Hk = Hk.T + else: + Hk = np.zeros((0, numCells)) + + xk = x_K[:, k] + gammaC = np.tile(gamma, numCells) if gamma.ndim == 0 or gamma.size == 1 else gamma + gammaC = np.atleast_2d(gammaC) + if gammaC.shape[0] == 1 and gammaC.shape[1] == 1: + gammaC = np.full((max(Hk.shape[0], 1), numCells), float(gamma.ravel()[0]) if gamma.size > 0 else 0.0) + + if Hk.ndim == 2 and Hk.shape[0] > 0 and gammaC.shape[0] == Hk.shape[0]: + hist_diag = np.diag(gammaC.T @ Hk) if Hk.shape[0] > 0 else np.zeros(numCells) + else: + hist_diag = np.zeros(numCells) + + terms = mu + beta.T @ xk + hist_diag + Wk = W_K[:, :, k] + ld_raw = np.clip(terms, -30, 30) + ld = 1.0 / (1.0 + np.exp(-ld_raw)) + bt = beta + btWbt_diag = np.diag(bt.T @ Wk @ bt) + ExplambdaDelta = ld + 0.5 * (ld * (1 - ld) * (1 - 2 * ld)) * btWbt_diag + ExplogLD = np.log(np.maximum(ld, 1e-30)) + 0.5 * (-ld * (1 - ld)) * btWbt_diag + sumPPll += float(np.sum(dN[:, k] * ExplogLD - ExplambdaDelta)) + + det_Q = max(float(np.linalg.det(Q)), np.finfo(float).tiny) + det_Px0 = max(float(np.linalg.det(Px0)), np.finfo(float).tiny) + logll = ( + -Dx * K / 2.0 * np.log(2.0 * np.pi) + - K / 2.0 * np.log(det_Q) + - Dx / 2.0 * np.log(2.0 * np.pi) + - 0.5 * np.log(det_Px0) + + sumPPll + - 0.5 * np.trace(np.linalg.solve(Q, sumXkTerms)) + - Dx / 2.0 + ) + + ExpectationSums = { + "Sxkm1xkm1": Sxkm1xkm1, + "Sxkm1xk": Sxkm1xk, + "Sxkxkm1": Sxkxkm1, + "Sxkxk": Sxkxk, + "sumXkTerms": sumXkTerms, + "sumPPll": sumPPll, + } + + return x_K, W_K, logll, ExpectationSums + + @staticmethod + def PP_MStep( + dN, x_K, W_K, x0, Px0, ExpectationSums, fitType, + muhat, betahat, gammahat, windowTimes, HkAll, + PPEM_Constraints=None, MstepMethod="NewtonRaphson", + ): + """M-step for PP EM: update all model parameters. + + Parameters + ---------- + dN : (C, N) spike observations + x_K : (dx, N) smoothed states + W_K : (dx, dx, N) smoothed covariances + x0 : (dx,) current initial state estimate + Px0 : (dx, dx) current initial covariance estimate + ExpectationSums : dict from E-step + fitType : 'poisson' or 'binomial' + muhat : (C,) current baseline rates + betahat : (dx, C) current stimulus coefficients + gammahat : scalar or (nW, C) current history coefficients + windowTimes : history window boundaries or None + HkAll : (N, nW, C) history tensor + PPEM_Constraints : dict from PP_EMCreateConstraints + MstepMethod : 'NewtonRaphson' (default) or 'GLM' + + Returns + ------- + Ahat, Qhat, muhat_new, betahat_new, gammahat_new, x0hat, Px0hat + """ + if PPEM_Constraints is None: + PPEM_Constraints = DecodingAlgorithms.PP_EMCreateConstraints() + + Sxkm1xkm1 = ExpectationSums["Sxkm1xkm1"] + Sxkxkm1 = ExpectationSums["Sxkxkm1"] + sumXkTerms = ExpectationSums["sumXkTerms"] + + dx, K = x_K.shape + numCells = dN.shape[0] + fitType = str(fitType).lower() + + x0 = np.asarray(x0, dtype=float).reshape(-1) + Px0 = np.atleast_2d(Px0).astype(float) + muhat = np.asarray(muhat, dtype=float).reshape(-1) + betahat = np.atleast_2d(betahat).astype(float) + gammahat = np.asarray(gammahat, dtype=float) + + # --- A update --- + I_dx = np.eye(dx) + if PPEM_Constraints["AhatDiag"]: + Ahat = (Sxkxkm1 * I_dx) @ np.linalg.inv(Sxkm1xkm1 * I_dx + 1e-12 * I_dx) + else: + Ahat = np.linalg.solve(Sxkm1xkm1.T + 1e-12 * I_dx, Sxkxkm1.T).T + + # --- Q update --- + if PPEM_Constraints["QhatDiag"]: + if PPEM_Constraints["QhatIsotropic"]: + Qhat = (1.0 / (dx * K)) * np.trace(sumXkTerms) * I_dx + else: + Qhat = (1.0 / K) * (sumXkTerms * I_dx) + Qhat = 0.5 * (Qhat + Qhat.T) + else: + Qhat = (1.0 / K) * sumXkTerms + Qhat = 0.5 * (Qhat + Qhat.T) + + # Ensure positive definiteness + eigvals, eigvecs = np.linalg.eigh(Qhat) + eigvals = np.maximum(eigvals, 1e-10) + Qhat = eigvecs @ np.diag(eigvals) @ eigvecs.T + Qhat = 0.5 * (Qhat + Qhat.T) + + # --- x0 update --- + if PPEM_Constraints["Estimatex0"]: + Px0inv = np.linalg.inv(Px0 + 1e-12 * I_dx) + Qinv = np.linalg.inv(Qhat + 1e-12 * I_dx) + x0hat = np.linalg.solve(Px0inv + Ahat.T @ Qinv @ Ahat, + Ahat.T @ Qinv @ x_K[:, 0] + Px0inv @ x0) + else: + x0hat = x0.copy() + + # --- Px0 update --- + if PPEM_Constraints["EstimatePx0"]: + if PPEM_Constraints["Px0Isotropic"]: + diff = x0hat - x0 + Px0hat = (np.dot(diff, diff) / (dx * K)) * I_dx + else: + diff = x0hat - x0 + Px0hat = np.outer(diff, diff) * I_dx + Px0hat = 0.5 * (Px0hat + Px0hat.T) + # Ensure positive definiteness + eigvals, eigvecs = np.linalg.eigh(Px0hat) + eigvals = np.maximum(eigvals, 1e-10) + Px0hat = eigvecs @ np.diag(eigvals) @ eigvecs.T + else: + Px0hat = Px0.copy() + + betahat_new = betahat.copy() + gammahat_new = gammahat.copy() if gammahat.ndim > 0 else np.atleast_1d(gammahat).copy() + muhat_new = muhat.copy() + + # --- Newton-Raphson for beta, mu, gamma --- + McExp = 50 + xKDrawExp = np.zeros((dx, K, McExp)) + diffTol = 1e-5 + + for k in range(K): + WuTemp = W_K[:, :, k] + try: + chol_m = np.linalg.cholesky(WuTemp).T + except np.linalg.LinAlgError: + eigv, eigvec = np.linalg.eigh(WuTemp) + eigv = np.maximum(eigv, 1e-12) + chol_m = np.linalg.cholesky(eigvec @ np.diag(eigv) @ eigvec.T).T + z = np.random.randn(dx, McExp) + xKDrawExp[:, k, :] = x_K[:, k:k + 1] + chol_m @ z + + # xkPerm: (dx, McExp, K) + xkPerm = np.transpose(xKDrawExp, (0, 2, 1)) + + # --- Beta Newton-Raphson --- + for c in range(numCells): + converged = False + maxIter_nr = 100 + for iteration in range(maxIter_nr): + HessianTerm = np.zeros((dx, dx)) + GradTerm = np.zeros(dx) + + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 0)) + Hk_vec = Hk_full[k, :] if Hk_full.ndim == 2 and Hk_full.shape[0] > k else np.zeros(0) + xk = xkPerm[:, :, k] # (dx, McExp) + + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, c] + gammaC = np.atleast_1d(gammaC) + Hk_vec = np.atleast_1d(Hk_vec) + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size and gammaC.size > 0 else 0.0 + + terms = muhat[c] + betahat_new[:, c] @ xk + hist_term + + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + ExpLambdaXk = (1.0 / McExp) * np.sum(ld[None, :] * xk, axis=1) + ExpLambdaXkXkT = (1.0 / McExp) * (ld[None, :] * xk) @ xk.T + GradTerm += dN[c, k] * x_K[:, k] - ExpLambdaXk + HessianTerm -= ExpLambdaXkXkT + else: # binomial + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExplambdaDeltaXkXk = (1.0 / McExp) * (ld[None, :] * xk) @ xk.T + ExplambdaDeltaSqXkXkT = (1.0 / McExp) * ((ld ** 2)[None, :] * xk) @ xk.T + ExplambdaDeltaCubeXkXkT = (1.0 / McExp) * ((ld ** 3)[None, :] * xk) @ xk.T + ExpLambdaXk = (1.0 / McExp) * np.sum(ld[None, :] * xk, axis=1) + ExpLambdaSquaredXk = (1.0 / McExp) * np.sum((ld ** 2)[None, :] * xk, axis=1) + GradTerm += dN[c, k] * x_K[:, k] - (dN[c, k] + 1) * ExpLambdaXk + ExpLambdaSquaredXk + HessianTerm += ExplambdaDeltaXkXk + ExplambdaDeltaSqXkXkT - 2 * ExplambdaDeltaCubeXkXkT + + if np.any(np.isnan(HessianTerm)) or np.any(np.isinf(HessianTerm)): + betahat_newTemp = betahat_new[:, c] + else: + try: + betahat_newTemp = betahat_new[:, c] - np.linalg.solve(HessianTerm, GradTerm) + except np.linalg.LinAlgError: + betahat_newTemp = betahat_new[:, c] + if np.any(np.isnan(betahat_newTemp)): + betahat_newTemp = betahat_new[:, c] + + mabsDiff = float(np.max(np.abs(betahat_newTemp - betahat_new[:, c]))) + if mabsDiff < diffTol: + converged = True + betahat_new[:, c] = betahat_newTemp + if converged: + break + + # --- Mu Newton-Raphson --- + for c in range(numCells): + converged = False + maxIter_nr = 100 + for iteration in range(maxIter_nr): + HessianTerm = 0.0 + GradTerm = 0.0 + + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 0)) + Hk_vec = Hk_full[k, :] if Hk_full.ndim == 2 and Hk_full.shape[0] > k else np.zeros(0) + xk = xkPerm[:, :, k] + + gammaC = gammahat if gammahat.ndim == 0 or gammahat.size == 1 else gammahat[:, c] + gammaC = np.atleast_1d(gammaC) + Hk_vec = np.atleast_1d(Hk_vec) + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size and gammaC.size > 0 else 0.0 + + terms = muhat_new[c] + betahat[:, c] @ xk + hist_term + + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + ExpLambdaDelta = (1.0 / McExp) * np.sum(ld) + GradTerm += dN[c, k] - ExpLambdaDelta + HessianTerm -= ExpLambdaDelta + else: # binomial + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExpLambdaDelta = (1.0 / McExp) * np.sum(ld) + ExpLambdaDeltaSq = (1.0 / McExp) * np.sum(ld ** 2) + ExpLambdaDeltaCubed = (1.0 / McExp) * np.sum(ld ** 3) + GradTerm += dN[c, k] - (dN[c, k] + 1) * ExpLambdaDelta + ExpLambdaDeltaSq + HessianTerm += -(dN[c, k] + 1) * ExpLambdaDelta + (dN[c, k] + 3) * ExpLambdaDeltaSq - 2 * ExpLambdaDeltaCubed + + if np.isnan(HessianTerm) or np.isinf(HessianTerm) or abs(HessianTerm) < 1e-30: + muhat_newTemp = muhat_new[c] + else: + muhat_newTemp = muhat_new[c] - GradTerm / HessianTerm + if np.isnan(muhat_newTemp): + muhat_newTemp = muhat_new[c] + + mabsDiff = abs(muhat_newTemp - muhat_new[c]) + if mabsDiff < diffTol: + converged = True + muhat_new[c] = muhat_newTemp + if converged: + break + + # --- Gamma Newton-Raphson --- + gammahat_flat = gammahat_new.ravel() + has_gamma = (windowTimes is not None and len(windowTimes) > 0 + and (gammahat_flat.size > 1 or (gammahat_flat.size == 1 and gammahat_flat[0] != 0))) + + if has_gamma and gammahat_new.ndim >= 1: + nGamma = gammahat_new.shape[0] if gammahat_new.ndim == 1 else gammahat_new.shape[0] + for c in range(numCells): + converged = False + maxIter_nr = 100 + gammaC = gammahat_new if gammahat_new.ndim == 0 or gammahat_new.size == 1 else gammahat_new[:, c] if gammahat_new.ndim == 2 else gammahat_new + gammaC = np.atleast_1d(gammaC).copy() + + for iteration in range(maxIter_nr): + HessianTerm = np.zeros((nGamma, nGamma)) + GradTerm = np.zeros(nGamma) + + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 0)) + Hk_vec = Hk_full[k, :] if Hk_full.ndim == 2 and Hk_full.shape[0] > k else np.zeros(0) + Hk_vec = np.atleast_1d(Hk_vec) + xk = xkPerm[:, :, k] + + hist_term = float(gammaC @ Hk_vec) if Hk_vec.size == gammaC.size and gammaC.size > 0 else 0.0 + terms = muhat[c] + betahat[:, c] @ xk + hist_term + + if fitType == "poisson": + ld = np.exp(np.clip(terms, -30, 30)) + ExpLambdaDelta = (1.0 / McExp) * np.sum(ld) + GradTerm += (dN[c, k] - ExpLambdaDelta) * Hk_vec + HessianTerm -= ExpLambdaDelta * np.outer(Hk_vec, Hk_vec) + else: # binomial + ld = 1.0 / (1.0 + np.exp(-np.clip(terms, -30, 30))) + ExpLambdaDelta = (1.0 / McExp) * np.sum(ld) + ExpLambdaDeltaSq = (1.0 / McExp) * np.sum(ld ** 2) + ExpLambdaDeltaCubed = (1.0 / McExp) * np.sum(ld ** 3) + GradTerm += (dN[c, k] - (dN[c, k] + 1) * ExpLambdaDelta + ExpLambdaDeltaSq) * Hk_vec + HessianTerm += (-(dN[c, k] + 1) * ExpLambdaDelta + (dN[c, k] + 3) * ExpLambdaDeltaSq - 2 * ExpLambdaDeltaCubed) * np.outer(Hk_vec, Hk_vec) + + if np.any(np.isnan(HessianTerm)) or np.any(np.isinf(HessianTerm)): + gammahat_newTemp = gammaC + else: + try: + gammahat_newTemp = gammaC - np.linalg.solve(HessianTerm, GradTerm) + except np.linalg.LinAlgError: + gammahat_newTemp = gammaC + if np.any(np.isnan(gammahat_newTemp)): + gammahat_newTemp = gammaC + + mabsDiff = float(np.max(np.abs(gammahat_newTemp - gammaC))) + if mabsDiff < diffTol: + converged = True + gammaC = gammahat_newTemp + if converged: + break + + if gammahat_new.ndim == 2: + gammahat_new[:, c] = gammaC + else: + gammahat_new = gammaC + + return Ahat, Qhat, muhat_new, betahat_new, gammahat_new, x0hat, Px0hat + + @staticmethod + def PP_EM( + dN, + Ahat0, + Qhat0, + mu, + beta, + fitType="poisson", + delta=0.001, + gamma=None, + windowTimes=None, + x0=None, + Px0=None, + PPEM_Constraints=None, + MstepMethod="NewtonRaphson", + ): + """Full Point-Process state-space EM algorithm. + + Estimates state-space model parameters (A, Q, mu, beta, gamma) via EM + for point-process observations. Unlike PPSS_EM, this operates on raw + spike observations with explicit beta/mu/gamma parameters (no basis + functions). + + Parameters + ---------- + dN : (C, N) binary spike observations + Ahat0 : (dx, dx) initial state transition matrix + Qhat0 : (dx, dx) initial state noise covariance + mu : (C,) initial baseline log-rates + beta : (dx, C) initial stimulus coefficients + fitType : 'poisson' or 'binomial' + delta : float, time bin width + gamma : (nW, C) or scalar, initial history coefficients + windowTimes : history window boundaries + x0 : (dx,) initial state (default zeros) + Px0 : (dx, dx) initial state covariance + PPEM_Constraints : dict from PP_EMCreateConstraints + MstepMethod : 'NewtonRaphson' or 'GLM' + + Returns + ------- + xKFinal, WKFinal, Ahat, Qhat, muhat, betahat, gammahat, + x0hat, Px0hat, IC, SE, Pvals, nIter + """ + from .history import History # local import to avoid circular dependency + + Ahat0 = np.atleast_2d(Ahat0).astype(float) + Qhat0 = np.atleast_2d(Qhat0).astype(float) + numStates = Ahat0.shape[0] + dN = np.atleast_2d(dN).astype(float) + + if PPEM_Constraints is None: + PPEM_Constraints = DecodingAlgorithms.PP_EMCreateConstraints() + if Px0 is None: + Px0 = 1e-9 * np.eye(numStates) + else: + Px0 = np.atleast_2d(Px0).astype(float) + if x0 is None: + x0 = np.zeros(numStates) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + if gamma is None: + gamma = np.zeros(0) + gamma = np.asarray(gamma, dtype=float) + + if delta is None or delta == 0: + delta = 0.001 + + if windowTimes is None: + gamma_flat = gamma.ravel() + if gamma_flat.size == 0 or (gamma_flat.size == 1 and gamma_flat[0] == 0): + windowTimes = [] + else: + windowTimes = np.arange(0, (gamma.shape[0] + 2) * delta, delta).tolist() + + mu = np.asarray(mu, dtype=float).reshape(-1) + beta = np.atleast_2d(beta).astype(float) + + # Build HkAll from spike trains and history windows + K_cells = dN.shape[0] + N_time = dN.shape[1] + minTime = 0.0 + maxTime = (N_time - 1) * delta + + if len(windowTimes) > 0: + histObj = History(windowTimes, minTime, maxTime) + HkAll_list = [] + for k in range(K_cells): + spike_indices = np.where(dN[k, :] == 1)[0] + spike_times = (spike_indices) * delta + nst = nspikeTrain(spike_times) + nst.setMinTime(minTime) + nst.setMaxTime(maxTime) + hmat = histObj.computeHistory(nst).dataToMatrix() + HkAll_list.append(hmat) + # Stack: (N_time, nW, K_cells) + HkAll = np.stack(HkAll_list, axis=2) + else: + HkAll = np.zeros((N_time, 0, K_cells)) + gamma = np.zeros(1) + gamma[0] = 0.0 + + # EM setup + tolAbs = 1e-3 + llTol = 1e-3 + maxIter = 100 + numToKeep = 10 + + # Circular buffer storage + A_buf = [None] * numToKeep + Q_buf = [None] * numToKeep + x0_buf = [None] * numToKeep + Px0_buf = [None] * numToKeep + mu_buf = [None] * numToKeep + beta_buf = [None] * numToKeep + gamma_buf = [None] * numToKeep + x_K_buf = [None] * numToKeep + W_K_buf = [None] * numToKeep + ExpSums_buf = [None] * numToKeep + + # Scaled system initialization + A0 = Ahat0.copy() + Q0 = Qhat0.copy() + + A_buf[0] = A0.copy() + Q_buf[0] = Q0.copy() + x0_buf[0] = x0.copy() + Px0_buf[0] = Px0.copy() + mu_buf[0] = mu.copy() + beta_buf[0] = beta.copy() + gamma_buf[0] = gamma.copy() + + # Apply scaling + try: + Tq = np.linalg.solve(np.linalg.cholesky(Q_buf[0]).T, np.eye(numStates)) + except np.linalg.LinAlgError: + Tq = np.eye(numStates) + TqInv = np.linalg.inv(Tq) + + A_buf[0] = Tq @ A_buf[0] @ TqInv + Q_buf[0] = Tq @ Q_buf[0] @ Tq.T + x0_buf[0] = Tq @ x0 + Px0_buf[0] = Tq @ Px0 @ Tq.T + beta_buf[0] = np.linalg.solve(Tq.T, beta_buf[0]) + + ll_list = [] + dLikelihood = [np.inf] + stoppingCriteria = False + cnt = 0 + + print(" Point-Process Observation EM Algorithm ") + while not stoppingCriteria and cnt < maxIter: + si = cnt % numToKeep + si_p1 = (cnt + 1) % numToKeep + si_m1 = (cnt - 1) % numToKeep + + print("-" * 80) + print(f"Iteration #{cnt + 1}") + print("-" * 80) + + # E-step + x_K_cur, W_K_cur, ll, ExpSums = DecodingAlgorithms.PP_EStep( + A_buf[si], Q_buf[si], dN, mu_buf[si], beta_buf[si], + fitType, gamma_buf[si], HkAll, x0_buf[si], Px0_buf[si] + ) + x_K_buf[si] = x_K_cur + W_K_buf[si] = W_K_cur + ExpSums_buf[si] = ExpSums + ll_list.append(ll) + + # M-step + Anew, Qnew, munew, bnew, gnew, x0new, Px0new = DecodingAlgorithms.PP_MStep( + dN, x_K_cur, W_K_cur, x0_buf[si], Px0_buf[si], ExpSums, + fitType, mu_buf[si], beta_buf[si], gamma_buf[si], + windowTimes, HkAll, PPEM_Constraints, MstepMethod + ) + A_buf[si_p1] = Anew + Q_buf[si_p1] = Qnew + mu_buf[si_p1] = munew + beta_buf[si_p1] = bnew + gamma_buf[si_p1] = gnew + x0_buf[si_p1] = x0new + Px0_buf[si_p1] = Px0new + + if not PPEM_Constraints["EstimateA"]: + A_buf[si_p1] = A_buf[si] + + # Convergence check + if cnt == 0: + dLikelihood.append(np.inf) + dMax = np.inf + else: + dLikelihood.append(ll_list[cnt] - ll_list[cnt - 1]) + dQvals = float(np.max(np.abs(np.sqrt(np.maximum(np.abs(Q_buf[si]), 0)) - np.sqrt(np.maximum(np.abs(Q_buf[si_m1]), 0))))) + dAvals = float(np.max(np.abs(A_buf[si] - A_buf[si_m1]))) + dMuvals = float(np.max(np.abs(mu_buf[si] - mu_buf[si_m1]))) + dBetavals = float(np.max(np.abs(beta_buf[si] - beta_buf[si_m1]))) + gam_cur = gamma_buf[si].ravel() if gamma_buf[si] is not None else np.zeros(1) + gam_prev = gamma_buf[si_m1].ravel() if gamma_buf[si_m1] is not None else np.zeros(1) + dGammavals = float(np.max(np.abs(gam_cur[:min(len(gam_cur), len(gam_prev))] - gam_prev[:min(len(gam_cur), len(gam_prev))]))) if gam_cur.size > 0 else 0.0 + dMax = max(dQvals, dAvals, dMuvals, dBetavals, dGammavals) + + if cnt == 0: + print("Max Parameter Change: N/A") + else: + print(f"Max Parameter Change: {dMax:.6f}") + + cnt += 1 + if dMax < tolAbs: + stoppingCriteria = True + print(f" EM converged at iteration# {cnt} b/c change in params was within criteria") + + if abs(dLikelihood[-1]) < llTol or dLikelihood[-1] < 0: + stoppingCriteria = True + print(f" EM stopped at iteration# {cnt} b/c change in likelihood was negative") + + print("-" * 80) + + # Select best iteration + ll_arr = np.array(ll_list) + if ll_arr.size > 0: + maxLLIndex = int(np.argmax(ll_arr)) + else: + maxLLIndex = 0 + maxLLIndMod = maxLLIndex % numToKeep + nIter = cnt + + xKFinal = x_K_buf[maxLLIndMod] if x_K_buf[maxLLIndMod] is not None else np.zeros((numStates, N_time)) + WKFinal = W_K_buf[maxLLIndMod] if W_K_buf[maxLLIndMod] is not None else np.zeros((numStates, numStates, N_time)) + Ahat = A_buf[maxLLIndMod] if A_buf[maxLLIndMod] is not None else A0 + Qhat = Q_buf[maxLLIndMod] if Q_buf[maxLLIndMod] is not None else Q0 + muhat = mu_buf[maxLLIndMod] if mu_buf[maxLLIndMod] is not None else mu + betahat = beta_buf[maxLLIndMod] if beta_buf[maxLLIndMod] is not None else beta + gammahat = gamma_buf[maxLLIndMod] if gamma_buf[maxLLIndMod] is not None else gamma + x0hat = x0_buf[maxLLIndMod] if x0_buf[maxLLIndMod] is not None else x0 + Px0hat = Px0_buf[maxLLIndMod] if Px0_buf[maxLLIndMod] is not None else Px0 + ExpSumsFinal = ExpSums_buf[maxLLIndMod] if ExpSums_buf[maxLLIndMod] is not None else {} + + # Unscale system + try: + Tq_unscale = np.linalg.solve(np.linalg.cholesky(Q0).T, np.eye(numStates)) + except np.linalg.LinAlgError: + Tq_unscale = np.eye(numStates) + TqInv_unscale = np.linalg.inv(Tq_unscale) + + Ahat = TqInv_unscale @ Ahat @ Tq_unscale + Qhat = TqInv_unscale @ Qhat @ TqInv_unscale.T + xKFinal = TqInv_unscale @ xKFinal + x0hat = TqInv_unscale @ x0hat + Px0hat = TqInv_unscale @ Px0hat @ TqInv_unscale.T + if WKFinal.ndim == 3: + for kk in range(WKFinal.shape[2]): + WKFinal[:, :, kk] = TqInv_unscale @ WKFinal[:, :, kk] @ TqInv_unscale.T + betahat = (betahat.T @ Tq_unscale).T + + # Compute standard errors + SE = {} + Pvals = {} + if ExpSumsFinal: + try: + SE, Pvals, _ = DecodingAlgorithms.PP_ComputeParamStandardErrors( + dN, xKFinal, WKFinal, Ahat, Qhat, x0hat, Px0hat, + ExpSumsFinal, fitType, muhat, betahat, gammahat, + windowTimes, HkAll, PPEM_Constraints + ) + except Exception: + pass + + # Information criteria + K_total = xKFinal.shape[1] + Dx = Ahat.shape[1] + + # Count parameters + if PPEM_Constraints["EstimateA"] and PPEM_Constraints["AhatDiag"]: + n1_ic = Ahat.shape[0] + elif PPEM_Constraints["EstimateA"]: + n1_ic = Ahat.size + else: + n1_ic = 0 + if PPEM_Constraints["QhatDiag"] and PPEM_Constraints["QhatIsotropic"]: + n2_ic = 1 + elif PPEM_Constraints["QhatDiag"]: + n2_ic = Qhat.shape[0] + else: + n2_ic = Qhat.size + if PPEM_Constraints["EstimatePx0"] and PPEM_Constraints["Px0Isotropic"]: + n3_ic = 1 + elif PPEM_Constraints["EstimatePx0"]: + n3_ic = Px0hat.shape[0] + else: + n3_ic = 0 + n4_ic = x0hat.size if PPEM_Constraints["Estimatex0"] else 0 + n5_ic = muhat.size + n6_ic = betahat.size + gammahat_flat = gammahat.ravel() + if gammahat_flat.size == 1 and gammahat_flat[0] == 0: + n7_ic = 0 + else: + n7_ic = gammahat.size + nTerms_ic = n1_ic + n2_ic + n3_ic + n4_ic + n5_ic + n6_ic + n7_ic + + sumXkTerms_ic = ExpSumsFinal.get("sumXkTerms", np.zeros((Dx, Dx))) + ll_best = ll_list[maxLLIndex] if ll_list else 0.0 + det_Q = max(float(np.linalg.det(Qhat)), np.finfo(float).tiny) + det_Px0 = max(float(np.linalg.det(Px0hat)), np.finfo(float).tiny) + + llobs = (ll_best + + Dx * K_total / 2.0 * np.log(2.0 * np.pi) + + K_total / 2.0 * np.log(det_Q) + + 0.5 * np.trace(np.linalg.solve(Qhat, sumXkTerms_ic)) + + Dx / 2.0 * np.log(2.0 * np.pi) + + 0.5 * np.log(det_Px0) + + 0.5 * Dx) + AIC = 2 * nTerms_ic - 2 * llobs + AICc = AIC + 2 * nTerms_ic * (nTerms_ic + 1) / max(K_total - nTerms_ic - 1, 1) + BIC = -2 * llobs + nTerms_ic * np.log(K_total) + + IC = { + "AIC": AIC, + "AICc": AICc, + "BIC": BIC, + "llobs": llobs, + "llcomp": ll_best, + } + + return (xKFinal, WKFinal, Ahat, Qhat, muhat, betahat, gammahat, + x0hat, Px0hat, IC, SE, Pvals, nIter) + + + # mPPCO family -- mixed Point-Process & Continuous Observation + # ------------------------------------------------------------------ + + @staticmethod + def mPPCODecode_predict(x_u, W_u, A, Q): + """Predict step for the mPPCO filter. + + Matlab: ``DecodingAlgorithms.mPPCODecode_predict`` (lines 4846-4854) + + Parameters + ---------- + x_u : array (ns,) -- updated state + W_u : array (ns,ns) -- updated covariance + A : array (ns,ns) -- state transition + Q : array (ns,ns) -- process noise + + Returns + ------- + x_p : array (ns,) + W_p : array (ns,ns) + """ + x_u = np.asarray(x_u, dtype=float).reshape(-1) + ns = x_u.size + A = np.asarray(A, dtype=float).reshape(ns, ns) + Q = np.asarray(Q, dtype=float).reshape(ns, ns) + W_u = np.asarray(W_u, dtype=float).reshape(ns, ns) + x_p = A @ x_u + W_p = A @ W_u @ A.T + Q + W_p = _symmetrize(W_p) + return x_p, W_p + + @staticmethod + def mPPCODecode_update(x_p, W_p, C, R, y, alpha, dN, mu, beta, + fitType='poisson', gamma=None, HkAll=None, + time_index=1, WuConv=None): + """Update step for the mPPCO filter (PP + continuous observation). + + Matlab: ``DecodingAlgorithms.mPPCODecode_update`` (lines 4855-4944) + + This combines both the point-process update terms (sumValVec/sumValMat) + AND the Kalman/continuous-observation terms C'*R^{-1}*C and C'*R^{-1}*(y-Cx-alpha). + + Parameters + ---------- + x_p : (ns,) -- predicted state + W_p : (ns,ns) -- predicted covariance + C : (nObs,ns) -- observation matrix + R : (nObs,nObs) -- observation noise covariance + y : (nObs,) -- continuous observation at this time step + alpha : (nObs,) -- observation offset + dN : (numCells,N) -- spike matrix (full) + mu : (numCells,) -- CIF baseline + beta : (ns,numCells) -- CIF state coefficients + fitType : 'poisson' or 'binomial' + gamma : (numWindows,numCells) or scalar -- history coefficients + HkAll : (numWindows,numCells,N) -- permuted history tensor with time on 3rd axis + time_index : int -- 1-based time index + WuConv : converged covariance or None + + Returns + ------- + x_u : (ns,) + W_u : (ns,ns) + lambdaDeltaMat : (numCells,1) + """ + x_p = np.asarray(x_p, dtype=float).reshape(-1) + ns = x_p.size + W_p = np.asarray(W_p, dtype=float).reshape(ns, ns) + obs = _as_observation_matrix(dN) + numCells = obs.shape[0] + C = np.asarray(C, dtype=float) + R = np.asarray(R, dtype=float) + y = np.asarray(y, dtype=float).reshape(-1) + alpha = np.asarray(alpha, dtype=float).reshape(-1) + mu_vec = np.asarray(mu, dtype=float).reshape(-1) + beta_mat = np.asarray(beta, dtype=float) + if beta_mat.ndim == 1: + beta_mat = beta_mat.reshape(-1, 1) + + # Default gamma + if gamma is None or (np.isscalar(gamma) and gamma == 0): + gamma_mat = np.zeros((1, numCells), dtype=float) + else: + gamma_mat = np.asarray(gamma, dtype=float) + if gamma_mat.ndim == 1: + gamma_mat = gamma_mat.reshape(-1, 1) + + # Default HkAll -- expects (numWindows, numCells, N) orientation + if HkAll is None: + HkAll_arr = np.zeros((1, numCells, 1), dtype=float) + else: + HkAll_arr = np.asarray(HkAll, dtype=float) + + sumValVec = np.zeros(ns, dtype=float) + sumValMat = np.zeros((ns, ns), dtype=float) + lambdaDeltaMat = np.zeros(numCells, dtype=float) + + # If gamma is scalar zero, expand + if gamma_mat.size == 1 and gamma_mat.flat[0] == 0: + gamma_mat = np.zeros_like(mu_vec).reshape(-1, 1) + + # Ensure gamma_mat is (numWindows, numCells) + if gamma_mat.shape[1] != numCells: + if gamma_mat.shape[0] == numCells: + gamma_mat = gamma_mat.T + + # Replicate gamma for all cells if needed + if gamma_mat.ndim == 2 and gamma_mat.shape[1] != numCells: + gamma_mat = np.tile(gamma_mat, (1, numCells)) + + # time_index is 1-based; extract history at this time + tidx = int(time_index) - 1 # zero-based + if HkAll_arr.ndim == 3 and HkAll_arr.shape[2] > tidx: + Histterm = HkAll_arr[:, :, tidx] # (numWindows, numCells) + else: + Histterm = np.zeros((gamma_mat.shape[0], numCells), dtype=float) + + if Histterm.shape[0] != numCells: + pass # already (numWindows, numCells) orientation + else: + if Histterm.shape[0] == numCells and Histterm.shape[1] != numCells: + Histterm = Histterm.T + + if str(fitType) == 'binomial': + # linTerm = mu + beta'*x_p + diag(gamma'*Histterm') + linTerm = mu_vec + beta_mat.T @ x_p + np.diag(gamma_mat.T @ Histterm) + exp_linTerm = np.exp(np.clip(linTerm, -500, 500)) + lambdaDeltaMat = exp_linTerm / (1.0 + exp_linTerm) + lambdaDeltaMat = np.where(np.isnan(lambdaDeltaMat) | np.isinf(lambdaDeltaMat), 1.0, lambdaDeltaMat) + + dN_t = obs[:, int(time_index) - 1] + factor = (dN_t - lambdaDeltaMat) * (1.0 - lambdaDeltaMat) + sumValVec = np.sum(beta_mat * factor[None, :], axis=1) + tempVec = (dN_t + (1.0 - 2.0 * lambdaDeltaMat)) * (1.0 - lambdaDeltaMat) * lambdaDeltaMat + sumValMat = (beta_mat * tempVec[None, :]) @ beta_mat.T + + elif str(fitType) == 'poisson': + linTerm = mu_vec + beta_mat.T @ x_p + np.diag(gamma_mat.T @ Histterm) + lambdaDeltaMat = np.exp(np.clip(linTerm, -500, 500)) + lambdaDeltaMat = np.where(np.isnan(lambdaDeltaMat) | np.isinf(lambdaDeltaMat), 1.0, lambdaDeltaMat) + + dN_t = obs[:, int(time_index) - 1] + sumValVec = np.sum(beta_mat * (dN_t - lambdaDeltaMat)[None, :], axis=1) + sumValMat = (beta_mat * lambdaDeltaMat[None, :]) @ beta_mat.T + + if WuConv is None or _is_empty_value(WuConv): + # sumValMat += C' * R^{-1} * C (continuous observation term) + sumValMat = sumValMat + C.T @ np.linalg.solve(R, C) + I = np.eye(ns, dtype=float) + try: + Wu = W_p @ (I - np.linalg.solve(I + sumValMat @ W_p, sumValMat @ W_p)) + except np.linalg.LinAlgError: + Wu = W_p.copy() + if np.any(np.isnan(Wu)) or np.any(np.isinf(Wu)): + Wu = W_p.copy() + W_u = _symmetrize(Wu) + else: + W_u = np.asarray(WuConv, dtype=float).reshape(ns, ns) + + # x_u = x_p + W_u*sumValVec + (W_u*C'/R)*(y - C*x_p - alpha) + x_u = x_p + W_u @ sumValVec + W_u @ C.T @ np.linalg.solve(R, y - C @ x_p - alpha) + + return x_u, W_u, lambdaDeltaMat.reshape(-1, 1) + + @staticmethod + def mPPCODecodeLinear(A, Q, C, R, y, alpha, dN, mu, beta, + fitType='poisson', delta=0.001, gamma=None, + windowTimes=None, x0=None, Px0=None, HkAll=None): + """Full mPPCO decode filter (linear CIF version). + + Matlab: ``DecodingAlgorithms.mPPCODecodeLinear`` (lines 4689-4845) + + Returns + ------- + x_p, W_p, x_u, W_u -- predicted / updated states & covariances + x_p : (ns, N+1), W_p : (ns, ns, N+1) + x_u : (ns, N), W_u : (ns, ns, N) + """ + obs = _as_observation_matrix(dN) + numCells, N = obs.shape + A_arr = np.asarray(A, dtype=float) + ns = A_arr.shape[0] + + # Defaults + if Px0 is None or _is_empty_value(Px0): + Px0 = np.zeros((ns, ns), dtype=float) + else: + Px0 = np.asarray(Px0, dtype=float).reshape(ns, ns) + if x0 is None or _is_empty_value(x0): + x0 = np.zeros(ns, dtype=float) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + if gamma is None: + gamma = 0 + if delta is None: + delta = 0.001 + + minTime = 0.0 + maxTime = (N - 1) * delta + + # Build history tensor if not provided + if HkAll is None or _is_empty_value(HkAll): + if windowTimes is not None and not _is_empty_value(windowTimes): + wt = np.asarray(windowTimes, dtype=float).reshape(-1) + HkAll = _compute_history_terms(dN, delta, wt) # (N, numWindows, numCells) + gamma_arr = np.asarray(gamma, dtype=float) + if gamma_arr.ndim <= 1 and gamma_arr.size == 1 and numCells > 1: + gamma = np.tile(gamma_arr.reshape(-1, 1), (1, numCells)) + else: + HkAll = np.zeros((N, 1, numCells), dtype=float) + gamma = np.zeros(numCells, dtype=float) + else: + HkAll = np.asarray(HkAll, dtype=float) + + gamma_arr = np.asarray(gamma, dtype=float) + if gamma_arr.ndim == 2 and gamma_arr.shape[1] != numCells: + gamma = gamma_arr.T + + # Permute HkAll from (N, numWindows, numCells) to (numWindows, numCells, N) + # This is Matlab: permute(HkAll, [2 3 1]) + if HkAll.ndim == 3 and HkAll.shape[0] == N: + Histtermperm = np.transpose(HkAll, (1, 2, 0)) + else: + Histtermperm = HkAll + + mu_vec = np.asarray(mu, dtype=float).reshape(-1) + beta_mat = np.asarray(beta, dtype=float) + if beta_mat.ndim == 1: + beta_mat = beta_mat.reshape(-1, 1) + + # Allocate outputs + x_p = np.zeros((ns, N + 1), dtype=float) + x_u = np.zeros((ns, N), dtype=float) + W_p = np.zeros((ns, ns, N + 1), dtype=float) + W_u = np.zeros((ns, ns, N), dtype=float) + + # Time-varying or static matrices: pick slice for time 0 + def _sel_A(n): + if A_arr.ndim == 3: + return A_arr[:, :, min(n, A_arr.shape[2] - 1)] + return A_arr.reshape(ns, ns) + + def _sel_Q(n): + Q_arr = np.asarray(Q, dtype=float) + if Q_arr.ndim == 3: + return Q_arr[:, :, min(n, Q_arr.shape[2] - 1)] + return Q_arr.reshape(ns, ns) + + def _sel_C(n): + C_arr = np.asarray(C, dtype=float) + if C_arr.ndim == 3: + return C_arr[:, :, min(n, C_arr.shape[2] - 1)] + return C_arr + + def _sel_R(n): + R_arr = np.asarray(R, dtype=float) + if R_arr.ndim == 3: + return R_arr[:, :, min(n, R_arr.shape[2] - 1)] + return R_arr + + def _sel_alpha(n): + alpha_arr = np.asarray(alpha, dtype=float) + if alpha_arr.ndim >= 2 and alpha_arr.shape[-1] > 1: + return alpha_arr[:, min(n, alpha_arr.shape[-1] - 1)] + return alpha_arr.reshape(-1) + + # Initial prediction + A1 = _sel_A(0) + Q1 = _sel_Q(0) + x_p[:, 0] = A1 @ x0 + W_p[:, :, 0] = A1 @ Px0 @ A1.T + Q1 + + y_arr = np.asarray(y, dtype=float) + + for n in range(N): + # 1-based time_index for mPPCODecode_update + x_u[:, n], W_u[:, :, n], _ = DecodingAlgorithms.mPPCODecode_update( + x_p[:, n], W_p[:, :, n], + _sel_C(n), _sel_R(n), + y_arr[:, n] if y_arr.ndim == 2 else y_arr, + _sel_alpha(n), + dN, mu_vec, beta_mat, fitType, + gamma, Histtermperm, n + 1, None) + if n < N - 1: + x_p[:, n + 1], W_p[:, :, n + 1] = DecodingAlgorithms.mPPCODecode_predict( + x_u[:, n], W_u[:, :, n], _sel_A(n), _sel_Q(n)) + + return x_p, W_p, x_u, W_u + + @staticmethod + def mPPCO_fixedIntervalSmoother(A, Q, C, R, y, alpha, dN, lags, mu, beta, + fitType, delta=0.001, gamma=None, + windowTimes=None, x0=None, Px0=None, HkAll=None): + """State-augmentation smoother for the mPPCO filter. + + Matlab: ``DecodingAlgorithms.mPPCO_fixedIntervalSmoother`` (lines 4587-4688) + + Returns + ------- + x_pLag, W_pLag, x_uLag, W_uLag -- lagged state estimates + """ + obs = _as_observation_matrix(dN) + numCells, N = obs.shape + A_arr = np.asarray(A, dtype=float) + ns = A_arr.shape[0] + nObs = np.asarray(C, dtype=float).shape[0] + + if Px0 is None or _is_empty_value(Px0): + Px0 = np.zeros((ns, ns), dtype=float) + else: + Px0 = np.asarray(Px0, dtype=float).reshape(ns, ns) + if x0 is None or _is_empty_value(x0): + x0 = np.zeros(ns, dtype=float) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + if gamma is None: + gamma = 0 + if delta is None: + delta = 0.001 + + minTime = 0.0 + maxTime = (N - 1) * delta + + # Build history if needed + if HkAll is None or _is_empty_value(HkAll): + if windowTimes is not None and not _is_empty_value(windowTimes): + wt = np.asarray(windowTimes, dtype=float).reshape(-1) + HkAll = _compute_history_terms(dN, delta, wt) + gamma_arr = np.asarray(gamma, dtype=float) + if gamma_arr.ndim <= 1 and gamma_arr.size == 1 and numCells > 1: + gamma = np.tile(gamma_arr.reshape(-1, 1), (1, numCells)) + else: + HkAll = np.zeros((N, 1, numCells), dtype=float) + gamma = np.zeros(numCells, dtype=float) + + gamma_arr = np.asarray(gamma, dtype=float) + if gamma_arr.ndim == 2 and gamma_arr.shape[1] != numCells: + gamma = gamma_arr.T + + lags = int(lags) + nStates = ns + + # Build augmented system + aug_dim = (lags + 1) * nStates + + def _sel_A(n): + if A_arr.ndim == 3: + return A_arr[:, :, min(n, A_arr.shape[2] - 1)] + return A_arr.reshape(ns, ns) + + def _sel_Q(n): + Q_arr = np.asarray(Q, dtype=float) + if Q_arr.ndim == 3: + return Q_arr[:, :, min(n, Q_arr.shape[2] - 1)] + return Q_arr.reshape(ns, ns) + + def _sel_C(n): + C_arr = np.asarray(C, dtype=float) + if C_arr.ndim == 3: + return C_arr[:, :, min(n, C_arr.shape[2] - 1)] + return C_arr + + def _sel_R(n): + R_arr = np.asarray(R, dtype=float) + if R_arr.ndim == 3: + return R_arr[:, :, min(n, R_arr.shape[2] - 1)] + return R_arr + + Alag = np.zeros((aug_dim, aug_dim, N), dtype=float) + Qlag = np.zeros((aug_dim, aug_dim, N), dtype=float) + Clag = np.zeros((nObs, aug_dim, N), dtype=float) + Rlag = np.zeros((nObs, nObs, N), dtype=float) + x0lag = np.zeros(aug_dim, dtype=float) + Px0lag = np.zeros((aug_dim, aug_dim), dtype=float) + Px0lag[:nStates, :nStates] = Px0 + x0lag[:nStates] = x0 + + for n in range(N): + offset = 0 + for i in range(lags + 1): + if i == 0: + Alag[offset:offset + nStates, offset:offset + nStates, n] = _sel_A(n) + Qlag[offset:offset + nStates, offset:offset + nStates, n] = _sel_Q(n) + Clag[:nObs, offset:offset + nStates, n] = _sel_C(n) + Rlag[:nObs, :nObs, n] = _sel_R(n) + else: + Alag[offset:offset + nStates, offset - nStates:offset, n] = np.eye(nStates) + # Qlag block remains zeros + # Clag block remains zeros + offset += nStates + + betaLag = np.zeros((aug_dim, numCells), dtype=float) + beta_mat = np.asarray(beta, dtype=float) + if beta_mat.ndim == 1: + beta_mat = beta_mat.reshape(-1, 1) + betaLag[:nStates, :numCells] = beta_mat + + x_p, W_p, x_u, W_u = DecodingAlgorithms.mPPCODecodeLinear( + Alag, Qlag, Clag, Rlag, y, alpha, dN, + mu, betaLag, fitType, delta, gamma, windowTimes, + x0lag, Px0lag, HkAll) + + # Extract lagged portion + lag_start = lags * nStates + lag_end = (lags + 1) * nStates + x_pLag = x_p[lag_start:lag_end, :] + W_pLag = W_p[lag_start:lag_end, lag_start:lag_end, :] + x_uLag = x_u[lag_start:lag_end, :] + W_uLag = W_u[lag_start:lag_end, lag_start:lag_end, :] + + return x_pLag, W_pLag, x_uLag, W_uLag + + @staticmethod + def mPPCO_EMCreateConstraints(EstimateA=1, AhatDiag=0, QhatDiag=1, + QhatIsotropic=0, RhatDiag=1, + RhatIsotropic=0, Estimatex0=1, + EstimatePx0=1, Px0Isotropic=0, + mcIter=1000, EnableIkeda=0): + """Create constraint dictionary for mPPCO EM. + + Matlab: ``DecodingAlgorithms.mPPCO_EMCreateConstraints`` (lines 4945-5005) + """ + C = {} + C['EstimateA'] = int(EstimateA) + C['AhatDiag'] = int(AhatDiag) + C['QhatDiag'] = int(QhatDiag) + C['QhatIsotropic'] = 1 if (QhatDiag and QhatIsotropic) else 0 + C['RhatDiag'] = int(RhatDiag) + C['RhatIsotropic'] = 1 if (RhatDiag and RhatIsotropic) else 0 + C['Estimatex0'] = int(Estimatex0) + C['EstimatePx0'] = int(EstimatePx0) + C['Px0Isotropic'] = 1 if (EstimatePx0 and Px0Isotropic) else 0 + C['mcIter'] = int(mcIter) + C['EnableIkeda'] = int(EnableIkeda) + return C + + @staticmethod + def mPPCO_ComputeParamStandardErrors(y, dN, xKFinal, WKFinal, Ahat, Qhat, + Chat, Rhat, alphahat, x0hat, Px0hat, + ExpectationSumsFinal, fitType, + muhat, betahat, gammahat, + windowTimes, HkAll, + mPPCOEM_Constraints=None): + """Compute standard errors for mPPCO EM parameters. + + Matlab: ``DecodingAlgorithms.mPPCO_ComputeParamStandardErrors`` (lines 5006-6138) + + Uses the observed information matrix approach: Io = Ic - Im (McLachlan & Krishnan Eq 4.7). + """ + if mPPCOEM_Constraints is None: + mPPCOEM_Constraints = DecodingAlgorithms.mPPCO_EMCreateConstraints() + + y = np.asarray(y, dtype=float) + obs = _as_observation_matrix(dN) + xKFinal = np.asarray(xKFinal, dtype=float) + Ahat = np.asarray(Ahat, dtype=float) + Qhat = np.asarray(Qhat, dtype=float) + Chat = np.asarray(Chat, dtype=float) + Rhat = np.asarray(Rhat, dtype=float) + alphahat = np.asarray(alphahat, dtype=float).reshape(-1) + x0hat = np.asarray(x0hat, dtype=float).reshape(-1) + Px0hat = np.asarray(Px0hat, dtype=float) + muhat = np.asarray(muhat, dtype=float).reshape(-1) + betahat = np.asarray(betahat, dtype=float) + if betahat.ndim == 1: + betahat = betahat.reshape(-1, 1) + gammahat = np.asarray(gammahat, dtype=float) + HkAll = np.asarray(HkAll, dtype=float) + + dy, N = y.shape if y.ndim == 2 else (1, y.shape[0]) + K = N + dx = xKFinal.shape[0] + numCells = betahat.shape[1] + McExp = mPPCOEM_Constraints['mcIter'] + + Qhat_inv = np.linalg.inv(Qhat) + Rhat_inv = np.linalg.inv(Rhat) + Px0hat_inv = np.linalg.inv(Px0hat + np.eye(Px0hat.shape[0]) * 1e-12) + + # ---- Complete Information Matrices ---- + + # IAComp - A parameter + if mPPCOEM_Constraints['EstimateA']: + n1A, n2A = Ahat.shape + el = np.eye(n1A) + em = np.eye(n2A) + if mPPCOEM_Constraints['AhatDiag']: + IAComp = np.zeros((n1A, n1A)) + for l in range(n1A): + termMat = Qhat_inv @ np.outer(el[:, l], em[:, l]) @ ExpectationSumsFinal['Sxkm1xkm1'] * np.eye(n1A) + IAComp[:, l] = np.diag(termMat) + else: + nA = Ahat.size + IAComp = np.zeros((nA, nA)) + cnt = 0 + for l in range(n1A): + for m in range(n2A): + termMat = Qhat_inv @ np.outer(el[:, l], em[:, m]) @ ExpectationSumsFinal['Sxkm1xkm1'] + IAComp[:, cnt] = termMat.T.reshape(-1) + cnt += 1 + + # ICComp - C parameter + n1C, n2C = Chat.shape + nC = Chat.size + ICComp = np.zeros((nC, nC)) + el = np.eye(n1C) + em = np.eye(n2C) + cnt = 0 + for l in range(n1C): + for m in range(n2C): + termMat = Rhat_inv @ np.outer(el[:, l], em[:, m]) @ ExpectationSumsFinal['Sxkxk'] + ICComp[:, cnt] = termMat.T.reshape(-1) + cnt += 1 + + # IRComp - R parameter + n1R, n2R = Rhat.shape + el = np.eye(n1R) + em = np.eye(n2R) + if mPPCOEM_Constraints['RhatDiag']: + if mPPCOEM_Constraints['RhatIsotropic']: + IRComp = np.array([[0.5 * N * dy * Rhat[0, 0] ** (-2)]]) + else: + IRComp = np.zeros((n1R, n1R)) + for l in range(n1R): + termMat = N / 2.0 * Rhat_inv @ np.outer(em[:, l], el[:, l]) @ Rhat_inv + IRComp[:, l] = np.diag(termMat) + else: + nR = Rhat.size + IRComp = np.zeros((nR, nR)) + cnt = 0 + for l in range(n1R): + for m in range(n2R): + termMat = N / 2.0 * Rhat_inv @ np.outer(em[:, m], el[:, l]) @ Rhat_inv + IRComp[:, cnt] = termMat.T.reshape(-1) + cnt += 1 + + # IQComp - Q parameter + n1Q, n2Q = Qhat.shape + el = np.eye(n1Q) + em = np.eye(n2Q) + if mPPCOEM_Constraints['QhatDiag']: + if mPPCOEM_Constraints['QhatIsotropic']: + IQComp = np.array([[0.5 * N * dx * Qhat[0, 0] ** (-2)]]) + else: + IQComp = np.zeros((n1Q, n1Q)) + for l in range(n1Q): + termMat = N / 2.0 * Qhat_inv @ np.outer(em[:, l], el[:, l]) @ Qhat_inv + IQComp[:, l] = np.diag(termMat) + else: + nQ = Qhat.size + IQComp = np.zeros((nQ, nQ)) + cnt = 0 + for l in range(n1Q): + for m in range(n2Q): + termMat = N / 2.0 * Qhat_inv @ np.outer(em[:, m], el[:, l]) @ Qhat_inv + IQComp[:, cnt] = termMat.T.reshape(-1) + cnt += 1 + + # ISComp - Px0 parameter + if mPPCOEM_Constraints['EstimatePx0']: + if mPPCOEM_Constraints['Px0Isotropic']: + ISComp = np.array([[0.5 * dx * Px0hat[0, 0] ** (-2)]]) + else: + n1S, n2S = Px0hat.shape + ISComp = np.zeros((n1S, n1S)) + el = np.eye(n1S) + em = np.eye(n2S) + for l in range(n1S): + termMat = 0.5 * Px0hat_inv @ np.outer(em[:, l], el[:, l]) @ Px0hat_inv + ISComp[:, l] = np.diag(termMat) + + # Ix0Comp + if mPPCOEM_Constraints['Estimatex0']: + Ix0Comp = Px0hat_inv + Ahat.T @ Qhat_inv @ Ahat + + # IAlphaComp + IAlphaComp = N * Rhat_inv + + # IBetaComp - Monte Carlo + xKDrawExp = np.zeros((dx, K, McExp), dtype=float) + for k in range(K): + WuTemp = WKFinal[:, :, k] + try: + chol_m = np.linalg.cholesky(WuTemp) + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(nearestSPD(WuTemp)) + z = np.random.randn(dx, McExp) + xKDrawExp[:, k, :] = xKFinal[:, k:k + 1] + chol_m @ z + + IBetaComp = np.zeros((dx * numCells, dx * numCells), dtype=float) + xkPerm = np.transpose(xKDrawExp, (0, 2, 1)) # (dx, McExp, K) + + for c in range(numCells): + HessianTerm = np.zeros((dx, dx), dtype=float) + for k in range(K): + Hk = HkAll[k, :, c] if HkAll.ndim == 3 else np.zeros(1) + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, c] if gammahat.ndim == 2 else gammahat) + terms = muhat[c] + betahat[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + HessianTerm -= (1.0 / McExp) * (np.tile(ld, (dx, 1)) * xk) @ xk.T + else: + ld = np.exp(np.clip(terms, -500, 500)) + ld = ld / (1.0 + ld) + EldXkXk = (1.0 / McExp) * (np.tile(ld, (dx, 1)) * xk) @ xk.T + EldSqXkXk = (1.0 / McExp) * (np.tile(ld ** 2, (dx, 1)) * xk) @ xk.T + EldCubeXkXk = (1.0 / McExp) * (np.tile(ld ** 3, (dx, 1)) * xk) @ xk.T + HessianTerm += EldXkXk + EldSqXkXk - 2.0 * EldCubeXkXk + si = dx * c + ei = dx * (c + 1) + IBetaComp[si:ei, si:ei] = -HessianTerm + + # IMuComp + IMuComp = np.zeros((numCells, numCells), dtype=float) + for c in range(numCells): + HessianTerm = 0.0 + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 1)) + Hk = Hk_full[k, :] + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, c] if gammahat.ndim == 2 else gammahat) + terms = muhat[c] + betahat[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + HessianTerm -= (1.0 / McExp) * float(np.sum(ld)) + else: + ld = np.exp(np.clip(terms, -500, 500)) / (1.0 + np.exp(np.clip(terms, -500, 500))) + Eld = (1.0 / McExp) * float(np.sum(ld)) + EldSq = (1.0 / McExp) * float(np.sum(ld ** 2)) + EldCube = (1.0 / McExp) * float(np.sum(ld ** 3)) + HessianTerm += -(obs[c, k] + 1) * Eld + (obs[c, k] + 3) * EldSq - 3 * EldCube + IMuComp[c, c] = -HessianTerm + + # IGammaComp + nHist = HkAll.shape[1] if HkAll.ndim == 3 else 1 + IGammaComp = np.zeros((nHist * numCells, nHist * numCells), dtype=float) + has_gamma = (windowTimes is not None and not _is_empty_value(windowTimes) + and np.any(gammahat != 0)) + if has_gamma: + for c in range(numCells): + HessianTerm = np.zeros((nHist, nHist), dtype=float) + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 1)) + Hk = Hk_full[k, :] + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, c] if gammahat.ndim == 2 else gammahat) + terms = muhat[c] + betahat[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + Eld = (1.0 / McExp) * float(np.sum(ld)) + HessianTerm -= np.outer(Hk, Hk) * Eld + else: + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + Eld = (1.0 / McExp) * float(np.sum(ld)) + EldSq = (1.0 / McExp) * float(np.sum(ld ** 2)) + EldCube = (1.0 / McExp) * float(np.sum(ld ** 2)) # matches Matlab typo (ld.^2) + HessianTerm += (-Eld * (obs[c, k] + 1) + EldSq * (obs[c, k] + 3) - 2 * EldCube) * np.outer(Hk, Hk) + si = nHist * c + ei = nHist * (c + 1) + IGammaComp[si:ei, si:ei] = -HessianTerm + + # Assemble IComp + n1 = IAComp.shape[0] if mPPCOEM_Constraints['EstimateA'] else 0 + n2 = IQComp.shape[0] + n3 = ICComp.shape[0] + n4 = IRComp.shape[0] + n5 = ISComp.shape[0] if mPPCOEM_Constraints['EstimatePx0'] else 0 + n6 = Ix0Comp.shape[0] if mPPCOEM_Constraints['Estimatex0'] else 0 + n7 = IAlphaComp.shape[0] + n8 = IMuComp.shape[0] + n9 = IBetaComp.shape[0] + if gammahat.size == 1 and float(gammahat.flat[0]) == 0: + n10 = 0 + else: + n10 = IGammaComp.shape[0] + nTerms = n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + IComp = np.zeros((nTerms, nTerms), dtype=float) + + offset = 0 + if mPPCOEM_Constraints['EstimateA']: + IComp[offset:offset + n1, offset:offset + n1] = IAComp + offset += n1 + IComp[offset:offset + n2, offset:offset + n2] = IQComp + offset += n2 + IComp[offset:offset + n3, offset:offset + n3] = ICComp + offset += n3 + IComp[offset:offset + n4, offset:offset + n4] = IRComp + offset += n4 + if mPPCOEM_Constraints['EstimatePx0']: + IComp[offset:offset + n5, offset:offset + n5] = ISComp + offset += n5 + if mPPCOEM_Constraints['Estimatex0']: + IComp[offset:offset + n6, offset:offset + n6] = Ix0Comp + offset += n6 + IComp[offset:offset + n7, offset:offset + n7] = IAlphaComp + offset += n7 + IComp[offset:offset + n8, offset:offset + n8] = IMuComp + offset += n8 + IComp[offset:offset + n9, offset:offset + n9] = IBetaComp + offset += n9 + if n10 > 0: + IComp[offset:offset + n10, offset:offset + n10] = IGammaComp + + # ---- Missing Information Matrix (Monte Carlo) ---- + Mc = McExp + xKDraw = np.zeros((dx, N, Mc), dtype=float) + for n in range(N): + WuTemp = WKFinal[:, :, n] + try: + chol_m = np.linalg.cholesky(WuTemp) + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(nearestSPD(WuTemp)) + z = np.random.randn(dx, Mc) + xKDraw[:, n, :] = xKFinal[:, n:n + 1] + chol_m @ z + + if mPPCOEM_Constraints['EstimatePx0'] or mPPCOEM_Constraints['Estimatex0']: + try: + chol_m = np.linalg.cholesky(Px0hat) + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(nearestSPD(Px0hat)) + z = np.random.randn(dx, Mc) + x0Draw = x0hat.reshape(-1, 1) + chol_m @ z + else: + x0Draw = np.tile(x0hat.reshape(-1, 1), (1, Mc)) + + IMc = np.zeros((nTerms, nTerms, Mc), dtype=float) + Dx = dx + Dy = dy + + for c_mc in range(Mc): + x_K = xKDraw[:, :, c_mc] + x_0 = x0Draw[:, c_mc] + + Sxkm1xk = np.zeros((Dx, Dx)) + Sxkm1xkm1 = np.zeros((Dx, Dx)) + Sxkxk = np.zeros((Dx, Dx)) + Sykyk = np.zeros((Dy, Dy)) + Sxkyk = np.zeros((Dx, Dy)) + + for k in range(K): + if k == 0: + Sxkm1xk += np.outer(x_0, x_K[:, k]) + Sxkm1xkm1 += np.outer(x_0, x_0) + else: + Sxkm1xk += np.outer(x_K[:, k - 1], x_K[:, k]) + Sxkm1xkm1 += np.outer(x_K[:, k - 1], x_K[:, k - 1]) + Sxkxk += np.outer(x_K[:, k], x_K[:, k]) + yk_alpha = y[:, k] - alphahat if y.ndim == 2 else y - alphahat + Sykyk += np.outer(yk_alpha, yk_alpha) + Sxkyk += np.outer(x_K[:, k], yk_alpha) + + Sxkxk = _symmetrize(Sxkxk) + Sykyk = _symmetrize(Sykyk) + sumXkTerms_mc = Sxkxk - Ahat @ Sxkm1xk - Sxkm1xk.T @ Ahat.T + Ahat @ Sxkm1xkm1 @ Ahat.T + sumYkTerms_mc = Sykyk - Chat @ Sxkyk - Sxkyk.T @ Chat.T + Chat @ Sxkxk @ Chat.T + Sxkxkm1 = Sxkm1xk.T + sumXkTerms_mc = _symmetrize(sumXkTerms_mc) + sumYkTerms_mc = _symmetrize(sumYkTerms_mc) + + # Score: A + if mPPCOEM_Constraints['EstimateA']: + ScorA = np.linalg.solve(Qhat, Sxkxkm1 - Ahat @ Sxkm1xkm1) + if mPPCOEM_Constraints['AhatDiag']: + ScoreAMc = np.diag(ScorA) + else: + ScoreAMc = ScorA.T.reshape(-1) + else: + ScoreAMc = np.array([], dtype=float) + + # Score: C + ScorC = np.linalg.solve(Rhat, Sxkyk.T - Chat @ Sxkxk) + ScoreCMc = ScorC.T.reshape(-1) + + # Score: Q + if mPPCOEM_Constraints['QhatDiag']: + if mPPCOEM_Constraints['QhatIsotropic']: + ScoreQ = -0.5 * (K * Dx * Qhat[0, 0] ** (-1) - Qhat[0, 0] ** (-2) * np.trace(sumXkTerms_mc)) + ScoreQMc = np.array([ScoreQ]) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * np.eye(Dx) - np.linalg.solve(Qhat, sumXkTerms_mc).T) + ScoreQMc = np.diag(ScoreQ) + else: + ScoreQ = -0.5 * np.linalg.solve(Qhat, K * np.eye(Dx) - np.linalg.solve(Qhat, sumXkTerms_mc).T) + ScoreQMc = ScoreQ.T.reshape(-1) + + # Score: alpha + resid = y - Chat @ x_K - alphahat.reshape(-1, 1) @ np.ones((1, N)) if y.ndim == 2 else y - Chat @ x_K - alphahat.reshape(-1, 1) + ScoreAlphaMc = np.sum(np.linalg.solve(Rhat, resid), axis=1) + + # Score: R + if mPPCOEM_Constraints['RhatDiag']: + if mPPCOEM_Constraints['RhatIsotropic']: + ScoreR = -0.5 * (K * Dy * Rhat[0, 0] ** (-1) - Rhat[0, 0] ** (-2) * np.trace(sumYkTerms_mc)) + ScoreRMc = np.array([ScoreR]) + else: + ScoreR = -0.5 * np.linalg.solve(Rhat, K * np.eye(Dy) - np.linalg.solve(Rhat, sumYkTerms_mc).T) + ScoreRMc = np.diag(ScoreR) + else: + ScoreR = -0.5 * np.linalg.solve(Rhat, K * np.eye(Dy) - np.linalg.solve(Rhat, sumYkTerms_mc).T) + ScoreRMc = ScoreR.T.reshape(-1) + + # Score: Px0 + if mPPCOEM_Constraints['Px0Isotropic']: + diff0 = x_0 - x0hat + ScoreSMc = np.array([-0.5 * (Dx * Px0hat[0, 0] ** (-1) - Px0hat[0, 0] ** (-2) * np.trace(np.outer(diff0, diff0)))]) + else: + diff0 = x_0 - x0hat + ScorS = -0.5 * np.linalg.solve(Px0hat, np.eye(Dx) - np.linalg.solve(Px0hat, np.outer(diff0, diff0)).T) + ScoreSMc = np.diag(ScorS) + + # Score: x0 + Scorx0 = -np.linalg.solve(Px0hat, x_0 - x0hat) + Ahat.T @ np.linalg.solve(Qhat, x_K[:, 0] - Ahat @ x_0) + Scorex0Mc = Scorx0.reshape(-1) + + # Score: mu, beta, gamma per cell + ScoreMuMc = np.zeros(numCells) + ScoreBetaMc = np.array([], dtype=float) + ScoreGammaMc = np.array([], dtype=float) + for nc in range(numCells): + Hk_full = HkAll[:, :, nc] if HkAll.ndim == 3 else np.zeros((K, 1)) + nHistC = Hk_full.shape[1] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, nc] if gammahat.ndim == 2 else gammahat) + terms = muhat[nc] + betahat[:, nc] @ x_K + gammaC.reshape(-1) @ Hk_full.T + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + ScoreMuMc[nc] = float(np.sum(obs[nc, :] - ld)) + ScoreBetaMc = np.concatenate([ScoreBetaMc, np.sum(np.tile(obs[nc, :] - ld, (Dx, 1)) * x_K, axis=1)]) + ScoreGammaMc = np.concatenate([ScoreGammaMc, np.sum(np.tile(obs[nc, :] - ld, (nHistC, 1)) * Hk_full.T, axis=1)]) + else: + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + ScoreMuMc[nc] = float(np.sum(obs[nc, :] - (obs[nc, :] + 1) * ld + ld ** 2)) + ScoreBetaMc = np.concatenate([ScoreBetaMc, np.sum(np.tile(obs[nc, :] * (1 - ld) - ld * (1 - ld), (Dx, 1)) * x_K, axis=1)]) + ScoreGammaMc = np.concatenate([ScoreGammaMc, np.sum(np.tile(obs[nc, :] - (obs[nc, :] + 1) * ld + ld ** 2, (nHistC, 1)) * Hk_full.T, axis=1)]) + + ScoreVec = np.concatenate([ScoreAMc, ScoreQMc, ScoreCMc, ScoreRMc]) + if mPPCOEM_Constraints['EstimatePx0']: + ScoreVec = np.concatenate([ScoreVec, ScoreSMc]) + if mPPCOEM_Constraints['Estimatex0']: + ScoreVec = np.concatenate([ScoreVec, Scorex0Mc]) + ScoreVec = np.concatenate([ScoreVec, ScoreAlphaMc, ScoreMuMc, ScoreBetaMc]) + if n10 > 0: + ScoreVec = np.concatenate([ScoreVec, ScoreGammaMc]) + + IMc[:, :, c_mc] = np.outer(ScoreVec, ScoreVec) + + IMissing = np.mean(IMc, axis=2) + IObs = IComp - IMissing + try: + invIObs = np.linalg.inv(IObs) + except np.linalg.LinAlgError: + invIObs = np.linalg.pinv(IObs) + invIObs = nearestSPD(invIObs) + VarVec = np.diag(invIObs) + SEVec = np.sqrt(np.maximum(VarVec, 0.0)) + + # Partition SE vector + off = 0 + SEAterms = SEVec[off:off + n1]; off += n1 + SEQterms = SEVec[off:off + n2]; off += n2 + SECterms = SEVec[off:off + n3]; off += n3 + SERterms = SEVec[off:off + n4]; off += n4 + SEPx0terms = SEVec[off:off + n5]; off += n5 + SEx0terms = SEVec[off:off + n6]; off += n6 + SEAlphaterms = SEVec[off:off + n7]; off += n7 + SEMuTerms = SEVec[off:off + n8]; off += n8 + SEBetaTerms = SEVec[off:off + n9]; off += n9 + SEGammaTerms = SEVec[off:off + n10]; off += n10 + + SE = {} + if mPPCOEM_Constraints['EstimateA']: + if mPPCOEM_Constraints['AhatDiag']: + SE['A'] = np.diag(SEAterms) + else: + SE['A'] = SEAterms.reshape(Ahat.shape[1], Ahat.shape[0]).T + SE['Q'] = np.diag(SEQterms) if mPPCOEM_Constraints['QhatDiag'] else SEQterms.reshape(Qhat.shape[1], Qhat.shape[0]).T + SE['C'] = SECterms.reshape(Chat.shape[1], Chat.shape[0]).T + SE['R'] = np.diag(SERterms) if mPPCOEM_Constraints['RhatDiag'] else SERterms.reshape(Rhat.shape[1], Rhat.shape[0]).T + SE['alpha'] = SEAlphaterms.reshape(alphahat.shape) + if mPPCOEM_Constraints['EstimatePx0']: + SE['Px0'] = np.diag(SEPx0terms) + if mPPCOEM_Constraints['Estimatex0']: + SE['x0'] = SEx0terms + SE['mu'] = SEMuTerms + SE['beta'] = SEBetaTerms.reshape(betahat.shape[1], betahat.shape[0]).T + if n10 > 0: + SE['gamma'] = SEGammaTerms.reshape(gammahat.shape[1], gammahat.shape[0]).T if gammahat.ndim == 2 else SEGammaTerms + + # P-values (two-sided z-test) + Pvals = {} + if mPPCOEM_Constraints['EstimateA']: + pA_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(Ahat.reshape(-1) if not mPPCOEM_Constraints['AhatDiag'] else np.diag(Ahat), SE['A'].reshape(-1) if not mPPCOEM_Constraints['AhatDiag'] else np.diag(SE['A']))]) + Pvals['A'] = np.diag(pA_flat) if mPPCOEM_Constraints['AhatDiag'] else pA_flat.reshape(Ahat.shape) + pC_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(Chat.reshape(-1), SE['C'].reshape(-1))]) + Pvals['C'] = pC_flat.reshape(Chat.shape) + if mPPCOEM_Constraints['RhatDiag']: + pR_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(np.diag(Rhat), np.diag(SE['R']))]) + Pvals['R'] = np.diag(pR_flat) + else: + pR_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(Rhat.reshape(-1), SE['R'].reshape(-1))]) + Pvals['R'] = pR_flat.reshape(Rhat.shape) + if mPPCOEM_Constraints['QhatDiag']: + pQ_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(np.diag(Qhat), np.diag(SE['Q']))]) + Pvals['Q'] = np.diag(pQ_flat) + else: + pQ_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(Qhat.reshape(-1), SE['Q'].reshape(-1))]) + Pvals['Q'] = pQ_flat.reshape(Qhat.shape) + if mPPCOEM_Constraints['EstimatePx0']: + pPx0_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(np.diag(Px0hat), np.diag(SE['Px0']))]) + Pvals['Px0'] = np.diag(pPx0_flat) + if mPPCOEM_Constraints['Estimatex0']: + Pvals['x0'] = np.array([_ztest_pvalue(p, s) for p, s in zip(x0hat, SE['x0'])]) + Pvals['alpha'] = np.array([_ztest_pvalue(p, s) for p, s in zip(alphahat.reshape(-1), SE['alpha'].reshape(-1))]) + Pvals['mu'] = np.array([_ztest_pvalue(p, s) for p, s in zip(muhat, SE['mu'])]) + pBeta_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(betahat.reshape(-1), SE['beta'].reshape(-1))]) + Pvals['beta'] = pBeta_flat.reshape(betahat.shape) + if n10 > 0: + pGamma_flat = np.array([_ztest_pvalue(p, s) for p, s in zip(gammahat.reshape(-1), SE['gamma'].reshape(-1))]) + Pvals['gamma'] = pGamma_flat.reshape(gammahat.shape) if gammahat.ndim == 2 else pGamma_flat + + return SE, Pvals, nTerms + + @staticmethod + def mPPCO_EStep(A, Q, C, R, y, alpha, dN, mu, beta, fitType='poisson', + delta=0.001, gamma=None, HkAll=None, x0=None, Px0=None): + """E-step for the mPPCO EM algorithm. + + Matlab: ``DecodingAlgorithms.mPPCO_EStep`` (lines 6555-6772) + + Returns + ------- + x_K : (dx, K) -- smoothed states + W_K : (dx, dx, K) -- smoothed covariances + logll : float -- log-likelihood + ExpectationSums : dict + """ + A = np.asarray(A, dtype=float) + Q = np.asarray(Q, dtype=float) + C = np.asarray(C, dtype=float) + R = np.asarray(R, dtype=float) + y = np.asarray(y, dtype=float) + alpha = np.asarray(alpha, dtype=float) + obs = _as_observation_matrix(dN) + numCells, K = obs.shape + Dx = A.shape[1] if A.ndim >= 2 else A.shape[0] + Dy = C.shape[0] if C.ndim >= 2 else 1 + mu_vec = np.asarray(mu, dtype=float).reshape(-1) + beta_mat = np.asarray(beta, dtype=float) + if beta_mat.ndim == 1: + beta_mat = beta_mat.reshape(-1, 1) + if gamma is None: + gamma = 0 + gamma_arr = np.asarray(gamma, dtype=float) + if x0 is None or _is_empty_value(x0): + x0 = np.zeros(Dx, dtype=float) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + if Px0 is None or _is_empty_value(Px0): + Px0 = np.zeros((Dx, Dx), dtype=float) + else: + Px0 = np.asarray(Px0, dtype=float).reshape(Dx, Dx) + + if HkAll is None or _is_empty_value(HkAll): + HkAll_arr = np.zeros((K, 1, numCells), dtype=float) + else: + HkAll_arr = np.asarray(HkAll, dtype=float) + + # Forward filter + x_p, W_p, x_u, W_u = DecodingAlgorithms.mPPCODecodeLinear( + A, Q, C, R, y, alpha, dN, mu_vec, beta_mat, fitType, + delta, gamma, None, x0, Px0, HkAll_arr) + + # Smoother -- x_p has N+1 columns, x_u has N columns + # kalman_smootherFromFiltered expects matching shapes + # Trim x_p and W_p to first N entries for smoother input + x_K, W_K, Lk = DecodingAlgorithms.kalman_smootherFromFiltered( + A, x_p[:, :N], W_p[:, :, :N], x_u, W_u) + + # Handle Matlab-style output -- ensure x_K is (dx, K) + if x_K.ndim == 2 and x_K.shape[0] == K and x_K.shape[1] == Dx: + x_K = x_K.T + if W_K.ndim == 3 and W_K.shape[0] == K: + W_K = np.transpose(W_K, (1, 2, 0)) + + # Best estimates of initial state given data + W1G0 = A @ Px0 @ A.T + Q if A.ndim == 2 else A.reshape(Dx, Dx) @ Px0 @ A.reshape(Dx, Dx).T + Q.reshape(Dx, Dx) + A_2d = A.reshape(Dx, Dx) if A.ndim != 2 else A + L0 = Px0 @ A_2d.T @ np.linalg.pinv(W1G0) + Ex0Gy = x0 + L0 @ (x_K[:, 0] - x_p[:, 0]) + Px0Gy = Px0 + L0 @ (np.linalg.pinv(W_K[:, :, 0]) - np.linalg.pinv(W1G0)) @ L0.T + Px0Gy = _symmetrize(Px0Gy) + + # Cross-covariance matrices Wku + numStates = Dx + Wku = np.zeros((numStates, numStates, K, K), dtype=float) + for k in range(K): + Wku[:, :, k, k] = W_K[:, :, k] + + for u in range(K - 1, 0, -1): + k = u - 1 + Dk = W_u[:, :, k] @ A_2d.T @ np.linalg.pinv(W_p[:, :, k + 1]) + Wku[:, :, k, u] = Dk @ Wku[:, :, k + 1, u] + Wku[:, :, u, k] = Wku[:, :, k, u].T + + # Sufficient statistics + Sxkm1xk = np.zeros((Dx, Dx)) + Sxkm1xkm1 = np.zeros((Dx, Dx)) + Sxkxk = np.zeros((Dx, Dx)) + Sykyk = np.zeros((Dy, Dy)) + Sxkyk = np.zeros((Dx, Dy)) + + alpha_vec = alpha.reshape(-1) + for k in range(K): + if k == 0: + Sxkm1xk += Px0 @ A_2d.T @ np.linalg.pinv(W_p[:, :, 0]) @ Wku[:, :, 0, 0] + Sxkm1xkm1 += Px0 + np.outer(x0, x0) + else: + Sxkm1xk += Wku[:, :, k - 1, k] + np.outer(x_K[:, k - 1], x_K[:, k]) + Sxkm1xkm1 += Wku[:, :, k - 1, k - 1] + np.outer(x_K[:, k - 1], x_K[:, k - 1]) + Sxkxk += Wku[:, :, k, k] + np.outer(x_K[:, k], x_K[:, k]) + yk = y[:, k] if y.ndim == 2 else y + Sykyk += np.outer(yk - alpha_vec, yk - alpha_vec) + Sxkyk += np.outer(x_K[:, k], yk - alpha_vec) + + Sxkxk = _symmetrize(Sxkxk) + Sykyk = _symmetrize(Sykyk) + sumXkTerms = Sxkxk - A_2d @ Sxkm1xk - Sxkm1xk.T @ A_2d.T + A_2d @ Sxkm1xkm1 @ A_2d.T + sumYkTerms = Sykyk - C @ Sxkyk - Sxkyk.T @ C.T + C @ Sxkxk @ C.T + Sxkxkm1 = Sxkm1xk.T + + # Log-likelihood with PP term + if str(fitType) == 'poisson': + sumPPll = 0.0 + HkPerm = np.transpose(HkAll_arr, (1, 2, 0)) if HkAll_arr.ndim == 3 and HkAll_arr.shape[0] == K else HkAll_arr + for k in range(K): + Hk = HkPerm[:, :, k] if HkPerm.ndim == 3 else np.zeros((1, numCells)) + if Hk.shape[0] == numCells and Hk.shape[1] != numCells: + Hk = Hk.T + xk = x_K[:, k] + gammaC_mat = np.tile(gamma_arr.reshape(-1, 1), (1, numCells)) if gamma_arr.size == 1 else gamma_arr + if gammaC_mat.ndim == 2 and gammaC_mat.shape[1] != numCells: + gammaC_mat = np.tile(gammaC_mat, (1, numCells)) + terms = mu_vec + beta_mat.T @ xk + np.diag(gammaC_mat.T @ Hk) if Hk.size > 0 and gammaC_mat.size > 0 else mu_vec + beta_mat.T @ xk + Wk = W_K[:, :, k] + ld = np.exp(np.clip(terms, -500, 500)) + bt = beta_mat + ExplambdaDelta = ld + 0.5 * (ld * np.diag(bt.T @ Wk @ bt)) + ExplogLD = terms + sumPPll += float(np.sum(obs[:, k] * ExplogLD - ExplambdaDelta)) + else: # binomial + sumPPll = 0.0 + HkPerm = np.transpose(HkAll_arr, (1, 2, 0)) if HkAll_arr.ndim == 3 and HkAll_arr.shape[0] == K else HkAll_arr + for k in range(K): + Hk = HkPerm[:, :, k] if HkPerm.ndim == 3 else np.zeros((1, numCells)) + if Hk.shape[0] == numCells and Hk.shape[1] != numCells: + Hk = Hk.T + xk = x_K[:, k] + gammaC_mat = np.tile(gamma_arr.reshape(-1, 1), (1, numCells)) if gamma_arr.size == 1 else gamma_arr + if gammaC_mat.ndim == 2 and gammaC_mat.shape[1] != numCells: + gammaC_mat = np.tile(gammaC_mat, (1, numCells)) + terms = mu_vec + beta_mat.T @ xk + np.diag(gammaC_mat.T @ Hk) if Hk.size > 0 and gammaC_mat.size > 0 else mu_vec + beta_mat.T @ xk + Wk = W_K[:, :, k] + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + bt = beta_mat + ExplambdaDelta = ld + 0.5 * (ld * (1 - ld) * (1 - 2 * ld)) * np.diag(bt.T @ Wk @ bt) + ExplogLD = np.log(np.clip(ld, 1e-300, None)) + 0.5 * (-ld * (1 - ld)) * np.diag(bt.T @ Wk @ bt) + sumPPll += float(np.sum(obs[:, k] * ExplogLD - ExplambdaDelta)) + + Q_2d = Q.reshape(Dx, Dx) if Q.ndim != 2 else Q + R_2d = R.reshape(Dy, Dy) if R.ndim != 2 else R + logll = (-Dx * K / 2.0 * np.log(2 * np.pi) + - K / 2.0 * np.log(max(np.linalg.det(Q_2d), 1e-300)) + - Dy * K / 2.0 * np.log(2 * np.pi) + - K / 2.0 * np.log(max(np.linalg.det(R_2d), 1e-300)) + - Dx / 2.0 * np.log(2 * np.pi) + - 0.5 * np.log(max(np.linalg.det(Px0), 1e-300)) + + sumPPll + - 0.5 * np.trace(np.linalg.solve(Q_2d, sumXkTerms)) + - 0.5 * np.trace(np.linalg.solve(R_2d, sumYkTerms)) + - Dx / 2.0) + + ExpectationSums = { + 'Sxkm1xkm1': Sxkm1xkm1, + 'Sxkm1xk': Sxkm1xk, + 'Sxkxkm1': Sxkxkm1, + 'Sxkxk': Sxkxk, + 'Sxkyk': Sxkyk, + 'Sykyk': Sykyk, + 'sumXkTerms': sumXkTerms, + 'sumYkTerms': sumYkTerms, + 'sumPPll': sumPPll, + 'Sx0': Ex0Gy, + 'Sx0x0': Px0Gy + np.outer(Ex0Gy, Ex0Gy), + } + + return x_K, W_K, float(logll), ExpectationSums + + @staticmethod + def mPPCO_MStep(dN, y, x_K, W_K, x0, Px0, ExpectationSums, fitType='poisson', + muhat=None, betahat=None, gammahat=None, windowTimes=None, + HkAll=None, mPPCOEM_Constraints=None, MstepMethod='GLM'): + """M-step for the mPPCO EM algorithm. + + Matlab: ``DecodingAlgorithms.mPPCO_MStep`` (lines 6773-7662) + + Returns + ------- + Ahat, Qhat, Chat, Rhat, alphahat, muhat_new, betahat_new, gammahat_new, x0hat, Px0hat + """ + if mPPCOEM_Constraints is None: + mPPCOEM_Constraints = DecodingAlgorithms.mPPCO_EMCreateConstraints() + + obs = _as_observation_matrix(dN) + numCells = obs.shape[0] + x_K = np.asarray(x_K, dtype=float) + y = np.asarray(y, dtype=float) + x0 = np.asarray(x0, dtype=float).reshape(-1) + Px0 = np.asarray(Px0, dtype=float) + muhat = np.asarray(muhat, dtype=float).reshape(-1) + betahat = np.asarray(betahat, dtype=float) + if betahat.ndim == 1: + betahat = betahat.reshape(-1, 1) + gammahat = np.asarray(gammahat, dtype=float) + if HkAll is None or _is_empty_value(HkAll): + HkAll = np.zeros((obs.shape[1], 1, numCells), dtype=float) + else: + HkAll = np.asarray(HkAll, dtype=float) + + Sxkm1xkm1 = ExpectationSums['Sxkm1xkm1'] + Sxkm1xk = ExpectationSums['Sxkm1xk'] + Sxkxkm1 = ExpectationSums['Sxkxkm1'] + Sxkxk = ExpectationSums['Sxkxk'] + Sxkyk = ExpectationSums['Sxkyk'] + Sykyk = ExpectationSums['Sykyk'] + sumXkTerms = ExpectationSums['sumXkTerms'] + sumYkTerms = ExpectationSums['sumYkTerms'] + Sx0 = ExpectationSums['Sx0'] + Sx0x0 = ExpectationSums['Sx0x0'] + + dx, K = x_K.shape + dy = y.shape[0] if y.ndim == 2 else 1 + I_dx = np.eye(dx) + + # A estimate + if mPPCOEM_Constraints['AhatDiag']: + Ahat = (Sxkxkm1 * I_dx) @ np.linalg.inv(Sxkm1xkm1 * I_dx) + else: + Ahat = Sxkxkm1 @ np.linalg.inv(Sxkm1xkm1) + + # C estimate + Chat = Sxkyk.T @ np.linalg.inv(Sxkxk) + + # alpha estimate + alphahat = np.sum(y - Chat @ x_K, axis=1) / K if y.ndim == 2 else (y - Chat @ x_K) / K + + # Q estimate + if mPPCOEM_Constraints['QhatDiag']: + if mPPCOEM_Constraints['QhatIsotropic']: + Qhat = (1.0 / (dx * K)) * np.trace(sumXkTerms) * I_dx + else: + Qhat = (1.0 / K) * (sumXkTerms * I_dx) + Qhat = _symmetrize(Qhat) + else: + Qhat = (1.0 / K) * sumXkTerms + Qhat = _symmetrize(Qhat) + + # R estimate + I_dy = np.eye(dy) + if mPPCOEM_Constraints['RhatDiag']: + if mPPCOEM_Constraints['RhatIsotropic']: + Rhat = (1.0 / (dy * K)) * np.trace(sumYkTerms) * I_dy + else: + Rhat = (1.0 / K) * (sumYkTerms * I_dy) + Rhat = _symmetrize(Rhat) + else: + Rhat = (1.0 / K) * sumYkTerms + Rhat = _symmetrize(Rhat) + + # x0 estimate + if mPPCOEM_Constraints['Estimatex0']: + x0hat = np.linalg.solve( + np.linalg.inv(Px0) + Ahat.T @ np.linalg.solve(Qhat, Ahat), + Ahat.T @ np.linalg.solve(Qhat, x_K[:, 0]) + np.linalg.solve(Px0, x0)) + else: + x0hat = x0.copy() + + # Px0 estimate + if mPPCOEM_Constraints['EstimatePx0']: + if mPPCOEM_Constraints['Px0Isotropic']: + Px0hat = (np.trace(np.outer(x0hat, x0hat) - np.outer(x0, x0hat) - np.outer(x0hat, x0) + np.outer(x0, x0)) / (dx * K)) * I_dx + else: + Px0hat = (np.outer(x0hat, x0hat) - np.outer(x0, x0hat) - np.outer(x0hat, x0) + np.outer(x0, x0)) * I_dx + Px0hat = _symmetrize(Px0hat) + else: + Px0hat = Px0.copy() + + # CIF parameter updates via Newton-Raphson + betahat_new = betahat.copy() + gammahat_new = gammahat.copy() + muhat_new = muhat.copy() + + # Newton-Raphson for beta, mu, gamma + McExp = 50 + diffTol = 1e-5 + maxIter_nr = 100 + + xKDrawExp = np.zeros((dx, K, McExp), dtype=float) + for k in range(K): + WuTemp = W_K[:, :, k] + try: + chol_m = np.linalg.cholesky(WuTemp) + except np.linalg.LinAlgError: + chol_m = np.linalg.cholesky(nearestSPD(WuTemp)) + z = np.random.randn(dx, McExp) + xKDrawExp[:, k, :] = x_K[:, k:k + 1] + chol_m @ z + + xkPerm = np.transpose(xKDrawExp, (0, 2, 1)) # (dx, McExp, K) + + # -- beta update -- + for c in range(numCells): + converged = False + iterNR = 0 + while not converged and iterNR < maxIter_nr: + HessianTerm = np.zeros((dx, dx)) + GradTerm = np.zeros(dx) + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 1)) + Hk = Hk_full[k, :] + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, c] if gammahat.ndim == 2 else gammahat) + terms = muhat[c] + betahat_new[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + ExpLambdaXk = (1.0 / McExp) * np.sum(np.tile(ld, (dx, 1)) * xk, axis=1) + ExpLambdaXkXkT = (1.0 / McExp) * (np.tile(ld, (dx, 1)) * xk) @ xk.T + GradTerm += obs[c, k] * x_K[:, k] - ExpLambdaXk + HessianTerm -= ExpLambdaXkXkT + else: + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + EldXkXk = (1.0 / McExp) * (np.tile(ld, (dx, 1)) * xk) @ xk.T + EldSqXkXk = (1.0 / McExp) * (np.tile(ld ** 2, (dx, 1)) * xk) @ xk.T + EldCubeXkXk = (1.0 / McExp) * (np.tile(ld ** 3, (dx, 1)) * xk) @ xk.T + ExpLambdaXk = (1.0 / McExp) * np.sum(np.tile(ld, (dx, 1)) * xk, axis=1) + ExpLambdaSquaredXk = (1.0 / McExp) * np.sum(np.tile(ld ** 2, (dx, 1)) * xk, axis=1) + GradTerm += obs[c, k] * x_K[:, k] - (obs[c, k] + 1) * ExpLambdaXk + ExpLambdaSquaredXk + HessianTerm += EldXkXk + EldSqXkXk - 2 * EldCubeXkXk + + if np.any(np.isnan(HessianTerm)) or np.any(np.isinf(HessianTerm)): + betahat_newTemp = betahat_new[:, c] + else: + try: + betahat_newTemp = betahat_new[:, c] - np.linalg.solve(HessianTerm, GradTerm) + except np.linalg.LinAlgError: + betahat_newTemp = betahat_new[:, c] + if np.any(np.isnan(betahat_newTemp)): + betahat_newTemp = betahat_new[:, c] + + mabsDiff = float(np.max(np.abs(betahat_newTemp - betahat_new[:, c]))) + if mabsDiff < diffTol: + converged = True + betahat_new[:, c] = betahat_newTemp + iterNR += 1 + + # -- mu update -- + for c in range(numCells): + converged = False + iterNR = 0 + while not converged and iterNR < maxIter_nr: + HessianTerm_mu = 0.0 + GradTerm_mu = 0.0 + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 1)) + Hk = Hk_full[k, :] + xk = xkPerm[:, :, k] + gammaC = gammahat if gammahat.size == 1 else (gammahat[:, c] if gammahat.ndim == 2 else gammahat) + terms = muhat_new[c] + betahat[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + ExpLD = (1.0 / McExp) * float(np.sum(ld)) + GradTerm_mu += obs[c, k] - ExpLD + HessianTerm_mu -= ExpLD + else: + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + ExpLD = (1.0 / McExp) * float(np.sum(ld)) + ExpLDSq = (1.0 / McExp) * float(np.sum(ld ** 2)) + ExpLDCube = (1.0 / McExp) * float(np.sum(ld ** 3)) + GradTerm_mu += obs[c, k] - (obs[c, k] + 1) * ExpLD + ExpLDSq + HessianTerm_mu += -ExpLD * (obs[c, k] + 1) + ExpLDSq * (obs[c, k] + 3) - 2 * ExpLDCube + + if np.isnan(HessianTerm_mu) or np.isinf(HessianTerm_mu) or abs(HessianTerm_mu) < 1e-300: + muhat_newTemp = muhat_new[c] + else: + muhat_newTemp = muhat_new[c] - GradTerm_mu / HessianTerm_mu + if np.isnan(muhat_newTemp): + muhat_newTemp = muhat_new[c] + + mabsDiff = abs(muhat_newTemp - muhat_new[c]) + if mabsDiff < diffTol: + converged = True + muhat_new[c] = muhat_newTemp + iterNR += 1 + + # -- gamma update -- + if (windowTimes is not None and not _is_empty_value(windowTimes) + and np.any(gammahat_new != 0)): + nGamma = gammahat.shape[0] if gammahat.ndim >= 1 else 1 + for c in range(numCells): + converged = False + iterNR = 0 + gammaC = gammahat_new.copy() if gammahat_new.size == 1 else (gammahat_new[:, c].copy() if gammahat_new.ndim == 2 else gammahat_new.copy()) + while not converged and iterNR < maxIter_nr: + HessianTerm_g = np.zeros((nGamma, nGamma)) + GradTerm_g = np.zeros(nGamma) + for k in range(K): + Hk_full = HkAll[:, :, c] if HkAll.ndim == 3 else np.zeros((K, 1)) + Hk = Hk_full[k, :] + xk = xkPerm[:, :, k] + terms = muhat[c] + betahat[:, c] @ xk + float(np.dot(gammaC.reshape(-1), Hk.reshape(-1))) + if fitType == 'poisson': + ld = np.exp(np.clip(terms, -500, 500)) + ExpLD = (1.0 / McExp) * float(np.sum(ld)) + GradTerm_g += (obs[c, k] - ExpLD) * Hk + HessianTerm_g -= ExpLD * np.outer(Hk, Hk) + else: + ld_raw = np.exp(np.clip(terms, -500, 500)) + ld = ld_raw / (1.0 + ld_raw) + ExpLD = (1.0 / McExp) * float(np.sum(ld)) + ExpLDSq = (1.0 / McExp) * float(np.sum(ld ** 2)) + ExpLDCube = (1.0 / McExp) * float(np.sum(ld ** 3)) + GradTerm_g += (obs[c, k] - (obs[c, k] + 1) * ExpLD + ExpLDSq) * Hk + HessianTerm_g += (-ExpLD * (obs[c, k] + 1) + ExpLDSq * (obs[c, k] + 3) - 2 * ExpLDCube) * np.outer(Hk, Hk) + + if np.any(np.isnan(HessianTerm_g)) or np.any(np.isinf(HessianTerm_g)): + gammahat_newTemp = gammaC.copy() + else: + try: + gammahat_newTemp = gammaC - np.linalg.solve(HessianTerm_g, GradTerm_g) + except np.linalg.LinAlgError: + gammahat_newTemp = gammaC.copy() + if np.any(np.isnan(gammahat_newTemp)): + gammahat_newTemp = gammaC.copy() + + mabsDiff = float(np.max(np.abs(gammahat_newTemp - gammaC))) + if mabsDiff < diffTol: + converged = True + gammaC = gammahat_newTemp + iterNR += 1 + + if gammahat_new.ndim == 2: + gammahat_new[:, c] = gammaC + else: + gammahat_new = gammaC + + return Ahat, Qhat, Chat, Rhat, alphahat, muhat_new, betahat_new, gammahat_new, x0hat, Px0hat + + @staticmethod + def mPPCO_EM(y, dN, Ahat0, Qhat0, Chat0, Rhat0, alphahat0, mu, beta, + fitType='poisson', delta=0.001, gamma=None, windowTimes=None, + x0=None, Px0=None, mPPCOEM_Constraints=None, MstepMethod='GLM'): + """Full EM algorithm for the mixed Point-Process / Continuous Observation model. + + Matlab: ``DecodingAlgorithms.mPPCO_EM`` (lines 6139-6554) + + Returns + ------- + xKFinal, WKFinal, Ahat, Qhat, Chat, Rhat, alphahat, + muhat, betahat, gammahat, x0hat, Px0hat, IC, SE, Pvals + """ + Ahat0 = np.asarray(Ahat0, dtype=float) + Qhat0 = np.asarray(Qhat0, dtype=float) + Chat0 = np.asarray(Chat0, dtype=float) + Rhat0 = np.asarray(Rhat0, dtype=float) + alphahat0 = np.asarray(alphahat0, dtype=float).reshape(-1) + mu = np.asarray(mu, dtype=float).reshape(-1) + beta = np.asarray(beta, dtype=float) + if beta.ndim == 1: + beta = beta.reshape(-1, 1) + numStates = Ahat0.shape[0] + obs = _as_observation_matrix(dN) + numCells_K, N = obs.shape + + if mPPCOEM_Constraints is None: + mPPCOEM_Constraints = DecodingAlgorithms.mPPCO_EMCreateConstraints() + if Px0 is None or _is_empty_value(Px0): + Px0 = 1e-9 * np.eye(numStates) + else: + Px0 = np.asarray(Px0, dtype=float).reshape(numStates, numStates) + if x0 is None or _is_empty_value(x0): + x0 = np.zeros(numStates, dtype=float) + else: + x0 = np.asarray(x0, dtype=float).reshape(-1) + if gamma is None: + gamma = np.array(0.0) + else: + gamma = np.asarray(gamma, dtype=float) + if delta is None: + delta = 0.001 + if windowTimes is None or _is_empty_value(windowTimes): + if gamma is not None and np.any(gamma != 0): + windowTimes = np.arange(gamma.size + 2, dtype=float) * delta + else: + windowTimes = None + + minTime = 0.0 + maxTime = (N - 1) * delta + K_cells = numCells_K + + # Build history + if windowTimes is not None and not _is_empty_value(windowTimes): + wt = np.asarray(windowTimes, dtype=float).reshape(-1) + HkAll = _compute_history_terms(dN, delta, wt) + else: + HkAll = np.zeros((N, 1, K_cells), dtype=float) + gamma = np.array(0.0) + + y_arr = np.asarray(y, dtype=float) + yOrig = y_arr.copy() + + tolAbs = 1e-3 + llTol = 1e-3 + maxIter = 100 + numToKeep = 10 + + # Circular buffers + Ahat_buf = [None] * numToKeep + Qhat_buf = [None] * numToKeep + Chat_buf = [None] * numToKeep + Rhat_buf = [None] * numToKeep + alphahat_buf = [None] * numToKeep + muhat_buf = [None] * numToKeep + betahat_buf = [None] * numToKeep + gammahat_buf = [None] * numToKeep + x0hat_buf = [None] * numToKeep + Px0hat_buf = [None] * numToKeep + x_K_buf = [None] * numToKeep + W_K_buf = [None] * numToKeep + ExpSums_buf = [None] * numToKeep + ll_list = [] + + # Initialize (scaled system) + A0 = Ahat0.copy() + Q0 = Qhat0.copy() + C0 = Chat0.copy() + R0 = Rhat0.copy() + + Tq = np.linalg.solve(np.linalg.cholesky(Q0), np.eye(numStates)) + Tr = np.linalg.solve(np.linalg.cholesky(R0), np.eye(R0.shape[0])) + + Ahat_buf[0] = Tq @ A0 @ np.linalg.inv(Tq) + Chat_buf[0] = Tr @ C0 @ np.linalg.inv(Tq) + Qhat_buf[0] = Tq @ Q0 @ Tq.T + Rhat_buf[0] = Tr @ R0 @ Tr.T + y_arr = Tr @ y_arr + x0hat_buf[0] = Tq @ x0 + Px0hat_buf[0] = Tq @ Px0 @ Tq.T + alphahat_buf[0] = Tr @ alphahat0 + betahat_buf[0] = np.linalg.solve(Tq.T, beta) + muhat_buf[0] = mu.copy() + gammahat_buf[0] = gamma.copy() + + cnt = 0 + stoppingCriteria = False + + print(" Joint Point-Process/Gaussian Observation EM Algorithm ") + + while not stoppingCriteria and cnt < maxIter: + si = cnt % numToKeep + si_p1 = (cnt + 1) % numToKeep + si_m1 = (cnt - 1) % numToKeep + + print("-" * 80) + print(f"Iteration #{cnt + 1}") + print("-" * 80) + + # E-step + x_K_buf[si], W_K_buf[si], ll_val, ExpSums_buf[si] = DecodingAlgorithms.mPPCO_EStep( + Ahat_buf[si], Qhat_buf[si], Chat_buf[si], Rhat_buf[si], + y_arr, alphahat_buf[si], dN, + muhat_buf[si], betahat_buf[si], fitType, delta, + gammahat_buf[si], HkAll, x0hat_buf[si], Px0hat_buf[si]) + ll_list.append(ll_val) + + # M-step + (Ahat_buf[si_p1], Qhat_buf[si_p1], Chat_buf[si_p1], Rhat_buf[si_p1], + alphahat_buf[si_p1], muhat_buf[si_p1], betahat_buf[si_p1], + gammahat_buf[si_p1], x0hat_buf[si_p1], Px0hat_buf[si_p1]) = \ + DecodingAlgorithms.mPPCO_MStep( + dN, y_arr, x_K_buf[si], W_K_buf[si], + x0hat_buf[si], Px0hat_buf[si], ExpSums_buf[si], + fitType, muhat_buf[si], betahat_buf[si], + gammahat_buf[si], windowTimes, HkAll, + mPPCOEM_Constraints, MstepMethod) + + if not mPPCOEM_Constraints['EstimateA']: + Ahat_buf[si_p1] = Ahat_buf[si].copy() + + # Convergence check + if cnt == 0: + dMax = np.inf + else: + diffs = [] + for arr_curr, arr_prev in [ + (Qhat_buf[si], Qhat_buf[si_m1]), + (Rhat_buf[si], Rhat_buf[si_m1]), + (Ahat_buf[si], Ahat_buf[si_m1]), + (Chat_buf[si], Chat_buf[si_m1]), + ]: + if arr_curr is not None and arr_prev is not None: + diffs.append(float(np.max(np.abs(np.sqrt(np.abs(arr_curr)) - np.sqrt(np.abs(arr_prev))))) if 'Q' in str(id(arr_curr)) else float(np.max(np.abs(arr_curr - arr_prev)))) + for arr_curr, arr_prev in [ + (muhat_buf[si], muhat_buf[si_m1]), + (alphahat_buf[si], alphahat_buf[si_m1]), + (betahat_buf[si], betahat_buf[si_m1]), + (gammahat_buf[si], gammahat_buf[si_m1]), + ]: + if arr_curr is not None and arr_prev is not None: + diffs.append(float(np.max(np.abs(np.asarray(arr_curr) - np.asarray(arr_prev))))) + dMax = max(diffs) if diffs else np.inf + + if cnt == 0: + print("Max Parameter Change: N/A") + else: + print(f"Max Parameter Change: {dMax}") + + cnt += 1 + + if dMax < tolAbs: + stoppingCriteria = True + print(f" EM converged at iteration# {cnt} b/c change in params was within criteria") + + if cnt >= 2: + dll = ll_list[-1] - ll_list[-2] + if abs(dll) < llTol or dll < 0: + stoppingCriteria = True + print(f" EM stopped at iteration# {cnt} b/c change in likelihood was negative or small") + + print("-" * 80) + + # Select best iteration + ll_arr = np.array(ll_list) + maxLLIndex = int(np.argmax(ll_arr)) + maxLLIndMod = maxLLIndex % numToKeep + + xKFinal = x_K_buf[maxLLIndMod] + WKFinal = W_K_buf[maxLLIndMod] + Ahat_out = Ahat_buf[maxLLIndMod] + Qhat_out = Qhat_buf[maxLLIndMod] + Chat_out = Chat_buf[maxLLIndMod] + Rhat_out = Rhat_buf[maxLLIndMod] + alphahat_out = alphahat_buf[maxLLIndMod] + muhat_out = muhat_buf[maxLLIndMod] + betahat_out = betahat_buf[maxLLIndMod] + gammahat_out = gammahat_buf[maxLLIndMod] + x0hat_out = x0hat_buf[maxLLIndMod] + Px0hat_out = Px0hat_buf[maxLLIndMod] + ExpectationSumsFinal = ExpSums_buf[maxLLIndMod] + + # Unscale + Tq = np.linalg.solve(np.linalg.cholesky(Q0), np.eye(numStates)) + Tr = np.linalg.solve(np.linalg.cholesky(R0), np.eye(R0.shape[0])) + Tq_inv = np.linalg.inv(Tq) + Tr_inv = np.linalg.inv(Tr) + + Ahat_out = Tq_inv @ Ahat_out @ Tq + Qhat_out = Tq_inv @ Qhat_out @ np.linalg.inv(Tq.T) + Chat_out = Tr_inv @ Chat_out @ Tq + Rhat_out = Tr_inv @ Rhat_out @ np.linalg.inv(Tr.T) + alphahat_out = Tr_inv @ alphahat_out + xKFinal = Tq_inv @ xKFinal + x0hat_out = Tq_inv @ x0hat_out + Px0hat_out = Tq_inv @ Px0hat_out @ np.linalg.inv(Tq.T) + for kk in range(WKFinal.shape[2]): + WKFinal[:, :, kk] = Tq_inv @ WKFinal[:, :, kk] @ np.linalg.inv(Tq.T) + betahat_out = (betahat_out.T @ Tq).T + + # Information criteria + ll_best = ll_arr[maxLLIndex] + # Count parameters + if mPPCOEM_Constraints['EstimateA'] and mPPCOEM_Constraints['AhatDiag']: + n1 = Ahat_out.shape[0] + elif mPPCOEM_Constraints['EstimateA']: + n1 = Ahat_out.size + else: + n1 = 0 + + if mPPCOEM_Constraints['QhatDiag'] and mPPCOEM_Constraints['QhatIsotropic']: + n2 = 1 + elif mPPCOEM_Constraints['QhatDiag']: + n2 = Qhat_out.shape[0] + else: + n2 = Qhat_out.size + + n3 = Chat_out.size + + if mPPCOEM_Constraints['RhatDiag'] and mPPCOEM_Constraints['RhatIsotropic']: + n4 = 1 + elif mPPCOEM_Constraints['RhatDiag']: + n4 = Rhat_out.shape[0] + else: + n4 = Rhat_out.size + + if mPPCOEM_Constraints['EstimatePx0'] and mPPCOEM_Constraints['Px0Isotropic']: + n5 = 1 + elif mPPCOEM_Constraints['EstimatePx0']: + n5 = Px0hat_out.shape[0] + else: + n5 = 0 + + n6 = x0hat_out.size if mPPCOEM_Constraints['Estimatex0'] else 0 + n7 = alphahat_out.size + n8 = muhat_out.size + n9 = betahat_out.size + if gammahat_out.size == 1 and float(gammahat_out.flat[0]) == 0: + n10 = 0 + else: + n10 = gammahat_out.size + nTerms = n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + + Dx = Ahat_out.shape[1] + sumXkTerms = ExpectationSumsFinal['sumXkTerms'] + llobs = (ll_best + Dx * N / 2.0 * np.log(2 * np.pi) + + N / 2.0 * np.log(max(np.linalg.det(Qhat_out), 1e-300)) + + 0.5 * np.trace(np.linalg.solve(Qhat_out, sumXkTerms)) + + Dx / 2.0 * np.log(2 * np.pi) + + 0.5 * np.log(max(np.linalg.det(Px0hat_out), 1e-300)) + + 0.5 * Dx) + + AIC = 2 * nTerms - 2 * llobs + AICc = AIC + 2 * nTerms * (nTerms + 1) / max(N - nTerms - 1, 1) + BIC = -2 * llobs + nTerms * np.log(max(N, 1)) + + IC = { + 'AIC': AIC, 'AICc': AICc, 'BIC': BIC, + 'llobs': llobs, 'llcomp': ll_best, + } + + # Standard errors + SE = {} + Pvals = {} + try: + SE, Pvals, _ = DecodingAlgorithms.mPPCO_ComputeParamStandardErrors( + yOrig, dN, xKFinal, WKFinal, Ahat_out, Qhat_out, + Chat_out, Rhat_out, alphahat_out, x0hat_out, Px0hat_out, + ExpectationSumsFinal, fitType, muhat_out, betahat_out, + gammahat_out, windowTimes, HkAll, mPPCOEM_Constraints) + except Exception: + pass + + return (xKFinal, WKFinal, Ahat_out, Qhat_out, Chat_out, Rhat_out, + alphahat_out, muhat_out, betahat_out, gammahat_out, + x0hat_out, Px0hat_out, IC, SE, Pvals) + + + + +# Module-level aliases for backward compatibility PP_fixedIntervalSmoother = DecodingAlgorithms.PP_fixedIntervalSmoother PPDecodeFilter = DecodingAlgorithms.PPDecodeFilter PPDecodeFilterLinear = DecodingAlgorithms.PPDecodeFilterLinear @@ -2225,14 +7634,47 @@ def prepareEMResults(fitType, neuronNumber, dN, HkAll, xK, WK, Q, gamma, kalman_smootherFromFiltered = DecodingAlgorithms.kalman_smootherFromFiltered kalman_smoother = DecodingAlgorithms.kalman_smoother ComputeStimulusCIs = DecodingAlgorithms.ComputeStimulusCIs +computeSpikeRateCIs = DecodingAlgorithms.computeSpikeRateCIs +computeSpikeRateDiffCIs = DecodingAlgorithms.computeSpikeRateDiffCIs ukf = DecodingAlgorithms.ukf ukf_ut = DecodingAlgorithms.ukf_ut ukf_sigmas = DecodingAlgorithms.ukf_sigmas +KF_EM = DecodingAlgorithms.KF_EM +KF_EMCreateConstraints = DecodingAlgorithms.KF_EMCreateConstraints +KF_EStep = DecodingAlgorithms.KF_EStep +KF_MStep = DecodingAlgorithms.KF_MStep +KF_ComputeParamStandardErrors = DecodingAlgorithms.KF_ComputeParamStandardErrors +PP_EM = DecodingAlgorithms.PP_EM +PP_EMCreateConstraints = DecodingAlgorithms.PP_EMCreateConstraints +PP_ComputeParamStandardErrors = DecodingAlgorithms.PP_ComputeParamStandardErrors +PP_EStep = DecodingAlgorithms.PP_EStep +PP_MStep = DecodingAlgorithms.PP_MStep +mPPCODecode_predict = DecodingAlgorithms.mPPCODecode_predict +mPPCODecode_update = DecodingAlgorithms.mPPCODecode_update +mPPCODecodeLinear = DecodingAlgorithms.mPPCODecodeLinear +mPPCO_fixedIntervalSmoother = DecodingAlgorithms.mPPCO_fixedIntervalSmoother +mPPCO_EMCreateConstraints = DecodingAlgorithms.mPPCO_EMCreateConstraints +mPPCO_ComputeParamStandardErrors = DecodingAlgorithms.mPPCO_ComputeParamStandardErrors +mPPCO_EM = DecodingAlgorithms.mPPCO_EM +mPPCO_EStep = DecodingAlgorithms.mPPCO_EStep +mPPCO_MStep = DecodingAlgorithms.mPPCO_MStep __all__ = [ "ComputeStimulusCIs", "DecodingAlgorithms", + "computeSpikeRateCIs", + "computeSpikeRateDiffCIs", + "KF_ComputeParamStandardErrors", + "KF_EM", + "KF_EMCreateConstraints", + "KF_EStep", + "KF_MStep", + "PP_ComputeParamStandardErrors", + "PP_EM", + "PP_EMCreateConstraints", + "PP_EStep", + "PP_MStep", "PPDecodeFilter", "PPDecodeFilterLinear", "PPDecode_predict", @@ -2251,6 +7693,15 @@ def prepareEMResults(fitType, neuronNumber, dN, HkAll, xK, WK, Q, gamma, "kalman_smoother", "kalman_smootherFromFiltered", "kalman_update", + "mPPCODecode_predict", + "mPPCODecode_update", + "mPPCODecodeLinear", + "mPPCO_ComputeParamStandardErrors", + "mPPCO_EM", + "mPPCO_EMCreateConstraints", + "mPPCO_EStep", + "mPPCO_MStep", + "mPPCO_fixedIntervalSmoother", "ukf", "ukf_sigmas", "ukf_ut", diff --git a/nstat/events.py b/nstat/events.py index 9332e597..bcbbec69 100644 --- a/nstat/events.py +++ b/nstat/events.py @@ -42,7 +42,18 @@ def fromStructure(structure: dict[str, Any] | None) -> "Events" | None: event_color = structure.get("eventColor", "r") return Events(event_times, event_labels, event_color) - def plot(self, *_, handle=None, **__): + def plot(self, *_, handle=None, colorString: str | None = None, **__): + """Plot event markers on one or more axes. + + Parameters + ---------- + handle : Axes or list[Axes], optional + Axes to plot into (default: current axes). + colorString : str, optional + Override line colour for event lines (default: ``'r'``). + Matches Matlab ``Events.plot`` ``colorString`` parameter. + """ + color = colorString if colorString is not None else "r" if handle is None: handles = [plt.gca()] elif isinstance(handle, Sequence) and not hasattr(handle, "plot"): @@ -62,7 +73,7 @@ def plot(self, *_, handle=None, **__): np.full(self.eventTimes.shape, float(v[3]), dtype=float), ] ) - ax.plot(times, y, "r", linewidth=4) + ax.plot(times, y, color, linewidth=4) for event_time, label in zip(self.eventTimes, self.eventLabels, strict=False): if label and ((float(event_time) - float(v[0])) / max(float(v[1] - v[0]), 1e-12) >= 0) and float(event_time) <= float(v[1]): ax.text( diff --git a/nstat/fit.py b/nstat/fit.py index 5c306290..535d9dd7 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from dataclasses import dataclass from typing import Any, Iterable, Sequence @@ -679,15 +680,41 @@ def addParamsToFit(self, neuronNum, lambda_signal, b, dev, stats, AIC, BIC, logL self.__dict__.update(merged.__dict__) return self - def getCoeffs(self, fit_num: int = 1) -> np.ndarray: + def _rawCoeffs(self, fit_num: int = 1) -> np.ndarray: + """Return the raw coefficient vector for *fit_num* (1-based).""" return self.b[fit_num - 1].copy() + def getCoeffs(self, fit_num: int = 1) -> np.ndarray: + """Return the coefficient vector for *fit_num* (1-based). + + In Matlab ``[coeffMat, labels, SEMat] = getCoeffs(fitObj, fitNum)`` + returns multiple outputs. Use :meth:`getCoeffsWithLabels` to obtain + the full ``(coeffMat, labels, SEMat)`` tuple. + """ + return self._rawCoeffs(fit_num) + def getHistCoeffs(self, fit_num: int = 1) -> np.ndarray: + """Return the history-coefficient vector for *fit_num* (1-based). + + In Matlab ``[histMat, labels, SEMat] = getHistCoeffs(fitObj, fitNum)`` + returns multiple outputs. Use :meth:`getHistCoeffsWithLabels` to + obtain the full ``(histMat, labels, SEMat)`` tuple. + """ num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 - coeff = self.getCoeffs(fit_num) if num_hist <= 0: return np.array([], dtype=float) - return coeff[-num_hist:] + return self._rawCoeffs(fit_num)[-num_hist:] + + def getHistCoeffsWithLabels(self, fit_num: int = 1) -> tuple[np.ndarray, list[str], np.ndarray]: + """Return ``(histMat, labels, SEMat)`` — Matlab multi-output form. + + Matlab: ``[histMat, labels, SEMat] = getHistCoeffs(fitObj, fitNum)`` + """ + num_hist = int(self.numHist[fit_num - 1]) if fit_num - 1 < len(self.numHist) else 0 + coeffs, labels, se = self.getCoeffsWithLabels(fit_num) + if num_hist <= 0: + return np.array([], dtype=float), [], np.array([], dtype=float) + return coeffs[-num_hist:], labels[-num_hist:], se[-num_hist:] def getCoeffIndex(self, fit_num: int = 1, sortByEpoch: int = 0): del sortByEpoch @@ -717,7 +744,7 @@ def getParam(self, paramNames, fit_num: int = 1): return coeffs[indices], se[indices], sig[indices] def getCoeffsWithLabels(self, fit_num: int = 1) -> tuple[np.ndarray, list[str], np.ndarray]: - coeffs = self.getCoeffs(fit_num) + coeffs = self._rawCoeffs(fit_num) labels = list(self.covLabels[fit_num - 1]) if fit_num - 1 < len(self.covLabels) else [f"b_{idx + 1}" for idx in range(coeffs.size)] if coeffs.size == len(labels) + 1: labels = ["Intercept", *labels] @@ -820,7 +847,7 @@ def _primary_spike_train(self) -> nspikeTrain: return self.neuralSpikeTrain[0] raise TypeError("FitResult does not contain a MATLAB-style neural spike train") - def _compute_diagnostics(self, fit_num: int = 1) -> dict[str, np.ndarray | float]: + def _compute_diagnostics(self, fit_num: int = 1, *, dt_correction: int = 1) -> dict[str, np.ndarray | float]: if fit_num in self._diagnostic_cache: return self._diagnostic_cache[fit_num] @@ -848,7 +875,7 @@ def _compute_diagnostics(self, fit_num: int = 1) -> dict[str, np.ndarray | float self.lambda_signal.yunits, selected_labels, ) - Z, U, xAxis, KSSorted, _ = _matlab_compute_ks_arrays(self._primary_spike_train(), lambda_signal, dt_correction=1) + Z, U, xAxis, KSSorted, _ = _matlab_compute_ks_arrays(self._primary_spike_train(), lambda_signal, dt_correction=dt_correction) z = np.asarray(Z[:, 0], dtype=float).reshape(-1) if np.asarray(Z).size else np.asarray([], dtype=float) uniforms = np.asarray(U[:, 0], dtype=float).reshape(-1) if np.asarray(U).size else np.asarray([], dtype=float) ideal = np.asarray(xAxis[:, 0], dtype=float).reshape(-1) if np.asarray(xAxis).size else np.asarray([], dtype=float) @@ -865,7 +892,9 @@ def _compute_diagnostics(self, fit_num: int = 1) -> dict[str, np.ndarray | float gaussianized = norm.ppf(np.clip(uniforms, 1e-6, 1.0 - 1e-6)) lags, acf = _autocorrelation(gaussianized, max_lag=25) acf_ci = 1.96 / np.sqrt(float(gaussianized.size)) if gaussianized.size else np.nan - coeffs = self.getCoeffs(fit_num) + coeffs = self._rawCoeffs(fit_num) + se = _extract_standard_errors(self.stats[fit_num - 1] if fit_num - 1 < len(self.stats) else None, coeffs.size) + sig_mask = _extract_significance_mask(self.stats[fit_num - 1] if fit_num - 1 < len(self.stats) else None, coeffs, se) labels = self.covLabels[fit_num - 1] if fit_num - 1 < len(self.covLabels) else [] if coeffs.size == len(labels): coeff_labels = list(labels) @@ -894,6 +923,8 @@ def _compute_diagnostics(self, fit_num: int = 1) -> dict[str, np.ndarray | float "acf_ci": acf_ci, "gaussianized": gaussianized, "coefficients": coeffs, + "coeff_se": se, + "coeff_sig": sig_mask, "coeff_labels": np.asarray(coeff_labels, dtype=object), } self._diagnostic_cache[fit_num] = diagnostics @@ -909,8 +940,8 @@ def _compute_diagnostics(self, fit_num: int = 1) -> dict[str, np.ndarray | float self.invGausStats = {"X": gaussianized, "rhoSig": acf.tolist(), "confBoundSig": [acf_ci]} return diagnostics - def computeKSStats(self, fit_num: int = 1) -> dict[str, float]: - diag = self._compute_diagnostics(fit_num) + def computeKSStats(self, fit_num: int = 1, *, dt_correction: int = 1) -> dict[str, float]: + diag = self._compute_diagnostics(fit_num, dt_correction=dt_correction) return { "ks_stat": float(diag["ks_stat"]), "ks_pvalue": float(diag["ks_pvalue"]), @@ -920,14 +951,14 @@ def computeKSStats(self, fit_num: int = 1) -> dict[str, float]: def computeInvGausTrans(self, fit_num: int = 1) -> np.ndarray: return np.asarray(self._compute_diagnostics(fit_num)["gaussianized"], dtype=float) - def computeFitResidual(self, fit_num: int = 1) -> Covariate: + def computeFitResidual(self, fit_num: int = 1, *, windowSize: float | None = None) -> Covariate: time, rate_hz = self._lambda_series(fit_num) if time.size == 0: residual = Covariate([], [], "M(t_k)", "time", "s", "counts/bin", ["residual"]) self.setFitResidual(residual) return residual - window_size = float(np.median(np.diff(time))) if time.size > 1 else 1.0 + window_size = float(windowSize) if windowSize is not None else (float(np.median(np.diff(time))) if time.size > 1 else 1.0) spike_train = self._primary_spike_train().nstCopy() spike_train.resample(1.0 / max(window_size, 1e-12)) spike_train.setMinTime(float(time[0])) @@ -935,6 +966,10 @@ def computeFitResidual(self, fit_num: int = 1) -> Covariate: sum_spikes = spike_train.getSigRep(window_size, float(time[0]), float(time[-1])) window_times = np.linspace(float(time[0]), float(time[-1]), sum_spikes.time.size, dtype=float) + # Use the label for the specific fit_num, not all labels + all_labels = self.lambda_signal.dataLabels if getattr(self.lambda_signal, "dataLabels", None) else ["\\lambda"] + idx = min(max(fit_num - 1, 0), len(all_labels) - 1) + fit_label = [all_labels[idx]] lambda_signal = Covariate( time, rate_hz, @@ -942,7 +977,7 @@ def computeFitResidual(self, fit_num: int = 1) -> Covariate: self.lambda_signal.xlabelval, self.lambda_signal.xunits, self.lambda_signal.yunits, - self.lambda_signal.dataLabels if getattr(self.lambda_signal, "dataLabels", None) else ["\\lambda"], + fit_label, ) lambda_int = lambda_signal.integral() lambda_int_vals = ( @@ -971,7 +1006,7 @@ def computeFitResidual(self, fit_num: int = 1) -> Covariate: return residual def evalLambda(self, fit_num: int = 1, newData=None) -> np.ndarray: - coeffs = self.getCoeffs(fit_num) + coeffs = self._rawCoeffs(fit_num) x = np.asarray(newData if newData is not None else [], dtype=float) if x.ndim == 0: x = x.reshape(1, 1) @@ -995,14 +1030,79 @@ def evalLambda(self, fit_num: int = 1, newData=None) -> np.ndarray: rate = np.exp(np.clip(eta, -20.0, 20.0)) * float(self.lambda_signal.sampleRate) return rate.reshape(np.asarray(newData[0] if isinstance(newData, list) else x[:, 0]).shape) if x.size else rate + def computeValLambda(self) -> tuple[Covariate, np.ndarray]: + """Compute the conditional intensity on validation data (Matlab ``computeValLambda``). + + Returns + ------- + lambda_val : Covariate + The validation-set conditional intensity function. + logLL : np.ndarray + Log-likelihood for each fit configuration on the validation data. + """ + if not self.XvalTime or not self.XvalData: + raise ValueError("No validation data available (XvalData / XvalTime are empty)") + + time_vec = np.asarray(self.XvalTime[0], dtype=float).reshape(-1) + lambda_data = np.zeros((time_vec.size, self.numResults), dtype=float) + for i in range(self.numResults): + xval = self.XvalData[i] if i < len(self.XvalData) else self.XvalData[0] + lambda_data[:, i] = self.evalLambda(i + 1, xval) + + lambda_val = Covariate( + time_vec, + lambda_data, + "\\lambda(t)", + self.lambda_signal.xlabelval, + self.lambda_signal.xunits, + "Hz", + list(self.lambda_signal.dataLabels), + ) + + delta = 1.0 / max(float(lambda_val.sampleRate), 1e-12) + y = self.neuralSpikeTrain.getSigRep().dataToMatrix().reshape(-1) + # Truncate or pad y to match validation lambda length + n = min(y.size, lambda_data.shape[0]) + logLL_arr = np.zeros(self.numResults, dtype=float) + for col in range(self.numResults): + lam = np.maximum(lambda_data[:n, col] * delta, 1e-30) + y_trunc = y[:n] + logLL_arr[col] = float(np.sum( + y_trunc * np.log(lam) + (1.0 - y_trunc) * np.log(np.maximum(1.0 - lam, 1e-30)) + )) + + return lambda_val, logLL_arr + def plotResults(self, fit_num: int = 1, handle=None): - fig = handle if handle is not None else plt.figure(figsize=(11.5, 8.0)) + """Matlab-matching 2x4 subplot layout with 5 diagnostic panels. + + Layout (matching Matlab ``subplot(2,4,...)``): + [1,2] KSPlot (double-wide) [3] InvGausTrans [4] SeqCorr + [5,6] plotCoeffs (double-wide) [7,8] plotResidual (double-wide) + """ + import matplotlib.gridspec as gridspec + + fig = handle if handle is not None else plt.figure(figsize=(14.0, 8.0)) fig.clear() - axes = fig.subplots(2, 2) - self.KSPlot(fit_num=fit_num, handle=axes[0, 0]) - self.plotInvGausTrans(fit_num=fit_num, handle=axes[0, 1]) - self.plotSeqCorr(fit_num=fit_num, handle=axes[1, 0]) - self.plotCoeffs(fit_num=fit_num, handle=axes[1, 1]) + gs = gridspec.GridSpec(2, 4, figure=fig) + + ax_ks = fig.add_subplot(gs[0, 0:2]) + ax_ig = fig.add_subplot(gs[0, 2]) + ax_sc = fig.add_subplot(gs[0, 3]) + ax_co = fig.add_subplot(gs[1, 0:2]) + ax_re = fig.add_subplot(gs[1, 2:4]) + + self.KSPlot(fit_num=fit_num, handle=ax_ks) + # Add neuron number label (matching Matlab) + ax_ks.text( + 0.45, 0.95, f"Neuron: {self.neuronNumber}", + transform=ax_ks.transAxes, fontweight="bold", fontsize=10, + verticalalignment="top", + ) + self.plotInvGausTrans(fit_num=fit_num, handle=ax_ig) + self.plotSeqCorr(fit_num=fit_num, handle=ax_sc) + self.plotCoeffs(fit_num=fit_num, handle=ax_co) + self.plotResidual(fit_num=fit_num, handle=ax_re) fig.tight_layout() return fig @@ -1035,18 +1135,11 @@ def plotResidual(self, fit_num: int = 1, handle=None): return ax def plotInvGausTrans(self, fit_num: int = 1, handle=None): - diag = self._compute_diagnostics(fit_num) - ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - x = np.asarray(diag["gaussianized"], dtype=float) - if x.size: - ax.plot(np.arange(1, x.size + 1), x, color="tab:green", linewidth=1.0) - ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") - ax.set_xlabel("event index") - ax.set_ylabel("\\Phi^{-1}(u_i)") - ax.set_title("Inverse-Gaussian/Uniform Transform") - return ax + """Plot ACF of gaussianized rescaled ISIs with 95% CIs. - def plotSeqCorr(self, fit_num: int = 1, handle=None): + Matlab: plotInvGausTrans computes X_j = Φ⁻¹(U_j) and plots the + autocorrelation function of X_j with 95% confidence bounds. + """ diag = self._compute_diagnostics(fit_num) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] lags = np.asarray(diag["acf_lags"], dtype=float) @@ -1058,19 +1151,64 @@ def plotSeqCorr(self, fit_num: int = 1, handle=None): ax.axhline(0.0, color="0.4", linewidth=1.0) ax.set_xlabel("lag") ax.set_ylabel("autocorrelation") - ax.set_title("Sequential Correlation of Rescaled ISIs") + ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs") + return ax + + def plotSeqCorr(self, fit_num: int = 1, handle=None): + """Plot U_j vs U_{j+1} scatter with correlation coefficient. + + Matlab: plotSeqCorr plots the sequential correlation scatter of + U_j (uniform-transformed rescaled ISIs) to detect serial dependence. + """ + diag = self._compute_diagnostics(fit_num) + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] + u = np.asarray(diag.get("uniforms", []), dtype=float) + if u.size > 1: + uj = u[:-1] + uj1 = u[1:] + ax.plot(uj, uj1, ".", color="tab:blue", markersize=4.0) + # Compute correlation coefficient (guard against constant series) + if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0: + with np.errstate(invalid="ignore"): + rho_mat = np.corrcoef(uj, uj1) + rho = rho_mat[0, 1] if rho_mat.shape[0] > 1 else float("nan") + ax.set_title(f"Sequential Correlation ($\\rho$ = {rho:.2g})") + else: + ax.set_title("Sequential Correlation") + else: + ax.set_title("Sequential Correlation") + ax.set_xlabel("$U_j$") + ax.set_ylabel("$U_{j+1}$") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) return ax - def plotCoeffs(self, fit_num: int = 1, handle=None): + def plotCoeffs(self, fit_num: int = 1, handle=None, plotSignificance: int = 1): + """Plot GLM coefficients with error bars and significance markers. + + Matches Matlab FitResult.plotCoeffs: errorbar plot with ±1 SE, + and asterisks (*) above significant coefficients (p < 0.05). + """ diag = self._compute_diagnostics(fit_num) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] coeffs = np.asarray(diag["coefficients"], dtype=float) + se = np.asarray(diag["coeff_se"], dtype=float) + sig = np.asarray(diag["coeff_sig"], dtype=float) labels = list(np.asarray(diag["coeff_labels"], dtype=object)) - xpos = np.arange(coeffs.size, dtype=float) + xpos = np.arange(1, coeffs.size + 1, dtype=float) ax.axhline(0.0, color="0.6", linewidth=1.0) - ax.plot(xpos, coeffs, "o-", color="tab:blue", linewidth=1.0) - ax.set_xticks(xpos, labels, rotation=45, ha="right") - ax.set_ylabel("coefficient value") + # Errorbar plot like Matlab (dot markers with SE whiskers) + valid_se = np.where(np.isfinite(se), se, 0.0) + ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color="tab:blue", + linewidth=1.0, markersize=8.0, capsize=3.0) + if plotSignificance and np.any(sig > 0): + ylims = ax.get_ylim() + y_star = 0.8 * ylims[1] + sig_idx = xpos[sig.astype(bool)] + ax.plot(sig_idx, np.full(sig_idx.size, y_star), "*", color="tab:blue", markersize=10.0) + ax.set_xticks(xpos) + ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=6) + ax.set_ylabel("GLM Fit Coefficients") ax.set_title("GLM Coefficients") return ax @@ -1092,8 +1230,9 @@ def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotS def plotHistCoeffs(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): del sortByEpoch, plotSignificance - coeffs = self.getHistCoeffs(fit_num) - labels = list(self.covLabels[fit_num - 1])[-coeffs.size :] if coeffs.size and fit_num - 1 < len(self.covLabels) else [f"hist_{idx + 1}" for idx in range(coeffs.size)] + coeffs, labels, _se = self.getHistCoeffsWithLabels(fit_num) + if not labels: + labels = [f"hist_{idx + 1}" for idx in range(coeffs.size)] ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] xpos = np.arange(coeffs.size, dtype=float) ax.axhline(0.0, color="0.6", linewidth=1.0) @@ -1313,10 +1452,11 @@ def getHistCoeffs(self, fitNum: int = 1): coeff_rows = [] se_rows = [] for fit in self.fitResCell: - coeffs = fit.getHistCoeffs(fitNum) - fit_labels = list(fit.covLabels[fitNum - 1])[-coeffs.size :] if coeffs.size and fitNum - 1 < len(fit.covLabels) else [] - se = _extract_standard_errors(fit.stats[fitNum - 1] if fitNum - 1 < len(fit.stats) else None, fit.getCoeffs(fitNum).size) - se_hist = se[-coeffs.size :] if coeffs.size else np.array([], dtype=float) + coeffs, fit_labels, se_hist = fit.getHistCoeffsWithLabels(fitNum) + if not fit_labels: + fit_labels = list(fit.covLabels[fitNum - 1])[-coeffs.size :] if coeffs.size and fitNum - 1 < len(fit.covLabels) else [] + se_all = _extract_standard_errors(fit.stats[fitNum - 1] if fitNum - 1 < len(fit.stats) else None, fit._rawCoeffs(fitNum).size) + se_hist = se_all[-coeffs.size :] if coeffs.size else np.array([], dtype=float) row = np.full(len(labels), np.nan, dtype=float) se_row = np.full(len(labels), np.nan, dtype=float) for coeff, coeff_se, label in zip(coeffs, se_hist, fit_labels, strict=False): @@ -1339,14 +1479,46 @@ def getSigCoeffs(self, fitNum: int = 1): sig[row_idx, labels.index(label)] = value return sig - def binCoeffs(self, minVal, maxVal, binSize): - coeff_mat, _, _ = self.getCoeffs(1) - values = coeff_mat[np.isfinite(coeff_mat)] + def binCoeffs(self, minVal=-12.0, maxVal=12.0, binSize=0.1): + """Histogram of regression coefficients per covariate. + + Matches Matlab FitResSummary.binCoeffs: for each unique covariate, + bins the significant coefficient values across all neurons/fits, + normalizes to a PDF, and computes the fraction of times each + covariate was significant. + + Returns + ------- + N : (nBins, nCovariates) per-covariate normalized histograms (PDFs) + edges : (nBins + 1,) bin edges + percentSig : (nCovariates,) fraction of times each covariate was significant + """ edges = np.arange(float(minVal), float(maxVal) + float(binSize), float(binSize), dtype=float) if edges.size < 2: edges = np.array([float(minVal), float(maxVal)], dtype=float) - N, edges = np.histogram(values, bins=edges) - percentSig = float(np.mean(self.getSigCoeffs(1))) if coeff_mat.size else 0.0 + + # Build per-covariate data across all fits + # bAct: (nNeurons, nCov), sigIdx: (nNeurons, nCov) + coeff_mat, labels, se_mat = self.getCoeffs(1) # (nNeurons, nCov) + sig_mat = self.getSigCoeffs(1) # (nNeurons, nCov) boolean + + nCov = len(labels) + N = np.zeros((edges.size - 1, nCov), dtype=float) + percentSig = np.zeros(nCov, dtype=float) + + for i in range(nCov): + vals = coeff_mat[:, i] + sig = sig_mat[:, i].astype(bool) + valid = np.isfinite(vals) + numPresent = float(np.sum(valid)) + # Take only significant values + sig_vals = vals[sig & valid] + Ntemp, _ = np.histogram(sig_vals, bins=edges) + numSig = float(Ntemp.sum()) + percentSig[i] = numSig / max(numPresent, 1.0) + if numSig > 0: + N[:, i] = Ntemp.astype(float) / numSig # normalize to PDF + return N, edges, percentSig def plotIC(self, handle=None): @@ -1361,21 +1533,21 @@ def plotIC(self, handle=None): def plotAIC(self, handle=None): ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] - ax.boxplot(self.AIC, labels=self.fitNames) + ax.boxplot(self.AIC, tick_labels=self.fitNames) ax.set_ylabel("AIC") ax.set_title("AIC Across Neurons") return ax def plotBIC(self, handle=None): ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] - ax.boxplot(self.BIC, labels=self.fitNames) + ax.boxplot(self.BIC, tick_labels=self.fitNames) ax.set_ylabel("BIC") ax.set_title("BIC Across Neurons") return ax def plotlogLL(self, handle=None): ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 3.5))[1] - ax.boxplot(self.logLL, labels=self.fitNames) + ax.boxplot(self.logLL, tick_labels=self.fitNames) ax.set_ylabel("log likelihood") ax.set_title("log likelihood Across Neurons") return ax @@ -1426,7 +1598,7 @@ def boxPlot(self, X, diffIndex: int = 1, h=None, dataLabels=None, **kwargs): labels = [name for idx, name in enumerate(self.fitNames, start=1) if idx != diffIndex] else: labels = list(self.fitNames[: values.shape[1]]) - ax.boxplot(values, labels=labels) + ax.boxplot(values, tick_labels=labels) return ax # ------------------------------------------------------------------ @@ -1461,6 +1633,56 @@ def plotAllCoeffs(self, fitNum: int | list[int] | None = None, ax.axhline(0, color="0.5", linewidth=0.5, linestyle="--") return ax + def _is_hist_label(self, label: str) -> bool: + """Return True if *label* looks like a history window term (e.g. ``[0.001,0.01]``).""" + return bool(re.match(r"^\[", label)) + + def getHistIndex(self, fitNum: int | list[int] | None = None) -> list[int]: + """Return 0-based indices into *uniqueCovLabels* that are history terms.""" + if fitNum is None: + fitNum = list(range(1, self.numResults + 1)) + if isinstance(fitNum, int): + fitNum = [fitNum] + coeff_mat, labels, _ = self.getCoeffs(fitNum[0]) + hist_indices: list[int] = [] + for idx, label in enumerate(labels): + if self._is_hist_label(label): + # Only include if at least one neuron has a non-NaN value + if np.any(np.isfinite(coeff_mat[:, idx])): + hist_indices.append(idx) + return hist_indices + + def getCoeffIndex(self, fitNum: int | list[int] | None = None) -> list[int]: + """Return 0-based indices into *uniqueCovLabels* that are NOT history terms.""" + if fitNum is None: + fitNum = list(range(1, self.numResults + 1)) + if isinstance(fitNum, int): + fitNum = [fitNum] + coeff_mat, labels, _ = self.getCoeffs(fitNum[0]) + hist_set = set(self.getHistIndex(fitNum)) + coeff_indices: list[int] = [] + for idx, label in enumerate(labels): + if idx not in hist_set: + if np.any(np.isfinite(coeff_mat[:, idx])): + coeff_indices.append(idx) + return coeff_indices + + def plotCoeffsWithoutHistory(self, fitNum: int | list[int] | None = None, + plotSignificance: bool = True, + handle=None): + """Plot coefficients excluding history terms (Matlab ``plotCoeffsWithoutHistory``).""" + coeffIndex = self.getCoeffIndex(fitNum) + return self.plotAllCoeffs(fitNum=fitNum, plotSignificance=plotSignificance, + subIndex=coeffIndex, handle=handle) + + def plotHistCoeffs(self, fitNum: int | list[int] | None = None, + plotSignificance: bool = True, + handle=None): + """Plot only the history coefficients (Matlab ``plotHistCoeffs``).""" + histIndex = self.getHistIndex(fitNum) + return self.plotAllCoeffs(fitNum=fitNum, plotSignificance=plotSignificance, + subIndex=histIndex, handle=handle) + def plot3dCoeffSummary(self, handle=None): """3D ribbon plot of binned coefficient distributions (Matlab ``plot3dCoeffSummary``).""" from mpl_toolkits.mplot3d import Axes3D # noqa: F401 diff --git a/nstat/trial.py b/nstat/trial.py index 229aa097..558b6c77 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -479,6 +479,23 @@ def getCovLabelsFromMask(self) -> list[str]: labels.extend([label for keep, label in zip(mask, cov.dataLabels) if keep == 1]) return labels + def getCovDimension(self, identifier=None) -> np.ndarray: + """Return the dimension of each covariate selected by *identifier*. + + Matlab signature: ``dim = getCovDimension(ccObj, identifier)`` + + Returns a 1-D int array whose *i*-th element is ``covs{i}.dimension``. + """ + if identifier is None: + covs = [self.getCov(i) for i in range(1, self.numCov + 1)] + elif isinstance(identifier, (int, np.integer)): + covs = [self.getCov(int(identifier))] + elif isinstance(identifier, (list, np.ndarray)): + covs = [self.getCov(int(idx)) for idx in identifier] + else: + covs = [self.getCov(identifier)] + return np.array([int(c.dimension) for c in covs], dtype=int) + def matrixWithTime(self, repType: str = "standard", dataSelector=None) -> tuple[np.ndarray, np.ndarray, list[str]]: if self.numCov == 0: raise ValueError("CovariateCollection is empty") @@ -824,6 +841,8 @@ def resample(self, sampleRate: float) -> None: self.sampleRate = float(sampleRate) for train in self.nstrain: train.resample(sampleRate) + train.setMinTime(float(self.minTime)) + train.setMaxTime(float(self.maxTime)) def enforceSampleRate(self) -> None: for index in range(1, self.numSpikeTrains + 1): @@ -1004,10 +1023,32 @@ def updateTimes(self, nst: nspikeTrain) -> None: else: nst.setMaxTime(float(self.maxTime)) - def plot(self, *_, handle=None, **__): - selected = self.getIndFromMask() - if not selected: - selected = list(range(1, self.numSpikeTrains + 1)) + def plot(self, selectorArray: Sequence[int] | None = None, + minTime: float | None = None, maxTime: float | None = None, + handle=None, reverseOrder: bool = False, **__): + """Plot a spike-train raster. + + Parameters + ---------- + selectorArray : sequence of int, optional + 1-based indices of neurons to plot. Defaults to the neuron mask + (or all neurons if no mask is set). Matches Matlab positional arg. + minTime, maxTime : float, optional + Time window to display. Defaults to the collection's time span. + handle : matplotlib Axes, optional + Axes to plot into. + reverseOrder : bool + If ``True``, reverse the display order so the last neuron is at + the top. Matches Matlab ``reverseOrderPlot`` parameter. + """ + if selectorArray is not None and len(selectorArray) > 0: + selected = [int(x) for x in selectorArray] + else: + selected = self.getIndFromMask() + if not selected: + selected = list(range(1, self.numSpikeTrains + 1)) + if reverseOrder: + selected = list(reversed(selected)) ax = handle if handle is not None else plt.subplots(1, 1, figsize=(8.0, max(2.5, 0.55 * max(len(selected), 1) + 1.0)))[1] ax.clear() for row, neuron_index in enumerate(selected, start=1): @@ -1015,6 +1056,10 @@ def plot(self, *_, handle=None, **__): train.plot(dHeight=0.8, yOffset=float(row), currentHandle=ax) ax.set_ylim(0.25, len(selected) + 0.75) ax.set_yticks(range(1, len(selected) + 1), [str(item) for item in selected]) + if minTime is not None or maxTime is not None: + lo = float(minTime) if minTime is not None else float(self.minTime) + hi = float(maxTime) if maxTime is not None else float(self.maxTime) + ax.set_xlim(lo, hi) ax.set_title("Spike Train Raster") return ax @@ -1135,14 +1180,18 @@ def psth( time = (window_times[1:] + window_times[:-1]) * 0.5 return Covariate(time, psth_data, "PSTH", "time", "s", "Hz", ["psth"]) - def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson"): + def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson", + *, alphaVal: float = 0.05, Mc: int = 1000): """GLM-based PSTH estimation (Matlab ``nstColl.psthGLM``). Returns ``(psth_covariate, histSignal, psthFitResult)`` matching the - Matlab signature. Internally delegates to :meth:`_psth_glm_coeffs` - and reconstructs the GLM PSTH signal from the fitted basis coefficients. + Matlab signature. The PSTH and history covariates carry Monte Carlo + confidence intervals matching the Matlab implementation (1000 draws + from the normal approximation to the coefficient posterior, transformed + through the link function, with empirical quantile CIs). """ from .analysis import Analysis + from .confidence_interval import ConfidenceInterval basis = self.generateUnitImpulseBasis( float(binwidth), float(self.minTime), float(self.maxTime), float(self.sampleRate) @@ -1160,19 +1209,29 @@ def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson"): psth_result = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, algorithm, [], 1) fit = psth_result[0] if isinstance(psth_result, list) else psth_result - # Reconstruct the GLM PSTH as a Covariate (same as Matlab) - coeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1) + # Extract coefficients and standard errors + coeffs_all, _labels, se_all = fit.getCoeffsWithLabels(1) + raw_coeffs = np.asarray(coeffs_all, dtype=float).reshape(-1) + se_vec = np.asarray(se_all, dtype=float).reshape(-1) numBasis = basis.dimension - if coeffs.size < numBasis: + + if raw_coeffs.size < numBasis: padded = np.zeros(numBasis, dtype=float) - padded[: coeffs.size] = coeffs - coeffs = padded + padded[: raw_coeffs.size] = raw_coeffs + bVals = padded + se_padded = np.full(numBasis, np.nan, dtype=float) + se_padded[: se_vec.size] = se_vec[:numBasis] if se_vec.size >= numBasis else se_vec + se_basis = se_padded else: - coeffs = coeffs[:numBasis] + bVals = raw_coeffs[:numBasis] + se_basis = se_vec[:numBasis] + + is_poisson = str(fitType or "poisson").lower() == "poisson" + sr = float(self.sampleRate) # basis.data is (nTimeBins x numBasis): multiply to get GLM rate bdata = np.asarray(basis.data, dtype=float) - lambda_glm = np.exp(bdata @ coeffs) + lambda_glm = np.exp(bdata @ bVals) * sr psth_cov = Covariate( basis.time.copy(), lambda_glm.reshape(-1, 1), @@ -1183,19 +1242,87 @@ def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson"): ["\\lambda_{GLM}"], ) - # History signal (only present when windowTimes is specified) - histSignal = None - if np.asarray(hist).size and coeffs.size > numBasis: - histCoeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1)[numBasis:] - histSignal = Covariate( - np.arange(len(histCoeffs), dtype=float), - histCoeffs.reshape(-1, 1), - "History", - "lag", - "bins", - "", - ["h"], + # ---- Monte Carlo confidence intervals for PSTH (Matlab parity) ---- + se_clean = np.where(np.isnan(se_basis), 0.0, se_basis) + if np.any(se_clean > 0): + rng = np.random.default_rng() + z = rng.standard_normal((se_clean.size, Mc)) + xKDraw = bVals[:, None] + se_clean[:, None] * z # (numBasis, Mc) + if is_poisson: + lambdaDraw = np.exp(np.clip(xKDraw, -30, 30)) * sr + else: + xc = np.clip(xKDraw, -30, 30) + lambdaDraw = (np.exp(xc) / (1.0 + np.exp(xc))) * sr + lambdaDraw = np.where(np.isinf(lambdaDraw), 0.0, lambdaDraw) + + # Per-coefficient empirical quantiles + CIs = np.column_stack([ + np.quantile(lambdaDraw, alphaVal / 2.0, axis=1), + np.quantile(lambdaDraw, 1.0 - alphaVal / 2.0, axis=1), + ]) # (numBasis, 2) + lower = bdata @ CIs[:, 0] + upper = bdata @ CIs[:, 1] + + ciPSTHGLM = ConfidenceInterval( + basis.time, np.column_stack([lower, upper]), + "CI_{psth_GLM}", psth_cov.xlabelval, psth_cov.xunits, "Hz", ) + psth_cov.setConfInterval(ciPSTHGLM) + + # ---- History signal (only present when windowTimes is specified) ---- + histSignal = None + if np.asarray(hist).size and raw_coeffs.size > numBasis: + histVals = raw_coeffs[numBasis:] + se_hist = se_vec[numBasis:] if se_vec.size > numBasis else np.zeros_like(histVals) + + # Build piecewise-constant basis for history time axis (Matlab style) + selfHist = np.asarray(hist, dtype=float).reshape(-1) + histTime = np.arange(0.0, float(np.max(selfHist)) + 0.001, 0.001) + nHistBins = len(selfHist) - 1 + if len(histTime) > 0 and nHistBins > 0: + basisMat = np.zeros((len(histTime), nHistBins), dtype=float) + for i in range(nHistBins): + if i == nHistBins - 1: + col = (histTime >= selfHist[i]) & (histTime <= selfHist[i + 1]) + else: + col = (histTime >= selfHist[i]) & (histTime < selfHist[i + 1]) + basisMat[:, i] = col.astype(float) + + expHistVals = np.exp(histVals[:nHistBins]) + histSignal = Covariate( + histTime, (basisMat @ expHistVals).reshape(-1, 1), + "PSTH_{glm}", "time", "s", "Hz", + ) + + # Monte Carlo CIs for history signal + se_h_clean = np.where(np.isnan(se_hist[:nHistBins]), 0.0, se_hist[:nHistBins]) + if np.any(se_h_clean > 0): + rng2 = np.random.default_rng() + z2 = rng2.standard_normal((se_h_clean.size, Mc)) + # Matlab centers on zero for history CIs (variability around null) + xKDrawH = se_h_clean[:, None] * z2 + if is_poisson: + histDraw = np.exp(np.clip(xKDrawH, -30, 30)) * sr + else: + xc2 = np.clip(xKDrawH, -30, 30) + histDraw = (np.exp(xc2) / (1.0 + np.exp(xc2))) * sr + CIsH = np.column_stack([ + np.quantile(histDraw, alphaVal / 2.0, axis=1), + np.quantile(histDraw, 1.0 - alphaVal / 2.0, axis=1), + ]) + lowerH = basisMat @ CIsH[:, 0] + upperH = basisMat @ CIsH[:, 1] + ciHist = ConfidenceInterval( + histTime, np.column_stack([lowerH, upperH]), + "CI_{psth_GLMHIST}", psth_cov.xlabelval, psth_cov.xunits, "Hz", + ) + histSignal.setConfInterval(ciHist) + else: + histSignal = Covariate( + np.arange(len(histVals), dtype=float), + histVals.reshape(-1, 1), + "History", "lag", "bins", "", ["h"], + ) return psth_cov, histSignal, fit @@ -1296,7 +1423,7 @@ def _psth_glm_coeffs( algorithm = "GLM" if str(fitType or "poisson").lower() == "poisson" else "BNLRCG" psth_result = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, algorithm, [], 1) fit = psth_result[0] if isinstance(psth_result, list) else psth_result - coeffs = np.asarray(fit.getCoeffs(1), dtype=float).reshape(-1) + coeffs = fit._rawCoeffs(1) numBasis = basis.dimension if coeffs.size < numBasis: padded = np.zeros(numBasis, dtype=float) @@ -1577,6 +1704,33 @@ def ssglmFB( A, Q0_diag, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, neuronName ) + def toStructure(self) -> dict[str, Any]: + """Serialize to a plain dict (Matlab ``nstColl.toStructure``).""" + self.resetMask() + return { + "nstrain": [train.toStructure() for train in self.nstrain], + "numSpikeTrains": int(self.numSpikeTrains), + "minTime": float(self.minTime), + "maxTime": float(self.maxTime), + "sampleRate": float(self.sampleRate), + "neuronMask": self.neuronMask.tolist(), + "neighbors": np.asarray(self.neighbors, dtype=int).tolist() if np.size(self.neighbors) else [], + } + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "SpikeTrainCollection": + """Reconstruct from a dict produced by :meth:`toStructure` (Matlab ``nstColl.fromStructure``).""" + nst_list = [nspikeTrain.fromStructure(item) for item in structure.get("nstrain", [])] + coll = SpikeTrainCollection(nst_list) + if "minTime" in structure: + coll.setMinTime(float(structure["minTime"])) + if "maxTime" in structure: + coll.setMaxTime(float(structure["maxTime"])) + neighbors = structure.get("neighbors", []) + if neighbors and np.size(neighbors): + coll.setNeighbors(np.asarray(neighbors, dtype=int)) + return coll + class TrialConfig: """MATLAB-style TrialConfig with configuration-application semantics.""" @@ -2071,6 +2225,34 @@ def isEnsCovHistSet(self) -> bool: return isinstance(self.ensCovHist, History) + def getNumHist(self) -> int | list[int]: + """Return the number of history coefficients. + + If a single ``History`` object is set, returns the number of + history window coefficients (``len(windowTimes) - 1``). + If a list of ``History`` objects is set, returns a list with + the count for each. Returns ``0`` when no history is set. + + Matches Matlab ``Trial.getNumHist()``. + """ + from .history import History + + if not self.isHistSet(): + return 0 + if isinstance(self.history, History): + wt = np.asarray(self.history.windowTimes, dtype=float).ravel() + return max(int(wt.size - 1), 0) + if isinstance(self.history, list): + counts: list[int] = [] + for h in self.history: + if isinstance(h, History): + wt = np.asarray(h.windowTimes, dtype=float).ravel() + counts.append(max(int(wt.size - 1), 0)) + else: + counts.append(0) + return counts + return 0 + def addCov(self, cov: Covariate) -> None: self.covarColl.addToColl(cov) self.covMask = self.covarColl.covMask @@ -2106,10 +2288,20 @@ def getDesignMatrix(self, neuronNum: int, dataSelector=None) -> np.ndarray: X = self.covarColl.dataToMatrix("standard", dataSelector) if self.isHistSet(): H = self.getHistMatrices(neuronNum) - X = H if X.size == 0 else np.column_stack([X, H]) + if X.size == 0: + X = H + else: + # Align row counts — covariates and history may differ by + # one sample due to boundary effects in time-grid construction. + n = min(X.shape[0], H.shape[0]) + X = np.column_stack([X[:n, :], H[:n, :]]) if self.isEnsCovHistSet(): E = self.getEnsCovMatrix(neuronNum) - X = E if X.size == 0 else np.column_stack([X, E]) + if X.size == 0: + X = E + else: + n = min(X.shape[0], E.shape[0]) + X = np.column_stack([X[:n, :], E[:n, :]]) return X def getHistForNeurons(self, neuronIndex) -> CovariateCollection: @@ -2207,6 +2399,16 @@ def getEnsCovLabelsFromMask(self, neuronNum: int) -> list[str]: ensCovCollTemp.maskAwayAllExcept(included) return ensCovCollTemp.getCovLabelsFromMask() + def getAllLabels(self) -> list[str]: + """Return all covariate + history + ensemble labels (no mask filtering). + + Matlab equivalent: ``Trial.getAllLabels``. + """ + labels = list(self.getAllCovLabels()) + labels.extend(self.getHistLabels()) + labels.extend(self.getEnsCovLabels()) + return labels + def getLabelsFromMask(self, neuronNum: int) -> list[str]: labels = list(self.getCovLabelsFromMask()) labels.extend(self.getHistLabels()) @@ -2245,6 +2447,54 @@ def restoreToOriginal(self) -> None: self.resampleEnsColl() self.makeConsistentTime() + # ------------------------------------------------------------------ + # Serialization (Matlab Trial.toStructure / Trial.fromStructure) + # ------------------------------------------------------------------ + def toStructure(self) -> dict[str, Any]: + """Serialize a Trial to a plain dict (Matlab ``Trial.toStructure``).""" + from .history import History + + structure: dict[str, Any] = {} + structure["nspikeColl"] = self.nspikeColl.toStructure() + structure["covarColl"] = self.covarColl.toStructure() + structure["ev"] = self.ev.toStructure() if self.ev is not None else None + structure["history"] = self.history.toStructure() if isinstance(self.history, History) else None + structure["ensCovHist"] = self.ensCovHist.toStructure() if isinstance(self.ensCovHist, History) else None + structure["sampleRate"] = float(self.sampleRate) if np.isfinite(self.sampleRate) else self.sampleRate + structure["minTime"] = float(self.minTime) + structure["maxTime"] = float(self.maxTime) + structure["covMask"] = [np.asarray(m, dtype=int).tolist() for m in self.covMask] if self.covMask is not None else [] + structure["neuronMask"] = np.asarray(self.neuronMask, dtype=int).tolist() + structure["trainingWindow"] = np.asarray(self.trainingWindow, dtype=float).tolist() if self.trainingWindow is not None else [] + structure["validationWindow"] = np.asarray(self.validationWindow, dtype=float).tolist() if self.validationWindow is not None else [] + return structure + + @staticmethod + def fromStructure(structure: dict[str, Any]) -> "Trial": + """Reconstruct a Trial from a dict produced by :meth:`toStructure` (Matlab ``Trial.fromStructure``).""" + from .events import Events + from .history import History + + nspikeColl = SpikeTrainCollection.fromStructure(structure["nspikeColl"]) + covarColl = CovariateCollection.fromStructure(structure["covarColl"]) + ev = Events.fromStructure(structure.get("ev")) + h = History.fromStructure(structure.get("history")) + ensHist = History.fromStructure(structure.get("ensCovHist")) + trial = Trial(nspikeColl, covarColl, ev, h, ensHist) + + if "minTime" in structure: + trial.setMinTime(float(structure["minTime"])) + if "maxTime" in structure: + trial.setMaxTime(float(structure["maxTime"])) + + trainingW = structure.get("trainingWindow", []) + validationW = structure.get("validationWindow", []) + if trainingW and validationW: + partition = list(trainingW) + list(validationW) + trial.setTrialPartition(partition) + + return trial + def makeConsistentSampleRate(self) -> None: self.resample(self.findMaxSampleRate()) @@ -2263,6 +2513,28 @@ def findMaxSampleRate(self) -> float: values = [value for value in [self.nspikeColl.findMaxSampleRate(), self.covarColl.findMaxSampleRate()] if np.isfinite(value)] return float(max(values)) if values else float("nan") + def findMinSampleRate(self) -> float: + """Return the minimum sample rate across spike collection, covariate collection, and trial. + + Matches Matlab ``Trial.findMinSampleRate()``. + """ + candidates: list[float] = [] + if hasattr(self, "sampleRate") and np.isfinite(self.sampleRate): + candidates.append(float(self.sampleRate)) + try: + sr = self.nspikeColl.sampleRate + if np.isfinite(sr): + candidates.append(float(sr)) + except Exception: + pass + try: + sr = self.covarColl.sampleRate + if np.isfinite(sr): + candidates.append(float(sr)) + except Exception: + pass + return float(min(candidates)) if candidates else float("nan") + def findMinTime(self) -> float: return float(min(self.nspikeColl.minTime, self.covarColl.minTime)) diff --git a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat index 4ef2eb5b..e44f5fe6 100644 Binary files a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat and b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat differ diff --git a/tests/test_fitresult_diagnostics.py b/tests/test_fitresult_diagnostics.py index 0b9b2a26..8b0d6862 100644 --- a/tests/test_fitresult_diagnostics.py +++ b/tests/test_fitresult_diagnostics.py @@ -45,7 +45,7 @@ def test_fitresult_plotting_methods_return_matplotlib_objects() -> None: ax4 = fit.plotSeqCorr() ax5 = fit.plotCoeffs() - assert len(fig.axes) == 4 + assert len(fig.axes) == 5 # 2x4 layout: KS, InvGaus, SeqCorr, Coeffs, Residual for ax in (ax1, ax2, ax3, ax4, ax5): assert hasattr(ax, "plot") plt.close("all") @@ -101,9 +101,10 @@ def test_fitsummary_matlab_style_helpers_cover_ic_and_coeff_views() -> None: assert coeff_mat.shape[0] == summary.numNeurons assert sig.shape == coeff_mat.shape assert len(labels) == coeff_mat.shape[1] - assert bins.ndim == 1 + assert bins.ndim == 2 # (nBins, nCovariates) — per-covariate histograms assert edges.ndim == 1 - assert 0.0 <= percent_sig <= 1.0 + assert percent_sig.ndim == 1 # one value per covariate + assert np.all((0.0 <= percent_sig) & (percent_sig <= 1.0)) assert summary.coeffMin == -2.0 assert summary.coeffMax == 2.0 diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index 26ef9afd..d097d659 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -344,6 +344,13 @@ function export_covariate_fixture(fixtureRoot) end function export_nstcoll_fixture(fixtureRoot) +% NOTE: Matlab's addSingleSpikeToColl stores references (handle objects), +% so the second train added here never gets computeStatistics called on it +% via updateTimes. Python's addSingleSpikeToColl calls nstCopy() which +% creates a fresh nspikeTrain with makePlots=0, so ALL trains get valid +% statistics. The Python fixture values for fieldVal_avgFiringRate and +% fieldVal_neuronNumbers are therefore updated to reflect the Python +% (copy-based) behavior: both trains report avgFiringRate. n1 = nspikeTrain([0.1 0.3], '1', 10, 0.0, 0.5, 'time', 's', 'spikes', 'spk', -1); n2 = nspikeTrain([0.2], '2', 10, 0.0, 0.5, 'time', 's', 'spikes', 'spk', -1); coll = nstColl({n1, n2});