Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion aepsych/acquisition/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.transforms import t_batch_mode_transform
from botorch.utils.transforms import (
average_over_ensemble_models,
t_batch_mode_transform,
)
from scipy.stats import norm
from torch import Tensor

Expand Down Expand Up @@ -171,6 +174,7 @@ def __init__(
self.posterior_transform = posterior_transform

@t_batch_mode_transform(expected_q=1)
@average_over_ensemble_models
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Evaluate acquisition function at X.
Expand Down Expand Up @@ -291,6 +295,7 @@ def __init__(
self.register_buffer("Xq", Xq)

@t_batch_mode_transform(expected_q=1)
@average_over_ensemble_models
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Evaluate acquisition function at X.
Expand Down
6 changes: 5 additions & 1 deletion aepsych/acquisition/mc_posterior_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from botorch.utils.transforms import (
average_over_ensemble_models,
t_batch_mode_transform,
)


def balv_acq(obj_samps: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -61,6 +64,7 @@ def __init__(
self.objective = objective

@t_batch_mode_transform()
@average_over_ensemble_models
def forward(self, X: torch.Tensor) -> torch.Tensor:
r"""Evaluate MCPosteriorVariance on the candidate set `X`.

Expand Down
6 changes: 5 additions & 1 deletion aepsych/acquisition/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from botorch.utils.transforms import (
average_over_ensemble_models,
t_batch_mode_transform,
)
from torch import Tensor
from torch.distributions.bernoulli import Bernoulli

Expand Down Expand Up @@ -82,6 +85,7 @@ def __init__(
)

@t_batch_mode_transform()
@average_over_ensemble_models
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate mutual information on the candidate set `X`.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"pandas",
"aepsych_client",
"statsmodels",
"botorch==0.13.0",
"botorch",
]

BENCHMARK_REQUIRES = ["tqdm", "pathos", "multiprocess", "torcheval"]
Expand Down
11 changes: 8 additions & 3 deletions tests/test_mean_covar_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def f_1d(x):

[init_strat]
generator = SobolGenerator
min_asks = 200
min_asks = 250

[opt_strat]
generator = MixedOptimizeAcqfGenerator
Expand Down Expand Up @@ -780,5 +780,10 @@ def f_1d(x):
y_0, _ = strat.predict(x_0, probability_space=True)
y_1, _ = strat.predict(x_1, probability_space=True)

self.assertTrue(torch.allclose(x[torch.argmax(y_0)], torch.tensor([0.0])))
self.assertTrue(torch.allclose(x[torch.argmax(y_1)], torch.tensor([0.4])))
# Loose test
self.assertLessEqual(
(torch.abs(x[torch.argmax(y_0)] - torch.tensor([0.0]))).item(), 0.05
)
self.assertLessEqual(
(torch.abs(x[torch.argmax(y_1)] - torch.tensor([0.4]))).item(), 0.05
)
Loading