diff --git a/src/dualip/preprocessing/precondition.py b/src/dualip/preprocessing/precondition.py index e46ca4b..fea5403 100644 --- a/src/dualip/preprocessing/precondition.py +++ b/src/dualip/preprocessing/precondition.py @@ -1,6 +1,8 @@ from pathlib import Path +from typing import Optional import torch +import torch.distributed as dist from dualip.utils.sparse_utils import left_multiply_sparse, row_norms_csc @@ -28,6 +30,66 @@ def jacobi_precondition(A: torch.sparse_csc_tensor, b: torch.Tensor, norms_save_ return row_norms +def dist_jacobi_precondition( + A_local: torch.Tensor, + b: Optional[torch.Tensor], + norms_save_path: Optional[str] = None, +) -> torch.Tensor: + """ + Distributed Jacobi preconditioning for a column-partitioned matrix. + + Each rank holds a column slice ``A_local`` of shape ``(m, n_local)``. + The full row L2-norms are reconstructed via an ``all_reduce(SUM)`` over + the per-rank partial sums of squared values, then each rank scales its + local slice in-place. Rank 0 also scales ``b`` when provided. + + Requires ``torch.distributed`` to be initialized before calling. + + Parameters + ---------- + A_local : torch.sparse_csc_tensor + Local column slice of the constraint matrix, shape ``(m, n_local)``. + Modified in-place. + b : torch.Tensor or None + Right-hand-side vector of length ``m``. Pass the actual tensor on + rank 0 and ``None`` on all other ranks. Modified in-place when given. + norms_save_path : str, optional + If provided, rank 0 saves the row-norm tensor to this path so it can + be used later with :func:`jacobi_invert_precondition`. + + Returns + ------- + torch.Tensor + Dense 1-D tensor of length ``m`` with the full (global) row L2-norms. + """ + if not dist.is_initialized(): + raise RuntimeError( + "dist_jacobi_precondition requires torch.distributed to be initialized. " + "Call torch.distributed.init_process_group() before using this function." + ) + + n_rows = A_local.size(0) + row_idx = A_local.row_indices() + vals = A_local.values() + + local_sq = torch.zeros(n_rows, dtype=vals.dtype, device=vals.device) + local_sq.scatter_add_(0, row_idx.to(torch.long), vals.pow(2)) + + dist.all_reduce(local_sq, op=dist.ReduceOp.SUM) + row_norms = local_sq.pow(0.5) + + if dist.get_rank() == 0 and norms_save_path: + torch.save(row_norms, Path(norms_save_path)) + + reciprocal = 1.0 / row_norms + left_multiply_sparse(reciprocal, A_local, A_local) + + if b is not None: + b.mul_(reciprocal) + + return row_norms + + def jacobi_invert_precondition(dual_val: torch.Tensor, norms_path_or_tensor: str | torch.Tensor): """ Reverse the Jacobi pre-conditioning using row-norms saved on disk. diff --git a/tests/preprocessing/test_precondition.py b/tests/preprocessing/test_precondition.py index c40dac0..c218dff 100644 --- a/tests/preprocessing/test_precondition.py +++ b/tests/preprocessing/test_precondition.py @@ -1,9 +1,12 @@ +import os from pathlib import Path import pytest import torch +import torch.distributed as dist -from dualip.preprocessing.precondition import jacobi_invert_precondition, jacobi_precondition +from dualip.preprocessing.precondition import dist_jacobi_precondition, jacobi_invert_precondition, jacobi_precondition +from dualip.utils.sparse_utils import split_csc_by_cols ccol_indices = [0, 2, 3, 5, 8, 10, 12, 15, 16] row_indices = [2, 3, 3, 1, 2, 0, 1, 2, 0, 2, 0, 3, 1, 2, 3, 2] @@ -93,3 +96,103 @@ def test_invert_precondition(norms_path): reciprocal = 1.0 / row_norms assert torch.allclose(restored, reciprocal) + + +# --------------------------------------------------------------------------- +# Distributed Jacobi preconditioner tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif("RANK" in os.environ, reason="Only runs in non-distributed (no torchrun) context") +def test_dist_jacobi_precondition_raises_without_dist(): + """dist_jacobi_precondition must raise when torch.distributed is not initialized.""" + assert not dist.is_initialized(), "Expected dist to be uninitialized in this test" + with pytest.raises(RuntimeError, match="torch.distributed"): + dist_jacobi_precondition(A_test.clone(), b_test.clone()) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs") +@pytest.mark.skipif( + "RANK" not in os.environ, + reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...", +) +def test_dist_jacobi_precondition(): + """ + Distributed preconditioning gives the same row norms, A scaling, and b + scaling as the single-process version when data is split across 2 ranks. + """ + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + device = f"cuda:{rank}" + + A_full = A_test.to(device) + b_full = b_test.to(device) + + expected_row_norms = A_full.to_dense().norm(2, dim=1) + + n_cols = A_full.size(1) + split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)] + a_splits = split_csc_by_cols(A_full, split_sizes) + + A_local = a_splits[rank].to(device) + original_vals = A_local.values().clone() + original_row_idx = A_local.row_indices().clone() + + b_local = b_full.clone() if rank == 0 else None + + row_norms = dist_jacobi_precondition(A_local, b_local) + + assert torch.allclose(row_norms, expected_row_norms, atol=1e-5), f"Rank {rank}: row norms mismatch" + + expected_vals = original_vals / expected_row_norms[original_row_idx.to(torch.long)] + assert torch.allclose( + A_local.values(), expected_vals, atol=1e-5 + ), f"Rank {rank}: A_local values not scaled correctly" + + if rank == 0: + expected_b = b_full / expected_row_norms + assert torch.allclose(b_local, expected_b, atol=1e-5), "Rank 0: b not scaled correctly" + + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs") +@pytest.mark.skipif( + "RANK" not in os.environ, + reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...", +) +def test_dist_jacobi_precondition_saves_norms(tmp_path): + """Rank 0 saves row norms to disk; the saved norms match the global row norms.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + device = f"cuda:{rank}" + + norms_path = str(tmp_path / "row_norms.pt") if rank == 0 else None + + A_full = A_test.to(device) + n_cols = A_full.size(1) + split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)] + a_splits = split_csc_by_cols(A_full, split_sizes) + A_local = a_splits[rank].to(device) + b_local = b_test.to(device).clone() if rank == 0 else None + + row_norms = dist_jacobi_precondition(A_local, b_local, norms_save_path=norms_path) + + if rank == 0: + assert Path(norms_path).exists(), "Rank 0 did not save the norm file" + saved = torch.load(norms_path, map_location=device) + assert torch.allclose(saved, row_norms, atol=1e-5), "Saved norms differ from returned row norms" + + finally: + if dist.is_initialized(): + dist.destroy_process_group()