Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions examples/paper/example05_decoding_ppaf_pphf.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,17 +518,19 @@ def _plot_part_a(result):
fig1.tight_layout()

# ── Figure 2: Decoded vs true (single axes, MATLAB style) ──
# MATLAB uses thin CI *lines* (not shaded fill), thick decoded/actual
fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4))
ax2.fill_between(time, result["ci_low"], result["ci_high"],
color="0.75", alpha=0.4, label="95% CI")
ax2.plot(time, result["x_decoded"], "k-", linewidth=2.0, label="Decoded")
ax2.plot(time, stimSignal, "b-", linewidth=2.0, label="Actual")
ax2.plot(time, result["ci_low"], "k-", linewidth=0.5)
ax2.plot(time, result["ci_high"], "k-", linewidth=0.5)
hEst, = ax2.plot(time, result["x_decoded"], "k-", linewidth=4.0)
hAct, = ax2.plot(time, stimSignal, "b-", linewidth=4.0)
ax2.legend([hEst, hAct], ["Decoded", "Actual"], loc="upper right")
ax2.set_xlabel("time [s]")
ax2.set_ylabel("Stimulus")
ax2.set_title(
f"Decoded Stimulus $\\pm$ 95% CIs with {n_cells} cells",
fontweight="bold", fontsize=14, fontfamily="Arial",
f"Decoded Stimulus +/- 95% CIs with {n_cells} cells",
fontweight="bold", fontsize=18, fontfamily="Arial",
)
ax2.legend(loc="upper right")
fig2.tight_layout()

return fig1, fig2
Expand All @@ -546,8 +548,10 @@ def _plot_part_b(result):
# (4,2,[1,3]): 2D reach path with start/end markers
ax_path = fig3.add_subplot(4, 2, (1, 3))
ax_path.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=2)
ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14)
ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14)
ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14,
markerfacecolor="none", markeredgewidth=1.5)
ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14,
markerfacecolor="none", markeredgewidth=1.5)
ax_path.set_xlabel("X Position [cm]")
ax_path.set_ylabel("Y Position [cm]")
ax_path.set_title("Reach Path", fontweight="bold", fontsize=14,
Expand Down Expand Up @@ -602,16 +606,21 @@ def _plot_part_b(result):
all_goal = result["all_x_u_goal"]
all_free = result["all_x_u_free"]

# (4,2,1:4): 2D reach paths overlaid
# (4,2,1:4): 2D reach paths overlaid — MATLAB overlays all sims + Start/Finish
ax_paths = fig4.add_subplot(4, 2, (1, 4))
ax_paths.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=3)
ax_paths.set_title("Estimated vs. Actual Reach Paths", fontweight="bold",
fontsize=12, fontfamily="Arial")
for k in range(len(all_goal)):
ax_paths.plot(100 * all_goal[k][0, :], 100 * all_goal[k][1, :], "b",
linewidth=0.5, alpha=0.7)
ax_paths.plot(100 * all_free[k][0, :], 100 * all_free[k][1, :], "g",
linewidth=0.5, alpha=0.7)
hStart, = ax_paths.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo",
markersize=14, markerfacecolor="none", markeredgewidth=1.5)
hFinish, = ax_paths.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro",
markersize=14, markerfacecolor="none", markeredgewidth=1.5)
ax_paths.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right")
ax_paths.set_title("Estimated vs. Actual Reach Paths", fontweight="bold",
fontsize=12, fontfamily="Arial")
ax_paths.set_xlabel("x [cm]")
ax_paths.set_ylabel("y [cm]")

Expand Down Expand Up @@ -669,11 +678,14 @@ def _plot_part_c(result):
# ── Figure 5: 4×2 setup — matches MATLAB subplot(4,2,...) ──
fig5 = plt.figure(figsize=(14, 9))

