diff --git a/neps/optimizers/acquisition/__init__.py b/neps/optimizers/acquisition/__init__.py index 0d2d27efa..d8e4a4771 100644 --- a/neps/optimizers/acquisition/__init__.py +++ b/neps/optimizers/acquisition/__init__.py @@ -1,5 +1,11 @@ from neps.optimizers.acquisition.cost_cooling import cost_cooled_acq from neps.optimizers.acquisition.pibo import pibo_acquisition from neps.optimizers.acquisition.weighted_acquisition import WeightedAcquisition +from neps.optimizers.acquisition.wrapped_acquisition import WrappedAcquisition -__all__ = ["WeightedAcquisition", "cost_cooled_acq", "pibo_acquisition"] +__all__ = [ + "WeightedAcquisition", + "WrappedAcquisition", + "cost_cooled_acq", + "pibo_acquisition", +] diff --git a/neps/optimizers/acquisition/wrapped_acquisition.py b/neps/optimizers/acquisition/wrapped_acquisition.py new file mode 100644 index 000000000..21abe872c --- /dev/null +++ b/neps/optimizers/acquisition/wrapped_acquisition.py @@ -0,0 +1,94 @@ +"""Module to wrap the existing acquisition function to account for mixed search spaces. + +For mixed search spaces, we first keep the categorical dimensions fixed to some randomly +chosen values and perform optimization over the continuous dimensions. +Next, we select the numerical dimensions from the returned best candidate and keep them +fixed, while we use `optimize_acqf_discrete_local_search` over the categorical dimensions. + +For this, we need to wrap the existing acquisition function to accept tensors containing +only the categorical dimensions since BoTorch does not natively support keeping numerical +dimensions fixed in `optimize_acqf_discrete_local_search`. + +Inside `WrappedAcquisition`, we concatenate the fixed numerical dimensions to the tensor +containing only the categoricals before passing it to the original acquisition function. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from botorch.acquisition import AcquisitionFunction +from botorch.acquisition.analytic import t_batch_mode_transform +from botorch.acquisition.monte_carlo import concatenate_pending_points + +if TYPE_CHECKING: + from torch import Tensor + + from neps.space.encoding import ConfigEncoder + + +class WrappedAcquisition(AcquisitionFunction): + """Acquisition function wrapper for mixed search spaces.""" + + def __init__( + self, + acq: AcquisitionFunction, + encoder: ConfigEncoder, + fixed_numericals: dict[int, float], + ) -> None: + """Initialize the wrapped acquisition function. + + Args: + acq: The base acquisition function. + fixed_numericals: A dictionary mapping numerical dimension indices to their + fixed values. + """ + super().__init__(model=acq.model) + # NOTE: Remove X_pending from the base acquisition function. + # See similar note in WeightedAcquisition. + if (X_pending := getattr(acq, "X_pending", None)) is not None: + acq.set_X_pending(None) + self.set_X_pending(X_pending) + else: + acq.set_X_pending(None) + self.set_X_pending(None) + + self.acq = acq + self.encoder = encoder + self.fixed_numericals = fixed_numericals + self.fixed_numericals = fixed_numericals + + @concatenate_pending_points # type: ignore + @t_batch_mode_transform() # type: ignore + def forward(self, X: Tensor) -> Tensor: + """Evaluate the wrapped acquisition function on the candidate set X + after concatenating the fixed numerical dimensions. + + Args: + X: A `batch_shape x q x d_categorical`-dim tensor of candidates, where + `d_categorical` is the number of categorical dimensions. + + Returns: + A `batch_shape`-dim tensor of acquisition function values at the input + candidates. + """ + batch, q, c_dims = X.shape + n_dims = len(self.fixed_numericals) + new_X_shape = (batch, q, c_dims + n_dims) + + # Create a new tensor to hold the concatenated dimensions + x_full: torch.Tensor = torch.empty(new_X_shape, dtype=X.dtype, device=X.device) + + # Create a mask to identify positions of categorical and numerical dimensions + mask = torch.ones(c_dims + n_dims, dtype=torch.bool, device=X.device) + insert_idxs = torch.tensor(list(self.fixed_numericals.keys()), device=X.device) + mask[insert_idxs] = False + + # Fill in the fixed numerical values and the input categorical values + for idx, val in self.fixed_numericals.items(): + x_full[:, :, idx] = val + x_full[:, :, mask] = X + + # Pass the concatenated tensor to the original acquisition function + return self.acq(x_full) diff --git a/neps/optimizers/models/gp.py b/neps/optimizers/models/gp.py index 586ba371e..365caa77e 100644 --- a/neps/optimizers/models/gp.py +++ b/neps/optimizers/models/gp.py @@ -3,26 +3,34 @@ from __future__ import annotations import logging +import warnings from collections.abc import Mapping, Sequence from contextlib import nullcontext from dataclasses import dataclass -from functools import reduce -from itertools import product from typing import TYPE_CHECKING, Any import gpytorch.constraints +import numpy as np import torch +from botorch.exceptions.warnings import InputDataWarning from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.gp_regression import Log, get_covar_module_with_dim_scaled_prior from botorch.models.gp_regression_mixed import CategoricalKernel, OutcomeTransform from botorch.models.transforms.outcome import ChainedOutcomeTransform, Standardize -from botorch.optim import optimize_acqf, optimize_acqf_mixed +from botorch.optim import ( + optimize_acqf, + optimize_acqf_discrete_local_search, +) from gpytorch import ExactMarginalLogLikelihood from gpytorch.kernels import ScaleKernel from gpytorch.utils.warnings import NumericalWarning -from neps.optimizers.acquisition import cost_cooled_acq, pibo_acquisition +from neps.optimizers.acquisition import ( + WrappedAcquisition, + cost_cooled_acq, + pibo_acquisition, +) from neps.space.encoding import CategoricalToIntegerTransformer, ConfigEncoder from neps.utils.common import disable_warnings @@ -35,6 +43,8 @@ logger = logging.getLogger(__name__) +warnings.filterwarnings("ignore", category=InputDataWarning) + @dataclass class GPEncodedData: @@ -132,10 +142,36 @@ def optimize_acq( n_intial_start_points: int | None = None, acq_options: Mapping[str, Any] | None = None, fixed_features: dict[str, Any] | None = None, - maximum_allowed_categorical_combinations: int = 30, hide_warnings: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - """Optimize the acquisition function.""" + """Optimize the acquisition function. + + For purely numerical spaces, this uses botorch's `optimize_acqf()`. + For purely categorical spaces, this uses `optimize_acqf_discrete_local_search()`. + For mixed spaces, this uses a two step sequential optimization: + 1. Optimize acquisition over continuous space, sampling a random + categorical combination to fix the continuous acquisition function + 2. Wrap the acquisition function to fix the numerical dimensions + and optimize over the categorical space using + `optimize_acqf_discrete_local_search()` + NOTE: `optimize_acqf_discrete_local_search()` scales much better than + `optimize_acqf_mixed()` for large categorical dimensions in the search space. + + + Args: + acq_fn: The acquisition function to optimize. + encoder: The encoder used for encoding the configurations + n_candidates_required: The number of candidates to return. + num_restarts: The number of restarts to use during optimization. + n_intial_start_points: The number of initial start points to use during + optimization. + acq_options: Additional options to pass to the botorch `optimizer_acqf` function. + fixed_features: The features to fix to a certain value during acquisition. + hide_warnings: Whether to hide numerical warnings issued during GP routines. + + Returns: + The (encoded) optimized candidate(s) and corresponding acquisition value(s). + """ warning_context = ( disable_warnings(NumericalWarning) if hide_warnings else nullcontext() ) @@ -161,6 +197,8 @@ def optimize_acq( ) } + num_numericals = len(encoder.domains) - len(cat_transformers) + # Proceed with regular numerical acquisition if not any(cat_transformers): # Small heuristic to increase the number of candidates as our @@ -181,23 +219,6 @@ def optimize_acq( **acq_options, ) - # We need to generate the product of all possible combinations of categoricals, - # first we do a sanity check - n_combos = reduce( - lambda x, y: x * y, # type: ignore - [t.domain.cardinality for t in cat_transformers.values()], - 1, - ) - if n_combos > maximum_allowed_categorical_combinations: - raise ValueError( - "The number of fixed categorical dimensions is too high. " - "This will lead to an explosion in the number of possible " - f"combinations. Got: {n_combos} while the setting for the function" - f" is: {maximum_allowed_categorical_combinations=}. Consider reducing the " - "dimensions or consider encoding your categoricals in some other format." - ) - - # Right, now we generate all possible combinations # First, just collect the possible values per cat column # {hp_name: [v1, v2], hp_name2: [v1, v2, v3], ...} cats: dict[int, list[float]] = { @@ -207,34 +228,100 @@ def optimize_acq( ] for name, transformer in cat_transformers.items() } + cat_keys = list(cats.keys()) + choices = [torch.tensor(cats[k], dtype=torch.float) for k in cat_keys] + fixed_cat: dict[int, float] = {} - # Second, generate all possible combinations - fixed_cats: list[dict[int, float]] - if len(cats) == 1: - col, choice_indices = next(iter(cats.items())) - fixed_cats = [{col: i} for i in choice_indices] - else: - fixed_cats = [ - dict(zip(cats.keys(), combo, strict=False)) - for combo in product(*cats.values()) - ] + if num_numericals > 0: + with warning_context: + # Sample a random categorical combination and keep it fixed during + # the continuous optimization step + fixed_cat = {key: float(np.random.choice(cats[key])) for key in cat_keys} - # Make sure to include caller's fixed features if provided - if len(_fixed_features) > 0: - fixed_cats = [{**cat, **_fixed_features} for cat in fixed_cats] - - with warning_context: - # TODO: we should deterministically shuffle the fixed_categoricals - # as the underlying function does not. - return optimize_acqf_mixed( # type: ignore - acq_function=acq_fn, - bounds=bounds, - num_restarts=min(num_restarts // n_combos, 2), - raw_samples=n_intial_start_points, - q=n_candidates_required, - fixed_features_list=fixed_cats, - **acq_options, - ) + # Step 1: Optimize acquisition function over the continuous space + + if len(_fixed_features) > 0: + fixed_cat.update(_fixed_features) + + best_x_continuous, _ = optimize_acqf( + acq_function=acq_fn, + bounds=bounds, + q=n_candidates_required, + num_restarts=num_restarts, + raw_samples=n_intial_start_points, + fixed_features=fixed_cat, + **acq_options, + ) + + # Extract the numerical dims from the optimized continuous vector + fixed_numericals = { + i: float(best_x_continuous[0, i].item()) + for i in range(len(encoder.domains)) + if i not in cat_keys + } + + # Update fixed_numericals with _fixed_features + fixed_numericals.update(_fixed_features) + + # Step 2: Wrap acquisition function for discrete search + wrapped_acq = WrappedAcquisition( + acq=acq_fn, + encoder=encoder, + fixed_numericals=fixed_numericals, + ) + + # Step 3: Run discrete local search over the categorical space + # with the wrapped acquisition function + best_cat_tensor, _ = optimize_acqf_discrete_local_search( + acq_function=wrapped_acq, + discrete_choices=choices, + q=n_candidates_required, + num_restarts=num_restarts, + raw_samples=n_intial_start_points, + ) + + # Step 4: Concatenate best categorical and numerical dims, along with + # any fixed features provided by the caller + + q, c_dims = best_cat_tensor.shape + n_dims = len(fixed_numericals) + new_X_shape = (q, c_dims + n_dims) + + # Create a new tensor to hold the concatenated dimensions + best_x_full: torch.Tensor = torch.empty( + new_X_shape, dtype=best_cat_tensor.dtype, device=best_cat_tensor.device + ) + + # Create a mask to identify positions of categorical and numerical dimensions + mask = torch.ones( + c_dims + n_dims, dtype=torch.bool, device=best_cat_tensor.device + ) + insert_idxs = torch.tensor( + list(fixed_numericals.keys()), device=best_cat_tensor.device + ) + mask[insert_idxs] = False + + # Fill in the fixed numerical values and the input categorical values + for idx, val in fixed_numericals.items(): + best_x_full[:, idx] = val + best_x_full[:, mask] = best_cat_tensor + + # Evaluate the final acquisition value + with torch.no_grad(): + best_val_final = acq_fn(best_x_full) + + return best_x_full, best_val_final + + else: + with warning_context: + return optimize_acqf_discrete_local_search( # type: ignore + acq_function=acq_fn, + discrete_choices=choices, + q=n_candidates_required, + num_restarts=num_restarts, + raw_samples=n_intial_start_points, + **acq_options, + ) def encode_trials_for_gp( @@ -318,7 +405,6 @@ def fit_and_acquire_from_gp( n_candidates_required: int | None = None, num_restarts: int = 20, n_initial_start_points: int = 256, - maximum_allowed_categorical_combinations: int = 30, fixed_acq_features: dict[str, Any] | None = None, acq_options: Mapping[str, Any] | None = None, hide_warnings: bool = False, @@ -362,9 +448,6 @@ def fit_and_acquire_from_gp( num_restarts: The number of restarts to use during optimization. n_initial_start_points: The number of initial start points to use during optimization. - maximum_allowed_categorical_combinations: The maximum number of categorical - combinations to allow. If the number of combinations exceeds this, an error - will be raised. acq_options: Additional options to pass to the botorch `optimizer_acqf` function. hide_warnings: Whether to hide numerical warnings issued during GP routines. @@ -445,7 +528,6 @@ def fit_and_acquire_from_gp( n_intial_start_points=n_initial_start_points, fixed_features=fixed_acq_features, acq_options=acq_options, - maximum_allowed_categorical_combinations=maximum_allowed_categorical_combinations, hide_warnings=hide_warnings, ) return candidates diff --git a/pyproject.toml b/pyproject.toml index 22f5c8341..024fc8d2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ "botorch>=0.12", "gpytorch==1.13.0", "ifbo>=0.3.13", + "pymoo>=0.6.1.5" ] [project.urls]