diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index a37c6e3e..a8245eaa 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -2043,7 +2043,18 @@ def PPHybridFilterLinear( X_u[model_index][:, time_index] = upd_x W_u[model_index][:, :, time_index] = upd_W - det_ratio = np.sqrt(max(np.linalg.det(upd_W), 0.0)) / max(np.sqrt(max(np.linalg.det(pred_W), 0.0)), 1e-15) + # Likelihood via Laplace approximation (Srinivasan et al. 2007). + # Compute sqrt(det(W_u)) / sqrt(det(W_p)) as sqrt(det(W_u)/det(W_p)) + # to avoid a fixed floor that destroys the ratio for + # high-dimensional models with tiny absolute determinants. + det_upd = np.linalg.det(upd_W) + det_pred = np.linalg.det(pred_W) + if det_pred > 0.0 and det_upd >= 0.0: + det_ratio = np.sqrt(det_upd / det_pred) + elif det_upd == 0.0 and det_pred == 0.0: + det_ratio = 1.0 + else: + det_ratio = 0.0 log_term = np.sum(obs[:, time_index] * np.log(np.clip(lambda_delta.reshape(-1), 1e-12, np.inf)) - lambda_delta.reshape(-1)) likelihoods[model_index] = float(det_ratio * np.exp(np.clip(log_term, -200.0, 50.0))) @@ -2212,7 +2223,18 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, X_u[model_index][:, time_index] = upd_x W_u[model_index][:, :, time_index] = upd_W - det_ratio = np.sqrt(max(np.linalg.det(upd_W), 0.0)) / max(np.sqrt(max(np.linalg.det(pred_W), 0.0)), 1e-15) + # Likelihood via Laplace approximation (Srinivasan et al. 2007). + # Compute sqrt(det(W_u)) / sqrt(det(W_p)) as sqrt(det(W_u)/det(W_p)) + # to avoid a fixed floor that destroys the ratio for + # high-dimensional models with tiny absolute determinants. + det_upd = np.linalg.det(upd_W) + det_pred = np.linalg.det(pred_W) + if det_pred > 0.0 and det_upd >= 0.0: + det_ratio = np.sqrt(det_upd / det_pred) + elif det_upd == 0.0 and det_pred == 0.0: + det_ratio = 1.0 + else: + det_ratio = 0.0 log_term = np.sum(obs[:, time_index] * np.log(np.clip(lambda_delta.reshape(-1), 1e-12, np.inf)) - lambda_delta.reshape(-1)) likelihoods[model_index] = float(det_ratio * np.exp(np.clip(log_term, -200.0, 50.0)))