Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Deferred items from PR reviews that were not addressed before merge.
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
| Regenerate `benchmarks/data/clubsandwich_cr2_golden.json` from R (`Rscript benchmarks/R/generate_clubsandwich_golden.R`). Current JSON has `source: python_self_reference` as a stability anchor until an authoritative R run. | `benchmarks/R/generate_clubsandwich_golden.R` | Phase 1a | Medium |
| `honest_did.py:1907` `np.linalg.solve(A_sys, b_sys) / except LinAlgError: continue` is a silent basis-rejection in the vertex-enumeration loop that is algorithmically intentional (try the next basis). Consider surfacing a count of rejected bases as a diagnostic when ARP enumeration exhausts, so users see when the vertex search was heavily constrained. Not a silent failure in the sense of the Phase 2 audit (the algorithm is supposed to skip), but the diagnostic would help debug borderline cases. | `honest_did.py` | #334 | Low |
| `compute_synthetic_weights` backend algorithm mismatch: Rust path uses Frank-Wolfe (`_rust_synthetic_weights` in `utils.py:1184`); Python fallback uses projected gradient descent (`_compute_synthetic_weights_numpy` in `utils.py:1228`). Both solve the same constrained QP but converge to different simplex vertices on near-degenerate / extreme-scale inputs (e.g. `Y~1e9`, or near-singular `Y'Y`). Unified backend (one algorithm) would close the parity gap surfaced by audit finding #22. Two `@pytest.mark.xfail(strict=True)` tests in `tests/test_rust_backend.py::TestSyntheticWeightsBackendParity` baseline the divergence so we notice when/if the algorithms align. | `utils.py`, `rust/` | follow-up | Medium |
| Rust `compute_synthetic_weights` + `compute_synthetic_weights_internal` (now dead code) can be removed from `rust/src/weights.rs:43-117` in a future Rust-cleanup PR. Python-side wrapper was deleted (post-audit cleanup for finding #22) and its sole caller now inlines Frank-Wolfe via `_sc_weight_fw`. The Rust symbol remains callable via `from diff_diff._rust_backend import compute_synthetic_weights` but no Python code calls it. Removal requires `maturin develop` rebuild. No functional impact of leaving it. | `rust/src/weights.rs` | follow-up | Low |
| TROP Rust vs Python grid-search divergence on rank-deficient Y: on two near-parallel control units, LOOCV grid-search ATT diverges ~6% between Rust (`trop_global.py:688`) and Python fallback (`trop_global.py:753`). Either grid-winner ties are broken differently or the per-λ solver reaches different stationary points under rank deficiency. Audit finding #23 flagged this surface. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_grid_search_rank_deficient_Y` baselines the gap. | `trop_global.py`, `rust/` | follow-up | Medium |
| TROP Rust vs Python bootstrap SE divergence under fixed seed: `seed=42` on a tiny panel produces ~28% bootstrap-SE gap. Root cause: Rust bootstrap uses its own RNG (`rand` crate) while Python uses `numpy.random.default_rng`; same seed value maps to different bytestreams across backends. Audit axis-H (RNG/seed) adjacent. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_bootstrap_seed_reproducibility` baselines the gap. Unifying RNG (threading a numpy-generated seed-sequence into Rust, or porting Python to ChaCha) would close it. | `trop_global.py`, `rust/` | follow-up | Medium |

Expand Down
1 change: 0 additions & 1 deletion diff_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
_rust_compute_robust_vcov,
_rust_project_simplex,
_rust_solve_ols,
_rust_synthetic_weights,
)

from diff_diff.bacon import (
Expand Down
4 changes: 0 additions & 4 deletions diff_diff/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
try:
from diff_diff._rust_backend import (
generate_bootstrap_weights_batch as _rust_bootstrap_weights,
compute_synthetic_weights as _rust_synthetic_weights,
project_simplex as _rust_project_simplex,
solve_ols as _rust_solve_ols,
compute_robust_vcov as _rust_compute_robust_vcov,
Expand All @@ -43,7 +42,6 @@
except ImportError:
_rust_available = False
_rust_bootstrap_weights = None
_rust_synthetic_weights = None
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
Expand All @@ -66,7 +64,6 @@
# Force pure Python mode - disable Rust even if available
HAS_RUST_BACKEND = False
_rust_bootstrap_weights = None
_rust_synthetic_weights = None
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
Expand Down Expand Up @@ -115,7 +112,6 @@ def rust_backend_info():
"HAS_RUST_BACKEND",
"rust_backend_info",
"_rust_bootstrap_weights",
"_rust_synthetic_weights",
"_rust_project_simplex",
"_rust_solve_ols",
"_rust_compute_robust_vcov",
Expand Down
78 changes: 74 additions & 4 deletions diff_diff/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
re-exported here for backward compatibility.
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -36,7 +37,7 @@
compute_replicate_if_variance,
compute_survey_if_variance,
)
from diff_diff.utils import compute_synthetic_weights
from diff_diff.utils import _compute_noise_level, _sc_weight_fw

# Constants for rank_control_units
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
Expand Down Expand Up @@ -837,7 +838,10 @@ def rank_control_units(
- quality_score: Combined quality score (0-1, higher is better)
- outcome_trend_score: Pre-treatment outcome trend similarity
- covariate_score: Covariate match score (NaN if no covariates)
- synthetic_weight: Weight from synthetic control optimization
- synthetic_weight: Informational heuristic weight from a single-pass
uncentered Frank-Wolfe solve; does NOT factor into ``quality_score``
(ranking) and is NOT the canonical SDID unit weight. For canonical
SDID weights use ``SyntheticDiD.fit()``.
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
- is_required: Whether unit was in require_units

Expand Down Expand Up @@ -989,8 +993,74 @@ def rank_control_units(
# -------------------------------------------------------------------------
# Compute outcome trend scores
# -------------------------------------------------------------------------
# Synthetic weights (higher = better match)
synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg)
# Informational `synthetic_weight` column. This is a RANKING HEURISTIC,
# not an estimator: it gives a rough "which controls would a synthetic
# regression weight heavily" signal that's reported alongside RMSE and
# covariate distance. The actual ranking (`quality_score`) is computed
# below from `outcome_trend_score` (RMSE-based) + `covariate_score`; the
# `synthetic_weight` column does NOT factor into the ranking decision.
#
# Solver choice. We use a single-pass uncentered Frank-Wolfe via the
# shared `_sc_weight_fw` dispatcher to solve:
#
# min_w ||Y_treated_mean - Y_control @ w||^2 + lambda_reg * ||w||^2
# s.t. w >= 0, sum(w) = 1
#
# Mapped to the FW objective `zeta^2 ||w||^2 + (1/N) ||Aw - b||^2` via
# `zeta = sqrt(lambda_reg / N)`. intercept=False because this QP does
# no column-centering, max_iter=1000 to bound ranking-loop cost,
# min_weight=1e-6 post-processing for interpretability.
#
# NOTE — this is INTENTIONALLY NOT the canonical SDID / R
# `synthdid::sc.weight.fw` two-pass unit-weight procedure (that uses
# intercept=TRUE, 100-iter -> sparsify -> 10000-iter). SDID estimation
# still uses that canonical path in `_sc_weight_fw_numpy` at
# `utils.py:_sc_weight_fw_numpy` via `compute_sdid_unit_weights`; this
# ranking heuristic uses a simpler single-pass call to the same solver
# for a cheap diagnostic score.
#
# Replaces the former `compute_synthetic_weights` wrapper whose Rust
# and Python backends had divergent PGD implementations (audit
# finding #22). Net effect: users on default `lambda_reg=0` with
# typical data see `synthetic_weight` values that agree with the old
# code to ~1e-7; extreme Y or `lambda_reg > 0` cases produce values
# that differ from the old code (which was mathematically wrong).
_Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
_Y_treated_mean = np.ascontiguousarray(Y_treated_mean, dtype=np.float64)
_n_pre, _n_control = _Y_control.shape
if _n_control == 0:
synthetic_weights = np.array([], dtype=np.float64)
elif _n_control == 1:
synthetic_weights = np.array([1.0])
else:
_zeta = float(np.sqrt(lambda_reg / _n_pre)) if lambda_reg > 0 else 0.0
# Scale stopping threshold by noise level so convergence stays
# meaningful at any data magnitude.
_sigma = _compute_noise_level(_Y_control)
_min_decrease = 1e-5 * max(_sigma, 1e-12)
_Y_fw = np.column_stack([_Y_control, _Y_treated_mean])
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r".*did not converge.*",
category=UserWarning,
)
synthetic_weights = _sc_weight_fw(
_Y_fw,
zeta=_zeta,
intercept=False,
min_decrease=_min_decrease,
max_iter=1000,
)
# Set small weights to zero for interpretability, then renormalize.
synthetic_weights = np.asarray(synthetic_weights, dtype=np.float64)
_min_weight = 1e-6
synthetic_weights[synthetic_weights < _min_weight] = 0.0
_total = float(np.sum(synthetic_weights))
if _total > 0:
synthetic_weights = synthetic_weights / _total
else:
synthetic_weights = np.ones(_n_control) / _n_control

# RMSE for each control vs treated mean (use nanmean to handle missing data)
rmse_scores = []
Expand Down
118 changes: 8 additions & 110 deletions diff_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from diff_diff._backend import (
HAS_RUST_BACKEND,
_rust_project_simplex,
_rust_synthetic_weights,
_rust_sdid_unit_weights,
_rust_compute_time_weights,
_rust_compute_noise_level,
Expand Down Expand Up @@ -1131,115 +1130,14 @@ def equivalence_test_trends(
}


def compute_synthetic_weights(
Y_control: np.ndarray, Y_treated: np.ndarray, lambda_reg: float = 0.0, min_weight: float = 1e-6
) -> np.ndarray:
"""
Compute synthetic control unit weights using constrained optimization.

Finds weights ω that minimize the squared difference between the
weighted average of control unit outcomes and the treated unit outcomes
during pre-treatment periods.

Parameters
----------
Y_control : np.ndarray
Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
Each column is a control unit, each row is a pre-treatment period.
Y_treated : np.ndarray
Treated unit mean outcomes of shape (n_pre_periods,).
Average across treated units for each pre-treatment period.
lambda_reg : float, default=0.0
L2 regularization parameter. Larger values shrink weights toward
uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
min_weight : float, default=1e-6
Minimum weight threshold. Weights below this are set to zero.

Returns
-------
np.ndarray
Unit weights of shape (n_control_units,) that sum to 1.

Notes
-----
Solves the quadratic program:

min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
s.t. ω >= 0, sum(ω) = 1

Uses a simplified coordinate descent approach with projection onto simplex.
"""
n_pre, n_control = Y_control.shape

if n_control == 0:
return np.asarray([])

if n_control == 1:
return np.asarray([1.0])

# Use Rust backend if available
if HAS_RUST_BACKEND:
Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
weights = _rust_synthetic_weights(
Y_control, Y_treated, lambda_reg, _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
)
else:
# Fallback to NumPy implementation
weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)

# Set small weights to zero for interpretability
weights[weights < min_weight] = 0
if np.sum(weights) > 0:
weights = weights / np.sum(weights)
else:
# Fallback to uniform if all weights are zeroed
weights = np.ones(n_control) / n_control

return np.asarray(weights)


def _compute_synthetic_weights_numpy(
Y_control: np.ndarray,
Y_treated: np.ndarray,
lambda_reg: float = 0.0,
) -> np.ndarray:
"""NumPy fallback implementation of compute_synthetic_weights."""
n_pre, n_control = Y_control.shape

# Initialize with uniform weights
weights = np.ones(n_control) / n_control

# Precompute matrices for optimization
# Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
# = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
YtY = Y_control.T @ Y_control
YtT = Y_control.T @ Y_treated
w_uniform = np.ones(n_control) / n_control

# Add regularization
H = YtY + lambda_reg * np.eye(n_control)
f = YtT + lambda_reg * w_uniform

# Solve with projected gradient descent
# Project onto probability simplex
step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)

for _ in range(_OPTIMIZATION_MAX_ITER):
weights_old = weights.copy()

# Gradient step: minimize ||Y - Y_control @ w||^2
grad = H @ weights - f
weights = weights - step_size * grad

# Project onto simplex (sum to 1, non-negative)
weights = _project_simplex(weights)

# Check convergence
if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
break

return weights
# compute_synthetic_weights and _compute_synthetic_weights_numpy removed in the
# silent-failures audit post-cleanup (finding #22). The one caller
# (`diff_diff.prep.rank_control_units`) inlines a single-pass, uncentered
# Frank-Wolfe via the shared `_sc_weight_fw` dispatcher — a ranking heuristic,
# NOT the canonical SDID/R `synthdid::sc.weight.fw` two-pass procedure
# (intercept=True, 100-iter -> sparsify -> 10000-iter). Canonical SDID unit
# weights go through `compute_sdid_unit_weights` (see `_sc_weight_fw_numpy`
# below and REGISTRY.md SDID section).


def _project_simplex(v: np.ndarray) -> np.ndarray:
Expand Down
22 changes: 3 additions & 19 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3124,25 +3124,9 @@ def test_project_simplex(self):
assert abs(np.sum(projected) - 1.0) < 1e-6
assert np.all(projected >= 0)

def test_compute_synthetic_weights(self):
"""Test synthetic weight computation."""
from diff_diff.utils import compute_synthetic_weights

np.random.seed(42)
n_pre = 5
n_control = 10

Y_control = np.random.randn(n_pre, n_control)
Y_treated = np.random.randn(n_pre)

weights = compute_synthetic_weights(Y_control, Y_treated)

# Weights should sum to 1
assert abs(np.sum(weights) - 1.0) < 1e-6
# Weights should be non-negative
assert np.all(weights >= 0)
# Should have correct length
assert len(weights) == n_control
# test_compute_synthetic_weights removed in the silent-failures audit
# post-cleanup (finding #22). Helper deleted; behavior now covered via
# tests/test_prep.py::TestRankControlUnits (its sole caller).

def test_compute_time_weights(self):
"""Test time weight computation with Frank-Wolfe solver."""
Expand Down
Loading
Loading