diff --git a/data_cache/nstat_data/paperHybridFilterExample.h5 b/data_cache/nstat_data/paperHybridFilterExample.h5 new file mode 100644 index 00000000..06b78aa6 Binary files /dev/null and b/data_cache/nstat_data/paperHybridFilterExample.h5 differ diff --git a/data_cache/nstat_data/paperHybridFilterExample.mat b/data_cache/nstat_data/paperHybridFilterExample.mat new file mode 100644 index 00000000..355e2232 Binary files /dev/null and b/data_cache/nstat_data/paperHybridFilterExample.mat differ diff --git a/docs/figures/example05/fig01_univariate_setup.png b/docs/figures/example05/fig01_univariate_setup.png index ace432bb..a35b5f9e 100644 Binary files a/docs/figures/example05/fig01_univariate_setup.png and b/docs/figures/example05/fig01_univariate_setup.png differ diff --git a/docs/figures/example05/fig02_univariate_decoding.png b/docs/figures/example05/fig02_univariate_decoding.png index bfb470e0..acca9506 100644 Binary files a/docs/figures/example05/fig02_univariate_decoding.png and b/docs/figures/example05/fig02_univariate_decoding.png differ diff --git a/docs/figures/example05/fig03_reach_and_population_setup.png b/docs/figures/example05/fig03_reach_and_population_setup.png index ab09ae8f..b4a25e82 100644 Binary files a/docs/figures/example05/fig03_reach_and_population_setup.png and b/docs/figures/example05/fig03_reach_and_population_setup.png differ diff --git a/docs/figures/example05/fig04_ppaf_goal_vs_free.png b/docs/figures/example05/fig04_ppaf_goal_vs_free.png index 0ce2fdec..457cbdf8 100644 Binary files a/docs/figures/example05/fig04_ppaf_goal_vs_free.png and b/docs/figures/example05/fig04_ppaf_goal_vs_free.png differ diff --git a/docs/figures/example05/fig05_hybrid_setup.png b/docs/figures/example05/fig05_hybrid_setup.png index 3839a193..fffa8a2f 100644 Binary files a/docs/figures/example05/fig05_hybrid_setup.png and b/docs/figures/example05/fig05_hybrid_setup.png differ diff --git a/docs/figures/example05/fig06_hybrid_decoding_summary.png b/docs/figures/example05/fig06_hybrid_decoding_summary.png index dbbf4033..710692b1 100644 Binary files a/docs/figures/example05/fig06_hybrid_decoding_summary.png and b/docs/figures/example05/fig06_hybrid_decoding_summary.png differ diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index db4f6574..112f0426 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -13,25 +13,26 @@ 3. Decode the stimulus using ``PPDecodeFilterLinear`` (PPAF). Part B — 4-State Arm Reach with PPAF (Figures 3–4): - 4. Simulate reaching trajectories (position + velocity, 4-D state). - 5. Encode with 20-cell cosine-tuning population (binomial CIF). - 6. Decode with PPAF (free) and PPAF + Goal; compare across 20 simulations. + 4. Simulate minimum-energy reaching trajectory (position + velocity, 4-D). + 5. Encode with 20-cell velocity-tuned population (binomial CIF). + 6. Decode with PPAF (free) and PPAF + Goal; overlay 20 simulations. Part C — Hybrid Filter (Figures 5–6): - 7. Simulate 40-cell population with 2 discrete reach-states (rest / reach) - that modulate baseline firing rate, plus velocity-tuned continuous state. - 8. Decode joint discrete + continuous state via ``PPHybridFilterLinear``. + 7. Load hybrid-filter trajectory fixture with 2 discrete movement states. + 8. Encode with 40-cell population (velocity-tuned, binomial CIF). + 9. Decode joint discrete + continuous state via ``PPHybridFilterLinear`` + over 20 simulations; average results. Paper mapping: - Section 2.5 (point-process adaptive filter) and Section 2.6 (hybrid filter). + Sections 2.3.6–2.3.7 (decoding); Figs. 8, 9, 14 plus hybrid extension. Expected outputs: - - Figure 1: CIF tuning curves and simulated spike raster. - - Figure 2: Decoded stimulus vs true (with 95% confidence band). - - Figure 3: Reach trajectory and population spike raster. - - Figure 4: PPAF comparison (free vs goal-informed, 20 runs box plot). - - Figure 5: Hybrid filter setup (state sequence, spike raster). - - Figure 6: Hybrid decoding results (state probabilities, decoded kinematics). + - Figure 1: Driving stimulus, CIFs, and simulated spike raster. + - Figure 2: Decoded stimulus vs. true with 95% CIs. + - Figure 3: Reach path, neural raster, position/velocity traces, CIFs. + - Figure 4: 20-simulation overlaid decoded reach paths (PPAF vs PPAF+Goal). + - Figure 5: Hybrid setup — reach path, raster, kinematics, discrete state. + - Figure 6: Hybrid 20-sim averaged decoding — state, reach path, kinematics. """ from __future__ import annotations @@ -56,28 +57,46 @@ # ────────────────────────────────────────────────────────────────────────────── -def _simulate_binomial_spikes(x, mu, beta, rng): - """Simulate spikes from binomial CIF: p_c = sigmoid(mu_c + beta_c @ x). +def _simulate_binomial_spikes_from_lambda(lambdaRate, delta, rng): + """Simulate spikes from precomputed lambda rates via thinning. Parameters ---------- - x : (ns, T) array — stimulus/state trajectory - mu : (C,) array — baseline log-odds per cell - beta : (ns, C) array — tuning coefficients + lambdaRate : (C, T) array — firing rates [spikes/sec] per cell per time + delta : float — bin width in seconds rng : numpy Generator Returns ------- dN : (C, T) array — binary spike indicators """ - ns, T = x.shape - C = mu.size - dN = np.zeros((C, T), dtype=float) - for t in range(T): - eta = mu + beta.T @ x[:, t] # (C,) - p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0))) - dN[:, t] = (rng.random(C) < p).astype(float) - return dN + prob = lambdaRate * delta # convert rate to probability per bin + prob = np.clip(prob, 0.0, 1.0) + return (rng.random(prob.shape) < prob).astype(float) + + +def _logistic_cif(dataMat, coeffs, delta): + """Compute binomial CIF rates matching MATLAB's logistic link. + + Parameters + ---------- + dataMat : (T, p) — design matrix [1, covariates] + coeffs : (C, p) — per-cell coefficients [mu, betas] + delta : float — bin width + + Returns + ------- + lambdaRate : (C, T) — firing rates in spikes/sec + """ + C = coeffs.shape[0] + T = dataMat.shape[0] + lambdaRate = np.zeros((C, T)) + for c in range(C): + eta = dataMat @ coeffs[c, :] + expEta = np.exp(np.clip(eta, -20.0, 20.0)) + p = expEta / (1.0 + expEta) + lambdaRate[c, :] = p / delta + return lambdaRate # ────────────────────────────────────────────────────────────────────────────── @@ -85,60 +104,67 @@ def _simulate_binomial_spikes(x, mu, beta, rng): # ────────────────────────────────────────────────────────────────────────────── -def _run_part_a(seed=11, n_cells=20): - """Encode/decode a 1-D sinusoidal stimulus with 20-cell binomial CIF.""" +def _run_part_a(seed=0, n_cells=20): + """Encode/decode a 1-D sinusoidal stimulus with 20-cell binomial CIF. + + Matches MATLAB: rng(0,'twister'), delta=0.001, f=2Hz, b0~N(log(10*delta),1), + b1~N(0,1), logistic CIF, PPDecodeFilterLinear with A=1, Q=std(diff(stim)). + """ rng = np.random.default_rng(seed) - delta = 0.001 # 1 ms bins - time = np.arange(0.0, 1.0 + delta, delta) + delta = 0.001 + tmax = 1.0 + time = np.arange(0.0, tmax + delta, delta) T = len(time) + f = 2.0 - # True stimulus: sinusoidal - x_true = np.sin(2.0 * np.pi * 2.0 * time) # (T,) - - # ── Encoding model: logistic CIF ── - # MATLAB: b0 = log(10*delta) + randn(C,1); b1 = randn(C,1); - b0 = np.log(10.0 * delta) + rng.standard_normal(n_cells) + # Encoding model — matches MATLAB exactly b1 = rng.standard_normal(n_cells) + b0 = np.log(10.0 * delta) + rng.standard_normal(n_cells) + stimSignal = np.sin(2.0 * np.pi * f * time) - # Simulate spikes - x_2d = x_true.reshape(1, -1) # (1, T) — scalar state - beta = b1.reshape(1, -1) # (1, C) — stimulus coefficients - dN = _simulate_binomial_spikes(x_2d, b0, beta, rng) + # Compute CIF and simulate spikes per cell + dN = np.zeros((n_cells, T)) + lambdaAll = np.zeros((n_cells, T)) + for c in range(n_cells): + eta = b1[c] * stimSignal + b0[c] + expEta = np.exp(np.clip(eta, -20.0, 20.0)) + p = expEta / (1.0 + expEta) + lambdaAll[c, :] = p / delta + dN[c, :] = (rng.random(T) < p).astype(float) - # ── State-space model ── - # x(t+1) = A * x(t) + w, w ~ N(0, Q) - # MATLAB: Q = std(stim.data(2:end) - stim.data(1:end-1)); A = 1; + # State-space model: x(t+1) = A*x(t) + w A = np.array([[1.0]]) - Q_val = float(np.std(np.diff(x_true))) + Q_val = float(np.std(np.diff(stimSignal))) Q = np.array([[Q_val]]) x0 = np.array([0.0]) Pi0 = 0.5 * np.eye(1) - # ── Decode with PPDecodeFilterLinear ── - # dN is (C, T) — the API expects (num_cells, num_steps) + # Decode + beta = b1.reshape(1, -1) # (1, C) x_p, W_p, x_u, W_u, _, _, _, _ = DecodingAlgorithms.PPDecodeFilterLinear( A, Q, dN, b0, beta, "binomial", delta, None, None, x0, Pi0 ) - # Extract decoded signal and 95% CI (±1.96σ, matching MATLAB zVal=1.96) - x_decoded = x_u[0, :] # (T,) + x_decoded = x_u[0, :] sigma = np.sqrt(np.maximum(W_u[0, 0, :], 0.0)) z_val = 1.96 ci_low = np.minimum(x_decoded - z_val * sigma, x_decoded + z_val * sigma) ci_high = np.maximum(x_decoded - z_val * sigma, x_decoded + z_val * sigma) - rmse = float(np.sqrt(np.mean((x_decoded - x_true) ** 2))) + rmse = float(np.sqrt(np.mean((x_decoded - stimSignal) ** 2))) return { "time": time, - "x_true": x_true, + "stimSignal": stimSignal, "x_decoded": x_decoded, "ci_low": ci_low, "ci_high": ci_high, "dN": dN, + "lambdaAll": lambdaAll, "b0": b0, "b1": b1, "rmse": rmse, "n_cells": n_cells, + "delta": delta, } @@ -147,108 +173,117 @@ def _run_part_a(seed=11, n_cells=20): # ────────────────────────────────────────────────────────────────────────────── -def _simulate_reach(delta, T_total, rng): - """Simulate a 2-D reaching trajectory with 4-D state [x, y, vx, vy]. +def _generate_reach_trajectory(delta, T_total, x0, xT): + """Generate minimum-energy reaching trajectory matching MATLAB. - Uses a simple sinusoidal trajectory to mimic a reaching task. + Uses the forcing function: + x(k) = A*x(k-1) + (delta/2)*(pi/T)^2*cos(pi*t/T)*[0; 0; dx; dy] """ - time = np.arange(0.0, T_total + delta, delta) - T = len(time) - - # Smooth trajectory - x_pos = 0.25 * np.sin(2.0 * np.pi * 0.15 * time) - y_pos = 0.20 * np.cos(2.0 * np.pi * 0.10 * time) - vx = np.gradient(x_pos, delta) - vy = np.gradient(y_pos, delta) - - state = np.vstack([x_pos, y_pos, vx, vy]) # (4, T) - return time, state - - -def _run_part_b(seed=19, n_cells=20, n_sims=20): - """Compare PPAF free vs goal-directed decoding for arm reach.""" - rng = np.random.default_rng(seed) - delta = 0.01 # 10 ms bins - ns = 4 # state dimension + time = np.arange(0.0, T_total + delta / 2, delta) + nT = len(time) - # State-space model (constant-velocity kinematic model) A = np.array([ [1, 0, delta, 0], [0, 1, 0, delta], [0, 0, 1, 0], [0, 0, 0, 1], ], dtype=float) - Q = 0.001 * np.eye(ns, dtype=float) - # Encoding model: cosine tuning to velocity - # mu_c ~ N(-3.0, 0.2) - # beta_c = [0, 0, w_vx, w_vy] — velocity tuned - b0 = rng.normal(-3.0, 0.2, n_cells) - beta = np.zeros((ns, n_cells), dtype=float) - for c in range(n_cells): - beta[2, c] = 3.0 * rng.normal(0.0, 1.0) # vx weight - beta[3, c] = 3.0 * rng.normal(0.0, 1.0) # vy weight - - # Run multiple simulations to compare free vs goal-directed - rmse_free = np.zeros((n_sims, ns), dtype=float) - rmse_goal = np.zeros((n_sims, ns), dtype=float) - - # Store one example run for plotting - example_run = None - - for sim_idx in range(n_sims): - sim_rng = np.random.default_rng(seed + sim_idx + 100) - time, state = _simulate_reach(delta, 10.0, sim_rng) - T = state.shape[1] - - # Simulate spikes - dN = _simulate_binomial_spikes(state, b0, beta, sim_rng) - - # Initial conditions - x0 = state[:, 0] - Pi0 = 0.1 * np.eye(ns) - - # --- Free decode (no goal) --- - x_p_free, _, x_u_free, W_u_free, _, _, _, _ = ( - DecodingAlgorithms.PPDecodeFilterLinear( - A, Q, dN, b0, beta, "binomial", delta, - None, None, x0, Pi0 - ) + xState = np.zeros((4, nT)) + xState[:, 0] = x0 + for k in range(1, nT): + forcing = (delta / 2.0) * (np.pi / T_total) ** 2 * np.cos(np.pi * time[k] / T_total) * \ + np.array([0.0, 0.0, xT[0] - x0[0], xT[1] - x0[1]]) + xState[:, k] = A @ xState[:, k - 1] + forcing + + return time, xState, A + + +def _run_part_b(seed=0, n_cells=20, n_sims=20): + """Arm reaching simulation and PPAF decoding — matches MATLAB exactly. + + Generates minimum-energy reach, encodes with velocity-tuned binomial CIF, + decodes with PPAF (free) and PPAF+Goal over 20 simulations. + """ + rng = np.random.default_rng(seed) + delta = 0.001 + T_total = 2.0 + + x0 = np.array([0.0, 0.0, 0.0, 0.0]) + xT_target = np.array([-0.35, 0.2, 0.0, 0.0]) + + time, xState, A = _generate_reach_trajectory(delta, T_total, x0, xT_target) + nT = len(time) + xT_actual = xState[:, -1] + + # Process noise: Qreach = diag(var(diff(xState))) * 100 + Qreach = np.diag(np.var(np.diff(xState, axis=1), axis=1)) * 100 + + # First simulation: generate CIFs and spike data for Figure 3 + bCoeffs = 10.0 * (rng.random((n_cells, 2)) - 0.5) + muCoeffs = np.log(10.0 * delta) + rng.standard_normal(n_cells) + coeffs = np.column_stack([muCoeffs, bCoeffs]) # (C, 3) + dataMat = np.column_stack([np.ones(nT), xState[2, :], xState[3, :]]) # (T, 3): [1, vx, vy] + + # Compute CIF for all cells + lambdaAll = _logistic_cif(dataMat, coeffs, delta) + + # Simulate spikes + dN_setup = _simulate_binomial_spikes_from_lambda(lambdaAll, delta, rng) + + # Store setup data for Figure 3 + setup_data = { + "time": time, + "xState": xState, + "lambdaAll": lambdaAll, + "dN": dN_setup, + "n_cells": n_cells, + } + + # 20 repeated simulations for Figure 4 + all_x_u_goal = [] # PPAF+Goal decoded paths + all_x_u_free = [] # PPAF free decoded paths + + for k in range(n_sims): + bCoeffs_k = 10.0 * (rng.random((n_cells, 2)) - 0.5) + muCoeffs_k = np.log(10.0 * delta) + rng.standard_normal(n_cells) + coeffs_k = np.column_stack([muCoeffs_k, bCoeffs_k]) + + lambdaK = _logistic_cif(dataMat, coeffs_k, delta) + dN_k = _simulate_binomial_spikes_from_lambda(lambdaK, delta, rng) + dN_k = np.minimum(dN_k, 1.0) # cap at 1 + + # beta for decoding: (4, C) — zeros for position, bCoeffs for velocity + beta_k = np.zeros((4, n_cells)) + beta_k[2, :] = bCoeffs_k[:, 0] + beta_k[3, :] = bCoeffs_k[:, 1] + + Pi0 = np.diag([1e-6, 1e-6, 1e-6, 1e-6]) + PiT = np.diag([1e-6, 1e-6, 1e-6, 1e-6]) + + # PPAF+Goal + _, _, x_u_goal, _, _, _, _, _ = DecodingAlgorithms.PPDecodeFilterLinear( + A, Qreach, dN_k, muCoeffs_k, beta_k, "binomial", delta, + None, None, x0, Pi0, xT_actual, PiT, 0 ) - # --- Goal-directed decode --- - yT = state[:, -1] # target = final state - PiT = 0.01 * np.eye(ns) # tight target uncertainty - x_p_goal, _, x_u_goal, W_u_goal, _, _, _, _ = ( - DecodingAlgorithms.PPDecodeFilterLinear( - A, Q, dN, b0, beta, "binomial", delta, - None, None, x0, Pi0, yT, PiT, 0 - ) + # PPAF free + _, _, x_u_free, _, _, _, _, _ = DecodingAlgorithms.PPDecodeFilterLinear( + A, Qreach, dN_k, muCoeffs_k, beta_k, "binomial", delta, + None, None, x0, Pi0 ) - # Compute RMSE per state dimension - for d in range(ns): - rmse_free[sim_idx, d] = np.sqrt(np.mean((x_u_free[d, :] - state[d, :]) ** 2)) - rmse_goal[sim_idx, d] = np.sqrt(np.mean((x_u_goal[d, :] - state[d, :]) ** 2)) - - if sim_idx == 0: - example_run = { - "time": time, - "state": state, - "dN": dN, - "x_u_free": x_u_free, - "x_u_goal": x_u_goal, - "W_u_free": W_u_free, - "W_u_goal": W_u_goal, - } + all_x_u_goal.append(x_u_goal) + all_x_u_free.append(x_u_free) return { - "rmse_free": rmse_free, - "rmse_goal": rmse_goal, - "example": example_run, + "setup": setup_data, + "time": time, + "xState": xState, + "all_x_u_goal": all_x_u_goal, + "all_x_u_free": all_x_u_free, "n_cells": n_cells, "n_sims": n_sims, - "state_labels": ["x", "y", "vx", "vy"], } @@ -257,109 +292,184 @@ def _run_part_b(seed=19, n_cells=20, n_sims=20): # ────────────────────────────────────────────────────────────────────────────── -def _run_part_c(seed=37, n_cells=40): - """PPHybridFilterLinear: joint discrete/continuous state decoding.""" +def _load_hybrid_fixture(): + """Load the hybrid filter trajectory fixture (HDF5 preferred, .mat fallback).""" + # Prefer HDF5 (needs h5py; fall back to .mat via scipy if unavailable) + h5_path = REPO_ROOT / "data_cache" / "nstat_data" / "paperHybridFilterExample.h5" + try: + import h5py # noqa: F811 + except ImportError: + h5py = None # type: ignore[assignment] + if h5py is not None and h5_path.exists(): + with h5py.File(str(h5_path), "r") as f: + d = { + "time": f["time"][:], + "delta": float(f["delta"][()]), + "X": f["X"][:], + "mstate": f["mstate"][:].astype(int), + "p_ij": f["p_ij"][:], + "A": [f[f"A/{i}"][:] for i in range(2)], + "Q": [f[f"Q/{i}"][:] for i in range(2)], + "Px0": [f[f"Px0/{i}"][:] for i in range(2)], + "ind": [f[f"ind/{i}"][:].astype(int) for i in range(2)], + } + return d + + # Fallback: .mat file (scipy required) + import scipy.io as sio + candidates = [ + REPO_ROOT / "data_cache" / "nstat_data" / "paperHybridFilterExample.mat", + REPO_ROOT.parent / "nSTAT_currentRelease_Local" / "helpfiles" / "paperHybridFilterExample.mat", + ] + for path in candidates: + if path.exists(): + d = sio.loadmat(str(path), squeeze_me=True) + # Normalize cell arrays to lists + d["A"] = [d["A"][i] for i in range(2)] + d["Q"] = [d["Q"][i] for i in range(2)] + d["Px0"] = [d["Px0"][i] for i in range(2)] + d["ind"] = [d["ind"][i].flatten().astype(int) for i in range(2)] + return d + raise FileNotFoundError( + "Cannot find paperHybridFilterExample.h5 or .mat. " + "Run the MATLAB export script or copy the data to data_cache/nstat_data/." + ) + + +def _run_part_c(seed=0, n_cells=40, n_sims=20): + """PPHybridFilterLinear: joint discrete/continuous state decoding. + + Loads trajectory fixture, encodes with velocity-tuned binomial CIF, + runs 20 simulations and averages decoded results — matching MATLAB. + """ rng = np.random.default_rng(seed) - delta = 0.01 # 10 ms bins - ns = 4 # continuous state dimension (x, y, vx, vy) - # ── Simulate trajectory ── - time = np.arange(0.0, 10.0, delta, dtype=float) - T = len(time) - x_pos = 0.3 * np.sin(2.0 * np.pi * 0.15 * time) - y_pos = 0.25 * np.cos(2.0 * np.pi * 0.10 * time) - vx = np.gradient(x_pos, delta) - vy = np.gradient(y_pos, delta) - state = np.vstack([x_pos, y_pos, vx, vy]) # (4, T) - - # Discrete state: alternating reach / hold (period ~6s) - true_mode = np.where(np.sin(2.0 * np.pi * time / 6.0) > 0.0, 1, 2).astype(int) - # Add stochastic flips - flip = rng.random(T) < 0.01 - true_mode[flip] = 3 - true_mode[flip] - - # ── State-space models (one per mode) ── - A_reach = np.array([ - [1, 0, delta, 0], - [0, 1, 0, delta], - [0, 0, 1, 0], - [0, 0, 0, 1], - ], dtype=float) - Q_reach = 0.001 * np.eye(ns) + fixture = _load_hybrid_fixture() + time = fixture["time"] + delta = float(fixture["delta"]) + X = fixture["X"] # (6, T) + mstate = fixture["mstate"].astype(int) + A_list = list(fixture["A"]) + Q_list = list(fixture["Q"]) + p_ij = fixture["p_ij"] + ind = [np.asarray(v).flatten().astype(int) - 1 for v in fixture["ind"]] # 0-based + Px0 = list(fixture["Px0"]) + + nT = len(time) + + # Clamp hold-state noise (matches MATLAB: minCovVal = 1e-12) + Q_list[0] = 1e-12 * np.eye(Q_list[0].shape[0]) + + # Compute actual process noise from trajectory + nonMovingInd = np.intersect1d(np.where(X[4, :] == 0)[0], np.where(X[5, :] == 0)[0]) + movingInd = np.setdiff1d(np.arange(nT), nonMovingInd) + if len(movingInd) > 1: + Q_list[1] = np.diag(np.var(np.diff(X[:, movingInd], axis=1), axis=1)) + Q_list[1][:4, :4] = 0.0 + if len(nonMovingInd) > 1: + varNV = np.diag(np.var(np.diff(X[:, nonMovingInd], axis=1), axis=1)) + n0 = Q_list[0].shape[0] + Q_list[0] = varNV[:n0, :n0] + + # Setup encoding: first simulation for Figure 5 + muCoeffs_0 = np.log(10.0 * delta) + rng.standard_normal(n_cells) + coeffs_0 = np.column_stack([ + muCoeffs_0, + np.zeros((n_cells, 2)), + 10.0 * (rng.random((n_cells, 2)) - 0.5), + np.zeros((n_cells, 2)), + ]) # (C, 7): [mu, 0, 0, b_vx, b_vy, 0, 0] + + dataMat = np.column_stack([np.ones(nT), X.T]) # (T, 7) + lambdaAll_0 = _logistic_cif(dataMat, coeffs_0, delta) + dN_0 = _simulate_binomial_spikes_from_lambda(lambdaAll_0, delta, rng) + + # Setup data for Figure 5 + setup_data = { + "time": time, + "X": X, + "mstate": mstate, + "dN": dN_0, + "n_cells": n_cells, + } - # Hold state: damped velocity - A_hold = np.array([ - [1, 0, delta, 0], - [0, 1, 0, delta], - [0, 0, 0.95, 0], - [0, 0, 0, 0.95], - ], dtype=float) - Q_hold = 0.0005 * np.eye(ns) - - # ── Encoding model ── - # Neurons tuned to ALL state dimensions (position + velocity). - # Mode-dependent baseline: mode 1 (reach) has different rate than mode 2 (hold). - b0_mode1 = rng.normal(-3.5, 0.2, n_cells) # reach baseline - b0_mode2 = rng.normal(-2.5, 0.2, n_cells) # hold baseline - - # Full state tuning: position + velocity - beta_mat = np.zeros((ns, n_cells), dtype=float) - beta_mat[0, :] = rng.normal(0.0, 2.0, n_cells) # x position - beta_mat[1, :] = rng.normal(0.0, 2.0, n_cells) # y position - beta_mat[2, :] = rng.normal(0.0, 3.0, n_cells) # vx - beta_mat[3, :] = rng.normal(0.0, 3.0, n_cells) # vy - - # Simulate spikes with mode-dependent baseline (binomial) - dN = np.zeros((n_cells, T), dtype=float) - for t in range(T): - b0 = b0_mode1 if true_mode[t] == 1 else b0_mode2 - eta = b0 + beta_mat.T @ state[:, t] - p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0))) - dN[:, t] = (rng.random(n_cells) < p).astype(float) - - # ── Transition matrix ── - p_ij = np.array([[0.985, 0.015], [0.02, 0.98]], dtype=float) - - # ── Decode with PPHybridFilterLinear ── - Mu0 = np.array([0.5, 0.5]) - x0 = [state[:, 0], state[:, 0]] - Pi0 = [0.5 * np.eye(ns), 0.5 * np.eye(ns)] - - S_est, X_est, W_est, MU_u, _, _, _ = DecodingAlgorithms.PPHybridFilterLinear( - [A_reach, A_hold], - [Q_reach, Q_hold], - p_ij, - Mu0, - dN, - [b0_mode1, b0_mode2], - [beta_mat, beta_mat], - "binomial", - delta, - None, # gamma - None, # windowTimes - x0, - Pi0, - ) + # 20 repeated simulations for Figure 6 + X_estAll = np.zeros((X.shape[0], nT, n_sims)) + X_estNTAll = np.zeros((X.shape[0], nT, n_sims)) + S_estAll = np.zeros((n_sims, nT)) + S_estNTAll = np.zeros((n_sims, nT)) + MU_estAll = [] + MU_estNTAll = [] + + for n in range(n_sims): + muCoeffs_n = np.log(10.0 * delta) + rng.standard_normal(n_cells) + coeffs_n = np.column_stack([ + muCoeffs_n, + np.zeros((n_cells, 2)), + 10.0 * (rng.random((n_cells, 2)) - 0.5), + np.zeros((n_cells, 2)), + ]) + + lambdaAll_n = _logistic_cif(dataMat, coeffs_n, delta) + dN_n = _simulate_binomial_spikes_from_lambda(lambdaAll_n, delta, rng) + dN_n = np.minimum(dN_n, 1.0) + + mu0 = 0.5 * np.ones(p_ij.shape[0]) + beta_full = coeffs_n[:, 1:].T # (6, C) + # Per-model beta slices: model 0 uses ind[0] states, model 1 uses ind[1] + beta_list = [beta_full[ind[0], :], beta_full[ind[1], :]] + + x0_list = [X[ind[0], 0], X[ind[1], 0]] + yT_list = [X[ind[0], -1], X[ind[1], -1]] + PiT_list = [ + 1e-9 * np.eye(len(ind[0])), + 1e-9 * np.eye(len(ind[1])), + ] + + # PPAF+Goal (with target) + S_est, X_est, _, MU_est, _, _, _ = DecodingAlgorithms.PPHybridFilterLinear( + A_list, Q_list, p_ij, mu0, dN_n, + muCoeffs_n, beta_list, "binomial", delta, + None, None, x0_list, Px0, yT_list, PiT_list + ) + + # PPAF free (no target) + S_estNT, X_estNT, _, MU_estNT, _, _, _ = DecodingAlgorithms.PPHybridFilterLinear( + A_list, Q_list, p_ij, mu0, dN_n, + muCoeffs_n, beta_list, "binomial", delta, + None, None, x0_list, Px0 + ) + + X_estAll[:, :, n] = X_est + X_estNTAll[:, :, n] = X_estNT + S_estAll[n, :] = S_est + S_estNTAll[n, :] = S_estNT + MU_estAll.append(MU_est) + MU_estNTAll.append(MU_estNT) - # Classification accuracy - state_acc = float(np.mean(S_est == true_mode)) + MU_estAll = np.array(MU_estAll) # (n_sims, n_modes, T) + MU_estNTAll = np.array(MU_estNTAll) # (n_sims, n_modes, T) - # Position RMSE - rmse_x = float(np.sqrt(np.mean((X_est[0, :] - x_pos) ** 2))) - rmse_y = float(np.sqrt(np.mean((X_est[1, :] - y_pos) ** 2))) + state_acc = float(np.mean(np.mean(S_estAll, axis=0).round() == mstate)) + rmse_x = float(np.sqrt(np.mean((np.mean(X_estAll[0, :, :], axis=1) - X[0, :]) ** 2))) + rmse_y = float(np.sqrt(np.mean((np.mean(X_estAll[1, :, :], axis=1) - X[1, :]) ** 2))) return { + "setup": setup_data, "time": time, - "state": state, - "true_mode": true_mode, - "S_est": S_est, - "X_est": X_est, - "MU_u": MU_u, - "dN": dN, + "X": X, + "mstate": mstate, + "X_estAll": X_estAll, + "X_estNTAll": X_estNTAll, + "S_estAll": S_estAll, + "S_estNTAll": S_estNTAll, + "MU_estAll": MU_estAll, + "MU_estNTAll": MU_estNTAll, "state_acc": state_acc, "rmse_x": rmse_x, "rmse_y": rmse_y, "n_cells": n_cells, + "n_sims": n_sims, } @@ -369,57 +479,55 @@ def _run_part_c(seed=37, n_cells=40): def _plot_part_a(result): - """Figure 1: CIF setup & raster. Figure 2: Decoded vs true stimulus.""" + """Figure 1: stimulus + CIF + raster (3×1). Figure 2: decoded vs true.""" time = result["time"] - x_true = result["x_true"] + stimSignal = result["stimSignal"] dN = result["dN"] - delta = time[1] - time[0] + n_cells = result["n_cells"] - # ── Figure 1: stimulus, CIF, spike raster (3 panels, matching MATLAB) ── + # ── Figure 1: 3×1 matching MATLAB subplot(3,1,...) ── fig1, axes1 = plt.subplots(3, 1, figsize=(10, 8), sharex=True) - # Top: driving stimulus - axes1[0].plot(time, x_true, "k-", linewidth=1.5) + # (3,1,1): Driving stimulus + axes1[0].plot(time, stimSignal, "k", linewidth=1.5) axes1[0].set_ylabel("Stimulus") - axes1[0].set_title("Driving Stimulus", fontweight="bold", fontsize=14) + axes1[0].set_title("Driving Stimulus", fontweight="bold", fontsize=14, + fontfamily="Arial") axes1[0].tick_params(labelbottom=False) - # Middle: conditional intensity functions (firing rates in spikes/sec) - b0 = result["b0"] - b1 = result["b1"] - n_cells = dN.shape[0] + # (3,1,2): CIFs overlaid in black + lambdaAll = result["lambdaAll"] for c in range(n_cells): - eta = b1[c] * x_true + b0[c] - exp_eta = np.exp(eta) - lam = (exp_eta / (1.0 + exp_eta)) / delta # probability → rate (Hz) - axes1[1].plot(time, lam, "k-", linewidth=1.0) + axes1[1].plot(time, lambdaAll[c, :], "k", linewidth=0.5) axes1[1].set_ylabel("Firing Rate [spikes/sec]") - axes1[1].set_title("Conditional Intensity Functions", fontweight="bold", fontsize=14) + axes1[1].set_title("Conditional Intensity Functions", fontweight="bold", + fontsize=14, fontfamily="Arial") axes1[1].tick_params(labelbottom=False) - # Bottom: spike raster + # (3,1,3): Spike raster for c in range(n_cells): - spike_times = time[dN[c, :] > 0] - axes1[2].plot(spike_times, np.full_like(spike_times, c + 1), "|", color="k", markersize=2) + spike_t = time[dN[c, :] > 0] + axes1[2].plot(spike_t, np.full_like(spike_t, c + 1), "|", color="k", + markersize=2) axes1[2].set_ylabel("Cell Number") axes1[2].set_xlabel("time [s]") axes1[2].set_ylim(0.5, n_cells + 0.5) axes1[2].set_yticks(np.arange(0, n_cells + 1, 10)) - axes1[2].set_title("Point Process Sample Paths", fontweight="bold", fontsize=14) + axes1[2].set_title("Point Process Sample Paths", fontweight="bold", + fontsize=14, fontfamily="Arial") fig1.tight_layout() - # ── Figure 2: Decoding results (MATLAB: black=decoded, blue=actual) ── + # ── Figure 2: Decoded vs true (single axes, MATLAB style) ── fig2, ax2 = plt.subplots(1, 1, figsize=(10, 4)) - ax2.fill_between( - time, result["ci_low"], result["ci_high"], - color="0.75", alpha=0.4, label="95% CI" - ) + ax2.fill_between(time, result["ci_low"], result["ci_high"], + color="0.75", alpha=0.4, label="95% CI") ax2.plot(time, result["x_decoded"], "k-", linewidth=2.0, label="Decoded") - ax2.plot(time, x_true, "b-", linewidth=2.0, label="Actual") + ax2.plot(time, stimSignal, "b-", linewidth=2.0, label="Actual") ax2.set_xlabel("time [s]") - ax2.set_ylabel("") - ax2.set_title(f"Decoded Stimulus $\\pm$ 95% CIs with {result['n_cells']} cells", - fontweight="bold", fontsize=14) + ax2.set_title( + f"Decoded Stimulus $\\pm$ 95% CIs with {n_cells} cells", + fontweight="bold", fontsize=14, fontfamily="Arial", + ) ax2.legend(loc="upper right") fig2.tight_layout() @@ -427,104 +535,281 @@ def _plot_part_a(result): def _plot_part_b(result): - """Figure 3: Reach trajectory & encoding. Figure 4: RMSE comparison.""" - ex = result["example"] - time = ex["time"] - state = ex["state"] - - # ── Figure 3: Example reach with decoded trajectories ── - fig3, axes3 = plt.subplots(2, 2, figsize=(12, 8)) - labels = result["state_labels"] - ylabels = ["x (m)", "y (m)", "vx (m/s)", "vy (m/s)"] - - for d, (ax, lab, ylab) in enumerate(zip(axes3.ravel(), labels, ylabels)): - ax.plot(time, state[d, :], "k-", linewidth=1.0, label="True") - ax.plot(time, ex["x_u_free"][d, :], "b-", linewidth=0.7, alpha=0.8, label="PPAF free") - ax.plot(time, ex["x_u_goal"][d, :], "r-", linewidth=0.7, alpha=0.8, label="PPAF+Goal") - ax.set_ylabel(ylab) - ax.set_title(f"State: {lab}") - if d >= 2: - ax.set_xlabel("Time (s)") - if d == 0: - ax.legend(loc="upper right", fontsize=8) - - fig3.suptitle("Part B: Arm Reach — PPAF Decoding (Example Run)", fontsize=12) + """Figure 3: Reach setup (4×2). Figure 4: 20-sim decoded paths (4×2).""" + time = result["time"] + xState = result["xState"] + setup = result["setup"] + + # ── Figure 3: 4×2 setup — matches MATLAB subplot(4,2,...) ── + fig3 = plt.figure(figsize=(14, 9)) + + # (4,2,[1,3]): 2D reach path with start/end markers + ax_path = fig3.add_subplot(4, 2, (1, 3)) + ax_path.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=2) + ax_path.plot(100 * xState[0, 0], 100 * xState[1, 0], "bo", markersize=14) + ax_path.plot(100 * xState[0, -1], 100 * xState[1, -1], "ro", markersize=14) + ax_path.set_xlabel("X Position [cm]") + ax_path.set_ylabel("Y Position [cm]") + ax_path.set_title("Reach Path", fontweight="bold", fontsize=14, + fontfamily="Arial") + ax_path.legend(["Path", "Start", "Finish"], loc="upper right") + + # (4,2,[2,4]): Spike raster + ax_raster = fig3.add_subplot(4, 2, (2, 4)) + dN = setup["dN"] + n_cells = setup["n_cells"] + for c in range(n_cells): + spike_t = time[dN[c, :] > 0] + ax_raster.plot(spike_t, np.full_like(spike_t, c + 1), "|", color="k", + markersize=2) + ax_raster.set_ylim(0.5, n_cells + 0.5) + ax_raster.set_xticklabels([]) + ax_raster.set_title("Neural Raster", fontweight="bold", fontsize=14, + fontfamily="Arial") + ax_raster.set_xlabel("time [s]") + ax_raster.set_ylabel("Cell Number") + + # (4,2,5): Position traces x, y + ax_pos = fig3.add_subplot(4, 2, 5) + h1, = ax_pos.plot(time, 100 * xState[0, :], "k", linewidth=2) + h2, = ax_pos.plot(time, 100 * xState[1, :], "k-.", linewidth=2) + ax_pos.legend([h1, h2], ["x", "y"], loc="upper right") + ax_pos.set_xlabel("time [s]") + ax_pos.set_ylabel("Position [cm]") + + # (4,2,7): Velocity traces vx, vy + ax_vel = fig3.add_subplot(4, 2, 7) + h1, = ax_vel.plot(time, 100 * xState[2, :], "k", linewidth=2) + h2, = ax_vel.plot(time, 100 * xState[3, :], "k-.", linewidth=2) + ax_vel.legend([h1, h2], ["v_x", "v_y"], loc="upper right") + ax_vel.set_xlabel("time [s]") + ax_vel.set_ylabel("Velocity [cm/s]") + + # (4,2,[6,8]): CIFs overlaid in black + ax_cif = fig3.add_subplot(4, 2, (6, 8)) + lambdaAll = setup["lambdaAll"] + for c in range(n_cells): + ax_cif.plot(time, lambdaAll[c, :], "k", linewidth=0.5) + ax_cif.set_title("Neural Conditional Intensity Functions", + fontweight="bold", fontsize=14, fontfamily="Arial") + ax_cif.set_xlabel("time [s]") + ax_cif.set_ylabel("Firing Rate [spikes/sec]") + fig3.tight_layout() - # ── Figure 4: RMSE box plot (free vs goal) ── - fig4, axes4 = plt.subplots(1, 4, figsize=(14, 4)) - for d, (ax, lab) in enumerate(zip(axes4, labels)): - data = [result["rmse_free"][:, d], result["rmse_goal"][:, d]] - bp = ax.boxplot(data, labels=["Free", "Goal"]) - ax.set_title(f"RMSE: {lab}") - ax.set_ylabel("RMSE") - - fig4.suptitle( - f"Part B: PPAF Free vs Goal ({result['n_sims']} simulations, {result['n_cells']} cells)", - fontsize=12, - ) + # ── Figure 4: 4×2 overlaid decoded paths — matches MATLAB ── + fig4 = plt.figure(figsize=(14, 9)) + all_goal = result["all_x_u_goal"] + all_free = result["all_x_u_free"] + + # (4,2,1:4): 2D reach paths overlaid + ax_paths = fig4.add_subplot(4, 2, (1, 4)) + ax_paths.plot(100 * xState[0, :], 100 * xState[1, :], "k", linewidth=3) + ax_paths.set_title("Estimated vs. Actual Reach Paths", fontweight="bold", + fontsize=12, fontfamily="Arial") + for k in range(len(all_goal)): + ax_paths.plot(100 * all_goal[k][0, :], 100 * all_goal[k][1, :], "b", + linewidth=0.5, alpha=0.7) + ax_paths.plot(100 * all_free[k][0, :], 100 * all_free[k][1, :], "g", + linewidth=0.5, alpha=0.7) + ax_paths.set_xlabel("x [cm]") + ax_paths.set_ylabel("y [cm]") + + # (4,2,5): x(t) + ax_x = fig4.add_subplot(4, 2, 5) + ax_x.plot(time, 100 * xState[0, :], "k", linewidth=3) + for k in range(len(all_goal)): + ax_x.plot(time, 100 * all_goal[k][0, :], "b", linewidth=0.5, alpha=0.5) + ax_x.plot(time, 100 * all_free[k][0, :], "g", linewidth=0.5, alpha=0.5) + ax_x.set_ylabel("x(t) [cm]") + ax_x.set_xticklabels([]) + + # (4,2,6): y(t) with legend + ax_y = fig4.add_subplot(4, 2, 6) + hA, = ax_y.plot(time, 100 * xState[1, :], "k", linewidth=3) + hB, = ax_y.plot(time, 100 * all_goal[0][1, :], "b", linewidth=0.5) + hC, = ax_y.plot(time, 100 * all_free[0][1, :], "g", linewidth=0.5) + for k in range(1, len(all_goal)): + ax_y.plot(time, 100 * all_goal[k][1, :], "b", linewidth=0.5, alpha=0.5) + ax_y.plot(time, 100 * all_free[k][1, :], "g", linewidth=0.5, alpha=0.5) + ax_y.legend([hA, hB, hC], ["Actual", "PPAF+Goal", "PPAF"], loc="lower right") + ax_y.set_ylabel("y(t) [cm]") + ax_y.set_xticklabels([]) + + # (4,2,7): vx(t) + ax_vx = fig4.add_subplot(4, 2, 7) + ax_vx.plot(time, 100 * xState[2, :], "k", linewidth=3) + for k in range(len(all_goal)): + ax_vx.plot(time, 100 * all_goal[k][2, :], "b", linewidth=0.5, alpha=0.5) + ax_vx.plot(time, 100 * all_free[k][2, :], "g", linewidth=0.5, alpha=0.5) + ax_vx.set_ylabel("v_x(t) [cm/s]") + ax_vx.set_xlabel("time [s]") + + # (4,2,8): vy(t) + ax_vy = fig4.add_subplot(4, 2, 8) + ax_vy.plot(time, 100 * xState[3, :], "k", linewidth=3) + for k in range(len(all_goal)): + ax_vy.plot(time, 100 * all_goal[k][3, :], "b", linewidth=0.5, alpha=0.5) + ax_vy.plot(time, 100 * all_free[k][3, :], "g", linewidth=0.5, alpha=0.5) + ax_vy.set_ylabel("v_y(t) [cm/s]") + ax_vy.set_xlabel("time [s]") + fig4.tight_layout() return fig3, fig4 def _plot_part_c(result): - """Figure 5: Hybrid setup. Figure 6: Hybrid decoding results.""" + """Figure 5: Hybrid setup (4×2). Figure 6: 20-sim decode (4×3).""" time = result["time"] + X = result["X"] + mstate = result["mstate"] + setup = result["setup"] + + # ── Figure 5: 4×2 setup — matches MATLAB subplot(4,2,...) ── + fig5 = plt.figure(figsize=(14, 9)) + + # (4,2,[1,3]): Reach path with markers + ax_path = fig5.add_subplot(4, 2, (1, 3)) + ax_path.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=2) + ax_path.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=16) + ax_path.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=16) + ax_path.set_xlabel("X [cm]") + ax_path.set_ylabel("Y [cm]") + ax_path.set_title("Reach Path", fontweight="bold", fontsize=14, + fontfamily="Arial") + + # (4,2,[2,4]): Spike raster + ax_raster = fig5.add_subplot(4, 2, (2, 4)) + dN = setup["dN"] + n_cells = setup["n_cells"] + for c in range(n_cells): + spike_t = time[dN[c, :] > 0] + ax_raster.plot(spike_t, np.full_like(spike_t, c + 1), "|", color="k", + markersize=2) + ax_raster.set_ylim(0.5, n_cells + 0.5) + ax_raster.set_xticklabels([]) + ax_raster.set_title("Neural Raster", fontweight="bold", fontsize=14, + fontfamily="Arial") + ax_raster.set_xlabel("time [s]") + ax_raster.set_ylabel("Cell Number") + + # (4,2,5): Position traces + ax_pos = fig5.add_subplot(4, 2, 5) + h1, = ax_pos.plot(time, 100 * X[0, :], "k", linewidth=2) + h2, = ax_pos.plot(time, 100 * X[1, :], "k-.", linewidth=2) + ax_pos.legend([h1, h2], ["x", "y"], loc="upper right") + ax_pos.set_xlabel("time [s]") + ax_pos.set_ylabel("Position [cm]") + + # (4,2,7): Velocity traces + ax_vel = fig5.add_subplot(4, 2, 7) + h1, = ax_vel.plot(time, 100 * X[2, :], "k", linewidth=2) + h2, = ax_vel.plot(time, 100 * X[3, :], "k-.", linewidth=2) + ax_vel.legend([h1, h2], ["v_x", "v_y"], loc="upper right") + ax_vel.set_xlabel("time [s]") + ax_vel.set_ylabel("Velocity [cm/s]") + + # (4,2,[6,8]): Discrete movement state + ax_state = fig5.add_subplot(4, 2, (6, 8)) + ax_state.plot(time, mstate, "k", linewidth=2) + ax_state.set_ylim(0, 3) + ax_state.set_yticks([1, 2]) + ax_state.set_yticklabels(["N", "M"]) + ax_state.set_xlabel("time [s]") + ax_state.set_ylabel("state") + ax_state.set_title("Discrete Movement State", fontweight="bold", + fontsize=14, fontfamily="Arial") - # ── Figure 5: Setup — state sequence + raster ── - fig5, axes5 = plt.subplots(2, 1, figsize=(12, 5), sharex=True) - - # Top: discrete state - axes5[0].plot(time, result["true_mode"], "k-", linewidth=1.0, label="True mode") - axes5[0].set_ylabel("Discrete State") - axes5[0].set_yticks([1, 2]) - axes5[0].set_yticklabels(["Reach", "Hold"]) - axes5[0].set_title("Part C: Hybrid Filter Setup") - axes5[0].legend() - - # Bottom: spike raster (first 20 cells) - dN = result["dN"] - n_show = min(20, dN.shape[0]) - for c in range(n_show): - idx = np.where(dN[c, :] > 0)[0] - spike_t = time[idx] - axes5[1].plot(spike_t, np.full_like(spike_t, c + 1), "|", color="k", markersize=2) - axes5[1].set_ylabel("Neuron") - axes5[1].set_xlabel("Time (s)") - axes5[1].set_ylim(0.5, n_show + 0.5) fig5.tight_layout() - # ── Figure 6: Decoding results ── - fig6, axes6 = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + # ── Figure 6: 4×3 averaged decode — matches MATLAB subplot(4,3,...) ── + fig6 = plt.figure(figsize=(14, 9)) + + mS_est = np.mean(result["S_estAll"], axis=0) + mS_estNT = np.mean(result["S_estNTAll"], axis=0) + mMU_est = np.mean(result["MU_estAll"], axis=0) # (n_modes, T) + mMU_estNT = np.mean(result["MU_estNTAll"], axis=0) + mX_est = 100 * np.mean(result["X_estAll"], axis=2) # (6, T) in cm + mX_estNT = 100 * np.mean(result["X_estNTAll"], axis=2) + + # (4,3,[1,4]): Estimated vs actual state + ax_s = fig6.add_subplot(4, 3, (1, 4)) + ax_s.plot(time, mstate, "k", linewidth=3) + ax_s.plot(time, mS_est, "b", linewidth=3) + ax_s.plot(time, mS_estNT, "g", linewidth=3) + ax_s.set_xticklabels([]) + ax_s.set_yticks([1, 2.1]) + ax_s.set_yticklabels(["N", "M"]) + ax_s.set_ylabel("state") + ax_s.set_title("Estimated vs. Actual State", fontweight="bold", + fontsize=12, fontfamily="Arial") + + # (4,3,[7,10]): P(s(t)=M|data) + ax_p = fig6.add_subplot(4, 3, (7, 10)) + ax_p.plot(time, mMU_est[1, :], "b", linewidth=3) + ax_p.plot(time, mMU_estNT[1, :], "g", linewidth=3) + ax_p.set_xlim(time[0], time[-1]) + ax_p.set_ylim(0, 1.1) + ax_p.set_xlabel("time [s]") + ax_p.set_ylabel("P(s(t)=M | data)") + ax_p.set_title("Probability of State", fontweight="bold", fontsize=12, + fontfamily="Arial") + + # (4,3,[2,3,5,6]): 2D reach path + ax_2d = fig6.add_subplot(4, 3, (2, 6)) + ax_2d.plot(100 * X[0, :], 100 * X[1, :], "k", linewidth=1) + ax_2d.plot(mX_est[0, :], mX_est[1, :], "b", linewidth=3) + ax_2d.plot(mX_estNT[0, :], mX_estNT[1, :], "g", linewidth=3) + ax_2d.plot(100 * X[0, 0], 100 * X[1, 0], "bo", markersize=14) + ax_2d.plot(100 * X[0, -1], 100 * X[1, -1], "ro", markersize=14) + ax_2d.set_xlabel("x [cm]") + ax_2d.set_ylabel("y [cm]") + ax_2d.set_title("Estimated vs. Actual Reach Path", fontweight="bold", + fontsize=12, fontfamily="Arial") + + # (4,3,8): X position + ax_xp = fig6.add_subplot(4, 3, 8) + ax_xp.plot(time, 100 * X[0, :], "k", linewidth=3) + ax_xp.plot(time, mX_est[0, :], "b", linewidth=3) + ax_xp.plot(time, mX_estNT[0, :], "g", linewidth=3) + ax_xp.set_ylabel("x(t) [cm]") + ax_xp.set_xticklabels([]) + ax_xp.set_title("X Position", fontweight="bold", fontsize=12, + fontfamily="Arial") + + # (4,3,9): Y position with legend + ax_yp = fig6.add_subplot(4, 3, 9) + h1, = ax_yp.plot(time, 100 * X[1, :], "k", linewidth=3) + h2, = ax_yp.plot(time, mX_est[1, :], "b", linewidth=3) + h3, = ax_yp.plot(time, mX_estNT[1, :], "g", linewidth=3) + ax_yp.legend([h1, h2, h3], ["Actual", "PPAF+Goal", "PPAF"], + loc="lower right") + ax_yp.set_ylabel("y(t) [cm]") + ax_yp.set_xticklabels([]) + ax_yp.set_title("Y Position", fontweight="bold", fontsize=12, + fontfamily="Arial") + + # (4,3,11): X velocity + ax_xv = fig6.add_subplot(4, 3, 11) + ax_xv.plot(time, 100 * X[2, :], "k", linewidth=3) + ax_xv.plot(time, mX_est[2, :], "b", linewidth=3) + ax_xv.plot(time, mX_estNT[2, :], "g", linewidth=3) + ax_xv.set_ylabel("v_x(t) [cm/s]") + ax_xv.set_xlabel("time [s]") + ax_xv.set_title("X Velocity", fontweight="bold", fontsize=12, + fontfamily="Arial") + + # (4,3,12): Y velocity + ax_yv = fig6.add_subplot(4, 3, 12) + ax_yv.plot(time, 100 * X[3, :], "k", linewidth=3) + ax_yv.plot(time, mX_est[3, :], "b", linewidth=3) + ax_yv.plot(time, mX_estNT[3, :], "g", linewidth=3) + ax_yv.set_ylabel("v_y(t) [cm/s]") + ax_yv.set_xlabel("time [s]") + ax_yv.set_title("Y Velocity", fontweight="bold", fontsize=12, + fontfamily="Arial") - # Top: model probabilities - axes6[0].plot(time, result["MU_u"][0, :], "b-", linewidth=0.5, label="P(Reach)") - axes6[0].plot(time, result["MU_u"][1, :], "r-", linewidth=0.5, label="P(Hold)") - axes6[0].axhline(0.5, color="gray", linestyle="--", linewidth=0.5) - axes6[0].set_ylabel("Model Prob") - axes6[0].set_title( - f"PPHybridFilterLinear — State Accuracy: {result['state_acc']:.1%}" - ) - axes6[0].legend(loc="upper right", fontsize=8) - - # Middle: decoded x-position - axes6[1].plot(time, result["state"][0, :], "k-", linewidth=1.0, label="True") - axes6[1].plot(time, result["X_est"][0, :], "b-", linewidth=0.7, alpha=0.8, label="Decoded") - axes6[1].set_ylabel("x (m)") - axes6[1].legend(loc="upper right", fontsize=8) - - # Bottom: decoded y-position - axes6[2].plot(time, result["state"][1, :], "k-", linewidth=1.0, label="True") - axes6[2].plot(time, result["X_est"][1, :], "r-", linewidth=0.7, alpha=0.8, label="Decoded") - axes6[2].set_ylabel("y (m)") - axes6[2].set_xlabel("Time (s)") - axes6[2].legend(loc="upper right", fontsize=8) - - fig6.suptitle( - f"Hybrid Decoding (RMSE: x={result['rmse_x']:.4f}, y={result['rmse_y']:.4f})", - fontsize=12, - ) fig6.tight_layout() return fig5, fig6 @@ -538,43 +823,38 @@ def _plot_part_c(result): def run_example05(*, export_figures=False, export_dir=None, show=False): """Run Example 05: PPAF and PPHF decoding. - Analysis workflow (mirrors Matlab ``example05_decoding_ppaf_pphf.m``): + Mirrors MATLAB ``example05_decoding_ppaf_pphf.m``: - Part A — Univariate stimulus decoding: - 1. Define 20-cell population with sinusoidal tuning. - 2. Simulate spikes from binomial CIF. - 3. Decode stimulus via PPDecodeFilterLinear. + Part A — Univariate stimulus decoding (Figs 1–2): + 1. 20-cell sinusoidal-tuned population, binomial CIF. + 2. PPDecodeFilterLinear decoding with 95% CIs. - Part B — Arm-reach PPAF: - 4. Simulate 4-state reaching movements (position + velocity). - 5. Encode with 20-cell cosine-tuning population. - 6. Decode with PPAF (free) and PPAF+Goal; compare across 20 runs. + Part B — Arm-reach PPAF (Figs 3–4): + 3. Minimum-energy reaching trajectory (4-D state). + 4. Velocity-tuned 20-cell population, binomial CIF. + 5. 20 simulations: PPAF free vs PPAF+Goal, overlaid decoded paths. - Part C — Hybrid filter: - 7. Simulate 40-cell population with discrete state modulation. - 8. Decode joint discrete/continuous state via PPHybridFilterLinear. + Part C — Hybrid filter (Figs 5–6): + 6. Fixture trajectory with 2 discrete movement states. + 7. 40-cell velocity-tuned population, binomial CIF. + 8. 20 simulations: PPHybridFilterLinear, averaged decode results. """ print("=" * 70) print("Example 05: Stimulus Decoding with PPAF and PPHF") print("=" * 70) - # --- Part A: Univariate sinusoidal stimulus --- + # --- Part A --- print("\n--- Part A: Univariate Sinusoidal Stimulus ---") result_a = _run_part_a() print(f" {result_a['n_cells']} cells, decode RMSE = {result_a['rmse']:.4f}") - # --- Part B: Arm-reach PPAF --- + # --- Part B --- print("\n--- Part B: Arm Reach PPAF (20 simulations) ---") result_b = _run_part_b() - mean_free = result_b["rmse_free"].mean(axis=0) - mean_goal = result_b["rmse_goal"].mean(axis=0) - print(f" Mean RMSE (free): x={mean_free[0]:.4f}, y={mean_free[1]:.4f}, " - f"vx={mean_free[2]:.4f}, vy={mean_free[3]:.4f}") - print(f" Mean RMSE (goal): x={mean_goal[0]:.4f}, y={mean_goal[1]:.4f}, " - f"vx={mean_goal[2]:.4f}, vy={mean_goal[3]:.4f}") - - # --- Part C: Hybrid filter --- - print("\n--- Part C: Hybrid Filter ---") + print(f" {result_b['n_cells']} cells, {result_b['n_sims']} sims completed") + + # --- Part C --- + print("\n--- Part C: Hybrid Filter (20 simulations) ---") result_c = _run_part_c() print(f" {result_c['n_cells']} cells, state accuracy = {result_c['state_acc']:.1%}") print(f" Position RMSE: x={result_c['rmse_x']:.4f}, y={result_c['rmse_y']:.4f}") @@ -588,8 +868,6 @@ def run_example05(*, export_figures=False, export_dir=None, show=False): "experiment5b": { "num_cells": float(result_b["n_cells"]), "n_sims": float(result_b["n_sims"]), - "mean_rmse_free_x": float(mean_free[0]), - "mean_rmse_goal_x": float(mean_goal[0]), }, "experiment6": { "num_cells": float(result_c["n_cells"]), @@ -611,13 +889,13 @@ def run_example05(*, export_figures=False, export_dir=None, show=False): export_dir = THIS_DIR / "figures" / "example05" export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - for i, fig in enumerate(figures, 1): - fig_names = [ - "fig01_univariate_setup", "fig02_univariate_decoding", - "fig03_reach_and_population_setup", "fig04_ppaf_goal_vs_free", - "fig05_hybrid_setup", "fig06_hybrid_decoding_summary", - ] - path = export_dir / f"{fig_names[i - 1]}.png" + fig_names = [ + "fig01_univariate_setup", "fig02_univariate_decoding", + "fig03_reach_and_population_setup", "fig04_ppaf_goal_vs_free", + "fig05_hybrid_setup", "fig06_hybrid_decoding_summary", + ] + for i, fig in enumerate(figures): + path = export_dir / f"{fig_names[i]}.png" fig.savefig(path, dpi=150, bbox_inches="tight") print(f" Saved: {path}") @@ -637,7 +915,8 @@ def run_example05(*, export_figures=False, export_dir=None, show=False): parser.add_argument("--export-figures", action="store_true") parser.add_argument("--export-dir", type=Path, default=None) parser.add_argument("--output-json", type=Path, default=None) - parser.add_argument("--show", action="store_true", help="Display figures interactively") + parser.add_argument("--show", action="store_true", + help="Display figures interactively") args = parser.parse_args() result = run_example05( @@ -646,4 +925,5 @@ def run_example05(*, export_figures=False, export_dir=None, show=False): show=args.show, ) if args.output_json: - args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") + args.output_json.write_text(json.dumps(result, indent=2), + encoding="utf-8")