Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion neps/optimizers/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from neps.optimizers.acquisition.cost_cooling import cost_cooled_acq
from neps.optimizers.acquisition.pibo import pibo_acquisition
from neps.optimizers.acquisition.weighted_acquisition import WeightedAcquisition
from neps.optimizers.acquisition.wrapped_acquisition import WrappedAcquisition

__all__ = ["WeightedAcquisition", "cost_cooled_acq", "pibo_acquisition"]
__all__ = [
"WeightedAcquisition",
"WrappedAcquisition",
"cost_cooled_acq",
"pibo_acquisition",
]
94 changes: 94 additions & 0 deletions neps/optimizers/acquisition/wrapped_acquisition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Module to wrap the existing acquisition function to account for mixed search spaces.

For mixed search spaces, we first keep the categorical dimensions fixed to some randomly
chosen values and perform optimization over the continuous dimensions.
Next, we select the numerical dimensions from the returned best candidate and keep them
fixed, while we use `optimize_acqf_discrete_local_search` over the categorical dimensions.

For this, we need to wrap the existing acquisition function to accept tensors containing
only the categorical dimensions since BoTorch does not natively support keeping numerical
dimensions fixed in `optimize_acqf_discrete_local_search`.

Inside `WrappedAcquisition`, we concatenate the fixed numerical dimensions to the tensor
containing only the categoricals before passing it to the original acquisition function.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from botorch.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import t_batch_mode_transform
from botorch.acquisition.monte_carlo import concatenate_pending_points

if TYPE_CHECKING:
from torch import Tensor

from neps.space.encoding import ConfigEncoder


class WrappedAcquisition(AcquisitionFunction):
"""Acquisition function wrapper for mixed search spaces."""

def __init__(
self,
acq: AcquisitionFunction,
encoder: ConfigEncoder,
fixed_numericals: dict[int, float],
) -> None:
"""Initialize the wrapped acquisition function.

Args:
acq: The base acquisition function.
fixed_numericals: A dictionary mapping numerical dimension indices to their
fixed values.
"""
super().__init__(model=acq.model)
# NOTE: Remove X_pending from the base acquisition function.
# See similar note in WeightedAcquisition.
if (X_pending := getattr(acq, "X_pending", None)) is not None:
acq.set_X_pending(None)
self.set_X_pending(X_pending)
else:
acq.set_X_pending(None)
self.set_X_pending(None)

self.acq = acq
self.encoder = encoder
self.fixed_numericals = fixed_numericals
self.fixed_numericals = fixed_numericals

@concatenate_pending_points # type: ignore
@t_batch_mode_transform() # type: ignore
def forward(self, X: Tensor) -> Tensor:
"""Evaluate the wrapped acquisition function on the candidate set X
after concatenating the fixed numerical dimensions.

Args:
X: A `batch_shape x q x d_categorical`-dim tensor of candidates, where
`d_categorical` is the number of categorical dimensions.

Returns:
A `batch_shape`-dim tensor of acquisition function values at the input
candidates.
"""
batch, q, c_dims = X.shape
n_dims = len(self.fixed_numericals)
new_X_shape = (batch, q, c_dims + n_dims)

# Create a new tensor to hold the concatenated dimensions
x_full: torch.Tensor = torch.empty(new_X_shape, dtype=X.dtype, device=X.device)

# Create a mask to identify positions of categorical and numerical dimensions
mask = torch.ones(c_dims + n_dims, dtype=torch.bool, device=X.device)
insert_idxs = torch.tensor(list(self.fixed_numericals.keys()), device=X.device)
mask[insert_idxs] = False

# Fill in the fixed numerical values and the input categorical values
for idx, val in self.fixed_numericals.items():
x_full[:, :, idx] = val
x_full[:, :, mask] = X

# Pass the concatenated tensor to the original acquisition function
return self.acq(x_full)
190 changes: 136 additions & 54 deletions neps/optimizers/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,34 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Mapping, Sequence
from contextlib import nullcontext
from dataclasses import dataclass
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any

import gpytorch.constraints
import numpy as np
import torch
from botorch.exceptions.warnings import InputDataWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.gp_regression import Log, get_covar_module_with_dim_scaled_prior
from botorch.models.gp_regression_mixed import CategoricalKernel, OutcomeTransform
from botorch.models.transforms.outcome import ChainedOutcomeTransform, Standardize
from botorch.optim import optimize_acqf, optimize_acqf_mixed
from botorch.optim import (
optimize_acqf,
optimize_acqf_discrete_local_search,
)
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.kernels import ScaleKernel
from gpytorch.utils.warnings import NumericalWarning

