Skip to content
652 changes: 652 additions & 0 deletions Example Function Subset Algorithm.ipynb

Large diffs are not rendered by default.

447 changes: 447 additions & 0 deletions NN BoTorch Model.ipynb

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions bax_algorithms/algo_fns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

def global_opt(f_x, x_grid, minimize=True):
if minimize:
y_opt, opt_idx = torch.min(f_x, dim=-2)
else:
y_opt, opt_idx = torch.max(f_x, dim=-2)

opt_idx = opt_idx.squeeze(dim=-1)
x_opt = x_grid[opt_idx]

return x_opt, y_opt

def single_level_band(f_x, x_grid, min_val = None, max_val = None):
idxs = torch.where((f_x >= min_val) & (f_x < max_val))

# To do: maybe add some shape checking here
y_opt = f_x[idxs].unsqueeze(-1)
# 1:-1 avoids sampling idx + y property idx
x_opt = x_grid[idxs[1:-1]]

return x_opt, y_opt

def multi_level_band(f_x, x_grid, bounds_list = None):
assert f_x.shape[-1] == len(bounds_list), f"len(bounds_list) ({len(bounds_list)}) must match number of property dimensions ({f_x.shape[-1]})"

# Start with a mask of all True values
condition = torch.ones(f_x.shape[:-1], dtype=torch.bool, device=f_x.device)

for i, (lower, upper) in enumerate(bounds_list):
condition &= (f_x[..., i] >= lower) & (f_x[..., i] < upper)

idxs = torch.where(condition)
y_opt = f_x[idxs]
# :-1 avoids y property idx
x_opt = x_grid[idxs[:-1]]

return x_opt, y_opt



# Implementation adapted from: https://stackoverflow.com/questions/32791911/fast-calculation-of-pareto-front-in-python (Peter)
def obtain_discrete_pareto_optima(f_x, x_grid):
is_efficient = torch.arange(f_x.shape[0])
next_point_index = 0 # Next index in the is_efficient array to search for
while next_point_index < len(f_x):
nondominated_point_mask = torch.any(f_x >= f_x[next_point_index], axis=1)
is_efficient = is_efficient[nondominated_point_mask] # Remove dominated points
f_x = f_x[nondominated_point_mask]
next_point_index = torch.sum(nondominated_point_mask[:next_point_index]) + 1

y_opt = f_x[is_efficient]
x_opt = x_grid[is_efficient]

return x_opt, y_opt
102 changes: 102 additions & 0 deletions bax_algorithms/base_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from pydantic import Field
import torch
from torch import Tensor

from botorch.models.model import Model, ModelList
from xopt.generators.bayesian.bax.algorithms import Algorithm

class BaseDiscreteAlgoFn(ABC):
@abstractmethod
def __call__(self, posterior_samples: Tensor, x_grid: Tensor, **algo_kwargs) -> Tuple[Tensor, Tensor]:
pass

class FunctionWrapper(BaseDiscreteAlgoFn):
def __init__(self, fn: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]]):
self.fn = fn

def __call__(self, posterior_samples: Tensor, x_grid: Tensor, **algo_kwargs) -> Tuple[Tensor, Tensor]:
return self.fn(posterior_samples, x_grid, **algo_kwargs)

class DiscreteSubsetAlgorithm(Algorithm, ABC):
algo_fn: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]] = Field(None,
description="Python function defining a BAX algorithm on a discrete grid")
grid: Tensor = Field(None,
description="n-d grid of discrete points")
observable_names_ordered: list[str] = Field(["y1"],
description="keys designating output properties")
algo_kwargs: dict = Field({},
description="keyword args for generic subset algorithm")

def get_execution_paths(self, model: Model, bounds: Tensor):
test_points = self.grid

if isinstance(model, ModelList):
test_points = test_points.to(model.models[0].train_targets)
else:
test_points = test_points.to(model.train_targets)

# get samples of the model posterior at mesh points
posterior_samples = self.evaluate_virtual_objective(
model, test_points, bounds, self.n_samples
)

# wrap if needed
if not isinstance(self.algo_fn, BaseDiscreteAlgoFn):
self.algo_fn = FunctionWrapper(self.algo_fn)

x_opt, y_opt = self.algo_fn(posterior_samples, test_points, **self.algo_kwargs)

# get the solution_center and solution_entropy for Turbo
# note: the entropy calc here drops a constant scaling factor
solution_center = x_opt.mean(dim=0).numpy()
solution_entropy = float(torch.log(x_opt.std(dim=0) ** 2).sum())

# collect secondary results in a dict
results_dict = {
"test_points": test_points,
"posterior_samples": posterior_samples,
"execution_paths": torch.hstack((x_opt, y_opt)),
"solution_center": solution_center,
"solution_entropy": solution_entropy,
}

# return execution paths
return x_opt.unsqueeze(-2), y_opt.unsqueeze(-2), results_dict

def evaluate_virtual_objective(
self,
model: Model,
x: Tensor,
bounds: Tensor,
n_samples: int,
tkwargs: dict = None,
) -> Tensor:
"""
Evaluate the virtual objective (samples).

Parameters:
-----------
model : Model
The model to use for evaluating the virtual objective.
x : Tensor
The inputs at which to evaluate the virtual objective.
bounds : Tensor
The bounds for the optimization.
n_samples : int
The number of samples to generate.
tkwargs : dict, optional
Additional keyword arguments for the evaluation.

Returns:
--------
Tensor
The evaluated virtual objective values.
"""
# get samples of the model posterior at inputs given by x
with torch.no_grad():
post = model.posterior(x)
objective_values = post.rsample(torch.Size([n_samples]))

return objective_values
32 changes: 32 additions & 0 deletions bax_algorithms/discrete_algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .base_discrete import DiscreteSubsetAlgorithm
from .algo_fns import global_opt, single_level_band, multi_level_band
from pydantic import Field

class GlobalOpt(DiscreteSubsetAlgorithm):
minimize: bool = Field(True,
description="If true, minimize function, otherwise maximize")

def __init__(self, **data):
data["algo_fn"] = global_opt
data["algo_kwargs"] = {"minimize": data["minimize"]}
super().__init__(**data)

class SingleLevelBand(DiscreteSubsetAlgorithm):
min_val: float = Field(...,
description="Min value of band")
max_val: float = Field(...,
description="Max value of band")

def __init__(self, **data):
data["algo_fn"] = single_level_band
data["algo_kwargs"] = {"min_val": data["min_val"], "max_val": data["max_val"]}
super().__init__(**data)

class MultiLevelBand(DiscreteSubsetAlgorithm):
bounds_list: list = Field(...,
description="List of bounds for multi-level band")

def __init__(self, **data):
data["algo_fn"] = multi_level_band
data["algo_kwargs"] = {"bounds_list": data["bounds_list"]}
super().__init__(**data)
Loading