diff --git a/docs/figures/example01/fig01_constant_mg_summary.png b/docs/figures/example01/fig01_constant_mg_summary.png index 0b026f75..9da8d1d1 100644 Binary files a/docs/figures/example01/fig01_constant_mg_summary.png and b/docs/figures/example01/fig01_constant_mg_summary.png differ diff --git a/docs/figures/example01/fig03_piecewise_baseline_comparison.png b/docs/figures/example01/fig03_piecewise_baseline_comparison.png index dca681a0..c9ef13ca 100644 Binary files a/docs/figures/example01/fig03_piecewise_baseline_comparison.png and b/docs/figures/example01/fig03_piecewise_baseline_comparison.png differ diff --git a/docs/figures/example03/fig04_ssglm_fit_diagnostics.png b/docs/figures/example03/fig04_ssglm_fit_diagnostics.png index 74e94daa..c5846300 100644 Binary files a/docs/figures/example03/fig04_ssglm_fit_diagnostics.png and b/docs/figures/example03/fig04_ssglm_fit_diagnostics.png differ diff --git a/docs/figures/example05/fig01_univariate_setup.png b/docs/figures/example05/fig01_univariate_setup.png index 7c3007fa..94dc2a8f 100644 Binary files a/docs/figures/example05/fig01_univariate_setup.png and b/docs/figures/example05/fig01_univariate_setup.png differ diff --git a/docs/figures/example05/fig02_univariate_decoding.png b/docs/figures/example05/fig02_univariate_decoding.png index a1fcb7ac..fbee2a41 100644 Binary files a/docs/figures/example05/fig02_univariate_decoding.png and b/docs/figures/example05/fig02_univariate_decoding.png differ diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 89e41a67..923cdabd 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -27,7 +27,7 @@ Expected outputs: - Figure 1: CIF tuning curves and simulated spike raster. - - Figure 2: Decoded stimulus vs true (with ±2σ confidence band). + - Figure 2: Decoded stimulus vs true (with 95% confidence band). - Figure 3: Reach trajectory and population spike raster. - Figure 4: PPAF comparison (free vs goal-informed, 20 runs box plot). - Figure 5: Hybrid filter setup (state sequence, spike raster). @@ -96,10 +96,9 @@ def _run_part_a(seed=11, n_cells=20): x_true = np.sin(2.0 * np.pi * 2.0 * time) # (T,) # ── Encoding model: logistic CIF ── - # mu_c ~ log(10*delta) + N(0, 0.3) — baseline firing probability - # beta_c ~ N(1.0, 0.5) — stimulus gain - b0 = np.log(10.0 * delta) * np.ones(n_cells) + rng.normal(0.0, 0.3, n_cells) - b1 = rng.normal(1.0, 0.5, n_cells) + # MATLAB: b0 = log(10*delta) + randn(C,1); b1 = randn(C,1); + b0 = np.log(10.0 * delta) + rng.standard_normal(n_cells) + b1 = rng.standard_normal(n_cells) # Simulate spikes x_2d = x_true.reshape(1, -1) # (1, T) — scalar state @@ -108,8 +107,10 @@ def _run_part_a(seed=11, n_cells=20): # ── State-space model ── # x(t+1) = A * x(t) + w, w ~ N(0, Q) + # MATLAB: Q = std(stim.data(2:end) - stim.data(1:end-1)); A = 1; A = np.array([[1.0]]) - Q = np.array([[0.001]]) + Q_val = float(np.std(np.diff(x_true))) + Q = np.array([[Q_val]]) x0 = np.array([0.0]) Pi0 = 0.5 * np.eye(1) @@ -119,11 +120,12 @@ def _run_part_a(seed=11, n_cells=20): A, Q, dN, b0, beta, "binomial", delta, None, None, x0, Pi0 ) - # Extract decoded signal and ±2σ confidence band + # Extract decoded signal and 95% CI (±1.96σ, matching MATLAB zVal=1.96) x_decoded = x_u[0, :] # (T,) sigma = np.sqrt(np.maximum(W_u[0, 0, :], 0.0)) - ci_low = x_decoded - 2.0 * sigma - ci_high = x_decoded + 2.0 * sigma + z_val = 1.96 + ci_low = np.minimum(x_decoded - z_val * sigma, x_decoded + z_val * sigma) + ci_high = np.maximum(x_decoded - z_val * sigma, x_decoded + z_val * sigma) rmse = float(np.sqrt(np.mean((x_decoded - x_true) ** 2))) return { @@ -371,36 +373,53 @@ def _plot_part_a(result): time = result["time"] x_true = result["x_true"] dN = result["dN"] + delta = time[1] - time[0] - # ── Figure 1: CIF tuning and spike raster ── - fig1, axes1 = plt.subplots(2, 1, figsize=(10, 6), sharex=True) + # ── Figure 1: stimulus, CIF, spike raster (3 panels, matching MATLAB) ── + fig1, axes1 = plt.subplots(3, 1, figsize=(10, 8), sharex=True) - # Top: true stimulus + # Top: driving stimulus axes1[0].plot(time, x_true, "k-", linewidth=1.5) - axes1[0].set_ylabel("Stimulus x(t)") - axes1[0].set_title("Part A: Sinusoidal Stimulus Encoding") + axes1[0].set_ylabel("Stimulus") + axes1[0].set_title("Driving Stimulus", fontweight="bold", fontsize=14) + axes1[0].tick_params(labelbottom=False) - # Bottom: spike raster + # Middle: conditional intensity functions (firing rates in spikes/sec) + b0 = result["b0"] + b1 = result["b1"] n_cells = dN.shape[0] + for c in range(n_cells): + eta = b1[c] * x_true + b0[c] + exp_eta = np.exp(eta) + lam = (exp_eta / (1.0 + exp_eta)) / delta # probability → rate (Hz) + axes1[1].plot(time, lam, "k-", linewidth=1.0) + axes1[1].set_ylabel("Firing Rate [spikes/sec]") + axes1[1].set_title("Conditional Intensity Functions", fontweight="bold", fontsize=14) + axes1[1].tick_params(labelbottom=False) + + # Bottom: spike raster for c in range(n_cells): spike_times = time[dN[c, :] > 0] - axes1[1].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2) - axes1[1].set_ylabel("Neuron") - axes1[1].set_xlabel("Time (s)") - axes1[1].set_ylim(0.5, n_cells + 0.5) + axes1[2].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2) + axes1[2].set_ylabel("Cell Number") + axes1[2].set_xlabel("time [s]") + axes1[2].set_ylim(0.5, n_cells + 0.5) + axes1[2].set_yticks(np.arange(0, n_cells + 1, 10)) + axes1[2].set_title("Point Process Sample Paths", fontweight="bold", fontsize=14) fig1.tight_layout() - # ── Figure 2: Decoding results ── + # ── Figure 2: Decoding results (MATLAB: black=decoded, blue=actual) ── fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4)) - ax2.plot(time, x_true, "k-", linewidth=1.5, label="True stimulus") - ax2.plot(time, result["x_decoded"], "r-", linewidth=1.0, label="PPAF decoded") ax2.fill_between( time, result["ci_low"], result["ci_high"], - color="red", alpha=0.15, label="±2σ CI" + color="0.75", alpha=0.4, label="95% CI" ) - ax2.set_xlabel("Time (s)") - ax2.set_ylabel("x(t)") - ax2.set_title(f"PPDecodeFilterLinear — Decoded Stimulus (RMSE = {result['rmse']:.4f})") + ax2.plot(time, result["x_decoded"], "k-", linewidth=2.0, label="Decoded") + ax2.plot(time, x_true, "b-", linewidth=2.0, label="Actual") + ax2.set_xlabel("time [s]") + ax2.set_ylabel("") + ax2.set_title(f"Decoded Stimulus $\\pm$ 95% CIs with {result['n_cells']} cells", + fontweight="bold", fontsize=14) ax2.legend(loc="upper right") fig2.tight_layout() diff --git a/nstat/fit.py b/nstat/fit.py index 6df778d3..f37dcf30 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -9,7 +9,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np -from scipy.stats import norm +from scipy.stats import norm, pearsonr from .core import Covariate, nspikeTrain @@ -18,6 +18,20 @@ def _ordered_unique(labels: Sequence[str]) -> list[str]: return list(dict.fromkeys(str(label) for label in labels)) +def _ensure_mathtext(label: str) -> str: + """Wrap a label in ``$...$`` if it contains LaTeX commands but isn't already wrapped.""" + s = str(label) + if not s: + return s + # Already wrapped in math delimiters — leave as-is + if s.startswith("$") and s.endswith("$"): + return s + # Contains LaTeX commands (e.g. \lambda, \rho) — wrap in $...$ + if re.search(r"\\[a-zA-Z]", s): + return f"${s}$" + return s + + def _parse_neuron_number(spike_obj: nspikeTrain | Sequence[nspikeTrain]) -> str | float: if isinstance(spike_obj, Sequence) and not isinstance(spike_obj, nspikeTrain): names = [str(item.name) for item in spike_obj if getattr(item, "name", "")] @@ -1154,7 +1168,7 @@ def plotResults(self, fit_num: int = 1, handle=None): verticalalignment="top", ) self.plotInvGausTrans(fit_num=fit_num, handle=ax_ig) - self.plotSeqCorr(fit_num=fit_num, handle=ax_sc) + self.plotSeqCorr(fit_num=None, handle=ax_sc) self.plotCoeffs(fit_num=fit_num, handle=ax_co) self.plotResidual(fit_num=fit_num, handle=ax_re) fig.tight_layout() @@ -1202,7 +1216,8 @@ def KSPlot(self, fit_num: int | list[int] | None = None, handle=None): ideal = np.asarray(diag["ks_ideal"], dtype=float) empirical = np.asarray(diag["ks_empirical"], dtype=float) color = self._MATLAB_KS_COLORS[i % len(self._MATLAB_KS_COLORS)] - label = data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}" + raw_label = data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}" + label = _ensure_mathtext(raw_label) if ideal.size: h, = ax.plot(ideal, empirical, color=color, linewidth=2.0) handles_for_legend.append(h) @@ -1249,29 +1264,63 @@ def plotInvGausTrans(self, fit_num: int = 1, handle=None): ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs") return ax - def plotSeqCorr(self, fit_num: int = 1, handle=None): - """Plot U_j vs U_{j+1} scatter with correlation coefficient. + def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None): + """Plot U_j vs U_{j+1} scatter with correlation coefficient and p-value. Matlab: plotSeqCorr plots the sequential correlation scatter of U_j (uniform-transformed rescaled ISIs) to detect serial dependence. + When multiple models are present, each is plotted with a different + colour and a legend entry showing ``label, ρ=X.XX (p=Y.YY)``. + + Parameters + ---------- + fit_num : int, list of int, or None + Which model(s) to plot. ``None`` (default) plots all models, + matching the MATLAB default behaviour. + handle : matplotlib Axes, optional + Axes to draw on. A new figure is created when *None*. """ - diag = self._compute_diagnostics(fit_num) - ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] - u = np.asarray(diag.get("uniforms", []), dtype=float) - if u.size > 1: - uj = u[:-1] - uj1 = u[1:] - ax.plot(uj, uj1, ".", color="tab:blue", markersize=4.0) - # Compute correlation coefficient (guard against constant series) - if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0: - with np.errstate(invalid="ignore"): - rho_mat = np.corrcoef(uj, uj1) - rho = rho_mat[0, 1] if rho_mat.shape[0] > 1 else float("nan") - ax.set_title(f"Sequential Correlation ($\\rho$ = {rho:.2g})") - else: - ax.set_title("Sequential Correlation") + if fit_num is None: + fit_nums = list(range(1, self.numResults + 1)) + elif isinstance(fit_num, int): + fit_nums = [fit_num] else: - ax.set_title("Sequential Correlation") + fit_nums = list(fit_num) + + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1] + data_labels = ( + list(self.lambda_signal.dataLabels) + if getattr(self.lambda_signal, "dataLabels", None) + else [] + ) + _SEQ_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"] + legend_labels: list[str] = [] + legend_handles: list[object] = [] + + for i, fn in enumerate(fit_nums): + diag = self._compute_diagnostics(fn) + u = np.asarray(diag.get("uniforms", []), dtype=float) + base_label = _ensure_mathtext( + data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}" + ) + color = _SEQ_COLORS[i % len(_SEQ_COLORS)] + if u.size > 1: + uj = u[:-1] + uj1 = u[1:] + h, = ax.plot(uj, uj1, ".", color=color, markersize=4.0) + # Compute correlation coefficient and p-value + if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0: + rho, pval = pearsonr(uj, uj1) + label = f"{base_label}, $\\rho$={rho:.2g} (p={pval:.2g})" + else: + label = base_label + legend_handles.append(h) + legend_labels.append(label) + + if legend_handles: + ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=10) + + ax.set_title("Sequential Correlation") ax.set_xlabel("$U_j$") ax.set_ylabel("$U_{j+1}$") ax.set_xlim(0, 1) diff --git a/nstat/paper_examples.py b/nstat/paper_examples.py index f2bda945..856306fe 100644 --- a/nstat/paper_examples.py +++ b/nstat/paper_examples.py @@ -226,8 +226,8 @@ def run_experiment5(seed: int = 11, *, return_payload: bool = False) -> Summary n_cells = 20 spikes = np.zeros((time.shape[0], n_cells), dtype=float) for i in range(n_cells): - b1 = rng.normal(1.0, 0.5) - b0 = np.log(10.0 * dt) + rng.normal(0.0, 0.3) + b1 = rng.standard_normal() + b0 = np.log(10.0 * dt) + rng.standard_normal() eta = b1 * stim + b0 p = np.exp(eta) p = p / (1.0 + p)