Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from neps.optimizers import AskFunction, OptimizerChoice, load_optimizer
from neps.runtime import _launch_runtime, _save_results
from neps.space.parsing import convert_to_space
from neps.state import NePSState, OptimizationState, SeedSnapshot
from neps.state import NePSState, OptimizationState, RNGStateManager
from neps.status.status import post_run_csv, trajectory_of_improvements
from neps.utils.common import dynamic_load_object
from neps.validation import _validate_imported_config, _validate_imported_result

if TYPE_CHECKING:
import numpy as np
import torch
from ConfigSpace import ConfigurationSpace

from neps.optimizers.algorithms import CustomOptimizer
Expand Down Expand Up @@ -59,6 +61,9 @@ def run( # noqa: C901, D417, PLR0913
| CustomOptimizer
| Literal["auto"]
) = "auto",
seed: int | None = None,
numpy_rng: np.random.Generator | None = None,
torch_rng: torch.Generator | None = None,
) -> None:
"""Run the optimization.

Expand Down Expand Up @@ -425,7 +430,20 @@ def __call__(

This is mainly meant for internal development but allows you to use the NePS
runtime to run your optimizer.

seed: An optional seed for the random number generators.
numpy_rng: An optional numpy random number generator.
torch_rng: An optional torch random number generator.

!!! tip "RNG Priority and Control"
When a previously created NePS state is loaded (by reusing the same root_directory),
all RNGs are reconstructed from the saved state. In this case, the parameters seed, numpy_rng,
and torch_rng are ignored, because the experiment continues from the exact stored RNG states.
If you provide numpy_rng or torch_rng explicitly, these generators take precedence and are used directly.
The seed parameter is not used to create new RNGs in this situation.
The overall priority is:
1. Saved NePS state — highest priority
2. User-provided numpy_rng and torch_rng
3. seed — used only when no RNG objects and no saved state are available
""" # noqa: E501
if (
evaluations_to_spend is None
Expand Down Expand Up @@ -479,7 +497,10 @@ def __call__(

logger.info(f"Starting neps.run using root directory {root_directory}")
space = convert_to_space(pipeline_space)
_optimizer_ask, _optimizer_info = load_optimizer(optimizer=optimizer, space=space)

_optimizer_ask_wrapper, _optimizer_info = load_optimizer(
optimizer=optimizer, space=space
)

multi_fidelity_optimizers = {
"successive_halving",
Expand Down Expand Up @@ -523,9 +544,13 @@ def __call__(
"'module:function'."
)

rng_manager = RNGStateManager.new_capture(
seed=seed, np_rng=numpy_rng, torch_rng=torch_rng
)

_launch_runtime(
evaluation_fn=_eval, # type: ignore
optimizer=_optimizer_ask,
optimizer_fn=_optimizer_ask_wrapper,
optimizer_info=_optimizer_info,
cost_to_spend=cost_to_spend,
fidelities_to_spend=fidelities_to_spend,
Expand All @@ -540,6 +565,7 @@ def __call__(
sample_batch_size=sample_batch_size,
write_summary_to_disk=write_summary_to_disk,
worker_id=worker_id,
rng_manager=rng_manager,
)

post_run_csv(root_directory)
Expand Down Expand Up @@ -634,25 +660,32 @@ def import_trials(
Exception: For unexpected errors during trial import.

Example:
>>> import neps
>>> from neps.state.pipeline_eval import UserResultDict
>>> pipeline_space = neps.SearchSpace({...})
>>> evaluated_trials = [
... ({"param1": 0.5, "param2": 10},
... UserResultDict(objective_to_minimize=-5.0)),
... ]
>>> neps.import_trials(pipeline_space, evaluated_trials, "my_results")
```python
import neps
from neps.state.pipeline_eval import UserResultDict
pipeline_space = neps.SearchSpace({...})
evaluated_trials = [
({"param1": 0.5, "param2": 10},
UserResultDict(objective_to_minimize=-5.0)),
]
neps.import_trials(pipeline_space, evaluated_trials, "my_results")
```
"""
if isinstance(root_directory, str):
root_directory = Path(root_directory)

optimizer_ask, optimizer_info = load_optimizer(optimizer, pipeline_space)
rng_manager = RNGStateManager.new_capture()

optimizer_ask_fn, optimizer_info = load_optimizer(
optimizer,
pipeline_space,
)

state = NePSState.create_or_load(
root_directory,
optimizer_info=optimizer_info,
optimizer_state=OptimizationState(
budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={}
budget=None, rng_state_manager=rng_manager, shared_state={}
),
)

Expand All @@ -675,7 +708,7 @@ def import_trials(
if tuple(sorted(t[0].items())) not in existing_configs
]

imported_trials = optimizer_ask.import_trials(
imported_trials = optimizer_ask_fn(rng_manager).import_trials(
external_evaluations=normalized_trials,
trials=state_trials,
)
Expand Down
27 changes: 20 additions & 7 deletions neps/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

if TYPE_CHECKING:
from neps.space import SearchSpace
from neps.state.seed_snapshot import RNGStateManager


def _load_optimizer_from_string(
optimizer: OptimizerChoice | Literal["auto"],
space: SearchSpace,
*,
optimizer_kwargs: Mapping[str, Any] | None = None,
) -> tuple[AskFunction, OptimizerInfo]:
) -> tuple[Callable[[RNGStateManager], AskFunction], OptimizerInfo]:
if optimizer == "auto":
_optimizer = determine_optimizer_automatically(space)
else:
Expand All @@ -37,7 +38,9 @@ def _load_optimizer_from_string(

keywords = extract_keyword_defaults(optimizer_build)
optimizer_kwargs = optimizer_kwargs or {}
opt = optimizer_build(space, **optimizer_kwargs)
opt = lambda rng_manager: optimizer_build(
space, rng_manager=rng_manager, **optimizer_kwargs
)
info = OptimizerInfo(name=_optimizer, info={**keywords, **optimizer_kwargs})
return opt, info

Expand All @@ -52,30 +55,40 @@ def load_optimizer(
| Literal["auto"]
),
space: SearchSpace,
) -> tuple[AskFunction, OptimizerInfo]:
) -> tuple[Callable[[RNGStateManager], AskFunction], OptimizerInfo]:
match optimizer:
# Predefined string (including "auto")
case str():
return _load_optimizer_from_string(optimizer, space)

# Predefined string with kwargs
case (opt, kwargs) if isinstance(opt, str):
return _load_optimizer_from_string(opt, space, optimizer_kwargs=kwargs) # type: ignore
return _load_optimizer_from_string(
opt, # type: ignore
space,
optimizer_kwargs=kwargs, # type: ignore
)

# Mapping with a name
case {"name": name, **_kwargs}:
return _load_optimizer_from_string(name, space, optimizer_kwargs=_kwargs) # type: ignore
return _load_optimizer_from_string(
name,
space,
optimizer_kwargs=_kwargs, # type: ignore
)

# Provided optimizer initializer
case _ if callable(optimizer):
keywords = extract_keyword_defaults(optimizer)
_optimizer = optimizer(space)
_optimizer = lambda rng_manager: optimizer(space, rng_manager=rng_manager)
info = OptimizerInfo(name=optimizer.__name__, info=keywords)
return _optimizer, info

# Custom optimizer, we create it
case CustomOptimizer(initialized=False):
_optimizer = optimizer.create(space)
_optimizer = lambda rng_manager=None: optimizer.create( # type: ignore
space, rng_manager=rng_manager
)
keywords = extract_keyword_defaults(optimizer.optimizer)
info = OptimizerInfo(
name=optimizer.name, info={**keywords, **optimizer.kwargs}
Expand Down
Loading