from neps.optimizers.acquisition import cost_cooled_acq, pibo_acquisition
from neps.optimizers.acquisition import (
WrappedAcquisition,
cost_cooled_acq,
pibo_acquisition,
)
from neps.space.encoding import CategoricalToIntegerTransformer, ConfigEncoder
from neps.utils.common import disable_warnings

Expand All @@ -35,6 +43,8 @@

logger = logging.getLogger(__name__)

warnings.filterwarnings("ignore", category=InputDataWarning)


@dataclass
class GPEncodedData:
Expand Down Expand Up @@ -132,10 +142,36 @@ def optimize_acq(
n_intial_start_points: int | None = None,
acq_options: Mapping[str, Any] | None = None,
fixed_features: dict[str, Any] | None = None,
maximum_allowed_categorical_combinations: int = 30,
hide_warnings: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Optimize the acquisition function."""
"""Optimize the acquisition function.

For purely numerical spaces, this uses botorch's `optimize_acqf()`.
For purely categorical spaces, this uses `optimize_acqf_discrete_local_search()`.
For mixed spaces, this uses a two step sequential optimization:
1. Optimize acquisition over continuous space, sampling a random
categorical combination to fix the continuous acquisition function
2. Wrap the acquisition function to fix the numerical dimensions
and optimize over the categorical space using
`optimize_acqf_discrete_local_search()`
NOTE: `optimize_acqf_discrete_local_search()` scales much better than
`optimize_acqf_mixed()` for large categorical dimensions in the search space.


Args:
acq_fn: The acquisition function to optimize.
encoder: The encoder used for encoding the configurations
n_candidates_required: The number of candidates to return.
num_restarts: The number of restarts to use during optimization.
n_intial_start_points: The number of initial start points to use during
optimization.
acq_options: Additional options to pass to the botorch `optimizer_acqf` function.
fixed_features: The features to fix to a certain value during acquisition.
hide_warnings: Whether to hide numerical warnings issued during GP routines.

Returns:
The (encoded) optimized candidate(s) and corresponding acquisition value(s).
"""
warning_context = (
disable_warnings(NumericalWarning) if hide_warnings else nullcontext()
)
Expand All @@ -161,6 +197,8 @@ def optimize_acq(
)
}

num_numericals = len(encoder.domains) - len(cat_transformers)

# Proceed with regular numerical acquisition
if not any(cat_transformers):
# Small heuristic to increase the number of candidates as our
Expand All @@ -181,23 +219,6 @@ def optimize_acq(
**acq_options,
)

# We need to generate the product of all possible combinations of categoricals,
# first we do a sanity check
n_combos = reduce(
lambda x, y: x * y, # type: ignore
[t.domain.cardinality for t in cat_transformers.values()],
1,
)
if n_combos > maximum_allowed_categorical_combinations:
raise ValueError(
"The number of fixed categorical dimensions is too high. "
"This will lead to an explosion in the number of possible "
f"combinations. Got: {n_combos} while the setting for the function"
f" is: {maximum_allowed_categorical_combinations=}. Consider reducing the "
"dimensions or consider encoding your categoricals in some other format."
)

# Right, now we generate all possible combinations
# First, just collect the possible values per cat column
# {hp_name: [v1, v2], hp_name2: [v1, v2, v3], ...}
cats: dict[int, list[float]] = {
Expand All @@ -207,34 +228,100 @@ def optimize_acq(
]
for name, transformer in cat_transformers.items()
}
cat_keys = list(cats.keys())
choices = [torch.tensor(cats[k], dtype=torch.float) for k in cat_keys]
fixed_cat: dict[int, float] = {}

# Second, generate all possible combinations
fixed_cats: list[dict[int, float]]
if len(cats) == 1:
col, choice_indices = next(iter(cats.items()))
fixed_cats = [{col: i} for i in choice_indices]
else:
fixed_cats = [
dict(zip(cats.keys(), combo, strict=False))
for combo in product(*cats.values())
]
if num_numericals > 0:
with warning_context:
# Sample a random categorical combination and keep it fixed during
# the continuous optimization step
fixed_cat = {key: float(np.random.choice(cats[key])) for key in cat_keys}

# Make sure to include caller's fixed features if provided
if len(_fixed_features) > 0:
fixed_cats = [{**cat, **_fixed_features} for cat in fixed_cats]

with warning_context:
# TODO: we should deterministically shuffle the fixed_categoricals
# as the underlying function does not.
return optimize_acqf_mixed( # type: ignore
acq_function=acq_fn,
bounds=bounds,
num_restarts=min(num_restarts // n_combos, 2),
raw_samples=n_intial_start_points,
q=n_candidates_required,
fixed_features_list=fixed_cats,
**acq_options,
)
# Step 1: Optimize acquisition function over the continuous space

if len(_fixed_features) > 0:
fixed_cat.update(_fixed_features)

best_x_continuous, _ = optimize_acqf(
acq_function=acq_fn,
bounds=bounds,
q=n_candidates_required,
num_restarts=num_restarts,
raw_samples=n_intial_start_points,
fixed_features=fixed_cat,
**acq_options,
)

# Extract the numerical dims from the optimized continuous vector
fixed_numericals = {
i: float(best_x_continuous[0, i].item())
for i in range(len(encoder.domains))
if i not in cat_keys
}

# Update fixed_numericals with _fixed_features
fixed_numericals.update(_fixed_features)

# Step 2: Wrap acquisition function for discrete search
wrapped_acq = WrappedAcquisition(
acq=acq_fn,
encoder=encoder,
fixed_numericals=fixed_numericals,
)

# Step 3: Run discrete local search over the categorical space
# with the wrapped acquisition function
best_cat_tensor, _ = optimize_acqf_discrete_local_search(
acq_function=wrapped_acq,
discrete_choices=choices,
q=n_candidates_required,
num_restarts=num_restarts,
raw_samples=n_intial_start_points,
)

# Step 4: Concatenate best categorical and numerical dims, along with
# any fixed features provided by the caller

q, c_dims = best_cat_tensor.shape
n_dims = len(fixed_numericals)
new_X_shape = (q, c_dims + n_dims)

# Create a new tensor to hold the concatenated dimensions
best_x_full: torch.Tensor = torch.empty(
new_X_shape, dtype=best_cat_tensor.dtype, device=best_cat_tensor.device
)

# Create a mask to identify positions of categorical and numerical dimensions
mask = torch.ones(
c_dims + n_dims, dtype=torch.bool, device=best_cat_tensor.device
)
insert_idxs = torch.tensor(
list(fixed_numericals.keys()), device=best_cat_tensor.device
)
mask[insert_idxs] = False

# Fill in the fixed numerical values and the input categorical values
for idx, val in fixed_numericals.items():
best_x_full[:, idx] = val
best_x_full[:, mask] = best_cat_tensor

# Evaluate the final acquisition value
with torch.no_grad():
best_val_final = acq_fn(best_x_full)

return best_x_full, best_val_final

else:
with warning_context:
return optimize_acqf_discrete_local_search( # type: ignore
acq_function=acq_fn,
discrete_choices=choices,
q=n_candidates_required,
num_restarts=num_restarts,
raw_samples=n_intial_start_points,
**acq_options,
)


def encode_trials_for_gp(
Expand Down Expand Up @@ -318,7 +405,6 @@ def fit_and_acquire_from_gp(
n_candidates_required: int | None = None,
num_restarts: int = 20,
n_initial_start_points: int = 256,
maximum_allowed_categorical_combinations: int = 30,
fixed_acq_features: dict[str, Any] | None = None,
acq_options: Mapping[str, Any] | None = None,
hide_warnings: bool = False,
Expand Down Expand Up @@ -362,9 +448,6 @@ def fit_and_acquire_from_gp(
num_restarts: The number of restarts to use during optimization.
n_initial_start_points: The number of initial start points to use during
optimization.
maximum_allowed_categorical_combinations: The maximum number of categorical
combinations to allow. If the number of combinations exceeds this, an error
will be raised.
acq_options: Additional options to pass to the botorch `optimizer_acqf` function.
hide_warnings: Whether to hide numerical warnings issued during GP routines.

Expand Down Expand Up @@ -445,7 +528,6 @@ def fit_and_acquire_from_gp(
n_intial_start_points=n_initial_start_points,
fixed_features=fixed_acq_features,
acq_options=acq_options,
maximum_allowed_categorical_combinations=maximum_allowed_categorical_combinations,
hide_warnings=hide_warnings,
)
return candidates
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dependencies = [
"botorch>=0.12",
"gpytorch==1.13.0",
"ifbo>=0.3.13",
"pymoo>=0.6.1.5"
]

[project.urls]
Expand Down
Loading