Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion diff_diff/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ImputationDiDResults,
)
from diff_diff.linalg import solve_ols
from diff_diff.utils import safe_inference
from diff_diff.utils import safe_inference, warn_if_not_converged

# =============================================================================
# Main Estimator
Expand Down Expand Up @@ -909,6 +909,7 @@ def _iterative_fe(
wsum_t = w_series.groupby(time_vals).transform("sum").values
wsum_u = w_series.groupby(unit_vals).transform("sum").values

converged = False
with np.errstate(invalid="ignore", divide="ignore"):
for iteration in range(max_iter):
resid_after_alpha = y - alpha
Expand Down Expand Up @@ -943,7 +944,9 @@ def _iterative_fe(
alpha = alpha_new
beta = beta_new
if max_change < tol:
converged = True
break
warn_if_not_converged(converged, "ImputationDiD iterative FE solver", max_iter, tol)

unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
Expand Down Expand Up @@ -978,6 +981,7 @@ def _iterative_demean(
wsum_t = w_series.groupby(time_vals).transform("sum").values
wsum_u = w_series.groupby(unit_vals).transform("sum").values

converged = False
with np.errstate(invalid="ignore", divide="ignore"):
for _ in range(max_iter):
if weights is not None:
Expand All @@ -1001,8 +1005,10 @@ def _iterative_demean(
result_new = result_after_time - unit_means
if np.max(np.abs(result_new - result)) < tol:
result = result_new
converged = True
break
result = result_new
warn_if_not_converged(converged, "ImputationDiD iterative demean", max_iter, tol)
return result

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion diff_diff/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
TwoStageBootstrapResults, # noqa: F401
TwoStageDiDResults,
) # noqa: F401 (re-export)
from diff_diff.utils import safe_inference
from diff_diff.utils import safe_inference, warn_if_not_converged

# =============================================================================
# Main Estimator
Expand Down Expand Up @@ -887,6 +887,7 @@ def _iterative_fe(
wsum_t = w_series.groupby(time_vals).transform("sum").values
wsum_u = w_series.groupby(unit_vals).transform("sum").values

converged = False
with np.errstate(invalid="ignore", divide="ignore"):
for iteration in range(max_iter):
resid_after_alpha = y - alpha
Expand Down Expand Up @@ -920,7 +921,9 @@ def _iterative_fe(
alpha = alpha_new
beta = beta_new
if max_change < tol:
converged = True
break
warn_if_not_converged(converged, "TwoStageDiD iterative FE solver", max_iter, tol)

unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
Expand Down Expand Up @@ -951,6 +954,7 @@ def _iterative_demean(
wsum_t = w_series.groupby(time_vals).transform("sum").values
wsum_u = w_series.groupby(unit_vals).transform("sum").values

converged = False
with np.errstate(invalid="ignore", divide="ignore"):
for _ in range(max_iter):
if weights is not None:
Expand All @@ -974,8 +978,10 @@ def _iterative_demean(
result_new = result_after_time - unit_means
if np.max(np.abs(result_new - result)) < tol:
result = result_new
converged = True
break
result = result_new
warn_if_not_converged(converged, "TwoStageDiD iterative demean", max_iter, tol)
return result

def _fit_untreated_model(
Expand Down
57 changes: 53 additions & 4 deletions diff_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,29 @@ def validate_binary(arr: np.ndarray, name: str) -> None:
raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")


def warn_if_not_converged(
converged: bool,
method_name: str,
max_iter: int,
tol: Optional[float] = None,
stacklevel: int = 3,
) -> None:
"""Emit a UserWarning when an iterative solver exhausts max_iter without converging.

Shared helper for axis-B silent-failure fixes (iterative loops that otherwise
return the current iterate without signaling non-convergence).
"""
if converged:
return
tol_suffix = f" (tol={tol})" if tol is not None else ""
warnings.warn(
f"{method_name} did not converge in {max_iter} iterations{tol_suffix}. "
"Results may be inaccurate.",
UserWarning,
stacklevel=stacklevel,
)


def compute_robust_se(
X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
) -> np.ndarray:
Expand Down Expand Up @@ -1791,6 +1814,8 @@ def within_transform(
inplace: bool = False,
suffix: str = "_demeaned",
weights: Optional[np.ndarray] = None,
max_iter: int = 100,
tol: float = 1e-8,
) -> pd.DataFrame:
"""
Apply two-way within transformation to remove unit and time fixed effects.
Expand Down Expand Up @@ -1818,6 +1843,14 @@ def within_transform(
Suffix for new column names when inplace=False.
weights : np.ndarray, optional
Observation weights for weighted group means.
max_iter : int, default 100
Maximum number of alternating-projection iterations. Used only when
``weights`` is not ``None``; the unweighted path is a single pass and
ignores this argument. Emits a ``UserWarning`` per call when any
variable fails to converge within this budget.
tol : float, default 1e-8
Convergence tolerance on the max absolute change across the iterate.
Used only when ``weights`` is not ``None``.

Returns
-------
Expand Down Expand Up @@ -1853,29 +1886,45 @@ def _weighted_group_demean(x, groups, w, w_sum):
wx_sum = pd.Series(w * x).groupby(groups).transform("sum").values
return x - wx_sum / w_sum

non_converged_vars: List[str] = []
if inplace:
for var in variables:
x = data[var].values.astype(np.float64)
for _iter in range(100): # max iterations
converged = False
for _iter in range(max_iter):
x_old = x.copy()
x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
x = _weighted_group_demean(x, time_groups, w, time_w_sum)
if np.max(np.abs(x - x_old)) < 1e-8:
if np.max(np.abs(x - x_old)) < tol:
converged = True
break
if not converged:
non_converged_vars.append(var)
data[var] = x
else:
demeaned_data = {}
for var in variables:
x = data[var].values.astype(np.float64)
for _iter in range(100):
converged = False
for _iter in range(max_iter):
x_old = x.copy()
x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
x = _weighted_group_demean(x, time_groups, w, time_w_sum)
if np.max(np.abs(x - x_old)) < 1e-8:
if np.max(np.abs(x - x_old)) < tol:
converged = True
break
if not converged:
non_converged_vars.append(var)
demeaned_data[f"{var}{suffix}"] = x
demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
data = pd.concat([data, demeaned_df], axis=1)
if non_converged_vars:
warn_if_not_converged(
False,
f"within_transform weighted demean (variables: {non_converged_vars})",
max_iter,
tol,
)
else:
# Cache groupby objects for efficiency
unit_grouper = data.groupby(unit, sort=False)
Expand Down
12 changes: 12 additions & 0 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ where `W_it(h) = 1[K_it = h]` are lead indicators, estimated on `Omega_0` only.
- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and design-based variance via `compute_survey_if_variance()`. PSU clustering, stratification, and FPC are fully supported in the Theorem 3 variance path. When `resolved_survey` is present, the observation-level influence function (`v_it * epsilon_tilde_it`) is passed to `compute_survey_if_variance()` which applies the stratified PSU-level sandwich with FPC correction. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
- **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns.
- **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean.
- **Note:** Both the iterative FE solver (`_iterative_fe`, Step 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization and the pre-trend test) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.

**Reference implementation(s):**
- Stata: `did_imputation` (Borusyak, Jaravel, Spiess; available from SSC)
Expand Down Expand Up @@ -1160,6 +1161,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
- **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0.
- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0.
- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.

**Reference implementation(s):**
- R: `did2s::did2s()` (Kyle Butts & John Gardner)
Expand Down Expand Up @@ -1299,6 +1301,7 @@ The saturated ETWFE regression includes:

The interaction coefficient `δ_{g,t}` identifies `ATT(g, t)` under parallel trends.
- **Note:** OLS path uses iterative alternating-projection within-transformation (uniform weights) for exact FE absorption on both balanced and unbalanced panels. One-pass demeaning (`y - ȳ_i - ȳ_t + ȳ`) is only exact for balanced panels.
- **Note:** The weighted within-transformation (`utils.within_transform` with `weights`) is invoked on every WooldridgeDiD fit (survey weights when provided, `np.ones` otherwise) and emits a `UserWarning` on non-convergence per the shared convention documented under *Absorbed Fixed Effects with Survey Weights*.

*Nonlinear extensions (Wooldridge 2023):*

Expand Down Expand Up @@ -2519,6 +2522,15 @@ unequal selection probabilities).
are rejected (single-pass sequential demeaning is not the correct weighted
FWL projection for N > 1 dimensions; iterative alternating projections are
needed but not yet implemented).
- **Note:** The shared weighted within-transformation path
(`diff_diff.utils.within_transform`, hit whenever `weights is not None`) emits
a `UserWarning` per call when any transformed variable exits the
alternating-projection loop without reaching `tol` within `max_iter`.
Defaults: `max_iter=100`, `tol=1e-8`. This signal applies uniformly across
TwoWayFixedEffects, SunAbraham, BaconDecomposition, and WooldridgeDiD whenever
they route through this helper (survey-weighted or otherwise). Silent return
of the current iterate was classified as a silent failure under the Phase 2
audit and replaced with this explicit signal.

### Survey Degrees of Freedom

Expand Down
54 changes: 54 additions & 0 deletions tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,3 +2087,57 @@ def test_balanced_cohort_mask_requires_negative_horizons(self):
df_treated, "first_treat", all_horizons, 1, cohort_rel_times
)
assert all(mask1)

def test_iterative_fe_warns_on_nonconvergence(self):
"""Silent-failure audit axis B: _iterative_fe must warn when max_iter exhausts."""
rng = np.random.default_rng(42)
n_units, n_periods = 8, 5
units = np.repeat(np.arange(n_units), n_periods)
times = np.tile(np.arange(n_periods), n_units)
y = rng.standard_normal(n_units * n_periods)
idx = pd.RangeIndex(len(y))
est = ImputationDiD()

with pytest.warns(UserWarning, match="did not converge"):
est._iterative_fe(y, units, times, idx, max_iter=1, tol=1e-15)

def test_iterative_fe_no_warning_on_convergence(self):
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
rng = np.random.default_rng(42)
n_units, n_periods = 8, 5
units = np.repeat(np.arange(n_units), n_periods)
times = np.tile(np.arange(n_periods), n_units)
y = rng.standard_normal(n_units * n_periods)
idx = pd.RangeIndex(len(y))
est = ImputationDiD()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
est._iterative_fe(y, units, times, idx)
assert not any("did not converge" in str(x.message) for x in w)

def test_iterative_demean_warns_on_nonconvergence(self):
"""Silent-failure audit axis B: _iterative_demean must warn when max_iter exhausts."""
rng = np.random.default_rng(42)
n_units, n_periods = 8, 5
units = np.repeat(np.arange(n_units), n_periods)
times = np.tile(np.arange(n_periods), n_units)
vals = rng.standard_normal(n_units * n_periods)
idx = pd.RangeIndex(len(vals))

with pytest.warns(UserWarning, match="did not converge"):
ImputationDiD._iterative_demean(vals, units, times, idx, max_iter=1, tol=1e-15)

def test_iterative_demean_no_warning_on_convergence(self):
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
rng = np.random.default_rng(42)
n_units, n_periods = 8, 5
units = np.repeat(np.arange(n_units), n_periods)
times = np.tile(np.arange(n_periods), n_units)
vals = rng.standard_normal(n_units * n_periods)
idx = pd.RangeIndex(len(vals))

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
ImputationDiD._iterative_demean(vals, units, times, idx)
assert not any("did not converge" in str(x.message) for x in w)
21 changes: 21 additions & 0 deletions tests/test_methodology_twfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,27 @@ def test_demeaned_outcome_sums_to_zero(self):
np.testing.assert_allclose(unit_sums.values, 0, atol=1e-10)
np.testing.assert_allclose(time_sums.values, 0, atol=1e-10)

def test_within_transform_weighted_warns_on_nonconvergence(self):
"""Silent-failure audit axis B: within_transform weighted path must warn."""
data = generate_twfe_panel(n_units=20, n_periods=4, seed=99)
weights = np.ones(len(data))

with pytest.warns(UserWarning, match="did not converge"):
within_transform(
data, ["outcome"], "unit", "period",
weights=weights, max_iter=1, tol=1e-15,
)

def test_within_transform_weighted_no_warning_on_convergence(self):
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
data = generate_twfe_panel(n_units=20, n_periods=4, seed=99)
weights = np.ones(len(data))

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
within_transform(data, ["outcome"], "unit", "period", weights=weights)
assert not any("did not converge" in str(x.message) for x in w)


# =============================================================================
# Phase 2: R Comparison
Expand Down
Loading
Loading