Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/figures/example01/fig01_constant_mg_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example01/fig03_piecewise_baseline_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig04_ssglm_fit_diagnostics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig01_univariate_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig02_univariate_decoding.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 45 additions & 26 deletions examples/paper/example05_decoding_ppaf_pphf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

Expected outputs:
- Figure 1: CIF tuning curves and simulated spike raster.
- Figure 2: Decoded stimulus vs true (with ±2σ confidence band).
- Figure 2: Decoded stimulus vs true (with 95% confidence band).
- Figure 3: Reach trajectory and population spike raster.
- Figure 4: PPAF comparison (free vs goal-informed, 20 runs box plot).
- Figure 5: Hybrid filter setup (state sequence, spike raster).
Expand Down Expand Up @@ -96,10 +96,9 @@ def _run_part_a(seed=11, n_cells=20):
x_true = np.sin(2.0 * np.pi * 2.0 * time) # (T,)

# ── Encoding model: logistic CIF ──
# mu_c ~ log(10*delta) + N(0, 0.3) — baseline firing probability
# beta_c ~ N(1.0, 0.5) — stimulus gain
b0 = np.log(10.0 * delta) * np.ones(n_cells) + rng.normal(0.0, 0.3, n_cells)
b1 = rng.normal(1.0, 0.5, n_cells)
# MATLAB: b0 = log(10*delta) + randn(C,1); b1 = randn(C,1);
b0 = np.log(10.0 * delta) + rng.standard_normal(n_cells)
b1 = rng.standard_normal(n_cells)

# Simulate spikes
x_2d = x_true.reshape(1, -1) # (1, T) — scalar state
Expand All @@ -108,8 +107,10 @@ def _run_part_a(seed=11, n_cells=20):

# ── State-space model ──
# x(t+1) = A * x(t) + w, w ~ N(0, Q)
# MATLAB: Q = std(stim.data(2:end) - stim.data(1:end-1)); A = 1;
A = np.array([[1.0]])
Q = np.array([[0.001]])
Q_val = float(np.std(np.diff(x_true)))
Q = np.array([[Q_val]])
x0 = np.array([0.0])
Pi0 = 0.5 * np.eye(1)

Expand All @@ -119,11 +120,12 @@ def _run_part_a(seed=11, n_cells=20):
A, Q, dN, b0, beta, "binomial", delta, None, None, x0, Pi0
)

# Extract decoded signal and ±2σ confidence band
# Extract decoded signal and 95% CI (±1.96σ, matching MATLAB zVal=1.96)
x_decoded = x_u[0, :] # (T,)
sigma = np.sqrt(np.maximum(W_u[0, 0, :], 0.0))
ci_low = x_decoded - 2.0 * sigma
ci_high = x_decoded + 2.0 * sigma
z_val = 1.96
ci_low = np.minimum(x_decoded - z_val * sigma, x_decoded + z_val * sigma)
ci_high = np.maximum(x_decoded - z_val * sigma, x_decoded + z_val * sigma)
rmse = float(np.sqrt(np.mean((x_decoded - x_true) ** 2)))

