diff --git a/docs/figures/example05/fig06_hybrid_decoding_summary.png b/docs/figures/example05/fig06_hybrid_decoding_summary.png index 4d0bea4..519f66d 100644 Binary files a/docs/figures/example05/fig06_hybrid_decoding_summary.png and b/docs/figures/example05/fig06_hybrid_decoding_summary.png differ diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index 8c7b783..15bcd51 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -1993,6 +1993,26 @@ def PPHybridFilterLinear( pNGivenS = np.zeros((n_models, num_steps), dtype=float) S_est = np.zeros(num_steps, dtype=int) + # Fuse initial state prior with terminal constraint for each model + # (Srinivasan et al. Eq. 2.23). Matches PPDecodeFilterLinear and + # the corrected MATLAB PPHybridFilterLinear. Without this step the + # goal-directed filter starts from the raw prior, never incorporating + # target information into x0/Pi0. + for s in range(n_models): + if _has_target[s] and estimateTarget == 0: + dim = state_dims[s] + Pi0s = Pi0_models[s] + x0s = x0_models[s] + det_Pi0 = np.linalg.det(Pi0s) + if det_Pi0 != 0.0: + invPi0s = np.linalg.pinv(Pi0s) + invPitTs = np.linalg.pinv(PitT_m[s][:, :, 0]) + Pi0New = np.linalg.pinv(invPi0s + invPitTs) + Pi0New = np.where(np.isnan(Pi0New), 0.0, Pi0New) + x0New = Pi0New @ (invPi0s @ x0s + invPitTs @ PhitT_m[s][:, :, 0] @ yT_models[s]) + x0_models[s] = x0New + Pi0_models[s] = Pi0New + fit_type = str(fitType) for time_index in range(num_steps): @@ -2031,8 +2051,13 @@ def PPHybridFilterLinear( for model_index in range(n_models): dim = state_dims[model_index] if _has_target[model_index]: - A_t = B_m[model_index][:, :, time_index] - Q_t = QT_m[model_index][:, :, time_index] + # Use B(:,:,0) and QT(:,:,-1) — matches the original + # Srinivasan et al. implementation where the prediction + # uses the initial modified dynamics B_1 with the terminal + # modified noise covariance QT_N for consistent goal + # correction across all time steps. + A_t = B_m[model_index][:, :, 0] + Q_t = QT_m[model_index][:, :, -1] else: A_t = _select_time_matrix(A_models[model_index], time_index, dim) Q_t = _select_time_matrix(Q_models[model_index], time_index, dim)