From 72d87a97694d9074269c8b217b2b31c213b62815 Mon Sep 17 00:00:00 2001 From: Sanjana Garg Date: Mon, 27 Apr 2026 12:32:06 -0700 Subject: [PATCH 1/3] Add distributed Jacobi preconditioning dist_jacobi_precondition reconstructs full row L2-norms from a column-partitioned matrix via all_reduce(SUM) over per-rank partial sums of squares, then scales each rank's local slice in place. Rank 0 optionally scales b and persists the row norms. Includes a no-dist guard test plus two torchrun-gated tests covering correctness of the global norms / A scaling / b scaling and the norm-save path. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/dualip/preprocessing/precondition.py | 62 ++++++++++++ tests/preprocessing/test_precondition.py | 120 ++++++++++++++++++++++- 2 files changed, 181 insertions(+), 1 deletion(-) diff --git a/src/dualip/preprocessing/precondition.py b/src/dualip/preprocessing/precondition.py index e46ca4b..fea5403 100644 --- a/src/dualip/preprocessing/precondition.py +++ b/src/dualip/preprocessing/precondition.py @@ -1,6 +1,8 @@ from pathlib import Path +from typing import Optional import torch +import torch.distributed as dist from dualip.utils.sparse_utils import left_multiply_sparse, row_norms_csc @@ -28,6 +30,66 @@ def jacobi_precondition(A: torch.sparse_csc_tensor, b: torch.Tensor, norms_save_ return row_norms +def dist_jacobi_precondition( + A_local: torch.Tensor, + b: Optional[torch.Tensor], + norms_save_path: Optional[str] = None, +) -> torch.Tensor: + """ + Distributed Jacobi preconditioning for a column-partitioned matrix. + + Each rank holds a column slice ``A_local`` of shape ``(m, n_local)``. + The full row L2-norms are reconstructed via an ``all_reduce(SUM)`` over + the per-rank partial sums of squared values, then each rank scales its + local slice in-place. Rank 0 also scales ``b`` when provided. + + Requires ``torch.distributed`` to be initialized before calling. + + Parameters + ---------- + A_local : torch.sparse_csc_tensor + Local column slice of the constraint matrix, shape ``(m, n_local)``. + Modified in-place. + b : torch.Tensor or None + Right-hand-side vector of length ``m``. Pass the actual tensor on + rank 0 and ``None`` on all other ranks. Modified in-place when given. + norms_save_path : str, optional + If provided, rank 0 saves the row-norm tensor to this path so it can + be used later with :func:`jacobi_invert_precondition`. + + Returns + ------- + torch.Tensor + Dense 1-D tensor of length ``m`` with the full (global) row L2-norms. + """ + if not dist.is_initialized(): + raise RuntimeError( + "dist_jacobi_precondition requires torch.distributed to be initialized. " + "Call torch.distributed.init_process_group() before using this function." + ) + + n_rows = A_local.size(0) + row_idx = A_local.row_indices() + vals = A_local.values() + + local_sq = torch.zeros(n_rows, dtype=vals.dtype, device=vals.device) + local_sq.scatter_add_(0, row_idx.to(torch.long), vals.pow(2)) + + dist.all_reduce(local_sq, op=dist.ReduceOp.SUM) + row_norms = local_sq.pow(0.5) + + if dist.get_rank() == 0 and norms_save_path: + torch.save(row_norms, Path(norms_save_path)) + + reciprocal = 1.0 / row_norms + left_multiply_sparse(reciprocal, A_local, A_local) + + if b is not None: + b.mul_(reciprocal) + + return row_norms + + def jacobi_invert_precondition(dual_val: torch.Tensor, norms_path_or_tensor: str | torch.Tensor): """ Reverse the Jacobi pre-conditioning using row-norms saved on disk. diff --git a/tests/preprocessing/test_precondition.py b/tests/preprocessing/test_precondition.py index c40dac0..89f8f51 100644 --- a/tests/preprocessing/test_precondition.py +++ b/tests/preprocessing/test_precondition.py @@ -1,9 +1,16 @@ +import os from pathlib import Path import pytest import torch +import torch.distributed as dist -from dualip.preprocessing.precondition import jacobi_invert_precondition, jacobi_precondition +from dualip.preprocessing.precondition import ( + dist_jacobi_precondition, + jacobi_invert_precondition, + jacobi_precondition, +) +from dualip.utils.sparse_utils import split_csc_by_cols ccol_indices = [0, 2, 3, 5, 8, 10, 12, 15, 16] row_indices = [2, 3, 3, 1, 2, 0, 1, 2, 0, 2, 0, 3, 1, 2, 3, 2] @@ -93,3 +100,114 @@ def test_invert_precondition(norms_path): reciprocal = 1.0 / row_norms assert torch.allclose(restored, reciprocal) + + +# --------------------------------------------------------------------------- +# Distributed Jacobi preconditioner tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif("RANK" in os.environ, reason="Only runs in non-distributed (no torchrun) context") +def test_dist_jacobi_precondition_raises_without_dist(): + """dist_jacobi_precondition must raise when torch.distributed is not initialized.""" + assert not dist.is_initialized(), "Expected dist to be uninitialized in this test" + with pytest.raises(RuntimeError, match="torch.distributed"): + dist_jacobi_precondition(A_test.clone(), b_test.clone()) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs") +@pytest.mark.skipif( + "RANK" not in os.environ, + reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...", +) +def test_dist_jacobi_precondition(): + """ + Distributed preconditioning gives the same row norms, A scaling, and b + scaling as the single-process version when data is split across 2 ranks. + """ + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + device = f"cuda:{rank}" + + A_full = A_test.to(device) + b_full = b_test.to(device) + + expected_row_norms = A_full.to_dense().norm(2, dim=1) + + n_cols = A_full.size(1) + split_sizes = [ + n_cols // world_size + (1 if i < n_cols % world_size else 0) + for i in range(world_size) + ] + a_splits = split_csc_by_cols(A_full, split_sizes) + + A_local = a_splits[rank].to(device) + original_vals = A_local.values().clone() + original_row_idx = A_local.row_indices().clone() + + b_local = b_full.clone() if rank == 0 else None + + row_norms = dist_jacobi_precondition(A_local, b_local) + + assert torch.allclose(row_norms, expected_row_norms, atol=1e-5), ( + f"Rank {rank}: row norms mismatch" + ) + + expected_vals = original_vals / expected_row_norms[original_row_idx.to(torch.long)] + assert torch.allclose(A_local.values(), expected_vals, atol=1e-5), ( + f"Rank {rank}: A_local values not scaled correctly" + ) + + if rank == 0: + expected_b = b_full / expected_row_norms + assert torch.allclose(b_local, expected_b, atol=1e-5), ( + "Rank 0: b not scaled correctly" + ) + + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs") +@pytest.mark.skipif( + "RANK" not in os.environ, + reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...", +) +def test_dist_jacobi_precondition_saves_norms(tmp_path): + """Rank 0 saves row norms to disk; the saved norms match the global row norms.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + device = f"cuda:{rank}" + + norms_path = str(tmp_path / "row_norms.pt") if rank == 0 else None + + A_full = A_test.to(device) + n_cols = A_full.size(1) + split_sizes = [ + n_cols // world_size + (1 if i < n_cols % world_size else 0) + for i in range(world_size) + ] + a_splits = split_csc_by_cols(A_full, split_sizes) + A_local = a_splits[rank].to(device) + b_local = b_test.to(device).clone() if rank == 0 else None + + row_norms = dist_jacobi_precondition(A_local, b_local, norms_save_path=norms_path) + + if rank == 0: + assert Path(norms_path).exists(), "Rank 0 did not save the norm file" + saved = torch.load(norms_path, map_location=device) + assert torch.allclose(saved, row_norms, atol=1e-5), ( + "Saved norms differ from returned row norms" + ) + + finally: + if dist.is_initialized(): + dist.destroy_process_group() From 667fe84c31d04469cbe8243d2a030fb0f9fb4824 Mon Sep 17 00:00:00 2001 From: Sanjana Garg Date: Mon, 27 Apr 2026 14:51:27 -0700 Subject: [PATCH 2/3] Apply black formatting Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/preprocessing/test_precondition.py | 29 ++++++++---------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/tests/preprocessing/test_precondition.py b/tests/preprocessing/test_precondition.py index 89f8f51..91c378f 100644 --- a/tests/preprocessing/test_precondition.py +++ b/tests/preprocessing/test_precondition.py @@ -106,6 +106,7 @@ def test_invert_precondition(norms_path): # Distributed Jacobi preconditioner tests # --------------------------------------------------------------------------- + @pytest.mark.skipif("RANK" in os.environ, reason="Only runs in non-distributed (no torchrun) context") def test_dist_jacobi_precondition_raises_without_dist(): """dist_jacobi_precondition must raise when torch.distributed is not initialized.""" @@ -138,10 +139,7 @@ def test_dist_jacobi_precondition(): expected_row_norms = A_full.to_dense().norm(2, dim=1) n_cols = A_full.size(1) - split_sizes = [ - n_cols // world_size + (1 if i < n_cols % world_size else 0) - for i in range(world_size) - ] + split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)] a_splits = split_csc_by_cols(A_full, split_sizes) A_local = a_splits[rank].to(device) @@ -152,20 +150,16 @@ def test_dist_jacobi_precondition(): row_norms = dist_jacobi_precondition(A_local, b_local) - assert torch.allclose(row_norms, expected_row_norms, atol=1e-5), ( - f"Rank {rank}: row norms mismatch" - ) + assert torch.allclose(row_norms, expected_row_norms, atol=1e-5), f"Rank {rank}: row norms mismatch" expected_vals = original_vals / expected_row_norms[original_row_idx.to(torch.long)] - assert torch.allclose(A_local.values(), expected_vals, atol=1e-5), ( - f"Rank {rank}: A_local values not scaled correctly" - ) + assert torch.allclose( + A_local.values(), expected_vals, atol=1e-5 + ), f"Rank {rank}: A_local values not scaled correctly" if rank == 0: expected_b = b_full / expected_row_norms - assert torch.allclose(b_local, expected_b, atol=1e-5), ( - "Rank 0: b not scaled correctly" - ) + assert torch.allclose(b_local, expected_b, atol=1e-5), "Rank 0: b not scaled correctly" finally: if dist.is_initialized(): @@ -191,10 +185,7 @@ def test_dist_jacobi_precondition_saves_norms(tmp_path): A_full = A_test.to(device) n_cols = A_full.size(1) - split_sizes = [ - n_cols // world_size + (1 if i < n_cols % world_size else 0) - for i in range(world_size) - ] + split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)] a_splits = split_csc_by_cols(A_full, split_sizes) A_local = a_splits[rank].to(device) b_local = b_test.to(device).clone() if rank == 0 else None @@ -204,9 +195,7 @@ def test_dist_jacobi_precondition_saves_norms(tmp_path): if rank == 0: assert Path(norms_path).exists(), "Rank 0 did not save the norm file" saved = torch.load(norms_path, map_location=device) - assert torch.allclose(saved, row_norms, atol=1e-5), ( - "Saved norms differ from returned row norms" - ) + assert torch.allclose(saved, row_norms, atol=1e-5), "Saved norms differ from returned row norms" finally: if dist.is_initialized(): From e9db7dfce8abd07fd7c6d3345550672b511d3eb5 Mon Sep 17 00:00:00 2001 From: Sanjana Garg Date: Mon, 27 Apr 2026 15:07:01 -0700 Subject: [PATCH 3/3] Fix isort ordering Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/preprocessing/test_precondition.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/preprocessing/test_precondition.py b/tests/preprocessing/test_precondition.py index 91c378f..c218dff 100644 --- a/tests/preprocessing/test_precondition.py +++ b/tests/preprocessing/test_precondition.py @@ -5,11 +5,7 @@ import torch import torch.distributed as dist -from dualip.preprocessing.precondition import ( - dist_jacobi_precondition, - jacobi_invert_precondition, - jacobi_precondition, -) +from dualip.preprocessing.precondition import dist_jacobi_precondition, jacobi_invert_precondition, jacobi_precondition from dualip.utils.sparse_utils import split_csc_by_cols ccol_indices = [0, 2, 3, 5, 8, 10, 12, 15, 16]