return {
Expand Down Expand Up @@ -371,36 +373,53 @@ def _plot_part_a(result):
time = result["time"]
x_true = result["x_true"]
dN = result["dN"]
delta = time[1] - time[0]

# ── Figure 1: CIF tuning and spike raster ──
fig1, axes1 = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
# ── Figure 1: stimulus, CIF, spike raster (3 panels, matching MATLAB) ──
fig1, axes1 = plt.subplots(3, 1, figsize=(10, 8), sharex=True)

# Top: true stimulus
# Top: driving stimulus
axes1[0].plot(time, x_true, "k-", linewidth=1.5)
axes1[0].set_ylabel("Stimulus x(t)")
axes1[0].set_title("Part A: Sinusoidal Stimulus Encoding")
axes1[0].set_ylabel("Stimulus")
axes1[0].set_title("Driving Stimulus", fontweight="bold", fontsize=14)
axes1[0].tick_params(labelbottom=False)

# Bottom: spike raster
# Middle: conditional intensity functions (firing rates in spikes/sec)
b0 = result["b0"]
b1 = result["b1"]
n_cells = dN.shape[0]
for c in range(n_cells):
eta = b1[c] * x_true + b0[c]
exp_eta = np.exp(eta)
lam = (exp_eta / (1.0 + exp_eta)) / delta # probability → rate (Hz)
axes1[1].plot(time, lam, "k-", linewidth=1.0)
axes1[1].set_ylabel("Firing Rate [spikes/sec]")
axes1[1].set_title("Conditional Intensity Functions", fontweight="bold", fontsize=14)
axes1[1].tick_params(labelbottom=False)

# Bottom: spike raster
for c in range(n_cells):
spike_times = time[dN[c, :] > 0]
axes1[1].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2)
axes1[1].set_ylabel("Neuron")
axes1[1].set_xlabel("Time (s)")
axes1[1].set_ylim(0.5, n_cells + 0.5)
axes1[2].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2)
axes1[2].set_ylabel("Cell Number")
axes1[2].set_xlabel("time [s]")
axes1[2].set_ylim(0.5, n_cells + 0.5)
axes1[2].set_yticks(np.arange(0, n_cells + 1, 10))
axes1[2].set_title("Point Process Sample Paths", fontweight="bold", fontsize=14)
fig1.tight_layout()

# ── Figure 2: Decoding results ──
# ── Figure 2: Decoding results (MATLAB: black=decoded, blue=actual) ──
fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4))
ax2.plot(time, x_true, "k-", linewidth=1.5, label="True stimulus")
ax2.plot(time, result["x_decoded"], "r-", linewidth=1.0, label="PPAF decoded")
ax2.fill_between(
time, result["ci_low"], result["ci_high"],
color="red", alpha=0.15, label="±2σ CI"
color="0.75", alpha=0.4, label="95% CI"
)
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("x(t)")
ax2.set_title(f"PPDecodeFilterLinear — Decoded Stimulus (RMSE = {result['rmse']:.4f})")
ax2.plot(time, result["x_decoded"], "k-", linewidth=2.0, label="Decoded")
ax2.plot(time, x_true, "b-", linewidth=2.0, label="Actual")
ax2.set_xlabel("time [s]")
ax2.set_ylabel("")
ax2.set_title(f"Decoded Stimulus $\\pm$ 95% CIs with {result['n_cells']} cells",
fontweight="bold", fontsize=14)
ax2.legend(loc="upper right")
fig2.tight_layout()

Expand Down
91 changes: 70 additions & 21 deletions nstat/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from scipy.stats import norm, pearsonr

from .core import Covariate, nspikeTrain

Expand All @@ -18,6 +18,20 @@ def _ordered_unique(labels: Sequence[str]) -> list[str]:
return list(dict.fromkeys(str(label) for label in labels))


def _ensure_mathtext(label: str) -> str:
"""Wrap a label in ``$...$`` if it contains LaTeX commands but isn't already wrapped."""
s = str(label)
if not s:
return s
# Already wrapped in math delimiters — leave as-is
if s.startswith("$") and s.endswith("$"):
return s
# Contains LaTeX commands (e.g. \lambda, \rho) — wrap in $...$
if re.search(r"\\[a-zA-Z]", s):
return f"${s}$"
return s


def _parse_neuron_number(spike_obj: nspikeTrain | Sequence[nspikeTrain]) -> str | float:
if isinstance(spike_obj, Sequence) and not isinstance(spike_obj, nspikeTrain):
names = [str(item.name) for item in spike_obj if getattr(item, "name", "")]
Expand Down Expand Up @@ -1154,7 +1168,7 @@ def plotResults(self, fit_num: int = 1, handle=None):
verticalalignment="top",
)
self.plotInvGausTrans(fit_num=fit_num, handle=ax_ig)
self.plotSeqCorr(fit_num=fit_num, handle=ax_sc)
self.plotSeqCorr(fit_num=None, handle=ax_sc)
self.plotCoeffs(fit_num=fit_num, handle=ax_co)
self.plotResidual(fit_num=fit_num, handle=ax_re)
fig.tight_layout()
Expand Down Expand Up @@ -1202,7 +1216,8 @@ def KSPlot(self, fit_num: int | list[int] | None = None, handle=None):
ideal = np.asarray(diag["ks_ideal"], dtype=float)
empirical = np.asarray(diag["ks_empirical"], dtype=float)
color = self._MATLAB_KS_COLORS[i % len(self._MATLAB_KS_COLORS)]
label = data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}"
raw_label = data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}"
label = _ensure_mathtext(raw_label)
if ideal.size:
h, = ax.plot(ideal, empirical, color=color, linewidth=2.0)
handles_for_legend.append(h)
Expand Down Expand Up @@ -1249,29 +1264,63 @@ def plotInvGausTrans(self, fit_num: int = 1, handle=None):
ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs")
return ax