# (4,2,[1,3]): Reach path with markers
# (4,2,[1,3]): Reach path with markers — MATLAB uses unfilled circles + legend
ax_path = fig5.add_subplot(4, 2, (1, 3))
ax_path.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=2)
ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16)
ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16)
hStart, = ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16,
markerfacecolor="none", markeredgewidth=1.5)
hFinish, = ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16,
markerfacecolor="none", markeredgewidth=1.5)
ax_path.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right")
ax_path.set_xlabel("X [cm]")
ax_path.set_ylabel("Y [cm]")
ax_path.set_title("Reach Path", fontweight="bold", fontsize=14,
Expand Down Expand Up @@ -723,30 +735,71 @@ def _plot_part_c(result):

fig5.tight_layout()

# ── Figure 6: 4×3 averaged decode — matches MATLAB subplot(4,3,...) ──
# ── Figure 6: 4×3 decode — matches MATLAB HybridFilterExample.m ──
# MATLAB overlays *individual* dashed traces first, then thick means on top.
fig6 = plt.figure(figsize=(14, 9))

mS_est = np.mean(result["S_estAll"], axis=0)
mS_estNT = np.mean(result["S_estNTAll"], axis=0)
mMU_est = np.mean(result["MU_estAll"], axis=0) # (n_modes, T)
mMU_estNT = np.mean(result["MU_estNTAll"], axis=0)
mX_est = 100 * np.mean(result["X_estAll"], axis=2) # (6, T) in cm
mX_estNT = 100 * np.mean(result["X_estNTAll"], axis=2)
S_estAll = result["S_estAll"] # (n_sims, T)
S_estNTAll = result["S_estNTAll"]
MU_estAll = result["MU_estAll"] # (n_sims, n_modes, T)
MU_estNTAll = result["MU_estNTAll"]
X_estAll = result["X_estAll"] # (6, T, n_sims)
X_estNTAll = result["X_estNTAll"]
n_sims = result["n_sims"]

# (4,3,[1,4]): Estimated vs actual state
# Pre-create all subplots
ax_s = fig6.add_subplot(4, 3, (1, 4))
ax_p = fig6.add_subplot(4, 3, (7, 10))
ax_2d = fig6.add_subplot(4, 3, (2, 6))
ax_xp = fig6.add_subplot(4, 3, 8)
ax_yp = fig6.add_subplot(4, 3, 9)
ax_xv = fig6.add_subplot(4, 3, 11)
ax_yv = fig6.add_subplot(4, 3, 12)

# --- Individual traces (thin dashed, matching MATLAB loop) ---
for n in range(n_sims):
# State traces
ax_s.plot(time, S_estAll[n, :], "b-.", linewidth=0.5)
ax_s.plot(time, S_estNTAll[n, :], "g-.", linewidth=0.5)
# Movement probability
ax_p.plot(time, MU_estAll[n, 1, :], "b-.", linewidth=0.5)
ax_p.plot(time, MU_estNTAll[n, 1, :], "g-.", linewidth=0.5)
# 2D path
ax_2d.plot(100 * X_estAll[0, :, n], 100 * X_estAll[1, :, n], "b-.",
linewidth=0.5)
ax_2d.plot(100 * X_estNTAll[0, :, n], 100 * X_estNTAll[1, :, n], "g-.",
linewidth=0.5)
# Position/velocity traces
ax_xp.plot(time, 100 * X_estAll[0, :, n], "b-.", linewidth=0.5)
ax_xp.plot(time, 100 * X_estNTAll[0, :, n], "g-.", linewidth=0.5)
ax_yp.plot(time, 100 * X_estAll[1, :, n], "b-.", linewidth=0.5)
ax_yp.plot(time, 100 * X_estNTAll[1, :, n], "g-.", linewidth=0.5)
ax_xv.plot(time, 100 * X_estAll[2, :, n], "b-.", linewidth=0.5)
ax_xv.plot(time, 100 * X_estNTAll[2, :, n], "g-.", linewidth=0.5)
ax_yv.plot(time, 100 * X_estAll[3, :, n], "b-.", linewidth=0.5)
ax_yv.plot(time, 100 * X_estNTAll[3, :, n], "g-.", linewidth=0.5)

# --- Mean traces (thick solid, on top) ---
mS_est = np.mean(S_estAll, axis=0)
mS_estNT = np.mean(S_estNTAll, axis=0)
mMU_est = np.mean(MU_estAll, axis=0)
mMU_estNT = np.mean(MU_estNTAll, axis=0)
mX_est = 100 * np.mean(X_estAll, axis=2)
mX_estNT = 100 * np.mean(X_estNTAll, axis=2)

# (4,3,[1,4]): State
ax_s.plot(time, mstate, "k", linewidth=3)
ax_s.plot(time, mS_est, "b", linewidth=3)
ax_s.plot(time, mS_estNT, "g", linewidth=3)
ax_s.set_xticklabels([])
ax_s.set_ylim(0.5, 2.5)
ax_s.set_yticks([1, 2.1])
ax_s.set_yticklabels(["N", "M"])
ax_s.set_ylabel("state")
ax_s.set_title("Estimated vs. Actual State", fontweight="bold",
fontsize=12, fontfamily="Arial")

# (4,3,[7,10]): P(s(t)=M|data)
ax_p = fig6.add_subplot(4, 3, (7, 10))
# (4,3,[7,10]): P(s=M|data)
ax_p.plot(time, mMU_est[1, :], "b", linewidth=3)
ax_p.plot(time, mMU_estNT[1, :], "g", linewidth=3)
ax_p.set_xlim(time[0], time[-1])
Expand All @@ -756,20 +809,21 @@ def _plot_part_c(result):
ax_p.set_title("Probability of State", fontweight="bold", fontsize=12,
fontfamily="Arial")

# (4,3,[2,3,5,6]): 2D reach path
ax_2d = fig6.add_subplot(4, 3, (2, 6))
# (4,3,[2,3,5,6]): 2D reach — actual + mean + Start/Finish
ax_2d.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=1)
ax_2d.plot(mX_est[0, :], mX_est[1, :], "b", linewidth=3)
ax_2d.plot(mX_estNT[0, :], mX_estNT[1, :], "g", linewidth=3)
ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14)
ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14)
hStart, = ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14,
markerfacecolor="none", markeredgewidth=1.5)
hFinish, = ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14,
markerfacecolor="none", markeredgewidth=1.5)
ax_2d.legend([hStart, hFinish], ["Start", "Finish"], loc="upper right")
ax_2d.set_xlabel("x [cm]")
ax_2d.set_ylabel("y [cm]")
ax_2d.set_title("Estimated vs. Actual Reach Path", fontweight="bold",
fontsize=12, fontfamily="Arial")

