Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions examples/paper/example05_decoding_ppaf_pphf.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ def _plot_part_c(result):

fig5.tight_layout()

# ── Figure 6: 4×3 decode — matches MATLAB HybridFilterExample.m ──
# MATLAB overlays *individual* dashed traces first, then thick means on top.
# ── Figure 6: 4×3 decode — matches MATLAB example05_decoding_ppaf_pphf.m ──
# MATLAB example05 plots ONLY means (thick solid lines), no individual traces.
fig6 = plt.figure(figsize=(14, 9))

S_estAll = result["S_estAll"] # (n_sims, T)
Expand All @@ -745,7 +745,6 @@ def _plot_part_c(result):
MU_estNTAll = result["MU_estNTAll"]
X_estAll = result["X_estAll"] # (6, T, n_sims)
X_estNTAll = result["X_estNTAll"]
n_sims = result["n_sims"]

# Pre-create all subplots
ax_s = fig6.add_subplot(4, 3, (1, 4))
Expand All @@ -756,30 +755,7 @@ def _plot_part_c(result):
ax_xv = fig6.add_subplot(4, 3, 11)
ax_yv = fig6.add_subplot(4, 3, 12)

# --- Individual traces (thin dashed, matching MATLAB loop) ---
for n in range(n_sims):
# State traces
ax_s.plot(time, S_estAll[n, :], "b-.", linewidth=0.5)
ax_s.plot(time, S_estNTAll[n, :], "g-.", linewidth=0.5)
# Movement probability
ax_p.plot(time, MU_estAll[n, 1, :], "b-.", linewidth=0.5)
ax_p.plot(time, MU_estNTAll[n, 1, :], "g-.", linewidth=0.5)
# 2D path
ax_2d.plot(100 * X_estAll[0, :, n], 100 * X_estAll[1, :, n], "b-.",
linewidth=0.5)
ax_2d.plot(100 * X_estNTAll[0, :, n], 100 * X_estNTAll[1, :, n], "g-.",
linewidth=0.5)
# Position/velocity traces
ax_xp.plot(time, 100 * X_estAll[0, :, n], "b-.", linewidth=0.5)
ax_xp.plot(time, 100 * X_estNTAll[0, :, n], "g-.", linewidth=0.5)
ax_yp.plot(time, 100 * X_estAll[1, :, n], "b-.", linewidth=0.5)
ax_yp.plot(time, 100 * X_estNTAll[1, :, n], "g-.", linewidth=0.5)
ax_xv.plot(time, 100 * X_estAll[2, :, n], "b-.", linewidth=0.5)
ax_xv.plot(time, 100 * X_estNTAll[2, :, n], "g-.", linewidth=0.5)
ax_yv.plot(time, 100 * X_estAll[3, :, n], "b-.", linewidth=0.5)
ax_yv.plot(time, 100 * X_estNTAll[3, :, n], "g-.", linewidth=0.5)

# --- Mean traces (thick solid, on top) ---
# --- Mean traces (thick solid — MATLAB plots only means, no individual) ---
mS_est = np.mean(S_estAll, axis=0)
mS_estNT = np.mean(S_estNTAll, axis=0)
mMU_est = np.mean(MU_estAll, axis=0)
Expand Down
Loading