diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 2602e04..9b6cfe9 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -735,8 +735,8 @@ def _plot_part_c(result): fig5.tight_layout() - # ── Figure 6: 4×3 decode — matches MATLAB HybridFilterExample.m ── - # MATLAB overlays *individual* dashed traces first, then thick means on top. + # ── Figure 6: 4×3 decode — matches MATLAB example05_decoding_ppaf_pphf.m ── + # MATLAB example05 plots ONLY means (thick solid lines), no individual traces. fig6 = plt.figure(figsize=(14, 9)) S_estAll = result["S_estAll"] # (n_sims, T) @@ -745,7 +745,6 @@ def _plot_part_c(result): MU_estNTAll = result["MU_estNTAll"] X_estAll = result["X_estAll"] # (6, T, n_sims) X_estNTAll = result["X_estNTAll"] - n_sims = result["n_sims"] # Pre-create all subplots ax_s = fig6.add_subplot(4, 3, (1, 4)) @@ -756,30 +755,7 @@ def _plot_part_c(result): ax_xv = fig6.add_subplot(4, 3, 11) ax_yv = fig6.add_subplot(4, 3, 12) - # --- Individual traces (thin dashed, matching MATLAB loop) --- - for n in range(n_sims): - # State traces - ax_s.plot(time, S_estAll[n, :], "b-.", linewidth=0.5) - ax_s.plot(time, S_estNTAll[n, :], "g-.", linewidth=0.5) - # Movement probability - ax_p.plot(time, MU_estAll[n, 1, :], "b-.", linewidth=0.5) - ax_p.plot(time, MU_estNTAll[n, 1, :], "g-.", linewidth=0.5) - # 2D path - ax_2d.plot(100 * X_estAll[0, :, n], 100 * X_estAll[1, :, n], "b-.", - linewidth=0.5) - ax_2d.plot(100 * X_estNTAll[0, :, n], 100 * X_estNTAll[1, :, n], "g-.", - linewidth=0.5) - # Position/velocity traces - ax_xp.plot(time, 100 * X_estAll[0, :, n], "b-.", linewidth=0.5) - ax_xp.plot(time, 100 * X_estNTAll[0, :, n], "g-.", linewidth=0.5) - ax_yp.plot(time, 100 * X_estAll[1, :, n], "b-.", linewidth=0.5) - ax_yp.plot(time, 100 * X_estNTAll[1, :, n], "g-.", linewidth=0.5) - ax_xv.plot(time, 100 * X_estAll[2, :, n], "b-.", linewidth=0.5) - ax_xv.plot(time, 100 * X_estNTAll[2, :, n], "g-.", linewidth=0.5) - ax_yv.plot(time, 100 * X_estAll[3, :, n], "b-.", linewidth=0.5) - ax_yv.plot(time, 100 * X_estNTAll[3, :, n], "g-.", linewidth=0.5) - - # --- Mean traces (thick solid, on top) --- + # --- Mean traces (thick solid — MATLAB plots only means, no individual) --- mS_est = np.mean(S_estAll, axis=0) mS_estNT = np.mean(S_estNTAll, axis=0) mMU_est = np.mean(MU_estAll, axis=0)