From 2e2bb7d49fe33454c7e88b2b3c2107917093dbb6 Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 19:48:38 -0400 Subject: [PATCH 1/8] HAD Phase 3 follow-up: joint Stute pretest + event-study workflow dispatch Adds `stute_joint_pretest` (residuals-in core) plus `joint_pretrends_test` and `joint_homogeneity_test` data-in wrappers for the paper's step-2 (mean-independence pre-trends) and step-3 (linearity joint extension) nulls. Extends `did_had_pretest_workflow` with `aggregate="event_study"` multi-period dispatch that closes the "paper step 2 deferred" gap previously flagged on two-period reports. Implementation highlights: - Sum-of-CvMs aggregation (Delgado 1993; Escanciano 2006) with shared Mammen wild bootstrap multiplier across horizons per unit to preserve vector-valued empirical-process unit-level dependence (Delgado-Manteiga 2001; Hlavka-Huskova 2020). - Per-horizon scale- and translation-invariant exact-linear short-circuit (a single degenerate horizon does not collapse the joint test). - Reciprocal front-door guards on both wrappers: non-empty horizon list, base_period ordering, D=0 invariant (pre-trends) and D>0 existence (post-homogeneity). - Backward-compatible HADPretestReport extension: new fields pretrends_joint, homogeneity_joint, aggregate with defaults; stute and yatchew become Optional. summary, to_dict, to_dataframe, and __repr__ branch on aggregate and preserve Phase 3 schemas bit-exactly on the aggregate="overall" path. - Eq (18) linear-trend detrending (paper Section 5.2 Pierce-Schott p=0.51) deferred to Phase 4 replication harness where the published value serves as parity anchor; TODO row migrated accordingly. 46 new tests (115 total in tests/test_had_pretests.py) covering: K=1 parity with stute_test, shared-eta white-box, per-horizon short- circuit independence, full reciprocal-validator matrix, event-study verdict priority, serialization round-trip across aggregates. Includes regression tests asserting the "paper step 2 deferred" string is absent from any event-study verdict. Closes TODO.md Phase 3 rows for joint Eq 18 and multi-period workflow dispatch. See REGISTRY.md HeterogeneousAdoptionDiD "Joint Stute tests" for algorithm, invariants, and the no-joint-Yatchew acknowledgment (the paper does not derive one; multi-period Yatchew remains available per-horizon via yatchew_hr_test). Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 3 + TODO.md | 3 +- diff_diff/__init__.py | 10 + diff_diff/had_pretests.py | 1401 +++++++++++++++-- docs/methodology/REGISTRY.md | 27 +- .../papers/dechaisemartin-2026-review.md | 18 +- tests/test_had_pretests.py | 1225 +++++++++++++- 7 files changed, 2542 insertions(+), 145 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58ab0324..4a9aebed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- **`stute_joint_pretest`, `joint_pretrends_test`, `joint_homogeneity_test` + `StuteJointResult`** (HeterogeneousAdoptionDiD Phase 3 follow-up). Joint Cramér-von Mises pretests across K horizons with shared-η Mammen wild bootstrap (preserves vector-valued empirical-process unit-level dependence per Delgado-Manteiga 2001 / Hlávka-Hušková 2020). The core `stute_joint_pretest` is residuals-in; two thin data-in wrappers construct per-horizon residuals for the two nulls the paper spells out: mean-independence (step 2 pre-trends, `OLS(Y_t − Y_base ~ 1)` per pre-period) and linearity (step 3 joint, `OLS(Y_t − Y_base ~ 1 + D)` per post-period). Sum-of-CvMs aggregation (`S_joint = Σ_k S_k`); per-horizon scale-invariant exact-linear short-circuit. Closes the paper Section 4.2 step-2 gap that Phase 3 `did_had_pretest_workflow` previously flagged with an "Assumption 7 pre-trends test NOT run" caveat. See `docs/methodology/REGISTRY.md` §HeterogeneousAdoptionDiD "Joint Stute tests" for algorithm, invariants, and scope exclusion of Eq 18 linear-trend detrending (deferred to Phase 4 Pierce-Schott replication). +- **`did_had_pretest_workflow(aggregate="event_study")`**: multi-period dispatch on balanced ≥3-period panels. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods + joint homogeneity-linearity Stute across post-periods. Reuses the Phase 2b event-study panel validator (last-cohort auto-filter under staggered timing with `UserWarning`; `ValueError` when `first_treat_col=None` and the panel is staggered). `HADPretestReport` extended with `pretrends_joint`, `homogeneity_joint`, and `aggregate` fields; `summary`, `to_dict`, `to_dataframe`, `__repr__` branch on `aggregate` and preserve Phase 3 schemas bit-exactly on the `aggregate="overall"` path. - **`target_parameter` block in BR/DR schemas (experimental; schema version bumped to 2.0)** — `BUSINESS_REPORT_SCHEMA_VERSION` and `DIAGNOSTIC_REPORT_SCHEMA_VERSION` bumped from `"1.0"` to `"2.0"` because the new `"no_scalar_by_design"` value on the `headline.status` / `headline_metric.status` enum (dCDH `trends_linear=True, L_max>=2` configuration) is a breaking change per the REPORTING.md stability policy. BusinessReport and DiagnosticReport now emit a top-level `target_parameter` block naming what the headline scalar actually represents for each of the 16 result classes. Closes BR/DR foundation gap #6 (target-parameter clarity). Fields: `name`, `definition`, `aggregation` (machine-readable dispatch tag), `headline_attribute` (raw result attribute), `reference` (citation pointer). BR's summary emits the short `name` right after the headline; DR's overall-interpretation paragraph does the same; both full reports carry a "## Target Parameter" section with the full definition. Per-estimator dispatch is sourced from REGISTRY.md and lives in the new `diff_diff/_reporting_helpers.py::describe_target_parameter`. A few branches read fit-time config (`EfficientDiDResults.pt_assumption`, `StackedDiDResults.clean_control`, `ChaisemartinDHaultfoeuilleResults.L_max` / `covariate_residuals` / `linear_trends_effects`); others emit a fixed tag (the fit-time `aggregate` kwarg on CS / Imputation / TwoStage / Wooldridge does not change the `overall_att` scalar — disambiguating horizon / group tables is tracked under gap #9). See `docs/methodology/REPORTING.md` "Target parameter" section. - SyntheticDiD coverage Monte Carlo calibration table added to `docs/methodology/REGISTRY.md` §SyntheticDiD — rejection rates at α ∈ {0.01, 0.05, 0.10} across `placebo` / `bootstrap` / `jackknife` on 3 representative DGPs (balanced / exchangeable, unbalanced, and Arkhangelsky et al. (2021) AER §6.3 non-exchangeable). Artifact at `benchmarks/data/sdid_coverage.json` (500 seeds × B=200), regenerable via `benchmarks/python/coverage_sdid.py`. @@ -18,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - SyntheticDiD bootstrap now retries degenerate resamples (all-control or all-treated, or non-finite `τ_b`) until exactly `n_bootstrap` valid replicates are accumulated, matching R's `synthdid::bootstrap_sample` and Arkhangelsky et al. (2021) Algorithm 2. Previously the Python path counted attempts (with degenerate draws silently dropped), producing fewer valid replicates than requested. A bounded-attempt guard (`20 × n_bootstrap`) prevents pathological-input hangs. ### Changed +- **`did_had_pretest_workflow(aggregate="event_study")` verdict no longer emits the "paper step 2 deferred to Phase 3 follow-up" caveat** — the joint pre-trends Stute test closes that gap. The two-period `aggregate="overall"` path retains the existing caveat since the joint variant does not apply to single-pre-period panels. Downstream code that greps verdict strings for the Phase 3 caveat will see it suppressed on the event-study path. - **SyntheticDiD bootstrap no longer supports survey designs** (capability regression). The removed fixed-weight bootstrap path was the only SDID variance method that supported strata/PSU/FPC (via Rao-Wu rescaled bootstrap); the new paper-faithful refit bootstrap rejects all survey designs (including pweight-only) with `NotImplementedError`. Pweight-only users can switch to `variance_method="placebo"` or `"jackknife"`. Strata/PSU/FPC users have no SDID variance option on this release. Composing Rao-Wu rescaled weights with Frank-Wolfe re-estimation requires a separate derivation (weighted FW solver); sketch and reusable scaffolding pointers are in `docs/methodology/REGISTRY.md` §SyntheticDiD and `TODO.md`. ## [3.2.0] - 2026-04-19 diff --git a/TODO.md b/TODO.md index da80f6ea..c1e25203 100644 --- a/TODO.md +++ b/TODO.md @@ -95,11 +95,10 @@ Deferred items from PR reviews that were not addressed before merge. | `HeterogeneousAdoptionDiD`: `weights=` support. Deferred jointly with survey integration. nprobust's `lprobust` has no weight argument so the nonparametric continuous path needs a derivation; the 2SLS mass-point path needs weighted-sandwich parity. | `diff_diff/had.py` | Phase 2a | Medium | | `HeterogeneousAdoptionDiD` mass-point: `vcov_type in {"hc2", "hc2_bm"}` raises `NotImplementedError` pending a 2SLS-specific leverage derivation. The OLS leverage `x_i' (X'X)^{-1} x_i` is wrong for 2SLS; the correct finite-sample correction uses `x_i' (Z'X)^{-1} (...) (X'Z)^{-1} x_i`. Needs derivation plus an R / Stata (`ivreg2 small robust`) parity anchor. | `diff_diff/had.py::_fit_mass_point_2sls` | Phase 2a | Medium | | `HeterogeneousAdoptionDiD` continuous paths: thread `cluster=` through `bias_corrected_local_linear` (Phase 1c's wrapper already supports cluster; Phase 2a ignores it with a `UserWarning` on the continuous path to keep scope tight). | `diff_diff/had.py`, `diff_diff/local_linear.py` | Phase 2a | Low | -| `HeterogeneousAdoptionDiD` Phase 3 joint Equation 18 cross-horizon Stute test: paper's step 2 of the four-step pre-testing workflow tests joint pre-trends via a stacked-residual CvM across pre-period placebos. Phase 3 shipped the single-horizon Stute in `did_had_pretest_workflow()`; the joint variant needs the exact stacked-residual formula extracted from the paper PDF (not reproduced in `dechaisemartin-2026-review.md`). Follow-up patch. | `diff_diff/had_pretests.py` | Phase 3 | Medium | +| `HeterogeneousAdoptionDiD` Eq 18 linear-trend detrending (Pierce-Schott style): the joint-Stute infrastructure shipped in the Phase 3 follow-up supports pre-trends (mean-indep) and post-homogeneity (linearity) nulls. The Pierce-Schott application (paper Section 5.2) uses a LINEAR-TREND detrending of pre-period outcomes before the joint CvM — `Y_{g,t} - Y_{g,t_anchor} - (t - t_anchor)*(Y_{g,t_anchor} - Y_{g,t_anchor-1})` — reaching p=0.51 on US-China tariff data. Extends `joint_pretrends_test` with a detrending mode or a separate Eq 18-specific helper. Deferred to Phase 4 replication harness (where the published p=0.51 serves as the parity anchor). | `diff_diff/had_pretests.py::joint_pretrends_test` | Phase 4 | Medium | | `HeterogeneousAdoptionDiD` Phase 3 Stute performance: Appendix D vectorized matrix form replaces the per-iteration OLS refit with a single precomputed `M = I - X(X'X)^{-1}X'` applied to `eps * eta`. Functionally identical, ~2x faster. Shipped literal-refit form in Phase 3 to match paper text and keep reviewer surface small. | `diff_diff/had_pretests.py::stute_test` | Phase 3 | Low | | `HeterogeneousAdoptionDiD` Phase 3 R-parity: Phase 3 ships coverage-rate validation on synthetic DGPs (not tight point parity against `chaisemartin::stute_test` / `yatchew_test`). Tight numerical parity requires aligning bootstrap seed semantics and `B` across numpy/R and is deferred. | `tests/test_had_pretests.py` | Phase 3 | Low | | `HeterogeneousAdoptionDiD` Phase 3 nprobust bandwidth for Stute: some Stute variants on continuous regressors use nprobust-style optimal bandwidth selection. Phase 3 uses OLS residuals from a 2-parameter linear fit (no bandwidth selection). nprobust integration is a future enhancement; not in paper scope. | `diff_diff/had_pretests.py::stute_test` | Phase 3 | Low | -| `HeterogeneousAdoptionDiD` Phase 3 multi-period workflow dispatch: `did_had_pretest_workflow` accepts two-period overall-path panels only. Multi-period users pre-slice to `(F-1, F)` before calling. A follow-up could add `aggregate="event_study"`-like dispatch for joint pre-trend diagnostics alongside Equation 18. | `diff_diff/had_pretests.py::did_had_pretest_workflow` | Phase 3 | Low | | `HeterogeneousAdoptionDiD` Phase 4: Pierce-Schott (2016) replication harness; reproduce paper Figure 2 values and Table 1 coverage rates. | `benchmarks/`, `tests/` | Phase 2a | Low | | `HeterogeneousAdoptionDiD` Phase 5: `practitioner_next_steps()` integration, tutorial notebook, and `llms.txt` updates (preserving UTF-8 fingerprint). | `diff_diff/practitioner.py`, `tutorials/`, `diff_diff/guides/` | Phase 2a | Low | | `HeterogeneousAdoptionDiD` time-varying dose on event study: Phase 2b REJECTS panels where `D_{g,t}` varies within a unit for `t >= F` (the aggregation uses `D_{g, F}` as the single regressor for all horizons, paper Appendix B.2 constant-dose convention). A follow-up PR could add a time-varying-dose estimator for these panels; current behavior is front-door rejection with a redirect to `ChaisemartinDHaultfoeuille`. | `diff_diff/had.py::_validate_had_panel_event_study` | Phase 2b | Low | diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 3dcb8d26..4a9b93c4 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -63,10 +63,14 @@ from diff_diff.had_pretests import ( HADPretestReport, QUGTestResults, + StuteJointResult, StuteTestResults, YatchewTestResults, did_had_pretest_workflow, + joint_homogeneity_test, + joint_pretrends_test, qug_test, + stute_joint_pretest, stute_test, yatchew_hr_test, ) @@ -461,6 +465,12 @@ "StuteTestResults", "YatchewTestResults", "HADPretestReport", + # HAD joint pre-tests (Phase 3 follow-up) — multi-period event-study + # workflow dispatch via did_had_pretest_workflow(aggregate="event_study") + "stute_joint_pretest", + "joint_pretrends_test", + "joint_homogeneity_test", + "StuteJointResult", # Datasets "load_card_krueger", "load_castle_doctrine", diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index fff55aea..78ef0970 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -40,6 +40,7 @@ _aggregate_first_difference, _json_safe_scalar, _validate_had_panel, + _validate_had_panel_event_study, ) from diff_diff.utils import _generate_mammen_weights @@ -47,10 +48,14 @@ "QUGTestResults", "StuteTestResults", "YatchewTestResults", + "StuteJointResult", "HADPretestReport", "qug_test", "stute_test", "yatchew_hr_test", + "stute_joint_pretest", + "joint_pretrends_test", + "joint_homogeneity_test", "did_had_pretest_workflow", ] @@ -364,159 +369,391 @@ def to_dataframe(self) -> pd.DataFrame: return pd.DataFrame([self.to_dict()]) +@dataclass +class StuteJointResult: + """Result of :func:`stute_joint_pretest` (joint Cramer-von Mises across horizons). + + Aggregates the per-horizon Stute (1997) CvM statistic into a joint + specification test: ``S_joint = sum_k S_k``, where ``S_k`` is the + single-horizon CvM on residuals ``eps_{g,k}``. Inference is via + Mammen (1993) wild bootstrap with a **shared** multiplier ``eta_g`` + across horizons per unit (Delgado-Manteiga 2001; Hlavka-Huskova 2020) + to preserve the unit-level dependence structure of the vector-valued + empirical process. + + Two nulls are supported via the thin wrappers + :func:`joint_pretrends_test` (mean-independence: ``E[Y_t - Y_base | D] + = mu_t``, design matrix ``[1]``) and :func:`joint_homogeneity_test` + (linearity: ``E[Y_t - Y_base | D_t] = beta_{0,t} + beta_{fe,t} * D``, + design matrix ``[1, D]``). Eq 18 linear-trend detrending (paper + Section 5.2 Pierce-Schott application) is a Phase 4 follow-up. + + Attributes + ---------- + cvm_stat_joint : float + Joint statistic ``S_joint = sum_k S_k``. NaN on NaN-propagation. + p_value : float + Bootstrap p-value ``(1 + sum(S*_b >= S_joint)) / (B + 1)``. NaN + when the statistic is NaN. ``1.0`` when the per-horizon exact- + linear short-circuit fires (all horizons machine-exact linear). + reject : bool + ``True`` iff ``p_value <= alpha``. Always ``False`` on NaN. + alpha : float + Significance level. + horizon_labels : list of str + Horizon identifiers as ``str(t)`` for each period. **String + identity only** - NOT a chronological ordering key. Callers who + need chronological order should preserve the original period + values alongside (a downstream plotter sorting labels + lexicographically will misorder e.g. + ``["2003-Q10", "2003-Q2", ...]``). + per_horizon_stats : dict[str, float] + ``{label: S_k}`` diagnostic. Per-horizon p-values are NOT + exposed (decomposing the joint bootstrap into K independent + loops is a K-fold memory/time cost; deferred). Callers who need + per-horizon p-values can call :func:`stute_test` separately on + each (period, residual) pair. + + On NaN-propagation (any horizon has NaN input), this dict is + preserved with ``{label: np.nan for label in horizon_labels}``, + NOT an empty dict, NOT a partial dict: the keys carry diagnostic + value (which horizons were attempted), the NaN values signal + non-propagation. + n_bootstrap : int + n_obs : int + Number of units ``G``. + n_horizons : int + seed : int or None + null_form : str + ``"mean_independence"`` (from :func:`joint_pretrends_test`) or + ``"linearity"`` (from :func:`joint_homogeneity_test`). + ``"custom"`` when called directly via :func:`stute_joint_pretest` + without a wrapper. + exact_linear_short_circuited : bool + ``True`` when every horizon's residual SSR to centered TSS ratio + is below :data:`_EXACT_LINEAR_RELATIVE_TOL`; bootstrap is + skipped and ``p_value = 1.0``. The per-horizon check ensures a + single degenerate horizon does not collapse the joint test when + other horizons have nontrivial residuals. + """ + + cvm_stat_joint: float + p_value: float + reject: bool + alpha: float + horizon_labels: list + per_horizon_stats: Dict[str, float] + n_bootstrap: int + n_obs: int + n_horizons: int + seed: Optional[int] + null_form: str + exact_linear_short_circuited: bool + + def __repr__(self) -> str: + return ( + f"StuteJointResult(cvm_stat_joint={self.cvm_stat_joint:.4f}, " + f"p_value={self.p_value:.4f}, reject={self.reject}, " + f"n_horizons={self.n_horizons}, null_form={self.null_form!r}, " + f"n_obs={self.n_obs})" + ) + + def summary(self) -> str: + """Formatted summary table.""" + width = 64 + per_horizon_lines = [ + f" {label:<20} {stat:>20.4f}" for label, stat in self.per_horizon_stats.items() + ] + null_label = { + "mean_independence": "mean-independence (pre-trends)", + "linearity": "linearity (post-homogeneity)", + }.get(self.null_form, self.null_form) + lines = [ + "=" * width, + f"Joint Stute CvM test ({null_label})".center(width), + "=" * width, + f"{'Joint CvM statistic:':<30} {self.cvm_stat_joint:>20.4f}", + f"{'Bootstrap p-value:':<30} {self.p_value:>20.4f}", + f"{'Reject H_0:':<30} {str(self.reject):>20}", + f"{'alpha:':<30} {self.alpha:>20.4f}", + f"{'Bootstrap replications:':<30} {self.n_bootstrap:>20}", + f"{'Horizons:':<30} {self.n_horizons:>20}", + f"{'Observations:':<30} {self.n_obs:>20}", + f"{'Seed:':<30} {str(self.seed):>20}", + f"{'Exact-linear short-circuit:':<30} " f"{str(self.exact_linear_short_circuited):>20}", + "-" * width, + "Per-horizon statistics:", + *per_horizon_lines, + "=" * width, + ] + return "\n".join(lines) + + def print_summary(self) -> None: + """Print the summary to stdout.""" + print(self.summary()) + + def to_dict(self) -> Dict[str, Any]: + """Return results as a JSON-safe dict.""" + return { + "test": "stute_joint", + "cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint), + "p_value": _json_safe_scalar(self.p_value), + "reject": bool(self.reject), + "alpha": float(self.alpha), + "horizon_labels": [str(label) for label in self.horizon_labels], + "per_horizon_stats": { + str(k): _json_safe_scalar(v) for k, v in self.per_horizon_stats.items() + }, + "n_bootstrap": int(self.n_bootstrap), + "n_obs": int(self.n_obs), + "n_horizons": int(self.n_horizons), + "seed": None if self.seed is None else int(self.seed), + "null_form": str(self.null_form), + "exact_linear_short_circuited": bool(self.exact_linear_short_circuited), + } + + def to_dataframe(self) -> pd.DataFrame: + """Return a one-row DataFrame of the top-level result fields.""" + return pd.DataFrame( + [ + { + "test": "stute_joint", + "cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint), + "p_value": _json_safe_scalar(self.p_value), + "reject": bool(self.reject), + "alpha": float(self.alpha), + "n_bootstrap": int(self.n_bootstrap), + "n_obs": int(self.n_obs), + "n_horizons": int(self.n_horizons), + "null_form": str(self.null_form), + } + ] + ) + + @dataclass class HADPretestReport: """Composite output of :func:`did_had_pretest_workflow`. - Bundles the three individual tests with an overall verdict string. + Two dispatch shapes, distinguished by :attr:`aggregate`: + + ``aggregate="overall"`` (default, two-period panel): bundles paper + steps 1 (QUG) and 3 (linearity via Stute + Yatchew-HR) on a + two-period first-differenced sample. Step 2 (Assumption 7 pre-trends) + is NOT implemented on this path and is explicitly flagged in the + verdict; callers must run pre-trends separately. - .. important:: - This report reflects a **partial** workflow: Phase 3 ships paper - Sections 4.2-4.3 steps 1 (QUG) and 3 (linearity via Stute + - Yatchew-HR), but **NOT** step 2 (Assumption 7 pre-trends test via - Equation 18). Even when ``all_pass`` is ``True``, the paper's - four-step certification for TWFE validity is incomplete — pre- - trends testing is a separate diagnostic that must be run via the - user's own event-study / placebo analysis until the Phase 3 - follow-up patch lands the joint Equation 18 Stute test. + ``aggregate="event_study"`` (multi-period panel, >= 3 periods): + bundles QUG + joint pre-trends Stute + joint homogeneity-linearity + Stute. The joint Stute variants close the paper step-2 gap; the + event-study verdict does NOT emit the "paper step 2 deferred" + caveat. Step 3 adjudication uses joint Stute only - no joint Yatchew + variant exists because the paper does not derive one; users who need + Yatchew robustness under multi-period data can run + :func:`yatchew_hr_test` on each (base, post) pair manually. Attributes ---------- qug : QUGTestResults - stute : StuteTestResults - yatchew : YatchewTestResults + Always populated. + stute : StuteTestResults or None + Populated when ``aggregate == "overall"``; ``None`` when + ``aggregate == "event_study"``. + yatchew : YatchewTestResults or None + Populated when ``aggregate == "overall"``; ``None`` when + ``aggregate == "event_study"``. + pretrends_joint : StuteJointResult or None + Populated when ``aggregate == "event_study"`` and at least one + earlier pre-period exists; ``None`` on the overall path or when + only the immediate base pre-period is available. + homogeneity_joint : StuteJointResult or None + Populated when ``aggregate == "event_study"``; ``None`` on the + overall path. all_pass : bool - ``True`` iff (a) QUG is conclusive (step 1), (b) at least ONE of - Stute / Yatchew is conclusive (step 3 - paper's "Stute or - Yatchew" wording), AND (c) no conclusive test rejects. This - gating follows the paper's four-step workflow exactly: step 3 - accepts either linearity test, so a conclusive Stute is - sufficient even when Yatchew returns NaN (e.g. tied-dose - panels). Even when ``all_pass`` is ``True``, the report is a - PARTIAL indicator: it does not certify Assumption 7 (pre-trends), - which is not tested by Phase 3. + On the overall path: same Phase 3 semantics - True iff QUG is + conclusive AND at least one of Stute/Yatchew is conclusive AND + no conclusive test rejects. On the event-study path: True iff + ``np.isfinite(qug.p_value)``, + ``pretrends_joint is not None and + np.isfinite(pretrends_joint.p_value)``, + ``np.isfinite(homogeneity_joint.p_value)``, AND none of the + three rejects. Mirrors Phase 3's ``bool(np.isfinite(p_value))`` + convention - no ``.conclusive()`` helper on any result dataclass. verdict : str - Human-readable classification. Paper rule: TWFE is admissible - only if NONE of the implemented tests rejects. A conclusive - rejection must therefore never be hidden by a purely-inconclusive - verdict just because another step happens to be NaN. - - Priority: - - 1. If any CONCLUSIVE test rejected, that is the primary verdict: - a bundled string naming each failed assumption, - ``"support infimum rejected - continuous_at_zero design - invalid (QUG)"`` and/or ``"linearity rejected - heterogeneity - bias ({Stute[,Yatchew]})"``. If another step is unresolved - (QUG NaN, or BOTH linearity tests NaN), an - ``"; additional steps unresolved: ..."`` suffix is APPENDED - rather than replacing the rejection. - 2. If no conclusive rejection but a required step is unresolved, - the verdict is ``"inconclusive - QUG NaN"`` when step 1 is - the only unresolved piece, ``"inconclusive - both Stute and - Yatchew linearity tests NaN"`` when step 3 lacks any - conclusive linearity test, or ``"inconclusive - QUG NaN; - both Stute and Yatchew linearity tests NaN"`` when both are - unresolved. - 3. Otherwise (all required steps conclusive and none reject), - the partial-workflow fail-to-reject verdict: - ``"QUG and linearity diagnostics fail-to-reject[ (Yatchew - NaN - skipped)]; Assumption 7 pre-trends test NOT run (paper - step 2 deferred to Phase 3 follow-up)"``. The - ``" (... - skipped)"`` suffix appears when Stute OR Yatchew - was NaN but the other was conclusive (step 3 resolved via - the paper's "Stute OR Yatchew" wording). + Human-readable classification. Paper rule applies symmetrically: + TWFE is admissible only if NONE of the implemented tests + rejects. Conclusive rejections are the primary verdict; + unresolved steps append as ``"; additional steps unresolved: + ..."`` rather than replacing the rejection. alpha : float - Significance level shared across tests. n_obs : int - Unit count after aggregation to the two-period first-difference. + Unit count. For overall: units after two-period first-difference + aggregation. For event_study: units after balanced-panel + validation and (if applicable) last-cohort auto-filter. + aggregate : str + ``"overall"`` or ``"event_study"``. Determines which component + fields are populated and which branch of serialization methods + to render. """ qug: QUGTestResults - stute: StuteTestResults - yatchew: YatchewTestResults + stute: Optional[StuteTestResults] + yatchew: Optional[YatchewTestResults] all_pass: bool verdict: str alpha: float n_obs: int + pretrends_joint: Optional[StuteJointResult] = None + homogeneity_joint: Optional[StuteJointResult] = None + aggregate: str = "overall" def __repr__(self) -> str: return ( - f"HADPretestReport(all_pass={self.all_pass}, " + f"HADPretestReport(aggregate={self.aggregate!r}, " + f"all_pass={self.all_pass}, " f"verdict={self.verdict!r}, n_obs={self.n_obs})" ) def summary(self) -> str: - """Formatted summary of all three tests and the verdict.""" + """Formatted summary of all tests and the verdict.""" width = 72 - parts = [ + header = [ "=" * width, "HAD pre-test workflow".center(width), + f"aggregate: {self.aggregate}".center(width), "=" * width, self.qug.summary(), "", - self.stute.summary(), - "", - self.yatchew.summary(), - "", + ] + if self.aggregate == "event_study": + if self.pretrends_joint is not None: + body = [self.pretrends_joint.summary(), ""] + else: + body = [ + "(joint pre-trends skipped - no earlier pre-period)", + "", + ] + if self.homogeneity_joint is not None: + body += [self.homogeneity_joint.summary(), ""] + else: + # aggregate == "overall" + body = [] + if self.stute is not None: + body += [self.stute.summary(), ""] + if self.yatchew is not None: + body += [self.yatchew.summary(), ""] + footer = [ "=" * width, f"{'All pass:':<30} {str(self.all_pass):>40}", f"Verdict: {self.verdict}", "=" * width, ] - return "\n".join(parts) + return "\n".join(header + body + footer) def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary()) def to_dict(self) -> Dict[str, Any]: - """Return a JSON-safe nested dict of the full report.""" - return { + """Return a JSON-safe nested dict of the full report. + + The ``aggregate`` key identifies which component fields are + present; ``None``-valued components are emitted as JSON null. + """ + base: Dict[str, Any] = { + "aggregate": str(self.aggregate), "qug": self.qug.to_dict(), - "stute": self.stute.to_dict(), - "yatchew": self.yatchew.to_dict(), "all_pass": bool(self.all_pass), "verdict": str(self.verdict), "alpha": float(self.alpha), "n_obs": int(self.n_obs), } + if self.aggregate == "event_study": + base["pretrends_joint"] = ( + None if self.pretrends_joint is None else self.pretrends_joint.to_dict() + ) + base["homogeneity_joint"] = ( + None if self.homogeneity_joint is None else self.homogeneity_joint.to_dict() + ) + else: + # aggregate == "overall" - Phase 3 schema preserved bit-exactly + base["stute"] = None if self.stute is None else self.stute.to_dict() + base["yatchew"] = None if self.yatchew is None else self.yatchew.to_dict() + return base def to_dataframe(self) -> pd.DataFrame: - """Return a tidy 3-row DataFrame (one row per test). + """Return a tidy 3-row DataFrame (one row per implemented test). + + Columns (stable across aggregates): + ``[test, statistic_name, statistic_value, p_value, reject, alpha, + n_obs]``. Row identifiers vary by aggregate: - Columns (in order): ``[test, statistic_name, statistic_value, - p_value, reject, alpha, n_obs]``. + - ``aggregate="overall"``: rows are ``qug``, ``stute``, + ``yatchew_hr`` (Phase 3 schema, unchanged). + - ``aggregate="event_study"``: rows are ``qug``, + ``pretrends_joint``, ``homogeneity_joint``. + + Rows for ``None``-valued components (e.g. ``pretrends_joint`` when + no earlier pre-period exists) are emitted with NaN statistic + values and ``reject=False`` to preserve the 3-row shape. """ - rows = [ - { - "test": "qug", - "statistic_name": "t_stat", - "statistic_value": _json_safe_scalar(self.qug.t_stat), - "p_value": _json_safe_scalar(self.qug.p_value), - "reject": bool(self.qug.reject), - "alpha": float(self.qug.alpha), - "n_obs": int(self.qug.n_obs), - }, - { - "test": "stute", - "statistic_name": "cvm_stat", - "statistic_value": _json_safe_scalar(self.stute.cvm_stat), - "p_value": _json_safe_scalar(self.stute.p_value), - "reject": bool(self.stute.reject), - "alpha": float(self.stute.alpha), - "n_obs": int(self.stute.n_obs), - }, - { - "test": "yatchew_hr", - "statistic_name": "t_stat_hr", - "statistic_value": _json_safe_scalar(self.yatchew.t_stat_hr), - "p_value": _json_safe_scalar(self.yatchew.p_value), - "reject": bool(self.yatchew.reject), - "alpha": float(self.yatchew.alpha), - "n_obs": int(self.yatchew.n_obs), - }, - ] + qug_row = { + "test": "qug", + "statistic_name": "t_stat", + "statistic_value": _json_safe_scalar(self.qug.t_stat), + "p_value": _json_safe_scalar(self.qug.p_value), + "reject": bool(self.qug.reject), + "alpha": float(self.qug.alpha), + "n_obs": int(self.qug.n_obs), + } + if self.aggregate == "event_study": + pre_row = self._joint_row_or_nan("pretrends_joint", self.pretrends_joint) + hom_row = self._joint_row_or_nan("homogeneity_joint", self.homogeneity_joint) + rows = [qug_row, pre_row, hom_row] + else: + stute_row = ( + { + "test": "stute", + "statistic_name": "cvm_stat", + "statistic_value": _json_safe_scalar(self.stute.cvm_stat), + "p_value": _json_safe_scalar(self.stute.p_value), + "reject": bool(self.stute.reject), + "alpha": float(self.stute.alpha), + "n_obs": int(self.stute.n_obs), + } + if self.stute is not None + else { + "test": "stute", + "statistic_name": "cvm_stat", + "statistic_value": float("nan"), + "p_value": float("nan"), + "reject": False, + "alpha": float(self.alpha), + "n_obs": int(self.n_obs), + } + ) + yatchew_row = ( + { + "test": "yatchew_hr", + "statistic_name": "t_stat_hr", + "statistic_value": _json_safe_scalar(self.yatchew.t_stat_hr), + "p_value": _json_safe_scalar(self.yatchew.p_value), + "reject": bool(self.yatchew.reject), + "alpha": float(self.yatchew.alpha), + "n_obs": int(self.yatchew.n_obs), + } + if self.yatchew is not None + else { + "test": "yatchew_hr", + "statistic_name": "t_stat_hr", + "statistic_value": float("nan"), + "p_value": float("nan"), + "reject": False, + "alpha": float(self.alpha), + "n_obs": int(self.n_obs), + } + ) + rows = [qug_row, stute_row, yatchew_row] cols = [ "test", "statistic_name", @@ -528,6 +765,35 @@ def to_dataframe(self) -> pd.DataFrame: ] return pd.DataFrame(rows).reindex(columns=cols) + def _joint_row_or_nan( + self, test_label: str, joint: Optional[StuteJointResult] + ) -> Dict[str, Any]: + """Build a to_dataframe row for a joint-Stute component. + + When ``joint`` is ``None`` (e.g. pretrends_joint skipped because + no earlier pre-period), emit a NaN row preserving the 3-row + shape for downstream plotting. + """ + if joint is None: + return { + "test": test_label, + "statistic_name": "cvm_stat_joint", + "statistic_value": float("nan"), + "p_value": float("nan"), + "reject": False, + "alpha": float(self.alpha), + "n_obs": int(self.n_obs), + } + return { + "test": test_label, + "statistic_name": "cvm_stat_joint", + "statistic_value": _json_safe_scalar(joint.cvm_stat_joint), + "p_value": _json_safe_scalar(joint.p_value), + "reject": bool(joint.reject), + "alpha": float(joint.alpha), + "n_obs": int(joint.n_obs), + } + # ============================================================================= # Private helpers @@ -1302,73 +1568,937 @@ def yatchew_hr_test(d: np.ndarray, dy: np.ndarray, alpha: float = 0.05) -> Yatch ) -def did_had_pretest_workflow( +def _validate_multi_period_panel( + data: pd.DataFrame, + outcome_col: str, + dose_col: str, + time_col: str, + unit_col: str, + first_treat_col: Optional[str], +) -> "tuple[Any, list, list, pd.DataFrame, Optional[Dict[str, Any]]]": + """Validate a multi-period HAD panel for joint pre-test dispatch. + + Thin wrapper over :func:`_validate_had_panel_event_study` (had.py) that + inherits the full contract: + + - ``first_treat_col=None`` combined with a staggered panel → raises + ``ValueError`` (the had.py helper does NOT silently accept; it + requires an explicit first-treatment column to identify cohorts). + - ``first_treat_col`` provided but identifies only one cohort → no + auto-filter, proceeds. + - ``first_treat_col`` provided with multiple cohorts → auto-filters + to last-cohort + never-treated, emits ``UserWarning`` with + ``filter_info`` summary. + - Requires ≥ 3 time periods, balanced panel, ordered time dtype, and + the pre-period D=0 invariant across all pre-periods. + + Additional guards on top of had.py: + - ``len(t_pre_list) >= 1`` (need ≥ 1 pre-period for joint pre-trends + infrastructure; had.py already enforces this). + - ``len(t_post_list) >= 1`` (need ≥ 1 post-period for joint + homogeneity; had.py already enforces this). + + Returns the same 5-tuple as the had.py helper: + ``(F, t_pre_list, t_post_list, data_filtered, filter_info)``. + """ + return _validate_had_panel_event_study( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + first_treat_col=first_treat_col, + ) + + +def _aggregate_for_joint_test( + data: pd.DataFrame, + outcome_col: str, + dose_col: str, + time_col: str, + unit_col: str, + horizons: list, + base_period: Any, +) -> "tuple[np.ndarray, Dict[str, np.ndarray], np.ndarray]": + """Aggregate a multi-period panel for a joint-Stute test. + + Builds per-horizon first differences ``dy_t = Y_{g,t} - Y_{g,base}`` + and the unit-level dose ``D_g`` for the joint-Stute test. All units + must appear in every (horizon + base_period) period, matching the + balanced-panel invariant of the single-period :func:`stute_test`. + + Dose extraction: ``D_g = max_t D_{g,t}`` under the HAD contract + "once treated, stay treated with same dose". For pre-periods + ``D_{g,t} = 0`` and for post-periods ``D_{g,t}`` is time-invariant + per unit, so ``max`` recovers the realized post-period dose. + + Parameters + ---------- + data : pd.DataFrame + outcome_col, dose_col, time_col, unit_col : str + horizons : list + Non-empty list of period labels to build ``dy_t`` for. + ``base_period`` must not be in ``horizons``. All ``horizons`` + and ``base_period`` must exist in the time column. + base_period : period label + The reference period for the first difference. + + Returns + ------- + d_arr : np.ndarray, shape (G,) + dy_by_horizon : dict[str, np.ndarray] + Keys are ``str(t)`` per horizon, values are ``dy_t`` arrays of + shape ``(G,)``. Insertion order follows ``horizons``. + unit_ids : np.ndarray, shape (G,) + """ + required = [outcome_col, dose_col, time_col, unit_col] + missing = [c for c in required if c not in data.columns] + if missing: + raise ValueError(f"Missing column(s) in data: {missing}. Required: {required}.") + if len(horizons) == 0: + raise ValueError("horizons must be a non-empty list of period labels.") + data_periods = set(data[time_col].unique()) + needed_periods = list(horizons) + [base_period] + missing_periods = [t for t in needed_periods if t not in data_periods] + if missing_periods: + raise ValueError( + f"Period(s) {missing_periods} not found in time_col " + f"{time_col!r}. Available periods: " + f"{sorted(data_periods, key=lambda x: (x is None, x))}." + ) + if base_period in horizons: + raise ValueError( + f"base_period={base_period!r} must not appear in horizons " f"{list(horizons)!r}." + ) + + mask = data[time_col].isin(needed_periods) + subset = data.loc[mask].copy() + + for col in [outcome_col, dose_col, unit_col]: + col_series = subset[col] + if bool(pd.isna(col_series).any()): + n_nan = int(pd.isna(col_series).sum()) + raise ValueError( + f"{n_nan} NaN value(s) found in column {col!r} across " + f"periods {needed_periods}. Joint pre-test does not " + f"silently drop rows; drop or impute before calling." + ) + + counts = subset.groupby(unit_col).size() + n_needed = len(needed_periods) + if (counts != n_needed).any(): + n_bad = int((counts != n_needed).sum()) + raise ValueError( + f"Panel unbalanced across needed periods {needed_periods}: " + f"{n_bad} unit(s) do not appear in all {n_needed} period(s). " + f"Joint pre-test requires a balanced sub-panel." + ) + + wide_y = subset.pivot(index=unit_col, columns=time_col, values=outcome_col) + wide_y = wide_y.sort_index() + unit_ids = np.asarray(wide_y.index) + + base_y = wide_y[base_period].to_numpy(dtype=np.float64) + dy_by_horizon: Dict[str, np.ndarray] = {} + for t in horizons: + y_t = wide_y[t].to_numpy(dtype=np.float64) + dy_by_horizon[str(t)] = y_t - base_y + + # Dose per unit is the HAD time-invariant post-period dose: + # D_g = max_t D_{g,t}. Critically, compute this over the FULL data, + # not just the subset of needed_periods - for joint pre-trends, + # needed_periods contains only pre-periods (all D=0), so taking max + # over the subset would yield D_g = 0 for every unit and collapse + # the CvM sort to arbitrary ties. Paper HAD convention: dose is + # fixed per unit once treated; pre-period zero-dose is enforced by + # the upstream validator. + d_per_unit = data.groupby(unit_col)[dose_col].max().sort_index() + # Align dose with the subset's unit ordering (pivot sort_index uses + # natural unit_col order; groupby/sort_index on the full data gives + # the same order). + d_per_unit = d_per_unit.loc[unit_ids] + d_arr = d_per_unit.to_numpy(dtype=np.float64) + + return d_arr, dy_by_horizon, unit_ids + + +def _compose_verdict_event_study( + qug: QUGTestResults, + pretrends_joint: Optional[StuteJointResult], + homogeneity_joint: Optional[StuteJointResult], +) -> str: + """Build the event-study :class:`HADPretestReport` verdict. + + Mirrors :func:`_compose_verdict` (two-period path) idiom verbatim: + hyphen-separated ``" - ()"`` reason + strings, ``"; "`` join, ``"; additional steps unresolved: ..."`` + suffix for conclusive rejections that coexist with unresolved + steps, lowercase concerns. + + Coverage: + - Step 1 (QUG): always runs on the event-study path. + - Step 2 (Assumption 7 pre-trends): runs via ``pretrends_joint`` + when at least one earlier pre-period is available. When skipped + (only the immediate base pre-period), the verdict flags the skip + but does NOT emit the Phase-3 "paper step 2 deferred to Phase 3 + follow-up" caveat - this PR closes that gap. + - Step 3 (Assumption 8 linearity/homogeneity): runs via + ``homogeneity_joint`` (joint Stute only; no joint Yatchew variant + exists in the paper). + - Step 4 (alternative linearity via Yatchew): not run on the + event-study path; adjudicated by joint Stute above. + + Priority: + 1. Any conclusive test rejecting → primary verdict bundles each + rejection reason. Unresolved / skipped steps append as a suffix. + 2. No conclusive rejection but a required step unresolved → + ``"inconclusive - ..."``. + 3. All required steps conclusive and none reject → admissible + fail-to-reject string (Section 4 coverage). + """ + qug_ok = bool(np.isfinite(qug.p_value)) + pretrends_ok = pretrends_joint is not None and bool(np.isfinite(pretrends_joint.p_value)) + homogeneity_ok = homogeneity_joint is not None and bool(np.isfinite(homogeneity_joint.p_value)) + + qug_rej = qug_ok and qug.reject + pretrends_rej = pretrends_joint is not None and pretrends_ok and bool(pretrends_joint.reject) + homogeneity_rej = ( + homogeneity_joint is not None and homogeneity_ok and bool(homogeneity_joint.reject) + ) + + reasons = [] + if qug_rej: + reasons.append("support infimum rejected - continuous_at_zero design invalid (QUG)") + if pretrends_rej: + reasons.append("joint pre-trends rejected - assumption 7 violated (joint Stute)") + if homogeneity_rej: + reasons.append("joint linearity rejected - heterogeneity bias (joint Stute)") + + unresolved = [] + if not qug_ok: + unresolved.append("QUG NaN") + if pretrends_joint is None: + unresolved.append("joint pre-trends skipped (no earlier pre-period)") + elif not pretrends_ok: + unresolved.append("joint pre-trends NaN") + if homogeneity_joint is None: + unresolved.append("joint linearity skipped") + elif not homogeneity_ok: + unresolved.append("joint linearity NaN") + + if reasons: + verdict = "; ".join(reasons) + if unresolved: + verdict += "; additional steps unresolved: " + "; ".join(unresolved) + return verdict + + if unresolved: + return "inconclusive - " + "; ".join(unresolved) + + return ( + "QUG, joint pre-trends, and joint linearity diagnostics " + "fail-to-reject (TWFE admissible under Section 4 assumptions)" + ) + + +def stute_joint_pretest( + residuals_by_horizon: Dict[Any, np.ndarray], + fitted_by_horizon: Dict[Any, np.ndarray], + doses: np.ndarray, + design_matrix: np.ndarray, + *, + alpha: float = 0.05, + n_bootstrap: int = 999, + seed: Optional[int] = None, + null_form: str = "custom", +) -> StuteJointResult: + """Joint Cramer-von Mises pretest across multiple horizons. + + Generalizes :func:`stute_test` to K horizons with the joint + statistic ``S_joint = sum_k S_k``, where ``S_k`` is the single- + horizon CvM on residuals ``eps_{g,k}``. Inference is via Mammen wild + bootstrap with a **shared** multiplier ``eta_g`` across horizons per + unit to preserve the vector-valued empirical process's unit-level + dependence. + + **Note:** sum-of-CvMs aggregation follows the standard joint + specification-test construction (Delgado 1993; Escanciano 2006). The + paper does not prescribe an aggregation; sum-of-CvMs balances power + across diffuse vs concentrated alternatives and bootstraps cleanly + with the shared-eta structure. + + Bootstrap uses the literal per-iteration OLS refit form (paper + Appendix D) for consistency with Phase 3's :func:`stute_test`. + ``XtX_inv_Xt`` is precomputed once (same design matrix each + iteration), so the refit cost is O(Gp) per bootstrap draw and the + overall loop is dominated by :func:`_cvm_statistic` across K + horizons. + + Parameters + ---------- + residuals_by_horizon : dict[str, np.ndarray] + ``{label: eps_g}`` per horizon. All values must have identical + length ``G`` and be unit-ordered consistently with ``doses``. + fitted_by_horizon : dict[str, np.ndarray] + ``{label: fitted_g}`` per horizon. Required to reconstruct + bootstrap outcomes ``dy*_{g,k} = fitted_{g,k} + eps_{g,k} * + eta_g`` under the null. + doses : np.ndarray, shape (G,) + Dose per unit. Shared across horizons (HAD contract: dose is + time-invariant per unit). Must be finite and non-negative. + design_matrix : np.ndarray, shape (G, p) + Regression design used in the per-horizon bootstrap refit. + Mean-independence: ``[1]`` (intercept only). Linearity: + ``[1, doses]``. The matrix is identical across horizons. + alpha, n_bootstrap, seed : see :func:`stute_test`. + null_form : str + Diagnostic label recorded on the result + (``"mean_independence"`` | ``"linearity"`` | ``"custom"``). + The wrappers :func:`joint_pretrends_test` and + :func:`joint_homogeneity_test` set this automatically. + + Returns + ------- + StuteJointResult + + Raises + ------ + ValueError + On empty input, key-mismatch, shape-mismatch, ``doses`` + containing negative values, ``G < _MIN_G_STUTE``, or + ``n_bootstrap < _MIN_N_BOOTSTRAP``. + """ + if not isinstance(residuals_by_horizon, dict) or not isinstance(fitted_by_horizon, dict): + raise ValueError( + "residuals_by_horizon and fitted_by_horizon must be dicts " "keyed by horizon label." + ) + if len(residuals_by_horizon) == 0: + raise ValueError("residuals_by_horizon must contain at least one horizon.") + if set(residuals_by_horizon.keys()) != set(fitted_by_horizon.keys()): + raise ValueError( + "residuals_by_horizon and fitted_by_horizon must have " + "identical keys. Got " + f"residuals keys: {sorted(residuals_by_horizon.keys())!r}, " + f"fitted keys: {sorted(fitted_by_horizon.keys())!r}." + ) + + doses_arr = _validate_1d_numeric(np.asarray(doses), "doses") + G = doses_arr.shape[0] + if np.any(doses_arr < 0): + raise ValueError( + "doses must be non-negative (HAD contract - paper Section 2). " + f"Found {int(np.sum(doses_arr < 0))} negative value(s)." + ) + + if G < _MIN_G_STUTE: + raise ValueError(f"Joint Stute test requires G >= {_MIN_G_STUTE} units; got " f"G = {G}.") + if n_bootstrap < _MIN_N_BOOTSTRAP: + raise ValueError(f"n_bootstrap must be >= {_MIN_N_BOOTSTRAP}; got " f"{n_bootstrap}.") + if not isinstance(alpha, (int, float)) or not (0 < float(alpha) < 1): + raise ValueError(f"alpha must be in (0, 1); got {alpha!r}.") + + X = np.asarray(design_matrix, dtype=np.float64) + if X.ndim != 2 or X.shape[0] != G: + raise ValueError(f"design_matrix must have shape (G, p) with G={G}; got " f"{X.shape}.") + if not np.all(np.isfinite(X)): + raise ValueError("design_matrix contains non-finite values (NaN/inf).") + + horizon_labels = list(residuals_by_horizon.keys()) + K = len(horizon_labels) + any_nan = False + residuals_arrays: Dict[str, np.ndarray] = {} + fitted_arrays: Dict[str, np.ndarray] = {} + for k in horizon_labels: + eps_k = np.asarray(residuals_by_horizon[k], dtype=np.float64) + fit_k = np.asarray(fitted_by_horizon[k], dtype=np.float64) + if eps_k.shape != (G,) or fit_k.shape != (G,): + raise ValueError( + f"Horizon {k!r}: residuals shape {eps_k.shape} and " + f"fitted shape {fit_k.shape} must both be ({G},) to " + f"align with doses." + ) + if not (np.all(np.isfinite(eps_k)) and np.all(np.isfinite(fit_k))): + any_nan = True + residuals_arrays[str(k)] = eps_k + fitted_arrays[str(k)] = fit_k + + # Re-key to str labels consistently (wrappers already pass str; direct + # callers may pass int/object). String identity per the documented + # horizon_labels contract. + horizon_labels = [str(k) for k in horizon_labels] + + if any_nan: + return StuteJointResult( + cvm_stat_joint=float("nan"), + p_value=float("nan"), + reject=False, + alpha=float(alpha), + horizon_labels=horizon_labels, + per_horizon_stats={k: float("nan") for k in horizon_labels}, + n_bootstrap=int(n_bootstrap), + n_obs=int(G), + n_horizons=int(K), + seed=None if seed is None else int(seed), + null_form=str(null_form), + exact_linear_short_circuited=False, + ) + + idx = np.argsort(doses_arr, kind="stable") + d_sorted = doses_arr[idx] + + per_horizon_stats: Dict[str, float] = {} + for k in horizon_labels: + per_horizon_stats[k] = _cvm_statistic(residuals_arrays[k][idx], d_sorted) + S_joint = float(sum(per_horizon_stats.values())) + + # Per-horizon exact-linear short-circuit (scale- and translation- + # invariant, matches Phase 3 invariant). A single degenerate horizon + # does NOT collapse the joint test if other horizons have nontrivial + # residuals. + short_circuit = True + for k in horizon_labels: + eps_k = residuals_arrays[k] + fit_k = fitted_arrays[k] + dy_k = fit_k + eps_k + tss_centered = float(np.sum((dy_k - dy_k.mean()) ** 2)) + if tss_centered == 0.0: + # Outcome identically constant: treat as trivially linear for + # this horizon (ratio = 0). Does not force short-circuit + # because other horizons may still be nontrivial. + ratio = 0.0 + else: + ratio = float(np.sum(eps_k**2) / tss_centered) + if ratio >= _EXACT_LINEAR_RELATIVE_TOL: + short_circuit = False + break + + if short_circuit: + return StuteJointResult( + cvm_stat_joint=S_joint, + p_value=1.0, + reject=False, + alpha=float(alpha), + horizon_labels=horizon_labels, + per_horizon_stats=per_horizon_stats, + n_bootstrap=int(n_bootstrap), + n_obs=int(G), + n_horizons=int(K), + seed=None if seed is None else int(seed), + null_form=str(null_form), + exact_linear_short_circuited=True, + ) + + # Precompute OLS projection matrix once: same X per bootstrap draw, + # so (X'X)^-1 X' is constant across iterations. Keeps refit O(Gp) + # per draw without changing semantics from the literal paper form. + XtX_inv_Xt = np.linalg.solve(X.T @ X, X.T) + + rng = np.random.default_rng(seed) + bootstrap_S = np.empty(n_bootstrap, dtype=np.float64) + for b in range(n_bootstrap): + # SHARED eta across horizons - preserves unit-level dependence + # in the vector-valued empirical process. Independent-per-horizon + # draws would overstate precision. + eta = _generate_mammen_weights(G, rng) + S_b = 0.0 + for k in horizon_labels: + dy_b = fitted_arrays[k] + residuals_arrays[k] * eta + beta_b = XtX_inv_Xt @ dy_b + eps_b = dy_b - X @ beta_b + S_b += _cvm_statistic(eps_b[idx], d_sorted) + bootstrap_S[b] = S_b + + p_value = float((1.0 + np.sum(bootstrap_S >= S_joint)) / (n_bootstrap + 1)) + reject = bool(p_value <= alpha) + + return StuteJointResult( + cvm_stat_joint=S_joint, + p_value=p_value, + reject=reject, + alpha=float(alpha), + horizon_labels=horizon_labels, + per_horizon_stats=per_horizon_stats, + n_bootstrap=int(n_bootstrap), + n_obs=int(G), + n_horizons=int(K), + seed=None if seed is None else int(seed), + null_form=str(null_form), + exact_linear_short_circuited=False, + ) + + +def joint_pretrends_test( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, + pre_periods: list, + base_period: Any, first_treat_col: Optional[str] = None, + *, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, -) -> HADPretestReport: - """Run a PARTIAL HAD pre-test workflow on a two-period panel - (paper Section 4.2-4.3, steps 1 and 3 only; step 2 deferred). +) -> StuteJointResult: + """Joint Stute pre-trends test (paper Section 4.2 step 2). + + Data-in wrapper around :func:`stute_joint_pretest` for the + mean-independence null + ``E[Y_{g,t} - Y_{g,base} | D_{g,treat}] = mu_t`` + across multiple pre-period placebos. For each ``t in pre_periods``, + residuals are the deviations of ``Y_{g,t} - Y_{g,base}`` from their + cross-unit mean (an intercept-only OLS fit); the joint CvM tests + that the conditional mean depends on ``D``. + + Use this wrapper to close the paper's step-2 pre-trends gap that + :func:`did_had_pretest_workflow` otherwise flags. On a panel with + at least one earlier pre-period, the + ``aggregate="event_study"`` dispatch calls this wrapper internally. - Phase 3 scope runs: + Parameters + ---------- + data : pd.DataFrame + outcome_col, dose_col, time_col, unit_col : str + pre_periods : list + Non-empty list of pre-period labels (all ``< base_period``, all + with ``D = 0`` across every unit). Empty list raises; the + workflow dispatch handles the "no earlier pre-period" case by + setting ``pretrends_joint=None`` rather than calling this + wrapper. + base_period : period label + The reference period. Must not be in ``pre_periods``. Must also + satisfy ``D = 0`` across every unit (reciprocal of the pre-period + HAD invariant - base is itself a pre-period in the four-step + workflow). + first_treat_col : str or None + Forwarded to the underlying panel validator; matched cohort + handling follows the HAD contract (staggered auto-filter warns + and proceeds on last cohort; solo cohort proceeds). + alpha, n_bootstrap, seed : as in :func:`stute_test`. - - Step 1: :func:`qug_test` (``H_0: d_lower = 0``, Theorem 4). - - Step 3: :func:`stute_test` and :func:`yatchew_hr_test` (linearity - of ``E[ΔY | D_2]``, Assumption 8). + Returns + ------- + StuteJointResult with ``null_form = "mean_independence"``. + """ + if len(pre_periods) == 0: + raise ValueError( + "pre_periods must be non-empty. Workflow dispatch handles " + "the empty case by setting pretrends_joint=None; direct " + "callers should not pass an empty list." + ) + if base_period in pre_periods: + raise ValueError( + f"base_period={base_period!r} must not appear in " f"pre_periods {list(pre_periods)!r}." + ) - Phase 3 does **NOT** run step 2 (Assumption 7 pre-trends test via - paper Equation 18); that joint cross-horizon Stute variant is - deferred to a follow-up patch. Users should continue to perform their - own pre-trends / placebo analysis until the follow-up ships. The - returned :class:`HADPretestReport` verdict explicitly flags the - Assumption 7 gap when all implemented diagnostics fail-to-reject, so - callers do not receive a misleading "TWFE safe" signal. + # Ordering check: all pre_periods strictly < base_period (natural + # order on the column dtype). We rely on the time column being + # comparable (numeric, datetime, or ordered categorical); other + # dtypes would silently misorder. The multi-period validator (when + # called via the workflow) enforces an ordered dtype; direct callers + # get a TypeError here on incomparable types. + try: + out_of_order = [t for t in pre_periods if not (t < base_period)] + except TypeError as exc: + raise TypeError( + "pre_periods and base_period must be comparable " + "(numeric, datetime, or ordered categorical values). " + f"Got pre_periods={list(pre_periods)!r}, " + f"base_period={base_period!r}." + ) from exc + if out_of_order: + raise ValueError( + f"All pre_periods must be strictly < base_period. " + f"Violators: {out_of_order!r} (base_period={base_period!r})." + ) - The workflow reduces the panel to unit-level first differences using - the Phase 2a validator + aggregator, then calls the three tests with - shared ``alpha`` and a single-source seed passthrough (``seed`` is - forwarded to :func:`stute_test` only; QUG and Yatchew are - deterministic). + d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + horizons=list(pre_periods), + base_period=base_period, + ) + G = d_arr.shape[0] + + # HAD invariant: D_{g,t} = 0 for every g and every pre_period (and + # for base_period - it is itself a pre-period relative to the + # treatment onset). We check this on the passed-in panel subset. + needed_all_zero = list(pre_periods) + [base_period] + subset_zero_check = data[data[time_col].isin(needed_all_zero)] + if (subset_zero_check[dose_col] != 0).any(): + n_nonzero = int((subset_zero_check[dose_col] != 0).sum()) + raise ValueError( + f"Pre-trends test requires D = 0 in every pre-period " + f"(including base_period). Found {n_nonzero} non-zero " + f"dose observation(s) across periods " + f"{needed_all_zero!r}. HAD contract (paper Section 2) and " + f"pre-trends test design both require the zero-dose " + f"invariant to hold in ALL periods used as placebo or " + f"anchor." + ) + + residuals_by_horizon: Dict[str, np.ndarray] = {} + fitted_by_horizon: Dict[str, np.ndarray] = {} + for label, dy_t in dy_by_horizon.items(): + mean_t = float(dy_t.mean()) + fitted_t = np.full(G, mean_t, dtype=np.float64) + residuals_t = dy_t - fitted_t + residuals_by_horizon[label] = residuals_t + fitted_by_horizon[label] = fitted_t + + design_matrix = np.ones((G, 1), dtype=np.float64) + + return stute_joint_pretest( + residuals_by_horizon=residuals_by_horizon, + fitted_by_horizon=fitted_by_horizon, + doses=d_arr, + design_matrix=design_matrix, + alpha=alpha, + n_bootstrap=n_bootstrap, + seed=seed, + null_form="mean_independence", + ) + + +def joint_homogeneity_test( + data: pd.DataFrame, + outcome_col: str, + dose_col: str, + time_col: str, + unit_col: str, + post_periods: list, + base_period: Any, + first_treat_col: Optional[str] = None, + *, + alpha: float = 0.05, + n_bootstrap: int = 999, + seed: Optional[int] = None, +) -> StuteJointResult: + """Joint Stute homogeneity-linearity test (paper Section 4.3 joint). + + Data-in wrapper around :func:`stute_joint_pretest` for the + linearity null + ``E[Y_{g,t} - Y_{g,base} | D_{g,t}] = beta_{0,t} + beta_{fe,t} * D_{g,t}`` + across multiple post-period horizons. For each ``t in post_periods``, + residuals are from an OLS regression of ``Y_{g,t} - Y_{g,base}`` on + ``[1, D_g]``; the joint CvM tests whether the conditional mean is + nonlinear in ``D`` in any horizon. + + Used by :func:`did_had_pretest_workflow` with + ``aggregate="event_study"`` as the step-3 test (no joint Yatchew + variant exists - the paper does not derive one; users who need + Yatchew-style adjacent-difference robustness can call + :func:`yatchew_hr_test` on each (base, post) pair manually). Parameters ---------- data : pd.DataFrame - Balanced two-period HAD panel. The dose column must be 0 for all - units at the pre-period (HAD no-unit-untreated pre-period - contract). outcome_col, dose_col, time_col, unit_col : str - Column names. + post_periods : list + Non-empty list of post-period labels (all ``>= base_period`` by + time order; each with ``D > 0`` for some unit, i.e. at least one + treated unit per horizon). + base_period : period label + The reference period (last pre-period in the event-study + convention). Must not be in ``post_periods``. + first_treat_col : str or None + Forwarded to the underlying panel validator. + alpha, n_bootstrap, seed : as in :func:`stute_test`. + + Returns + ------- + StuteJointResult with ``null_form = "linearity"``. + """ + if len(post_periods) == 0: + raise ValueError( + "post_periods must be non-empty. Workflow dispatch handles " + "the empty case upstream; direct callers should not pass " + "an empty list." + ) + if base_period in post_periods: + raise ValueError( + f"base_period={base_period!r} must not appear in " + f"post_periods {list(post_periods)!r}." + ) + + # Ordering: all post_periods >= base_period (and in fact strictly + # greater under the HAD contract where base is the last pre-period). + try: + out_of_order = [t for t in post_periods if not (t > base_period)] + except TypeError as exc: + raise TypeError( + "post_periods and base_period must be comparable " + "(numeric, datetime, or ordered categorical values). " + f"Got post_periods={list(post_periods)!r}, " + f"base_period={base_period!r}." + ) from exc + if out_of_order: + raise ValueError( + f"All post_periods must be strictly > base_period. " + f"Violators: {out_of_order!r} (base_period={base_period!r})." + ) + + d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + horizons=list(post_periods), + base_period=base_period, + ) + G = d_arr.shape[0] + + # HAD invariant for the homogeneity path: base_period has D = 0 + # (last pre-period contract); each post_period has D > 0 for SOME + # unit (existence) and is NOT identically zero across all units + # (reciprocal twin of the pretrends guard - an all-zero post-period + # contradicts the HAD treatment-onset contract). + base_doses = data.loc[data[time_col] == base_period, dose_col] + if (base_doses != 0).any(): + n_nonzero = int((base_doses != 0).sum()) + raise ValueError( + f"base_period={base_period!r} must have D = 0 across every " + f"unit (HAD last-pre-period invariant). Found {n_nonzero} " + f"non-zero dose observation(s) in base_period." + ) + for t in post_periods: + post_doses = data.loc[data[time_col] == t, dose_col] + if not (post_doses > 0).any(): + raise ValueError( + f"post_period={t!r} has D = 0 for every unit. HAD " + f"contract requires at least some unit to have D > 0 " + f"in each post-period (reciprocal of the pre-period " + f"zero-dose invariant)." + ) + + residuals_by_horizon: Dict[str, np.ndarray] = {} + fitted_by_horizon: Dict[str, np.ndarray] = {} + for label, dy_t in dy_by_horizon.items(): + a_hat, b_hat, residuals_t = _fit_ols_intercept_slope(d_arr, dy_t) + fitted_t = a_hat + b_hat * d_arr + residuals_by_horizon[label] = residuals_t + fitted_by_horizon[label] = fitted_t + + design_matrix = np.column_stack([np.ones(G, dtype=np.float64), d_arr.astype(np.float64)]) + + return stute_joint_pretest( + residuals_by_horizon=residuals_by_horizon, + fitted_by_horizon=fitted_by_horizon, + doses=d_arr, + design_matrix=design_matrix, + alpha=alpha, + n_bootstrap=n_bootstrap, + seed=seed, + null_form="linearity", + ) + + +_VALID_AGGREGATES = ("overall", "event_study") + + +def did_had_pretest_workflow( + data: pd.DataFrame, + outcome_col: str, + dose_col: str, + time_col: str, + unit_col: str, + first_treat_col: Optional[str] = None, + alpha: float = 0.05, + n_bootstrap: int = 999, + seed: Optional[int] = None, + *, + aggregate: str = "overall", +) -> HADPretestReport: + """Run the HAD pre-test workflow (paper Section 4.2-4.3). + + Two dispatch modes via ``aggregate``: + + ``aggregate="overall"`` (default, two-period panel): runs paper + steps 1 (:func:`qug_test`) and 3 (:func:`stute_test` + + :func:`yatchew_hr_test`). Step 2 (Assumption 7 pre-trends) is NOT + implemented on this path because a single-pre-period panel cannot + support the joint Stute variant; the returned verdict flags the + Assumption 7 gap explicitly so callers do not receive a misleading + "TWFE safe" signal. For multi-period panels, pass + ``aggregate="event_study"`` to close the step-2 gap. + + ``aggregate="event_study"`` (multi-period panel, >= 3 periods): runs + QUG + joint pre-trends Stute + joint homogeneity-linearity Stute, + covering paper Section 4 steps 1-3 together. Step 4 (Yatchew-style + linearity as an alternative to Stute) is subsumed by the joint Stute + in this path - the paper does not derive a joint Yatchew variant, so + users who need Yatchew robustness under multi-period data should + call :func:`yatchew_hr_test` on each (base, post) pair manually. + + Eq 18 linear-trend detrending (paper Section 5.2 Pierce-Schott + application) is a Phase 4 follow-up; the event-study path here + implements the simpler mean-independence / linearity nulls. + + Parameters + ---------- + data : pd.DataFrame + HAD panel. For ``aggregate="overall"``: balanced two-period + panel with pre-period dose = 0 for every unit. For + ``aggregate="event_study"``: balanced multi-period panel with + >= 3 periods, an ordered time dtype (numeric, datetime, or + ordered categorical), and the pre-period D=0 invariant across + all pre-periods. + outcome_col, dose_col, time_col, unit_col : str first_treat_col : str or None, default None - Optional first-treatment-period column for cross-validation - (see :func:`HeterogeneousAdoptionDiD.fit`). + Optional first-treatment-period column. Required on the + ``aggregate="event_study"`` path when the panel is staggered + (multi-cohort); the panel validator auto-filters to the last + cohort and emits ``UserWarning``. The overall path uses this for + cross-validation only. alpha : float, default 0.05 n_bootstrap : int, default 999 - Replication count for :func:`stute_test`. + Replication count for the single-horizon Stute (overall) or + joint Stute (event_study). seed : int or None, default None - Seed forwarded to :func:`stute_test` only. + Seed forwarded to the Stute bootstrap. QUG / Yatchew are + deterministic. + aggregate : str, keyword-only, default ``"overall"`` + Dispatch mode. Invalid values raise ``ValueError``. Returns ------- HADPretestReport + On the overall path: ``stute`` and ``yatchew`` populated, + ``pretrends_joint`` / ``homogeneity_joint`` are ``None``. On the + event-study path: ``pretrends_joint`` (``None`` if no earlier + pre-period) and ``homogeneity_joint`` populated, ``stute`` / + ``yatchew`` are ``None``. ``aggregate`` is recorded on the + report for serialization dispatch. - Notes - ----- - Phase 3 scope is two-period overall-path only. For multi-period - panels, slice to ``(F - 1, F)`` before calling. A future patch will - add a multi-period dispatch and the joint Equation 18 pre-trend - Stute test. + Raises + ------ + ValueError + On invalid ``aggregate`` or any downstream front-door failure + (panel balance, dtype, dose invariant). References ---------- - de Chaisemartin et al. (2026), Section 4.2-4.3, Theorem 4, Appendix D, - Theorem 7. + de Chaisemartin et al. (2026), Section 4.2-4.3, Theorem 4, Appendix + D, Theorem 7. """ + if aggregate not in _VALID_AGGREGATES: + raise ValueError( + f"aggregate must be one of {list(_VALID_AGGREGATES)!r}; " f"got {aggregate!r}." + ) + + if aggregate == "event_study": + F, t_pre_list, t_post_list, data_filtered, filter_info = _validate_multi_period_panel( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + first_treat_col=first_treat_col, + ) + if filter_info is not None: + warnings.warn( + f"HAD event-study pre-test: staggered panel auto-" + f"filtered to last cohort " + f"(F_last={filter_info['F_last']!r}, " + f"n_kept={filter_info['n_kept']}, " + f"n_dropped={filter_info['n_dropped']}, " + f"dropped_cohorts={filter_info['dropped_cohorts']}). " + f"Paper Appendix B.2 prescription.", + UserWarning, + stacklevel=2, + ) + + # Base period for both joint tests is the last pre-period + # (paper convention: anchor at F-1 under natural time order). + # This is t_pre_list[-1] - NOT an arithmetic F-1, since the + # time column may be non-integer (datetime, ordered categorical). + base_period = t_pre_list[-1] + + # Step 1: QUG on dose distribution at F. Doses are + # time-invariant in HAD, so D_g at F equals max_t D_{g,t}. + doses_at_F = ( + data_filtered.loc[data_filtered[time_col] == F, [unit_col, dose_col]] + .set_index(unit_col) + .sort_index()[dose_col] + .to_numpy(dtype=np.float64) + ) + qug_res = qug_test(doses_at_F, alpha=alpha) + + # Step 2: joint pre-trends on earlier pre-periods (those + # strictly before base_period). If only the base pre-period is + # available (len(t_pre_list) == 1), there are no earlier + # placebos; set pretrends_joint=None and flag in verdict. + earlier_pre = [t for t in t_pre_list if t < base_period] + if len(earlier_pre) >= 1: + pretrends_joint = joint_pretrends_test( + data_filtered, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + pre_periods=earlier_pre, + base_period=base_period, + first_treat_col=first_treat_col, + alpha=alpha, + n_bootstrap=n_bootstrap, + seed=seed, + ) + else: + pretrends_joint = None + + # Step 3: joint homogeneity-linearity on post-periods. + homogeneity_joint = joint_homogeneity_test( + data_filtered, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + post_periods=list(t_post_list), + base_period=base_period, + first_treat_col=first_treat_col, + alpha=alpha, + n_bootstrap=n_bootstrap, + seed=seed, + ) + + # Event-study `all_pass`: True iff every implemented step is + # conclusive AND none reject. `pretrends_joint` must exist + # (cannot be None) for the step-2 gap to be closed. Uses + # `np.isfinite(p_value)` per Phase 3 convention (no + # `.conclusive()` helper on result dataclasses). + qug_ok = bool(np.isfinite(qug_res.p_value)) + pretrends_ok = pretrends_joint is not None and bool(np.isfinite(pretrends_joint.p_value)) + homogeneity_ok = bool(np.isfinite(homogeneity_joint.p_value)) + all_pass = bool( + qug_ok + and pretrends_ok + and pretrends_joint is not None + and not pretrends_joint.reject + and homogeneity_ok + and not homogeneity_joint.reject + and not qug_res.reject + ) + verdict = _compose_verdict_event_study(qug_res, pretrends_joint, homogeneity_joint) + + return HADPretestReport( + qug=qug_res, + stute=None, + yatchew=None, + all_pass=all_pass, + verdict=verdict, + alpha=alpha, + n_obs=int(doses_at_F.shape[0]), + pretrends_joint=pretrends_joint, + homogeneity_joint=homogeneity_joint, + aggregate="event_study", + ) + + # aggregate == "overall" - Phase 3 behavior, unchanged. t_pre, t_post = _validate_had_panel( data, outcome_col, dose_col, time_col, unit_col, first_treat_col ) @@ -1412,4 +2542,5 @@ def did_had_pretest_workflow( verdict=verdict, alpha=alpha, n_obs=int(d_arr.shape[0]), + aggregate="overall", ) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index d6972cd8..09b37a4c 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -2325,13 +2325,32 @@ Shipped in `diff_diff/had_pretests.py` as `yatchew_hr_test()`. Alternative to St 7. Inference on `β̂_{fe}` conditional on accepting the linearity test is asymptotically valid (Theorem 7, Point 1; citing de Chaisemartin and D'Haultfœuille 2024 arXiv:2407.03725). *Four-step pre-testing workflow (Section 4.2-4.3):* -Shipped as `did_had_pretest_workflow()` in Phase 3. The paper's decision rule for TWFE reliability in HADs: +Shipped as `did_had_pretest_workflow()` in Phase 3 (two-period `aggregate="overall"`) and extended in the Phase 3 follow-up with `aggregate="event_study"` dispatch that closes the step-2 pre-trends gap on multi-period panels. The paper's decision rule for TWFE reliability in HADs: 1. Test the null of a QUG (`H_0: d̲ = 0`) using `qug_test()`. -2. Run a pre-trends test of Assumption 7 (requires a pre-period `t=0`). -3. Test that `E(ΔY | D_2)` is linear (`stute_test` or `yatchew_hr_test`). +2. Run a pre-trends test of Assumption 7 (requires at least one earlier pre-period). +3. Test that `E(ΔY | D_2)` is linear (`stute_test` or `yatchew_hr_test`; or the joint Stute variants below in event-study dispatch). 4. If NONE of the three is rejected, `β̂_{fe}` from TWFE may be used to estimate the treatment effect. -**Phase 3 delivery:** `did_had_pretest_workflow()` runs steps 1 + 3 (QUG + Stute + Yatchew) on a two-period panel and returns a verdict. Step 2 (pre-trends test via Equation 18 joint cross-horizon Stute on pre-period placebos) is deferred to a Phase 3 follow-up patch; see `TODO.md`. The `practitioner_next_steps()` integration is queued for Phase 5. +**Phase 3 delivery (`aggregate="overall"`, two-period):** `did_had_pretest_workflow()` runs steps 1 + 3 (QUG + Stute + Yatchew). Step 2 is NOT run on this path because a two-period panel has no pre-period placebo horizon to test against; the verdict explicitly flags the Assumption 7 gap via the "paper step 2 deferred" caveat. + +**Phase 3 follow-up delivery (`aggregate="event_study"`, multi-period):** `did_had_pretest_workflow(..., aggregate="event_study")` dispatches on a balanced ≥3-period panel. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods (step 2, mean-independence null) + joint homogeneity-linearity Stute across post-periods (step 3 joint extension, linearity null). The verdict on this path does NOT emit the "paper step 2 deferred" caveat — the gap is closed. + +*Algorithm variant - Joint Stute tests (Section 4.2-4.3 joint; Phase 3 follow-up 2026-04):* +Shipped in `diff_diff/had_pretests.py` as `stute_joint_pretest()` (residuals-in core) plus two thin data-in wrappers `joint_pretrends_test()` (mean-independence null) and `joint_homogeneity_test()` (linearity null). Generalizes the single-horizon Stute CvM (above) to K horizons with joint inference. +1. Per-horizon statistic: for each horizon `k`, compute `S_k` via the tie-safe CvM on residuals `ε̂_{g,k}` sorted by dose `D_g`. +2. Joint aggregation: `S_joint = Σ_k S_k` (sum-of-CvMs). +3. Wild bootstrap with **shared** Mammen multiplier `η_g` across horizons per unit — preserves the vector-valued empirical process's unit-level dependence (Delgado-Manteiga 2001; Hlávka-Hušková 2020 for related vector wild-bootstrap theory). Per-horizon OLS refit with shared design matrix precomputes `(X'X)^{-1} X'` once; the bootstrap loop cost per draw is `O(G·p·K)` for K horizons. +4. Per-horizon exact-linear short-circuit: scale- and translation-invariant `Σ eps_k² / centered_TSS_k < _EXACT_LINEAR_RELATIVE_TOL` test applied per horizon. Joint short-circuit fires only when EVERY horizon is machine-exact linear; a single degenerate horizon does not collapse the test when others have nontrivial residuals. +5. Two data-in wrappers: + - `joint_pretrends_test(pre_periods, base_period)`: `null_form="mean_independence"`, design matrix `[1]`; residuals from `OLS(Y_t - Y_base ~ 1)` per pre-period (paper Section 4.2 footnote 6 + Section 4.3 paragraph 1: "regress Y_1 − Y_0 on a constant [only], then apply CvM to residuals vs D_2"). + - `joint_homogeneity_test(post_periods, base_period)`: `null_form="linearity"`, design matrix `[1, D]`; residuals from `OLS(Y_t - Y_base ~ 1 + D)` per post-period (paper Section 4.3 page 32 joint across post-periods, Pierce-Schott reports p=0.40). +- **Note:** Sum-of-CvMs aggregation is a standard joint specification-test construction (Delgado 1993; Escanciano 2006); the paper does not prescribe an aggregation rule. Sum-of-CvMs balances power across diffuse vs concentrated alternatives and bootstraps cleanly with shared-η. +- **Note:** Event-study dispatch adjudicates step 3 via joint Stute only; there is no joint Yatchew variant because the paper does not derive one. The overall two-period path still uses the Phase 3 "Stute OR Yatchew" adjudication. Users who need Yatchew-style adjacent-difference variance-ratio robustness under multi-period data can run `yatchew_hr_test` on each (base, post) pair manually. +- **Note:** Eq 18 linear-trend detrending (paper Section 5.2 Pierce-Schott application, p=0.51) is DEFERRED to Phase 4 where the Pierce-Schott replication harness exercises the exact detrending formula against the published value. The Phase 3 follow-up ships the simpler mean-independence null that paper step 2 requires. +- **Note:** Horizon labels in `StuteJointResult.horizon_labels` are `str(t)` verbatim and carry STRING IDENTITY ONLY — NOT a chronological ordering key. Callers who need chronological order must preserve the original period values alongside (e.g. from the `pre_periods` / `post_periods` argument). +- **Note:** NaN propagation is explicit: when any horizon has NaN in residuals, `cvm_stat_joint=NaN`, `p_value=NaN`, `reject=False`, AND `per_horizon_stats={label: np.nan for every horizon}` (full dict preserved with NaN values — not empty, not partial). + +**Phase 3 follow-up delivery:** `stute_joint_pretest()`, `joint_pretrends_test()`, `joint_homogeneity_test()`, `StuteJointResult`, and `did_had_pretest_workflow(aggregate="event_study")` shipped together in PR #353 (2026-04). The `practitioner_next_steps()` integration and tutorial are queued for Phase 5. **Reference implementation(s):** - R: `did_had` (de Chaisemartin, Ciccia, D'Haultfœuille, Knau 2024a); `stute_test` (2024c); `yatchew_test` (Online Appendix, Table 3). diff --git a/docs/methodology/papers/dechaisemartin-2026-review.md b/docs/methodology/papers/dechaisemartin-2026-review.md index 15c62eee..1d7abac1 100644 --- a/docs/methodology/papers/dechaisemartin-2026-review.md +++ b/docs/methodology/papers/dechaisemartin-2026-review.md @@ -184,14 +184,26 @@ Alternative to Stute when `G` is large or heteroskedasticity is suspected. - [ ] Local-linear regression backend (kernel weights, bandwidth selector). - [ ] Integration with bias-corrected CI from Calonico-Cattaneo-Farrell. - [x] QUG null test (`T = D_{2,(1)} / (D_{2,(2)} - D_{2,(1)})`, rejection region `{T > 1/α - 1}`). **Phase 3 implementation (2026-04):** `qug_test()` in `diff_diff/had_pretests.py`. Asymptotic p-value `1/(1+T)` under Exp(1)/Exp(1) limit law. Zero-dose observations filtered upfront with `UserWarning`; tie-break `D_{(1)} == D_{(2)}` returns all-NaN inference. Tight closed-form parity at `atol=1e-12`. -- [x] Stute Cramér-von Mises test with Mammen wild bootstrap. **Phase 3 implementation (2026-04):** `stute_test()` in `diff_diff/had_pretests.py`. Literal per-iteration OLS refit per paper Appendix D Algorithm. `n_bootstrap=999` default, `n_bootstrap >= 99` validated. Single-horizon only; joint Equation 18 cross-horizon Stute (Paper Section 5.2 Pierce-Schott application) deferred to a Phase 3 follow-up patch. +- [x] Stute Cramér-von Mises test with Mammen wild bootstrap. **Phase 3 implementation (2026-04):** `stute_test()` in `diff_diff/had_pretests.py`. Literal per-iteration OLS refit per paper Appendix D Algorithm. `n_bootstrap=999` default, `n_bootstrap >= 99` validated. - [x] Yatchew heteroskedasticity-robust linearity test. **Phase 3 implementation (2026-04):** `yatchew_hr_test()` in `diff_diff/had_pretests.py`. Test statistic `T_hr = sqrt(G)·(σ²_lin - σ²_diff)/σ²_W` from paper Equation 29. `σ²_diff` normalizes by `2G` (paper-literal), NOT `2(G-1)` (finite-sample equivalent but tests pin the paper-literal form). Standard-normal critical value, one-sided. -- [x] Composite workflow `did_had_pretest_workflow()` (paper Section 4.2-4.3). **Phase 3 implementation (2026-04):** Two-period panel entry point runs all three tests and returns `HADPretestReport` with priority-ordered verdict string. The paper's step 2 (pre-trends test of Assumption 7) requires Equation 18 and is deferred to the Phase 3 follow-up patch. +- [x] Composite workflow `did_had_pretest_workflow()` (paper Section 4.2-4.3). **Phase 3 implementation (2026-04):** `aggregate="overall"` (default, two-period) runs QUG + Stute + Yatchew on a two-period panel; step 2 is NOT run on this path because a two-period panel has no pre-period placebo horizon. **Phase 3 follow-up (2026-04):** `aggregate="event_study"` (multi-period) runs QUG at F + joint pre-trends Stute + joint homogeneity-linearity Stute; closes the paper step-2 gap. - [ ] Warnings for staggered treatment timing (direct users to existing `ChaisemartinDHaultfoeuille` in diff-diff). - [ ] Warnings for extensive-margin effects / positive mass of untreated (not fatal; suggests running existing DiD). - [ ] Documentation of non-testability of Assumptions 5 and 6. - [x] Multi-period event-study extension (Appendix B.2). **Phase 2b implementation (2026-04):** `aggregate="event_study"` returns per-event-time WAS estimates using uniform `F-1` anchor. Staggered timing auto-filtered to last cohort with `UserWarning` per Appendix B.2 prescription. Pointwise CIs per horizon (no joint cross-horizon covariance; matches paper's Pierce-Schott Figure 2). Pre-period placebos at `e <= -2`; the anchor `e = -1` is skipped since `ΔY = 0` there by construction. -- [ ] Joint Stute test (Equation 18) across pre-periods. Deferred to a **Phase 3 follow-up patch** — Phase 3 (2026-04) shipped the single-horizon Stute test but not the joint cross-horizon variant; the exact stacked-residual formula needs extraction from the paper PDF (not reproduced in this review). Tracked in `TODO.md`. +- [x] Joint Stute tests (paper Section 4.2 step 2 + Section 4.3 joint extension, pages 23-25 + 32). **Phase 3 follow-up (2026-04):** `stute_joint_pretest()` (residuals-in core) + `joint_pretrends_test()` (mean-independence null) + `joint_homogeneity_test()` (linearity null) in `diff_diff/had_pretests.py`. Sum-of-CvMs aggregation, shared-η Mammen wild bootstrap across horizons (Delgado-Manteiga 2001), per-horizon exact-linear short-circuit. Paper Eq (18) linear-trend detrending variant (Section 5.2 Pierce-Schott p=0.51) deferred to Phase 4 replication harness where the published value serves as parity anchor. + +**Eq (18) transcription (paper page 31):** The Pierce-Schott linear-trend-detrended joint Stute test of pre-trends reads +``` +E( Y_{g,t} − Y_{g,1999} − (t − 1999)·(Y_{g,2000} − Y_{g,1999}) | D_{g,2001} ) = μ_t ∀ t ∈ {1998, 1997} +``` +The paper reports p=0.51 on US-China tariff data. The detrended outcome replaces a raw first-difference: `Y_{g,t}` is first linearly-detrended using the `(Y_{g,2000} − Y_{g,1999})` pre-period slope per unit, then tested against `D_{g,2001}` via the joint CvM. The Phase 3 follow-up ships the simpler mean-independence joint Stute (no detrending); Phase 4 extends it with the Eq (18) detrending wired to the Pierce-Schott replication. + +**Joint Stute construction (paper Section 4.2-4.3 non-linear-trend variant, Phase 3 follow-up delivery):** For a set of horizons `{t_1, ..., t_K}` with residuals `{ε̂_{g,k}}_{k}` per unit and shared doses `D_g`: +1. Per-horizon CvM `S_k = (1/G²) · Σ_g (Σ_{h ≤ g(dose-order)} ε̂_{(h),k})²` (tie-safe via block-collapsed cumsum). +2. Joint statistic `S_joint = Σ_k S_k`. +3. Wild bootstrap p-value `p = (1 + #{S*_b ≥ S_joint}) / (B+1)`, with `η_g` drawn once per iteration and applied SHARED across horizons per unit (vector-valued empirical-process convention). Per-horizon OLS refit on the same design matrix each iteration; `(X'X)^{-1}X'` precomputed. +The paper's text does not prescribe the joint aggregation rule; sum-of-CvMs is the standard joint specification-test construction (Delgado 1993; Escanciano 2006). --- diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index a47e5991..77327f42 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -11,14 +11,21 @@ from diff_diff import ( QUGTestResults, + StuteJointResult, StuteTestResults, YatchewTestResults, did_had_pretest_workflow, + joint_homogeneity_test, + joint_pretrends_test, qug_test, + stute_joint_pretest, stute_test, yatchew_hr_test, ) -from diff_diff.had_pretests import _compose_verdict +from diff_diff.had_pretests import ( + _compose_verdict, + _compose_verdict_event_study, +) # ============================================================================= # Helpers @@ -592,6 +599,7 @@ def test_rejects_on_quadratic_plus_shifted_support(self): assert "linearity rejected" in report.verdict # At least QUG and one of {Stute, Yatchew} rejected. assert report.qug.reject is True + assert report.stute is not None and report.yatchew is not None assert report.stute.reject or report.yatchew.reject def test_workflow_handles_tied_zero_doses_via_stute_fallback(self): @@ -618,6 +626,7 @@ def test_workflow_handles_tied_zero_doses_via_stute_fallback(self): report = did_had_pretest_workflow( panel, "y", "d", "time", "unit", n_bootstrap=199, seed=42 ) + assert report.stute is not None and report.yatchew is not None # Yatchew must NaN (ties from the 20 zero doses). assert np.isnan(report.yatchew.p_value) assert report.yatchew.reject is False @@ -663,6 +672,8 @@ def test_workflow_seed_controls_stute_only(self): report_b = did_had_pretest_workflow( panel, "y", "d", "time", "unit", n_bootstrap=199, seed=42 ) + assert report_a.stute is not None and report_b.stute is not None + assert report_a.yatchew is not None and report_b.yatchew is not None assert report_a.stute.p_value == report_b.stute.p_value assert report_a.qug.t_stat == report_b.qug.t_stat assert report_a.yatchew.t_stat_hr == report_b.yatchew.t_stat_hr @@ -694,6 +705,7 @@ def test_constant_dy_trivially_satisfies_linearity(self): report = did_had_pretest_workflow( panel, "y", "d", "time", "unit", n_bootstrap=199, seed=42 ) + assert report.stute is not None and report.yatchew is not None assert report.stute.reject is False assert report.yatchew.reject is False assert report.stute.p_value > 0.5 @@ -1112,3 +1124,1214 @@ def test_report_summary_bundles_all(self): assert "Stute" in s or "CvM" in s assert "Yatchew" in s assert "Verdict:" in s + + +# ============================================================================= +# Phase 3 follow-up: joint Stute tests + event-study workflow dispatch +# ============================================================================= + + +def _make_multi_period_panel( + G: int, + periods: list, + first_treat_period, + dose_fn=None, + outcome_fn=None, + seed: int = 42, +) -> pd.DataFrame: + """Construct a multi-period HAD panel. + + Parameters + ---------- + G : int + Number of units. + periods : list + Time labels, ordered (numeric or ordered dtype). Must contain at + least one pre-period (``t < first_treat_period``) and one + post-period (``t >= first_treat_period``). + first_treat_period : period label + The first period where any unit has ``D > 0``. For pre-periods, + ``D = 0`` for every unit. + dose_fn : callable or None + ``dose_fn(rng, G) -> np.ndarray`` returning per-unit dose. + Default: uniform on [0.05, 1.0]. + outcome_fn : callable or None + ``outcome_fn(rng, unit_id, t, d, is_post, first_treat) -> float`` + returning the outcome for a single (unit, period) cell. Default: + linear effect ``0.5 * d`` on post-periods plus Gaussian noise. + seed : int + """ + rng = np.random.default_rng(seed) + if dose_fn is None: + doses = rng.uniform(0.05, 1.0, size=G) + else: + doses = dose_fn(rng, G) + unit_effects = rng.normal(0.0, 0.3, size=G) + + def _default_outcome(rng_, g, t, d, is_post, _ft): + eff = 0.5 * d if is_post else 0.0 + return float(unit_effects[g] + eff + rng_.normal(0.0, 0.1)) + + if outcome_fn is None: + outcome_fn = _default_outcome + + rows = [] + for g in range(G): + for t in periods: + is_post = t >= first_treat_period + d = float(doses[g]) if is_post else 0.0 + y = outcome_fn(rng, g, t, d, is_post, first_treat_period) + rows.append({"unit": g, "period": t, "y": y, "d": d}) + return pd.DataFrame(rows) + + +def _nonlinear_outcome(d_effect_fn): + """Build an outcome_fn applying d_effect_fn(d) at post-periods.""" + + def _fn(rng_, g, t, d, is_post, _ft): + eff = d_effect_fn(d) if is_post else 0.0 + noise = rng_.normal(0.0, 0.1) + return float(0.3 * g / 100.0 + eff + noise) + + return _fn + + +def _multi_period_residuals(G: int, K: int, seed: int = 42): + """Random Gaussian residuals + zero fitted + uniform doses.""" + rng = np.random.default_rng(seed) + horizon_labels = [f"t={1995 + k}" for k in range(K)] + residuals = {k: rng.normal(0.0, 1.0, size=G) for k in horizon_labels} + fitted = {k: np.zeros(G) for k in horizon_labels} + doses = rng.uniform(0.0, 1.0, size=G) + return residuals, fitted, doses + + +class TestStuteJointPretest: + """Tests for :func:`stute_joint_pretest` (residuals-in core).""" + + def test_k1_parity_with_single_horizon_stute(self): + """K=1 joint matches stute_test on same residuals (refit semantics).""" + rng = np.random.default_rng(42) + G = 100 + d = rng.uniform(0.0, 1.0, G) + dy = 0.3 * d + rng.normal(0.0, 0.2, G) + # stute_test fits OLS(dy ~ 1 + d) internally. Mirror the fit here + # so the residuals passed to the joint helper are identical. + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy) + fitted = x @ beta + resid = dy - fitted + joint = stute_joint_pretest( + residuals_by_horizon={"only": resid}, + fitted_by_horizon={"only": fitted}, + doses=d, + design_matrix=x, + n_bootstrap=999, + seed=123, + null_form="linearity", + ) + single = stute_test(d, dy, n_bootstrap=999, seed=123) + np.testing.assert_allclose(joint.cvm_stat_joint, single.cvm_stat, atol=1e-14, rtol=1e-14) + # p_value can differ slightly due to RNG draw order (joint draws + # one eta vector per iteration, single draws one - same shape); + # the statistic being bit-identical is the critical check. + + def test_linear_dgp_fails_to_reject(self): + """Linear DGP across all horizons: joint test should not reject.""" + rng = np.random.default_rng(2) + G = 80 + d = rng.uniform(0.0, 1.0, G) + residuals = {} + fitted = {} + for k in range(3): + dy = 0.4 * d + rng.normal(0.0, 0.2, G) + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy) + fit = x @ beta + residuals[f"h{k}"] = dy - fit + fitted[f"h{k}"] = fit + result = stute_joint_pretest( + residuals_by_horizon=residuals, + fitted_by_horizon=fitted, + doses=d, + design_matrix=np.column_stack([np.ones(G), d]), + n_bootstrap=499, + seed=99, + ) + assert result.p_value > 0.05, f"unexpected rejection: p={result.p_value}" + assert result.reject is False + + def test_violated_dgp_in_single_horizon_reject(self): + """Quadratic effect in one of 3 horizons: joint test should reject.""" + rng = np.random.default_rng(7) + G = 150 + d = rng.uniform(0.05, 1.0, G) + residuals = {} + fitted = {} + for k in range(3): + if k == 1: + dy = 4.0 * (d**2) + rng.normal(0.0, 0.2, G) + else: + dy = 0.4 * d + rng.normal(0.0, 0.2, G) + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy) + fit = x @ beta + residuals[f"h{k}"] = dy - fit + fitted[f"h{k}"] = fit + result = stute_joint_pretest( + residuals_by_horizon=residuals, + fitted_by_horizon=fitted, + doses=d, + design_matrix=np.column_stack([np.ones(G), d]), + n_bootstrap=999, + seed=99, + ) + assert result.reject is True, f"expected rejection, got p={result.p_value}" + + def test_shared_eta_across_horizons_white_box(self): + """Bootstrap uses the same eta for all horizons in each iteration. + + White-box check: construct residuals where horizon 0 and horizon + 1 have the EXACT SAME residuals. The joint bootstrap under a + SHARED eta must produce bootstrap outcomes dy_b_h0 == dy_b_h1 + exactly (same fitted + same residuals * same eta). The refit + then gives the same residuals, and the CvM statistic for both + horizons is identical. If eta were drawn independently per + horizon, the two bootstrap residual streams would diverge. + + We verify by checking that S_joint_bootstrap = 2 * S_single_bootstrap + across many iterations (same underlying process duplicated). + """ + rng = np.random.default_rng(11) + G = 60 + d = rng.uniform(0.0, 1.0, G) + dy = 0.5 * d + rng.normal(0.0, 0.3, G) + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy) + fit = x @ beta + resid = dy - fit + joint = stute_joint_pretest( + residuals_by_horizon={"a": resid, "b": resid.copy()}, + fitted_by_horizon={"a": fit, "b": fit.copy()}, + doses=d, + design_matrix=x, + n_bootstrap=499, + seed=77, + ) + single = stute_joint_pretest( + residuals_by_horizon={"a": resid}, + fitted_by_horizon={"a": fit}, + doses=d, + design_matrix=x, + n_bootstrap=499, + seed=77, + ) + # Under shared eta, the joint stat is exactly 2x the single stat + # (both horizons identical). Under independent eta, the joint + # would be the sum of two INDEPENDENT draws - bootstrap + # distributions would differ from 2x the single distribution. + np.testing.assert_allclose( + joint.cvm_stat_joint, 2.0 * single.cvm_stat_joint, atol=1e-14, rtol=1e-14 + ) + # Under shared eta, each bootstrap S*_b_joint = 2 * S*_b_single, + # so the p-value (P(S*_b >= S_obs)) is identical to the single. + np.testing.assert_allclose(joint.p_value, single.p_value, atol=1e-14) + + def test_seed_reproducibility(self): + """Same seed -> bit-identical results across calls.""" + rng = np.random.default_rng(3) + G = 80 + d = rng.uniform(0.0, 1.0, G) + resid = {"h0": rng.normal(0.0, 1.0, G), "h1": rng.normal(0.0, 1.0, G)} + fit = {"h0": np.zeros(G), "h1": np.zeros(G)} + r1 = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=299, + seed=55, + ) + r2 = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=299, + seed=55, + ) + np.testing.assert_allclose(r1.cvm_stat_joint, r2.cvm_stat_joint, atol=1e-14, rtol=1e-14) + np.testing.assert_allclose(r1.p_value, r2.p_value, atol=1e-14, rtol=1e-14) + assert r1.reject == r2.reject + + def test_nan_propagation(self): + """NaN in any residual -> p_value=NaN, reject=False, dict preserved.""" + G = 20 + rng = np.random.default_rng(4) + resid = { + "h0": rng.normal(0.0, 1.0, G), + "h1": rng.normal(0.0, 1.0, G), + } + resid["h1"][5] = np.nan + fit = {"h0": np.zeros(G), "h1": np.zeros(G)} + d = rng.uniform(0.0, 1.0, G) + result = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + assert np.isnan(result.p_value) + assert np.isnan(result.cvm_stat_joint) + assert result.reject is False + assert result.exact_linear_short_circuited is False + # per_horizon_stats must preserve ALL keys with NaN values (not + # empty, not partial) - feedback_no_silent_failures. + assert set(result.per_horizon_stats.keys()) == {"h0", "h1"} + assert all(np.isnan(v) for v in result.per_horizon_stats.values()) + assert result.horizon_labels == ["h0", "h1"] + + def test_negative_dose_raises(self): + G = 20 + resid, fit, _ = _multi_period_residuals(G, K=2) + doses_neg = np.full(G, -0.1) + with pytest.raises(ValueError, match="non-negative"): + stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=doses_neg, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + + def test_small_G_raises(self): + G = 5 # below _MIN_G_STUTE (10) + resid, fit, d = _multi_period_residuals(G, K=2) + with pytest.raises(ValueError, match="G >="): + stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + + def test_small_bootstrap_raises(self): + G = 50 + resid, fit, d = _multi_period_residuals(G, K=2) + with pytest.raises(ValueError, match="n_bootstrap"): + stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=50, + seed=0, + ) + + def test_empty_residuals_raises(self): + with pytest.raises(ValueError, match="at least one horizon"): + stute_joint_pretest( + residuals_by_horizon={}, + fitted_by_horizon={}, + doses=np.arange(30, dtype=np.float64), + design_matrix=np.ones((30, 1)), + n_bootstrap=199, + ) + + def test_key_mismatch_raises(self): + G = 30 + with pytest.raises(ValueError, match="identical keys"): + stute_joint_pretest( + residuals_by_horizon={"a": np.zeros(G)}, + fitted_by_horizon={"b": np.zeros(G)}, + doses=np.arange(G, dtype=np.float64), + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + ) + + def test_exact_linear_short_circuit_per_horizon(self): + """All-horizons exact linear -> short-circuit (p=1, no bootstrap).""" + G = 40 + rng = np.random.default_rng(5) + d = rng.uniform(0.0, 1.0, G) + # Two horizons, both perfectly linear in d (residuals near-zero) + dy1 = 2.0 * d + 1.0 + dy2 = -0.5 * d + 3.0 + x = np.column_stack([np.ones(G), d]) + beta1 = np.linalg.solve(x.T @ x, x.T @ dy1) + beta2 = np.linalg.solve(x.T @ x, x.T @ dy2) + fit1 = x @ beta1 + fit2 = x @ beta2 + resid1 = dy1 - fit1 + resid2 = dy2 - fit2 + result = stute_joint_pretest( + residuals_by_horizon={"h1": resid1, "h2": resid2}, + fitted_by_horizon={"h1": fit1, "h2": fit2}, + doses=d, + design_matrix=x, + n_bootstrap=199, + seed=1, + ) + assert result.exact_linear_short_circuited is True + assert result.p_value == 1.0 + assert result.reject is False + + def test_exact_linear_short_circuit_scale_invariant(self): + """Scale-invariant: rescaling residuals by 1e10 preserves short-circuit.""" + G = 40 + rng = np.random.default_rng(6) + d = rng.uniform(0.0, 1.0, G) + dy = 2.0 * d + 1.0 + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy) + fit = x @ beta + resid = dy - fit + # Scale by 1e10 + result = stute_joint_pretest( + residuals_by_horizon={"h1": resid * 1e10}, + fitted_by_horizon={"h1": fit * 1e10}, + doses=d, + design_matrix=x, + n_bootstrap=199, + seed=1, + ) + assert result.exact_linear_short_circuited is True + assert result.p_value == 1.0 + + def test_per_horizon_short_circuit_independence(self): + """Degenerate horizon + nontrivial horizon -> no short-circuit.""" + G = 80 + rng = np.random.default_rng(8) + d = rng.uniform(0.05, 1.0, G) + # Horizon 1: exact linear (fitted = dy, residuals ~ 0) + dy1 = 2.0 * d + 1.0 + x = np.column_stack([np.ones(G), d]) + beta = np.linalg.solve(x.T @ x, x.T @ dy1) + fit1 = x @ beta + resid1 = dy1 - fit1 + # Horizon 2: strong quadratic (nontrivial residuals) + dy2 = 5.0 * (d**2) + rng.normal(0.0, 0.1, G) + beta2 = np.linalg.solve(x.T @ x, x.T @ dy2) + fit2 = x @ beta2 + resid2 = dy2 - fit2 + result = stute_joint_pretest( + residuals_by_horizon={"lin": resid1, "quad": resid2}, + fitted_by_horizon={"lin": fit1, "quad": fit2}, + doses=d, + design_matrix=x, + n_bootstrap=999, + seed=3, + ) + # Must NOT short-circuit - the quadratic horizon is informative. + assert result.exact_linear_short_circuited is False + # Strong nonlinearity in horizon 2 should make the joint reject. + assert result.reject is True, f"expected rejection; p={result.p_value}" + + def test_horizon_labels_preserved_as_strings(self): + """Int / str / pd.Period labels all get str()'d; order preserved.""" + G = 40 + rng = np.random.default_rng(9) + d = rng.uniform(0.0, 1.0, G) + resid_int_keyed = {1997: rng.normal(0.0, 1.0, G), 1998: rng.normal(0.0, 1.0, G)} + fit_int_keyed = {1997: np.zeros(G), 1998: np.zeros(G)} + result = stute_joint_pretest( + residuals_by_horizon=resid_int_keyed, + fitted_by_horizon=fit_int_keyed, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + assert result.horizon_labels == ["1997", "1998"] + assert set(result.per_horizon_stats.keys()) == {"1997", "1998"} + + +class TestJointPretrendsTest: + """Tests for :func:`joint_pretrends_test` data-in wrapper.""" + + def test_smoke_runs_on_valid_panel(self): + df = _make_multi_period_panel( + G=50, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=11, + ) + result = joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=299, + seed=7, + ) + assert isinstance(result, StuteJointResult) + assert result.null_form == "mean_independence" + assert result.n_horizons == 2 + assert result.n_obs == 50 + assert np.isfinite(result.p_value) + # Linear DGP on post-periods; pre-periods have D=0 so no relationship + # between dy_pre_t and D is expected. Fail-to-reject is the target. + assert result.p_value > 0.05 + + def test_matches_manually_constructed_residuals(self): + """Data-in path reproduces explicit residuals-in call exactly.""" + df = _make_multi_period_panel( + G=60, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=12, + ) + # Data-in dispatch + data_result = joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=399, + seed=22, + ) + # Manual construction: pivot, compute dy_t = Y_t - Y_base, then + # center per-horizon to build residuals. + pivot = df.pivot(index="unit", columns="period", values="y").sort_index() + d_per = df.groupby("unit")["d"].max().sort_index().to_numpy() + G = len(d_per) + base = pivot[1999].to_numpy(dtype=np.float64) + resid = {} + fit = {} + for t in [1997, 1998]: + dy_t = pivot[t].to_numpy(dtype=np.float64) - base + mean_t = float(dy_t.mean()) + fit[str(t)] = np.full(G, mean_t) + resid[str(t)] = dy_t - mean_t + manual_result = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d_per, + design_matrix=np.ones((G, 1)), + n_bootstrap=399, + seed=22, + null_form="mean_independence", + ) + np.testing.assert_allclose( + data_result.cvm_stat_joint, + manual_result.cvm_stat_joint, + atol=1e-14, + rtol=1e-14, + ) + np.testing.assert_allclose( + data_result.p_value, + manual_result.p_value, + atol=1e-14, + rtol=1e-14, + ) + + def test_empty_pre_periods_raises(self): + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999], first_treat_period=1999, seed=1 + ) + with pytest.raises(ValueError, match="non-empty"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[], + base_period=1998, + n_bootstrap=199, + seed=0, + ) + + def test_base_period_in_pre_periods_raises(self): + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999], first_treat_period=1999, seed=1 + ) + with pytest.raises(ValueError, match="must not appear"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1998, + n_bootstrap=199, + seed=0, + ) + + def test_out_of_order_pre_period_raises(self): + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=2000, seed=1 + ) + with pytest.raises(ValueError, match="strictly < base_period"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1999], + base_period=1998, + n_bootstrap=199, + seed=0, + ) + + def test_non_zero_dose_in_pre_period_raises(self): + """HAD contract: pre-periods have D=0 for every unit.""" + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=2000, seed=1 + ) + # Contaminate pre-period 1998 with a non-zero dose for one unit + df.loc[(df["unit"] == 0) & (df["period"] == 1998), "d"] = 0.5 + with pytest.raises(ValueError, match="D = 0"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + + def test_non_zero_dose_in_base_period_raises(self): + """Reciprocal: base_period (last pre-period) must also satisfy D=0.""" + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=2000, seed=1 + ) + df.loc[(df["unit"] == 0) & (df["period"] == 1999), "d"] = 0.3 + with pytest.raises(ValueError, match="D = 0"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + + +class TestJointHomogeneityTest: + """Tests for :func:`joint_homogeneity_test` data-in wrapper.""" + + def test_smoke_runs_on_linear_dgp(self): + df = _make_multi_period_panel( + G=50, + periods=[1998, 1999, 2000, 2001, 2002], + first_treat_period=2000, + seed=13, + ) + result = joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[2000, 2001, 2002], + base_period=1999, + n_bootstrap=299, + seed=21, + ) + assert isinstance(result, StuteJointResult) + assert result.null_form == "linearity" + assert result.n_horizons == 3 + assert np.isfinite(result.p_value) + assert result.reject is False + + def test_rejects_on_quadratic_post_effect(self): + """Quadratic effect in D across post-periods -> joint homogeneity rejects.""" + df = _make_multi_period_panel( + G=120, + periods=[1998, 1999, 2000, 2001], + first_treat_period=2000, + outcome_fn=_nonlinear_outcome(lambda d: 4.0 * (d**2)), + seed=14, + ) + result = joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[2000, 2001], + base_period=1999, + n_bootstrap=999, + seed=31, + ) + assert result.reject is True, f"expected rejection; p={result.p_value}" + + def test_matches_manually_constructed_residuals(self): + df = _make_multi_period_panel( + G=60, + periods=[1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=15, + ) + data_result = joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[2000, 2001], + base_period=1999, + n_bootstrap=399, + seed=41, + ) + # Manual construction: OLS(dy ~ 1 + D) per horizon + pivot = df.pivot(index="unit", columns="period", values="y").sort_index() + d_per = df.groupby("unit")["d"].max().sort_index().to_numpy() + G = len(d_per) + base = pivot[1999].to_numpy(dtype=np.float64) + X = np.column_stack([np.ones(G), d_per.astype(np.float64)]) + resid = {} + fit = {} + for t in [2000, 2001]: + dy_t = pivot[t].to_numpy(dtype=np.float64) - base + beta = np.linalg.solve(X.T @ X, X.T @ dy_t) + fit[str(t)] = X @ beta + resid[str(t)] = dy_t - fit[str(t)] + manual_result = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d_per, + design_matrix=X, + n_bootstrap=399, + seed=41, + null_form="linearity", + ) + np.testing.assert_allclose( + data_result.cvm_stat_joint, + manual_result.cvm_stat_joint, + atol=1e-14, + rtol=1e-14, + ) + np.testing.assert_allclose( + data_result.p_value, + manual_result.p_value, + atol=1e-14, + rtol=1e-14, + ) + + def test_empty_post_periods_raises(self): + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999], first_treat_period=1999, seed=1 + ) + with pytest.raises(ValueError, match="non-empty"): + joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[], + base_period=1998, + n_bootstrap=199, + seed=0, + ) + + def test_base_period_in_post_periods_raises(self): + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999], first_treat_period=1999, seed=1 + ) + with pytest.raises(ValueError, match="must not appear"): + joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[1999], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + + def test_post_period_before_base_raises(self): + """All post_periods must be strictly > base_period.""" + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=1999, seed=1 + ) + with pytest.raises(ValueError, match="strictly > base_period"): + joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[1998, 2000], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + + def test_all_zero_dose_post_period_raises(self): + """Post-period with D=0 for every unit contradicts HAD contract.""" + df = _make_multi_period_panel( + G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=1999, seed=1 + ) + # Zero out all doses at post-period 2000 (keep 1999 post intact for base contract) + df.loc[df["period"] == 2000, "d"] = 0.0 + with pytest.raises(ValueError, match="D = 0 for every unit"): + joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[1999, 2000], + base_period=1998, + n_bootstrap=199, + seed=0, + ) + + +class TestMultiPeriodWorkflow: + """Tests for :func:`did_had_pretest_workflow` event-study dispatch.""" + + def _linear_panel(self, seed: int = 100) -> pd.DataFrame: + return _make_multi_period_panel( + G=80, + periods=[1996, 1997, 1998, 1999, 2000, 2001], + first_treat_period=1999, + seed=seed, + ) + + def test_overall_aggregate_unchanged(self): + """Default aggregate='overall' preserves Phase 3 behavior.""" + d, dy = _linear_dgp(G=50, seed=42) + panel = _make_two_period_panel(50, d, dy, seed=42) + report = did_had_pretest_workflow(panel, "y", "d", "time", "unit", n_bootstrap=299, seed=42) + assert report.aggregate == "overall" + assert report.stute is not None + assert report.yatchew is not None + assert report.pretrends_joint is None + assert report.homogeneity_joint is None + # Phase 3 step-2 gap string STILL present on the overall path + assert "paper step 2 deferred" in report.verdict + + def test_event_study_linear_dgp_all_pass(self): + df = self._linear_panel(seed=101) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=299, + seed=17, + ) + assert report.aggregate == "event_study" + assert report.stute is None and report.yatchew is None + assert report.pretrends_joint is not None + assert report.homogeneity_joint is not None + assert report.all_pass is True + assert "TWFE admissible under Section 4" in report.verdict + # The Phase 3 "paper step 2 deferred" string MUST NOT appear on + # the event-study path - the gap is closed. + assert "paper step 2 deferred" not in report.verdict + + def test_event_study_pretrend_violation_flagged(self): + """Strong pre-trend correlated with D -> pretrends_joint rejects.""" + + def pretrend_outcome(rng_, g, t, d, is_post, ft): + # D is fixed at F for this unit; simulate correlated pre-trend + # via knowing what the unit's eventual dose will be. + # Placeholder: rng seeds unit g the same way _make_multi_period_panel does + # NOTE: this is hacky; we bake-in correlation via g as proxy for dose. + trend = (g / 100.0) * (t - 1998) # pre-existing linear trend + eff = 0.3 * d if is_post else 0.0 + return float(trend + eff + rng_.normal(0.0, 0.05)) + + # Construct panel where dose correlates strongly with unit id (proxy for trend strength) + def dose_fn(rng_, G): + return np.linspace(0.05, 1.0, G) + + df = _make_multi_period_panel( + G=80, + periods=[1996, 1997, 1998, 1999, 2000], + first_treat_period=1999, + dose_fn=dose_fn, + outcome_fn=pretrend_outcome, + seed=200, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=499, + seed=29, + ) + assert report.pretrends_joint is not None + assert report.pretrends_joint.reject is True + assert "joint pre-trends rejected - assumption 7 violated" in report.verdict + + def test_event_study_homogeneity_violation_flagged(self): + """Strong quadratic effect across post-periods -> homogeneity rejects.""" + df = _make_multi_period_panel( + G=100, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=1999, + outcome_fn=_nonlinear_outcome(lambda d: 4.0 * (d**2)), + seed=210, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=999, + seed=31, + ) + assert report.homogeneity_joint is not None + assert report.homogeneity_joint.reject is True + assert "joint linearity rejected - heterogeneity bias" in report.verdict + + def test_event_study_qug_violation_flagged(self): + """Shift dose support away from 0 -> QUG rejects.""" + + def shifted_dose_fn(rng_, G): + # All doses far from 0; D_2,(1) and D_2,(2) close -> T small -> reject + return rng_.uniform(0.5, 1.0, G) + + df = _make_multi_period_panel( + G=80, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=1999, + dose_fn=shifted_dose_fn, + seed=220, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=299, + seed=42, + ) + # QUG T statistic should be large (shifted support) -> reject. + if report.qug.reject: + assert report.verdict.startswith("support infimum rejected") + + def test_invalid_aggregate_raises(self): + df = self._linear_panel(seed=102) + with pytest.raises(ValueError, match="aggregate must be one of"): + did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="bogus", + n_bootstrap=199, + ) + + def test_single_pre_period_yields_pretrends_skipped(self): + """If t_pre_list has only the base pre-period, no earlier placebos + exist -> pretrends_joint is None and verdict flags the skip.""" + df = _make_multi_period_panel( + G=50, + periods=[1998, 1999, 2000, 2001], + first_treat_period=1999, + seed=130, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=299, + seed=52, + ) + assert report.pretrends_joint is None + # Even with a fail-to-reject homogeneity test, all_pass should be + # False because pretrends_joint is None (step 2 not closed). + if report.homogeneity_joint is not None and not report.homogeneity_joint.reject: + assert report.all_pass is False + # Verdict should mention the pre-trends skip. + assert "joint pre-trends skipped" in report.verdict + + def test_no_paper_step_2_deferred_string_on_event_study(self): + """Regression: event-study verdict must not emit the Phase 3 caveat.""" + df = self._linear_panel(seed=111) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=299, + seed=61, + ) + assert "paper step 2 deferred" not in report.verdict + assert "deferred to Phase 3 follow-up" not in report.verdict + + def test_first_treat_col_none_with_staggered_raises(self): + """Inherited contract: staggered panel + no first_treat_col -> raises.""" + # Build two cohorts: one treated at 1999, one at 2000. + parts = [] + for cohort_ft, cohort_range in [(1999, (0, 40)), (2000, (40, 80))]: + for g in range(*cohort_range): + dose = 0.05 + 0.01 * (g - cohort_range[0]) + for t in [1997, 1998, 1999, 2000, 2001]: + is_post = t >= cohort_ft + parts.append( + { + "unit": g, + "period": t, + "y": 0.1 * g + (0.3 * dose if is_post else 0.0), + "d": dose if is_post else 0.0, + } + ) + df = pd.DataFrame(parts) + with pytest.raises(ValueError): + did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + + def test_staggered_auto_filter_warns(self): + """With first_treat_col provided, staggered panel auto-filters with warning.""" + parts = [] + for cohort_ft, cohort_range in [(1999, (0, 30)), (2000, (30, 80))]: + for g in range(*cohort_range): + dose = 0.05 + 0.01 * (g - cohort_range[0]) + for t in [1997, 1998, 1999, 2000, 2001]: + is_post = t >= cohort_ft + parts.append( + { + "unit": g, + "period": t, + "y": 0.1 * g + (0.3 * dose if is_post else 0.0), + "d": dose if is_post else 0.0, + "first_treat": cohort_ft, + } + ) + df = pd.DataFrame(parts) + with pytest.warns(UserWarning, match="staggered"): + did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + first_treat_col="first_treat", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + + def test_event_study_verdict_priority_qug_first(self): + """Rejections bundle; QUG rejection appears first.""" + qug = _mk_qug(reject=True, p=0.01) + pretrends = StuteJointResult( + cvm_stat_joint=1.0, + p_value=0.01, + reject=True, + alpha=0.05, + horizon_labels=["t"], + per_horizon_stats={"t": 1.0}, + n_bootstrap=999, + n_obs=50, + n_horizons=1, + seed=1, + null_form="mean_independence", + exact_linear_short_circuited=False, + ) + homogeneity = StuteJointResult( + cvm_stat_joint=0.5, + p_value=0.5, + reject=False, + alpha=0.05, + horizon_labels=["t"], + per_horizon_stats={"t": 0.5}, + n_bootstrap=999, + n_obs=50, + n_horizons=1, + seed=1, + null_form="linearity", + exact_linear_short_circuited=False, + ) + verdict = _compose_verdict_event_study(qug, pretrends, homogeneity) + # QUG appears before pre-trends + assert verdict.index("QUG") < verdict.index("assumption 7") + + def test_event_study_all_conclusive_no_reject_admissible(self): + qug = _mk_qug(reject=False, p=0.8) + pretrends = StuteJointResult( + cvm_stat_joint=0.3, + p_value=0.7, + reject=False, + alpha=0.05, + horizon_labels=["t"], + per_horizon_stats={"t": 0.3}, + n_bootstrap=999, + n_obs=50, + n_horizons=1, + seed=1, + null_form="mean_independence", + exact_linear_short_circuited=False, + ) + homogeneity = StuteJointResult( + cvm_stat_joint=0.5, + p_value=0.6, + reject=False, + alpha=0.05, + horizon_labels=["t"], + per_horizon_stats={"t": 0.5}, + n_bootstrap=999, + n_obs=50, + n_horizons=1, + seed=1, + null_form="linearity", + exact_linear_short_circuited=False, + ) + verdict = _compose_verdict_event_study(qug, pretrends, homogeneity) + assert "TWFE admissible under Section 4" in verdict + + +class TestHADPretestReportSerialization: + """Tests for HADPretestReport serialization branching by aggregate.""" + + def test_to_dict_overall_preserves_phase3_schema(self): + d, dy = _linear_dgp(G=50) + panel = _make_two_period_panel(50, d, dy, seed=42) + report = did_had_pretest_workflow(panel, "y", "d", "time", "unit", n_bootstrap=199, seed=42) + out = report.to_dict() + assert out["aggregate"] == "overall" + assert "qug" in out and "stute" in out and "yatchew" in out + # Event-study keys absent on overall + assert "pretrends_joint" not in out and "homogeneity_joint" not in out + # Round-trip JSON safely + json.dumps(out) + + def test_to_dict_event_study_emits_joint_keys(self): + df = _make_multi_period_panel( + G=60, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=131, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + out = report.to_dict() + assert out["aggregate"] == "event_study" + assert "qug" in out + assert "pretrends_joint" in out and "homogeneity_joint" in out + # Overall-path keys absent on event-study + assert "stute" not in out and "yatchew" not in out + json.dumps(out) + + def test_to_dataframe_stable_3_row_shape(self): + """to_dataframe returns 3 rows for both aggregates.""" + d, dy = _linear_dgp(G=50) + panel_overall = _make_two_period_panel(50, d, dy, seed=42) + report_overall = did_had_pretest_workflow( + panel_overall, "y", "d", "time", "unit", n_bootstrap=199, seed=42 + ) + df_o = report_overall.to_dataframe() + assert df_o.shape[0] == 3 + assert list(df_o["test"]) == ["qug", "stute", "yatchew_hr"] + + panel_es = _make_multi_period_panel( + G=50, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=132, + ) + report_es = did_had_pretest_workflow( + panel_es, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + df_e = report_es.to_dataframe() + assert df_e.shape[0] == 3 + assert list(df_e["test"]) == ["qug", "pretrends_joint", "homogeneity_joint"] + # Columns identical across aggregates + assert set(df_o.columns) == set(df_e.columns) + + def test_summary_includes_aggregate_header(self): + df = _make_multi_period_panel( + G=50, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=133, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + s = report.summary() + assert "aggregate: event_study" in s + + def test_repr_includes_aggregate(self): + df = _make_multi_period_panel( + G=50, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + seed=134, + ) + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + r = repr(report) + assert "aggregate='event_study'" in r From 8da8e43699e3a1260ad2e9af528713d3c2f47c07 Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 20:05:58 -0400 Subject: [PATCH 2/8] Address PR #353 CI review round 1 (2 P1 + 1 P2 + 2 P3) P1 - wrapper validator + first_treat_col wiring (had_pretests.py): `joint_pretrends_test` and `joint_homogeneity_test` now route through `_validate_had_panel_event_study` when the panel has >= 3 periods, so direct wrapper calls inherit the Appendix B.2 last-cohort filter, constant-post-dose invariant, and staggered/no-first_treat_col raise contract. `first_treat_col` is actually wired through instead of accepted-but-ignored. Subset checks (base_period in validated t_pre_list; pre_periods / post_periods subsets of the corresponding validated set) run after the validator, so callers get crisp errors on mistyped horizons rather than silent miscomputation. P1 - constant-d degeneracy guard in `stute_joint_pretest`: When `ptp(doses) <= 0` (all units share identical dose), warn and return all-NaN inference instead of computing a mechanically-zero CvM (mean-independence null - bogus fail-to-reject) or attempting a singular `[1, d]` refit (linearity null - matrix solve crash). Uses `np.ptp` rather than `np.var` because var-of-constant yields ~1e-32 rounding noise that would slip past a `<= 0` comparison. Mirrors stute_test's intent at single-horizon scale. P2 - bit-exact overall-path serialization: `HADPretestReport.__repr__`, `summary()`, and `to_dict()` now produce Phase 3-identical output when `aggregate="overall"` - no `aggregate` key in the dict, no header line in the summary, no new kwarg in the repr. The `aggregate` field remains on the dataclass internally and is surfaced in these serializations only on `aggregate="event_study"`. Restores the CHANGELOG's bit-exact compatibility claim. P3 - regression tests + docs: Four new tests cover the P1 edge cases: constant-d core path, direct-wrapper staggered panel (with and without first_treat_col), and wrapper-level constant-d propagation. REGISTRY.md and CHANGELOG.md document that step-2 closure requires >=2 pre-periods (the base `F-1` plus at least one earlier placebo); on single-pre- period panels the workflow emits `pretrends_joint=None` with a skip note in the verdict and `all_pass=False`. Existing tests updated for the new validator path: the pre-period D=0 and all-zero post-period checks now fire via the event-study validator's staggered-cohort or contiguous-dose guards before the wrapper's local reciprocal guards can run; regex matchers widened to accept either error surface. `test_to_dict_overall_preserves_phase3_schema` now asserts the ABSENCE of the `aggregate` key on the overall path to match the restored bit-exact schema. 119 tests pass (115 + 4 new R1 regressions); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 2 +- diff_diff/had_pretests.py | 226 ++++++++++++++++++++++++++++++----- docs/methodology/REGISTRY.md | 2 +- tests/test_had_pretests.py | 151 +++++++++++++++++++++-- 4 files changed, 340 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a9aebed..a940c723 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - **`stute_joint_pretest`, `joint_pretrends_test`, `joint_homogeneity_test` + `StuteJointResult`** (HeterogeneousAdoptionDiD Phase 3 follow-up). Joint Cramér-von Mises pretests across K horizons with shared-η Mammen wild bootstrap (preserves vector-valued empirical-process unit-level dependence per Delgado-Manteiga 2001 / Hlávka-Hušková 2020). The core `stute_joint_pretest` is residuals-in; two thin data-in wrappers construct per-horizon residuals for the two nulls the paper spells out: mean-independence (step 2 pre-trends, `OLS(Y_t − Y_base ~ 1)` per pre-period) and linearity (step 3 joint, `OLS(Y_t − Y_base ~ 1 + D)` per post-period). Sum-of-CvMs aggregation (`S_joint = Σ_k S_k`); per-horizon scale-invariant exact-linear short-circuit. Closes the paper Section 4.2 step-2 gap that Phase 3 `did_had_pretest_workflow` previously flagged with an "Assumption 7 pre-trends test NOT run" caveat. See `docs/methodology/REGISTRY.md` §HeterogeneousAdoptionDiD "Joint Stute tests" for algorithm, invariants, and scope exclusion of Eq 18 linear-trend detrending (deferred to Phase 4 Pierce-Schott replication). -- **`did_had_pretest_workflow(aggregate="event_study")`**: multi-period dispatch on balanced ≥3-period panels. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods + joint homogeneity-linearity Stute across post-periods. Reuses the Phase 2b event-study panel validator (last-cohort auto-filter under staggered timing with `UserWarning`; `ValueError` when `first_treat_col=None` and the panel is staggered). `HADPretestReport` extended with `pretrends_joint`, `homogeneity_joint`, and `aggregate` fields; `summary`, `to_dict`, `to_dataframe`, `__repr__` branch on `aggregate` and preserve Phase 3 schemas bit-exactly on the `aggregate="overall"` path. +- **`did_had_pretest_workflow(aggregate="event_study")`**: multi-period dispatch on balanced ≥3-period panels. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods + joint homogeneity-linearity Stute across post-periods. Step 2 closure requires ≥2 pre-periods; with only a single pre-period (the base `F-1`) `pretrends_joint=None` and the verdict flags the skip. Reuses the Phase 2b event-study panel validator (last-cohort auto-filter under staggered timing with `UserWarning`; `ValueError` when `first_treat_col=None` and the panel is staggered). The data-in wrappers `joint_pretrends_test` and `joint_homogeneity_test` also route through that same validator internally, so direct wrapper calls inherit the last-cohort filter and constant-post-dose invariant. `HADPretestReport` extended with `pretrends_joint`, `homogeneity_joint`, and `aggregate` fields; serialization methods (`summary`, `to_dict`, `to_dataframe`, `__repr__`) preserve the Phase 3 output bit-exactly on `aggregate="overall"` — no `aggregate` key, no header row, no schema drift — and only surface the new fields on `aggregate="event_study"`. - **`target_parameter` block in BR/DR schemas (experimental; schema version bumped to 2.0)** — `BUSINESS_REPORT_SCHEMA_VERSION` and `DIAGNOSTIC_REPORT_SCHEMA_VERSION` bumped from `"1.0"` to `"2.0"` because the new `"no_scalar_by_design"` value on the `headline.status` / `headline_metric.status` enum (dCDH `trends_linear=True, L_max>=2` configuration) is a breaking change per the REPORTING.md stability policy. BusinessReport and DiagnosticReport now emit a top-level `target_parameter` block naming what the headline scalar actually represents for each of the 16 result classes. Closes BR/DR foundation gap #6 (target-parameter clarity). Fields: `name`, `definition`, `aggregation` (machine-readable dispatch tag), `headline_attribute` (raw result attribute), `reference` (citation pointer). BR's summary emits the short `name` right after the headline; DR's overall-interpretation paragraph does the same; both full reports carry a "## Target Parameter" section with the full definition. Per-estimator dispatch is sourced from REGISTRY.md and lives in the new `diff_diff/_reporting_helpers.py::describe_target_parameter`. A few branches read fit-time config (`EfficientDiDResults.pt_assumption`, `StackedDiDResults.clean_control`, `ChaisemartinDHaultfoeuilleResults.L_max` / `covariate_residuals` / `linear_trends_effects`); others emit a fixed tag (the fit-time `aggregate` kwarg on CS / Imputation / TwoStage / Wooldridge does not change the `overall_att` scalar — disambiguating horizon / group tables is tracked under gap #9). See `docs/methodology/REPORTING.md` "Target parameter" section. - SyntheticDiD coverage Monte Carlo calibration table added to `docs/methodology/REGISTRY.md` §SyntheticDiD — rejection rates at α ∈ {0.01, 0.05, 0.10} across `placebo` / `bootstrap` / `jackknife` on 3 representative DGPs (balanced / exchangeable, unbalanced, and Arkhangelsky et al. (2021) AER §6.3 non-exchangeable). Artifact at `benchmarks/data/sdid_coverage.json` (500 seeds × B=200), regenerable via `benchmarks/python/coverage_sdid.py`. diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index 78ef0970..a1ccff74 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -608,24 +608,36 @@ class HADPretestReport: aggregate: str = "overall" def __repr__(self) -> str: + # Preserve Phase 3 repr bit-exactly on the overall path. The + # aggregate kwarg is only surfaced on the event-study path so + # downstream consumers comparing repr strings on two-period + # reports see identical output. + if self.aggregate == "event_study": + return ( + f"HADPretestReport(aggregate={self.aggregate!r}, " + f"all_pass={self.all_pass}, " + f"verdict={self.verdict!r}, n_obs={self.n_obs})" + ) return ( - f"HADPretestReport(aggregate={self.aggregate!r}, " - f"all_pass={self.all_pass}, " + f"HADPretestReport(all_pass={self.all_pass}, " f"verdict={self.verdict!r}, n_obs={self.n_obs})" ) def summary(self) -> str: """Formatted summary of all tests and the verdict.""" width = 72 - header = [ - "=" * width, - "HAD pre-test workflow".center(width), - f"aggregate: {self.aggregate}".center(width), - "=" * width, - self.qug.summary(), - "", - ] + # Preserve Phase 3 summary bit-exactly on the overall path. The + # `aggregate: ...` header line is only rendered on the event- + # study path; two-period reports produce the Phase 3 layout. if self.aggregate == "event_study": + header = [ + "=" * width, + "HAD pre-test workflow".center(width), + f"aggregate: {self.aggregate}".center(width), + "=" * width, + self.qug.summary(), + "", + ] if self.pretrends_joint is not None: body = [self.pretrends_joint.summary(), ""] else: @@ -636,7 +648,14 @@ def summary(self) -> str: if self.homogeneity_joint is not None: body += [self.homogeneity_joint.summary(), ""] else: - # aggregate == "overall" + # aggregate == "overall" - Phase 3 layout preserved. + header = [ + "=" * width, + "HAD pre-test workflow".center(width), + "=" * width, + self.qug.summary(), + "", + ] body = [] if self.stute is not None: body += [self.stute.summary(), ""] @@ -657,29 +676,39 @@ def print_summary(self) -> None: def to_dict(self) -> Dict[str, Any]: """Return a JSON-safe nested dict of the full report. - The ``aggregate`` key identifies which component fields are - present; ``None``-valued components are emitted as JSON null. + On ``aggregate="overall"``, the output schema is bit-exact with + Phase 3 (``{qug, stute, yatchew, all_pass, verdict, alpha, + n_obs}``) - no new keys, no aggregate field. On + ``aggregate="event_study"``, the output carries ``aggregate``, + ``pretrends_joint``, ``homogeneity_joint`` and omits the + ``None``-valued ``stute`` / ``yatchew`` keys entirely. """ - base: Dict[str, Any] = { - "aggregate": str(self.aggregate), + if self.aggregate == "event_study": + return { + "aggregate": str(self.aggregate), + "qug": self.qug.to_dict(), + "pretrends_joint": ( + None if self.pretrends_joint is None else self.pretrends_joint.to_dict() + ), + "homogeneity_joint": ( + None if self.homogeneity_joint is None else self.homogeneity_joint.to_dict() + ), + "all_pass": bool(self.all_pass), + "verdict": str(self.verdict), + "alpha": float(self.alpha), + "n_obs": int(self.n_obs), + } + # aggregate == "overall" - Phase 3 schema preserved bit-exactly, + # including key order and the absence of the aggregate field. + return { "qug": self.qug.to_dict(), + "stute": None if self.stute is None else self.stute.to_dict(), + "yatchew": None if self.yatchew is None else self.yatchew.to_dict(), "all_pass": bool(self.all_pass), "verdict": str(self.verdict), "alpha": float(self.alpha), "n_obs": int(self.n_obs), } - if self.aggregate == "event_study": - base["pretrends_joint"] = ( - None if self.pretrends_joint is None else self.pretrends_joint.to_dict() - ) - base["homogeneity_joint"] = ( - None if self.homogeneity_joint is None else self.homogeneity_joint.to_dict() - ) - else: - # aggregate == "overall" - Phase 3 schema preserved bit-exactly - base["stute"] = None if self.stute is None else self.stute.to_dict() - base["yatchew"] = None if self.yatchew is None else self.yatchew.to_dict() - return base def to_dataframe(self) -> pd.DataFrame: """Return a tidy 3-row DataFrame (one row per implemented test). @@ -1943,6 +1972,42 @@ def stute_joint_pretest( exact_linear_short_circuited=False, ) + # Zero-variation-in-D degeneracy guard: mirrors stute_test's intent + # (had_pretests.py:~1233). The CvM cusum is defined against the + # dose regressor; constant d has no cross-sectional variation for + # the test to detect nonlinearity. Under the mean-independence null + # this yields a mechanically-zero statistic (bogus fail-to-reject); + # under the linearity null a singular [1, d] design matrix crashes + # the refit. Emit warning + NaN result instead. + # + # Uses ``ptp`` (peak-to-peak = max - min) rather than ``np.var`` for + # the degeneracy check: ``np.var`` of a truly constant array returns + # a small non-zero value (~1e-32) due to E[X^2] - E[X]^2 rounding + # noise, so a ``<= 0`` comparison misses the degeneracy. ``ptp`` is + # bit-exact for identical inputs. + if float(np.ptp(doses_arr)) <= 0.0: + warnings.warn( + "stute_joint_pretest: constant doses (zero cross-sectional " + "variation); the joint Stute CvM requires dose variation. " + "Returning NaN result.", + UserWarning, + stacklevel=2, + ) + return StuteJointResult( + cvm_stat_joint=float("nan"), + p_value=float("nan"), + reject=False, + alpha=float(alpha), + horizon_labels=horizon_labels, + per_horizon_stats={k: float("nan") for k in horizon_labels}, + n_bootstrap=int(n_bootstrap), + n_obs=int(G), + n_horizons=int(K), + seed=None if seed is None else int(seed), + null_form=str(null_form), + exact_linear_short_circuited=False, + ) + idx = np.argsort(doses_arr, kind="stable") d_sorted = doses_arr[idx] @@ -2113,8 +2178,59 @@ def joint_pretrends_test( f"Violators: {out_of_order!r} (base_period={base_period!r})." ) + # Event-study validation contract (paper Appendix B.2): + # When the panel has >= 3 distinct periods, always route through + # `_validate_had_panel_event_study`. This enforces (a) balanced + # panel, (b) ordered time dtype, (c) D = 0 across every pre-period, + # (d) last-cohort auto-filter under staggered timing with + # UserWarning, (e) constant post-treatment dose within unit. When + # first_treat_col is None and the panel is staggered, the validator + # RAISES - matching the workflow dispatch contract. For 2-period + # panels the validator does not apply; skip and fall through to the + # simpler balance/invariant guards in `_aggregate_for_joint_test`. + n_periods = int(data[time_col].nunique()) + data_filtered: pd.DataFrame = data + if n_periods >= 3: + F_val, t_pre_list, _t_post_list, data_filtered, filter_info = ( + _validate_had_panel_event_study( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + first_treat_col=first_treat_col, + ) + ) + if filter_info is not None: + warnings.warn( + f"joint_pretrends_test: staggered panel auto-filtered to " + f"last cohort (F_last={filter_info['F_last']!r}, " + f"n_kept={filter_info['n_kept']}, " + f"n_dropped={filter_info['n_dropped']}). " + f"Paper Appendix B.2 prescription.", + UserWarning, + stacklevel=2, + ) + # Subset invariants: the caller's base_period and pre_periods + # must be pre-treatment periods under the validator's partition. + if base_period not in t_pre_list: + raise ValueError( + f"base_period={base_period!r} is not in the validated " + f"pre-period set {list(t_pre_list)!r} (periods before " + f"first-treatment period F={F_val!r}). For the HAD " + f"pre-trends workflow, base_period must be a pre-period " + f"anchor (typically the last pre-period, F-1)." + ) + not_pre = [t for t in pre_periods if t not in t_pre_list] + if not_pre: + raise ValueError( + f"pre_periods must all be validated pre-treatment " + f"periods. Not-pre entries: {not_pre!r}. Validator's " + f"pre-period set: {list(t_pre_list)!r}." + ) + d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( - data, + data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, @@ -2128,7 +2244,7 @@ def joint_pretrends_test( # for base_period - it is itself a pre-period relative to the # treatment onset). We check this on the passed-in panel subset. needed_all_zero = list(pre_periods) + [base_period] - subset_zero_check = data[data[time_col].isin(needed_all_zero)] + subset_zero_check = data_filtered[data_filtered[time_col].isin(needed_all_zero)] if (subset_zero_check[dose_col] != 0).any(): n_nonzero = int((subset_zero_check[dose_col] != 0).sum()) raise ValueError( @@ -2242,8 +2358,54 @@ def joint_homogeneity_test( f"Violators: {out_of_order!r} (base_period={base_period!r})." ) + # Event-study validation contract (paper Appendix B.2) - twin of + # `joint_pretrends_test`. Same gating by `n_periods >= 3`; same + # subset-invariant checks; emits the staggered-filter UserWarning. + # The validator also enforces constant post-treatment dose within + # unit, which is critical for the homogeneity path because a + # time-varying post-dose would make the per-horizon refit on + # `[1, D_g]` misspecify the regressor. + n_periods = int(data[time_col].nunique()) + data_filtered: pd.DataFrame = data + if n_periods >= 3: + F_val, t_pre_list, t_post_list, data_filtered, filter_info = ( + _validate_had_panel_event_study( + data, + outcome_col=outcome_col, + dose_col=dose_col, + time_col=time_col, + unit_col=unit_col, + first_treat_col=first_treat_col, + ) + ) + if filter_info is not None: + warnings.warn( + f"joint_homogeneity_test: staggered panel auto-filtered " + f"to last cohort (F_last={filter_info['F_last']!r}, " + f"n_kept={filter_info['n_kept']}, " + f"n_dropped={filter_info['n_dropped']}). " + f"Paper Appendix B.2 prescription.", + UserWarning, + stacklevel=2, + ) + if base_period not in t_pre_list: + raise ValueError( + f"base_period={base_period!r} is not in the validated " + f"pre-period set {list(t_pre_list)!r} (periods before " + f"first-treatment period F={F_val!r}). For the HAD " + f"homogeneity workflow, base_period must be a pre-period " + f"anchor (typically the last pre-period, F-1)." + ) + not_post = [t for t in post_periods if t not in t_post_list] + if not_post: + raise ValueError( + f"post_periods must all be validated post-treatment " + f"periods. Not-post entries: {not_post!r}. Validator's " + f"post-period set: {list(t_post_list)!r}." + ) + d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( - data, + data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, @@ -2258,7 +2420,7 @@ def joint_homogeneity_test( # unit (existence) and is NOT identically zero across all units # (reciprocal twin of the pretrends guard - an all-zero post-period # contradicts the HAD treatment-onset contract). - base_doses = data.loc[data[time_col] == base_period, dose_col] + base_doses = data_filtered.loc[data_filtered[time_col] == base_period, dose_col] if (base_doses != 0).any(): n_nonzero = int((base_doses != 0).sum()) raise ValueError( @@ -2267,7 +2429,7 @@ def joint_homogeneity_test( f"non-zero dose observation(s) in base_period." ) for t in post_periods: - post_doses = data.loc[data[time_col] == t, dose_col] + post_doses = data_filtered.loc[data_filtered[time_col] == t, dose_col] if not (post_doses > 0).any(): raise ValueError( f"post_period={t!r} has D = 0 for every unit. HAD " diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 09b37a4c..224f2c8b 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -2333,7 +2333,7 @@ Shipped as `did_had_pretest_workflow()` in Phase 3 (two-period `aggregate="overa **Phase 3 delivery (`aggregate="overall"`, two-period):** `did_had_pretest_workflow()` runs steps 1 + 3 (QUG + Stute + Yatchew). Step 2 is NOT run on this path because a two-period panel has no pre-period placebo horizon to test against; the verdict explicitly flags the Assumption 7 gap via the "paper step 2 deferred" caveat. -**Phase 3 follow-up delivery (`aggregate="event_study"`, multi-period):** `did_had_pretest_workflow(..., aggregate="event_study")` dispatches on a balanced ≥3-period panel. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods (step 2, mean-independence null) + joint homogeneity-linearity Stute across post-periods (step 3 joint extension, linearity null). The verdict on this path does NOT emit the "paper step 2 deferred" caveat — the gap is closed. +**Phase 3 follow-up delivery (`aggregate="event_study"`, multi-period):** `did_had_pretest_workflow(..., aggregate="event_study")` dispatches on a balanced ≥3-period panel. Runs QUG at `F` + joint pre-trends Stute across earlier pre-periods (step 2, mean-independence null) + joint homogeneity-linearity Stute across post-periods (step 3 joint extension, linearity null). Step 2 closure requires at least TWO pre-periods (the base pre-period plus one earlier placebo); on panels with only a single pre-period (the base `F-1`) the workflow emits `pretrends_joint=None` and the verdict flags the skip ("joint pre-trends skipped (no earlier pre-period)"). `all_pass` is False in this degenerate case. The verdict on the event-study path does NOT emit the "paper step 2 deferred" caveat when step 2 runs. *Algorithm variant - Joint Stute tests (Section 4.2-4.3 joint; Phase 3 follow-up 2026-04):* Shipped in `diff_diff/had_pretests.py` as `stute_joint_pretest()` (residuals-in core) plus two thin data-in wrappers `joint_pretrends_test()` (mean-independence null) and `joint_homogeneity_test()` (linearity null). Generalizes the single-horizon Stute CvM (above) to K horizons with joint inference. diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index 77327f42..f223f80c 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -1550,6 +1550,31 @@ def test_horizon_labels_preserved_as_strings(self): assert result.horizon_labels == ["1997", "1998"] assert set(result.per_horizon_stats.keys()) == {"1997", "1998"} + def test_constant_d_returns_nan_with_warning(self): + """R1: constant doses - no cross-sectional variation to detect + nonlinearity. Must warn and return NaN inference rather than + a mechanically-zero CvM (mean-indep null) or singular refit + (linearity null). Mirrors stute_test's single-horizon guard.""" + G = 30 + resid, fit, _ = _multi_period_residuals(G, K=2, seed=123) + d_constant = np.full(G, 0.5, dtype=np.float64) + with pytest.warns(UserWarning, match="constant doses"): + result = stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d_constant, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + assert np.isnan(result.cvm_stat_joint) + assert np.isnan(result.p_value) + assert result.reject is False + assert result.exact_linear_short_circuited is False + # Per-horizon stats preserved with NaN values (diagnostic surface) + assert set(result.per_horizon_stats.keys()) == set(resid.keys()) + assert all(np.isnan(v) for v in result.per_horizon_stats.values()) + class TestJointPretrendsTest: """Tests for :func:`joint_pretrends_test` data-in wrapper.""" @@ -1688,13 +1713,16 @@ def test_out_of_order_pre_period_raises(self): ) def test_non_zero_dose_in_pre_period_raises(self): - """HAD contract: pre-periods have D=0 for every unit.""" + """HAD contract: pre-periods have D=0 for every unit. The + event-study validator catches this via its staggered-cohort + detection (a pre-period unit with D>0 looks like an earlier + treatment cohort).""" df = _make_multi_period_panel( G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=2000, seed=1 ) # Contaminate pre-period 1998 with a non-zero dose for one unit df.loc[(df["unit"] == 0) & (df["period"] == 1998), "d"] = 0.5 - with pytest.raises(ValueError, match="D = 0"): + with pytest.raises(ValueError, match="Staggered|dose invariant|D = 0"): joint_pretrends_test( df, "y", @@ -1708,12 +1736,14 @@ def test_non_zero_dose_in_pre_period_raises(self): ) def test_non_zero_dose_in_base_period_raises(self): - """Reciprocal: base_period (last pre-period) must also satisfy D=0.""" + """Reciprocal: base_period (last pre-period) must also satisfy + D=0. Caught by the event-study validator before our local + guard runs.""" df = _make_multi_period_panel( G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=2000, seed=1 ) df.loc[(df["unit"] == 0) & (df["period"] == 1999), "d"] = 0.3 - with pytest.raises(ValueError, match="D = 0"): + with pytest.raises(ValueError, match="Staggered|dose invariant|D = 0"): joint_pretrends_test( df, "y", @@ -1726,6 +1756,107 @@ def test_non_zero_dose_in_base_period_raises(self): seed=0, ) + def test_staggered_panel_without_first_treat_col_raises(self): + """R1: direct wrapper call on a staggered panel without + first_treat_col must raise via the event-study validator + contract (same behavior as did_had_pretest_workflow's + event-study dispatch).""" + parts = [] + for cohort_ft, cohort_range in [(1999, (0, 15)), (2000, (15, 30))]: + for g in range(*cohort_range): + dose = 0.05 + 0.01 * (g - cohort_range[0]) + for t in [1997, 1998, 1999, 2000, 2001]: + is_post = t >= cohort_ft + parts.append( + { + "unit": g, + "period": t, + "y": 0.1 * g + (0.3 * dose if is_post else 0.0), + "d": dose if is_post else 0.0, + } + ) + df = pd.DataFrame(parts) + with pytest.raises(ValueError, match="Staggered"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + + def test_staggered_panel_with_first_treat_col_warns_and_filters(self): + """R1: direct wrapper call on a staggered panel WITH + first_treat_col auto-filters to last cohort + never-treated + and emits UserWarning.""" + parts = [] + for cohort_ft, cohort_range in [(1999, (0, 10)), (2000, (10, 40))]: + for g in range(*cohort_range): + dose = 0.05 + 0.01 * (g - cohort_range[0]) + for t in [1997, 1998, 1999, 2000, 2001]: + is_post = t >= cohort_ft + parts.append( + { + "unit": g, + "period": t, + "y": 0.1 * g + (0.3 * dose if is_post else 0.0), + "d": dose if is_post else 0.0, + "first_treat": cohort_ft, + } + ) + df = pd.DataFrame(parts) + with pytest.warns(UserWarning, match="staggered|Staggered"): + result = joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + first_treat_col="first_treat", + n_bootstrap=199, + seed=0, + ) + assert isinstance(result, StuteJointResult) + # Last cohort (F=2000) + never-treated kept (none in this fixture + # - all units are treated); n_obs reflects the filter. + assert result.n_obs == 30 # 10 filtered + 30 kept... actually just 30 kept + + def test_constant_d_wrapper_path_returns_nan_with_warning(self): + """R1: direct wrapper call on a panel where ALL units have the + same dose - propagates the joint core's constant-d guard and + returns NaN inference rather than a spurious fail-to-reject.""" + + def const_dose(rng_, G): # noqa: ARG001 + return np.full(G, 0.4, dtype=np.float64) + + df = _make_multi_period_panel( + G=30, + periods=[1997, 1998, 1999, 2000, 2001], + first_treat_period=2000, + dose_fn=const_dose, + seed=33, + ) + with pytest.warns(UserWarning, match="constant doses"): + result = joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=[1997, 1998], + base_period=1999, + n_bootstrap=199, + seed=0, + ) + assert np.isnan(result.p_value) + assert result.reject is False + class TestJointHomogeneityTest: """Tests for :func:`joint_homogeneity_test` data-in wrapper.""" @@ -1882,13 +2013,16 @@ def test_post_period_before_base_raises(self): ) def test_all_zero_dose_post_period_raises(self): - """Post-period with D=0 for every unit contradicts HAD contract.""" + """Post-period with D=0 for every unit contradicts HAD contract. + Caught either by the event-study validator's contiguous-dose + invariant (post-zero breaks the monotone transition from pre + D=0 to post D>0) or by our local reciprocal guard.""" df = _make_multi_period_panel( G=30, periods=[1997, 1998, 1999, 2000], first_treat_period=1999, seed=1 ) # Zero out all doses at post-period 2000 (keep 1999 post intact for base contract) df.loc[df["period"] == 2000, "d"] = 0.0 - with pytest.raises(ValueError, match="D = 0 for every unit"): + with pytest.raises(ValueError, match="dose invariant|D = 0 for every unit"): joint_homogeneity_test( df, "y", @@ -2231,7 +2365,10 @@ def test_to_dict_overall_preserves_phase3_schema(self): panel = _make_two_period_panel(50, d, dy, seed=42) report = did_had_pretest_workflow(panel, "y", "d", "time", "unit", n_bootstrap=199, seed=42) out = report.to_dict() - assert out["aggregate"] == "overall" + # Phase 3 schema is bit-exact: no `aggregate` key on the overall + # path (only emitted on event_study) - Phase 3 downstream + # consumers must not see a new key. + assert "aggregate" not in out assert "qug" in out and "stute" in out and "yatchew" in out # Event-study keys absent on overall assert "pretrends_joint" not in out and "homogeneity_joint" not in out From 84835defc69e7cdb36598f7c9b51425ce0ea2e5e Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 20:21:26 -0400 Subject: [PATCH 3/8] Address PR #353 CI review round 2 (1 P1 + 1 P3) P1 - ordered-categorical chronology: raw `t < base_period` / `t > base_period` comparisons in `joint_pretrends_test`, `joint_homogeneity_test`, and `did_had_pretest_workflow(aggregate= "event_study")` silently misorder ordered-categorical time columns whose lexical and chronological order disagree (e.g. categories ["q1", "q2", "q10"] sort lexically as "q1" < "q10" < "q2"). On such panels the raw comparison could (a) silently drop valid pre-period horizons via the raw `<` check, (b) emit a spurious "joint pre-trends skipped" verdict from the workflow's `earlier_pre` filter, or (c) raise on valid post-period inputs. Fix: new private helper `_build_period_rank` returns a {period_label: chronological_rank} map using the ordered- categorical category order when applicable, natural sort on numeric / datetime otherwise. Both wrappers compare period labels via rank (`rank[t1] < rank[t2]`) instead of raw Python `<`/`>`. The workflow's `earlier_pre` replaces the raw-< filter with `list(t_pre_list[:-1])` - `t_pre_list` is already chronologically sorted by the validator (via its `_sort_key`), so excluding the last element yields the earlier pre-periods regardless of dtype. P3 - ordered-categorical regression tests: new `TestOrderedCategoricalChronology` class (4 tests) with a fixture using categories `["q1", "q2", "q10", "post"]`. Covers (a) direct pretrends wrapper picks up both earlier placebos, (b) pretrends wrapper rejects lexically-ordered-but-chrono-invalid input (e.g. pre=["q10"], base="q2"), (c) homogeneity wrapper accepts valid post-period input, (d) workflow event-study dispatch surfaces both earlier placebos in `pretrends_joint.horizon_labels` without the false skip note. 123 tests pass (119 + 4 new); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 111 +++++++++++++++++++++++---------- tests/test_had_pretests.py | 124 +++++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+), 31 deletions(-) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index a1ccff74..6ddb1243 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -1640,6 +1640,30 @@ def _validate_multi_period_panel( ) +def _build_period_rank(data: pd.DataFrame, time_col: str) -> Dict[Any, int]: + """Build a ``{period_label: chronological_rank}`` map. + + For ordered categorical time columns, uses the declared category + order so that e.g. ``["q1", "q2", "q10"]`` ranks chronologically + even though it sorts lexically in the opposite order. For numeric + or datetime time columns, uses natural Python `sorted` order on + the unique period labels. Object dtypes would fall back to + lexicographic order - callers relying on chronology with object- + dtype labels should convert to an ordered categorical first + (this mirrors the contract in ``_validate_had_panel_event_study``). + + The rank map lets the joint-pretest wrappers compare period labels + chronologically via ``rank[t1] < rank[t2]`` instead of raw Python + ``t1 < t2``, which would silently misorder ordered-categorical + panels (paper Appendix B.2 support contract). + """ + time_dtype = data[time_col].dtype + if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered: + return {c: i for i, c in enumerate(time_dtype.categories)} + periods = sorted(data[time_col].unique()) + return {p: i for i, p in enumerate(periods)} + + def _aggregate_for_joint_test( data: pd.DataFrame, outcome_col: str, @@ -2157,25 +2181,32 @@ def joint_pretrends_test( f"base_period={base_period!r} must not appear in " f"pre_periods {list(pre_periods)!r}." ) - # Ordering check: all pre_periods strictly < base_period (natural - # order on the column dtype). We rely on the time column being - # comparable (numeric, datetime, or ordered categorical); other - # dtypes would silently misorder. The multi-period validator (when - # called via the workflow) enforces an ordered dtype; direct callers - # get a TypeError here on incomparable types. - try: - out_of_order = [t for t in pre_periods if not (t < base_period)] - except TypeError as exc: - raise TypeError( - "pre_periods and base_period must be comparable " - "(numeric, datetime, or ordered categorical values). " - f"Got pre_periods={list(pre_periods)!r}, " - f"base_period={base_period!r}." - ) from exc + # Ordering check: all pre_periods strictly < base_period in + # chronological order. Uses `_build_period_rank` to handle ordered- + # categorical time columns correctly (raw Python `<` would fail on + # categories whose lexical order disagrees with chronology, e.g. + # ["q1", "q2", "q10"]). Numeric / datetime dtypes get natural order. + period_rank = _build_period_rank(data, time_col) + if base_period not in period_rank: + raise ValueError( + f"base_period={base_period!r} not found in time_col " + f"{time_col!r}. Available: " + f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." + ) + missing_pre_in_data = [t for t in pre_periods if t not in period_rank] + if missing_pre_in_data: + raise ValueError( + f"pre_periods entries {missing_pre_in_data!r} not found in " + f"time_col {time_col!r}. Available: " + f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." + ) + base_rank = period_rank[base_period] + out_of_order = [t for t in pre_periods if period_rank[t] >= base_rank] if out_of_order: raise ValueError( - f"All pre_periods must be strictly < base_period. " - f"Violators: {out_of_order!r} (base_period={base_period!r})." + f"All pre_periods must be strictly < base_period in " + f"chronological order. Violators: {out_of_order!r} " + f"(base_period={base_period!r})." ) # Event-study validation contract (paper Appendix B.2): @@ -2341,21 +2372,31 @@ def joint_homogeneity_test( f"post_periods {list(post_periods)!r}." ) - # Ordering: all post_periods >= base_period (and in fact strictly - # greater under the HAD contract where base is the last pre-period). - try: - out_of_order = [t for t in post_periods if not (t > base_period)] - except TypeError as exc: - raise TypeError( - "post_periods and base_period must be comparable " - "(numeric, datetime, or ordered categorical values). " - f"Got post_periods={list(post_periods)!r}, " - f"base_period={base_period!r}." - ) from exc + # Ordering: all post_periods strictly > base_period in + # chronological order. Uses `_build_period_rank` for ordered- + # categorical correctness (raw Python `>` would misorder e.g. + # "q10" > "q2"). + period_rank = _build_period_rank(data, time_col) + if base_period not in period_rank: + raise ValueError( + f"base_period={base_period!r} not found in time_col " + f"{time_col!r}. Available: " + f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." + ) + missing_post_in_data = [t for t in post_periods if t not in period_rank] + if missing_post_in_data: + raise ValueError( + f"post_periods entries {missing_post_in_data!r} not found in " + f"time_col {time_col!r}. Available: " + f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." + ) + base_rank = period_rank[base_period] + out_of_order = [t for t in post_periods if period_rank[t] <= base_rank] if out_of_order: raise ValueError( - f"All post_periods must be strictly > base_period. " - f"Violators: {out_of_order!r} (base_period={base_period!r})." + f"All post_periods must be strictly > base_period in " + f"chronological order. Violators: {out_of_order!r} " + f"(base_period={base_period!r})." ) # Event-study validation contract (paper Appendix B.2) - twin of @@ -2595,7 +2636,15 @@ def did_had_pretest_workflow( # strictly before base_period). If only the base pre-period is # available (len(t_pre_list) == 1), there are no earlier # placebos; set pretrends_joint=None and flag in verdict. - earlier_pre = [t for t in t_pre_list if t < base_period] + # ``t_pre_list`` is returned chronologically sorted by + # ``_validate_had_panel_event_study`` (using the column's + # ordered-categorical category order or the natural numeric / + # datetime order), so taking everything but the last element + # gives the earlier pre-periods regardless of dtype. Raw + # ``t < base_period`` would misorder ordered-categorical labels + # whose lexical and chronological order disagree (e.g. "q10" < + # "q2" lexically but > chronologically). + earlier_pre = list(t_pre_list[:-1]) if len(earlier_pre) >= 1: pretrends_joint = joint_pretrends_test( data_filtered, diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index f223f80c..65826f49 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -2357,6 +2357,130 @@ def test_event_study_all_conclusive_no_reject_admissible(self): assert "TWFE admissible under Section 4" in verdict +class TestOrderedCategoricalChronology: + """R2 P1 regressions: ordered-categorical time columns whose lexical + and chronological order disagree (e.g. ``"q10"`` < ``"q2"`` + lexically but > chronologically). Raw ``t < base_period`` comparisons + misorder these panels; the wrappers and workflow must use validated- + rank comparisons to apply the test to the intended horizons.""" + + @staticmethod + def _categorical_panel( + G: int = 60, + categories=("q1", "q2", "q10", "post"), + first_treat="post", + seed: int = 501, + ) -> pd.DataFrame: + """Panel with ordered-categorical time whose lexical order + (``"q1" < "q10" < "q2" < "post"``) differs from chronological + order (``"q1" < "q2" < "q10" < "post"``).""" + cat_type = pd.CategoricalDtype(categories=list(categories), ordered=True) + rng = np.random.default_rng(seed) + doses = rng.uniform(0.05, 1.0, size=G) + rows = [] + for g in range(G): + for t in categories: + is_post = t == first_treat + d = float(doses[g]) if is_post else 0.0 + y = 0.1 * g + (0.4 * d if is_post else 0.0) + rng.normal(0.0, 0.1) + rows.append({"unit": g, "period": t, "y": y, "d": d}) + df = pd.DataFrame(rows) + df["period"] = df["period"].astype(cat_type) + return df + + def test_joint_pretrends_test_uses_chronological_rank(self): + """Direct wrapper call with categories ["q1", "q2", "q10"] where + the lexical order puts "q10" BEFORE "q2" but chronologically + "q10" comes AFTER "q2". All three pre-periods must be accepted + without a false out-of-order error.""" + df = self._categorical_panel() + result = joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=["q1", "q2"], + base_period="q10", + n_bootstrap=199, + seed=3, + ) + assert result.n_horizons == 2 + assert set(result.horizon_labels) == {"q1", "q2"} + # The detrended-outcome residuals are mean-centered; under null + # (no pre-trend correlated with D), p should be > 0.05 on this + # weakly-noisy DGP. + assert np.isfinite(result.p_value) + + def test_joint_pretrends_raises_on_lexically_ordered_but_chrono_invalid(self): + """With base_period="q2" and pre_periods=["q10"], chronologically + q10 > q2 so this is out-of-order - the rank-based check must + raise. Raw `<` on the lexical side would INCORRECTLY accept + it since "q10" < "q2" lexically.""" + df = self._categorical_panel() + with pytest.raises(ValueError, match="chronological order"): + joint_pretrends_test( + df, + "y", + "d", + "period", + "unit", + pre_periods=["q10"], + base_period="q2", + n_bootstrap=199, + seed=0, + ) + + def test_joint_homogeneity_test_uses_chronological_rank(self): + """Homogeneity wrapper twin of the pretrends test. Post-period + "post" comes after all pre-periods chronologically; base="q10" + is the last pre-period. Lexically "post" > "q10" too (coincides + here), but the rank-based check must not rely on that.""" + df = self._categorical_panel() + result = joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=["post"], + base_period="q10", + n_bootstrap=199, + seed=7, + ) + assert result.n_horizons == 1 + assert result.horizon_labels == ["post"] + assert np.isfinite(result.p_value) + + def test_workflow_event_study_ordered_categorical(self): + """did_had_pretest_workflow(aggregate="event_study") must pick + up BOTH earlier pre-periods ("q1", "q2") from an ordered- + categorical panel where lexical order would silently drop one + of them. Regression against the `earlier_pre` raw-< fix.""" + df = self._categorical_panel() + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + aggregate="event_study", + n_bootstrap=199, + seed=13, + ) + assert report.aggregate == "event_study" + assert report.pretrends_joint is not None + # t_pre_list = ["q1", "q2", "q10"] chronologically; base = "q10" + # (last pre-period); earlier_pre should be ["q1", "q2"] - both + # placebo horizons must appear in pretrends_joint. + assert set(report.pretrends_joint.horizon_labels) == {"q1", "q2"} + assert report.homogeneity_joint is not None + assert report.homogeneity_joint.horizon_labels == ["post"] + # Verdict does not emit the step-2-skipped flag (both earlier + # placebos were found). + assert "joint pre-trends skipped" not in report.verdict + + class TestHADPretestReportSerialization: """Tests for HADPretestReport serialization branching by aggregate.""" From 0040bad7b4679a369ec5b404389bccdec94cf25c Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 20:30:55 -0400 Subject: [PATCH 4/8] Address PR #353 CI review round 3 (1 P1 + 1 P3) P1 - row-level non-negative-dose guard in `_aggregate_for_joint_test`: On a 2-period direct call to `joint_pretrends_test` or `joint_homogeneity_test`, the n_periods < 3 path skips `_validate_had_panel_event_study` (which requires >= 3 periods) and falls through to `_aggregate_for_joint_test`. That helper collapsed unit dose via `groupby(unit_col)[dose_col].max()`, which silently recodes a negative post dose to 0 (`max(0, -d) = 0` for positive pre-period d), allowing finite joint-Stute output on data that violates the HAD support restriction `D_{g,t} >= 0` (paper Section 2). Fix: add a row-level `dose_col >= 0` check in `_aggregate_for_joint_test` BEFORE the groupby/max collapse. Centralizes the guard so both data-in wrappers inherit it on the n_periods < 3 fallback path. The multi-period path already enforces the same invariant via `_validate_had_panel_event_study`, so the contract is consistent across both wrapper dispatch modes. P3 - regression test: new `TestJointHomogeneityTest::test_two_period_negative_post_dose_raises` constructs a 2-period panel with a single unit carrying a negative post dose and asserts the wrapper raises `ValueError` with the "negative dose value" substring rather than producing a finite statistic via the groupby-max collapse. 124 tests pass (123 + 1 new R3 regression); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 20 +++++++++++++++++++ tests/test_had_pretests.py | 39 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index 6ddb1243..843ec351 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -1737,6 +1737,26 @@ def _aggregate_for_joint_test( f"silently drop rows; drop or impute before calling." ) + # Row-level non-negative-dose guard (paper Section 2 HAD support + # restriction `D_{g,t} >= 0`). Must run BEFORE the groupby/max() + # collapse below, otherwise a negative post dose would silently + # become 0 in the per-unit dose vector (since `max(0, -d) = 0` for + # positive d), letting the wrappers run on invalid data and + # potentially return finite results. This is the direct-wrapper + # equivalent of the row-level check inside + # `_validate_had_panel_event_study`, centralized so both + # `joint_pretrends_test` and `joint_homogeneity_test` inherit it on + # the `n_periods < 3` fallback path that skips the validator. + negative_dose_mask = subset[dose_col] < 0 + if bool(negative_dose_mask.any()): + n_neg = int(negative_dose_mask.sum()) + raise ValueError( + f"{n_neg} negative dose value(s) found in column " + f"{dose_col!r} across periods {needed_periods}. HAD support " + f"restriction (paper Section 2) requires D_{{g,t}} >= 0 " + f"for every (unit, period)." + ) + counts = subset.groupby(unit_col).size() n_needed = len(needed_periods) if (counts != n_needed).any(): diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index 65826f49..865d15e8 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -2035,6 +2035,45 @@ def test_all_zero_dose_post_period_raises(self): seed=0, ) + def test_two_period_negative_post_dose_raises(self): + """R2 P1 regression: direct wrapper call on a 2-period panel + with a negative post dose must raise rather than silently + collapse to zero via ``groupby.max()`` and produce a finite + result. The 2-period path skips the event-study validator + (``n_periods < 3``) so the row-level non-negative guard must + live in ``_aggregate_for_joint_test`` itself.""" + G = 20 + rng = np.random.default_rng(601) + doses = rng.uniform(0.1, 1.0, size=G) + # Flip one unit's post dose to a negative value. + doses[0] = -0.3 + rows = [] + for g in range(G): + # pre-period + rows.append({"unit": g, "period": 0, "y": rng.normal(0, 0.1), "d": 0.0}) + # post-period (with negative dose injected for unit 0) + rows.append( + { + "unit": g, + "period": 1, + "y": rng.normal(0, 0.1) + 0.3 * doses[g], + "d": float(doses[g]), + } + ) + df = pd.DataFrame(rows) + with pytest.raises(ValueError, match="negative dose value"): + joint_homogeneity_test( + df, + "y", + "d", + "period", + "unit", + post_periods=[1], + base_period=0, + n_bootstrap=199, + seed=0, + ) + class TestMultiPeriodWorkflow: """Tests for :func:`did_had_pretest_workflow` event-study dispatch.""" From db170bd62366c83ddaab794294c69363679c9a17 Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 20:44:13 -0400 Subject: [PATCH 5/8] Address PR #353 CI review round 4 (1 P1 + 2 P3) P1 - stringified-label collision guard in stute_joint_pretest: The core indexed residuals_arrays / fitted_arrays by `str(k)` with no uniqueness check on the stringified keys. Two distinct raw keys whose str() forms collide (e.g. {1: ..., "1": ...} both stringify to "1", or custom objects with identical __str__) would silently overwrite one entry and then be double-counted in S_joint = sum(S_k) because the surviving horizon's statistic gets summed twice while n_horizons still reports K=2. That produces wrong methodology output with no diagnostic. Fix: compute the stringified labels once up front and reject any collision explicitly with a ValueError listing which raw keys collide to which stringified form. Centralizes the check before any residual/fitted array is dropped. Replaces the ad-hoc post-hoc re-keying with a reuse of the pre-computed collision-free list. P3 - dedupe staggered-filter UserWarning: `_validate_had_panel_event_study` already warns on the staggered auto-filter path; both joint-pretest wrappers and the event-study workflow were re-emitting the same information with a wrapper-prefixed message. Each staggered call therefore surfaced two warnings to the user. Removes the secondary emissions; wrappers now consume `_filter_info` silently. Existing tests still pass because the validator's own `"Staggered-timing panel detected"` message satisfies the regex matchers. P3 - collision regression test: new `TestStuteJointPretest::test_stringified_key_collision_raises` exercises (a) the int 1 + str "1" case and (b) a pair of custom objects with identical __str__ but distinct hash; both must raise `ValueError` with "collision after str" in the message. 125 tests pass (124 + 1 new R4 collision regression); black/ruff/ mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 85 ++++++++++++++++++++------------------ tests/test_had_pretests.py | 58 ++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 40 deletions(-) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index 843ec351..7409d431 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -1976,12 +1976,36 @@ def stute_joint_pretest( if not np.all(np.isfinite(X)): raise ValueError("design_matrix contains non-finite values (NaN/inf).") - horizon_labels = list(residuals_by_horizon.keys()) - K = len(horizon_labels) + raw_horizon_labels = list(residuals_by_horizon.keys()) + K = len(raw_horizon_labels) + + # Stringified-label collision guard: distinct raw keys whose str() + # representations collide (e.g. {1: ..., "1": ..., 1.0: ...}) would + # overwrite each other in residuals_arrays / fitted_arrays, letting + # the surviving horizon be double-counted in S_joint = sum of S_k + # and leaving `n_horizons` inconsistent with the number of distinct + # diagnostic statistics. Reject explicitly rather than silently + # collapsing the test. + str_labels = [str(k) for k in raw_horizon_labels] + if len(set(str_labels)) != len(str_labels): + from collections import Counter + + dup_strs = [s for s, c in Counter(str_labels).items() if c > 1] + collisions = {s: [k for k in raw_horizon_labels if str(k) == s] for s in dup_strs} + raise ValueError( + f"Horizon label collision after str() stringification: " + f"{collisions!r}. The joint Stute helpers index residuals " + f"and fitted values by str(label); distinct raw keys whose " + f"stringified form collides would silently overwrite each " + f"other and double-count the surviving horizon in S_joint. " + f"Use string-distinct horizon labels (e.g. 1997 and 1998 " + f'as int, or "1997" and "1998" as str; not both).' + ) + any_nan = False residuals_arrays: Dict[str, np.ndarray] = {} fitted_arrays: Dict[str, np.ndarray] = {} - for k in horizon_labels: + for k in raw_horizon_labels: eps_k = np.asarray(residuals_by_horizon[k], dtype=np.float64) fit_k = np.asarray(fitted_by_horizon[k], dtype=np.float64) if eps_k.shape != (G,) or fit_k.shape != (G,): @@ -1997,8 +2021,9 @@ def stute_joint_pretest( # Re-key to str labels consistently (wrappers already pass str; direct # callers may pass int/object). String identity per the documented - # horizon_labels contract. - horizon_labels = [str(k) for k in horizon_labels] + # horizon_labels contract. The collision guard above ensures this + # stringification is injective on the provided keys. + horizon_labels = str_labels if any_nan: return StuteJointResult( @@ -2242,7 +2267,7 @@ def joint_pretrends_test( n_periods = int(data[time_col].nunique()) data_filtered: pd.DataFrame = data if n_periods >= 3: - F_val, t_pre_list, _t_post_list, data_filtered, filter_info = ( + F_val, t_pre_list, _t_post_list, data_filtered, _filter_info = ( _validate_had_panel_event_study( data, outcome_col=outcome_col, @@ -2252,16 +2277,10 @@ def joint_pretrends_test( first_treat_col=first_treat_col, ) ) - if filter_info is not None: - warnings.warn( - f"joint_pretrends_test: staggered panel auto-filtered to " - f"last cohort (F_last={filter_info['F_last']!r}, " - f"n_kept={filter_info['n_kept']}, " - f"n_dropped={filter_info['n_dropped']}). " - f"Paper Appendix B.2 prescription.", - UserWarning, - stacklevel=2, - ) + # `_validate_had_panel_event_study` already emits its own + # `UserWarning` on the staggered-filter path; the wrapper + # consumes `_filter_info` silently to avoid duplicated console + # noise (R4 code-quality fix). # Subset invariants: the caller's base_period and pre_periods # must be pre-treatment periods under the validator's partition. if base_period not in t_pre_list: @@ -2429,7 +2448,7 @@ def joint_homogeneity_test( n_periods = int(data[time_col].nunique()) data_filtered: pd.DataFrame = data if n_periods >= 3: - F_val, t_pre_list, t_post_list, data_filtered, filter_info = ( + F_val, t_pre_list, t_post_list, data_filtered, _filter_info = ( _validate_had_panel_event_study( data, outcome_col=outcome_col, @@ -2439,16 +2458,10 @@ def joint_homogeneity_test( first_treat_col=first_treat_col, ) ) - if filter_info is not None: - warnings.warn( - f"joint_homogeneity_test: staggered panel auto-filtered " - f"to last cohort (F_last={filter_info['F_last']!r}, " - f"n_kept={filter_info['n_kept']}, " - f"n_dropped={filter_info['n_dropped']}). " - f"Paper Appendix B.2 prescription.", - UserWarning, - stacklevel=2, - ) + # `_validate_had_panel_event_study` already emits its own + # `UserWarning` on the staggered-filter path; the wrapper + # consumes `_filter_info` silently to avoid duplicated console + # noise (R4 code-quality fix). if base_period not in t_pre_list: raise ValueError( f"base_period={base_period!r} is not in the validated " @@ -2615,7 +2628,7 @@ def did_had_pretest_workflow( ) if aggregate == "event_study": - F, t_pre_list, t_post_list, data_filtered, filter_info = _validate_multi_period_panel( + F, t_pre_list, t_post_list, data_filtered, _filter_info = _validate_multi_period_panel( data, outcome_col=outcome_col, dose_col=dose_col, @@ -2623,18 +2636,10 @@ def did_had_pretest_workflow( unit_col=unit_col, first_treat_col=first_treat_col, ) - if filter_info is not None: - warnings.warn( - f"HAD event-study pre-test: staggered panel auto-" - f"filtered to last cohort " - f"(F_last={filter_info['F_last']!r}, " - f"n_kept={filter_info['n_kept']}, " - f"n_dropped={filter_info['n_dropped']}, " - f"dropped_cohorts={filter_info['dropped_cohorts']}). " - f"Paper Appendix B.2 prescription.", - UserWarning, - stacklevel=2, - ) + # `_validate_multi_period_panel` delegates to + # `_validate_had_panel_event_study`, which already emits its own + # `UserWarning` on the staggered-filter path; we do NOT warn a + # second time here (R4 code-quality fix - single emission point). # Base period for both joint tests is the last pre-period # (paper convention: anchor at F-1 under natural time order). diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index 865d15e8..0f8dfa60 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -1575,6 +1575,64 @@ def test_constant_d_returns_nan_with_warning(self): assert set(result.per_horizon_stats.keys()) == set(resid.keys()) assert all(np.isnan(v) for v in result.per_horizon_stats.values()) + def test_stringified_key_collision_raises(self): + """R4 P1 regression: two raw keys whose str() representations + collide (e.g. int 1 and str '1', or int 1 and float 1.0) must + raise explicitly rather than silently overwrite one horizon in + the internal residuals_arrays map and double-count the survivor + in the sum-of-CvMs S_joint.""" + G = 20 + rng = np.random.default_rng(701) + d = rng.uniform(0.0, 1.0, G) + # int / str collision: str(1) == "1" + resid_int_str_collision = { + 1: rng.normal(0.0, 1.0, G), + "1": rng.normal(0.0, 1.0, G), + } + fit_int_str_collision = {1: np.zeros(G), "1": np.zeros(G)} + with pytest.raises(ValueError, match="collision after str"): + stute_joint_pretest( + residuals_by_horizon=resid_int_str_collision, + fitted_by_horizon=fit_int_str_collision, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + + # int / float collision: str(1) == "1" but str(1.0) == "1.0" + # so these actually don't collide. Test a real collision case: + # two different string representations of the same label. + # Python: str(True) == "True"; bool(1) == True but that's the + # same key. Use: str(None) == "None" collides if passed twice, + # but keys must be unique per dict. Safer: two equal-after-str + # object keys that were distinct before str conversion. + class _WeirdLabel: + def __init__(self, s): + self._s = s + + def __str__(self): + return self._s + + def __hash__(self): + return hash((id(self), self._s)) + + a = _WeirdLabel("horizon-1") + b = _WeirdLabel("horizon-1") # same str, different object + assert a is not b + assert str(a) == str(b) + resid_obj_collision = {a: rng.normal(0.0, 1.0, G), b: rng.normal(0.0, 1.0, G)} + fit_obj_collision = {a: np.zeros(G), b: np.zeros(G)} + with pytest.raises(ValueError, match="collision after str"): + stute_joint_pretest( + residuals_by_horizon=resid_obj_collision, + fitted_by_horizon=fit_obj_collision, + doses=d, + design_matrix=np.ones((G, 1)), + n_bootstrap=199, + seed=0, + ) + class TestJointPretrendsTest: """Tests for :func:`joint_pretrends_test` data-in wrapper.""" From e3f7450022fc0ca781a1ecdd59e0fabd33641b27 Mon Sep 17 00:00:00 2001 From: igerber Date: Thu, 23 Apr 2026 21:01:46 -0400 Subject: [PATCH 6/8] Address PR #353 CI review round 5 (1 P1 + 1 P3) P1 - stute_joint_pretest G<_MIN_G_STUTE warn+NaN contract: The joint core raised `ValueError` on G < 10, while single-horizon `stute_test` emits a `UserWarning` and returns a NaN result on the same condition. Because the event-study workflow dispatches into the joint core for both step-2 pre-trends and step-3 homogeneity, a staggered panel whose last-cohort auto-filter leaves fewer than 10 units would now crash the workflow instead of surfacing an inconclusive report - a regression versus Phase 3's two-period behavior. Fix: mirror the single-horizon contract. Emit `UserWarning` ("below the minimum ... Returning NaN result") and return a `StuteJointResult` with `cvm_stat_joint=nan`, `p_value=nan`, `reject=False`, and a full-NaN `per_horizon_stats` dict keyed by the validated horizon labels (so the diagnostic surface is consistent with the NaN-propagation branch). `n_bootstrap < _MIN_N_BOOTSTRAP` and non-numeric `alpha` still raise; only the small-G branch relaxes. Test updates: - `test_small_G_raises` renamed to `test_small_G_warns_returns_nan` and rewritten to assert the new contract. - New `test_event_study_small_panel_after_filter_inconclusive_not_ crash` covers the workflow-level regression: a staggered fixture with 40 early-cohort + 6 late-cohort units filters to G=6 after the validator's last-cohort auto-filter; `did_had_pretest_ workflow(aggregate="event_study")` now completes without exception, emits the "below the minimum" warning, and surfaces a NaN joint-Stute report with `all_pass=False`. P3 - module docstring refresh: `had_pretests.py` top-level docstring still said Phase 3 shipped steps 1 + 3 only, that step 2 was deferred, and that `did_had_pretest_workflow` was a two-period-only entry point. That drifted after the joint-pretest follow-up landed. Rewrote the docstring to describe: (a) the three single-horizon tests, (b) the three new joint helpers (`stute_joint_pretest`, `joint_pretrends_test`, `joint_homogeneity_test`), (c) both workflow dispatch modes (`aggregate="overall"` two-period and `aggregate="event_study"` multi-period), and (d) the narrowed deferment - only Eq. 18 linear-trend detrending remains, tracked in TODO for Phase 4 alongside the Pierce-Schott replication. 126 tests pass (125 + 1 new R5 workflow regression, -0 + 1 converted from raise to warn); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 93 ++++++++++++++++++++++++++++++++------ tests/test_had_pretests.py | 71 +++++++++++++++++++++++++++-- 2 files changed, 147 insertions(+), 17 deletions(-) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index 7409d431..0618b840 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -2,8 +2,9 @@ Paper Section 4 (de Chaisemartin, Ciccia, D'Haultfoeuille, Knau 2026, arXiv:2405.04465v6) prescribes a four-step pre-testing workflow for TWFE -validity in HADs. Phase 3 ships steps 1 and 3 of that workflow (step 2 is -deferred): +validity in HADs. This module ships the tests and the composite workflow: + +Single-horizon tests: 1. :func:`qug_test` - order-statistic ratio test of the support infimum ``H_0: d_lower = 0`` (paper Theorem 4). Closed-form, tuning-free. @@ -14,16 +15,43 @@ linearity test (paper Theorem 7 / Equation 29). Feasible at ``G >= 100k``. -The composite :func:`did_had_pretest_workflow` runs the three implemented -tests in sequence on a two-period HAD panel and returns a -:class:`HADPretestReport` with a partial-workflow verdict. When all three -fail-to-reject, the verdict explicitly flags that **the paper's step 2 -pre-trends test (Assumption 7) is NOT run** — callers do not receive an -unconditional "TWFE safe" signal; the Assumption 7 check must be performed -separately (e.g., via an event-study / placebo analysis) until the Phase 3 -follow-up patch lands the joint Equation 18 cross-horizon Stute variant. - -See ``docs/methodology/REGISTRY.md`` and ``TODO.md`` for the deferred items. +Joint / multi-period tests (Phase 3 follow-up): + +4. :func:`stute_joint_pretest` - residuals-in core that generalizes the + single-horizon Stute CvM to K horizons with shared-η wild bootstrap + and sum-of-CvMs aggregation (Delgado 1993; Escanciano 2006). +5. :func:`joint_pretrends_test` - data-in wrapper for the mean- + independence null (paper step 2 pre-trends across pre-period + placebos, Section 4.2 footnote 6 + Section 4.3 paragraph 1). +6. :func:`joint_homogeneity_test` - data-in wrapper for the linearity + null across post-periods (paper Section 4.3 joint extension, + page 32). + +Composite workflow: + +:func:`did_had_pretest_workflow` has two dispatch modes: + +- ``aggregate="overall"`` (default, two-period panel): runs steps 1 + 3 + via :func:`qug_test` + :func:`stute_test` + :func:`yatchew_hr_test`. + Paper step 2 is NOT run on this path (a two-period panel has no pre- + period placebo); the verdict explicitly flags the Assumption 7 gap + via the ``"paper step 2 deferred"`` caveat so callers do not get an + unconditional "TWFE safe" signal. +- ``aggregate="event_study"`` (multi-period panel, >= 3 periods): runs + QUG at ``F`` + joint pre-trends Stute across earlier pre-periods + + joint homogeneity-linearity Stute across post-periods. Closes the + paper step-2 gap and does NOT emit the step-2-deferred caveat in the + verdict when at least one earlier pre-period is available. Step 4 + (alternative linearity via Yatchew) is subsumed by joint Stute on + this path; the paper does not derive a joint Yatchew variant, so + users who need Yatchew robustness under multi-period data can call + :func:`yatchew_hr_test` on each ``(base, post)`` pair manually. + +Eq. 18 linear-trend detrending (paper Section 5.2 Pierce-Schott +application, published p=0.51) is the one remaining deferred item; +tracked in ``TODO.md`` and slated for Phase 4 alongside the replication +harness. See ``docs/methodology/REGISTRY.md`` for the full algorithm +narrative, invariants, and deviation notes. """ from __future__ import annotations @@ -1963,8 +1991,16 @@ def stute_joint_pretest( f"Found {int(np.sum(doses_arr < 0))} negative value(s)." ) - if G < _MIN_G_STUTE: - raise ValueError(f"Joint Stute test requires G >= {_MIN_G_STUTE} units; got " f"G = {G}.") + # G < _MIN_G_STUTE (CvM statistic not well-calibrated): mirror the + # single-horizon `stute_test` contract - warn + return NaN result + # rather than raise, so callers (including the event-study workflow + # on a staggered panel whose last-cohort filter leaves fewer than + # 10 units) get an inconclusive diagnostic instead of a crash. The + # NaN return still satisfies the workflow's `np.isfinite(p_value)` + # gating, so `all_pass` becomes False downstream. + # Note: the actual `warn + return` happens below after horizon + # labels are validated and collision-checked, so the NaN result + # carries full per-horizon diagnostic keys. if n_bootstrap < _MIN_N_BOOTSTRAP: raise ValueError(f"n_bootstrap must be >= {_MIN_N_BOOTSTRAP}; got " f"{n_bootstrap}.") if not isinstance(alpha, (int, float)) or not (0 < float(alpha) < 1): @@ -2025,6 +2061,35 @@ def stute_joint_pretest( # stringification is injective on the provided keys. horizon_labels = str_labels + # Small-G NaN result (paired with the comment near the top of this + # function): mirror the single-horizon stute_test contract so the + # event-study workflow on a small or staggered-filtered panel gets + # an inconclusive diagnostic rather than an exception. Positioned + # AFTER the label-collision / shape-alignment guards so the NaN + # result carries a consistent per-horizon diagnostic surface. + if G < _MIN_G_STUTE: + warnings.warn( + f"stute_joint_pretest: G = {G} is below the minimum " + f"{_MIN_G_STUTE} for the CvM statistic to be well-calibrated. " + f"Returning NaN result.", + UserWarning, + stacklevel=2, + ) + return StuteJointResult( + cvm_stat_joint=float("nan"), + p_value=float("nan"), + reject=False, + alpha=float(alpha), + horizon_labels=horizon_labels, + per_horizon_stats={k: float("nan") for k in horizon_labels}, + n_bootstrap=int(n_bootstrap), + n_obs=int(G), + n_horizons=int(K), + seed=None if seed is None else int(seed), + null_form=str(null_form), + exact_linear_short_circuited=False, + ) + if any_nan: return StuteJointResult( cvm_stat_joint=float("nan"), diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index 0f8dfa60..6276697f 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -1407,11 +1407,15 @@ def test_negative_dose_raises(self): seed=0, ) - def test_small_G_raises(self): + def test_small_G_warns_returns_nan(self): + """R5: G < _MIN_G_STUTE mirrors single-horizon stute_test - + warn + NaN result instead of raise. Prevents event-study + workflow crash when a last-cohort filter leaves fewer than 10 + units.""" G = 5 # below _MIN_G_STUTE (10) resid, fit, d = _multi_period_residuals(G, K=2) - with pytest.raises(ValueError, match="G >="): - stute_joint_pretest( + with pytest.warns(UserWarning, match="below the minimum"): + result = stute_joint_pretest( residuals_by_horizon=resid, fitted_by_horizon=fit, doses=d, @@ -1419,6 +1423,13 @@ def test_small_G_raises(self): n_bootstrap=199, seed=0, ) + assert np.isnan(result.cvm_stat_joint) + assert np.isnan(result.p_value) + assert result.reject is False + assert result.n_obs == G + # Full diagnostic surface preserved on the NaN result + assert set(result.per_horizon_stats.keys()) == set(str(k) for k in resid.keys()) + assert all(np.isnan(v) for v in result.per_horizon_stats.values()) def test_small_bootstrap_raises(self): G = 50 @@ -2453,6 +2464,60 @@ def test_event_study_all_conclusive_no_reject_admissible(self): verdict = _compose_verdict_event_study(qug, pretrends, homogeneity) assert "TWFE admissible under Section 4" in verdict + def test_event_study_small_panel_after_filter_inconclusive_not_crash(self): + """R5: staggered-panel last-cohort filter can leave fewer than + `_MIN_G_STUTE` (10) units. The joint Stute core must warn + + return NaN on small G (matching single-horizon stute_test) so + the event-study workflow surfaces an inconclusive report + rather than crashing. Regression against the original + ValueError-on-G<10 contract.""" + parts = [] + # First cohort: 40 units treated at 1999 - will be DROPPED by + # the last-cohort filter (F_last=2000 > 1999). + # Second cohort: only 6 units treated at 2000 - kept. After + # filter G = 6 < _MIN_G_STUTE, so the joint CvM is ill- + # calibrated and must return NaN via warn. + for cohort_ft, cohort_range in [(1999, (0, 40)), (2000, (40, 46))]: + for g in range(*cohort_range): + dose = 0.05 + 0.01 * (g - cohort_range[0]) + for t in [1997, 1998, 1999, 2000, 2001]: + is_post = t >= cohort_ft + parts.append( + { + "unit": g, + "period": t, + "y": 0.1 * g + (0.3 * dose if is_post else 0.0), + "d": dose if is_post else 0.0, + "first_treat": cohort_ft, + } + ) + df = pd.DataFrame(parts) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + report = did_had_pretest_workflow( + df, + "y", + "d", + "period", + "unit", + first_treat_col="first_treat", + aggregate="event_study", + n_bootstrap=199, + seed=0, + ) + # Workflow must complete (no crash) and surface an inconclusive + # report. Both joint tests (pretrends + homogeneity) should + # return NaN on the post-filter G=6 panel. + assert report.aggregate == "event_study" + if report.pretrends_joint is not None: + assert np.isnan(report.pretrends_joint.p_value) + assert report.homogeneity_joint is not None + assert np.isnan(report.homogeneity_joint.p_value) + assert report.all_pass is False + # At least one "below the minimum" warning from the joint core. + msgs = [str(w.message) for w in caught] + assert any("below the minimum" in m for m in msgs) + class TestOrderedCategoricalChronology: """R2 P1 regressions: ordered-categorical time columns whose lexical From f381ed53d6392cd0c05525828ab5484e44fb43e2 Mon Sep 17 00:00:00 2001 From: igerber Date: Fri, 24 Apr 2026 05:37:29 -0400 Subject: [PATCH 7/8] Address PR #353 CI review round 6 (1 P3) P3 - stute_joint_pretest docstring drift: The Raises block still listed `G < _MIN_G_STUTE` as a ValueError condition, but R5 converted that branch to a UserWarning + full-NaN StuteJointResult return to match single-horizon stute_test and keep the event-study workflow from crashing on staggered-filtered small panels. Fix: rewrote the Returns and Raises docstring blocks to describe the actual contract. Returns now enumerates the three NaN-result branches (small G, constant dose, any-NaN residuals / fitted) with their warning behavior. Raises is narrowed to the genuinely-raising conditions: empty input, key-mismatch, str-label collision, shape mismatch, negative doses, too-few bootstrap replicates, invalid alpha. Explicitly notes that small-G does NOT raise. No code changes; docstring-only edit. 126 tests still pass; black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index 0618b840..a80a27ed 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -1961,13 +1961,25 @@ def stute_joint_pretest( Returns ------- StuteJointResult + On the common path, a populated result with bootstrap-based + ``p_value`` and ``cvm_stat_joint``. On the small-sample branch + (``G < _MIN_G_STUTE``), constant-dose branch + (``np.ptp(doses) <= 0``), or any-NaN branch in the input + residuals / fitted arrays, returns an all-NaN result (with + ``reject=False`` and the full ``per_horizon_stats`` dict keyed + by the validated horizon labels) and emits a ``UserWarning`` + for the first two branches. Mirrors the single-horizon + :func:`stute_test` contract so event-study workflows on small + or staggered-filtered panels surface an inconclusive report + rather than crashing. Raises ------ ValueError - On empty input, key-mismatch, shape-mismatch, ``doses`` - containing negative values, ``G < _MIN_G_STUTE``, or - ``n_bootstrap < _MIN_N_BOOTSTRAP``. + On empty input, key-mismatch, stringified-label collisions + between distinct raw keys, shape-mismatch, ``doses`` containing + negative values, ``n_bootstrap < _MIN_N_BOOTSTRAP``, or invalid + ``alpha``. ``G < _MIN_G_STUTE`` does NOT raise; see Returns. """ if not isinstance(residuals_by_horizon, dict) or not isinstance(fitted_by_horizon, dict): raise ValueError( From 7c7d5cdbffb5e7a8e6df3ccae9a788047d6e03d1 Mon Sep 17 00:00:00 2001 From: igerber Date: Fri, 24 Apr 2026 05:51:07 -0400 Subject: [PATCH 8/8] Address PR #353 CI review round 7 (1 P2 + 2 P3) P2 - explicit ValueError on singular design_matrix: `stute_joint_pretest` previously surfaced a raw `np.linalg.LinAlgError` to direct callers when `design_matrix` was rank-deficient (e.g. duplicate columns), breaking the front-door validation style of the rest of the function. Wrap the `np.linalg.solve(X.T @ X, X.T)` precompute in a try/except and re-raise as `ValueError` with a message naming the likely cause (linearly-dependent columns) and the shape. Regression: new `TestStuteJointPretest::test_singular_design_matrix_raises_valueerror` constructs a (G, 2) design with two identical columns and asserts the explicit `ValueError("rank-deficient")`. P3 - Yatchew "step 4" -> "step 3 alternative" docstring drift: Two docstrings (module header and `_compose_verdict_event_study`) referred to the Yatchew-HR test as "step 4". Paper Section 4.2-4.3 defines step 4 as the final admissibility decision ("use TWFE if none of the tests rejects"), not a separate diagnostic; Yatchew is the alternative linearity test within step 3. Updated both docstrings to describe Yatchew as the step-3 alternative (subsumed by joint Stute on the event-study path) and clarified that paper step 4 has no separate code path. P3 - `joint_homogeneity_test` post_periods docstring: Text said `>= base_period` but the actual guard is strict `> base_period` in chronological order. Tightened the Parameters block to match. 127 tests pass (126 + 1 new R7 regression); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/had_pretests.py | 35 +++++++++++++++++++++++++---------- tests/test_had_pretests.py | 21 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/diff_diff/had_pretests.py b/diff_diff/had_pretests.py index a80a27ed..b2587665 100644 --- a/diff_diff/had_pretests.py +++ b/diff_diff/had_pretests.py @@ -41,11 +41,13 @@ QUG at ``F`` + joint pre-trends Stute across earlier pre-periods + joint homogeneity-linearity Stute across post-periods. Closes the paper step-2 gap and does NOT emit the step-2-deferred caveat in the - verdict when at least one earlier pre-period is available. Step 4 - (alternative linearity via Yatchew) is subsumed by joint Stute on - this path; the paper does not derive a joint Yatchew variant, so + verdict when at least one earlier pre-period is available. The + step-3 alternative (Yatchew-HR linearity) is subsumed by joint Stute + on this path; the paper does not derive a joint Yatchew variant, so users who need Yatchew robustness under multi-period data can call :func:`yatchew_hr_test` on each ``(base, post)`` pair manually. + (Step 4 in the paper's workflow is the decision itself - "use TWFE + if none of the tests rejects" - not a separate test.) Eq. 18 linear-trend detrending (paper Section 5.2 Pierce-Schott application, published p=0.51) is the one remaining deferred item; @@ -1845,9 +1847,10 @@ def _compose_verdict_event_study( follow-up" caveat - this PR closes that gap. - Step 3 (Assumption 8 linearity/homogeneity): runs via ``homogeneity_joint`` (joint Stute only; no joint Yatchew variant - exists in the paper). - - Step 4 (alternative linearity via Yatchew): not run on the - event-study path; adjudicated by joint Stute above. + exists in the paper). The step-3 alternative Yatchew-HR test is + subsumed by joint Stute on this path. (Paper step 4 is the + decision itself - "use TWFE if none of the tests rejects" - not + a separate diagnostic, so it has no code path here.) Priority: 1. Any conclusive test rejecting → primary verdict bundles each @@ -2202,7 +2205,19 @@ def stute_joint_pretest( # Precompute OLS projection matrix once: same X per bootstrap draw, # so (X'X)^-1 X' is constant across iterations. Keeps refit O(Gp) # per draw without changing semantics from the literal paper form. - XtX_inv_Xt = np.linalg.solve(X.T @ X, X.T) + # Catch rank-deficient designs explicitly rather than surfacing a + # raw ``np.linalg.LinAlgError`` to direct callers of the public + # residuals-in core; matches the front-door validation style of + # the other guards in this function. + try: + XtX_inv_Xt = np.linalg.solve(X.T @ X, X.T) + except np.linalg.LinAlgError as exc: + raise ValueError( + f"design_matrix is rank-deficient (singular X^T X); cannot " + f"compute the OLS projection (X^T X)^-1 X^T for the " + f"bootstrap refit. Check for duplicate or linearly-" + f"dependent columns. shape={X.shape}." + ) from exc rng = np.random.default_rng(seed) bootstrap_S = np.empty(n_bootstrap, dtype=np.float64) @@ -2462,9 +2477,9 @@ def joint_homogeneity_test( data : pd.DataFrame outcome_col, dose_col, time_col, unit_col : str post_periods : list - Non-empty list of post-period labels (all ``>= base_period`` by - time order; each with ``D > 0`` for some unit, i.e. at least one - treated unit per horizon). + Non-empty list of post-period labels (all strictly ``> + base_period`` by chronological order; each with ``D > 0`` for + some unit, i.e. at least one treated unit per horizon). base_period : period label The reference period (last pre-period in the event-study convention). Must not be in ``post_periods``. diff --git a/tests/test_had_pretests.py b/tests/test_had_pretests.py index 6276697f..d5d1fabb 100644 --- a/tests/test_had_pretests.py +++ b/tests/test_had_pretests.py @@ -1586,6 +1586,27 @@ def test_constant_d_returns_nan_with_warning(self): assert set(result.per_horizon_stats.keys()) == set(resid.keys()) assert all(np.isnan(v) for v in result.per_horizon_stats.values()) + def test_singular_design_matrix_raises_valueerror(self): + """R7 P2: rank-deficient custom design_matrix (e.g. duplicate + columns) must raise an explicit ValueError from the front-door, + not a raw np.linalg.LinAlgError from the internal solve().""" + G = 30 + rng = np.random.default_rng(801) + d = rng.uniform(0.0, 1.0, G) + resid = {"h0": rng.normal(0.0, 1.0, G), "h1": rng.normal(0.0, 1.0, G)} + fit = {"h0": np.zeros(G), "h1": np.zeros(G)} + # design_matrix with two identical columns (rank deficient). + singular_X = np.column_stack([d, d]) + with pytest.raises(ValueError, match="rank-deficient"): + stute_joint_pretest( + residuals_by_horizon=resid, + fitted_by_horizon=fit, + doses=d, + design_matrix=singular_X, + n_bootstrap=199, + seed=0, + ) + def test_stringified_key_collision_raises(self): """R4 P1 regression: two raw keys whose str() representations collide (e.g. int 1 and str '1', or int 1 and float 1.0) must