Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,19 @@ data_cache/
data/
output/
verification_output/

# Build artifacts
dist/
*.egg-info/

# MATLAB / Simulink cache
*.slxc
slprj/

# IDE
.idea/

# Temporary debug / verification scripts
debug_*.py
verify_all_examples.py
tests/test_matlab_python_parity.py
Binary file modified docs/figures/example01/fig01_constant_mg_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example01/fig02_washout_raster_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example01/fig03_piecewise_baseline_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example02/fig01_data_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example02/fig02_lag_and_model_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig01_simulated_and_real_rasters.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig04_ssglm_fit_diagnostics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig05_stimulus_effect_surfaces.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig06_learning_trial_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig03_reach_and_population_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig05_hybrid_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig06_hybrid_decoding_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 14 additions & 21 deletions examples/paper/example03_psth_and_ssglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,13 @@ def run_part_a(data_dir, export_dir=None):

# Top-left: CIF
ax = axes1[0, 0]
ax.plot(time, lambdaData, "b", linewidth=2)
ax.plot(time, lambdaData, "b", linewidth=2, label=r"$\lambda_i$")
ax.set_title("Simulated Conditional Intensity Function (CIF)",
fontweight="bold", fontsize=14, fontfamily="Arial")
ax.set_xlabel("time [s]", fontsize=12, fontweight="bold", fontfamily="Arial")
ax.set_ylabel(r"$\lambda(t)$ [spikes/sec]", fontsize=14, fontweight="bold",
fontfamily="Arial")
ax.legend(loc="upper right")

# Bottom-left: simulated raster
ax = axes1[1, 0]
Expand Down Expand Up @@ -499,38 +500,30 @@ def run_part_b(data_dir, export_dir=None):

# ------------------------------------------------------------------
# Figure 5: True/PSTH/SSGLM stimulus effect surfaces
# Match MATLAB: mesh(trial, time, data) with view([90 -90]) → top-down
# MATLAB orientation: time [s] on x-axis, Trial [k] on y-axis
# (matches fig03 bottom-panel "True Conditional Intensity Function")
# MATLAB: mesh(trial, time, data) with view([90 -90]) renders as a
# top-down colored heatmap (MATLAB applies its colormap to Z-values).
# Python equivalent: pcolormesh with viridis (≈MATLAB parula default).
# MATLAB orientation: trial on x-axis, time on y-axis (view [90 -90]).
# ------------------------------------------------------------------
from mpl_toolkits.mplot3d import Axes3D # noqa: F401

fig5 = plt.figure(figsize=(14, 9))
fig5, axes5 = plt.subplots(3, 1, figsize=(14, 9))
trial_axis = np.arange(1, numRealizations + 1)
T_act = min(actStimEffect.shape[0], len(basis_time))

# Build meshgrid: MATLAB mesh(trial, time, data) — trial on X, time on Y
K_mesh, T_mesh = np.meshgrid(trial_axis, basis_time[:T_act])

# MATLAB uses mesh() which renders wireframe only (no filled faces).
# plot_wireframe is the closest matplotlib equivalent.
surfaces = [
("True Stimulus Effect", actStimEffect[:T_act, :]),
("PSTH Estimated Stimulus Effect", psthSurface2D[:T_act, :]),
("SSGLM Estimated Stimulus Effect", estStimEffect[:T_act, :]),
]
for idx, (title, data) in enumerate(surfaces, 1):
ax = fig5.add_subplot(3, 1, idx, projection="3d")
ax.plot_wireframe(K_mesh, T_mesh, data,
rstride=5, cstride=1,
linewidth=0.3, color="k")
ax.view_init(elev=-90, azim=90)
for ax, (title, data) in zip(axes5, surfaces):
# MATLAB mesh(trial, time, data) viewed from above: x=trial, y=time
ax.pcolormesh(trial_axis, basis_time[:T_act], data,
shading="gouraud", cmap="viridis")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title, fontweight="bold", fontsize=14, fontfamily="Arial")

fig5.tight_layout()
print(" Figure 5: Stimulus effect surfaces (top-down mesh)")
print(" Figure 5: Stimulus effect surfaces (top-down heatmap)")