# (4,3,8): X position
ax_xp = fig6.add_subplot(4, 3, 8)
ax_xp.plot(time, 100 * X[0, :], "k", linewidth=3)
ax_xp.plot(time, mX_est[0, :], "b", linewidth=3)
ax_xp.plot(time, mX_estNT[0, :], "g", linewidth=3)
Expand All @@ -779,7 +833,6 @@ def _plot_part_c(result):
fontfamily="Arial")

# (4,3,9): Y position with legend
ax_yp = fig6.add_subplot(4, 3, 9)
h1, = ax_yp.plot(time, 100 * X[1, :], "k", linewidth=3)
h2, = ax_yp.plot(time, mX_est[1, :], "b", linewidth=3)
h3, = ax_yp.plot(time, mX_estNT[1, :], "g", linewidth=3)
Expand All @@ -791,7 +844,6 @@ def _plot_part_c(result):
fontfamily="Arial")

# (4,3,11): X velocity
ax_xv = fig6.add_subplot(4, 3, 11)
ax_xv.plot(time, 100 * X[2, :], "k", linewidth=3)
ax_xv.plot(time, mX_est[2, :], "b", linewidth=3)
ax_xv.plot(time, mX_estNT[2, :], "g", linewidth=3)
Expand All @@ -801,7 +853,6 @@ def _plot_part_c(result):
fontfamily="Arial")

# (4,3,12): Y velocity
ax_yv = fig6.add_subplot(4, 3, 12)
ax_yv.plot(time, 100 * X[3, :], "k", linewidth=3)
ax_yv.plot(time, mX_est[3, :], "b", linewidth=3)
ax_yv.plot(time, mX_estNT[3, :], "g", linewidth=3)
Expand Down
Loading