Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diff_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,12 +1460,15 @@ def _sc_weight_fw_numpy(
lam = np.ones(T0) / T0

vals = np.full(max_iter, np.nan)
converged = False
for t in range(max_iter):
lam = _fw_step(A, lam, b, eta)
err = Y @ np.append(lam, -1.0)
vals[t] = zeta**2 * np.sum(lam**2) + np.sum(err**2) / N
if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2:
converged = True
break
warn_if_not_converged(converged, "Frank-Wolfe SC weight solver", max_iter, min_decrease)

return lam

Expand Down
2 changes: 1 addition & 1 deletion docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi
P-value: analytical (normal distribution), not empirical.

*Edge cases:*
- **Frank-Wolfe non-convergence**: Returns current weights after max_iter iterations. No warning emitted; the convergence check `vals[t-1] - vals[t] < min_decrease²` simply does not trigger early exit, and the final iterate is returned.
- **Frank-Wolfe non-convergence**: Returns current weights after max_iter iterations when the convergence check `vals[t-1] - vals[t] < min_decrease²` never triggers early exit. The numpy-backend path (`_sc_weight_fw_numpy`) emits a `UserWarning` via `diff_diff.utils.warn_if_not_converged` in that case; the Rust-backend path silently returns the final iterate (Rust-side signature change required to thread convergence status — tracked as an axis-G backend-parity follow-up).
- **`_sparsify` all-zero input**: If `max(v) <= 0`, returns uniform weights `ones(len(v)) / len(v)`.
- **Single control unit**: `compute_sdid_unit_weights` returns `[1.0]` immediately (short-circuit before Frank-Wolfe).
- **Zero control units**: `compute_sdid_unit_weights` returns empty array `[]`.
Expand Down
53 changes: 53 additions & 0 deletions tests/test_methodology_sdid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_compute_regularization,
_fw_step,
_sc_weight_fw,
_sc_weight_fw_numpy,
_sparsify,
_sum_normalize,
compute_sdid_estimator,
Expand Down Expand Up @@ -207,6 +208,58 @@ def test_intercept_centering(self):
# They should be different because centering matters
assert not np.allclose(lam_intercept, lam_no_intercept, atol=1e-3)

def test_fw_warns_on_nonconvergence(self):
"""Silent-failure audit axis B: _sc_weight_fw_numpy must warn when max_iter exhausts."""
rng = np.random.default_rng(42)
Y = rng.standard_normal((15, 7)) # (N, T0+1) with T0=6

with pytest.warns(UserWarning, match="did not converge"):
_sc_weight_fw_numpy(Y, zeta=0.1, max_iter=1, min_decrease=1e-12)

def test_fw_no_warning_on_convergence(self):
"""Silent-failure audit axis B: no warning on well-conditioned convergent input."""
rng = np.random.default_rng(42)
Y = rng.standard_normal((15, 7))

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_sc_weight_fw_numpy(Y, zeta=0.1, max_iter=10000, min_decrease=1e-3)
assert not any("did not converge" in str(x.message) for x in w)

def test_fw_wrapper_warns_on_nonconvergence_without_rust(self):
"""Silent-failure audit axis B: public _sc_weight_fw wrapper must route
warnings through even when called via the dispatcher with the Rust
backend disabled. Pins the contract against refactors that would
bypass the numpy path."""
rng = np.random.default_rng(42)
Y = rng.standard_normal((15, 7))

with patch("diff_diff.utils.HAS_RUST_BACKEND", False):
with pytest.warns(UserWarning, match="did not converge"):
_sc_weight_fw(Y, zeta=0.1, max_iter=1, min_decrease=1e-12)

def test_fw_wrapper_no_warning_on_convergence_without_rust(self):
"""Silent-failure audit axis B: wrapper-level negative control."""
rng = np.random.default_rng(42)
Y = rng.standard_normal((15, 7))

with patch("diff_diff.utils.HAS_RUST_BACKEND", False):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_sc_weight_fw(Y, zeta=0.1, max_iter=10000, min_decrease=1e-3)
assert not any("did not converge" in str(x.message) for x in w)

def test_fw_max_iter_zero_warns(self):
"""Silent-failure audit axis B: max_iter=0 produces the uniform init
without iterating, which cannot converge by construction. The warning
must fire (consistent with the convention: if we exited the loop
without hitting the tolerance gate, we signal). Pins this contract."""
Y = np.random.default_rng(0).standard_normal((5, 4))

with patch("diff_diff.utils.HAS_RUST_BACKEND", False):
with pytest.warns(UserWarning, match="did not converge"):
_sc_weight_fw(Y, zeta=0.1, max_iter=0)


class TestSparsify:
"""Verify sparsification behavior."""
Expand Down
Loading