diff --git a/src/dualip/gamma_scheduler.py b/src/dualip/gamma_scheduler.py new file mode 100644 index 0000000..458af8e --- /dev/null +++ b/src/dualip/gamma_scheduler.py @@ -0,0 +1,84 @@ +import warnings +from typing import Callable + +from dualip.objectives.base import BaseObjective +from dualip.utils.mlflow_utils import log_metrics + + +def _interval_schedule(itr: int, gamma: float, p: dict) -> float: + """Piecewise-constant schedule: gammas[i] is held for intervals[i] iterations. + + Example: intervals=[100, 100, 500], gammas=[0.1, 0.1, 0.01] + itr in 1..100 -> 0.1 + itr in 101..200 -> 0.1 + itr in 201..700 -> 0.01 + itr > 700 -> 0.01 (last value held) + """ + intervals = p["intervals"] + gammas = p["gammas"] + if len(intervals) != len(gammas): + raise ValueError( + f"'interval' schedule requires intervals and gammas of equal length, " + f"got {len(intervals)} and {len(gammas)}" + ) + cumulative = 0 + new_gamma = gammas[-1] + for length, g in zip(intervals, gammas): + cumulative += length + if itr <= cumulative: + new_gamma = g + break + if new_gamma > gamma: + warnings.warn( + f"'interval' schedule increased gamma from {gamma} to {new_gamma} at itr={itr}; " + f"gamma schedules are typically non-increasing.", + stacklevel=2, + ) + return new_gamma + + +# Maps decay_type -> fn(itr, gamma, params) -> new_gamma +_SCHEDULES: dict[str, Callable[[int, float, dict], float]] = { + "step": lambda itr, gamma, p: (gamma * p["decay_factor"] if itr % p["decay_steps"] == 0 else gamma), + "interval": _interval_schedule, +} + +_REQUIRED_PARAMS: dict[str, list[str]] = { + "step": ["decay_steps", "decay_factor"], + "interval": ["intervals", "gammas"], +} + + +class GammaScheduler: + """ + Drives gamma decay on the objective each optimizer iteration. + + To add a new schedule type, register it in _SCHEDULES: + _SCHEDULES["my_type"] = lambda itr, gamma, params: new_gamma + """ + + def __init__( + self, + objective: BaseObjective, + initial_gamma: float, + decay_type: str, + decay_params: dict, + ): + if decay_type not in _SCHEDULES: + raise ValueError(f"Unsupported gamma decay type: {decay_type}") + required = _REQUIRED_PARAMS.get(decay_type, []) + missing = [k for k in required if k not in decay_params] + if missing: + raise ValueError(f"decay_params missing required keys for '{decay_type}': {missing}") + self.objective = objective + self.gamma = initial_gamma + self.decay_params = decay_params + self._schedule_fn = _SCHEDULES[decay_type] + + def step(self, itr: int) -> None: + """Called after each optimizer iteration. Updates gamma on the objective if a decay fires.""" + new_gamma = self._schedule_fn(itr, self.gamma, self.decay_params) + if new_gamma != self.gamma: + self.gamma = new_gamma + self.objective.set_gamma(new_gamma) + log_metrics({"gamma": self.gamma}, step=itr) diff --git a/src/dualip/objectives/base.py b/src/dualip/objectives/base.py index d307485..d8c33f0 100644 --- a/src/dualip/objectives/base.py +++ b/src/dualip/objectives/base.py @@ -24,3 +24,7 @@ class BaseObjective(ABC): @abstractmethod def calculate(self) -> ObjectiveResult: pass + + def set_gamma(self, gamma: float) -> None: + """Update the regularization parameter. Override in subclasses that use gamma.""" + pass diff --git a/src/dualip/objectives/matching.py b/src/dualip/objectives/matching.py index 3ff6b5d..556b872 100644 --- a/src/dualip/objectives/matching.py +++ b/src/dualip/objectives/matching.py @@ -6,6 +6,7 @@ from dualip.objectives.base import BaseInputArgs, BaseObjective, ObjectiveResult from dualip.projections.base import ProjectionEntry, project +from dualip.utils.objective_utils import calc_grad from dualip.utils.sparse_utils import apply_F_to_columns, elementwise_csc, left_multiply_sparse, row_sums_csc @@ -22,18 +23,6 @@ class MatchingInputArgs(BaseInputArgs): equality_mask: torch.Tensor = None -def calc_grad( - dual_grad: torch.Tensor, - dual_obj: torch.Tensor, - dual_val: torch.Tensor, - b_vec: torch.Tensor, - reg_penalty: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - dual_grad = dual_grad - b_vec - dual_obj = dual_obj + reg_penalty + torch.dot(dual_val, dual_grad) - return dual_grad, dual_obj - - class MatchingSolverDualObjectiveFunction(BaseObjective): """ Computes dual gradient, objective, and regularization penalty @@ -113,25 +102,21 @@ def _compute_buckets(self, indices: list[torch.Tensor]) -> list[list[torch.Tenso buckets.append(bucket) return buckets - def calculate( - self, dual_val: torch.Tensor, gamma: float = None, save_primal: bool = False, **kwargs - ) -> ObjectiveResult: + def set_gamma(self, gamma: float) -> None: + self.gamma = gamma + self.c_rescaled = -1.0 / gamma * self.c + + def calculate(self, dual_val: torch.Tensor, save_primal: bool = False, **kwargs) -> ObjectiveResult: """ Compute dual gradient, objective, and reg penalty. Args: dual_val: current dual variables - gamma: regularization parameter save_primal: if True, save the primal variable Returns: ObjectiveResult """ - if gamma is not None and gamma != self.gamma: - self.gamma = gamma - # Recompute c_rescaled when gamma changes - self.c_rescaled = -1.0 / gamma * self.c - # -dual_val/gamma scaled = -1.0 / self.gamma * dual_val @@ -244,12 +229,16 @@ def __init__( # Create single-GPU objective with local data self.local_objective = MatchingSolverDualObjectiveFunction(local_matching_input_args, gamma, batching) + def set_gamma(self, gamma: float) -> None: + self.gamma = gamma + self.local_objective.set_gamma(gamma) + def calculate( self, dual_val: torch.Tensor, - gamma: float = None, save_primal: bool = False, rank: int = 0, + **kwargs, ) -> ObjectiveResult: """Compute and reduce gradients/objectives across all GPUs.""" if save_primal: @@ -258,7 +247,7 @@ def calculate( # dual_val is on cuda:rank (each rank has it on its own device) # local_objective data is also on cuda:rank # Compute local partition - objective_result = self.local_objective.calculate(dual_val, gamma, save_primal=False) + objective_result = self.local_objective.calculate(dual_val, save_primal=False) # Keep results on local device (cuda:rank) for NCCL reduce # NCCL expects each rank to have tensor on its own GPU diff --git a/src/dualip/objectives/miplib.py b/src/dualip/objectives/miplib.py index 7411b90..041d65f 100644 --- a/src/dualip/objectives/miplib.py +++ b/src/dualip/objectives/miplib.py @@ -34,8 +34,10 @@ class MIPLIB2017ObjectiveFunction(BaseObjective): def __init__( self, miplib_input_args: MIPLIBInputArgs, + gamma: float = 1.0, use_jacobi_precondition: bool = False, ): + self.gamma = gamma self.A = miplib_input_args.A # Store CSR and CSC versions of A for efficient computations if needed self.A_csr = self.A.to_sparse_csr() if self.A.is_sparse else self.A @@ -57,13 +59,15 @@ def __init__( else: self.row_norms = None - def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = False, **kwargs) -> ObjectiveResult: + def set_gamma(self, gamma: float) -> None: + self.gamma = gamma + + def calculate(self, dual_val: torch.Tensor, save_primal: bool = False, **kwargs) -> ObjectiveResult: """ Compute dual gradient, objective, and reg penalty. Args: dual_val: current dual variables - gamma: regularization parameter save_primal: if True, save the primal variable Returns: @@ -73,7 +77,7 @@ def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = Fa if self.row_norms is not None: dual_val = 1 / self.row_norms * dual_val - z = -1.0 / gamma * (self.A.T @ dual_val + self.c) + z = -1.0 / self.gamma * (self.A.T @ dual_val + self.c) # Apply projection on z based on projection_map projected_sol = z.clone() @@ -94,7 +98,7 @@ def calculate(self, dual_val: torch.Tensor, gamma: float, save_primal: bool = Fa else: dual_gradient = self.A_csr @ projected_sol - self.b_vec - reg_penalty = gamma / 2.0 * torch.norm(projected_sol) ** 2 + reg_penalty = self.gamma / 2.0 * torch.norm(projected_sol) ** 2 dual_obj = self.c @ projected_sol + reg_penalty + dual_val @ (self.A_csr @ projected_sol - self.b_vec) primal_obj = self.c @ projected_sol diff --git a/src/dualip/optimizers/agd.py b/src/dualip/optimizers/agd.py index 8f45d5e..7006ac6 100644 --- a/src/dualip/optimizers/agd.py +++ b/src/dualip/optimizers/agd.py @@ -64,26 +64,26 @@ def _fmt(name, val): class AcceleratedGradientDescent: + """ + Accelerated Gradient Descent optimizer (pure dual update). + + Gamma scheduling is handled externally via step_callback passed to maximize(). + See GammaScheduler for the built-in step / interval decay implementations. + """ + def __init__( self, max_iter: int, - gamma: float, initial_step_size: float = 1e-5, max_step_size: float = 0.1, - gamma_decay_type: str = None, - gamma_decay_params: dict = {}, save_primal: bool = False, iteration_callback: Optional[Callable[[int, ObjectiveResult], None]] = None, ): - self.initial_step_size = initial_step_size self.max_step_size = max_step_size self.max_iter = max_iter self.beta_seq = self._compute_beta_seq(self.max_iter) self.streams = None - self.gamma = gamma - self.gamma_decay_type = gamma_decay_type - self.gamma_decay_params = gamma_decay_params self.save_primal = save_primal # Default behavior: print summary line each iteration; can be overridden by passing a callback self.iteration_callback: Callable[[int, ObjectiveResult], None] = ( @@ -99,15 +99,6 @@ def _compute_beta_seq(self, max_iter: int) -> torch.Tensor: beta_seq[i] = (1 - t_seq[i + 1]) / t_seq[i + 2] return beta_seq - def _update_gamma(self, itr: int, step_size: float): - if self.gamma_decay_type == "step": - if itr % self.gamma_decay_params["decay_steps"] == 0: - decay_factor = self.gamma_decay_params["decay_factor"] - self.gamma = self.gamma * decay_factor - self.max_step_size = step_size * decay_factor - else: - raise ValueError(f"Unsupported gamma decay type: {self.gamma_decay_type}") - def _default_iteration_callback(self, iteration: int, objective_result: ObjectiveResult) -> None: """ Default iteration callback that prints a one-line summary. @@ -118,23 +109,26 @@ def _default_iteration_callback(self, iteration: int, objective_result: Objectiv # Ensure optimizer never crashes due to logging/printing pass - def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0) -> SolverResult: + def maximize( + self, + f: BaseObjective, + initial_value: torch.Tensor, + rank: int = 0, + step_callback: Optional[Callable[[int], None]] = None, + ) -> SolverResult: """ - Maximizes the dual-primal objective function f. - f must provide a method: - - f.calculate(x) returning an object with attributes: - * dual_gradient (torch.Tensor) - * dual_objective (float) - * dual_val (torch.Tensor) + Maximizes the dual objective f. Args: - f: The objective function to maximize - initial_value: Initial dual variable values - rank: Process rank for distributed training (default: 0 for single-GPU) - - Returns a tuple: (final solution, final result, dual_obj_log, step_size_log), - where dual_obj_log is the list of dual objective values recorded at each iteration - and step_size_log is the list of the dynamic step size. + f: objective implementing BaseObjective.calculate(dual_val, save_primal). + Objectives that use a regularization parameter own it internally; + update it externally via f.set_gamma(). + initial_value: starting dual variable + rank: distributed rank (0 = primary) + step_callback: optional callable(itr) -> None, called after each iteration. + Use this to drive gamma scheduling via GammaScheduler. + + Returns a SolverResult with the final dual / objective and per-iteration logs. """ grad_history = [] dual_history = [] @@ -149,15 +143,11 @@ def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0) i = 1 while i <= self.max_iter: - gamma_params = {"gamma": self.gamma} if self.gamma is not None else {} - # ALL ranks participate in calculate (for distributed objectives) if i == self.max_iter and self.save_primal: - objective_result: ObjectiveResult = f.calculate( - dual_val=x, **gamma_params, save_primal=self.save_primal, rank=rank - ) + objective_result: ObjectiveResult = f.calculate(dual_val=x, save_primal=self.save_primal, rank=rank) else: - objective_result: ObjectiveResult = f.calculate(dual_val=x, **gamma_params, rank=rank) + objective_result: ObjectiveResult = f.calculate(dual_val=x, rank=rank) # Only rank 0 performs optimizer updates if rank == 0: @@ -183,8 +173,10 @@ def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0) # Accelerated update. x = (y_new * (1.0 - self.beta_seq[i - 1])) + (y * self.beta_seq[i - 1]) y = y_new - if self.gamma is not None and self.gamma_decay_type is not None: - self._update_gamma(i, step_size) + + # Drive external scheduling (e.g. gamma decay via GammaScheduler) + if step_callback is not None: + step_callback(i) # Log iteration metrics (will check MLflow state internally) iteration_metrics = { @@ -192,9 +184,6 @@ def maximize(self, f: BaseObjective, initial_value: torch.Tensor, rank: int = 0) "dual_objective": dual_obj, } - if self.gamma is not None: - iteration_metrics["gamma"] = self.gamma - log_metrics(iteration_metrics, step=i) # Log objective result details diff --git a/src/dualip/run_solver.py b/src/dualip/run_solver.py index 31a3c1e..9e409d6 100644 --- a/src/dualip/run_solver.py +++ b/src/dualip/run_solver.py @@ -3,6 +3,7 @@ import torch +from dualip.gamma_scheduler import GammaScheduler from dualip.objectives.base import BaseInputArgs from dualip.objectives.matching import ( MatchingSolverDualObjectiveFunction, @@ -53,7 +54,9 @@ def build_objective( if objective_type == "miplib2017": objective_kwargs = objective_kwargs or {} - objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args, **objective_kwargs) + objective = MIPLIB2017ObjectiveFunction( + miplib_input_args=input_args, gamma=solver_args.gamma, **objective_kwargs + ) elif objective_type == "matching": if compute_device_num == 1: objective = MatchingSolverDualObjectiveFunction(matching_input_args=input_args, gamma=solver_args.gamma) @@ -117,9 +120,6 @@ def run_solver( initial_step_size=solver_args.initial_step_size, max_iter=solver_args.max_iter, max_step_size=solver_args.max_step_size, - gamma=solver_args.gamma, - gamma_decay_type=solver_args.gamma_decay_type, - gamma_decay_params=solver_args.gamma_decay_params, save_primal=solver_args.save_primal, ) @@ -131,7 +131,18 @@ def run_solver( ) initial_dual = initial_dual.to(host_device) - solver_result = solver.maximize(objective, initial_dual) + # Wire up gamma scheduling when configured + step_callback = None + if solver_args.gamma_decay_type is not None: + scheduler = GammaScheduler( + objective=objective, + initial_gamma=solver_args.gamma, + decay_type=solver_args.gamma_decay_type, + decay_params=solver_args.gamma_decay_params or {}, + ) + step_callback = scheduler.step + + solver_result = solver.maximize(objective, initial_dual, step_callback=step_callback) use_jacobi_precondition = getattr(objective, "use_jacobi_precondition", None) if use_jacobi_precondition: diff --git a/src/dualip/types.py b/src/dualip/types.py index 6911e6c..f9cdb59 100644 --- a/src/dualip/types.py +++ b/src/dualip/types.py @@ -11,7 +11,7 @@ class SolverArgs: gamma: float = 1e-3 max_step_size: float = 0.1 initial_dual_path: Optional[str] = None - gamma_decay_type: Optional[Literal["step"]] = None + gamma_decay_type: Optional[Literal["step", "interval"]] = None gamma_decay_params: Optional[dict] = None save_primal: bool = False diff --git a/src/dualip/utils/objective_utils.py b/src/dualip/utils/objective_utils.py new file mode 100644 index 0000000..108fcad --- /dev/null +++ b/src/dualip/utils/objective_utils.py @@ -0,0 +1,13 @@ +import torch + + +def calc_grad( + dual_grad: torch.Tensor, + dual_obj: torch.Tensor, + dual_val: torch.Tensor, + b_vec: torch.Tensor, + reg_penalty: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + dual_grad = dual_grad - b_vec + dual_obj = dual_obj + reg_penalty + torch.dot(dual_val, dual_grad) + return dual_grad, dual_obj diff --git a/tests/distributed/test_matching_distributed.py b/tests/distributed/test_matching_distributed.py index 9308b8d..1670672 100644 --- a/tests/distributed/test_matching_distributed.py +++ b/tests/distributed/test_matching_distributed.py @@ -175,7 +175,7 @@ def test_simplex_solver_inequality_distributed(init_distributed): initial_dual = 0.1 * torch.ones(5, device=device) - solver = AcceleratedGradientDescent(max_iter=30, gamma=gamma) + solver = AcceleratedGradientDescent(max_iter=30) solver_result = solver.maximize(f, initial_dual, rank=rank) # Only rank 0 checks results diff --git a/tests/objectives/test_dualip_matching_simplex.py b/tests/objectives/test_dualip_matching_simplex.py index 14aa7ea..72b04fc 100644 --- a/tests/objectives/test_dualip_matching_simplex.py +++ b/tests/objectives/test_dualip_matching_simplex.py @@ -122,7 +122,7 @@ def test_simplex_solver_inequality(): initial_dual = 0.1 * torch.ones(5, device=HOST_DEVICE) - solver = AcceleratedGradientDescent(max_iter=30, gamma=gamma) + solver = AcceleratedGradientDescent(max_iter=30) solver_result = solver.maximize(objective, initial_dual) diff --git a/tests/objectives/test_miplib_objective.py b/tests/objectives/test_miplib_objective.py index 3a4ab7c..acab65e 100644 --- a/tests/objectives/test_miplib_objective.py +++ b/tests/objectives/test_miplib_objective.py @@ -112,12 +112,9 @@ def test_miplib_general_convergence_criteria_III(): b_vec=b, equality_mask=equality_mask, ) - objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args) + objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args, gamma=0.001) solver = AcceleratedGradientDescent( max_iter=500, - gamma=0.001, - gamma_decay_type=None, - gamma_decay_params=None, save_primal=True, ) initial_dual = torch.zeros(2) @@ -148,14 +145,11 @@ def test_miplib_convergence_with_one_sided_x_bound_I(): b_vec=b, equality_mask=equality_mask, ) - objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args) + objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args, gamma=0.001) solver = AcceleratedGradientDescent( initial_step_size=1e-6, max_step_size=1e-5, max_iter=10000, - gamma=0.001, - gamma_decay_type=None, - gamma_decay_params=None, save_primal=True, ) initial_dual = torch.zeros(2) @@ -186,14 +180,11 @@ def test_miplib_convergence_with_one_sided_x_bound_II(): b_vec=b, equality_mask=equality_mask, ) - objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args) + objective = MIPLIB2017ObjectiveFunction(miplib_input_args=input_args, gamma=0.001) solver = AcceleratedGradientDescent( initial_step_size=1e-6, max_step_size=1e-5, max_iter=10000, - gamma=0.001, - gamma_decay_type=None, - gamma_decay_params=None, save_primal=True, ) initial_dual = torch.zeros(2) diff --git a/tests/test_agd.py b/tests/test_agd.py index a863adf..3c03387 100644 --- a/tests/test_agd.py +++ b/tests/test_agd.py @@ -52,7 +52,7 @@ def test_quadratic_1d_function(): default_step_size = 1e-5 # Test with the default initial_step_size. - solver_default = AcceleratedGradientDescent(max_iter=1, gamma=None) + solver_default = AcceleratedGradientDescent(max_iter=1) solver_default_result = solver_default.maximize(Quadratic1DObjective(), torch.tensor([0.0], device=HOST_DEVICE)) assert abs(solver_default_result.dual_val[0] - (initial_gradient * default_step_size)) < 1e-10, ( f"Test fails for default initialStepSize: expected {initial_gradient * default_step_size}, " @@ -61,7 +61,7 @@ def test_quadratic_1d_function(): # Test with a new initial_step_size. new_step_size = 0.1 - solver_new_step_size = AcceleratedGradientDescent(max_iter=1, gamma=None, initial_step_size=new_step_size) + solver_new_step_size = AcceleratedGradientDescent(max_iter=1, initial_step_size=new_step_size) solver_new_step_size_result = solver_new_step_size.maximize( Quadratic1DObjective(), torch.tensor([0.0], device=HOST_DEVICE) ) @@ -85,7 +85,7 @@ def test_simple_objective_dual_value(): # With a very small step, the dual objective value will increase slightly toward the optimum f(3,0) = -25. default_step_size = 1e-5 - solver = AcceleratedGradientDescent(max_iter=30, gamma=None, initial_step_size=default_step_size) + solver = AcceleratedGradientDescent(max_iter=30, initial_step_size=default_step_size) solver_result = solver.maximize(SimpleObjective(), torch.tensor([0.0, 0.0], device=HOST_DEVICE)) for i, (dual, step) in enumerate(zip(solver_result.dual_objective_log[:25], solver_result.step_size_log[:25])): print(f"Iteration: {i + 1}. Dual: {dual}. Step: {step}") diff --git a/tests/test_equality_constraints.py b/tests/test_equality_constraints.py index 3c6270d..01dec6d 100644 --- a/tests/test_equality_constraints.py +++ b/tests/test_equality_constraints.py @@ -47,9 +47,10 @@ def test_solver_with_equality_constraint(): ) objective = MIPLIB2017ObjectiveFunction( miplib_input_args=input_args, + gamma=gamma, ) - solver = AcceleratedGradientDescent(max_iter=1000, gamma=gamma) + solver = AcceleratedGradientDescent(max_iter=1000) solver_result = solver.maximize(objective, initial_dual) # Verify the solution is correct within tolerance diff --git a/tests/test_gamma_scheduler.py b/tests/test_gamma_scheduler.py new file mode 100644 index 0000000..d728d54 --- /dev/null +++ b/tests/test_gamma_scheduler.py @@ -0,0 +1,240 @@ +import re + +import pytest +import torch + +from dualip.gamma_scheduler import GammaScheduler +from dualip.objectives.base import BaseObjective +from dualip.optimizers.agd import AcceleratedGradientDescent +from dualip.types import ObjectiveResult + + +class GammaTrackingObjective(BaseObjective): + """Minimal objective that records every gamma it is given via set_gamma().""" + + def __init__(self, gamma: float): + self.gamma = gamma + self.gamma_history: list[float] = [gamma] + self.equality_mask = None + + def set_gamma(self, gamma: float) -> None: + self.gamma = gamma + self.gamma_history.append(gamma) + + def calculate(self, dual_val: torch.Tensor, save_primal: bool = False, **kwargs) -> ObjectiveResult: + grad = -2.0 * dual_val + obj = -(dual_val**2).sum() + return ObjectiveResult(dual_gradient=grad, dual_objective=obj) + + +def test_step_scheduler_calls_set_gamma(): + """GammaScheduler with 'step' decay calls set_gamma on the objective at the right iterations.""" + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="step", + decay_params={"decay_steps": 3, "decay_factor": 0.5}, + ) + + for itr in range(1, 7): + scheduler.step(itr) + + # Decay fires at itr=3 and itr=6 + assert len(objective.gamma_history) == 3 # initial + 2 decays + assert objective.gamma_history[0] == pytest.approx(1.0) + assert objective.gamma_history[1] == pytest.approx(0.5) + assert objective.gamma_history[2] == pytest.approx(0.25) + + +def test_step_scheduler_decays_gamma_at_correct_iterations(): + """GammaScheduler decays gamma at the right iterations and leaves it unchanged otherwise.""" + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="step", + decay_params={"decay_steps": 2, "decay_factor": 0.5}, + ) + + scheduler.step(1) # no decay + assert scheduler.gamma == pytest.approx(1.0) + scheduler.step(2) # decay fires + assert scheduler.gamma == pytest.approx(0.5) + + +def test_agd_step_callback_drives_gamma(): + """step_callback correctly drives GammaScheduler which updates objective gamma.""" + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="step", + decay_params={"decay_steps": 5, "decay_factor": 0.5}, + ) + + solver = AcceleratedGradientDescent( + max_iter=10, + initial_step_size=1e-3, + max_step_size=0.1, + ) + initial = torch.zeros(2) + solver.maximize(objective, initial, step_callback=scheduler.step) + + # Decay fires at itr=5 and itr=10 → 2 updates on top of initial + assert len(objective.gamma_history) == 3 + assert objective.gamma_history[1] == pytest.approx(0.5) + assert objective.gamma_history[2] == pytest.approx(0.25) + + +def test_unsupported_decay_type_raises(): + objective = GammaTrackingObjective(gamma=1.0) + with pytest.raises(ValueError, match="Unsupported gamma decay type"): + GammaScheduler(objective=objective, initial_gamma=1.0, decay_type="none", decay_params={}) + + +def test_missing_decay_params_raises(): + objective = GammaTrackingObjective(gamma=1.0) + expected_message = "decay_params missing required keys for 'step': ['decay_steps', 'decay_factor']" + with pytest.raises(ValueError, match=re.escape(expected_message)): + GammaScheduler(objective=objective, initial_gamma=1.0, decay_type="step", decay_params={}) + + +def test_agd_with_gamma_scheduler_decays_gamma(): + """GammaScheduler decays objective gamma at the right iterations when wired via step_callback.""" + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="step", + decay_params={"decay_steps": 3, "decay_factor": 0.5}, + ) + solver = AcceleratedGradientDescent( + max_iter=6, + initial_step_size=1e-3, + max_step_size=0.1, + ) + solver.maximize(objective, torch.zeros(2), step_callback=scheduler.step) + + # Decay fires at itr=3 and itr=6 → gamma halved twice + assert scheduler.gamma == pytest.approx(0.25) + assert len(objective.gamma_history) == 3 + # solver's max_step_size is untouched + assert solver.max_step_size == pytest.approx(0.1) + + +def test_interval_scheduler_piecewise_constant_gamma(): + """GammaScheduler with 'interval' decay holds gammas[i] for intervals[i] iterations.""" + objective = GammaTrackingObjective(gamma=0.1) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=0.1, + decay_type="interval", + decay_params={"intervals": [3, 2, 4], "gammas": [0.1, 0.05, 0.01]}, + ) + + # itrs 1..3 -> 0.1 (no change from initial) + for itr in range(1, 4): + scheduler.step(itr) + assert scheduler.gamma == pytest.approx(0.1) + assert len(objective.gamma_history) == 1 + + # itrs 4..5 -> 0.05 + scheduler.step(4) + assert scheduler.gamma == pytest.approx(0.05) + scheduler.step(5) + assert scheduler.gamma == pytest.approx(0.05) + assert len(objective.gamma_history) == 2 + + # itrs 6..9 -> 0.01 + scheduler.step(6) + assert scheduler.gamma == pytest.approx(0.01) + for itr in range(7, 10): + scheduler.step(itr) + assert scheduler.gamma == pytest.approx(0.01) + + # past end -> last value held + scheduler.step(100) + assert scheduler.gamma == pytest.approx(0.01) + assert objective.gamma_history == pytest.approx([0.1, 0.05, 0.01]) + + +def test_interval_scheduler_mismatched_lengths_raises(): + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="interval", + decay_params={"intervals": [10, 20], "gammas": [0.1, 0.05, 0.01]}, + ) + with pytest.raises(ValueError, match="intervals and gammas of equal length"): + scheduler.step(1) + + +def test_interval_scheduler_missing_params_raises(): + objective = GammaTrackingObjective(gamma=1.0) + expected_message = "decay_params missing required keys for 'interval': ['intervals', 'gammas']" + with pytest.raises(ValueError, match=re.escape(expected_message)): + GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="interval", + decay_params={}, + ) + + +def test_interval_scheduler_warns_on_gamma_increase(): + """Warn when the schedule would raise gamma above the current value.""" + objective = GammaTrackingObjective(gamma=0.01) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=0.01, + decay_type="interval", + decay_params={"intervals": [2, 2], "gammas": [0.1, 0.05]}, + ) + with pytest.warns(UserWarning, match="increased gamma"): + scheduler.step(1) + assert scheduler.gamma == pytest.approx(0.1) + + +def test_agd_with_interval_gamma_scheduler(): + """GammaScheduler with 'interval' decay wired to AGD via step_callback.""" + objective = GammaTrackingObjective(gamma=0.1) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=0.1, + decay_type="interval", + decay_params={"intervals": [3, 3], "gammas": [0.1, 0.01]}, + ) + solver = AcceleratedGradientDescent( + max_iter=6, + initial_step_size=1e-3, + max_step_size=0.1, + ) + solver.maximize(objective, torch.zeros(2), step_callback=scheduler.step) + + # Gamma switches from 0.1 -> 0.01 at itr=4 + assert scheduler.gamma == pytest.approx(0.01) + assert objective.gamma_history == pytest.approx([0.1, 0.01]) + + +def test_agd_with_gamma_scheduler_no_decay_between_steps(): + """Gamma and solver max_step_size are unchanged when no decay fires.""" + objective = GammaTrackingObjective(gamma=1.0) + scheduler = GammaScheduler( + objective=objective, + initial_gamma=1.0, + decay_type="step", + decay_params={"decay_steps": 10, "decay_factor": 0.5}, + ) + solver = AcceleratedGradientDescent( + max_iter=5, + initial_step_size=1e-3, + max_step_size=0.1, + ) + solver.maximize(objective, torch.zeros(2), step_callback=scheduler.step) + + # No decay fires within 5 iterations (decay_steps=10) + assert scheduler.gamma == pytest.approx(1.0) + assert solver.max_step_size == pytest.approx(0.1) + assert len(objective.gamma_history) == 1 # only the initial gamma