From 9181ec1e9eaae71de97b3e3e7c4b72aa7f44af09 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Thu, 19 Mar 2026 23:03:04 -0500 Subject: [PATCH] fix: match example05 figures to MATLAB reference exactly Fig 2 (univariate decoding): - Replace fill_between CI shading with thin CI *lines* matching MATLAB - Increase decoded/actual linewidth to 4 (was 2) matching MATLAB - Add ylabel "Stimulus" and increase title fontsize to 18 Fig 3 (reach setup): - Use unfilled circle markers for Start/Finish (MATLAB default) Fig 4 (PPAF goal vs free): - Add Start/Finish unfilled circle markers with legend on 2D path Fig 5 (hybrid setup): - Use unfilled circle markers for Start/Finish - Add Start/Finish legend on reach path subplot Fig 6 (hybrid decoding summary): - Overlay all 20 individual simulation traces as thin dashed lines (matching MATLAB's in-loop plotting of S_est, X_est, MU_est) - Plot thick solid mean traces on top (matching MATLAB post-loop code) - Add Start/Finish unfilled markers and legend on 2D reach path - Set state axis ylim to [0.5, 2.5] matching MATLAB Co-Authored-By: Claude Opus 4.6 (1M context) --- .../paper/example05_decoding_ppaf_pphf.py | 117 +++++++++++++----- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 8430129..2602e04 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -518,17 +518,19 @@ def _plot_part_a(result): fig1.tight_layout() # ── Figure 2: Decoded vs true (single axes, MATLAB style) ── + # MATLAB uses thin CI *lines* (not shaded fill), thick decoded/actual fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4)) - ax2.fill_between(time, result["ci_low"], result["ci_high"], - color="0.75", alpha=0.4, label="95% CI") - ax2.plot(time, result["x_decoded"], "k-", linewidth=2.0, label="Decoded") - ax2.plot(time, stimSignal, "b-", linewidth=2.0, label="Actual") + ax2.plot(time, result["ci_low"], "k-", linewidth=0.5) + ax2.plot(time, result["ci_high"], "k-", linewidth=0.5) + hEst, = ax2.plot(time, result["x_decoded"], "k-", linewidth=4.0) + hAct, = ax2.plot(time, stimSignal, "b-", linewidth=4.0) + ax2.legend([hEst, hAct], ["Decoded", "Actual"], loc="upper right") ax2.set_xlabel("time [s]") + ax2.set_ylabel("Stimulus") ax2.set_title( - f"Decoded Stimulus $\\pm$ 95% CIs with {n_cells} cells", - fontweight="bold", fontsize=14, fontfamily="Arial", + f"Decoded Stimulus +/- 95% CIs with {n_cells} cells", + fontweight="bold", fontsize=18, fontfamily="Arial", ) - ax2.legend(loc="upper right") fig2.tight_layout() return fig1, fig2 @@ -546,8 +548,10 @@ def _plot_part_b(result): # (4,2,[1,3]): 2D reach path with start/end markers ax_path = fig3.add_subplot(4, 2, (1, 3)) ax_path.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=2) - ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14) - ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14) + ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14, + markerfacecolor="none", markeredgewidth=1.5) + ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14, + markerfacecolor="none", markeredgewidth=1.5) ax_path.set_xlabel("X Position [cm]") ax_path.set_ylabel("Y Position [cm]") ax_path.set_title("Reach Path", fontweight="bold", fontsize=14, @@ -602,16 +606,21 @@ def _plot_part_b(result): all_goal = result["all_x_u_goal"] all_free = result["all_x_u_free"] - # (4,2,1:4): 2D reach paths overlaid + # (4,2,1:4): 2D reach paths overlaid — MATLAB overlays all sims + Start/Finish ax_paths = fig4.add_subplot(4, 2, (1, 4)) ax_paths.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=3) - ax_paths.set_title("Estimated vs. Actual Reach Paths", fontweight="bold", - fontsize=12, fontfamily="Arial") for k in range(len(all_goal)): ax_paths.plot(100 * all_goal[k][0, :], 100 * all_goal[k][1, :], "b", linewidth=0.5, alpha=0.7) ax_paths.plot(100 * all_free[k][0, :], 100 * all_free[k][1, :], "g", linewidth=0.5, alpha=0.7) + hStart, = ax_paths.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", + markersize=14, markerfacecolor="none", markeredgewidth=1.5) + hFinish, = ax_paths.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", + markersize=14, markerfacecolor="none", markeredgewidth=1.5) + ax_paths.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right") + ax_paths.set_title("Estimated vs. Actual Reach Paths", fontweight="bold", + fontsize=12, fontfamily="Arial") ax_paths.set_xlabel("x [cm]") ax_paths.set_ylabel("y [cm]") @@ -669,11 +678,14 @@ def _plot_part_c(result): # ── Figure 5: 4×2 setup — matches MATLAB subplot(4,2,...) ── fig5 = plt.figure(figsize=(14, 9)) - # (4,2,[1,3]): Reach path with markers + # (4,2,[1,3]): Reach path with markers — MATLAB uses unfilled circles + legend ax_path = fig5.add_subplot(4, 2, (1, 3)) ax_path.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=2) - ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16) - ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16) + hStart, = ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16, + markerfacecolor="none", markeredgewidth=1.5) + hFinish, = ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16, + markerfacecolor="none", markeredgewidth=1.5) + ax_path.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right") ax_path.set_xlabel("X [cm]") ax_path.set_ylabel("Y [cm]") ax_path.set_title("Reach Path", fontweight="bold", fontsize=14, @@ -723,30 +735,71 @@ def _plot_part_c(result): fig5.tight_layout() - # ── Figure 6: 4×3 averaged decode — matches MATLAB subplot(4,3,...) ── + # ── Figure 6: 4×3 decode — matches MATLAB HybridFilterExample.m ── + # MATLAB overlays *individual* dashed traces first, then thick means on top. fig6 = plt.figure(figsize=(14, 9)) - mS_est = np.mean(result["S_estAll"], axis=0) - mS_estNT = np.mean(result["S_estNTAll"], axis=0) - mMU_est = np.mean(result["MU_estAll"], axis=0) # (n_modes, T) - mMU_estNT = np.mean(result["MU_estNTAll"], axis=0) - mX_est = 100 * np.mean(result["X_estAll"], axis=2) # (6, T) in cm - mX_estNT = 100 * np.mean(result["X_estNTAll"], axis=2) + S_estAll = result["S_estAll"] # (n_sims, T) + S_estNTAll = result["S_estNTAll"] + MU_estAll = result["MU_estAll"] # (n_sims, n_modes, T) + MU_estNTAll = result["MU_estNTAll"] + X_estAll = result["X_estAll"] # (6, T, n_sims) + X_estNTAll = result["X_estNTAll"] + n_sims = result["n_sims"] - # (4,3,[1,4]): Estimated vs actual state + # Pre-create all subplots ax_s = fig6.add_subplot(4, 3, (1, 4)) + ax_p = fig6.add_subplot(4, 3, (7, 10)) + ax_2d = fig6.add_subplot(4, 3, (2, 6)) + ax_xp = fig6.add_subplot(4, 3, 8) + ax_yp = fig6.add_subplot(4, 3, 9) + ax_xv = fig6.add_subplot(4, 3, 11) + ax_yv = fig6.add_subplot(4, 3, 12) + + # --- Individual traces (thin dashed, matching MATLAB loop) --- + for n in range(n_sims): + # State traces + ax_s.plot(time, S_estAll[n, :], "b-.", linewidth=0.5) + ax_s.plot(time, S_estNTAll[n, :], "g-.", linewidth=0.5) + # Movement probability + ax_p.plot(time, MU_estAll[n, 1, :], "b-.", linewidth=0.5) + ax_p.plot(time, MU_estNTAll[n, 1, :], "g-.", linewidth=0.5) + # 2D path + ax_2d.plot(100 * X_estAll[0, :, n], 100 * X_estAll[1, :, n], "b-.", + linewidth=0.5) + ax_2d.plot(100 * X_estNTAll[0, :, n], 100 * X_estNTAll[1, :, n], "g-.", + linewidth=0.5) + # Position/velocity traces + ax_xp.plot(time, 100 * X_estAll[0, :, n], "b-.", linewidth=0.5) + ax_xp.plot(time, 100 * X_estNTAll[0, :, n], "g-.", linewidth=0.5) + ax_yp.plot(time, 100 * X_estAll[1, :, n], "b-.", linewidth=0.5) + ax_yp.plot(time, 100 * X_estNTAll[1, :, n], "g-.", linewidth=0.5) + ax_xv.plot(time, 100 * X_estAll[2, :, n], "b-.", linewidth=0.5) + ax_xv.plot(time, 100 * X_estNTAll[2, :, n], "g-.", linewidth=0.5) + ax_yv.plot(time, 100 * X_estAll[3, :, n], "b-.", linewidth=0.5) + ax_yv.plot(time, 100 * X_estNTAll[3, :, n], "g-.", linewidth=0.5) + + # --- Mean traces (thick solid, on top) --- + mS_est = np.mean(S_estAll, axis=0) + mS_estNT = np.mean(S_estNTAll, axis=0) + mMU_est = np.mean(MU_estAll, axis=0) + mMU_estNT = np.mean(MU_estNTAll, axis=0) + mX_est = 100 * np.mean(X_estAll, axis=2) + mX_estNT = 100 * np.mean(X_estNTAll, axis=2) + + # (4,3,[1,4]): State ax_s.plot(time, mstate, "k", linewidth=3) ax_s.plot(time, mS_est, "b", linewidth=3) ax_s.plot(time, mS_estNT, "g", linewidth=3) ax_s.set_xticklabels([]) + ax_s.set_ylim(0.5, 2.5) ax_s.set_yticks([1, 2.1]) ax_s.set_yticklabels(["N", "M"]) ax_s.set_ylabel("state") ax_s.set_title("Estimated vs. Actual State", fontweight="bold", fontsize=12, fontfamily="Arial") - # (4,3,[7,10]): P(s(t)=M|data) - ax_p = fig6.add_subplot(4, 3, (7, 10)) + # (4,3,[7,10]): P(s=M|data) ax_p.plot(time, mMU_est[1, :], "b", linewidth=3) ax_p.plot(time, mMU_estNT[1, :], "g", linewidth=3) ax_p.set_xlim(time[0], time[-1]) @@ -756,20 +809,21 @@ def _plot_part_c(result): ax_p.set_title("Probability of State", fontweight="bold", fontsize=12, fontfamily="Arial") - # (4,3,[2,3,5,6]): 2D reach path - ax_2d = fig6.add_subplot(4, 3, (2, 6)) + # (4,3,[2,3,5,6]): 2D reach — actual + mean + Start/Finish ax_2d.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=1) ax_2d.plot(mX_est[0, :], mX_est[1, :], "b", linewidth=3) ax_2d.plot(mX_estNT[0, :], mX_estNT[1, :], "g", linewidth=3) - ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14) - ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14) + hStart, = ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14, + markerfacecolor="none", markeredgewidth=1.5) + hFinish, = ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14, + markerfacecolor="none", markeredgewidth=1.5) + ax_2d.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right") ax_2d.set_xlabel("x [cm]") ax_2d.set_ylabel("y [cm]") ax_2d.set_title("Estimated vs. Actual Reach Path", fontweight="bold", fontsize=12, fontfamily="Arial") # (4,3,8): X position - ax_xp = fig6.add_subplot(4, 3, 8) ax_xp.plot(time, 100 * X[0, :], "k", linewidth=3) ax_xp.plot(time, mX_est[0, :], "b", linewidth=3) ax_xp.plot(time, mX_estNT[0, :], "g", linewidth=3) @@ -779,7 +833,6 @@ def _plot_part_c(result): fontfamily="Arial") # (4,3,9): Y position with legend - ax_yp = fig6.add_subplot(4, 3, 9) h1, = ax_yp.plot(time, 100 * X[1, :], "k", linewidth=3) h2, = ax_yp.plot(time, mX_est[1, :], "b", linewidth=3) h3, = ax_yp.plot(time, mX_estNT[1, :], "g", linewidth=3) @@ -791,7 +844,6 @@ def _plot_part_c(result): fontfamily="Arial") # (4,3,11): X velocity - ax_xv = fig6.add_subplot(4, 3, 11) ax_xv.plot(time, 100 * X[2, :], "k", linewidth=3) ax_xv.plot(time, mX_est[2, :], "b", linewidth=3) ax_xv.plot(time, mX_estNT[2, :], "g", linewidth=3) @@ -801,7 +853,6 @@ def _plot_part_c(result): fontfamily="Arial") # (4,3,12): Y velocity - ax_yv = fig6.add_subplot(4, 3, 12) ax_yv.plot(time, 100 * X[3, :], "k", linewidth=3) ax_yv.plot(time, mX_est[3, :], "b", linewidth=3) ax_yv.plot(time, mX_estNT[3, :], "g", linewidth=3)