diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index 019796c3..8f433b25 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -227,7 +227,50 @@ def fit( f"Dose must be time-invariant. Units with varying dose: {bad_units[:5]}" ) - # Normalize first_treat: inf → 0 + # Normalize first_treat: +inf → 0 (R-style never-treated encoding). + # Count rows recategorized so users can see how many units just + # crossed from "treated at some point" to "never treated" — silent + # recategorization here would shift the control composition (axis-E + # silent coercion). Only positive infinity is recoded (to match the + # existing `.replace([np.inf, float("inf")], 0)` semantics on the + # next line). + first_treat_vals = df[first_treat].values + # Reject NaN first_treat explicitly. NaN survives preprocessing but + # satisfies neither the treated (g > 0) nor never-treated (g == 0) + # mask, so affected units would be silently excluded from the + # estimator (same silent-failure shape as `first_treat < 0`). + nan_mask = pd.isna(df[first_treat]) + n_nan_first_treat = int(nan_mask.sum()) + if n_nan_first_treat > 0: + raise ValueError( + f"{n_nan_first_treat} row(s) have NaN '{first_treat}' " + f"values. Valid values are 0 (never-treated) or a positive " + f"treatment period; such units would otherwise be silently " + f"excluded from both treated and control pools." + ) + inf_mask = np.isposinf(first_treat_vals) + n_inf_first_treat = int(inf_mask.sum()) + if n_inf_first_treat > 0: + warnings.warn( + f"{n_inf_first_treat} row(s) have inf in '{first_treat}'; " + f"treating the corresponding units as never-treated. Pass an " + f"explicit never-treated marker (0) if this is not intended.", + UserWarning, + stacklevel=2, + ) + # Reject negative first_treat values (including -inf) explicitly. + # Without this guard they would survive preprocessing but fall out of + # both the treated (g > 0) and never-treated (g == 0) masks, silently + # excluding the affected units. + negative_mask = first_treat_vals < 0 + n_negative_first_treat = int(negative_mask.sum()) + if n_negative_first_treat > 0: + raise ValueError( + f"{n_negative_first_treat} row(s) have negative '{first_treat}' " + f"values (including -inf). Valid values are 0 (never-treated) " + f"or a positive treatment period; such units would otherwise " + f"be silently excluded from both treated and control pools." + ) df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0) # Drop units with positive first_treat but zero dose (R convention) @@ -265,9 +308,22 @@ def fit( stacklevel=2, ) - # Force dose=0 for never-treated units with nonzero dose + # Force dose=0 for never-treated units with nonzero dose. Report the + # affected row count via UserWarning so users can see whether their + # never-treated rows had unintended nonzero doses — silent zeroing + # here would quietly shift part of the control trajectory (axis-E + # silent coercion, paired with the `first_treat=inf -> 0` fix above). never_treated_mask = df[first_treat] == 0 - if (df.loc[never_treated_mask, dose] != 0).any(): + nonzero_dose_rows = never_treated_mask & (df[dose] != 0) + n_nonzero_dose_never_treated = int(nonzero_dose_rows.sum()) + if n_nonzero_dose_never_treated > 0: + warnings.warn( + f"{n_nonzero_dose_never_treated} row(s) have '{first_treat}'=0 " + f"(never-treated) but nonzero '{dose}'; zeroing the dose. Pass " + f"dose=0 for never-treated rows to avoid this coercion.", + UserWarning, + stacklevel=2, + ) df.loc[never_treated_mask, dose] = 0.0 # Verify balanced panel diff --git a/diff_diff/staggered_triple_diff.py b/diff_diff/staggered_triple_diff.py index 758d518b..85086aa3 100644 --- a/diff_diff/staggered_triple_diff.py +++ b/diff_diff/staggered_triple_diff.py @@ -284,6 +284,19 @@ def fit( if first_treat != "first_treat": df["first_treat"] = df[first_treat] + # Surface the inf → 0 recategorization the same way StaggeredDiD does + # (see `staggered.py:1508-1519`). Silently recoding inf would shift + # units between treated and never-treated pools with no signal + # (axis-E silent coercion under the Phase 2 audit). + _inf_mask = np.isposinf(df["first_treat"].values) + if _inf_mask.any(): + n_inf_rows = int(_inf_mask.sum()) + warnings.warn( + f"{n_inf_rows} row(s) have first_treat=inf; recoding to 0 " + f"(never-treated). Use first_treat=0 to suppress this warning.", + UserWarning, + stacklevel=2, + ) df["first_treat"] = df["first_treat"].replace([np.inf, float("inf")], 0) precomputed = self._precompute_structures( diff --git a/diff_diff/utils.py b/diff_diff/utils.py index 704a2138..47df08e1 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -821,7 +821,8 @@ def check_parallel_trends_robust( # Compute outcome changes treated_changes, control_changes = _compute_outcome_changes( - pre_data, outcome, time, treatment_group, unit + pre_data, outcome, time, treatment_group, unit, + caller_label="check_parallel_trends_robust", ) if len(treated_changes) < 2 or len(control_changes) < 2: @@ -897,7 +898,12 @@ def check_parallel_trends_robust( def _compute_outcome_changes( - data: pd.DataFrame, outcome: str, time: str, treatment_group: str, unit: Optional[str] = None + data: pd.DataFrame, + outcome: str, + time: str, + treatment_group: str, + unit: Optional[str] = None, + caller_label: str = "parallel-trend diagnostic", ) -> Tuple[np.ndarray, np.ndarray]: """ Compute period-to-period outcome changes for treated and control groups. @@ -925,7 +931,24 @@ def _compute_outcome_changes( data_sorted = data.sort_values([unit, time]) data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff() - # Remove NaN from first period of each unit + # Remove NaN from first period of each unit. The first period per unit + # has no prior observation to diff against, so n_units drops are + # expected. Anything beyond that is a silent side-effect of gaps or + # NaN outcomes — surface the excess via warning (axis-E drop counter). + n_units_observed = int(data_sorted[unit].nunique()) + n_dropped = int(data_sorted["_outcome_change"].isna().sum()) + n_unexpected_drops = max(0, n_dropped - n_units_observed) + if n_unexpected_drops > 0: + warnings.warn( + f"{caller_label}: dropped {n_dropped} row(s) with NaN " + f"first-differences; {n_units_observed} are the expected " + f"first-period-per-unit drops, and {n_unexpected_drops} are " + f"additional NaN first-differences (e.g. NaN outcomes or " + f"unit-period gaps upstream). Parallel-trend statistics are " + f"computed on the remaining rows.", + UserWarning, + stacklevel=3, + ) changes_data = data_sorted.dropna(subset=["_outcome_change"]) treated_changes = changes_data[changes_data[treatment_group] == 1]["_outcome_change"].values @@ -1001,7 +1024,8 @@ def equivalence_test_trends( # Compute outcome changes treated_changes, control_changes = _compute_outcome_changes( - pre_data, outcome, time, treatment_group, unit + pre_data, outcome, time, treatment_group, unit, + caller_label="equivalence_test_trends", ) # Need at least 2 observations per group to compute variance diff --git a/diff_diff/wooldridge.py b/diff_diff/wooldridge.py index 4bc2c0ce..fb63a29a 100644 --- a/diff_diff/wooldridge.py +++ b/diff_diff/wooldridge.py @@ -13,6 +13,7 @@ from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -112,6 +113,26 @@ def _resolve_survey_for_wooldridge(survey_design, sample, cluster_ids, cluster_n return resolved, survey_weights, survey_weight_type, survey_metadata, df_inf +def _warn_and_fill_nan_cohort(df: pd.DataFrame, cohort: str, stacklevel: int) -> pd.DataFrame: + """Fill NaN cohort with 0 (never-treated) and warn with the row count. + + Used by both `_filter_sample` (pre-fit) and `WooldridgeDiD.fit()` so the + silent recategorization is surfaced on whichever entry path the caller + hits first. See REGISTRY.md §WooldridgeDiD (axis-E silent coercion). + """ + n_nan_cohort = int(df[cohort].isna().sum()) + if n_nan_cohort > 0: + warnings.warn( + f"{n_nan_cohort} row(s) have NaN cohort values; filling with 0 " + f"and treating the corresponding units as never-treated. Pass " + f"an explicit never-treated marker (0) if this is not intended.", + UserWarning, + stacklevel=stacklevel, + ) + df[cohort] = df[cohort].fillna(0) + return df + + def _filter_sample( data: pd.DataFrame, unit: str, @@ -128,8 +149,7 @@ def _filter_sample( (see _build_interaction_matrix). """ df = data.copy() - # Normalise never-treated: fill NaN cohort with 0 - df[cohort] = df[cohort].fillna(0) + df = _warn_and_fill_nan_cohort(df, cohort, stacklevel=3) treated_mask = df[cohort] > 0 @@ -396,7 +416,7 @@ def fit( ``NotImplementedError``. """ df = data.copy() - df[cohort] = df[cohort].fillna(0) + df = _warn_and_fill_nan_cohort(df, cohort, stacklevel=2) # 0a. Validate cohort is time-invariant within unit cohort_per_unit = df.groupby(unit)[cohort].nunique() diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index bdcfef34..1504a67b 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -720,6 +720,8 @@ See `docs/methodology/continuous-did.md` Section 4 for full details. - [ ] Lowest-dose-as-control (Remark 3.1) - [x] Survey design support (Phase 3): weighted B-spline OLS, TSL on influence functions; bootstrap+survey supported (Phase 6) - **Note:** ContinuousDiD bootstrap with survey weights supported (Phase 6) via PSU-level multiplier weights +- **Note:** The R-style convention of coding never-treated units as `first_treat=inf` is still accepted and normalized to `first_treat=0` internally, but the estimator now emits a `UserWarning` reporting the row count so the silent recategorization is surfaced (axis-E silent coercion under the Phase 2 audit). Only `+inf` is recoded (matching the R convention). Any **negative** `first_treat` value (including `-inf`) raises `ValueError` with the row count, since such units would otherwise silently fall out of both the treated (`g > 0`) and never-treated (`g == 0`) masks. Pass `0` directly for never-treated units to avoid the warning. +- **Note:** Rows where `first_treat=0` (never-treated) carry a nonzero `dose` are silently zeroed for internal consistency (never-treated cells must have `D=0` in the dose response). The estimator now emits a `UserWarning` with the affected row count before the zeroing, so unintended nonzero doses on never-treated rows are no longer absorbed without a signal (axis-E silent coercion). --- @@ -1303,6 +1305,7 @@ The saturated ETWFE regression includes: The interaction coefficient `δ_{g,t}` identifies `ATT(g, t)` under parallel trends. - **Note:** OLS path uses iterative alternating-projection within-transformation (uniform weights) for exact FE absorption on both balanced and unbalanced panels. One-pass demeaning (`y - ȳ_i - ȳ_t + ȳ`) is only exact for balanced panels. - **Note:** The weighted within-transformation (`utils.within_transform` with `weights`) is invoked on every WooldridgeDiD fit (survey weights when provided, `np.ones` otherwise) and emits a `UserWarning` on non-convergence per the shared convention documented under *Absorbed Fixed Effects with Survey Weights*. +- **Note:** NaN values in the `cohort` column are filled with 0 (treated as never-treated), both in `_filter_sample` and in `fit()`. This recategorization now emits a `UserWarning` reporting the affected row count so it is no longer silent (axis-E silent coercion under the Phase 2 audit). Pass `0` directly for never-treated units to avoid the warning. *Nonlinear extensions (Wooldridge 2023):* @@ -1689,6 +1692,7 @@ Balanced panel. Key variables: - `Q_i` (`eligibility`): binary, time-invariant eligibility indicator - Treatment: `D_{i,t} = 1{t >= S_i AND Q_i = 1}` (absorbing) - Covariates `X_i`: time-invariant (first observation per unit used) +- **Note:** `first_treat=inf` (R-style never-enabled marker) is accepted and normalized to `0` internally. The recoding now emits a `UserWarning` reporting the affected row count so the reclassification is not silent (axis-E silent coercion under the Phase 2 audit, mirroring the StaggeredDiD behavior). Pass `first_treat=0` directly to avoid the warning. *Estimator equation (Equation 4.1 in paper, as implemented):* diff --git a/tests/test_continuous_did.py b/tests/test_continuous_did.py index c519c3bb..56c773cb 100644 --- a/tests/test_continuous_did.py +++ b/tests/test_continuous_did.py @@ -641,16 +641,195 @@ def test_few_treated_units(self): assert isinstance(results, ContinuousDiDResults) def test_inf_first_treat_normalization(self): - """first_treat=inf should be treated as never-treated.""" + """first_treat=inf should be treated as never-treated, and the caller + must receive a UserWarning reporting the affected row count so the + recategorization is not silent (axis-E counter).""" data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) data["first_treat"] = data["first_treat"].astype(float) - data.loc[data["first_treat"] == 0, "first_treat"] = np.inf + inf_mask = data["first_treat"] == 0 + n_inf_rows = int(inf_mask.sum()) + data.loc[inf_mask, "first_treat"] = np.inf est = ContinuousDiD() - results = est.fit( - data, "outcome", "unit", "period", "first_treat", "dose" - ) + + with pytest.warns( + UserWarning, + match=rf"{n_inf_rows} row\(s\) have inf in 'first_treat'", + ): + results = est.fit( + data, "outcome", "unit", "period", "first_treat", "dose" + ) assert results.n_control_units > 0 + def test_no_inf_first_treat_no_warning(self): + """No inf rows in first_treat — no recategorization warning.""" + import warnings + + data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) + data["first_treat"] = data["first_treat"].astype(float) + est = ContinuousDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + inf_warnings = [x for x in w if "inf in 'first_treat'" in str(x.message)] + assert inf_warnings == [] + + def test_nonzero_dose_on_never_treated_warns(self): + """first_treat=0 (never-treated) rows with nonzero dose must now surface + a UserWarning with the affected row count before the zeroing coercion. + Before PR #331's CI-review follow-up this was silent.""" + # 4 units x 3 periods (12 rows). 2 units are never-treated (first_treat=0) + # but carry dose=1.5 on every row — 6 rows should be reported. + rows = [] + for unit in range(4): + if unit < 2: + ft, dose_val = 0.0, 1.5 # never-treated with nonzero dose + else: + ft, dose_val = 2.0, 1.0 # treated + for t in range(1, 4): + rows.append({ + "unit": unit, "period": t, "outcome": float(unit + t), + "first_treat": ft, "dose": dose_val, + }) + data = pd.DataFrame(rows) + est = ContinuousDiD() + + with pytest.warns( + UserWarning, + match=r"6 row\(s\) have 'first_treat'=0 \(never-treated\) but nonzero 'dose'", + ): + try: + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + except Exception: + # Downstream validation may reject this minimal panel (too few + # treated for OLS); we only care about the dose-coercion warning. + pass + + def test_clean_never_treated_doses_silent(self): + """Never-treated rows with dose=0 must not trigger the coercion warning.""" + import warnings + data = generate_continuous_did_data(n_units=50, n_periods=3, seed=42) + # generate_continuous_did_data already sets dose=0 for never-treated. + est = ContinuousDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + coerce_warnings = [ + x for x in w + if "never-treated" in str(x.message) and "nonzero 'dose'" in str(x.message) + ] + assert coerce_warnings == [] + + def test_negative_first_treat_raises_with_row_count(self): + """Negative `first_treat` (including -inf) must raise ValueError with + the affected row count. Without this guard the affected units fall + out of both the treated (g > 0) and never-treated (g == 0) masks and + are silently excluded from the estimator.""" + rows = [] + for unit in range(4): + # Unit 0: -inf. Unit 1: -2. Others: valid (0 or positive). + if unit == 0: + ft = -np.inf + elif unit == 1: + ft = -2.0 + else: + ft = 0.0 + for t in range(1, 4): + rows.append({ + "unit": unit, "period": t, "outcome": float(unit + t), + "first_treat": ft, "dose": 0.0, + }) + data = pd.DataFrame(rows) + est = ContinuousDiD() + + with pytest.raises( + ValueError, + match=r"6 row\(s\) have negative 'first_treat' values", + ): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_nan_first_treat_raises_with_row_count(self): + """NaN `first_treat` must raise ValueError with the row count. Without + this guard, NaN rows survive preprocessing but match neither the + treated (g > 0) nor never-treated (g == 0) mask, so the affected + units would be silently excluded.""" + rows = [] + for unit in range(4): + # Unit 0 has NaN first_treat across all 3 periods (3 NaN rows). + ft = np.nan if unit == 0 else 0.0 + for t in range(1, 4): + rows.append({ + "unit": unit, "period": t, "outcome": float(unit + t), + "first_treat": ft, "dose": 0.0, + }) + data = pd.DataFrame(rows) + est = ContinuousDiD() + + with pytest.raises( + ValueError, + match=r"3 row\(s\) have NaN 'first_treat' values", + ): + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + + def test_positive_inf_warning_silent_when_no_inf(self): + """+inf warning is gated on +inf rows only; panels with only valid + non-negative values (including just 0 and positive periods) must + never trigger the recategorization warning.""" + import warnings + rows = [] + for unit in range(4): + ft = 0.0 if unit < 2 else 2.0 + for t in range(1, 4): + rows.append({ + "unit": unit, "period": t, "outcome": float(unit + t), + "first_treat": ft, "dose": 0.0 if unit < 2 else 1.0, + }) + data = pd.DataFrame(rows) + est = ContinuousDiD() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + except Exception: + pass + + inf_warnings = [x for x in w if "inf in 'first_treat'" in str(x.message)] + assert inf_warnings == [] + + def test_inf_first_treat_warning_counts_rows_not_units(self): + """The warning counts affected rows (not units). On a panel with + multiple periods per unit, each inf row must count separately so the + message surface matches the per-row semantics of `.replace(inf, 0)`.""" + # Build a 4-unit, 3-period panel (12 rows). 2 units have inf across + # all 3 periods → 6 inf rows, 2 units, so row-count != unit-count. + rows = [] + for unit in range(4): + ft = np.inf if unit < 2 else 2.0 + dose = 0.0 if unit < 2 else 1.0 + for t in range(1, 4): + rows.append({ + "unit": unit, "period": t, "outcome": float(unit + t), + "first_treat": ft, "dose": dose, + }) + data = pd.DataFrame(rows) + est = ContinuousDiD() + + with pytest.warns( + UserWarning, + match=r"6 row\(s\) have inf in 'first_treat'", + ): + try: + est.fit(data, "outcome", "unit", "period", "first_treat", "dose") + except Exception: + # Downstream validation may reject this minimal panel (too few + # treated for OLS). We only care that the inf-row warning fires + # with the correct row count. + pass + def test_custom_dvals(self): data = generate_continuous_did_data(n_units=100, n_periods=3, seed=42) custom_grid = np.array([1.0, 2.0, 3.0]) diff --git a/tests/test_staggered_triple_diff.py b/tests/test_staggered_triple_diff.py index 6954b255..a47d1758 100644 --- a/tests/test_staggered_triple_diff.py +++ b/tests/test_staggered_triple_diff.py @@ -397,12 +397,23 @@ def test_missing_column_raises(self, simple_data): est.fit(simple_data, "outcome", "unit", "period", "nonexistent", "eligibility") def test_inf_first_treat_works(self): - """Never-enabled units encoded as inf should work.""" + """Never-enabled units encoded as inf should be recoded to 0, and the + recoding must surface a UserWarning with the affected row count + (axis-E silent coercion, mirroring the StaggeredDiD behavior).""" data = generate_staggered_ddd_data(n_units=100, seed=33) data["first_treat"] = data["first_treat"].astype(float) - data.loc[data["first_treat"] == 0, "first_treat"] = np.inf + inf_mask = data["first_treat"] == 0 + n_inf_rows = int(inf_mask.sum()) + data.loc[inf_mask, "first_treat"] = np.inf est = StaggeredTripleDifference() - res = est.fit(data, "outcome", "unit", "period", "first_treat", "eligibility") + + with pytest.warns( + UserWarning, + match=rf"{n_inf_rows} row\(s\) have first_treat=inf; recoding to 0", + ): + res = est.fit( + data, "outcome", "unit", "period", "first_treat", "eligibility" + ) assert np.isfinite(res.overall_att) def test_survey_design_invalid_type_raises(self, simple_data): diff --git a/tests/test_utils.py b/tests/test_utils.py index 61a0ed80..61f2d353 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -827,6 +827,86 @@ def test_changes_reflect_trend(self): np.testing.assert_array_almost_equal(treated_changes, 2.0, decimal=5) np.testing.assert_array_almost_equal(control_changes, 2.0, decimal=5) + def test_silent_on_balanced_panel(self): + """Balanced panel: only first-period-per-unit drops, no warning.""" + import warnings + + rng = np.random.default_rng(0) + rows = [] + for unit in range(10): + treated = int(unit >= 5) + for t in range(1, 5): + rows.append({ + "unit": unit, "period": t, + "treated": treated, "outcome": rng.normal(), + }) + df = pd.DataFrame(rows) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _compute_outcome_changes( + df, outcome="outcome", time="period", + treatment_group="treated", unit="unit", + ) + + # Generic filter on "dropped" catches both the old and new label so a + # regression in the label wouldn't hide a real silent-drop warning. + drop_warnings = [x for x in w if "dropped" in str(x.message).lower()] + assert drop_warnings == [] + + def test_warns_on_nan_outcomes_with_excess_drop_count(self): + """Extra NaN-outcome rows beyond first-period drops must surface via + a UserWarning reporting the excess count (axis-E drop counter).""" + rng = np.random.default_rng(0) + rows = [] + for unit in range(10): + treated = int(unit >= 5) + for t in range(1, 5): + rows.append({ + "unit": unit, "period": t, + "treated": treated, "outcome": rng.normal(), + }) + df = pd.DataFrame(rows) + df.loc[[5, 12, 22], "outcome"] = np.nan + + with pytest.warns( + UserWarning, + match=r"parallel-trend diagnostic: dropped \d+ row\(s\).*additional NaN first-differences", + ): + _compute_outcome_changes( + df, outcome="outcome", time="period", + treatment_group="treated", unit="unit", + ) + + def test_warning_label_reflects_public_caller(self): + """`check_parallel_trends_robust` and `equivalence_test_trends` must + each surface the axis-E excess-drop warning under their own name so + users can trace the signal back to the function they called.""" + rng = np.random.default_rng(0) + rows = [] + for unit in range(10): + treated = int(unit >= 5) + for t in range(1, 5): + rows.append({ + "unit": unit, "period": t, + "treated": treated, "outcome": rng.normal(), + }) + df = pd.DataFrame(rows) + df.loc[[5, 12, 22], "outcome"] = np.nan + + with pytest.warns(UserWarning, match="check_parallel_trends_robust:"): + check_parallel_trends_robust( + df, outcome="outcome", time="period", + treatment_group="treated", unit="unit", + n_permutations=100, seed=0, + ) + + with pytest.warns(UserWarning, match="equivalence_test_trends:"): + equivalence_test_trends( + df, outcome="outcome", time="period", + treatment_group="treated", unit="unit", + ) + # ============================================================================= # Tests for check_parallel_trends_robust diff --git a/tests/test_wooldridge.py b/tests/test_wooldridge.py index 19027fce..2cedbf56 100644 --- a/tests/test_wooldridge.py +++ b/tests/test_wooldridge.py @@ -1593,3 +1593,50 @@ def test_survey_aggregate_and_summary(self, survey_panel): s = r.summary() assert "Survey Design" in s assert "pweight" in s + + +class TestCohortNaNWarning: + """Axis-E: silent recategorization of NaN cohort rows as never-treated.""" + + @staticmethod + def _make_panel_with_nan_cohort(): + rows = [] + for unit in range(10): + cohort_val = np.nan if unit < 2 else 0.0 + for t in range(1, 5): + rows.append({ + "unit": unit, "time": t, "cohort": cohort_val, + "y": unit + t + np.random.default_rng(unit).normal(0, 0.1), + }) + return pd.DataFrame(rows) + + def test_fit_warns_on_nan_cohort_with_count(self): + df = self._make_panel_with_nan_cohort() + est = WooldridgeDiD(method="ols") + with pytest.warns(UserWarning, match=r"8 row\(s\) have NaN cohort values"): + try: + est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + except Exception: + pass + + def test_fit_silent_on_clean_cohort(self): + import warnings + df = self._make_panel_with_nan_cohort() + df["cohort"] = df["cohort"].fillna(0) + est = WooldridgeDiD(method="ols") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + except Exception: + pass + nan_warnings = [x for x in w if "NaN cohort values" in str(x.message)] + assert nan_warnings == [] + + def test_select_sample_helper_warns(self): + df = self._make_panel_with_nan_cohort() + with pytest.warns(UserWarning, match=r"8 row\(s\) have NaN cohort values"): + _filter_sample( + df, unit="unit", time="time", cohort="cohort", + control_group="never_treated", anticipation=0, + )