def plotSeqCorr(self, fit_num: int = 1, handle=None):
"""Plot U_j vs U_{j+1} scatter with correlation coefficient.
def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None):
"""Plot U_j vs U_{j+1} scatter with correlation coefficient and p-value.

Matlab: plotSeqCorr plots the sequential correlation scatter of
U_j (uniform-transformed rescaled ISIs) to detect serial dependence.
When multiple models are present, each is plotted with a different
colour and a legend entry showing ``label, ρ=X.XX (p=Y.YY)``.

Parameters
----------
fit_num : int, list of int, or None
Which model(s) to plot. ``None`` (default) plots all models,
matching the MATLAB default behaviour.
handle : matplotlib Axes, optional
Axes to draw on. A new figure is created when *None*.
"""
diag = self._compute_diagnostics(fit_num)
ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1]
u = np.asarray(diag.get("uniforms", []), dtype=float)
if u.size > 1:
uj = u[:-1]
uj1 = u[1:]
ax.plot(uj, uj1, ".", color="tab:blue", markersize=4.0)
# Compute correlation coefficient (guard against constant series)
if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0:
with np.errstate(invalid="ignore"):
rho_mat = np.corrcoef(uj, uj1)
rho = rho_mat[0, 1] if rho_mat.shape[0] > 1 else float("nan")
ax.set_title(f"Sequential Correlation ($\\rho$ = {rho:.2g})")
else:
ax.set_title("Sequential Correlation")
if fit_num is None:
fit_nums = list(range(1, self.numResults + 1))
elif isinstance(fit_num, int):
fit_nums = [fit_num]
else:
ax.set_title("Sequential Correlation")
fit_nums = list(fit_num)

ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 3.5))[1]
data_labels = (
list(self.lambda_signal.dataLabels)
if getattr(self.lambda_signal, "dataLabels", None)
else []
)
_SEQ_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"]
legend_labels: list[str] = []
legend_handles: list[object] = []

for i, fn in enumerate(fit_nums):
diag = self._compute_diagnostics(fn)
u = np.asarray(diag.get("uniforms", []), dtype=float)
base_label = _ensure_mathtext(
data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}"
)
color = _SEQ_COLORS[i % len(_SEQ_COLORS)]
if u.size > 1:
uj = u[:-1]
uj1 = u[1:]
h, = ax.plot(uj, uj1, ".", color=color, markersize=4.0)
# Compute correlation coefficient and p-value
if uj.size > 2 and np.std(uj) > 0 and np.std(uj1) > 0:
rho, pval = pearsonr(uj, uj1)
label = f"{base_label}, $\\rho$={rho:.2g} (p={pval:.2g})"
else:
label = base_label
legend_handles.append(h)
legend_labels.append(label)

if legend_handles:
ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=10)

ax.set_title("Sequential Correlation")
ax.set_xlabel("$U_j$")
ax.set_ylabel("$U_{j+1}$")
ax.set_xlim(0, 1)
Expand Down
4 changes: 2 additions & 2 deletions nstat/paper_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def run_experiment5(seed: int = 11, *, return_payload: bool = False) -> Summary
n_cells = 20
spikes = np.zeros((time.shape[0], n_cells), dtype=float)
for i in range(n_cells):
b1 = rng.normal(1.0, 0.5)
b0 = np.log(10.0 * dt) + rng.normal(0.0, 0.3)
b1 = rng.standard_normal()
b0 = np.log(10.0 * dt) + rng.standard_normal()
eta = b1 * stim + b0
p = np.exp(eta)
p = p / (1.0 + p)
Expand Down
Loading