diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b31f290b..798f6368 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -233,3 +233,31 @@ jobs: python -m pip install -e .[dev] - name: Verify audited MATLAB-facing runtime surface run: pytest -q tests/test_matlab_symbol_surface.py tests/test_class_fidelity_audit.py tests/test_parity_report.py + + regenerate-figures: + runs-on: ubuntu-latest + env: + NSTAT_DATA_DIR: ${{ github.workspace }}/.nstat_data_cache + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Cache example data + uses: actions/cache@v4 + with: + path: ${{ env.NSTAT_DATA_DIR }} + key: nstat-example-data-v1 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[dev] + - name: Regenerate all paper figures + run: python examples/paper/regenerate_all_figures.py + - name: Upload regenerated figures + uses: actions/upload-artifact@v4 + with: + name: paper-figures + path: docs/figures/ + retention-days: 30 diff --git a/docs/figures/example01/fig01_constant_mg_summary.png b/docs/figures/example01/fig01_constant_mg_summary.png index 9da8d1d1..97afa7cf 100644 Binary files a/docs/figures/example01/fig01_constant_mg_summary.png and b/docs/figures/example01/fig01_constant_mg_summary.png differ diff --git a/docs/figures/example01/fig02_washout_raster_overview.png b/docs/figures/example01/fig02_washout_raster_overview.png index c65e2a06..0d8075b5 100644 Binary files a/docs/figures/example01/fig02_washout_raster_overview.png and b/docs/figures/example01/fig02_washout_raster_overview.png differ diff --git a/docs/figures/example01/fig03_piecewise_baseline_comparison.png b/docs/figures/example01/fig03_piecewise_baseline_comparison.png index c9ef13ca..c194ff0d 100644 Binary files a/docs/figures/example01/fig03_piecewise_baseline_comparison.png and b/docs/figures/example01/fig03_piecewise_baseline_comparison.png differ diff --git a/docs/figures/example02/fig01_data_overview.png b/docs/figures/example02/fig01_data_overview.png index de99cf48..f6aede37 100644 Binary files a/docs/figures/example02/fig01_data_overview.png and b/docs/figures/example02/fig01_data_overview.png differ diff --git a/docs/figures/example02/fig02_lag_and_model_comparison.png b/docs/figures/example02/fig02_lag_and_model_comparison.png index 73592f75..f43947ff 100644 Binary files a/docs/figures/example02/fig02_lag_and_model_comparison.png and b/docs/figures/example02/fig02_lag_and_model_comparison.png differ diff --git a/docs/figures/example03/fig01_simulated_and_real_rasters.png b/docs/figures/example03/fig01_simulated_and_real_rasters.png index 0bb68340..95ece6e0 100644 Binary files a/docs/figures/example03/fig01_simulated_and_real_rasters.png and b/docs/figures/example03/fig01_simulated_and_real_rasters.png differ diff --git a/docs/figures/example03/fig02_psth_comparison.png b/docs/figures/example03/fig02_psth_comparison.png index a7457956..3976e35e 100644 Binary files a/docs/figures/example03/fig02_psth_comparison.png and b/docs/figures/example03/fig02_psth_comparison.png differ diff --git a/docs/figures/example03/fig03_ssglm_simulation_summary.png b/docs/figures/example03/fig03_ssglm_simulation_summary.png index 0e0864ca..7f5ac534 100644 Binary files a/docs/figures/example03/fig03_ssglm_simulation_summary.png and b/docs/figures/example03/fig03_ssglm_simulation_summary.png differ diff --git a/docs/figures/example03/fig04_ssglm_fit_diagnostics.png b/docs/figures/example03/fig04_ssglm_fit_diagnostics.png index c5846300..de9efc25 100644 Binary files a/docs/figures/example03/fig04_ssglm_fit_diagnostics.png and b/docs/figures/example03/fig04_ssglm_fit_diagnostics.png differ diff --git a/docs/figures/example03/fig05_stimulus_effect_surfaces.png b/docs/figures/example03/fig05_stimulus_effect_surfaces.png index 3f9fd4ad..24274b27 100644 Binary files a/docs/figures/example03/fig05_stimulus_effect_surfaces.png and b/docs/figures/example03/fig05_stimulus_effect_surfaces.png differ diff --git a/docs/figures/example03/fig06_learning_trial_comparison.png b/docs/figures/example03/fig06_learning_trial_comparison.png index cb058492..7bfeccbd 100644 Binary files a/docs/figures/example03/fig06_learning_trial_comparison.png and b/docs/figures/example03/fig06_learning_trial_comparison.png differ diff --git a/docs/figures/example04/fig01_example_cells_path_overlay.png b/docs/figures/example04/fig01_example_cells_path_overlay.png index a05302b9..47f41cdd 100644 Binary files a/docs/figures/example04/fig01_example_cells_path_overlay.png and b/docs/figures/example04/fig01_example_cells_path_overlay.png differ diff --git a/docs/figures/example04/fig02_model_summary_statistics.png b/docs/figures/example04/fig02_model_summary_statistics.png index a53e29d7..5a17f5fe 100644 Binary files a/docs/figures/example04/fig02_model_summary_statistics.png and b/docs/figures/example04/fig02_model_summary_statistics.png differ diff --git a/docs/figures/example04/fig03_gaussian_place_fields_animal1.png b/docs/figures/example04/fig03_gaussian_place_fields_animal1.png index 86b8b753..b18cb706 100644 Binary files a/docs/figures/example04/fig03_gaussian_place_fields_animal1.png and b/docs/figures/example04/fig03_gaussian_place_fields_animal1.png differ diff --git a/docs/figures/example04/fig04_zernike_place_fields_animal1.png b/docs/figures/example04/fig04_zernike_place_fields_animal1.png index e3cda44b..c6d47fc7 100644 Binary files a/docs/figures/example04/fig04_zernike_place_fields_animal1.png and b/docs/figures/example04/fig04_zernike_place_fields_animal1.png differ diff --git a/docs/figures/example04/fig05_gaussian_place_fields_animal2.png b/docs/figures/example04/fig05_gaussian_place_fields_animal2.png index acadca6f..b1f79da1 100644 Binary files a/docs/figures/example04/fig05_gaussian_place_fields_animal2.png and b/docs/figures/example04/fig05_gaussian_place_fields_animal2.png differ diff --git a/docs/figures/example04/fig06_zernike_place_fields_animal2.png b/docs/figures/example04/fig06_zernike_place_fields_animal2.png index 3931b1df..35c787ee 100644 Binary files a/docs/figures/example04/fig06_zernike_place_fields_animal2.png and b/docs/figures/example04/fig06_zernike_place_fields_animal2.png differ diff --git a/docs/figures/example04/fig07_example_cell_mesh_comparison.png b/docs/figures/example04/fig07_example_cell_mesh_comparison.png index 61a3dbdf..a27857fc 100644 Binary files a/docs/figures/example04/fig07_example_cell_mesh_comparison.png and b/docs/figures/example04/fig07_example_cell_mesh_comparison.png differ diff --git a/docs/figures/example05/fig01_univariate_setup.png b/docs/figures/example05/fig01_univariate_setup.png index 94dc2a8f..ace432bb 100644 Binary files a/docs/figures/example05/fig01_univariate_setup.png and b/docs/figures/example05/fig01_univariate_setup.png differ diff --git a/docs/figures/example05/fig02_univariate_decoding.png b/docs/figures/example05/fig02_univariate_decoding.png index fbee2a41..bfb470e0 100644 Binary files a/docs/figures/example05/fig02_univariate_decoding.png and b/docs/figures/example05/fig02_univariate_decoding.png differ diff --git a/docs/figures/example05/fig03_reach_and_population_setup.png b/docs/figures/example05/fig03_reach_and_population_setup.png index 4e87d03e..ab09ae8f 100644 Binary files a/docs/figures/example05/fig03_reach_and_population_setup.png and b/docs/figures/example05/fig03_reach_and_population_setup.png differ diff --git a/docs/figures/example05/fig04_ppaf_goal_vs_free.png b/docs/figures/example05/fig04_ppaf_goal_vs_free.png index afef5c89..0ce2fdec 100644 Binary files a/docs/figures/example05/fig04_ppaf_goal_vs_free.png and b/docs/figures/example05/fig04_ppaf_goal_vs_free.png differ diff --git a/docs/figures/example05/fig05_hybrid_setup.png b/docs/figures/example05/fig05_hybrid_setup.png index c00124d3..3839a193 100644 Binary files a/docs/figures/example05/fig05_hybrid_setup.png and b/docs/figures/example05/fig05_hybrid_setup.png differ diff --git a/docs/figures/example05/fig06_hybrid_decoding_summary.png b/docs/figures/example05/fig06_hybrid_decoding_summary.png index d01845bc..dbbf4033 100644 Binary files a/docs/figures/example05/fig06_hybrid_decoding_summary.png and b/docs/figures/example05/fig06_hybrid_decoding_summary.png differ diff --git a/examples/paper/example01_mepsc_poisson.py b/examples/paper/example01_mepsc_poisson.py index 5ca60dd1..a74456f4 100644 --- a/examples/paper/example01_mepsc_poisson.py +++ b/examples/paper/example01_mepsc_poisson.py @@ -140,12 +140,36 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non print(f" AIC: {resultConst.AIC}") print(f" BIC: {resultConst.BIC}") - # --- Figure 1: Constant Mg2+ diagnostics (Matlab-matching plotResults) --- - # Matlab calls resultConst.plotResults which creates a 2x4 grid: - # [KSPlot (2-wide)] [InvGausTrans] [SeqCorr] - # [plotCoeffs (2-wide)] [plotResidual (2-wide)] - fig1 = plt.figure(figsize=(14, 9)) - resultConst.plotResults(handle=fig1) + # --- Figure 1: Constant Mg2+ diagnostics (Matlab-matching 2x2 layout) --- + # Matlab uses subplot(2,2,...) with: raster, InvGausTrans, KSPlot, lambda + fig1, axes1 = plt.subplots(2, 2, figsize=(14, 9)) + + # (2,2,1): Neural raster + ax = axes1[0, 0] + spikeCollConst.plot(handle=ax) + ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_yticks([0, 1]) + + # (2,2,2): Inverse Gaussian transform (ACF) + resultConst.plotInvGausTrans(fit_num=None, handle=axes1[0, 1]) + + # (2,2,3): KS plot + resultConst.KSPlot(fit_num=None, handle=axes1[1, 0]) + + # (2,2,4): Lambda estimate + ax = axes1[1, 1] + lam = resultConst.lambda_signal + ax.plot(np.asarray(lam.time, dtype=float), + np.asarray(lam.data[:, 0], dtype=float), + "b", linewidth=2) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel(lam.ylabel if hasattr(lam, "ylabel") else "spikes/sec", + fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$"], loc="upper right") + fig1.suptitle("Example 01 — Figure 1: Constant Mg$^{2+}$ Summary", fontsize=14, fontweight="bold") fig1.tight_layout() @@ -229,10 +253,39 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non print(f" AIC: {resultWashout.AIC}") print(f" BIC: {resultWashout.BIC}") - # --- Figure 3: Piecewise model diagnostics (Matlab-matching plotResults) --- - # Matlab calls resultWashout.plotResults which creates the same 2x4 grid. - fig3 = plt.figure(figsize=(14, 9)) - resultWashout.plotResults(handle=fig3) + # --- Figure 3: Piecewise model diagnostics (Matlab-matching 2x2 layout) --- + # Matlab uses subplot(2,2,...) with: raster+epoch lines, InvGausTrans, KSPlot, lambda comparison + fig3, axes3 = plt.subplots(2, 2, figsize=(14, 9)) + + # (2,2,1): Neural raster with epoch boundary lines + ax = axes3[0, 0] + spikeCollWashout.plot(handle=ax) + ax.set_title("Neural Raster with decreasing Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_yticklabels([]) + ax.axvline(495, color="r", linewidth=4) + ax.axvline(765, color="r", linewidth=4) + + # (2,2,2): Inverse Gaussian transform (ACF) — all fits + resultWashout.plotInvGausTrans(fit_num=None, handle=axes3[0, 1]) + + # (2,2,3): KS plot — all fits + resultWashout.KSPlot(fit_num=None, handle=axes3[1, 0]) + + # (2,2,4): Lambda comparison (two models overlaid) + ax = axes3[1, 1] + lam = resultWashout.lambda_signal + t = np.asarray(lam.time, dtype=float) + ax.plot(t, np.asarray(lam.data[:, 0], dtype=float), "b", linewidth=2) + if lam.data.shape[1] > 1: + ax.plot(t, np.asarray(lam.data[:, 1], dtype=float), "g", linewidth=2) + ax.set_ylim(0, 5) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel(lam.ylabel if hasattr(lam, "ylabel") else "spikes/sec", + fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$", "$\\lambda_{const-epoch}$"], loc="upper right") + fig3.suptitle("Example 01 — Figure 3: Piecewise Baseline Comparison", fontsize=14, fontweight="bold") fig3.tight_layout() diff --git a/examples/paper/example03_psth_and_ssglm.py b/examples/paper/example03_psth_and_ssglm.py index 61cfe649..34c39a80 100644 --- a/examples/paper/example03_psth_and_ssglm.py +++ b/examples/paper/example03_psth_and_ssglm.py @@ -494,45 +494,30 @@ def run_part_b(data_dir, export_dir=None): estStimEffect = np.exp(basisMat @ xK) / delta # (T, K) # ------------------------------------------------------------------ - # Figure 5: True/PSTH/SSGLM stimulus effect surfaces (3D mesh) + # Figure 5: True/PSTH/SSGLM stimulus effect surfaces + # Match MATLAB: mesh(trial, time, data) with view([90 -90]) → top-down + # Using pcolormesh for clean 2D rendering matching MATLAB's top-down mesh # ------------------------------------------------------------------ - from mpl_toolkits.mplot3d import Axes3D # noqa: F401 - - fig5 = plt.figure(figsize=(10, 12)) + fig5, axes5 = plt.subplots(3, 1, figsize=(14, 9)) trial_axis = np.arange(1, numRealizations + 1) T_act = min(actStimEffect.shape[0], len(basis_time)) - T_mesh, K_mesh = np.meshgrid( - basis_time[:T_act], trial_axis, indexing="ij" - ) - ax = fig5.add_subplot(3, 1, 1, projection="3d") - ax.plot_surface(K_mesh, T_mesh, actStimEffect[:T_act, :], - cmap="viridis", edgecolor="none") - ax.view_init(elev=-90, azim=90) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("True Stimulus Effect", fontweight="bold", fontsize=14) - - ax = fig5.add_subplot(3, 1, 2, projection="3d") - ax.plot_surface(K_mesh, T_mesh, psthSurface2D[:T_act, :], - cmap="viridis", edgecolor="none") - ax.view_init(elev=-90, azim=90) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("PSTH Estimated Stimulus Effect", fontweight="bold", - fontsize=14) - - ax = fig5.add_subplot(3, 1, 3, projection="3d") - ax.plot_surface(K_mesh, T_mesh, estStimEffect[:T_act, :], - cmap="viridis", edgecolor="none") - ax.view_init(elev=-90, azim=90) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title("SSGLM Estimated Stimulus Effect", fontweight="bold", - fontsize=14) + surfaces = [ + (actStimEffect[:T_act, :], "True Stimulus Effect"), + (psthSurface2D[:T_act, :], "PSTH Estimated Stimulus Effect"), + (estStimEffect[:T_act, :], "SSGLM Estimated Stimulus Effect"), + ] + for ax, (data, title_str) in zip(axes5, surfaces): + # data is (T, K) — time on rows, trials on columns + # MATLAB: imagesc shows trial on x, time on y + ax.pcolormesh(trial_axis, basis_time[:T_act], data, cmap="viridis", + shading="auto") + ax.set_ylabel("time [s]") + ax.set_title(title_str, fontweight="bold", fontsize=14) + axes5[-1].set_xlabel("Trial [k]") fig5.tight_layout() - print(" Figure 5: Stimulus effect surfaces (3D mesh)") + print(" Figure 5: Stimulus effect surfaces (top-down mesh)") # ------------------------------------------------------------------ # 6. Learning-trial analysis: spike rate CIs diff --git a/examples/paper/example04_place_cells_continuous_stimulus.py b/examples/paper/example04_place_cells_continuous_stimulus.py index 05543f5d..c5062e23 100644 --- a/examples/paper/example04_place_cells_continuous_stimulus.py +++ b/examples/paper/example04_place_cells_continuous_stimulus.py @@ -269,19 +269,19 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non fig2, axes2 = plt.subplots(1, 3, figsize=(14, 5)) axes2[0].boxplot([dKS1[np.isfinite(dKS1)], dKS2[np.isfinite(dKS2)]], - tick_labels=["Animal 1", "Animal 2"]) + labels=["Animal 1", "Animal 2"]) axes2[0].set_ylabel(r"$\Delta$KS (Gaussian - Zernike)") axes2[0].set_title("KS Statistic Difference") axes2[0].axhline(0, color="gray", linestyle="--", linewidth=0.5) axes2[1].boxplot([dAIC1[np.isfinite(dAIC1)], dAIC2[np.isfinite(dAIC2)]], - tick_labels=["Animal 1", "Animal 2"]) + labels=["Animal 1", "Animal 2"]) axes2[1].set_ylabel(r"$\Delta$AIC (Zernike - Gaussian)") axes2[1].set_title("AIC Difference") axes2[1].axhline(0, color="gray", linestyle="--", linewidth=0.5) axes2[2].boxplot([dBIC1[np.isfinite(dBIC1)], dBIC2[np.isfinite(dBIC2)]], - tick_labels=["Animal 1", "Animal 2"]) + labels=["Animal 1", "Animal 2"]) axes2[2].set_ylabel(r"$\Delta$BIC (Zernike - Gaussian)") axes2[2].set_title("BIC Difference") axes2[2].axhline(0, color="gray", linestyle="--", linewidth=0.5) diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 923cdabd..db4f6574 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -455,7 +455,7 @@ def _plot_part_b(result): fig4, axes4 = plt.subplots(1, 4, figsize=(14, 4)) for d, (ax, lab) in enumerate(zip(axes4, labels)): data = [result["rmse_free"][:, d], result["rmse_goal"][:, d]] - bp = ax.boxplot(data, tick_labels=["Free", "Goal"]) + bp = ax.boxplot(data, labels=["Free", "Goal"]) ax.set_title(f"RMSE: {lab}") ax.set_ylabel("RMSE") diff --git a/examples/paper/regenerate_all_figures.py b/examples/paper/regenerate_all_figures.py new file mode 100644 index 00000000..de63d6e7 --- /dev/null +++ b/examples/paper/regenerate_all_figures.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +"""Regenerate all paper example figures. + +Runs each example script with --export-figures and saves PNGs to +docs/figures/exampleNN/. Called by CI on every push and can also be +run locally: + + python examples/paper/regenerate_all_figures.py +""" +from __future__ import annotations + +import importlib +import sys +import traceback +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") + +THIS_DIR = Path(__file__).resolve().parent +REPO_ROOT = THIS_DIR.parents[1] +FIGURES_ROOT = REPO_ROOT / "docs" / "figures" + +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +EXAMPLES = [ + ("example01_mepsc_poisson", "example01", "run_example01"), + ("example02_whisker_stimulus_thalamus", "example02", "run_example02"), + ("example03_psth_and_ssglm", "example03", "run_example03"), + ("example04_place_cells_continuous_stimulus", "example04", "run_example04"), + ("example05_decoding_ppaf_pphf", "example05", "run_example05"), +] + + +def main() -> int: + # Ensure example data is available + from nstat.data_manager import ensure_example_data + ensure_example_data(download=True) + + failed = 0 + for mod_name, dir_name, run_fn_name in EXAMPLES: + export_dir = FIGURES_ROOT / dir_name + print(f"\n{'='*60}") + print(f" {mod_name}") + print(f"{'='*60}") + try: + mod = importlib.import_module( + f"examples.paper.{mod_name}" + ) + run_fn = getattr(mod, run_fn_name) + run_fn(export_figures=True, export_dir=export_dir) + print(f" OK: figures saved to {export_dir}") + except Exception: + traceback.print_exc() + failed += 1 + + total = len(EXAMPLES) + print(f"\n=== Done. {total - failed}/{total} examples succeeded ===") + return 1 if failed else 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/nstat/cif.py b/nstat/cif.py index f9731d9d..c275667e 100644 --- a/nstat/cif.py +++ b/nstat/cif.py @@ -319,9 +319,7 @@ def from_linear_terms( name: str = "lambda", ) -> "CIFModel": eta = intercept + np.asarray(design_matrix, dtype=float) @ np.asarray(coefficients, dtype=float) - p = np.exp(np.clip(eta, -20.0, 20.0)) - p = p / (1.0 + p) - rate = p / max(float(dt), 1e-12) + rate = np.exp(np.clip(eta, -20.0, 20.0)) # Poisson log link: lambda = exp(eta) return cls(np.asarray(time, dtype=float).reshape(-1), rate, name) diff --git a/nstat/fit.py b/nstat/fit.py index f37dcf30..8e6bc3a3 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -56,7 +56,12 @@ def _pad_rows(rows: Sequence[np.ndarray], fill_value: float = np.nan) -> np.ndar return out -def _autocorrelation(values: np.ndarray, max_lag: int = 25) -> tuple[np.ndarray, np.ndarray]: +def _autocorrelation(values: np.ndarray, max_lag: int | None = None) -> tuple[np.ndarray, np.ndarray]: + """Compute normalized autocorrelation (xcov/xcov[0]) for lags 1..max_lag. + + Matches MATLAB ``xcov`` normalization: ``rho(k) = xcov(k) / xcov(0)``. + When *max_lag* is None (default) the full range is returned, matching MATLAB. + """ centered = np.asarray(values, dtype=float).reshape(-1) - float(np.mean(values)) if centered.size < 2 or float(np.var(centered)) <= 0.0: return np.asarray([], dtype=float), np.asarray([], dtype=float) @@ -64,8 +69,8 @@ def _autocorrelation(values: np.ndarray, max_lag: int = 25) -> tuple[np.ndarray, corr = corr[corr.size // 2 :] corr = corr / corr[0] lags = np.arange(corr.shape[0], dtype=float) - max_lag = int(min(max_lag, corr.shape[0] - 1)) - return lags[1 : max_lag + 1], corr[1 : max_lag + 1] + end = corr.shape[0] - 1 if max_lag is None else int(min(max_lag, corr.shape[0] - 1)) + return lags[1 : end + 1], corr[1 : end + 1] def _time_rescaled_uniforms(y: np.ndarray, lam_per_bin: np.ndarray) -> np.ndarray: @@ -249,7 +254,7 @@ def _matlab_compute_ks_arrays( else: Z = np.zeros((1, n_dims), dtype=float) - U = 1.0 - np.exp(-Z) + U = 1.0 - np.exp(-np.clip(Z, -700, 700)) # FIX: clip to avoid overflow → -inf in uniforms if U.ndim == 1: U = U[:, None] @@ -941,7 +946,7 @@ def _compute_diagnostics(self, fit_num: int = 1, *, dt_correction: int = 1) -> d ks_pvalue = np.nan within = np.nan gaussianized = norm.ppf(np.clip(uniforms, 1e-6, 1.0 - 1e-6)) - lags, acf = _autocorrelation(gaussianized, max_lag=25) + lags, acf = _autocorrelation(gaussianized) acf_ci = 1.96 / np.sqrt(float(gaussianized.size)) if gaussianized.size else np.nan coeffs = self._rawCoeffs(fit_num) se = _extract_standard_errors(self.stats[fit_num - 1] if fit_num - 1 < len(self.stats) else None, coeffs.size) @@ -1167,10 +1172,10 @@ def plotResults(self, fit_num: int = 1, handle=None): transform=ax_ks.transAxes, fontweight="bold", fontsize=10, verticalalignment="top", ) - self.plotInvGausTrans(fit_num=fit_num, handle=ax_ig) + self.plotInvGausTrans(fit_num=None, handle=ax_ig) self.plotSeqCorr(fit_num=None, handle=ax_sc) - self.plotCoeffs(fit_num=fit_num, handle=ax_co) - self.plotResidual(fit_num=fit_num, handle=ax_re) + self.plotCoeffs(fit_num=None, handle=ax_co) + self.plotResidual(fit_num=None, handle=ax_re) fig.tight_layout() return fig @@ -1233,32 +1238,97 @@ def KSPlot(self, fit_num: int | list[int] | None = None, handle=None): ax.set_title("KS Plot of Rescaled ISIs\nwith 95% Confidence Intervals", fontweight="bold", fontsize=11) return ax - def plotResidual(self, fit_num: int = 1, handle=None): - """Plot the martingale residual M(t) (Matlab ``plotResidual``).""" + def plotResidual(self, fit_num: int | list[int] | None = None, handle=None): + """Plot the martingale residual M(t) for one or more fits. + + Matches Matlab ``plotResidual``: plots all residuals with per-fit + colours and a legend using ``lambda.dataLabels``. + """ + if fit_num is None: + fit_nums = list(range(1, self.numResults + 1)) + elif isinstance(fit_num, int): + fit_nums = [fit_num] + else: + fit_nums = list(fit_num) + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - residual = self.computeFitResidual(fit_num) - ax.plot(np.asarray(residual.time, dtype=float), np.asarray(residual.data[:, 0], dtype=float), color="tab:purple", linewidth=1.0) + _SEQ_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"] + data_labels = ( + list(self.lambda_signal.dataLabels) + if getattr(self.lambda_signal, "dataLabels", None) + else [] + ) + for i, fn in enumerate(fit_nums): + residual = self.computeFitResidual(fn) + color = _SEQ_COLORS[i % len(_SEQ_COLORS)] + label = _ensure_mathtext( + data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}" + ) + ax.plot( + np.asarray(residual.time, dtype=float), + np.asarray(residual.data[:, 0], dtype=float), + color=color, linewidth=1.0, label=label, + ) ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") + if len(fit_nums) > 1: + ax.legend(loc="upper right", fontsize=8) ax.set_xlabel("time [s]") ax.set_ylabel("count residual") - ax.set_title("Fit Residual") + ax.set_title("Point Process Residual", fontweight="bold", fontsize=11) + # Match MATLAB: symmetric y-axis with 10% margin + ylims = ax.get_ylim() + max_y = max(abs(ylims[0]), abs(ylims[1])) * 1.1 + ax.set_ylim(-max_y, max_y) return ax - def plotInvGausTrans(self, fit_num: int = 1, handle=None): + def plotInvGausTrans(self, fit_num: int | list[int] | None = None, handle=None): """Plot ACF of gaussianized rescaled ISIs with 95% CIs. Matlab: plotInvGausTrans computes X_j = Φ⁻¹(U_j) and plots the autocorrelation function of X_j with 95% confidence bounds. + Supports multi-fit overlay with per-fit colours matching KS/SeqCorr. """ - diag = self._compute_diagnostics(fit_num) + if fit_num is None: + fit_nums = list(range(1, self.numResults + 1)) + elif isinstance(fit_num, int): + fit_nums = [fit_num] + else: + fit_nums = list(fit_num) + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - lags = np.asarray(diag["acf_lags"], dtype=float) - acf = np.asarray(diag["acf_values"], dtype=float) - if lags.size: - ax.vlines(lags, 0.0, acf, color="tab:orange", linewidth=1.4) - ax.axhline(float(diag["acf_ci"]), color="tab:red", linewidth=1.0) - ax.axhline(-float(diag["acf_ci"]), color="tab:red", linewidth=1.0) - ax.axhline(0.0, color="0.4", linewidth=1.0) + data_labels = ( + list(self.lambda_signal.dataLabels) + if getattr(self.lambda_signal, "dataLabels", None) + else [] + ) + _SEQ_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"] + legend_handles: list[object] = [] + legend_labels: list[str] = [] + ci_val = None + + for i, fn in enumerate(fit_nums): + diag = self._compute_diagnostics(fn) + lags = np.asarray(diag["acf_lags"], dtype=float) + acf = np.asarray(diag["acf_values"], dtype=float) + color = _SEQ_COLORS[i % len(_SEQ_COLORS)] + base_label = _ensure_mathtext( + data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}" + ) + if lags.size: + h, = ax.plot(lags, acf, ".", color=color, markersize=4.0) + legend_handles.append(h) + legend_labels.append(base_label) + if ci_val is None: + ci_val = float(diag["acf_ci"]) + + # Plot 95% CI lines without legend entries + if ci_val is not None: + ax.axhline(ci_val, color="0.4", linewidth=0.8, linestyle="--") + ax.axhline(-ci_val, color="0.4", linewidth=0.8, linestyle="--") + ax.axhline(0.0, color="0.4", linewidth=0.8) + + if legend_handles: + ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=8) ax.set_xlabel("lag") ax.set_ylabel("autocorrelation") ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs") @@ -1308,9 +1378,11 @@ def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None): uj = u[:-1] uj1 = u[1:] h, = ax.plot(uj, uj1, ".", color=color, markersize=4.0) - # Compute correlation coefficient and p-value - if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0: - rho, pval = pearsonr(uj, uj1) + # FIX: filter non-finite values before correlation + finite = np.isfinite(uj) & np.isfinite(uj1) + uj_f, uj1_f = uj[finite], uj1[finite] + if uj_f.size > 2 and np.std(uj_f) > 0 and np.std(uj1_f) > 0: + rho, pval = pearsonr(uj_f, uj1_f) label = f"{base_label}, $\\rho$={rho:.2g} (p={pval:.2g})" else: label = base_label @@ -1320,40 +1392,63 @@ def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None): if legend_handles: ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=10) - ax.set_title("Sequential Correlation") + ax.set_title("Sequential Correlation of\nRescaled ISIs", fontweight="bold", fontsize=11) ax.set_xlabel("$U_j$") ax.set_ylabel("$U_{j+1}$") ax.set_xlim(0, 1) ax.set_ylim(0, 1) return ax - def plotCoeffs(self, fit_num: int = 1, handle=None, plotSignificance: int = 1): + def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSignificance: int = 1): """Plot GLM coefficients with error bars and significance markers. - Matches Matlab FitResult.plotCoeffs: errorbar plot with ±1 SE, - and asterisks (*) above significant coefficients (p < 0.05). + Matches Matlab FitResult.plotCoeffs: when *fit_num* is ``None`` + (default) all fits are overlaid with per-fit colours, errorbar + plots with ±1 SE, and asterisks (*) above significant coefficients + (p < 0.05). """ - diag = self._compute_diagnostics(fit_num) + if fit_num is None: + fit_nums = list(range(1, self.numResults + 1)) + elif isinstance(fit_num, int): + fit_nums = [fit_num] + else: + fit_nums = list(fit_num) + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - coeffs = np.asarray(diag["coefficients"], dtype=float) - se = np.asarray(diag["coeff_se"], dtype=float) - sig = np.asarray(diag["coeff_sig"], dtype=float) - labels = list(np.asarray(diag["coeff_labels"], dtype=object)) - xpos = np.arange(1, coeffs.size + 1, dtype=float) + _SEQ_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"] ax.axhline(0.0, color="0.6", linewidth=1.0) - # Errorbar plot like Matlab (dot markers with SE whiskers) - valid_se = np.where(np.isfinite(se), se, 0.0) - ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color="tab:blue", - linewidth=1.0, markersize=8.0, capsize=3.0) - if plotSignificance and np.any(sig > 0): - ylims = ax.get_ylim() - y_star = 0.8 * ylims[1] - sig_idx = xpos[sig.astype(bool)] - ax.plot(sig_idx, np.full(sig_idx.size, y_star), "*", color="tab:blue", markersize=10.0) - ax.set_xticks(xpos) - ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=6) + + # Collect all labels across fits to build a unified x-axis + all_labels: list[str] = [] + for fn in fit_nums: + diag = self._compute_diagnostics(fn) + for lbl in np.asarray(diag["coeff_labels"], dtype=object): + if lbl not in all_labels: + all_labels.append(lbl) + label_to_x = {lbl: float(j + 1) for j, lbl in enumerate(all_labels)} + xpos_all = np.arange(1, len(all_labels) + 1, dtype=float) + + for i, fn in enumerate(fit_nums): + diag = self._compute_diagnostics(fn) + coeffs = np.asarray(diag["coefficients"], dtype=float) + se = np.asarray(diag["coeff_se"], dtype=float) + sig = np.asarray(diag["coeff_sig"], dtype=float) + fit_labels = list(np.asarray(diag["coeff_labels"], dtype=object)) + xpos = np.array([label_to_x[lbl] for lbl in fit_labels]) + color = _SEQ_COLORS[i % len(_SEQ_COLORS)] + valid_se = np.where(np.isfinite(se), se, 0.0) + ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color=color, + linewidth=1.0, markersize=8.0, capsize=3.0) + if plotSignificance and np.any(sig > 0): + ylims = ax.get_ylim() + y_star = 0.8 * ylims[1] - i * 0.1 + sig_idx = xpos[sig.astype(bool)] + ax.plot(sig_idx, np.full(sig_idx.size, y_star), "*", color=color, markersize=10.0) + + ax.set_xticks(xpos_all) + ax.set_xticklabels(all_labels, rotation=45, ha="right", fontsize=6) ax.set_ylabel("GLM Fit Coefficients") - ax.set_title("GLM Coefficients") + ax.set_title("GLM Coefficients", fontweight="bold", fontsize=11) return ax def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None):