Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/figures/example05/fig06_hybrid_decoding_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 27 additions & 2 deletions nstat/decoding_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,26 @@ def PPHybridFilterLinear(
pNGivenS = np.zeros((n_models, num_steps), dtype=float)
S_est = np.zeros(num_steps, dtype=int)

# Fuse initial state prior with terminal constraint for each model
# (Srinivasan et al. Eq. 2.23). Matches PPDecodeFilterLinear and
# the corrected MATLAB PPHybridFilterLinear. Without this step the
# goal-directed filter starts from the raw prior, never incorporating
# target information into x0/Pi0.
for s in range(n_models):
if _has_target[s] and estimateTarget == 0:
dim = state_dims[s]
Pi0s = Pi0_models[s]
x0s = x0_models[s]
det_Pi0 = np.linalg.det(Pi0s)
if det_Pi0 != 0.0:
invPi0s = np.linalg.pinv(Pi0s)
invPitTs = np.linalg.pinv(PitT_m[s][:, :, 0])
Pi0New = np.linalg.pinv(invPi0s + invPitTs)
Pi0New = np.where(np.isnan(Pi0New), 0.0, Pi0New)
x0New = Pi0New @ (invPi0s @ x0s + invPitTs @ PhitT_m[s][:, :, 0] @ yT_models[s])
x0_models[s] = x0New
Pi0_models[s] = Pi0New

fit_type = str(fitType)

for time_index in range(num_steps):
Expand Down Expand Up @@ -2031,8 +2051,13 @@ def PPHybridFilterLinear(
for model_index in range(n_models):
dim = state_dims[model_index]
if _has_target[model_index]:
A_t = B_m[model_index][:, :, time_index]
Q_t = QT_m[model_index][:, :, time_index]
# Use B(:,:,0) and QT(:,:,-1) — matches the original
# Srinivasan et al. implementation where the prediction
# uses the initial modified dynamics B_1 with the terminal
# modified noise covariance QT_N for consistent goal
# correction across all time steps.
A_t = B_m[model_index][:, :, 0]
Q_t = QT_m[model_index][:, :, -1]
else:
A_t = _select_time_matrix(A_models[model_index], time_index, dim)
Q_t = _select_time_matrix(Q_models[model_index], time_index, dim)
Expand Down
Loading