# ------------------------------------------------------------------
# 6. Learning-trial analysis: spike rate CIs
Expand Down Expand Up @@ -588,9 +581,9 @@ def run_part_b(data_dir, export_dir=None):
ax3.fill_between(basis_time, ci1_lo, ci1_hi, alpha=0.3, color="gray")
ax3.fill_between(basis_time, cilt_lo, cilt_hi, alpha=0.3, color="red")
h1, = ax3.plot(basis_time, stim1_data, "k", linewidth=4,
label="\\lambda_1(t)")
label=r"$\lambda_1(t)$")
h2, = ax3.plot(basis_time, stimlt_data, "r", linewidth=4,
label=f"\\lambda_{{{lt}}}(t)")
label=rf"$\lambda_{{{lt}}}(t)$")
ax3.legend(handles=[h1, h2])
ax3.set_xlabel("time [s]")
ax3.set_ylabel("Firing Rate [spikes/sec]")
Expand Down
18 changes: 12 additions & 6 deletions examples/paper/example05_decoding_ppaf_pphf.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,10 @@ def _plot_part_b(result):
# Top-left [1,3]: 2D reach path (in cm)
ax_path = fig3.add_subplot(4, 2, (1, 3))
ax_path.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=2)
ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14)
ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14)
ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14,
markerfacecolor="none", markeredgewidth=2)
ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14,
markerfacecolor="none", markeredgewidth=2)
ax_path.legend(["Path", "Start", "Finish"], loc="upper right")
ax_path.set_xlabel("X Position [cm]")
ax_path.set_ylabel("Y Position [cm]")
Expand Down Expand Up @@ -690,8 +692,10 @@ def _plot_part_c(result):
# Top-left [1,3]: 2D reach path
ax_path = fig5.add_subplot(4, 2, (1, 3))
ax_path.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=2)
ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16)
ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16)
ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16,
markerfacecolor="none", markeredgewidth=2)
ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16,
markerfacecolor="none", markeredgewidth=2)
ax_path.set_xlabel("X [cm]")
ax_path.set_ylabel("Y [cm]")
ax_path.set_title("Reach Path", fontweight="bold", fontsize=14)
Expand Down Expand Up @@ -773,8 +777,10 @@ def _plot_part_c(result):
ax_2d.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=1)
ax_2d.plot(mX_est[0, :], mX_est[1, :], "b", linewidth=3)
ax_2d.plot(mX_estNT[0, :], mX_estNT[1, :], "g", linewidth=3)
ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14)
ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14)
ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14,
markerfacecolor="none", markeredgewidth=2)
ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14,
markerfacecolor="none", markeredgewidth=2)
ax_2d.set_xlabel("x [cm]")
ax_2d.set_ylabel("y [cm]")
ax_2d.set_title("Estimated vs. Actual Reach Path",
Expand Down
Binary file modified examples/paper/figures/example01/fig01_constant_mg_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/paper/figures/example01/fig02_washout_raster_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/paper/figures/example02/fig01_data_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/paper/figures/example05/fig05_hybrid_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 7 additions & 4 deletions nstat/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,9 @@ def plotInvGausTrans(self, fit_num: int | list[int] | None = None, handle=None):
data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}"
)
if lags.size:
h, = ax.plot(lags, acf, ".", color=color, markersize=4.0)
# MATLAB: scatter dots (plot with '.') — use small markers
# so individual lags are visible, not bar-like blobs.
h, = ax.plot(lags, acf, ".", color=color, markersize=3.0)
legend_handles.append(h)
legend_labels.append(base_label)
if ci_val is None:
Expand Down Expand Up @@ -1456,9 +1458,10 @@ def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSi
xpos = np.array([label_to_x[lbl] for lbl in fit_labels])
color = _SEQ_COLORS[i % len(_SEQ_COLORS)]
valid_se = np.where(np.isfinite(se), se, 0.0)
# Larger markers and thicker error bars to match MATLAB visibility
h = ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color=color,
linewidth=1.5, markersize=12.0, capsize=5.0,
# MATLAB plots 95% CI = 1.96 * SE as error bars
ci95 = 1.96 * valid_se
h = ax.errorbar(xpos, coeffs, yerr=ci95, fmt="o", color=color,
linewidth=1.5, markersize=8.0, capsize=5.0,
markeredgecolor=color, markerfacecolor=color)
errorbar_handles.append(h)
if plotSignificance and np.any(sig > 0):
Expand Down
Loading