Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/dualip/gamma_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import warnings
from typing import Callable

from dualip.objectives.base import BaseObjective
from dualip.utils.mlflow_utils import log_metrics


def _interval_schedule(itr: int, gamma: float, p: dict) -> float:
"""Piecewise-constant schedule: gammas[i] is held for intervals[i] iterations.

Example: intervals=[100, 100, 500], gammas=[0.1, 0.1, 0.01]
itr in 1..100 -> 0.1
itr in 101..200 -> 0.1
itr in 201..700 -> 0.01
itr > 700 -> 0.01 (last value held)
"""
intervals = p["intervals"]
gammas = p["gammas"]
if len(intervals) != len(gammas):
raise ValueError(
f"'interval' schedule requires intervals and gammas of equal length, "
f"got {len(intervals)} and {len(gammas)}"
)
cumulative = 0
new_gamma = gammas[-1]
for length, g in zip(intervals, gammas):
cumulative += length
if itr <= cumulative:
new_gamma = g
break
if new_gamma > gamma:
warnings.warn(
f"'interval' schedule increased gamma from {gamma} to {new_gamma} at itr={itr}; "
f"gamma schedules are typically non-increasing.",
stacklevel=2,
)
return new_gamma


# Maps decay_type -> fn(itr, gamma, params) -> new_gamma
_SCHEDULES: dict[str, Callable[[int, float, dict], float]] = {
"step": lambda itr, gamma, p: (gamma * p["decay_factor"] if itr % p["decay_steps"] == 0 else gamma),
"interval": _interval_schedule,
}

_REQUIRED_PARAMS: dict[str, list[str]] = {
"step": ["decay_steps", "decay_factor"],
"interval": ["intervals", "gammas"],
}


class GammaScheduler:
"""
Drives gamma decay on the objective each optimizer iteration.

To add a new schedule type, register it in _SCHEDULES:
_SCHEDULES["my_type"] = lambda itr, gamma, params: new_gamma
"""

def __init__(
self,
objective: BaseObjective,
initial_gamma: float,
decay_type: str,
decay_params: dict,
):
if decay_type not in _SCHEDULES:
raise ValueError(f"Unsupported gamma decay type: {decay_type}")
required = _REQUIRED_PARAMS.get(decay_type, [])
missing = [k for k in required if k not in decay_params]
if missing:
raise ValueError(f"decay_params missing required keys for '{decay_type}': {missing}")
self.objective = objective
self.gamma = initial_gamma
self.decay_params = decay_params
self._schedule_fn = _SCHEDULES[decay_type]

def step(self, itr: int) -> None:
"""Called after each optimizer iteration. Updates gamma on the objective if a decay fires."""
new_gamma = self._schedule_fn(itr, self.gamma, self.decay_params)
if new_gamma != self.gamma:
self.gamma = new_gamma
self.objective.set_gamma(new_gamma)
log_metrics({"gamma": self.gamma}, step=itr)
4 changes: 4 additions & 0 deletions src/dualip/objectives/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ class BaseObjective(ABC):
@abstractmethod
def calculate(self) -> ObjectiveResult:
pass

def set_gamma(self, gamma: float) -> None:
"""Update the regularization parameter. Override in subclasses that use gamma."""
pass
35 changes: 12 additions & 23 deletions src/dualip/objectives/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dualip.objectives.base import BaseInputArgs, BaseObjective, ObjectiveResult
from dualip.projections.base import ProjectionEntry, project
from dualip.utils.objective_utils import calc_grad
from dualip.utils.sparse_utils import apply_F_to_columns, elementwise_csc, left_multiply_sparse, row_sums_csc


Expand All @@ -22,18 +23,6 @@ class MatchingInputArgs(BaseInputArgs):
equality_mask: torch.Tensor = None


