diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index 5fe3eb93..3798927e 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -28,7 +28,7 @@ ImputationDiDResults, ) from diff_diff.linalg import solve_ols -from diff_diff.utils import safe_inference +from diff_diff.utils import safe_inference, warn_if_not_converged # ============================================================================= # Main Estimator @@ -909,6 +909,7 @@ def _iterative_fe( wsum_t = w_series.groupby(time_vals).transform("sum").values wsum_u = w_series.groupby(unit_vals).transform("sum").values + converged = False with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): resid_after_alpha = y - alpha @@ -943,7 +944,9 @@ def _iterative_fe( alpha = alpha_new beta = beta_new if max_change < tol: + converged = True break + warn_if_not_converged(converged, "ImputationDiD iterative FE solver", max_iter, tol) unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict() time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict() @@ -978,6 +981,7 @@ def _iterative_demean( wsum_t = w_series.groupby(time_vals).transform("sum").values wsum_u = w_series.groupby(unit_vals).transform("sum").values + converged = False with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): if weights is not None: @@ -1001,8 +1005,10 @@ def _iterative_demean( result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new + converged = True break result = result_new + warn_if_not_converged(converged, "ImputationDiD iterative demean", max_iter, tol) return result @staticmethod diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index ff411383..560ded0a 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -41,7 +41,7 @@ TwoStageBootstrapResults, # noqa: F401 TwoStageDiDResults, ) # noqa: F401 (re-export) -from diff_diff.utils import safe_inference +from diff_diff.utils import safe_inference, warn_if_not_converged # ============================================================================= # Main Estimator @@ -887,6 +887,7 @@ def _iterative_fe( wsum_t = w_series.groupby(time_vals).transform("sum").values wsum_u = w_series.groupby(unit_vals).transform("sum").values + converged = False with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): resid_after_alpha = y - alpha @@ -920,7 +921,9 @@ def _iterative_fe( alpha = alpha_new beta = beta_new if max_change < tol: + converged = True break + warn_if_not_converged(converged, "TwoStageDiD iterative FE solver", max_iter, tol) unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict() time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict() @@ -951,6 +954,7 @@ def _iterative_demean( wsum_t = w_series.groupby(time_vals).transform("sum").values wsum_u = w_series.groupby(unit_vals).transform("sum").values + converged = False with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): if weights is not None: @@ -974,8 +978,10 @@ def _iterative_demean( result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new + converged = True break result = result_new + warn_if_not_converged(converged, "TwoStageDiD iterative demean", max_iter, tol) return result def _fit_untreated_model( diff --git a/diff_diff/utils.py b/diff_diff/utils.py index c252c418..4a7ba5eb 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -65,6 +65,29 @@ def validate_binary(arr: np.ndarray, name: str) -> None: raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}") +def warn_if_not_converged( + converged: bool, + method_name: str, + max_iter: int, + tol: Optional[float] = None, + stacklevel: int = 3, +) -> None: + """Emit a UserWarning when an iterative solver exhausts max_iter without converging. + + Shared helper for axis-B silent-failure fixes (iterative loops that otherwise + return the current iterate without signaling non-convergence). + """ + if converged: + return + tol_suffix = f" (tol={tol})" if tol is not None else "" + warnings.warn( + f"{method_name} did not converge in {max_iter} iterations{tol_suffix}. " + "Results may be inaccurate.", + UserWarning, + stacklevel=stacklevel, + ) + + def compute_robust_se( X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None ) -> np.ndarray: @@ -1791,6 +1814,8 @@ def within_transform( inplace: bool = False, suffix: str = "_demeaned", weights: Optional[np.ndarray] = None, + max_iter: int = 100, + tol: float = 1e-8, ) -> pd.DataFrame: """ Apply two-way within transformation to remove unit and time fixed effects. @@ -1818,6 +1843,14 @@ def within_transform( Suffix for new column names when inplace=False. weights : np.ndarray, optional Observation weights for weighted group means. + max_iter : int, default 100 + Maximum number of alternating-projection iterations. Used only when + ``weights`` is not ``None``; the unweighted path is a single pass and + ignores this argument. Emits a ``UserWarning`` per call when any + variable fails to converge within this budget. + tol : float, default 1e-8 + Convergence tolerance on the max absolute change across the iterate. + Used only when ``weights`` is not ``None``. Returns ------- @@ -1853,29 +1886,45 @@ def _weighted_group_demean(x, groups, w, w_sum): wx_sum = pd.Series(w * x).groupby(groups).transform("sum").values return x - wx_sum / w_sum + non_converged_vars: List[str] = [] if inplace: for var in variables: x = data[var].values.astype(np.float64) - for _iter in range(100): # max iterations + converged = False + for _iter in range(max_iter): x_old = x.copy() x = _weighted_group_demean(x, unit_groups, w, unit_w_sum) x = _weighted_group_demean(x, time_groups, w, time_w_sum) - if np.max(np.abs(x - x_old)) < 1e-8: + if np.max(np.abs(x - x_old)) < tol: + converged = True break + if not converged: + non_converged_vars.append(var) data[var] = x else: demeaned_data = {} for var in variables: x = data[var].values.astype(np.float64) - for _iter in range(100): + converged = False + for _iter in range(max_iter): x_old = x.copy() x = _weighted_group_demean(x, unit_groups, w, unit_w_sum) x = _weighted_group_demean(x, time_groups, w, time_w_sum) - if np.max(np.abs(x - x_old)) < 1e-8: + if np.max(np.abs(x - x_old)) < tol: + converged = True break + if not converged: + non_converged_vars.append(var) demeaned_data[f"{var}{suffix}"] = x demeaned_df = pd.DataFrame(demeaned_data, index=data.index) data = pd.concat([data, demeaned_df], axis=1) + if non_converged_vars: + warn_if_not_converged( + False, + f"within_transform weighted demean (variables: {non_converged_vars})", + max_iter, + tol, + ) else: # Cache groupby objects for efficiency unit_grouper = data.groupby(unit, sort=False) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 0bf3c5de..e5cb686b 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1083,6 +1083,7 @@ where `W_it(h) = 1[K_it = h]` are lead indicators, estimated on `Omega_0` only. - **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and design-based variance via `compute_survey_if_variance()`. PSU clustering, stratification, and FPC are fully supported in the Theorem 3 variance path. When `resolved_survey` is present, the observation-level influence function (`v_it * epsilon_tilde_it`) is passed to `compute_survey_if_variance()` which applies the stratified PSU-level sandwich with FPC correction. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights. - **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns. - **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean. +- **Note:** Both the iterative FE solver (`_iterative_fe`, Step 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization and the pre-trend test) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`. **Reference implementation(s):** - Stata: `did_imputation` (Borusyak, Jaravel, Spiess; available from SSC) @@ -1160,6 +1161,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. - **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. - **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights. +- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`. **Reference implementation(s):** - R: `did2s::did2s()` (Kyle Butts & John Gardner) @@ -1299,6 +1301,7 @@ The saturated ETWFE regression includes: The interaction coefficient `δ_{g,t}` identifies `ATT(g, t)` under parallel trends. - **Note:** OLS path uses iterative alternating-projection within-transformation (uniform weights) for exact FE absorption on both balanced and unbalanced panels. One-pass demeaning (`y - ȳ_i - ȳ_t + ȳ`) is only exact for balanced panels. +- **Note:** The weighted within-transformation (`utils.within_transform` with `weights`) is invoked on every WooldridgeDiD fit (survey weights when provided, `np.ones` otherwise) and emits a `UserWarning` on non-convergence per the shared convention documented under *Absorbed Fixed Effects with Survey Weights*. *Nonlinear extensions (Wooldridge 2023):* @@ -2519,6 +2522,15 @@ unequal selection probabilities). are rejected (single-pass sequential demeaning is not the correct weighted FWL projection for N > 1 dimensions; iterative alternating projections are needed but not yet implemented). +- **Note:** The shared weighted within-transformation path + (`diff_diff.utils.within_transform`, hit whenever `weights is not None`) emits + a `UserWarning` per call when any transformed variable exits the + alternating-projection loop without reaching `tol` within `max_iter`. + Defaults: `max_iter=100`, `tol=1e-8`. This signal applies uniformly across + TwoWayFixedEffects, SunAbraham, BaconDecomposition, and WooldridgeDiD whenever + they route through this helper (survey-weighted or otherwise). Silent return + of the current iterate was classified as a silent failure under the Phase 2 + audit and replaced with this explicit signal. ### Survey Degrees of Freedom diff --git a/tests/test_imputation.py b/tests/test_imputation.py index 5e2b46b7..53093d8d 100644 --- a/tests/test_imputation.py +++ b/tests/test_imputation.py @@ -2087,3 +2087,57 @@ def test_balanced_cohort_mask_requires_negative_horizons(self): df_treated, "first_treat", all_horizons, 1, cohort_rel_times ) assert all(mask1) + + def test_iterative_fe_warns_on_nonconvergence(self): + """Silent-failure audit axis B: _iterative_fe must warn when max_iter exhausts.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + y = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(y)) + est = ImputationDiD() + + with pytest.warns(UserWarning, match="did not converge"): + est._iterative_fe(y, units, times, idx, max_iter=1, tol=1e-15) + + def test_iterative_fe_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-behaved convergent input.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + y = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(y)) + est = ImputationDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + est._iterative_fe(y, units, times, idx) + assert not any("did not converge" in str(x.message) for x in w) + + def test_iterative_demean_warns_on_nonconvergence(self): + """Silent-failure audit axis B: _iterative_demean must warn when max_iter exhausts.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + vals = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(vals)) + + with pytest.warns(UserWarning, match="did not converge"): + ImputationDiD._iterative_demean(vals, units, times, idx, max_iter=1, tol=1e-15) + + def test_iterative_demean_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-behaved convergent input.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + vals = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(vals)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ImputationDiD._iterative_demean(vals, units, times, idx) + assert not any("did not converge" in str(x.message) for x in w) diff --git a/tests/test_methodology_twfe.py b/tests/test_methodology_twfe.py index 5734db94..631d4658 100644 --- a/tests/test_methodology_twfe.py +++ b/tests/test_methodology_twfe.py @@ -235,6 +235,27 @@ def test_demeaned_outcome_sums_to_zero(self): np.testing.assert_allclose(unit_sums.values, 0, atol=1e-10) np.testing.assert_allclose(time_sums.values, 0, atol=1e-10) + def test_within_transform_weighted_warns_on_nonconvergence(self): + """Silent-failure audit axis B: within_transform weighted path must warn.""" + data = generate_twfe_panel(n_units=20, n_periods=4, seed=99) + weights = np.ones(len(data)) + + with pytest.warns(UserWarning, match="did not converge"): + within_transform( + data, ["outcome"], "unit", "period", + weights=weights, max_iter=1, tol=1e-15, + ) + + def test_within_transform_weighted_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-behaved convergent input.""" + data = generate_twfe_panel(n_units=20, n_periods=4, seed=99) + weights = np.ones(len(data)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + within_transform(data, ["outcome"], "unit", "period", weights=weights) + assert not any("did not converge" in str(x.message) for x in w) + # ============================================================================= # Phase 2: R Comparison diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py index bae7ff32..ebf093be 100644 --- a/tests/test_two_stage.py +++ b/tests/test_two_stage.py @@ -1324,3 +1324,57 @@ def test_item2_nan_ytilde_group(self): first_treat="first_treat", aggregate="group", ) + + def test_iterative_fe_warns_on_nonconvergence(self): + """Silent-failure audit axis B: _iterative_fe must warn when max_iter exhausts.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + y = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(y)) + est = TwoStageDiD() + + with pytest.warns(UserWarning, match="did not converge"): + est._iterative_fe(y, units, times, idx, max_iter=1, tol=1e-15) + + def test_iterative_fe_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-behaved convergent input.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + y = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(y)) + est = TwoStageDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + est._iterative_fe(y, units, times, idx) + assert not any("did not converge" in str(x.message) for x in w) + + def test_iterative_demean_warns_on_nonconvergence(self): + """Silent-failure audit axis B: _iterative_demean must warn when max_iter exhausts.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + vals = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(vals)) + + with pytest.warns(UserWarning, match="did not converge"): + TwoStageDiD._iterative_demean(vals, units, times, idx, max_iter=1, tol=1e-15) + + def test_iterative_demean_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-behaved convergent input.""" + rng = np.random.default_rng(42) + n_units, n_periods = 8, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + vals = rng.standard_normal(n_units * n_periods) + idx = pd.RangeIndex(len(vals)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + TwoStageDiD._iterative_demean(vals, units, times, idx) + assert not any("did not converge" in str(x.message) for x in w)