From 5a02c733d8ecac9e5563769a09e19383f8fb63b2 Mon Sep 17 00:00:00 2001 From: Nastaran Alipour Date: Fri, 14 Nov 2025 09:45:02 +0100 Subject: [PATCH 1/2] add support of seeding --- neps/api.py | 45 +++- neps/optimizers/__init__.py | 27 ++- neps/optimizers/algorithms.py | 57 +++++- neps/optimizers/bayesian_optimization.py | 8 +- neps/optimizers/bracket_optimizer.py | 11 +- neps/optimizers/grid_search.py | 3 + neps/optimizers/ifbo.py | 17 +- neps/optimizers/models/gp.py | 139 +++++++------ neps/optimizers/mopriors.py | 13 +- neps/optimizers/primo.py | 13 +- neps/optimizers/priorband.py | 34 +++- neps/optimizers/random_search.py | 6 +- neps/optimizers/utils/brackets.py | 35 +++- .../optimizers/utils/multiobjective/epsnet.py | 11 +- neps/runtime.py | 12 +- neps/sampling/distributions.py | 61 +++++- neps/sampling/priors.py | 24 ++- neps/sampling/samplers.py | 3 - neps/state/__init__.py | 4 +- neps/state/neps_state.py | 4 +- neps/state/optimizer.py | 4 +- neps/state/seed_snapshot.py | 192 ++++++++++++------ neps/utils/trial_io.py | 4 +- neps_examples/basic_usage/hyperparameters.py | 3 +- .../test_default_report_values.py | 17 +- .../test_error_handling_strategies.py | 16 +- .../test_save_evaluation_results.py | 8 +- tests/test_runtime/test_stopping_criterion.py | 44 +++- tests/test_runtime/test_worker_creation.py | 12 +- tests/test_state/test_filebased_neps_state.py | 6 +- tests/test_state/test_neps_state.py | 11 +- tests/test_state/test_rng.py | 125 ++++++++++-- 32 files changed, 694 insertions(+), 275 deletions(-) diff --git a/neps/api.py b/neps/api.py index 10ca2702f..8602b8524 100644 --- a/neps/api.py +++ b/neps/api.py @@ -12,12 +12,14 @@ from neps.optimizers import AskFunction, OptimizerChoice, load_optimizer from neps.runtime import _launch_runtime, _save_results from neps.space.parsing import convert_to_space -from neps.state import NePSState, OptimizationState, SeedSnapshot +from neps.state import NePSState, OptimizationState, RNGStateManager from neps.status.status import post_run_csv, trajectory_of_improvements from neps.utils.common import dynamic_load_object from neps.validation import _validate_imported_config, _validate_imported_result if TYPE_CHECKING: + import numpy as np + import torch from ConfigSpace import ConfigurationSpace from neps.optimizers.algorithms import CustomOptimizer @@ -59,6 +61,9 @@ def run( # noqa: C901, D417, PLR0913 | CustomOptimizer | Literal["auto"] ) = "auto", + seed: int | None = None, + numpy_rng: np.random.Generator | None = None, + torch_rng: torch.Generator | None = None, ) -> None: """Run the optimization. @@ -425,7 +430,20 @@ def __call__( This is mainly meant for internal development but allows you to use the NePS runtime to run your optimizer. - + seed: An optional seed for the random number generators. + numpy_rng: An optional numpy random number generator. + torch_rng: An optional torch random number generator. + + ??? tip "RNG Priority and Control" + When a previously created NePS state is loaded (by reusing the same root_directory), + all RNGs are reconstructed from the saved state. In this case, the parameters seed, numpy_rng, + and torch_rng are ignored, because the experiment continues from the exact stored RNG states. + If you provide numpy_rng or torch_rng explicitly, these generators take precedence and are used directly. + The seed parameter is not used to create new RNGs in this situation. + The overall priority is: + 1. Saved NePS state — highest priority + 2. User-provided numpy_rng and torch_rng + 3. seed — used only when no RNG objects and no saved state are available """ # noqa: E501 if ( evaluations_to_spend is None @@ -479,7 +497,10 @@ def __call__( logger.info(f"Starting neps.run using root directory {root_directory}") space = convert_to_space(pipeline_space) - _optimizer_ask, _optimizer_info = load_optimizer(optimizer=optimizer, space=space) + + _optimizer_ask_wrapper, _optimizer_info = load_optimizer( + optimizer=optimizer, space=space + ) multi_fidelity_optimizers = { "successive_halving", @@ -523,9 +544,13 @@ def __call__( "'module:function'." ) + rng_manager = RNGStateManager.new_capture( + seed=seed, np_rng=numpy_rng, torch_rng=torch_rng + ) + _launch_runtime( evaluation_fn=_eval, # type: ignore - optimizer=_optimizer_ask, + optimizer_fn=_optimizer_ask_wrapper, optimizer_info=_optimizer_info, cost_to_spend=cost_to_spend, fidelities_to_spend=fidelities_to_spend, @@ -540,6 +565,7 @@ def __call__( sample_batch_size=sample_batch_size, write_summary_to_disk=write_summary_to_disk, worker_id=worker_id, + rng_manager=rng_manager, ) post_run_csv(root_directory) @@ -646,13 +672,18 @@ def import_trials( if isinstance(root_directory, str): root_directory = Path(root_directory) - optimizer_ask, optimizer_info = load_optimizer(optimizer, pipeline_space) + rng_manager = RNGStateManager.new_capture() + + optimizer_ask_fn, optimizer_info = load_optimizer( + optimizer, + pipeline_space, + ) state = NePSState.create_or_load( root_directory, optimizer_info=optimizer_info, optimizer_state=OptimizationState( - budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} + budget=None, rng_state_manager=rng_manager, shared_state={} ), ) @@ -675,7 +706,7 @@ def import_trials( if tuple(sorted(t[0].items())) not in existing_configs ] - imported_trials = optimizer_ask.import_trials( + imported_trials = optimizer_ask_fn(rng_manager).import_trials( external_evaluations=normalized_trials, trials=state_trials, ) diff --git a/neps/optimizers/__init__.py b/neps/optimizers/__init__.py index 9b97790a9..89a670fe9 100644 --- a/neps/optimizers/__init__.py +++ b/neps/optimizers/__init__.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from neps.space import SearchSpace + from neps.state.seed_snapshot import RNGStateManager def _load_optimizer_from_string( @@ -21,7 +22,7 @@ def _load_optimizer_from_string( space: SearchSpace, *, optimizer_kwargs: Mapping[str, Any] | None = None, -) -> tuple[AskFunction, OptimizerInfo]: +) -> tuple[Callable[[RNGStateManager], AskFunction], OptimizerInfo]: if optimizer == "auto": _optimizer = determine_optimizer_automatically(space) else: @@ -37,7 +38,9 @@ def _load_optimizer_from_string( keywords = extract_keyword_defaults(optimizer_build) optimizer_kwargs = optimizer_kwargs or {} - opt = optimizer_build(space, **optimizer_kwargs) + opt = lambda rng_manager: optimizer_build( + space, rng_manager=rng_manager, **optimizer_kwargs + ) info = OptimizerInfo(name=_optimizer, info={**keywords, **optimizer_kwargs}) return opt, info @@ -52,7 +55,7 @@ def load_optimizer( | Literal["auto"] ), space: SearchSpace, -) -> tuple[AskFunction, OptimizerInfo]: +) -> tuple[Callable[[RNGStateManager], AskFunction], OptimizerInfo]: match optimizer: # Predefined string (including "auto") case str(): @@ -60,22 +63,32 @@ def load_optimizer( # Predefined string with kwargs case (opt, kwargs) if isinstance(opt, str): - return _load_optimizer_from_string(opt, space, optimizer_kwargs=kwargs) # type: ignore + return _load_optimizer_from_string( + opt, # type: ignore + space, + optimizer_kwargs=kwargs, # type: ignore + ) # Mapping with a name case {"name": name, **_kwargs}: - return _load_optimizer_from_string(name, space, optimizer_kwargs=_kwargs) # type: ignore + return _load_optimizer_from_string( + name, + space, + optimizer_kwargs=_kwargs, # type: ignore + ) # Provided optimizer initializer case _ if callable(optimizer): keywords = extract_keyword_defaults(optimizer) - _optimizer = optimizer(space) + _optimizer = lambda rng_manager: optimizer(space, rng_manager=rng_manager) info = OptimizerInfo(name=optimizer.__name__, info=keywords) return _optimizer, info # Custom optimizer, we create it case CustomOptimizer(initialized=False): - _optimizer = optimizer.create(space) + _optimizer = lambda rng_manager=None: optimizer.create( # type: ignore + space, rng_manager=rng_manager + ) keywords = extract_keyword_defaults(optimizer.optimizer) info = OptimizerInfo( name=optimizer.name, info={**keywords, **optimizer.kwargs} diff --git a/neps/optimizers/algorithms.py b/neps/optimizers/algorithms.py index e80377e1e..79cd8e369 100644 --- a/neps/optimizers/algorithms.py +++ b/neps/optimizers/algorithms.py @@ -45,6 +45,7 @@ from neps.optimizers.utils.brackets import Bracket from neps.space import SearchSpace + from neps.state import RNGStateManager logger = logging.getLogger(__name__) @@ -56,6 +57,7 @@ def _bo( use_priors: bool, cost_aware: bool | Literal["log"], sample_prior_first: bool, + rng_manager: RNGStateManager, ignore_fidelity: bool = False, device: torch.device | str | None, reference_point: tuple[float, ...] | None = None, @@ -130,6 +132,7 @@ def _bo( prior=Prior.from_parameters(parameters) if use_priors is True else None, sample_prior_first=sample_prior_first, device=device, + rng_manager=rng_manager, reference_point=reference_point, ) @@ -143,6 +146,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 | PriorBandSampler | MOPriorSampler | Sampler, + rng_manager: RNGStateManager, bayesian_optimization_kick_in_point: int | float | None, sample_prior_first: bool | Literal["highest_fidelity"], # NOTE: This is the only argument to get a default, since it @@ -264,6 +268,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 rung_sizes=rung_sizes, is_multi_objective=multi_objective, mo_selector=mo_selector, + np_rng=rng_manager.np_rng, ) case "hyperband": @@ -277,6 +282,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 bracket_layouts=bracket_layouts, is_multi_objective=multi_objective, mo_selector=mo_selector, + np_rng=rng_manager.np_rng, ) case "asha": @@ -292,6 +298,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 eta=eta, is_multi_objective=multi_objective, mo_selector=mo_selector, + np_rng=rng_manager.np_rng, ) case "async_hb": @@ -308,6 +315,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 eta=eta, is_multi_objective=multi_objective, mo_selector=mo_selector, + np_rng=rng_manager.np_rng, ) case _: raise ValueError(f"Unknown bracket type: {bracket_type}") @@ -319,7 +327,9 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 case "uniform": _sampler = Sampler.uniform(ndim=encoder.ndim) case "prior": - _sampler = Prior.from_parameters(parameters) + _sampler = Prior.from_parameters( + parameters, + ) case "priorband": _sampler = PriorBandSampler( parameters=parameters, @@ -331,6 +341,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 early_stopping_rate if early_stopping_rate is not None else 0 ), fid_bounds=(fidelity.lower, fidelity.upper), + rng_manager=rng_manager, ) case "mopriorsampler": assert prior_centers is not None @@ -340,6 +351,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 prior_centers=prior_centers, confidence_values=prior_confidences, encoder=encoder, + rng_manager=rng_manager, ) case _: raise ValueError(f"Unknown sampler: {sampler}") @@ -386,6 +398,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 sampler=_sampler, sample_prior_first=sample_prior_first, create_brackets=create_brackets, + rng_manager=rng_manager, ) @@ -413,6 +426,7 @@ def random_search( *, use_priors: bool = False, ignore_fidelity: bool | Literal["highest fidelity"] = False, + rng_manager: RNGStateManager, ) -> RandomSearch: """A simple random search algorithm that samples configurations uniformly at random. @@ -476,12 +490,15 @@ def random_search( if use_priors else Uniform(ndim=len(parameters)) ), + rng_manager=rng_manager, ) def grid_search( pipeline_space: SearchSpace, ignore_fidelity: bool = False, # noqa: FBT001, FBT002 + *, + rng_manager: RNGStateManager, ) -> GridSearch: """A simple grid search algorithm which discretizes the search space and evaluates all possible configurations. @@ -504,7 +521,8 @@ def grid_search( ) return GridSearch( - configs_list=make_grid(pipeline_space, ignore_fidelity=ignore_fidelity) + configs_list=make_grid(pipeline_space, ignore_fidelity=ignore_fidelity), + rng_manager=rng_manager, ) @@ -518,6 +536,7 @@ def ifbo( device: torch.device | str | None = None, surrogate_path: str | Path | None = None, surrogate_version: str = "0.0.1", + rng_manager: RNGStateManager, ) -> IFBO: """A transformer that has been trained to predict loss curves of deep-learing models, used to guide the optimization procedure and select configurations which @@ -624,6 +643,7 @@ def ifbo( for cat_name, cat in space.categoricals.items() }, ), + rng_manager=rng_manager, ) @@ -634,6 +654,7 @@ def successive_halving( eta: int = 3, early_stopping_rate: int = 0, sample_prior_first: bool | Literal["highest_fidelity"] = False, + rng_manager: RNGStateManager, ) -> BracketOptimizer: """ A bandit-based optimization algorithm that uses a _fidelity_ parameter @@ -706,6 +727,7 @@ def successive_halving( # TODO: Implement this bayesian_optimization_kick_in_point=None, device=None, + rng_manager=rng_manager, ) @@ -715,6 +737,7 @@ def hyperband( eta: int = 3, sampler: Literal["uniform", "prior"] = "uniform", sample_prior_first: bool | Literal["highest_fidelity"] = False, + rng_manager: RNGStateManager, ) -> BracketOptimizer: """Another bandit-based optimization algorithm that uses a _fidelity_ parameter, very similar to [`successive_halving`][neps.optimizers.algorithms.successive_halving], @@ -770,6 +793,7 @@ def hyperband( # TODO: Implement this bayesian_optimization_kick_in_point=None, device=None, + rng_manager=rng_manager, ) @@ -780,6 +804,7 @@ def mo_hyperband( sampler: Literal["uniform", "prior"] = "uniform", sample_prior_first: bool | Literal["highest_fidelity"] = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + rng_manager: RNGStateManager, ) -> BracketOptimizer: """Multi-objective version of hyperband using the same candidate selection method as MOASHA. @@ -796,6 +821,7 @@ def mo_hyperband( device=None, multi_objective=True, mo_selector=mo_selector, + rng_manager=rng_manager, ) @@ -806,6 +832,7 @@ def asha( early_stopping_rate: int = 0, sampler: Literal["uniform", "prior"] = "uniform", sample_prior_first: bool | Literal["highest_fidelity"] = False, + rng_manager: RNGStateManager, ) -> BracketOptimizer: """A bandit-based optimization algorithm that uses a _fidelity_ parameter, the _asynchronous_ version of @@ -860,6 +887,7 @@ def asha( # TODO: Implement this bayesian_optimization_kick_in_point=None, device=None, + rng_manager=rng_manager, ) @@ -871,6 +899,7 @@ def moasha( sampler: Literal["uniform", "prior"] = "uniform", sample_prior_first: bool | Literal["highest_fidelity"] = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + rng_manager: RNGStateManager, ) -> BracketOptimizer: return _bracket_optimizer( pipeline_space=space, @@ -884,6 +913,7 @@ def moasha( device=None, multi_objective=True, mo_selector=mo_selector, + rng_manager=rng_manager, ) @@ -893,6 +923,7 @@ def async_hb( eta: int = 3, sampler: Literal["uniform", "prior"] = "uniform", sample_prior_first: bool = False, + rng_manager: RNGStateManager, ) -> BracketOptimizer: """An _asynchronous_ version of [`hyperband`][neps.optimizers.algorithms.hyperband], where the brackets are run asynchronously, and the promotion rule is based on the @@ -944,6 +975,7 @@ def async_hb( # TODO: Implement this bayesian_optimization_kick_in_point=None, device=None, + rng_manager=rng_manager, ) @@ -954,6 +986,7 @@ def priorband( sample_prior_first: bool | Literal["highest_fidelity"] = False, base: Literal["successive_halving", "hyperband", "asha", "async_hb"] = "hyperband", bayesian_optimization_kick_in_point: int | float | None = None, + rng_manager: RNGStateManager, ) -> BracketOptimizer: """Priorband is also a bandit-based optimization algorithm that uses a _fidelity_, providing a general purpose sampling extension to other algorithms. It makes better @@ -1006,6 +1039,7 @@ def priorband( early_stopping_rate=0 if base in ("successive_halving", "asha") else None, bayesian_optimization_kick_in_point=bayesian_optimization_kick_in_point, device=None, + rng_manager=rng_manager, ) @@ -1017,6 +1051,7 @@ def bayesian_optimization( ignore_fidelity: bool = False, device: torch.device | str | None = None, reference_point: tuple[float, ...] | None = None, + rng_manager: RNGStateManager, ) -> BayesianOptimization: """Models the relation between hyperparameters in your `pipeline_space` and the results of `evaluate_pipeline` using bayesian optimization. @@ -1101,6 +1136,7 @@ def bayesian_optimization( sample_prior_first=False, ignore_fidelity=ignore_fidelity, reference_point=reference_point, + rng_manager=rng_manager, ) @@ -1112,6 +1148,7 @@ def pibo( device: torch.device | str | None = None, sample_prior_first: bool = False, ignore_fidelity: bool = False, + rng_manager: RNGStateManager, ) -> BayesianOptimization: """A modification of [`bayesian_optimization`][neps.optimizers.algorithms.bayesian_optimization] @@ -1162,6 +1199,7 @@ def pibo( use_priors=True, sample_prior_first=sample_prior_first, ignore_fidelity=ignore_fidelity, + rng_manager=rng_manager, ) @@ -1172,13 +1210,14 @@ def primo( sample_prior_first: bool | Literal["highest_fidelity"] = False, # noqa: ARG001 eta: int = 3, epsilon: float = 0.25, - prior_centers: Mapping[str, Mapping[str, Any]], + prior_centers: Mapping[str, Mapping[str, Any]] | None = None, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", prior_confidences: Mapping[str, Mapping[str, float]] | None = None, initial_design_size: int | Literal["ndim"] = "ndim", cost_aware: bool | Literal["log"] = False, # noqa: ARG001 device: torch.device | str | None = None, bo_scalar_weights: dict[str, float] | None = None, + rng_manager: RNGStateManager, ) -> PriMO: """Replaces the initial design of Bayesian optimization with MOASHA, then switches to BO after N*max_fidelity worth of evaluations, where N is the initial_design_size.""" @@ -1193,6 +1232,7 @@ def primo( sample_prior_first=False, early_stopping_rate=0, device=device, + rng_manager=rng_manager, ) parameters = space.searchables @@ -1232,6 +1272,7 @@ def primo( device=device, priors=_priors, epsilon=epsilon, + rng_manager=rng_manager, ) @@ -1248,9 +1289,15 @@ class CustomOptimizer: kwargs: Mapping[str, Any] = field(default_factory=dict) initialized: bool = False - def create(self, space: SearchSpace) -> AskFunction: + def create( + self, space: SearchSpace, rng_manager: RNGStateManager | None = None + ) -> AskFunction: assert not self.initialized, "Custom optimizer already initialized." - return self.optimizer(space, **self.kwargs) # type: ignore + args = dict(self.kwargs) + if rng_manager is not None: + args["rng_manager"] = rng_manager + + return self.optimizer(space, **args) # type: ignore def custom( diff --git a/neps/optimizers/bayesian_optimization.py b/neps/optimizers/bayesian_optimization.py index 1160a9fb2..d18ac0a2e 100644 --- a/neps/optimizers/bayesian_optimization.py +++ b/neps/optimizers/bayesian_optimization.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from neps.sampling import Prior from neps.space import ConfigEncoder, SearchSpace - from neps.state import BudgetInfo, Trial + from neps.state import BudgetInfo, RNGStateManager, Trial from neps.state.pipeline_eval import UserResultDict @@ -86,6 +86,9 @@ class BayesianOptimization: device: torch.device | None """The device to use for the optimization.""" + rng_manager: RNGStateManager + """The RNG state manager to use for seeding.""" + reference_point: tuple[float, ...] | None = None """The reference point to use for the multi-objective optimization.""" @@ -124,7 +127,7 @@ def __call__( # noqa: C901, PLR0912, PLR0915 # noqa: C901, PLR0912 encoder=self.encoder, sample_prior_first=self.sample_prior_first if n_sampled == 0 else False, sampler=self.prior if self.prior is not None else "uniform", - seed=None, # TODO: Seeding, however we need to avoid repeating configs + seed=self.rng_manager.torch_manual_rng, sample_size=self.n_initial_design, ) @@ -271,6 +274,7 @@ def __call__( # noqa: C901, PLR0912, PLR0915 # noqa: C901, PLR0912 cost_percentage_used=cost_percent, costs_on_log_scale=self.cost_aware == "log", hide_warnings=True, + seed=self.rng_manager.torch_manual_rng, ) configs = encoder.decode(candidates) diff --git a/neps/optimizers/bracket_optimizer.py b/neps/optimizers/bracket_optimizer.py index c357b8419..364e3dd24 100644 --- a/neps/optimizers/bracket_optimizer.py +++ b/neps/optimizers/bracket_optimizer.py @@ -38,6 +38,7 @@ from neps.space.parameters import Parameter from neps.state.optimizer import BudgetInfo from neps.state.pipeline_eval import UserResultDict + from neps.state.seed_snapshot import RNGStateManager from neps.state.trial import Trial @@ -132,6 +133,7 @@ def sample_config( trials: Mapping[str, Trial], budget_info: BudgetInfo | None, target_fidelity: int | float, + seed: torch.Generator | None = None, ) -> dict[str, Any]: """Samples a configuration using the GP model. @@ -186,6 +188,7 @@ def sample_config( cost_percentage_used=None, costs_on_log_scale=False, hide_warnings=True, + seed=seed, ) assert len(candidates) == N @@ -258,6 +261,9 @@ class BracketOptimizer: fid_name: str """The name of the fidelity in the space.""" + rng_manager: RNGStateManager + """The RNG state manager to use for seeding.""" + def __call__( # noqa: C901, PLR0912 self, trials: Mapping[str, Trial], @@ -356,6 +362,7 @@ def __call__( # noqa: C901, PLR0912 trials, budget_info=None, # TODO: budget_info not supported yet target_fidelity=target_fidelity, + seed=self.rng_manager.torch_manual_rng, ) config.update(space.constants) return SampledConfig(id=f"{nxt_id}_rung_{rung}", config=config) @@ -366,7 +373,9 @@ def __call__( # noqa: C901, PLR0912 # Otherwise, we proceed with the original sampler match self.sampler: case Sampler(): - config = self.sampler.sample_config(to=self.encoder) + config = self.sampler.sample_config( + to=self.encoder, seed=self.rng_manager.torch_manual_rng + ) config = { **config, **space.constants, diff --git a/neps/optimizers/grid_search.py b/neps/optimizers/grid_search.py index dcd24ef33..54c7407e5 100644 --- a/neps/optimizers/grid_search.py +++ b/neps/optimizers/grid_search.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from neps.state import BudgetInfo, Trial from neps.state.pipeline_eval import UserResultDict + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -18,6 +19,8 @@ class GridSearch: configs_list: list[dict[str, Any]] """The list of configurations to evaluate.""" + rng_manager: RNGStateManager + def __call__( self, trials: Mapping[str, Trial], diff --git a/neps/optimizers/ifbo.py b/neps/optimizers/ifbo.py index 6ab898a95..d641e8bb0 100755 --- a/neps/optimizers/ifbo.py +++ b/neps/optimizers/ifbo.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -import numpy as np import torch from neps.optimizers.models.ftpfn import ( @@ -23,6 +22,7 @@ if TYPE_CHECKING: from neps.state import BudgetInfo, Trial from neps.state.pipeline_eval import UserResultDict + from neps.state.seed_snapshot import RNGStateManager # NOTE: Ifbo was trained using 32 bit FTPFN_DTYPE = torch.float32 @@ -130,10 +130,12 @@ class IFBO: n_fidelity_bins: int """The number of bins to divide the fidelity domain into. - Each one will be treated as an individual fidelity level. """ + rng_manager: RNGStateManager + """The RNG state manager to use for seeding.""" + def __call__( self, trials: Mapping[str, Trial], @@ -180,7 +182,7 @@ def __call__( encoder=self.encoder, sample_prior_first=self.sample_prior_first, sampler="sobol" if self.prior is None else self.prior, - seed=None, # TODO: + seed=self.rng_manager.torch_manual_rng, sample_size=self.n_initial_design, ) @@ -216,14 +218,15 @@ def __call__( # objective_to_minimize # 2. The budget is the second column # 3. The budget is encoded between 1/max_fid and 1 - rng = np.random.RandomState(len(trials)) # Cast the a random budget index into the ftpfn budget domain horizon_increment = budget_domain.cast_one( - rng.randint(*budget_index_domain.bounds) + 1, + self.rng_manager.np_rng.integers(*budget_index_domain.bounds) + 1, frm=budget_index_domain, ) f_best = y.max().item() - threshold = f_best + (10 ** rng.uniform(-4, -1)) * (1 - f_best) + threshold = f_best + (10 ** self.rng_manager.np_rng.uniform(-4, -1)) * ( + 1 - f_best + ) def _mfpi_random(samples: torch.Tensor) -> torch.Tensor: # HACK: Because we are modifying the samples inplace, we do, @@ -254,7 +257,7 @@ def _mfpi_random(samples: torch.Tensor) -> torch.Tensor: (Sampler.uniform(ndim=sample_dims), 512), (Sampler.borders(ndim=sample_dims), 256), ], - seed=None, # TODO: Seeding + seed=self.rng_manager.torch_manual_rng, # A next step local sampling around best point found by initial_samplers local_search_sample_size=256, local_search_confidence=0.95, diff --git a/neps/optimizers/models/gp.py b/neps/optimizers/models/gp.py index 586ba371e..590182e1c 100644 --- a/neps/optimizers/models/gp.py +++ b/neps/optimizers/models/gp.py @@ -24,6 +24,7 @@ from neps.optimizers.acquisition import cost_cooled_acq, pibo_acquisition from neps.space.encoding import CategoricalToIntegerTransformer, ConfigEncoder +from neps.state.seed_snapshot import use_generator_globally from neps.utils.common import disable_warnings if TYPE_CHECKING: @@ -314,7 +315,7 @@ def fit_and_acquire_from_gp( costs: torch.Tensor | None = None, cost_percentage_used: float | None = None, costs_on_log_scale: bool = True, - seed: int | None = None, + seed: torch.Generator | None = None, n_candidates_required: int | None = None, num_restarts: int = 20, n_initial_start_points: int = 256, @@ -372,80 +373,78 @@ def fit_and_acquire_from_gp( The encoded next configuration(s) to evaluate. Use the encoder you provided to decode the configuration. """ - if seed is not None: - raise NotImplementedError("Seed is not implemented yet for gps") - - fit_gpytorch_mll(ExactMarginalLogLikelihood(likelihood=gp.likelihood, model=gp)) - - if prior: - if pibo_exp_term is None: - raise ValueError( - "If providing a prior, you must provide the `pibo_exp_term`." + with use_generator_globally(generator=seed): + fit_gpytorch_mll(ExactMarginalLogLikelihood(likelihood=gp.likelihood, model=gp)) + + if prior: + if pibo_exp_term is None: + raise ValueError( + "If providing a prior, you must provide the `pibo_exp_term`." + ) + + acquisition = pibo_acquisition( + acquisition, + prior=prior, + prior_exponent=pibo_exp_term, + x_domain=encoder.domains, ) - acquisition = pibo_acquisition( - acquisition, - prior=prior, - prior_exponent=pibo_exp_term, - x_domain=encoder.domains, - ) - - if costs is not None: - if cost_percentage_used is None: - raise ValueError( - "If providing costs, you must provide `cost_percentage_used`." + if costs is not None: + if cost_percentage_used is None: + raise ValueError( + "If providing costs, you must provide `cost_percentage_used`." + ) + + # We simply ignore missing costs when training the cost GP. + missing_costs = torch.isnan(costs) + if missing_costs.any(): + raise ValueError( + "Must have at least some configurations reported with a cost" + " if using costs with a GP." + ) + + if missing_costs.any(): + not_missing_mask = ~missing_costs + x_train_cost = costs[not_missing_mask] + y_train_cost = x_train[not_missing_mask] + else: + x_train_cost = x_train + y_train_cost = costs + + if costs_on_log_scale: + transform = ChainedOutcomeTransform( + log=Log(), + standardize=Standardize(m=1), + ) + else: + transform = Standardize(m=1) + + cost_gp = make_default_single_obj_gp( + x_train_cost, + y_train_cost, + encoder=encoder, + y_transform=transform, ) - - # We simply ignore missing costs when training the cost GP. - missing_costs = torch.isnan(costs) - if missing_costs.any(): - raise ValueError( - "Must have at least some configurations reported with a cost" - " if using costs with a GP." + fit_gpytorch_mll( + ExactMarginalLogLikelihood(likelihood=cost_gp.likelihood, model=cost_gp) ) - - if missing_costs.any(): - not_missing_mask = ~missing_costs - x_train_cost = costs[not_missing_mask] - y_train_cost = x_train[not_missing_mask] - else: - x_train_cost = x_train - y_train_cost = costs - - if costs_on_log_scale: - transform = ChainedOutcomeTransform( - log=Log(), - standardize=Standardize(m=1), + acquisition = cost_cooled_acq( + acq_fn=acquisition, + model=cost_gp, + used_max_cost_total_percentage=cost_percentage_used, ) - else: - transform = Standardize(m=1) - cost_gp = make_default_single_obj_gp( - x_train_cost, - y_train_cost, - encoder=encoder, - y_transform=transform, - ) - fit_gpytorch_mll( - ExactMarginalLogLikelihood(likelihood=cost_gp.likelihood, model=cost_gp) - ) - acquisition = cost_cooled_acq( - acq_fn=acquisition, - model=cost_gp, - used_max_cost_total_percentage=cost_percentage_used, - ) + _n = n_candidates_required if n_candidates_required is not None else 1 - _n = n_candidates_required if n_candidates_required is not None else 1 - - candidates, _scores = optimize_acq( - acquisition, - encoder, - n_candidates_required=_n, - num_restarts=num_restarts, - n_intial_start_points=n_initial_start_points, - fixed_features=fixed_acq_features, - acq_options=acq_options, - maximum_allowed_categorical_combinations=maximum_allowed_categorical_combinations, - hide_warnings=hide_warnings, - ) + candidates, _scores = optimize_acq( + acquisition, + encoder, + n_candidates_required=_n, + num_restarts=num_restarts, + n_intial_start_points=n_initial_start_points, + fixed_features=fixed_acq_features, + acq_options=acq_options, + maximum_allowed_categorical_combinations=maximum_allowed_categorical_combinations, + hide_warnings=hide_warnings, + ) return candidates diff --git a/neps/optimizers/mopriors.py b/neps/optimizers/mopriors.py index f5a61bc3d..77b335a83 100644 --- a/neps/optimizers/mopriors.py +++ b/neps/optimizers/mopriors.py @@ -4,13 +4,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -import numpy as np - from neps.sampling.priors import Prior if TYPE_CHECKING: from neps.space import ConfigEncoder from neps.space.parameters import Parameter + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -24,6 +23,8 @@ class MOPriorSampler: encoder: ConfigEncoder + rng_manager: RNGStateManager + @classmethod def dists_from_centers_and_confidences( cls, @@ -62,6 +63,7 @@ def create_sampler( prior_centers: Mapping[str, Mapping[str, float]], confidence_values: Mapping[str, Mapping[str, float]], encoder: ConfigEncoder, + rng_manager: RNGStateManager, ) -> MOPriorSampler: """Creates a MOPriorSampler instance. @@ -85,6 +87,7 @@ def create_sampler( prior_dists=_priors, parameters=parameters, encoder=encoder, + rng_manager=rng_manager, ) def sample_config(self) -> dict[str, Any]: @@ -93,9 +96,11 @@ def sample_config(self) -> dict[str, Any]: Returns: The sampled configuration. """ - _prior_choice: Prior = np.random.choice( + _prior_choice: Prior = self.rng_manager.np_rng.choice( list(self.prior_dists.values()), ) # Sample a configuration from the chosen prior - return _prior_choice.sample_config(to=self.encoder) + return _prior_choice.sample_config( + to=self.encoder, seed=self.rng_manager.torch_manual_rng + ) diff --git a/neps/optimizers/primo.py b/neps/optimizers/primo.py index 748479676..4473f993a 100644 --- a/neps/optimizers/primo.py +++ b/neps/optimizers/primo.py @@ -33,6 +33,7 @@ from neps.space.encoding import ConfigEncoder from neps.state import BudgetInfo, Trial from neps.state.pipeline_eval import UserResultDict + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -57,6 +58,9 @@ class PriMO: fid_name: str """The name of the fidelity in the BracketOptimizer's search space.""" + rng_manager: RNGStateManager + """The RNG state manager to use for sampling.""" + scalarization_weights: dict[str, float] | None = None """The scalarization weights to use for the objectives for BO.""" @@ -123,7 +127,9 @@ def __call__( # noqa: C901, PLR0912 # Set scalarization weights if not set if self.scalarization_weights is None: - self.scalarization_weights = np.random.uniform(size=num_objectives) + self.scalarization_weights = self.rng_manager.np_rng.uniform( + size=num_objectives + ) self.scalarization_weights /= np.sum(self.scalarization_weights) # Scalarize trials.report.objective_to_minimize and remove fidelity @@ -249,11 +255,11 @@ def sample_using_bo( selected_prior = None if self.priors is not None: - selected_prior = np.random.choice( + selected_prior = self.rng_manager.np_rng.choice( list(self.priors.values()), ) - selected_prior = np.random.choice( + selected_prior = self.rng_manager.np_rng.choice( [selected_prior, None], p=[1 - self.epsilon, self.epsilon], ) @@ -292,6 +298,7 @@ def sample_using_bo( n_candidates_required=n_to_acquire, pibo_exp_term=pibo_exp_term, hide_warnings=True, + seed=self.rng_manager.torch_manual_rng, ) return encoder.decode_one(candidates) diff --git a/neps/optimizers/priorband.py b/neps/optimizers/priorband.py index 9d6d23e4b..6f3ae66b1 100644 --- a/neps/optimizers/priorband.py +++ b/neps/optimizers/priorband.py @@ -17,6 +17,7 @@ import pandas as pd from neps.space.parameters import Parameter + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -47,6 +48,9 @@ class PriorBandSampler: fid_bounds: tuple[int, int] | tuple[float, float] """The fidelity bounds.""" + rng_manager: RNGStateManager + """The RNG state manager to use for seeding.""" + def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: """Samples a configuration using the PriorBand algorithm. @@ -64,7 +68,9 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: ) max_rung = max(rung_sizes) - prior_dist = Prior.from_parameters(self.parameters) + prior_dist = Prior.from_parameters( + self.parameters, + ) # Below we will follow the "geomtric" spacing w_random = 1 / (1 + self.eta**rung) @@ -103,13 +109,19 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: or spent_one_sh_bracket_worth_of_fidelity is False or any_rung_with_eta_evals is False ): - policy = np.random.choice(["prior", "random"], p=[w_prior, w_random]) + policy = self.rng_manager.np_rng.choice( + ["prior", "random"], p=[w_prior, w_random] + ) match policy: case "prior": - config = prior_dist.sample_config(to=self.encoder) + config = prior_dist.sample_config( + to=self.encoder, seed=self.rng_manager.torch_manual_rng + ) case "random": _sampler = Sampler.uniform(ndim=self.encoder.ndim) - config = _sampler.sample_config(to=self.encoder) + config = _sampler.sample_config( + to=self.encoder, seed=self.rng_manager.torch_manual_rng + ) return config @@ -126,7 +138,10 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: # 2. Get the global incumbent, and build a prior distribution around it inc = completed.loc[completed["perf"].idxmin()]["config"] - inc_dist = Prior.from_parameters(self.parameters, center_values=inc) + inc_dist = Prior.from_parameters( + self.parameters, + center_values=inc, + ) # 3. Calculate a ratio score of how likely each of the top K configs are under # the prior and inc distribution, weighing them by their position in the top K @@ -148,7 +163,7 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: assert np.isclose(w_prior + w_inc + w_random, 1.0) # Now we use these weights to choose which sampling distribution to sample from - policy = np.random.choice( + policy = self.rng_manager.np_rng.choice( ["prior", "inc", "random"], p=[w_prior, w_inc, w_random], ) @@ -165,7 +180,7 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: parameters=self.parameters, mutation_rate=self.mutation_rate, std=self.mutation_std, - seed=None, + seed=self.rng_manager.torch_manual_rng, ) raise RuntimeError(f"Unknown policy: {policy}") @@ -179,9 +194,6 @@ def mutate_config( std: float = 0.25, seed: torch.Generator | None = None, ) -> dict[str, Any]: - if seed is not None: - raise NotImplementedError("Seed is not implemented yet.") - # This prior places a guassian on the numericals and places a 0 probability on the # current value of the categoricals. mutatation_prior = Prior.from_parameters( @@ -196,7 +208,7 @@ def mutate_config( ) config_encoder = ConfigEncoder.from_parameters(parameters) - mutant: dict[str, Any] = mutatation_prior.sample_config(to=config_encoder) + mutant: dict[str, Any] = mutatation_prior.sample_config(to=config_encoder, seed=seed) mutatant_selection = torch.rand(len(config), generator=seed) < mutation_rate return { diff --git a/neps/optimizers/random_search.py b/neps/optimizers/random_search.py index 0de6d28a3..154097671 100644 --- a/neps/optimizers/random_search.py +++ b/neps/optimizers/random_search.py @@ -11,6 +11,7 @@ from neps.space import ConfigEncoder, SearchSpace from neps.state import BudgetInfo, Trial from neps.state.pipeline_eval import UserResultDict + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -20,6 +21,7 @@ class RandomSearch: space: SearchSpace encoder: ConfigEncoder sampler: Sampler + rng_manager: RNGStateManager def __call__( self, @@ -29,7 +31,9 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: n_trials = len(trials) _n = 1 if n is None else n - configs = self.sampler.sample(_n, to=self.encoder.domains) + configs = self.sampler.sample( + _n, to=self.encoder.domains, seed=self.rng_manager.torch_manual_rng + ) config_dicts = self.encoder.decode(configs) for config in config_dicts: diff --git a/neps/optimizers/utils/brackets.py b/neps/optimizers/utils/brackets.py index 690c5fab2..332c008cb 100644 --- a/neps/optimizers/utils/brackets.py +++ b/neps/optimizers/utils/brackets.py @@ -83,7 +83,9 @@ def calculate_hb_bracket_layouts( return rung_to_fidelity, bracket_layouts -def async_hb_sample_bracket_to_run(max_rung: int, eta: int) -> int: +def async_hb_sample_bracket_to_run( + max_rung: int, eta: int, np_rng: np.random.Generator +) -> int: # Sampling distribution derived from Appendix A (https://arxiv.org/abs/2003.10865) # Adapting the distribution based on the current optimization state # s \in [0, max_rung] and to with the denominator's constraint, we have K > s - 1 @@ -92,7 +94,7 @@ def async_hb_sample_bracket_to_run(max_rung: int, eta: int) -> int: K = max_rung bracket_probs = [eta ** (K - s) * (K + 1) / (K - s + 1) for s in range(max_rung + 1)] bracket_probs = np.array(bracket_probs) / sum(bracket_probs) - return int(np.random.choice(range(max_rung + 1), p=bracket_probs)) + return int(np_rng.choice(range(max_rung + 1), p=bracket_probs)) @dataclass @@ -143,6 +145,7 @@ def pareto_promotion_sync( k: int, exclude: Sequence[Hashable] = [], mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + np_rng: np.random.Generator, ) -> tuple[int, dict[str, Any], float] | None: """Selects the best configurations based on Pareto front for sync bracket optimizers. @@ -157,6 +160,7 @@ def pareto_promotion_sync( selector=mo_selector, contenders=contenders, k=k, + np_rng=np_rng, ) _idx, _rung = _df.index[0] row = _df.loc[(_idx, _rung)] @@ -173,6 +177,7 @@ def mo_selector( selector: Literal["nsga2", "epsnet"] = "epsnet", contenders: pd.DataFrame | None = None, k: int, + np_rng: np.random.Generator, ) -> pd.DataFrame: """Replaces top_k in single objective Bracket Optimizers with a multi-objective selector, which selects the best @@ -191,6 +196,7 @@ def mo_selector( return self.epsnet_selector( k=k, contenders=contenders, + np_rng=np_rng, ) case _: raise ValueError( @@ -210,6 +216,7 @@ def epsnet_selector( *, k: int, contenders: pd.DataFrame, + np_rng: np.random.Generator, ) -> pd.DataFrame: """Selects the best configurations based on epsilon-net sorting strategy. Uses Epsilon-net based sorting from SyneTune. @@ -220,6 +227,7 @@ def epsnet_selector( indices = nondominated_sort( X=mo_costs, max_items=k, + np_rng=np_rng, ) return contenders.iloc[indices] @@ -239,6 +247,8 @@ class Sync: mo_selector: Literal["nsga2", "epsnet"] = field(default="epsnet") """The selector to use for multi-objective optimization.""" + np_rng: np.random.Generator = field(default=False) + def __post_init__(self) -> None: if not all_unique(rung.value for rung in self.rungs): raise ValueError(f"Got rungs with duplicate values\n{self.rungs}") @@ -282,6 +292,7 @@ def next(self) -> BracketAction: mo_selector=self.mo_selector, k=1, exclude=upper.config_ids, + np_rng=self.np_rng, ) else: promote_config = lower.best_to_promote(exclude=upper.config_ids) @@ -305,6 +316,7 @@ def create_repeating( rung_sizes: dict[int, int], is_multi_objective: bool = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + np_rng: np.random.Generator, ) -> list[Sync]: """Create a list of brackets from the table. @@ -384,6 +396,7 @@ def create_repeating( ], is_multi_objective=is_multi_objective, mo_selector=mo_selector, + np_rng=np_rng, ) for bracket_data in all_N_bracket_datas ] @@ -411,6 +424,8 @@ class Async: mo_selector: Literal["nsga2", "epsnet"] = field(default="epsnet") """The selector to use for multi-objective optimization.""" + np_rng: np.random.Generator = field(default=False) + def __post_init__(self) -> None: self.rungs = sorted(self.rungs, key=lambda rung: rung.value) if any(rung.capacity is not None for rung in self.rungs): @@ -437,7 +452,9 @@ def next(self) -> BracketAction: continue # Not enough configs to promote yet if self.is_multi_objective: - best_k = lower_dropped.mo_selector(selector=self.mo_selector, k=k) + best_k = lower_dropped.mo_selector( + selector=self.mo_selector, k=k, np_rng=self.np_rng + ) else: best_k = lower_dropped.top_k(k) candidates = best_k.copy(deep=True) @@ -461,6 +478,7 @@ def create( eta: int, is_multi_objective: bool = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + np_rng: np.random.Generator, ) -> Async: return cls( rungs=[ @@ -474,6 +492,7 @@ def create( eta=eta, is_multi_objective=is_multi_objective, mo_selector=mo_selector, + np_rng=np_rng, ) @@ -506,6 +525,7 @@ def create_repeating( bracket_layouts: list[dict[int, int]], is_multi_objective: bool = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + np_rng: np.random.Generator, ) -> list[Hyperband]: """Create a list of brackets from the table. @@ -597,6 +617,7 @@ def create_repeating( ], is_multi_objective=is_multi_objective, mo_selector=mo_selector, + np_rng=np_rng, ) sh_brackets.append(bracket) @@ -648,6 +669,9 @@ class AsyncHyperband: is_multi_objective: bool = field(default=False) """Whether the BracketOptimizer is multi-objective or not.""" + np_rng: np.random.Generator = field(default=False) + """The random number generator to use for sampling.""" + mo_selector: Literal["nsga2", "epsnet"] = field(default="epsnet") """The selector to use for multi-objective optimization.""" @@ -671,6 +695,7 @@ def create( eta: int, is_multi_objective: bool = False, mo_selector: Literal["nsga2", "epsnet"] = "epsnet", + np_rng: np.random.Generator, ) -> AsyncHyperband: """Create an AsyncHyperbandBrackets from the table. @@ -705,15 +730,17 @@ def create( eta=eta, is_multi_objective=is_multi_objective, mo_selector=mo_selector, + np_rng=np_rng, ) for layout in bracket_rungs ], eta=eta, + np_rng=np_rng, ) def next(self) -> BracketAction: # Each ASHA bracket always has an action, sample which to take - bracket_ix = async_hb_sample_bracket_to_run(self._max_rung, self.eta) + bracket_ix = async_hb_sample_bracket_to_run(self._max_rung, self.eta, self.np_rng) bracket = self.asha_brackets[bracket_ix] return bracket.next() diff --git a/neps/optimizers/utils/multiobjective/epsnet.py b/neps/optimizers/utils/multiobjective/epsnet.py index 780ea7413..3ca4e50f3 100644 --- a/neps/optimizers/utils/multiobjective/epsnet.py +++ b/neps/optimizers/utils/multiobjective/epsnet.py @@ -43,7 +43,11 @@ def pareto_efficient(X: np.ndarray) -> np.ndarray: return mask -def compute_epsilon_net(X: np.ndarray, dim: int | None = None) -> np.ndarray: +def compute_epsilon_net( + X: np.ndarray, + np_rng: np.random.Generator, + dim: int | None = None, +) -> np.ndarray: """ Outputs an order of the items in the provided array such that the items are spaced well. This means that after choosing a seed item, the next item is @@ -71,7 +75,7 @@ def compute_epsilon_net(X: np.ndarray, dim: int | None = None) -> np.ndarray: # Choose the seed item according to dim if dim is None: - initial_index = np.random.choice(X.shape[0]) + initial_index = np_rng.choice(X.shape[0]) else: initial_index = np.argmin(X, axis=0)[dim] @@ -104,6 +108,7 @@ def nondominated_sort( max_items: int | None = None, *, flatten: bool = True, + np_rng: np.random.Generator, ) -> list[int] | list[list[int]]: """ Performs a multi-objective sort by iteratively computing the Pareto front @@ -138,7 +143,7 @@ def nondominated_sort( # Compute the Pareto front and sort the items within pareto_mask = pareto_efficient(X[remaining]) pareto_front = remaining[pareto_mask] - pareto_order = compute_epsilon_net(X[pareto_front], dim=dim) + pareto_order = compute_epsilon_net(X[pareto_front], dim=dim, np_rng=np_rng) # Add order to the indices indices.append(pareto_front[pareto_order].tolist()) diff --git a/neps/runtime.py b/neps/runtime.py index 988c8f500..15ecccceb 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -42,7 +42,7 @@ NePSState, OnErrorPossibilities, OptimizationState, - SeedSnapshot, + RNGStateManager, Trial, UserResult, WorkerSettings, @@ -542,7 +542,7 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915 _trace_lock_path.touch(exist_ok=True) logger.info( - "Summary files can be found in the “summary” folder inside" + "Summary files are stored in the “summary” folder inside " "the root directory: %s", summary_dir, ) @@ -952,7 +952,7 @@ def _launch_ddp_runtime( def _launch_runtime( # noqa: PLR0913 *, evaluation_fn: Callable[..., EvaluatePipelineReturn], - optimizer: AskFunction, + optimizer_fn: Callable[[RNGStateManager], AskFunction], optimizer_info: OptimizerInfo, optimization_dir: Path, cost_to_spend: float | None, @@ -967,6 +967,7 @@ def _launch_runtime( # noqa: PLR0913 sample_batch_size: int | None, write_summary_to_disk: bool = True, worker_id: str | None = None, + rng_manager: RNGStateManager, ) -> None: default_report_values = _make_default_report_values( objective_value_on_error=objective_value_on_error, @@ -991,12 +992,12 @@ def _launch_runtime( # noqa: PLR0913 for _retry_count in range(MAX_RETRIES_CREATE_LOAD_STATE): try: - neps_state = NePSState.create_or_load( + neps_state: NePSState = NePSState.create_or_load( path=optimization_dir, load_only=False, optimizer_info=optimizer_info, optimizer_state=OptimizationState( - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=rng_manager, budget=( BudgetInfo( cost_to_spend=cost_to_spend, @@ -1023,6 +1024,7 @@ def _launch_runtime( # noqa: PLR0913 f" {MAX_RETRIES_CREATE_LOAD_STATE} attempts. Bailing!" " Please enable debug logging to see the errors that occured." ) + optimizer = optimizer_fn(neps_state._optimizer_state.rng_state_manager) settings = WorkerSettings( on_error=( diff --git a/neps/sampling/distributions.py b/neps/sampling/distributions.py index 371979b29..d3c7b18c1 100644 --- a/neps/sampling/distributions.py +++ b/neps/sampling/distributions.py @@ -10,7 +10,7 @@ from typing_extensions import override import torch -from torch.distributions import Distribution, Uniform, constraints +from torch.distributions import Categorical, Distribution, Uniform, constraints from torch.distributions.utils import broadcast_all from neps.space import Domain @@ -147,17 +147,68 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: self._validate_sample(value) return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 # type: ignore - @override # type: ignore - def rsample(self, sample_shape: torch.Size | None = None) -> torch.Tensor: + # Not overrided becuase we changed the signiture by passing + # generator through arguments + def sample( # noqa: D102 + self, + sample_shape: torch.Size | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + with torch.no_grad(): + return self.rsample(sample_shape=sample_shape, generator=generator) + + # Not overrided becuase we changed the signiture by passing + # generator through arguments + def rsample( # noqa: D102 + self, + sample_shape: torch.Size | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: if sample_shape is None: sample_shape = torch.Size([]) shape = self._extended_shape(sample_shape) p = torch.empty(shape, device=self.a.device).uniform_( - self._dtype_min_gt_0, self._dtype_max_lt_1 + self._dtype_min_gt_0, self._dtype_max_lt_1, generator=generator ) return self.icdf(p) +class CategoricalWithGenerator(Categorical): # noqa: D101 + def sample( # noqa: D102 + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 + generator: torch.Generator | None = None, + ) -> torch.Tensor: + if not isinstance(sample_shape, torch.Size): + sample_shape = torch.Size(sample_shape) + probs_2d = self.probs.reshape(-1, self._num_events) + samples_2d = torch.multinomial( + probs_2d, sample_shape.numel(), replacement=True, generator=generator + ).T + return samples_2d.reshape(self._extended_shape(sample_shape)) + + +class UniformWithGenerator(Uniform): # noqa: D101 + def sample( # noqa: D102 + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 + generator: torch.Generator | None = None, + ) -> torch.Tensor: + with torch.no_grad(): + return self.rsample(sample_shape=sample_shape, generator=generator) + + def rsample( # noqa: D102 + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 + generator: torch.Generator | None = None, + ) -> torch.Tensor: + shape = self._extended_shape(sample_shape) + rand = torch.rand( + shape, dtype=self.low.dtype, device=self.low.device, generator=generator + ) + return self.low + rand * (self.high - self.low) + + class TruncatedNormal(TruncatedStandardNormal): """Truncated Normal distribution. @@ -229,7 +280,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value) - self._log_scale # type: ignore -class UniformWithUpperBound(Uniform): +class UniformWithUpperBound(UniformWithGenerator): """Uniform distribution with upper bound inclusive. This is mostly a hack because torch's version of Uniform does not include diff --git a/neps/sampling/priors.py b/neps/sampling/priors.py index dfff8d612..6685f82ff 100644 --- a/neps/sampling/priors.py +++ b/neps/sampling/priors.py @@ -19,6 +19,7 @@ from neps.sampling.distributions import ( UNIT_UNIFORM_DIST, + CategoricalWithGenerator, TorchDistributionWithDomain, TruncatedNormal, ) @@ -140,6 +141,7 @@ def from_parameters( used for determining the strength of the prior. Values should be between 0 and 1. Overwrites whatever is set by default in the `.prior-confidence`. + seed: custom torch random number generator Returns: The prior distribution @@ -211,6 +213,7 @@ def from_domains_and_centers( domain. All confidence levels should be within the `[0, 1]` range. device: Device to place the tensors on for distributions. + seed: custom torch random number generator Returns: A prior for the search space. @@ -231,8 +234,10 @@ def from_domains_and_centers( # Uniform categorical n_cats = domain.cardinality assert n_cats is not None + # hack: torch.distributions.Categorical does not support generators, + # but in sample function we are using multinomial, which needs seed. dist = TorchDistributionWithDomain( - distribution=torch.distributions.Categorical( + distribution=CategoricalWithGenerator( probs=torch.ones(n_cats, device=device) / n_cats, validate_args=False, ), @@ -266,7 +271,7 @@ def from_domains_and_centers( weights[int(center_index)] = conf dist = TorchDistributionWithDomain( - distribution=torch.distributions.Categorical( + distribution=CategoricalWithGenerator( probs=weights, validate_args=False ), domain=domain, @@ -418,9 +423,6 @@ def sample( device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor: - if seed is not None: - raise NotImplementedError("Seeding is not yet implemented.") - _out_shape = ( torch.Size((n, self.ncols)) if isinstance(n, int) @@ -430,7 +432,10 @@ def sample( out = torch.empty(_out_shape, device=device, dtype=dtype) for i, dist in enumerate(self.distributions): - out[..., i] = dist.distribution.sample(_n) + # the abstract torch.distributions.Distribution + # does not suppor generator in its signiture. + # but whichever class we add here should + out[..., i] = dist.distribution.sample(_n, generator=seed) return Domain.translate(out, frm=self._distribution_domains, to=to, dtype=dtype) @@ -471,9 +476,6 @@ def sample( device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor: - if seed is not None: - raise NotImplementedError("Seeding is not yet implemented.") - _n = ( torch.Size((n, self.ndim)) if isinstance(n, int) @@ -481,8 +483,8 @@ def sample( ) # Doesn't like integer dtypes if dtype is not None and dtype.is_floating_point: - samples = torch.rand(_n, device=device, dtype=dtype) + samples = torch.rand(_n, device=device, dtype=dtype, generator=seed) else: - samples = torch.rand(_n, device=device) + samples = torch.rand(_n, device=device, generator=seed) return Domain.translate(samples, frm=Domain.unit_float(), to=to, dtype=dtype) diff --git a/neps/sampling/samplers.py b/neps/sampling/samplers.py index c6f70b6af..5d41a69eb 100644 --- a/neps/sampling/samplers.py +++ b/neps/sampling/samplers.py @@ -167,9 +167,6 @@ def sample( device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> torch.Tensor: - if seed is not None: - raise NotImplementedError("Setting the seed is not supported yet") - # Sobol can only produce 2d tensors. To handle batches or arbitrary # dimensions, we get a count of the total number of samples needed # and reshape the output tensor to the desired shape, if needed. diff --git a/neps/state/__init__.py b/neps/state/__init__.py index 2daa057f3..34bd24790 100644 --- a/neps/state/__init__.py +++ b/neps/state/__init__.py @@ -1,7 +1,7 @@ from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState from neps.state.pipeline_eval import EvaluatePipelineReturn, UserResult, evaluate_trial -from neps.state.seed_snapshot import SeedSnapshot +from neps.state.seed_snapshot import RNGStateManager from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.state.trial import State, Trial @@ -12,7 +12,7 @@ "NePSState", "OnErrorPossibilities", "OptimizationState", - "SeedSnapshot", + "RNGStateManager", "State", "Trial", "UserResult", diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index cd8e88058..a2b5c74f7 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -421,8 +421,6 @@ def _sample_trial( with self._optimizer_state_path.open("rb") as f: opt_state: OptimizationState = pickle.load(f) # noqa: S301 - opt_state.seed_snapshot.set_as_global_seed_state() - assert callable(optimizer) if opt_state.budget is not None: # NOTE: All other values of budget are ones that should remain @@ -480,7 +478,7 @@ def _sample_trial( sampled_trials.append(trial) opt_state.shared_state = shared_state - opt_state.seed_snapshot.recapture() + opt_state.rng_state_manager.capture_local() with self._optimizer_state_path.open("wb") as f: pickle.dump(opt_state, f, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/neps/state/optimizer.py b/neps/state/optimizer.py index 9f95191f9..04d0421b4 100644 --- a/neps/state/optimizer.py +++ b/neps/state/optimizer.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from neps.state.seed_snapshot import SeedSnapshot + from neps.state.seed_snapshot import RNGStateManager @dataclass @@ -31,7 +31,7 @@ class OptimizationState: budget: BudgetInfo | None """Information regarind the budget used by the optimization trajectory.""" - seed_snapshot: SeedSnapshot + rng_state_manager: RNGStateManager """The state of the random number generators at the time of the last sample.""" shared_state: dict[str, Any] | None diff --git a/neps/state/seed_snapshot.py b/neps/state/seed_snapshot.py index 4a26370b0..c3457e619 100644 --- a/neps/state/seed_snapshot.py +++ b/neps/state/seed_snapshot.py @@ -4,14 +4,14 @@ import contextlib import random +from collections.abc import Generator from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeAlias import numpy as np +import torch if TYPE_CHECKING: - import torch - NP_RNG_STATE: TypeAlias = tuple[str, np.ndarray, int, int, float] PY_RNG_STATE: TypeAlias = tuple[int, tuple[int, ...], int | None] TORCH_RNG_STATE: TypeAlias = torch.Tensor @@ -19,96 +19,158 @@ @dataclass -class SeedSnapshot: +class RNGStateManager: """State of the global rng. - Primarly enables storing of the rng state to disk using a binary format - native to each library, allowing for potential version mistmatches between + Primarly used as a seed manager, having all the RNGs needed and enables storing of + the rng state to disk using a binary format native to each library, + allowing for potential version mistmatches between processes loading the state, as long as they can read the binary format. """ - np_rng: NP_RNG_STATE - py_rng: PY_RNG_STATE - torch_rng: TORCH_RNG_STATE | None - torch_cuda_rng: TORCH_CUDA_RNG_STATE | None + np_rng_state: NP_RNG_STATE + py_rng_state: PY_RNG_STATE + torch_rng_state: TORCH_RNG_STATE + torch_cuda_rng_state: TORCH_CUDA_RNG_STATE | None + + # Not appear in the dump + py_rng: random.Random + np_rng: np.random.Generator + torch_manual_rng: torch.Generator + torch_cuda_rng: list[torch.Generator] | None = None @classmethod - def new_capture(cls) -> SeedSnapshot: + def new_capture( + cls, + seed: int | None = None, + np_rng: np.random.Generator | None = None, + torch_rng: torch.Generator | None = None, + ) -> RNGStateManager: """Current state of the global rng. Takes a snapshot, including cloning or copying any arrays, tensors, etc. """ - self = cls(None, None, None, None) # type: ignore - self.recapture() + self = cls( + np_rng_state=None, # type: ignore + py_rng_state=None, # type: ignore + torch_rng_state=None, # type: ignore + torch_cuda_rng_state=None, + py_rng=random.Random(seed), + np_rng=np_rng or np.random.default_rng(seed), + torch_manual_rng=torch_rng + or torch.Generator().manual_seed(seed or torch.seed()), + torch_cuda_rng=None, + ) + if torch.cuda.is_available(): + self.torch_cuda_rng = [ + torch.Generator(device=f"cuda:{i}").manual_seed( + (seed or seed or torch.seed()) + i + ) + for i in range(torch.cuda.device_count()) + ] + + self.capture_local() return self - def recapture(self) -> None: - """Reread the state of the global rng into this snapshot.""" - # https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html - - self.py_rng = random.getstate() - - np_keys = np.random.get_state(legacy=True) - assert np_keys[0] == "MT19937" # type: ignore - self.np_rng = (np_keys[0], np_keys[1].copy(), *np_keys[2:]) # type: ignore - - with contextlib.suppress(Exception): - import torch - - self.torch_rng = torch.random.get_rng_state().clone() - torch_cuda_keys: list[torch.Tensor] | None = None - if torch.cuda.is_available(): - torch_cuda_keys = [c.clone() for c in torch.cuda.get_rng_state_all()] - self.torch_cuda_rng = torch_cuda_keys - - def set_as_global_seed_state(self) -> None: - """Set the global rng to the given state.""" - np.random.set_state(self.np_rng) - random.setstate(self.py_rng) - - if self.torch_rng is not None or self.torch_cuda_rng is not None: - import torch - - if self.torch_rng is not None: - torch.random.set_rng_state(self.torch_rng) - - if self.torch_cuda_rng is not None and torch.cuda.is_available(): - torch.cuda.set_rng_state_all(self.torch_cuda_rng) - - def __eq__(self, other: Any, /) -> bool: # noqa: PLR0911 - if not isinstance(other, SeedSnapshot): + def capture_local(self) -> None: + """Capture the current state of the local rngs.""" + # Capture Python RNG state + self.py_rng_state = self.py_rng.getstate() + + # Capture NumPy RNG state + self.np_rng_state = self.np_rng.bit_generator.state + + # Capture PyTorch CPU generator state + self.torch_rng_state = self.torch_manual_rng.get_state().clone() + + # Capture PyTorch CUDA generators state + if self.torch_cuda_rng is not None: + self.torch_cuda_rng_state = [ + g.get_state().clone() for g in self.torch_cuda_rng + ] + + def __getstate__(self) -> dict[str, Any]: + return { + "np_rng_state": self.np_rng_state, + "py_rng_state": self.py_rng_state, + "torch_rng_state": self.torch_rng_state, + "torch_cuda_rng_state": self.torch_cuda_rng_state, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.np_rng_state = state["np_rng_state"] + self.py_rng_state = state["py_rng_state"] + self.torch_rng_state = state["torch_rng_state"] + self.torch_cuda_rng_state = state.get("torch_cuda_rng_state") + + self.py_rng = random.Random() # or whatever Random object you want + self.py_rng.setstate(self.py_rng_state) + + # Restore NumPy RNG + self.np_rng = np.random.default_rng() # create a new Generator + self.np_rng.bit_generator.state = self.np_rng_state + + # Restore PyTorch CPU generator + self.torch_manual_rng = torch.Generator() + self.torch_manual_rng.set_state(self.torch_rng_state) + + # Restore PyTorch CUDA generators + if self.torch_cuda_rng_state is not None: + self.torch_cuda_rng = [ + torch.Generator(device=f"cuda:{i}").set_state(state) + for i, state in enumerate(self.torch_cuda_rng_state) + ] + + def __eq__(self, other: Any, /) -> bool: + if not isinstance(other, RNGStateManager): return False - if not (self.py_rng == other.py_rng): + if not (self.py_rng_state == other.py_rng_state): return False - if not ( - self.np_rng[0] == other.np_rng[0] - and self.np_rng[2] == other.np_rng[2] - and self.np_rng[3] == other.np_rng[3] - and self.np_rng[4] == other.np_rng[4] - ): + if not self.np_rng_state == other.np_rng_state: return False - if not np.array_equal(self.np_rng[1], other.np_rng[1]): - return False - - if self.torch_rng is not None and other.torch_rng is not None: + if self.torch_rng_state is not None and other.torch_rng_state is not None: import torch - if not torch.equal(self.torch_rng, other.torch_rng): + if not torch.equal(self.torch_rng_state, other.torch_rng_state): return False - if self.torch_cuda_rng is not None and other.torch_cuda_rng is not None: + if ( + self.torch_cuda_rng_state is not None + and other.torch_cuda_rng_state is not None + ): import torch if not all( torch.equal(a, b) - for a, b in zip(self.torch_cuda_rng, other.torch_cuda_rng, strict=False) + for a, b in zip( + self.torch_cuda_rng_state, other.torch_cuda_rng_state, strict=False + ) ): return False - if not isinstance(self.torch_rng, type(other.torch_rng)): - return False - - return isinstance(self.torch_cuda_rng, type(other.torch_cuda_rng)) + return True + + +@contextlib.contextmanager +def use_generator_globally( # noqa: D103 + generator: torch.Generator, + device: str = "cpu", +) -> Generator[Any, Any, Any]: + if device == "cpu": + old_state = torch.get_rng_state() + torch.set_rng_state(generator.get_state()) + else: + old_state = torch.cuda.get_rng_state_all() + torch.cuda.set_rng_state_all( + [generator.get_state() for _ in range(torch.cuda.device_count())] + ) + try: + yield + finally: + if device == "cpu": + torch.set_rng_state(old_state) + else: + torch.cuda.set_rng_state_all(old_state) diff --git a/neps/utils/trial_io.py b/neps/utils/trial_io.py index 8a9c948eb..145ee5e21 100644 --- a/neps/utils/trial_io.py +++ b/neps/utils/trial_io.py @@ -7,8 +7,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from neps.state.neps_state import TrialRepo - if TYPE_CHECKING: from neps.state.trial import Trial @@ -26,6 +24,8 @@ def load_trials_from_pickle( each containing the trial configuration and its corresponding report as a dictionary. """ + from neps.state.neps_state import TrialRepo + if isinstance(root_dir, str): root_dir = Path(root_dir) trials: ValuesView[Trial] = ( diff --git a/neps_examples/basic_usage/hyperparameters.py b/neps_examples/basic_usage/hyperparameters.py index 0f3fdc898..8980e7f1d 100644 --- a/neps_examples/basic_usage/hyperparameters.py +++ b/neps_examples/basic_usage/hyperparameters.py @@ -23,11 +23,12 @@ def evaluate_pipeline(float1, float2, categorical, integer1, integer2): integer2=neps.Integer(lower=1, upper=1000, log=True), ) -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) neps.run( evaluate_pipeline=evaluate_pipeline, pipeline_space=pipeline_space, root_directory="results/hyperparameters_example", evaluations_to_spend=30, worker_id=f"worker_1-{socket.gethostname()}-{os.getpid()}", + optimizer="bayesian_optimization" ) diff --git a/tests/test_runtime/test_default_report_values.py b/tests/test_runtime/test_default_report_values.py index 45ad5e5cc..b974105f6 100644 --- a/tests/test_runtime/test_default_report_values.py +++ b/tests/test_runtime/test_default_report_values.py @@ -13,7 +13,7 @@ NePSState, OnErrorPossibilities, OptimizationState, - SeedSnapshot, + RNGStateManager, Trial, WorkerSettings, ) @@ -25,7 +25,7 @@ def neps_state(tmp_path: Path) -> NePSState: path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( - budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} + budget=None, rng_state_manager=RNGStateManager.new_capture(), shared_state={} ), ) @@ -33,7 +33,10 @@ def neps_state(tmp_path: Path) -> NePSState: def test_default_values_on_error( neps_state: NePSState, ) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues( @@ -86,7 +89,9 @@ def eval_function(*args, **kwargs) -> float: def test_default_values_on_not_specified( neps_state: NePSState, ) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues( @@ -137,7 +142,9 @@ def eval_function(*args, **kwargs) -> float: def test_default_value_objective_to_minimize_curve_take_objective_to_minimize_value( neps_state: NePSState, ) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues( diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 0549f87b5..8a1654ca1 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -17,7 +17,7 @@ NePSState, OnErrorPossibilities, OptimizationState, - SeedSnapshot, + RNGStateManager, Trial, WorkerSettings, ) @@ -30,7 +30,7 @@ def neps_state(tmp_path: Path) -> NePSState: optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( budget=None, - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=RNGStateManager.new_capture(), shared_state=None, ), ) @@ -44,7 +44,9 @@ def test_worker_raises_when_error_in_self( neps_state: NePSState, on_error: OnErrorPossibilities, ) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=on_error, # <- Highlight default_report_values=DefaultReportValues(), @@ -85,7 +87,9 @@ def eval_function(*args, **kwargs) -> float: def test_worker_raises_when_error_in_other_worker(neps_state: NePSState) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=OnErrorPossibilities.RAISE_ANY_ERROR, # <- Highlight default_report_values=DefaultReportValues(), @@ -146,7 +150,9 @@ def test_worker_does_not_raise_when_error_in_other_worker( neps_state: NePSState, on_error: OnErrorPossibilities, ) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=on_error, # <- Highlight default_report_values=DefaultReportValues(), diff --git a/tests/test_runtime/test_save_evaluation_results.py b/tests/test_runtime/test_save_evaluation_results.py index b0db8b163..4d483aade 100644 --- a/tests/test_runtime/test_save_evaluation_results.py +++ b/tests/test_runtime/test_save_evaluation_results.py @@ -14,7 +14,7 @@ NePSState, OnErrorPossibilities, OptimizationState, - SeedSnapshot, + RNGStateManager, Trial, WorkerSettings, ) @@ -26,13 +26,15 @@ def neps_state(tmp_path: Path) -> NePSState: path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( - budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} + budget=None, rng_state_manager=RNGStateManager.new_capture(), shared_state={} ), ) def test_async_happy_path_changes_state(neps_state: NePSState) -> None: - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues( diff --git a/tests/test_runtime/test_stopping_criterion.py b/tests/test_runtime/test_stopping_criterion.py index 67f15365f..7495a2a9b 100644 --- a/tests/test_runtime/test_stopping_criterion.py +++ b/tests/test_runtime/test_stopping_criterion.py @@ -14,7 +14,7 @@ NePSState, OnErrorPossibilities, OptimizationState, - SeedSnapshot, + RNGStateManager, Trial, WorkerSettings, ) @@ -27,7 +27,7 @@ def neps_state(tmp_path: Path) -> NePSState: optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( budget=None, - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=RNGStateManager.new_capture(), shared_state=None, ), ) @@ -36,7 +36,10 @@ def neps_state(tmp_path: Path) -> NePSState: def test_max_evaluations_total_stopping_criterion( neps_state: NePSState, ) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -89,7 +92,10 @@ def eval_function(*args, **kwargs) -> float: def test_worker_evaluations_total_stopping_criterion( neps_state: NePSState, ) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -151,7 +157,10 @@ def eval_function(*args, **kwargs) -> float: def test_include_in_progress_evaluations_towards_maximum_with_work_eval_count( neps_state: NePSState, ) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -206,7 +215,10 @@ def eval_function(*args, **kwargs) -> float: def test_max_cost_total(neps_state: NePSState) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -255,7 +267,10 @@ def eval_function(*args, **kwargs) -> dict: def test_worker_cost_total(neps_state: NePSState) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -312,7 +327,10 @@ def eval_function(*args, **kwargs) -> dict: def test_worker_wallclock_time(neps_state: NePSState) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -368,7 +386,10 @@ def eval_function(*args, **kwargs) -> float: def test_max_worker_evaluation_time(neps_state: NePSState) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), @@ -425,7 +446,10 @@ def eval_function(*args, **kwargs) -> float: def test_max_evaluation_time_global(neps_state: NePSState) -> None: - optimizer = random_search(pipeline_space=SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + pipeline_space=SearchSpace({"a": Float(0, 1)}), + rng_manager=RNGStateManager.new_capture(), + ) settings = WorkerSettings( on_error=OnErrorPossibilities.IGNORE, default_report_values=DefaultReportValues(), diff --git a/tests/test_runtime/test_worker_creation.py b/tests/test_runtime/test_worker_creation.py index 4b741640a..355f00dad 100644 --- a/tests/test_runtime/test_worker_creation.py +++ b/tests/test_runtime/test_worker_creation.py @@ -13,7 +13,7 @@ WorkerSettings, ) from neps.space import Float, SearchSpace -from neps.state import NePSState, OptimizationState, SeedSnapshot +from neps.state import NePSState, OptimizationState, RNGStateManager @pytest.fixture @@ -22,7 +22,7 @@ def neps_state(tmp_path: Path) -> NePSState: path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( - budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} + budget=None, rng_state_manager=RNGStateManager.new_capture(), shared_state={} ), ) @@ -47,7 +47,9 @@ def eval_fn(config: dict) -> float: return 1.0 test_worker_id = "my_worker_123" - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) worker = DefaultWorker.new( state=neps_state, @@ -80,7 +82,9 @@ def test_create_worker_auto_id(neps_state: NePSState) -> None: def eval_fn(config: dict) -> float: return 1.0 - optimizer = random_search(SearchSpace({"a": Float(0, 1)})) + optimizer = random_search( + SearchSpace({"a": Float(0, 1)}), rng_manager=RNGStateManager.new_capture() + ) worker = DefaultWorker.new( state=neps_state, diff --git a/tests/test_state/test_filebased_neps_state.py b/tests/test_state/test_filebased_neps_state.py index 5572abb4d..652d5555b 100644 --- a/tests/test_state/test_filebased_neps_state.py +++ b/tests/test_state/test_filebased_neps_state.py @@ -16,7 +16,7 @@ from neps.state.err_dump import ErrDump from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState -from neps.state.seed_snapshot import SeedSnapshot +from neps.state.seed_snapshot import RNGStateManager @fixture @@ -28,7 +28,7 @@ def optimizer_state( ) -> OptimizationState: return OptimizationState( budget=budget_info, - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=RNGStateManager.new_capture(), shared_state=shared_state, ) @@ -82,7 +82,7 @@ def test_create_or_load_with_load_filebased_neps_state( # was passed in. different_state = OptimizationState( budget=BudgetInfo(cost_to_spend=20, used_cost_budget=10), - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=RNGStateManager.new_capture(), shared_state=None, ) neps_state2 = NePSState.create_or_load( diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index bf7512fae..9a2bbe76f 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -26,7 +26,9 @@ Integer, SearchSpace, ) -from neps.state import BudgetInfo, NePSState, OptimizationState, SeedSnapshot +from neps.state import BudgetInfo, NePSState, OptimizationState, RNGStateManager + +rng_manager = RNGStateManager.new_capture() @case @@ -179,7 +181,7 @@ def case_neps_state_filebased( optimizer_info=optimizer_info, optimizer_state=OptimizationState( budget=cost_to_spend, - seed_snapshot=SeedSnapshot.new_capture(), + rng_state_manager=rng_manager, shared_state=shared_state, ), ) @@ -191,7 +193,8 @@ def test_sample_trial( optimizer_and_key_and_search_space: tuple[AskFunction, str, SearchSpace], capsys, ) -> None: - optimizer, key, search_space = optimizer_and_key_and_search_space + optimizer_fn, key, search_space = optimizer_and_key_and_search_space + optimizer = optimizer_fn(rng_manager) assert neps_state.lock_and_read_trials() == {} assert neps_state.lock_and_get_next_pending_trial() is None @@ -233,7 +236,7 @@ def test_optimizers_work_roughly( optimizer_and_key_and_search_space: tuple[AskFunction, str, SearchSpace], ) -> None: opt, key, search_space = optimizer_and_key_and_search_space - ask_and_tell = AskAndTell(opt) + ask_and_tell = AskAndTell(opt(rng_manager)) for _ in range(20): trial = ask_and_tell.ask() diff --git a/tests/test_state/test_rng.py b/tests/test_state/test_rng.py index 2605433a4..fb2137eb1 100644 --- a/tests/test_state/test_rng.py +++ b/tests/test_state/test_rng.py @@ -1,38 +1,129 @@ from __future__ import annotations +import pickle import random from collections.abc import Callable from pathlib import Path +from typing import Any import numpy as np import pytest import torch -from neps.state.seed_snapshot import SeedSnapshot +from neps.state.seed_snapshot import RNGStateManager, use_generator_globally @pytest.mark.parametrize( - "make_ints", + ("make_ints", "rng_factory"), [ - lambda: [random.randint(0, 100) for _ in range(10)], - lambda: list(np.random.randint(0, 100, (10,))), - lambda: list(torch.randint(0, 100, (10,))), + # Python RNG + ( + lambda rng: [rng.randint(0, 100) for _ in range(10)], + lambda seed: random.Random(seed), + ), + # NumPy Generator + ( + lambda rng: list(rng.integers(0, 100, size=10)), + lambda seed: np.random.default_rng(seed), + ), + # PyTorch Generator + ( + lambda rng: list(torch.randint(0, 100, (10,), generator=rng)), + lambda seed: torch.Generator().manual_seed(seed), + ), ], ) def test_randomstate_consistent( - tmp_path: Path, make_ints: Callable[[], list[int]] -) -> None: - random.seed(42) - np.random.seed(42) - torch.manual_seed(42) + tmp_path: Path, + make_ints: Callable[[Any], list[int]], + rng_factory: Callable[[int], Any], +): + seed = 230 - seed_dir = tmp_path / "seed_dir" - seed_dir.mkdir(exist_ok=True, parents=True) + rng1 = rng_factory(seed) + integers_1 = make_ints(rng1) - seed_state = SeedSnapshot.new_capture() - integers_1 = make_ints() + rng2 = rng_factory(seed) + integers_2 = make_ints(rng2) - seed_state.set_as_global_seed_state() - - integers_2 = make_ints() assert integers_1 == integers_2 + + rng3 = rng_factory(1111) + integers_3 = make_ints(rng3) + + assert integers_1 != integers_3 + + +@pytest.mark.parametrize("seed", [0, 42, 999]) +def test_capture_reproducibility(seed): + """Ensure new_capture gives reproducible RNG states.""" + mgr1 = RNGStateManager.new_capture( + seed, torch_rng=torch.Generator().manual_seed(seed) + ) + mgr2 = RNGStateManager.new_capture( + seed, torch_rng=torch.Generator().manual_seed(seed) + ) + + # Check that the Python RNG produces the same values + py_vals1 = [mgr1.py_rng.randint(0, 100) for _ in range(10)] + py_vals2 = [mgr2.py_rng.randint(0, 100) for _ in range(10)] + assert py_vals1 == py_vals2 + + # Check NumPy RNG + np_vals1 = mgr1.np_rng.integers(0, 100, size=10).tolist() + np_vals2 = mgr2.np_rng.integers(0, 100, size=10).tolist() + assert np_vals1 == np_vals2 + + # Check Torch CPU RNG + t_vals1 = torch.randint(0, 100, (10,), generator=mgr1.torch_manual_rng) + t_vals2 = torch.randint(0, 100, (10,), generator=mgr2.torch_manual_rng) + assert torch.equal(t_vals1, t_vals2) + + +@pytest.mark.parametrize("seed", [123]) +def test_getsetstate_pickling(seed): + """Ensure pickling and unpickling preserves RNG states.""" + mgr = RNGStateManager.new_capture(seed) + dumped = pickle.dumps(mgr) + loaded = pickle.loads(dumped) # noqa: S301 + + assert mgr == loaded + + # Check reproducibility after restoring + val_before = mgr.py_rng.randint(0, 100) + val_after = loaded.py_rng.randint(0, 100) + assert val_before == val_after + + +@pytest.mark.parametrize("seed", [7]) +def test_context_manager_sets_global_rng(seed): + """Ensure use_generator_globally temporarily sets global RNG.""" + gen = torch.Generator().manual_seed(seed) + + torch.manual_seed(999) + torch.randint(0, 100, (1,)).item() + + with use_generator_globally(gen): + val_inside = torch.randint(0, 100, (1,)).item() + val_after = torch.randint(0, 100, (1,)).item() + + # Inside the context, the value should match gen's first draw + gen_check = torch.randint(0, 100, (1,), generator=gen).item() + assert val_inside == gen_check + + # Outside, the global RNG is restored + assert val_after != val_inside + + +@pytest.mark.parametrize("seed", [42]) +def test_cuda_generator_if_available(seed): + """Ensure CUDA generator captures work if CUDA is available.""" + if torch.cuda.is_available(): + mgr = RNGStateManager.new_capture(seed) + assert mgr.torch_cuda_rng is not None + assert len(mgr.torch_cuda_rng) == torch.cuda.device_count() + # Generate a tensor on GPU with the captured generator + t_gpu = torch.randint( + 0, 100, (10,), generator=mgr.torch_cuda_rng[0], device="cuda" + ) + assert t_gpu.device.type == "cuda" From f1bfa982e310757eeb3002f40ac64d38fc863c88 Mon Sep 17 00:00:00 2001 From: Nastaran Alipour Date: Fri, 14 Nov 2025 10:05:43 +0100 Subject: [PATCH 2/2] fix docs --- neps/api.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/neps/api.py b/neps/api.py index 8602b8524..2d0da5c5f 100644 --- a/neps/api.py +++ b/neps/api.py @@ -434,16 +434,16 @@ def __call__( numpy_rng: An optional numpy random number generator. torch_rng: An optional torch random number generator. - ??? tip "RNG Priority and Control" - When a previously created NePS state is loaded (by reusing the same root_directory), - all RNGs are reconstructed from the saved state. In this case, the parameters seed, numpy_rng, - and torch_rng are ignored, because the experiment continues from the exact stored RNG states. - If you provide numpy_rng or torch_rng explicitly, these generators take precedence and are used directly. - The seed parameter is not used to create new RNGs in this situation. - The overall priority is: - 1. Saved NePS state — highest priority - 2. User-provided numpy_rng and torch_rng - 3. seed — used only when no RNG objects and no saved state are available + !!! tip "RNG Priority and Control" + When a previously created NePS state is loaded (by reusing the same root_directory), + all RNGs are reconstructed from the saved state. In this case, the parameters seed, numpy_rng, + and torch_rng are ignored, because the experiment continues from the exact stored RNG states. + If you provide numpy_rng or torch_rng explicitly, these generators take precedence and are used directly. + The seed parameter is not used to create new RNGs in this situation. + The overall priority is: + 1. Saved NePS state — highest priority + 2. User-provided numpy_rng and torch_rng + 3. seed — used only when no RNG objects and no saved state are available """ # noqa: E501 if ( evaluations_to_spend is None @@ -660,14 +660,16 @@ def import_trials( Exception: For unexpected errors during trial import. Example: - >>> import neps - >>> from neps.state.pipeline_eval import UserResultDict - >>> pipeline_space = neps.SearchSpace({...}) - >>> evaluated_trials = [ - ... ({"param1": 0.5, "param2": 10}, - ... UserResultDict(objective_to_minimize=-5.0)), - ... ] - >>> neps.import_trials(pipeline_space, evaluated_trials, "my_results") + ```python + import neps + from neps.state.pipeline_eval import UserResultDict + pipeline_space = neps.SearchSpace({...}) + evaluated_trials = [ + ({"param1": 0.5, "param2": 10}, + UserResultDict(objective_to_minimize=-5.0)), + ] + neps.import_trials(pipeline_space, evaluated_trials, "my_results") + ``` """ if isinstance(root_directory, str): root_directory = Path(root_directory)