From 954b7c03535b184fbab263d811cccac59e99a9c8 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 18 Apr 2026 13:44:55 -0400 Subject: [PATCH 1/5] Signal non-convergence in TROP alternating-minimization solvers Addresses axis B findings #6 and #7 from the silent-failures audit: trop_global.py:448 outer alternating-min loop, trop_global.py:466 hard-coded range(20) inner FISTA loop, and trop_local.py:680 alternating-minimization loop all exited silently on max_iter exhaustion, returning the current iterate as if converged. - trop_global._solve_global_with_lowrank: thread a converged flag through the outer loop; count non-convergence events from the inner FISTA and surface the count in the outer warning for diagnostic context. One warn_if_not_converged call per solver invocation. - trop_local._estimate_model: thread a converged flag through the outer alternating-min loop; call warn_if_not_converged on exhaustion. - REGISTRY updated under TROP. New TestTROPConvergenceWarnings class (4 tests) exercises both global and local paths with forced non-convergence (max_iter=1, tol=1e-15) and a convergent negative control. Notable: the default TROP local config (max_iter=100, tol=1e-6) does not converge within max_iter on typical synthetic panels, so this PR surfaces a previously silent non-convergence that affected routine user fits. No numerical change in the returned iterate; the warning is additive. Axis-B regression-lint baseline: 5 -> 2 silent range(max_iter) loops remaining (minor loops in honest_did/power not yet addressed). Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/trop_global.py | 20 ++++++- diff_diff/trop_local.py | 6 ++ docs/methodology/REGISTRY.md | 1 + tests/test_trop.py | 106 +++++++++++++++++++++++++++++++++++ 4 files changed, 131 insertions(+), 2 deletions(-) diff --git a/diff_diff/trop_global.py b/diff_diff/trop_global.py index 3d17e4ac..2367dca1 100644 --- a/diff_diff/trop_global.py +++ b/diff_diff/trop_global.py @@ -26,7 +26,7 @@ ) from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment from diff_diff.trop_results import TROPResults -from diff_diff.utils import safe_inference +from diff_diff.utils import safe_inference, warn_if_not_converged class TROPGlobalMixin: @@ -445,6 +445,9 @@ def _solve_global_with_lowrank( # Initialize L = 0 L = np.zeros((n_periods, n_units)) + _FISTA_MAX_ITER = 20 + inner_nonconverged_count = 0 + outer_converged = False for iteration in range(max_iter): L_old = L.copy() @@ -463,7 +466,8 @@ def _solve_global_with_lowrank( L_inner_prev = L_inner # share reference initially (no copy needed) t_fista = 1.0 - for _ in range(20): + inner_converged = False + for _ in range(_FISTA_MAX_ITER): # FISTA momentum t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0 momentum = (t_fista - 1.0) / t_fista_new @@ -479,14 +483,26 @@ def _solve_global_with_lowrank( # Convergence check (L_inner_prev holds the pre-SVD value) if np.max(np.abs(L_inner - L_inner_prev)) < tol: + inner_converged = True break + if not inner_converged: + inner_nonconverged_count += 1 L = L_inner # Outer convergence check if np.max(np.abs(L - L_old)) < tol: + outer_converged = True break + if not outer_converged: + detail = ( + f"TROP global alternating minimization " + f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} " + f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})" + ) + warn_if_not_converged(False, detail, max_iter, tol) + # Final re-solve with converged L (match Rust behavior) Y_adj = Y_safe - L mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked) diff --git a/diff_diff/trop_local.py b/diff_diff/trop_local.py index fcfe59b4..ada6d746 100644 --- a/diff_diff/trop_local.py +++ b/diff_diff/trop_local.py @@ -25,6 +25,7 @@ _rust_unit_distance_matrix, ) from diff_diff.trop_results import _PrecomputedStructures +from diff_diff.utils import warn_if_not_converged def _validate_and_pivot_treatment(data, time, unit, treatment, all_periods, all_units): @@ -677,6 +678,7 @@ def _estimate_model( # Alternating minimization following Algorithm 1 (page 9) # Minimize: sum W_{ti}(Y_{ti} - alpha_i - beta_t - L_{ti})^2 + lambda_nn||L||_* + converged = False for _ in range(self.max_iter): alpha_old = alpha.copy() beta_old = beta.copy() @@ -717,7 +719,11 @@ def _estimate_model( L_diff = np.max(np.abs(L - L_old)) if max(alpha_diff, beta_diff, L_diff) < self.tol: + converged = True break + warn_if_not_converged( + converged, "TROP local alternating minimization", self.max_iter, self.tol + ) return alpha, beta, L diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index b4cd419f..fdbccd89 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1972,6 +1972,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - **Bootstrap minimum**: `n_bootstrap` must be >= 2 (enforced via `ValueError`). TROP uses bootstrap for all variance estimation — there is no analytical SE formula. - **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages - **Inference CI distribution**: After `safe_inference()` migration, CI uses t-distribution (df = max(1, n_treated_obs - 1)), consistent with p_value. Previously CI used normal-distribution while p_value used t-distribution (inconsistent). This is a minor behavioral change; CIs may be slightly wider for small n_treated_obs. +- **Note:** Both the `local` alternating-minimization solver (`_estimate_model`) and the `global` alternating-minimization solver (`_solve_global_with_lowrank`, including its hard-coded inner FISTA loop of 20 iterations) emit `UserWarning` via `diff_diff.utils.warn_if_not_converged` when the outer loop exhausts `max_iter` without reaching `tol`. The global-method warning surfaces the inner-FISTA non-convergence count as diagnostic context. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the convention used across other iterative solvers in the library. **Reference implementation(s):** - Authors' replication code (forthcoming) diff --git a/tests/test_trop.py b/tests/test_trop.py index eb28d407..2824491c 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -3936,3 +3936,109 @@ def test_observed_treatment_nan_raises_local(self): ) with pytest.raises(ValueError, match="missing treatment values"): trop_est.fit(df, "outcome", "treated", "unit", "time") + + +class TestTROPConvergenceWarnings: + """Silent-failure audit axis B: TROP alternating minimization must warn on non-convergence.""" + + @staticmethod + def _panel_matrices(simple_panel_data): + """Pivot simple_panel_data into (Y, D, n_units, n_periods, treated_periods).""" + all_units = sorted(simple_panel_data["unit"].unique()) + all_periods = sorted(simple_panel_data["period"].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + Y = ( + simple_panel_data.pivot(index="period", columns="unit", values="outcome") + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + simple_panel_data.pivot(index="period", columns="unit", values="treated") + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + treated_periods = int(np.sum(np.any(D == 1, axis=1))) + return Y, D, n_units, n_periods, treated_periods + + def test_global_alternating_min_warns_on_nonconvergence(self, simple_panel_data): + """_solve_global_with_lowrank must warn when outer alternating-min loop exhausts max_iter.""" + Y, D, n_units, n_periods, treated_periods = self._panel_matrices(simple_panel_data) + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + seed=42, + ) + delta = trop_est._compute_global_weights( + Y, D, 1.0, 1.0, treated_periods, n_units, n_periods + ) + + with pytest.warns(UserWarning, match="did not converge"): + trop_est._solve_global_with_lowrank(Y, delta, lambda_nn=0.1, max_iter=1, tol=1e-15) + + def test_global_alternating_min_no_warning_on_convergence(self, simple_panel_data): + """_solve_global_with_lowrank must not warn on a well-behaved fit with generous max_iter.""" + Y, D, n_units, n_periods, treated_periods = self._panel_matrices(simple_panel_data) + + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + seed=42, + ) + delta = trop_est._compute_global_weights( + Y, D, 1.0, 1.0, treated_periods, n_units, n_periods + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est._solve_global_with_lowrank(Y, delta, lambda_nn=0.1, max_iter=500, tol=1e-6) + assert not any("did not converge" in str(x.message) for x in w) + + def test_local_alternating_min_warns_on_nonconvergence(self, simple_panel_data): + """TROP local _estimate_model must warn when alternating-min exhausts max_iter.""" + Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data) + control_mask = (np.sum(D, axis=0) == 0) # units never treated + + trop_est = TROP( + method="local", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + max_iter=1, + tol=1e-15, + seed=42, + ) + W = np.where(D == 0, 1.0, 0.0) + + with pytest.warns(UserWarning, match="did not converge"): + trop_est._estimate_model(Y, control_mask, W, lambda_nn=0.1, + n_units=n_units, n_periods=n_periods) + + def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data): + """TROP local _estimate_model must not warn on a well-behaved fit.""" + Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data) + control_mask = (np.sum(D, axis=0) == 0) + + trop_est = TROP( + method="local", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + max_iter=500, + tol=1e-6, + seed=42, + ) + W = np.where(D == 0, 1.0, 0.0) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est._estimate_model(Y, control_mask, W, lambda_nn=0.1, + n_units=n_units, n_periods=n_periods) + assert not any("did not converge" in str(x.message) for x in w) From 810d8629817bfee09fb55bfb15d7ac61b8ce10bc Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 18 Apr 2026 14:19:53 -0400 Subject: [PATCH 2/5] Aggregate TROP convergence warnings at top-level call sites AI review on PR #317 flagged a P2 fan-out problem: the new warnings fired from low-level solver helpers that are called inside LOOCV, per-treated-observation, and bootstrap loops. A single non-convergent configuration could emit dozens to hundreds of duplicate warnings, which is noisy, slow on Python fallback paths, and a hard failure in environments that escalate warnings to errors. Fix pattern: add an optional _nonconvergence_tracker: list kwarg to _solve_global_with_lowrank, _solve_global_model, _fit_global_with_fixed_lambda, _estimate_model, and _fit_with_fixed_lambda. When provided, the solver appends non-convergence events to the tracker instead of warning directly. Each top-level caller (LOOCV, bootstrap, Rao-Wu bootstrap, per-treated-observation fit) supplies a tracker and emits a single consolidated warning summarizing the count of non-converged fits. Six call sites wrapped: - trop.py:768 local per-treated-observation main fit loop - trop_local.py:815 local LOOCV - trop_local.py:1044 local bootstrap - trop_local.py:1199 local Rao-Wu bootstrap - trop_global.py:283 global LOOCV - trop_global.py:1048 global bootstrap - trop_global.py:1226 global Rao-Wu bootstrap Also addresses the P3 test-shape finding: the convergence tests now pass observation-level `control_mask = (D == 0)` matching the production call contract at trop.py:567 and trop_local.py:625 (not the unit-level mask I had earlier). Plus a new fit()-level test `test_local_fit_emits_single_aggregate_warning` that pins the aggregate-per-call warning contract. Smoke check: test_basic_fit previously emitted one per-observation warning; it now emits a single aggregate like "TROP local per-treated-observation fit: 15 of 15 fits did not converge". Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/trop.py | 16 ++++++++-- diff_diff/trop_global.py | 68 ++++++++++++++++++++++++++++++++++------ diff_diff/trop_local.py | 52 +++++++++++++++++++++++++++--- tests/test_trop.py | 47 +++++++++++++++++++++++++-- 4 files changed, 163 insertions(+), 20 deletions(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index d06ec96c..3d19a903 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -37,7 +37,7 @@ _PrecomputedStructures, TROPResults, ) -from diff_diff.utils import safe_inference +from diff_diff.utils import safe_inference, warn_if_not_converged class TROP(TROPLocalMixin, TROPGlobalMixin): @@ -748,6 +748,7 @@ def fit( # Use pre-computed treated observations treated_observations = self._precomputed["treated_observations"] + nonconverg_tracker: list = [] for t, i in treated_observations: unit_id = idx_to_unit[i] @@ -766,7 +767,8 @@ def fit( # Fit model with these weights alpha_hat, beta_hat, L_hat = self._estimate_model( - Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods + Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods, + _nonconvergence_tracker=nonconverg_tracker, ) # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it} @@ -782,6 +784,16 @@ def fit( beta_estimates.append(beta_hat) L_estimates.append(L_hat) + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP local per-treated-observation fit: " + f"{len(nonconverg_tracker)} of {len(treated_observations)} " + f"fits did not converge", + self.max_iter, + self.tol, + ) + # Count valid treated observations n_valid_treated = len(tau_values) if n_valid_treated == 0: diff --git a/diff_diff/trop_global.py b/diff_diff/trop_global.py index 2367dca1..2008eca4 100644 --- a/diff_diff/trop_global.py +++ b/diff_diff/trop_global.py @@ -156,6 +156,7 @@ def _solve_global_model( Y: np.ndarray, delta: np.ndarray, lambda_nn: float, + _nonconvergence_tracker: Optional[List[int]] = None, ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: """ Dispatch to no-lowrank or with-lowrank solver based on lambda_nn. @@ -168,7 +169,8 @@ def _solve_global_model( L = np.zeros((n_periods, n_units)) else: mu, alpha, beta, L = self._solve_global_with_lowrank( - Y, delta, lambda_nn, self.max_iter, self.tol + Y, delta, lambda_nn, self.max_iter, self.tol, + _nonconvergence_tracker=_nonconvergence_tracker, ) return mu, alpha, beta, L @@ -273,6 +275,7 @@ def _loocv_score_global( tau_sq_sum = 0.0 n_valid = 0 + nonconverg_tracker: List[int] = [] for t_ex, i_ex in control_obs: # Create modified delta with excluded observation zeroed out @@ -280,7 +283,10 @@ def _loocv_score_global( delta_ex[t_ex, i_ex] = 0.0 try: - mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn) + mu, alpha, beta, L = self._solve_global_model( + Y, delta_ex, lambda_nn, + _nonconvergence_tracker=nonconverg_tracker, + ) # Pseudo treatment effect: tau = Y - mu - alpha - beta - L if np.isfinite(Y[t_ex, i_ex]): @@ -292,6 +298,16 @@ def _loocv_score_global( # Any failure means this lambda combination is invalid per Equation 5 return np.inf + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP global LOOCV: {len(nonconverg_tracker)} of {len(control_obs)} " + f"per-observation fits did not converge " + f"(\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}))", + self.max_iter, + self.tol, + ) + if n_valid == 0: return np.inf @@ -395,6 +411,7 @@ def _solve_global_with_lowrank( lambda_nn: float, max_iter: int = 100, tol: float = 1e-6, + _nonconvergence_tracker: Optional[List[int]] = None, ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]: """ Solve TWFE + low-rank on control data via alternating minimization. @@ -496,12 +513,15 @@ def _solve_global_with_lowrank( break if not outer_converged: - detail = ( - f"TROP global alternating minimization " - f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} " - f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})" - ) - warn_if_not_converged(False, detail, max_iter, tol) + if _nonconvergence_tracker is not None: + _nonconvergence_tracker.append(inner_nonconverged_count) + else: + detail = ( + f"TROP global alternating minimization " + f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} " + f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})" + ) + warn_if_not_converged(False, detail, max_iter, tol) # Final re-solve with converged L (match Rust behavior) Y_adj = Y_safe - L @@ -1000,6 +1020,7 @@ def _bootstrap_variance_global( n_control_units = len(control_units) bootstrap_estimates_list: List[float] = [] + nonconverg_tracker: List[int] = [] for _ in range(self.n_bootstrap): # Stratified sampling @@ -1034,6 +1055,7 @@ def _bootstrap_variance_global( optimal_lambda, treated_periods, survey_design=survey_design, + _nonconvergence_tracker=nonconverg_tracker, ) if np.isfinite(tau): bootstrap_estimates_list.append(tau) @@ -1042,6 +1064,15 @@ def _bootstrap_variance_global( bootstrap_estimates = np.array(bootstrap_estimates_list) + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP global bootstrap: {len(nonconverg_tracker)} of " + f"{self.n_bootstrap} replicate fits did not converge", + self.max_iter, + self.tol, + ) + if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning @@ -1185,6 +1216,7 @@ def _bootstrap_rao_wu_global( ) bootstrap_estimates_list: List[float] = [] + nonconverg_tracker: List[int] = [] for _ in range(self.n_bootstrap): try: @@ -1203,7 +1235,10 @@ def _bootstrap_rao_wu_global( delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) - mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model( + Y, delta, lambda_nn, + _nonconvergence_tracker=nonconverg_tracker, + ) # Extract weighted ATT using Rao-Wu rescaled weights att, _, _ = self._extract_posthoc_tau( @@ -1217,6 +1252,15 @@ def _bootstrap_rao_wu_global( bootstrap_estimates = np.array(bootstrap_estimates_list) + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP global Rao-Wu bootstrap: {len(nonconverg_tracker)} of " + f"{self.n_bootstrap} replicate fits did not converge", + self.max_iter, + self.tol, + ) + if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", @@ -1238,6 +1282,7 @@ def _fit_global_with_fixed_lambda( fixed_lambda: Tuple[float, float, float], treated_periods: int, survey_design=None, + _nonconvergence_tracker: Optional[List[int]] = None, ) -> float: """ Fit global model with fixed tuning parameters. @@ -1279,7 +1324,10 @@ def _fit_global_with_fixed_lambda( ) # Fit model on control data and extract post-hoc tau - mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model( + Y, delta, lambda_nn, + _nonconvergence_tracker=_nonconvergence_tracker, + ) att, _, _ = self._extract_posthoc_tau( Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr ) diff --git a/diff_diff/trop_local.py b/diff_diff/trop_local.py index ada6d746..0ac731fd 100644 --- a/diff_diff/trop_local.py +++ b/diff_diff/trop_local.py @@ -12,7 +12,7 @@ import logging import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import pandas as pd @@ -610,6 +610,7 @@ def _estimate_model( n_units: int, n_periods: int, exclude_obs: Optional[Tuple[int, int]] = None, + _nonconvergence_tracker: Optional[List[int]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Estimate the model: Y = alpha + beta + L + tau*D + eps with nuclear norm penalty on L. @@ -721,9 +722,14 @@ def _estimate_model( if max(alpha_diff, beta_diff, L_diff) < self.tol: converged = True break - warn_if_not_converged( - converged, "TROP local alternating minimization", self.max_iter, self.tol - ) + if not converged: + if _nonconvergence_tracker is not None: + _nonconvergence_tracker.append(1) + else: + warn_if_not_converged( + converged, "TROP local alternating minimization", + self.max_iter, self.tol, + ) return alpha, beta, L @@ -802,6 +808,7 @@ def _loocv_score_obs_specific( tau_squared_sum = 0.0 n_valid = 0 + nonconverg_tracker: List[int] = [] for t, i in control_obs: try: @@ -820,6 +827,7 @@ def _loocv_score_obs_specific( n_units, n_periods, exclude_obs=(t, i), + _nonconvergence_tracker=nonconverg_tracker, ) # Pseudo treatment effect @@ -838,6 +846,16 @@ def _loocv_score_obs_specific( ) return np.inf + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP local LOOCV: {len(nonconverg_tracker)} of " + f"{len(control_obs)} per-observation fits did not converge " + f"(\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}))", + self.max_iter, + self.tol, + ) + # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8): # Q(lambda) = sum_{j,s: D_js=0} [tau_js^loocv(lambda)]^2 return tau_squared_sum @@ -995,6 +1013,7 @@ def _bootstrap_variance( n_control_units = len(control_units) bootstrap_estimates_list = [] + nonconverg_tracker: List[int] = [] for _ in range(self.n_bootstrap): # Stratified sampling: sample control and treated units separately @@ -1031,6 +1050,7 @@ def _bootstrap_variance( time, optimal_lambda, survey_design=survey_design, + _nonconvergence_tracker=nonconverg_tracker, ) if np.isfinite(att): bootstrap_estimates_list.append(att) @@ -1039,6 +1059,15 @@ def _bootstrap_variance( bootstrap_estimates = np.array(bootstrap_estimates_list) + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP local bootstrap: {len(nonconverg_tracker)} non-converged " + f"per-observation fits across {self.n_bootstrap} bootstrap replicates", + self.max_iter, + self.tol, + ) + if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. " @@ -1167,6 +1196,7 @@ def _bootstrap_rao_wu_local( # weights, mirroring the physical-resampling bootstrap but using weight # perturbation instead of unit resampling. bootstrap_estimates_list = [] + nonconverg_tracker: List[int] = [] for _ in range(self.n_bootstrap): try: @@ -1187,6 +1217,7 @@ def _bootstrap_rao_wu_local( optimal_lambda, survey_design=survey_design, unit_weight_arr=boot_weights, + _nonconvergence_tracker=nonconverg_tracker, ) if np.isfinite(att): @@ -1196,6 +1227,15 @@ def _bootstrap_rao_wu_local( bootstrap_estimates = np.array(bootstrap_estimates_list) + if nonconverg_tracker: + warn_if_not_converged( + False, + f"TROP local Rao-Wu bootstrap: {len(nonconverg_tracker)} non-converged " + f"per-observation fits across {self.n_bootstrap} bootstrap replicates", + self.max_iter, + self.tol, + ) + if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. " @@ -1218,6 +1258,7 @@ def _fit_with_fixed_lambda( fixed_lambda: Tuple[float, float, float], survey_design=None, unit_weight_arr: Optional[np.ndarray] = None, + _nonconvergence_tracker: Optional[List[int]] = None, ) -> float: """ Fit model with fixed tuning parameters (for bootstrap). @@ -1297,7 +1338,8 @@ def _fit_with_fixed_lambda( # Fit model with these weights alpha, beta, L = self._estimate_model( - Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods + Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods, + _nonconvergence_tracker=_nonconvergence_tracker, ) # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it} diff --git a/tests/test_trop.py b/tests/test_trop.py index 2824491c..4fa50efe 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -4002,9 +4002,12 @@ def test_global_alternating_min_no_warning_on_convergence(self, simple_panel_dat assert not any("did not converge" in str(x.message) for x in w) def test_local_alternating_min_warns_on_nonconvergence(self, simple_panel_data): - """TROP local _estimate_model must warn when alternating-min exhausts max_iter.""" + """TROP local _estimate_model must warn when alternating-min exhausts max_iter. + + Uses observation-level control_mask matching the production call contract. + """ Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data) - control_mask = (np.sum(D, axis=0) == 0) # units never treated + control_mask = D == 0 # observation-level, matching trop.py/trop_local.py usage trop_est = TROP( method="local", @@ -4024,7 +4027,7 @@ def test_local_alternating_min_warns_on_nonconvergence(self, simple_panel_data): def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data): """TROP local _estimate_model must not warn on a well-behaved fit.""" Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data) - control_mask = (np.sum(D, axis=0) == 0) + control_mask = D == 0 # observation-level, matching production trop_est = TROP( method="local", @@ -4042,3 +4045,41 @@ def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data trop_est._estimate_model(Y, control_mask, W, lambda_nn=0.1, n_units=n_units, n_periods=n_periods) assert not any("did not converge" in str(x.message) for x in w) + + def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data): + """Fit-level warning aggregation: per-treated-observation non-convergence must + surface as at most one aggregate warning per call, not one per observation. + + Pins the P2 fan-out fix: warnings are accumulated via the + `_nonconvergence_tracker` kwarg and emitted once at the top-level fit.""" + trop_est = TROP( + method="local", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + max_iter=1, + tol=1e-15, + n_bootstrap=2, + seed=42, + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + # The per-treated-observation fit loop must emit exactly one aggregate + # warning of the form "TROP local per-treated-observation fit: N of M fits + # did not converge", not N separate warnings. + per_obs_warnings = [ + x for x in w if "per-treated-observation" in str(x.message) + ] + assert len(per_obs_warnings) <= 1, ( + f"Expected at most one aggregated per-treated-observation warning, " + f"got {len(per_obs_warnings)}: {[str(x.message) for x in per_obs_warnings]}" + ) From 3d21dd87a5c06d61b023b1daf65529ea227d422a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 18 Apr 2026 14:44:42 -0400 Subject: [PATCH 3/5] Extend TROP convergence tests to cover LOOCV and bootstrap aggregation AI review on PR #317 flagged that my earlier fit()-level test only covered the per-treated-observation aggregation path, not the LOOCV or bootstrap wrapper paths. A regression in _nonconvergence_tracker plumbing for those paths could slip through. - test_local_fit_emits_single_aggregate_warning: expanded to assert per-obs, LOOCV, and bootstrap warnings each appear at most once per .fit(). - test_global_fit_emits_single_aggregate_warning: new test mirroring the local one for method="global" (LOOCV + bootstrap paths). Both use n_bootstrap=2, minimal lambda grid, and max_iter=1/tol=1e-15 to keep cost low: ~3.4s for all 6 TROP convergence tests combined. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_trop.py | 57 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/tests/test_trop.py b/tests/test_trop.py index 4fa50efe..99991a46 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -4047,11 +4047,9 @@ def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data assert not any("did not converge" in str(x.message) for x in w) def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data): - """Fit-level warning aggregation: per-treated-observation non-convergence must - surface as at most one aggregate warning per call, not one per observation. - - Pins the P2 fan-out fix: warnings are accumulated via the - `_nonconvergence_tracker` kwarg and emitted once at the top-level fit.""" + """Fit-level warning aggregation: per-treated-observation, LOOCV, and + bootstrap non-convergence each surface as at most one aggregate warning + per wrapping call, not one per inner fit. Pins the P2 fan-out fix.""" trop_est = TROP( method="local", lambda_time_grid=[1.0], @@ -4073,13 +4071,44 @@ def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data): time="period", ) - # The per-treated-observation fit loop must emit exactly one aggregate - # warning of the form "TROP local per-treated-observation fit: N of M fits - # did not converge", not N separate warnings. - per_obs_warnings = [ - x for x in w if "per-treated-observation" in str(x.message) - ] - assert len(per_obs_warnings) <= 1, ( - f"Expected at most one aggregated per-treated-observation warning, " - f"got {len(per_obs_warnings)}: {[str(x.message) for x in per_obs_warnings]}" + def count_matching(needle: str) -> int: + return sum(1 for x in w if needle in str(x.message)) + + # Per-treated-observation aggregation (called once per .fit()). + assert count_matching("per-treated-observation") <= 1 + # LOOCV aggregation (called once per (lambda_time, lambda_unit, lambda_nn) combo; + # grid has exactly 1 combo). + assert count_matching("local LOOCV") <= 1 + # Bootstrap aggregation (called once per .fit()). + assert count_matching("local bootstrap") <= 1 + + def test_global_fit_emits_single_aggregate_warning(self, simple_panel_data): + """Global-method fit-level warning aggregation: LOOCV and bootstrap + non-convergence each surface as at most one aggregate warning per + wrapping call, mirroring the local test above.""" + trop_est = TROP( + method="global", + lambda_time_grid=[1.0], + lambda_unit_grid=[1.0], + lambda_nn_grid=[0.1], + max_iter=1, + tol=1e-15, + n_bootstrap=2, + seed=42, ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) + + def count_matching(needle: str) -> int: + return sum(1 for x in w if needle in str(x.message)) + + assert count_matching("global LOOCV") <= 1 + assert count_matching("global bootstrap") <= 1 From 71e4153e9d3217667b73939d591f43e362316700 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 18 Apr 2026 14:54:29 -0400 Subject: [PATCH 4/5] Force Python backend in TROP aggregation tests and tighten assertions AI review on PR #317 flagged that my fit()-level tests did not force HAS_RUST_BACKEND=False, so in Rust-enabled environments they could pass without exercising the Python aggregation code they were intended to cover. The earlier <= 1 assertion also would not catch a dropped warning. Changes: - patch.object(sys.modules[...], "HAS_RUST_BACKEND", False) across diff_diff.trop, diff_diff.trop_local, diff_diff.trop_global so the LOOCV and bootstrap paths route through the Python aggregation wrappers. (Uses sys.modules to bypass the name collision between the trop() convenience function and the trop module at diff_diff.trop.) - Tightened assertions: per-treated-observation and bootstrap aggregation are called exactly once per fit() so assert == 1. LOOCV is called multiple times by the coordinate-descent grid refinement in trop.py, so the per-call single-emission contract is verified via message format ("N of M per-observation fits") on every LOOCV aggregate rather than by global occurrence count. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_trop.py | 115 ++++++++++++++++++++++++++++++--------------- 1 file changed, 77 insertions(+), 38 deletions(-) diff --git a/tests/test_trop.py b/tests/test_trop.py index 99991a46..5ea51db6 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -1,6 +1,8 @@ """Tests for Triply Robust Panel (TROP) estimator.""" +import sys import warnings +from unittest.mock import patch import numpy as np import pandas as pd @@ -4047,9 +4049,19 @@ def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data assert not any("did not converge" in str(x.message) for x in w) def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data): - """Fit-level warning aggregation: per-treated-observation, LOOCV, and - bootstrap non-convergence each surface as at most one aggregate warning - per wrapping call, not one per inner fit. Pins the P2 fan-out fix.""" + """Fit-level warning aggregation: when routed through the Python + backend, every aggregation wrapper (per-treated-observation, LOOCV, + bootstrap) emits exactly one aggregate warning per call, not per + inner fit. + + Forces HAS_RUST_BACKEND=False so the new Python aggregation paths are + actually exercised; without this the LOOCV and bootstrap paths would + dispatch to Rust in wheel-built environments and skip the changed code. + + LOOCV count is >= 1 (not == 1) because fit() calls it multiple times + during coordinate-descent refinement of the lambda grid; the contract + this test pins is *per-call* single emission, asserted via message + format rather than global occurrence count.""" trop_est = TROP( method="local", lambda_time_grid=[1.0], @@ -4061,31 +4073,47 @@ def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data): seed=42, ) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - trop_est.fit( - simple_panel_data, - outcome="outcome", - treatment="treated", - unit="unit", - time="period", - ) - - def count_matching(needle: str) -> int: - return sum(1 for x in w if needle in str(x.message)) + trop_mod = sys.modules["diff_diff.trop"] + trop_local_mod = sys.modules["diff_diff.trop_local"] + with patch.object(trop_mod, "HAS_RUST_BACKEND", False), \ + patch.object(trop_local_mod, "HAS_RUST_BACKEND", False): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) - # Per-treated-observation aggregation (called once per .fit()). - assert count_matching("per-treated-observation") <= 1 - # LOOCV aggregation (called once per (lambda_time, lambda_unit, lambda_nn) combo; - # grid has exactly 1 combo). - assert count_matching("local LOOCV") <= 1 - # Bootstrap aggregation (called once per .fit()). - assert count_matching("local bootstrap") <= 1 + def matching(needle: str): + return [str(x.message) for x in w if needle in str(x.message)] + + # Per-treated-observation aggregation (called exactly once per .fit()). + per_obs = matching("per-treated-observation") + assert len(per_obs) == 1, f"expected 1 per-obs aggregate, got {len(per_obs)}" + + # Bootstrap aggregation (called exactly once per .fit()). + boot = matching("local bootstrap") + assert len(boot) == 1, f"expected 1 bootstrap aggregate, got {len(boot)}" + + # LOOCV: at least one aggregate fired (Python path exercised), and each + # fired message is itself an aggregate (has the "N of M" fan-out-reduced + # format), not one warning per inner observation. + loocv = matching("local LOOCV") + assert len(loocv) >= 1, "expected at least one LOOCV aggregate warning" + for msg in loocv: + assert "of" in msg and "per-observation fits" in msg, ( + f"LOOCV warning is not in aggregate format (fan-out not reduced): {msg}" + ) def test_global_fit_emits_single_aggregate_warning(self, simple_panel_data): - """Global-method fit-level warning aggregation: LOOCV and bootstrap - non-convergence each surface as at most one aggregate warning per - wrapping call, mirroring the local test above.""" + """Global-method fit-level warning aggregation: mirrors the local test. + + Forces HAS_RUST_BACKEND=False to exercise the Python aggregation path. + LOOCV count is >= 1 by the same grid-refinement reasoning; each fired + message must be in the aggregate format.""" trop_est = TROP( method="global", lambda_time_grid=[1.0], @@ -4097,18 +4125,29 @@ def test_global_fit_emits_single_aggregate_warning(self, simple_panel_data): seed=42, ) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - trop_est.fit( - simple_panel_data, - outcome="outcome", - treatment="treated", - unit="unit", - time="period", - ) + trop_mod = sys.modules["diff_diff.trop"] + trop_global_mod = sys.modules["diff_diff.trop_global"] + with patch.object(trop_mod, "HAS_RUST_BACKEND", False), \ + patch.object(trop_global_mod, "HAS_RUST_BACKEND", False): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + trop_est.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="period", + ) - def count_matching(needle: str) -> int: - return sum(1 for x in w if needle in str(x.message)) + def matching(needle: str): + return [str(x.message) for x in w if needle in str(x.message)] - assert count_matching("global LOOCV") <= 1 - assert count_matching("global bootstrap") <= 1 + boot = matching("global bootstrap") + assert len(boot) == 1, f"expected 1 bootstrap aggregate, got {len(boot)}" + + loocv = matching("global LOOCV") + assert len(loocv) >= 1, "expected at least one LOOCV aggregate warning" + for msg in loocv: + assert "of" in msg and "per-observation fits" in msg, ( + f"LOOCV warning is not in aggregate format (fan-out not reduced): {msg}" + ) From 41b6a5f38d21303577b917bbe04b999c4f7bea3e Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 18 Apr 2026 15:03:35 -0400 Subject: [PATCH 5/5] Use attempt count, not treated-observation count, in per-obs aggregate warning AI review noted that the per-treated-observation aggregate warning used len(treated_observations) as the denominator, but the loop skips cells with non-finite outcomes before calling _estimate_model(). On panels with missing treated outcomes, the reported non-convergence rate would be understated because attempted-but-failed fits were compared against a total that included never-attempted cells. Track n_fits_attempted separately and use that as the denominator. Report is now "X of N-attempted fits did not converge" rather than "X of N-treated-cells fits did not converge". Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/trop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 3d19a903..76a44b43 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -749,6 +749,7 @@ def fit( # Use pre-computed treated observations treated_observations = self._precomputed["treated_observations"] nonconverg_tracker: list = [] + n_fits_attempted = 0 for t, i in treated_observations: unit_id = idx_to_unit[i] @@ -766,6 +767,7 @@ def fit( ) # Fit model with these weights + n_fits_attempted += 1 alpha_hat, beta_hat, L_hat = self._estimate_model( Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods, _nonconvergence_tracker=nonconverg_tracker, @@ -788,7 +790,7 @@ def fit( warn_if_not_converged( False, f"TROP local per-treated-observation fit: " - f"{len(nonconverg_tracker)} of {len(treated_observations)} " + f"{len(nonconverg_tracker)} of {n_fits_attempted} " f"fits did not converge", self.max_iter, self.tol,