def calc_grad(
dual_grad: torch.Tensor,
dual_obj: torch.Tensor,
dual_val: torch.Tensor,
b_vec: torch.Tensor,
reg_penalty: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
dual_grad = dual_grad - b_vec
dual_obj = dual_obj + reg_penalty + torch.dot(dual_val, dual_grad)
return dual_grad, dual_obj


class MatchingSolverDualObjectiveFunction(BaseObjective):
"""
Computes dual gradient, objective, and regularization penalty
Expand Down Expand Up @@ -113,25 +102,21 @@ def _compute_buckets(self, indices: list[torch.Tensor]) -> list[list[torch.Tenso
buckets.append(bucket)
return buckets

def calculate(
self, dual_val: torch.Tensor, gamma: float = None, save_primal: bool = False, **kwargs
) -> ObjectiveResult:
def set_gamma(self, gamma: float) -> None:
self.gamma = gamma
self.c_rescaled = -1.0 / gamma * self.c

def calculate(self, dual_val: torch.Tensor, save_primal: bool = False, **kwargs) -> ObjectiveResult:
"""
Compute dual gradient, objective, and reg penalty.

Args:
dual_val: current dual variables
gamma: regularization parameter
save_primal: if True, save the primal variable

Returns:
ObjectiveResult
"""
if gamma is not None and gamma != self.gamma:
self.gamma = gamma
# Recompute c_rescaled when gamma changes
self.c_rescaled = -1.0 / gamma * self.c

# -dual_val/gamma
scaled = -1.0 / self.gamma * dual_val

Expand Down Expand Up @@ -244,12 +229,16 @@ def __init__(
# Create single-GPU objective with local data
self.local_objective = MatchingSolverDualObjectiveFunction(local_matching_input_args, gamma, batching)

def set_gamma(self, gamma: float) -> None:
self.gamma = gamma
self.local_objective.set_gamma(gamma)

def calculate(
self,
dual_val: torch.Tensor,
gamma: float = None,
save_primal: bool = False,
rank: int = 0,
**kwargs,
) -> ObjectiveResult:
"""Compute and reduce gradients/objectives across all GPUs."""
if save_primal:
Expand All @@ -258,7 +247,7 @@ def calculate(
# dual_val is on cuda:rank (each rank has it on its own device)
# local_objective data is also on cuda:rank
# Compute local partition
objective_result = self.local_objective.calculate(dual_val, gamma, save_primal=False)
objective_result = self.local_objective.calculate(dual_val, save_primal=False)

# Keep results on local device (cuda:rank) for NCCL reduce
# NCCL expects each rank to have tensor on its own GPU
Expand Down
12 changes: 8 additions & 4 deletions src/dualip/objectives/miplib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class MIPLIB2017ObjectiveFunction(BaseObjective):
def __init__(
self,
miplib_input_args: MIPLIBInputArgs,
gamma: float = 1.0,
use_jacobi_precondition: bool = False,
):
self.gamma = gamma
self.A = miplib_input_args.A
# Store CSR and CSC versions of A for efficient computations if needed
self.A_csr = self.A.to_sparse_csr() if self.A.is_sparse else self.A
Expand All @@ -57,13 +59,15 @@ def __init__(
else:
self.row_norms = None

def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = False, **kwargs) -> ObjectiveResult:
def set_gamma(self, gamma: float) -> None:
self.gamma = gamma

def calculate(self, dual_val: torch.Tensor, save_primal: bool = False, **kwargs) -> ObjectiveResult:
"""
Compute dual gradient, objective, and reg penalty.

Args:
dual_val: current dual variables
gamma: regularization parameter
save_primal: if True, save the primal variable

Returns:
Expand All @@ -73,7 +77,7 @@ def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = Fa
if self.row_norms is not None:
dual_val = 1 / self.row_norms * dual_val

z = -1.0 / gamma * (self.A.T @ dual_val + self.c)
z = -1.0 / self.gamma * (self.A.T @ dual_val + self.c)

# Apply projection on z based on projection_map
projected_sol = z.clone()
Expand All @@ -94,7 +98,7 @@ def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = Fa
else:
dual_gradient = self.A_csr @ projected_sol - self.b_vec

reg_penalty = gamma / 2.0 * torch.norm(projected_sol) ** 2
reg_penalty = self.gamma / 2.0 * torch.norm(projected_sol) ** 2

dual_obj = self.c @ projected_sol + reg_penalty + dual_val @ (self.A_csr @ projected_sol - self.b_vec)
primal_obj = self.c @ projected_sol
Expand Down
71 changes: 30 additions & 41 deletions src/dualip/optimizers/agd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,26 @@ def _fmt(name, val):


class AcceleratedGradientDescent:
"""
Accelerated Gradient Descent optimizer (pure dual update).

Gamma scheduling is handled externally via step_callback passed to maximize().
See GammaScheduler for the built-in step / interval decay implementations.
"""

def __init__(
self,
max_iter: int,
gamma: float,
initial_step_size: float = 1e-5,
max_step_size: float = 0.1,
gamma_decay_type: str = None,
gamma_decay_params: dict = {},
save_primal: bool = False,
iteration_callback: Optional[Callable[[int, ObjectiveResult], None]] = None,
):

self.initial_step_size = initial_step_size
self.max_step_size = max_step_size
self.max_iter = max_iter
self.beta_seq = self._compute_beta_seq(self.max_iter)
self.streams = None
self.gamma = gamma
self.gamma_decay_type = gamma_decay_type
self.gamma_decay_params = gamma_decay_params
self.save_primal = save_primal
# Default behavior: print summary line each iteration; can be overridden by passing a callback
self.iteration_callback: Callable[[int, ObjectiveResult], None] = (
Expand All @@ -99,15 +99,6 @@ def _compute_beta_seq(self, max_iter: int) -> torch.Tensor:
beta_seq[i] = (1 - t_seq[i + 1]) / t_seq[i + 2]
return beta_seq

def _update_gamma(self, itr: int, step_size: float):
if self.gamma_decay_type == "step":
if itr % self.gamma_decay_params["decay_steps"] == 0:
decay_factor = self.gamma_decay_params["decay_factor"]
self.gamma = self.gamma * decay_factor
self.max_step_size = step_size * decay_factor
else:
raise ValueError(f"Unsupported gamma decay type: {self.gamma_decay_type}")

def _default_iteration_callback(self, iteration: int, objective_result: ObjectiveResult) -> None:
"""
Default iteration callback that prints a one-line summary.
Expand All @@ -118,23 +109,26 @@ def _default_iteration_callback(self, iteration: int, objective_result: Objectiv
# Ensure optimizer never crashes due to logging/printing
pass

def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0) -> SolverResult:
def maximize(
self,
f: BaseObjective,
initial_value: torch.Tensor,
rank: int = 0,
step_callback: Optional[Callable[[int], None]] = None,
) -> SolverResult:
"""
Maximizes the dual-primal objective function f.
f must provide a method:
- f.calculate(x) returning an object with attributes:
* dual_gradient (torch.Tensor)
* dual_objective (float)
* dual_val (torch.Tensor)
Maximizes the dual objective f.

Args:
f: The objective function to maximize
initial_value: Initial dual variable values
rank: Process rank for distributed training (default: 0 for single-GPU)

Returns a tuple: (final solution, final result, dual_obj_log, step_size_log),
where dual_obj_log is the list of dual objective values recorded at each iteration
and step_size_log is the list of the dynamic step size.
f: objective implementing BaseObjective.calculate(dual_val, save_primal).
Objectives that use a regularization parameter own it internally;
update it externally via f.set_gamma().
initial_value: starting dual variable
rank: distributed rank (0 = primary)
step_callback: optional callable(itr) -> None, called after each iteration.
Use this to drive gamma scheduling via GammaScheduler.

Returns a SolverResult with the final dual / objective and per-iteration logs.
"""
grad_history = []
dual_history = []
Expand All @@ -149,15 +143,11 @@ def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0)
i = 1
while i <= self.max_iter:

gamma_params = {"gamma": self.gamma} if self.gamma is not None else {}

# ALL ranks participate in calculate (for distributed objectives)
if i == self.max_iter and self.save_primal:
objective_result: ObjectiveResult = f.calculate(
dual_val=x, **gamma_params, save_primal=self.save_primal, rank=rank
)
objective_result: ObjectiveResult = f.calculate(dual_val=x, save_primal=self.save_primal, rank=rank)
else:
objective_result: ObjectiveResult = f.calculate(dual_val=x, **gamma_params, rank=rank)
objective_result: ObjectiveResult = f.calculate(dual_val=x, rank=rank)

# Only rank 0 performs optimizer updates
if rank == 0:
Expand All @@ -183,18 +173,17 @@ def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0)
# Accelerated update.
x = (y_new * (1.0 - self.beta_seq[i - 1])) + (y * self.beta_seq[i - 1])
y = y_new
if self.gamma is not None and self.gamma_decay_type is not None:
self._update_gamma(i, step_size)

# Drive external scheduling (e.g. gamma decay via GammaScheduler)
if step_callback is not None:
step_callback(i)

# Log iteration metrics (will check MLflow state internally)
iteration_metrics = {
"step_size": step_size,
"dual_objective": dual_obj,
}

if self.gamma is not None:
iteration_metrics["gamma"] = self.gamma

log_metrics(iteration_metrics, step=i)

# Log objective result details
Expand Down
Loading
Loading