Skip to content

Unify Rust TROP inner solver to SVD (close finding #23 grid-search divergence)#348

Merged
igerber merged 2 commits intomainfrom
fix/trop-grid-search-parity
Apr 21, 2026
Merged

Unify Rust TROP inner solver to SVD (close finding #23 grid-search divergence)#348
igerber merged 2 commits intomainfrom
fix/trop-grid-search-parity

Conversation

@igerber
Copy link
Copy Markdown
Owner

@igerber igerber commented Apr 21, 2026

Summary

  • Rewrites Rust's TROP inner TWFE solver (rust/src/trop.rs::solve_joint_no_lowrank) from iterative block coordinate descent to SVD-based minimum-norm weighted least squares, mirroring Python's np.linalg.lstsq(rcond=None) step-for-step with numpy-compatible rcond = eps * max(n, k).
  • Closes the grid-search half of silent-failures finding Address code review feedback for CallawaySantAnna covariates #23 (TODO row 87). The ~6% ATT divergence between Rust and Python on rank-deficient Y (two near-parallel control units) is eliminated; test_grid_search_rank_deficient_Y is no longer xfailed.
  • Bootstrap-seed divergence (~28% SE gap from Rust rand vs numpy default_rng) is a separate root cause — tracked in TODO.md row 87 for a future PR.

Methodology references (required if estimator / math changes)

  • Method name(s): TROP (Triply Robust Panel) global estimation, inner weighted least-squares fit.
  • Paper / source link(s): Athey, Imbens, Qu & Viviano (2025), arXiv:2508.21536. TROP Equation 2 / LOOCV Equation 5. Canonical numerical reference for rank-deficient WLS: Golub & Van Loan, Matrix Computations, Ch. 5.5 (minimum-norm least squares via SVD).
  • Any intentional deviations from the source (and why): None. This PR makes Rust match Python on the canonical numerical path. REGISTRY.md TROP Global Estimation bullet updated in place (near line 2061) to spell out that both backends now use SVD-based minimum-norm WLS with numpy-compatible rcond.

Validation

  • Tests added/updated: tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_grid_search_rank_deficient_Y xfail decorator removed; test now asserts Rust/Python parity at atol=1e-6 on rank-deficient Y.
  • maturin develop --release --features accelerate: clean build, no warnings.
  • pytest tests/test_rust_backend.py::TestTROPRustEdgeCaseParity: 1 passed (grid-search), 1 xfailed (bootstrap-seed, out of scope).
  • pytest tests/test_rust_backend.py -k TROP -m '': 23 passed, 1 xfailed, no regressions.
  • pytest tests/test_trop.py: 83 passed, 37 deselected (slow-marked).
  • TestTROPGlobalRustVsNumpy (includes lambda_nn=0 low-rank FISTA path): 8 passed. FISTA TWFE step behavior preserved on well-conditioned data.
  • Pre-emptive grep for _ in 0..50 in rust/src/*.rs: no other iterative coordinate-descent patterns that would need similar treatment.

Security / privacy

  • Confirm no secrets/PII in this PR: Yes.

Diff: +148 / −96 across rust/src/trop.rs, rust/src/linalg.rs, tests/test_rust_backend.py, TODO.md, docs/methodology/REGISTRY.md.

…vergence

Closes the grid-search half of silent-failures finding #23 (TODO row 87).
The `xfail(strict=True)` regression `test_grid_search_rank_deficient_Y`
baselined a ~6% ATT divergence between Rust and Python on two near-parallel
control units. Root cause: Rust's `solve_joint_no_lowrank` used iterative
block coordinate descent (50 iter, tol=1e-8) while Python used SVD-based
minimum-norm least squares. On rank-deficient Y the two solvers converge
to different stationary points of the same objective.

Python is canonical (SVD / minimum-norm least squares per Golub & Van Loan).
Rust's iterative solver was a speed optimization, not a methodology choice.
Port the Rust inner TWFE step to SVD-based WLS that mirrors Python's
`np.linalg.lstsq(rcond=None)` step-for-step, with numpy-compatible
`rcond = eps * max(n, k)`.

Changes
- rust/src/linalg.rs: promote ndarray_to_faer to pub(crate) so trop.rs can reuse it.
- rust/src/trop.rs: new module-private solve_wls_svd helper — thin-SVD + rcond truncation, matches numpy's minimum-norm semantics. Rewrite solve_joint_no_lowrank body to flatten y/weights row-major, build the [intercept | unit_dummies[1..] | time_dummies[1..]] design matrix, apply sqrt-weights, and solve via solve_wls_svd. Function signature unchanged — all 4 call sites (LOOCV, FISTA TWFE step x2, bootstrap) benefit transitively.
- tests/test_rust_backend.py: remove @pytest.mark.xfail from test_grid_search_rank_deficient_Y; the gap is closed. Bootstrap-seed test retains its xfail (row 87 RNG mismatch, out of scope).
- docs/methodology/REGISTRY.md: update the TROP Global Estimation bullet at the existing `np.linalg.lstsq` line to note Rust and Python now both use SVD-based minimum-norm WLS with numpy-compatible rcond.
- TODO.md: delete row 87 (grid-search divergence entry).

Verification
- maturin develop --release --features accelerate: clean build, no warnings.
- pytest tests/test_rust_backend.py::TestTROPRustEdgeCaseParity: grid-search test now passes; bootstrap-seed test correctly xfails.
- pytest tests/test_rust_backend.py -k TROP -m '': 23 passed, 1 xfailed, no regressions.
- pytest tests/test_trop.py: 83 passed, 37 deselected (slow).
- TestTROPGlobalRustVsNumpy (incl. lambda_nn=0 low-rank FISTA path): 8 passed — FISTA TWFE step unchanged in behavior on well-conditioned data.
- grep for other 'for _ in 0..50' coordinate-descent patterns in rust/src/*.rs: none found.

Non-goals
- No changes to row 87 (bootstrap RNG mismatch — Rust rand crate vs numpy default_rng ~28% SE gap on seed=42). Separate PR.
- No changes to linalg.rs::solve_ols (rcond=1e-7 is load-bearing for MultiPeriodDiD / DiD / TWFE).
- No public API changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

Overall Assessment
✅ Looks good

Executive Summary

  • Affected method: TROP global estimation’s inner weighted least-squares fit. The rewritten Rust helper in rust/src/trop.rs:L1234-L1395 now matches the Python reference solve in diff_diff/trop_global.py:L323-L410 for the λ_nn = ∞ path and the shared inner step used by the finite-λ_nn alternating-minimization path in rust/src/trop.rs:L1435-L1500. This preserves the paper’s weighted TROP objective and LOOCV criterion; the switch to an SVD/minimum-norm solve is a numerical implementation choice, not a methodology change. The registry update in docs/methodology/REGISTRY.md:L2059-L2065 documents that alignment. citeturn6view0turn5view2turn7view0
  • I found no P0/P1 issues in estimator math, weighting, identification, variance/SE handling, or inference anti-patterns within the changed diff.
  • The separate Rust/Python bootstrap-seed divergence remains explicitly tracked in TODO.md:L86-L86 and intentionally xfailed in tests/test_rust_backend.py:L2255-L2298, so it is mitigated rather than blocking.
  • One minor issue only: the new regression-guard docstring still says “TODO row 87,” but after this PR row 87 now refers to a different item. tests/test_rust_backend.py:L2208-L2217, TODO.md:L86-L87

Methodology

  • No findings. The affected method is TROP global estimation’s inner WLS solve, used directly when λ_nn = ∞ and reused as the TWFE substep when λ_nn is finite. The new Rust code builds the same weighted dummy-variable system as the Python reference path and solves it with NumPy-style minimum-norm least squares, which is consistent with the paper and with the registry’s documented implementation note. rust/src/trop.rs:L1234-L1395, diff_diff/trop_global.py:L323-L410, docs/methodology/REGISTRY.md:L2059-L2065 citeturn6view0turn5view2turn7view0

Code Quality

Performance

  • No findings. The change intentionally prioritizes numerical parity on rank-deficient systems; I did not benchmark runtime in this read-only review environment.

Maintainability

  • No findings. The added in-code rationale for not reusing the generic solve_ols cutoff makes the TROP-specific numerical contract clearer. rust/src/trop.rs:L1321-L1333

Tech Debt

  • P3 — Impact: the Rust/Python bootstrap seed mismatch is still open, but it is explicitly tracked and xfailed, so it does not block this PR under the stated rubric. Concrete fix: none for this PR; resolve it in the follow-up by unifying the Rust and NumPy seed-to-replicate mapping. TODO.md:L86-L86, tests/test_rust_backend.py:L2255-L2298

Security

  • No findings.

Documentation/Tests

…t docstring

After this PR deletes the old row 87 from TODO.md, row 87 now points to a
different item. Replace the row-number breadcrumb with "Silent-failures
audit Finding #23 (grid-search half)" which is stable across future
TODO.md reshuffles.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@igerber
Copy link
Copy Markdown
Owner Author

igerber commented Apr 21, 2026

/ai-review

@github-actions
Copy link
Copy Markdown

🔁 AI review rerun (requested by @igerber)

Head SHA: bbbcc4db0e7d325d706ad6cc2178c838e4571a01


Overall Assessment
✅ Looks good

Executive Summary

  • Affected method: the TROP method="global" inner weighted least-squares subproblem used by the Rust LOOCV grid-search and bootstrap accelerators. The Python reference solve remains the numerical contract, and the Rust rewrite now mirrors it closely. rust/src/trop.rs:L1234-L1500, rust/src/trop.rs:L1738-L1897, diff_diff/trop_global.py:L323-L537, diff_diff/trop_global.py:L688-L706, diff_diff/trop_global.py:L964-L996
  • Cross-checking REGISTRY.md plus the in-code Eq. 2 / Eq. 5 references shows no methodology change: (1-W) masking, weighting, and LOOCV scoring are unchanged. The switch from iterative coordinate descent to SVD minimum-norm WLS is a numerical implementation choice for the same weighted LS subproblem, not an undocumented estimator deviation. docs/methodology/REGISTRY.md:L2030-L2085, rust/src/trop.rs:L1226-L1233, rust/src/trop.rs:L1505-L1515, diff_diff/trop_global.py:L323-L412
  • I found no unmitigated P0/P1 issues in estimator math, control-group composition, weighting, variance/SE handling, or inference anti-patterns in the changed diff.
  • The prior AI review’s only informational issue is resolved: the rank-deficient regression test no longer points to a stale TODO row and now documents Finding Address code review feedback for CallawaySantAnna covariates #23 directly. tests/test_rust_backend.py:L2207-L2218, TODO.md:L86-L87
  • The remaining Rust/Python bootstrap seed divergence is still explicitly tracked and xfailed, so it remains P3/informational rather than blocking. TODO.md:L86-L86, tests/test_rust_backend.py:L2256-L2302

Methodology

  • No findings. The PR preserves the TROP global objective and Eq. 5 LOOCV criterion documented in the registry; it only changes how the Rust backend solves the shared weighted LS subproblem on rank-deficient systems. The new Rust path matches the Python reference structure: row-major flattening, NaN masking, weighted dummy design, sqrt-weight transform, and NumPy-style rcond=None truncation. rust/src/trop.rs:L1234-L1395, diff_diff/trop_global.py:L323-L412, docs/methodology/REGISTRY.md:L2059-L2079

Code Quality

  • No findings. The change is localized, the reason not to reuse solve_ols is documented, and the helper exposure in linalg.rs is crate-local only. rust/src/trop.rs:L1326-L1333, rust/src/linalg.rs:L279-L284

Performance

  • No findings. I did not benchmark the new dense SVD path against the removed coordinate-descent loop in this review.

Maintainability

  • No findings. The faer SVD usage is consistent with existing crate conventions, including the U/S/V shape assumptions already exercised in rust/src/linalg.rs tests. rust/src/linalg.rs:L69-L135, rust/src/linalg.rs:L425-L442, rust/src/trop.rs:L1343-L1393

Tech Debt

  • Severity: P3
    Impact: Cross-backend bootstrap SE parity under identical seeds is still intentionally unresolved because Rust and Python use different RNG implementations; reproducibility across backends remains incomplete, but this limitation is tracked and explicitly xfailed.
    Concrete fix: Keep the current TODO/xfail for this PR; resolve it in follow-up by unifying the seed-to-replicate mapping across backends. TODO.md:L86-L86, tests/test_rust_backend.py:L2256-L2302

Security

  • No findings.

Documentation/Tests

  • No findings. The methodology registry now documents the Rust/Python SVD minimum-norm contract, and the new regression test directly guards the rank-deficient grid-search parity case that motivated the fix. docs/methodology/REGISTRY.md:L2059-L2065, tests/test_rust_backend.py:L2207-L2254
  • I did not execute the test suite in this read-only review environment.

@igerber igerber added the ready-for-ci Triggers CI test workflows label Apr 21, 2026
@igerber igerber merged commit c76eea3 into main Apr 21, 2026
21 of 22 checks passed
@igerber igerber deleted the fix/trop-grid-search-parity branch April 21, 2026 09:49
igerber added a commit that referenced this pull request Apr 24, 2026
Rust and Python TROP backends produced different bootstrap standard
errors for the same `seed` value. On a tiny correlated panel under
`seed=42` the gap was ~28% of SE: Rust seeded `rand_xoshiro::
Xoshiro256PlusPlus` per replicate while Python's fallback consumed
`numpy.random.default_rng` (PCG64), so identical seeds mapped to
different bytestreams.

Canonicalize on numpy. New `stratified_bootstrap_indices` helper in
`diff_diff/bootstrap_utils.py` pre-generates per-replicate
(control, treated) positional index arrays from a numpy `Generator`
and hands them to both backends through the PyO3 surface — both
Rust bootstrap functions (`bootstrap_trop_variance_global`,
`bootstrap_trop_variance`) now accept `control_indices` and
`treated_indices` as `i64` arrays in place of `seed: u64`. Parallelism
is preserved. Sampling law (stratified: controls then treated, with
replacement) is unchanged.

Global-method SE is now backend-invariant under the same seed to
machine precision: the prior `xfail(strict=True)` in
`test_bootstrap_seed_reproducibility` is flipped to a passing
`assert_allclose(atol=rtol=1e-14)` and parametrized over
`[0, 42, 12345]`.

A companion `test_bootstrap_seed_reproducibility_local` is added for
the local-method bootstrap. It is currently `xfail(strict=True)`
because aligning the RNG exposed two separate local-method backend
divergences beyond this PR's scope: Rust's `compute_weight_matrix`
normalizes time and unit weights to sum to 1, while Python's
`_compute_observation_weights` does not; and the Python fallback's
`_compute_observation_weights(_precomputed branch)` reads the
original-panel cached `Y`/`D` instead of the bootstrap-sample
arguments. Both are tracked as follow-up rows in `TODO.md` with
file:line pointers and will land in a separate methodology PR.

Closes the bootstrap half of silent-failures audit finding #23 (the
grid-search half closed in PR #348). Reference: Athey, Imbens, Qu &
Viviano (2025), "Triply Robust Panel Estimators", Algorithm 3.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
igerber added a commit that referenced this pull request Apr 24, 2026
…+ Python cache-fallthrough

Closes the local-method half of silent-failures audit finding #23
(RNG half closed in PR #354; grid-search half in PR #348). Two
methodology fixes, both isolated to the local-method path — global
is unaffected.

1. Rust weight-matrix normalization removed
   ------------------------------------------
   `rust/src/trop.rs::compute_weight_matrix` no longer divides
   `time_weights` and `unit_weights` by their respective sums before
   the outer product. The paper's Equation 2/3 (Athey, Imbens, Qu,
   Viviano 2025) and REGISTRY.md Requirements checklist
   (`[x] Unit weights: exp(-λ_unit × distance) (unnormalized,
   matching Eq. 2)`) both specify raw-exponential weights; Python's
   `_compute_observation_weights` was already REGISTRY-compliant.
   Rust's normalization inflated the effective nuclear-norm penalty
   relative to the data-fit term, changing the regularization
   trade-off. User-visible effect: Rust local-method ATT values may
   shift for fits with `lambda_nn < infinity`. For
   `lambda_nn = infinity` (factor model disabled) outputs are
   unchanged — uniform weight scaling leaves the minimum-norm WLS
   argmin invariant. Rust LOOCV-selected lambdas may also shift on
   that boundary; both backends now converge on the same selection.
   Affects both local-method Rust call sites (LOOCV at trop.rs:459,
   bootstrap at trop.rs:1096).

2. Python `_compute_observation_weights` cache-fallthrough removed
   ---------------------------------------------------------------
   Removed the `if self._precomputed is not None:` branch that
   silently substituted `self._precomputed["Y"]` / `["D"]` /
   `["time_dist_matrix"]` (original-panel cache populated during
   main fit) for the function-argument `Y, D`. Under bootstrap,
   `_fit_with_fixed_lambda` computes fresh `Y, D` from the resampled
   `boot_data` and passes them in; the helper was discarding those
   and recomputing unit distances from the original panel, so
   Python's local bootstrap resampled units but reused stale
   unit-distance weights. Rust's bootstrap was already correct
   (always consumed `y_boot, d_boot`).

Test changes
------------
- `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::
  test_bootstrap_seed_reproducibility_local`: flipped from
  `xfail(strict=True)` to passing `assert_allclose` at `atol=1e-5`
  across seeds `[0, 42, 12345]`. Residual ~1e-7 gap is Rust
  `estimate_model` vs numpy `lstsq` roundoff that accumulates
  differently across per-replicate bootstrap fits; follow-up TODO
  row tracks unifying Rust to the `solve_wls_svd` path (same SVD
  helper the global-method uses since PR #348) for sub-1e-14
  parity.
- New `test_local_method_main_fit_parity`: parametrized over
  `(lambda_nn=inf, atol=1e-14)` and `(lambda_nn=0.1, atol=1e-10)`;
  asserts `atol=1e-14` bit-identity for the main-fit ATT at
  `lambda_nn=inf` (the regression guard for the normalization fix)
  and `atol=1e-10` for the finite-`lambda_nn` FISTA path.

Verification
------------
Targeted regression sweep — all green:
- 9 `TestTROPRustEdgeCaseParity` tests (grid-search + global
  bootstrap × 3 seeds + local bootstrap × 3 seeds + local main-fit
  × 2 regimes)
- Full `test_rust_backend.py` suite: 92 passed
- Full `test_trop.py` suite under Rust backend: 120 passed

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
igerber added a commit that referenced this pull request Apr 24, 2026
P3 — the PR #354 [Unreleased] Fixed entry (line 21) said local-method
bit-identity SE remained blocked by the Rust-normalization and Python
cache-fallthrough divergences and was "tracked as a follow-up in
TODO.md." With the two TROP-local Fixed entries that this PR adds
(lines 22-27) closing exactly those divergences, the PR #354 tail
sentence is now internally inconsistent with the surrounding entries.
Rewritten to say the RNG half of finding #23 is closed here (bootstrap
contract), grid-search half was closed in PR #348, and the local-
method methodology half is closed by the two Fixed entries that
follow in the same release.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-for-ci Triggers CI test workflows

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant