Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/dualip/preprocessing/precondition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import Optional

import torch
import torch.distributed as dist

from dualip.utils.sparse_utils import left_multiply_sparse, row_norms_csc

Expand Down Expand Up @@ -28,6 +30,66 @@ def jacobi_precondition(A: torch.sparse_csc_tensor, b: torch.Tensor, norms_save_
return row_norms


def dist_jacobi_precondition(
A_local: torch.Tensor,
b: Optional[torch.Tensor],
norms_save_path: Optional[str] = None,
) -> torch.Tensor:
"""
Distributed Jacobi preconditioning for a column-partitioned matrix.

Each rank holds a column slice ``A_local`` of shape ``(m, n_local)``.
The full row L2-norms are reconstructed via an ``all_reduce(SUM)`` over
the per-rank partial sums of squared values, then each rank scales its
local slice in-place. Rank 0 also scales ``b`` when provided.

Requires ``torch.distributed`` to be initialized before calling.

Parameters
----------
A_local : torch.sparse_csc_tensor
Local column slice of the constraint matrix, shape ``(m, n_local)``.
Modified in-place.
b : torch.Tensor or None
Right-hand-side vector of length ``m``. Pass the actual tensor on
rank 0 and ``None`` on all other ranks. Modified in-place when given.
norms_save_path : str, optional
If provided, rank 0 saves the row-norm tensor to this path so it can
be used later with :func:`jacobi_invert_precondition`.

Returns
-------
torch.Tensor
Dense 1-D tensor of length ``m`` with the full (global) row L2-norms.
"""
if not dist.is_initialized():
raise RuntimeError(
"dist_jacobi_precondition requires torch.distributed to be initialized. "
"Call torch.distributed.init_process_group() before using this function."
)

n_rows = A_local.size(0)
row_idx = A_local.row_indices()
vals = A_local.values()

local_sq = torch.zeros(n_rows, dtype=vals.dtype, device=vals.device)
local_sq.scatter_add_(0, row_idx.to(torch.long), vals.pow(2))

dist.all_reduce(local_sq, op=dist.ReduceOp.SUM)
row_norms = local_sq.pow(0.5)

if dist.get_rank() == 0 and norms_save_path:
torch.save(row_norms, Path(norms_save_path))

reciprocal = 1.0 / row_norms
left_multiply_sparse(reciprocal, A_local, A_local)

if b is not None:
b.mul_(reciprocal)

return row_norms


def jacobi_invert_precondition(dual_val: torch.Tensor, norms_path_or_tensor: str | torch.Tensor):
"""
Reverse the Jacobi pre-conditioning using row-norms saved on disk.
Expand Down
105 changes: 104 additions & 1 deletion tests/preprocessing/test_precondition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from pathlib import Path

import pytest
import torch
import torch.distributed as dist

from dualip.preprocessing.precondition import jacobi_invert_precondition, jacobi_precondition
from dualip.preprocessing.precondition import dist_jacobi_precondition, jacobi_invert_precondition, jacobi_precondition
from dualip.utils.sparse_utils import split_csc_by_cols

ccol_indices = [0, 2, 3, 5, 8, 10, 12, 15, 16]
row_indices = [2, 3, 3, 1, 2, 0, 1, 2, 0, 2, 0, 3, 1, 2, 3, 2]
Expand Down Expand Up @@ -93,3 +96,103 @@ def test_invert_precondition(norms_path):
reciprocal = 1.0 / row_norms

assert torch.allclose(restored, reciprocal)


# ---------------------------------------------------------------------------
# Distributed Jacobi preconditioner tests
# ---------------------------------------------------------------------------


@pytest.mark.skipif("RANK" in os.environ, reason="Only runs in non-distributed (no torchrun) context")
def test_dist_jacobi_precondition_raises_without_dist():
"""dist_jacobi_precondition must raise when torch.distributed is not initialized."""
assert not dist.is_initialized(), "Expected dist to be uninitialized in this test"
with pytest.raises(RuntimeError, match="torch.distributed"):
dist_jacobi_precondition(A_test.clone(), b_test.clone())


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs")
@pytest.mark.skipif(
"RANK" not in os.environ,
reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...",
)
def test_dist_jacobi_precondition():
"""
Distributed preconditioning gives the same row norms, A scaling, and b
scaling as the single-process version when data is split across 2 ranks.
"""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")

try:
rank = dist.get_rank()
world_size = dist.get_world_size()
device = f"cuda:{rank}"

A_full = A_test.to(device)
b_full = b_test.to(device)

expected_row_norms = A_full.to_dense().norm(2, dim=1)

n_cols = A_full.size(1)
split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)]
a_splits = split_csc_by_cols(A_full, split_sizes)

A_local = a_splits[rank].to(device)
original_vals = A_local.values().clone()
original_row_idx = A_local.row_indices().clone()

b_local = b_full.clone() if rank == 0 else None

row_norms = dist_jacobi_precondition(A_local, b_local)

assert torch.allclose(row_norms, expected_row_norms, atol=1e-5), f"Rank {rank}: row norms mismatch"

expected_vals = original_vals / expected_row_norms[original_row_idx.to(torch.long)]
assert torch.allclose(
A_local.values(), expected_vals, atol=1e-5
), f"Rank {rank}: A_local values not scaled correctly"

if rank == 0:
expected_b = b_full / expected_row_norms
assert torch.allclose(b_local, expected_b, atol=1e-5), "Rank 0: b not scaled correctly"

finally:
if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA GPUs")
@pytest.mark.skipif(
"RANK" not in os.environ,
reason="Requires torchrun - run with: torchrun --nproc_per_node=2 -m pytest ...",
)
def test_dist_jacobi_precondition_saves_norms(tmp_path):
"""Rank 0 saves row norms to disk; the saved norms match the global row norms."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")

try:
rank = dist.get_rank()
world_size = dist.get_world_size()
device = f"cuda:{rank}"

norms_path = str(tmp_path / "row_norms.pt") if rank == 0 else None

A_full = A_test.to(device)
n_cols = A_full.size(1)
split_sizes = [n_cols // world_size + (1 if i < n_cols % world_size else 0) for i in range(world_size)]
a_splits = split_csc_by_cols(A_full, split_sizes)
A_local = a_splits[rank].to(device)
b_local = b_test.to(device).clone() if rank == 0 else None

row_norms = dist_jacobi_precondition(A_local, b_local, norms_save_path=norms_path)

if rank == 0:
assert Path(norms_path).exists(), "Rank 0 did not save the norm file"
saved = torch.load(norms_path, map_location=device)
assert torch.allclose(saved, row_norms, atol=1e-5), "Saved norms differ from returned row norms"

finally:
if dist.is_initialized():
dist.destroy_process_group()
Loading