diff --git a/diff_diff/utils.py b/diff_diff/utils.py index 4a7ba5eb..704a2138 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -1460,12 +1460,15 @@ def _sc_weight_fw_numpy( lam = np.ones(T0) / T0 vals = np.full(max_iter, np.nan) + converged = False for t in range(max_iter): lam = _fw_step(A, lam, b, eta) err = Y @ np.append(lam, -1.0) vals[t] = zeta**2 * np.sum(lam**2) + np.sum(err**2) / N if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2: + converged = True break + warn_if_not_converged(converged, "Frank-Wolfe SC weight solver", max_iter, min_decrease) return lam diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 391482d2..b4cd419f 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1496,7 +1496,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi P-value: analytical (normal distribution), not empirical. *Edge cases:* -- **Frank-Wolfe non-convergence**: Returns current weights after max_iter iterations. No warning emitted; the convergence check `vals[t-1] - vals[t] < min_decrease²` simply does not trigger early exit, and the final iterate is returned. +- **Frank-Wolfe non-convergence**: Returns current weights after max_iter iterations when the convergence check `vals[t-1] - vals[t] < min_decrease²` never triggers early exit. The numpy-backend path (`_sc_weight_fw_numpy`) emits a `UserWarning` via `diff_diff.utils.warn_if_not_converged` in that case; the Rust-backend path silently returns the final iterate (Rust-side signature change required to thread convergence status — tracked as an axis-G backend-parity follow-up). - **`_sparsify` all-zero input**: If `max(v) <= 0`, returns uniform weights `ones(len(v)) / len(v)`. - **Single control unit**: `compute_sdid_unit_weights` returns `[1.0]` immediately (short-circuit before Frank-Wolfe). - **Zero control units**: `compute_sdid_unit_weights` returns empty array `[]`. diff --git a/tests/test_methodology_sdid.py b/tests/test_methodology_sdid.py index 36b8162a..c7a77545 100644 --- a/tests/test_methodology_sdid.py +++ b/tests/test_methodology_sdid.py @@ -21,6 +21,7 @@ _compute_regularization, _fw_step, _sc_weight_fw, + _sc_weight_fw_numpy, _sparsify, _sum_normalize, compute_sdid_estimator, @@ -207,6 +208,58 @@ def test_intercept_centering(self): # They should be different because centering matters assert not np.allclose(lam_intercept, lam_no_intercept, atol=1e-3) + def test_fw_warns_on_nonconvergence(self): + """Silent-failure audit axis B: _sc_weight_fw_numpy must warn when max_iter exhausts.""" + rng = np.random.default_rng(42) + Y = rng.standard_normal((15, 7)) # (N, T0+1) with T0=6 + + with pytest.warns(UserWarning, match="did not converge"): + _sc_weight_fw_numpy(Y, zeta=0.1, max_iter=1, min_decrease=1e-12) + + def test_fw_no_warning_on_convergence(self): + """Silent-failure audit axis B: no warning on well-conditioned convergent input.""" + rng = np.random.default_rng(42) + Y = rng.standard_normal((15, 7)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _sc_weight_fw_numpy(Y, zeta=0.1, max_iter=10000, min_decrease=1e-3) + assert not any("did not converge" in str(x.message) for x in w) + + def test_fw_wrapper_warns_on_nonconvergence_without_rust(self): + """Silent-failure audit axis B: public _sc_weight_fw wrapper must route + warnings through even when called via the dispatcher with the Rust + backend disabled. Pins the contract against refactors that would + bypass the numpy path.""" + rng = np.random.default_rng(42) + Y = rng.standard_normal((15, 7)) + + with patch("diff_diff.utils.HAS_RUST_BACKEND", False): + with pytest.warns(UserWarning, match="did not converge"): + _sc_weight_fw(Y, zeta=0.1, max_iter=1, min_decrease=1e-12) + + def test_fw_wrapper_no_warning_on_convergence_without_rust(self): + """Silent-failure audit axis B: wrapper-level negative control.""" + rng = np.random.default_rng(42) + Y = rng.standard_normal((15, 7)) + + with patch("diff_diff.utils.HAS_RUST_BACKEND", False): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _sc_weight_fw(Y, zeta=0.1, max_iter=10000, min_decrease=1e-3) + assert not any("did not converge" in str(x.message) for x in w) + + def test_fw_max_iter_zero_warns(self): + """Silent-failure audit axis B: max_iter=0 produces the uniform init + without iterating, which cannot converge by construction. The warning + must fire (consistent with the convention: if we exited the loop + without hitting the tolerance gate, we signal). Pins this contract.""" + Y = np.random.default_rng(0).standard_normal((5, 4)) + + with patch("diff_diff.utils.HAS_RUST_BACKEND", False): + with pytest.warns(UserWarning, match="did not converge"): + _sc_weight_fw(Y, zeta=0.1, max_iter=0) + class TestSparsify: """Verify sparsification behavior."""