Skip to content

Commit 2fa5f79

Browse files
committed
fix(HEBO): Make sure to report costs always
1 parent 8d5308d commit 2fa5f79

File tree

1 file changed

+39
-6
lines changed
  • src/amltk/optimization/optimizers

1 file changed

+39
-6
lines changed

src/amltk/optimization/optimizers/hebo.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,16 @@ def tell(self, report: Trial.Report[HEBOTrial]) -> None:
9999
# than HEBO should be fine with these reported.
100100
# Either way, we don't actually have to look at the status of the trial to give
101101
# the info to hebo.
102-
103102
# Make sure we have a value for each
104103
_lookup: dict[str, Metric.Value] = {
105104
v.metric.name: v for v in report.metric_values
106105
}
107106
metric_values = [
108107
_lookup.get(metric.name, metric.worst) for metric in self.metrics
109108
]
110-
raw_y = np.array([[v.value for v in metric_values]])
109+
110+
costs = [self.cost(v) for v in metric_values]
111+
raw_y = np.array([costs]) # Yep, it needs 2d, for single report tells
111112
self.optimizer.observe(raw_x, raw_y)
112113

113114
@override
@@ -221,13 +222,12 @@ def create(
221222
)
222223
case Sequence():
223224
assert len(metrics) > 1
224-
# TODO: Not really sure if I should give a ref point or not, especially
225-
# if there are unbounded metrics.
226-
ref_point = np.array([metric.worst.value for metric in metrics])
227225
optimizer = GeneralBO(
228226
space=space,
229227
num_obj=len(metrics),
230-
ref_point=ref_point,
228+
# TODO: Not really sure if I should give a ref point or not,
229+
# especially if there are unbounded metrics.
230+
ref_point=np.array(cls.worst_possible_cost(metrics)),
231231
**optimizer_kwargs,
232232
)
233233

@@ -237,3 +237,36 @@ def create(
237237
@classmethod
238238
def preferred_parser(cls) -> HEBOParser:
239239
return parser
240+
241+
@overload
242+
@classmethod
243+
def worst_possible_cost(cls, metric: Metric) -> float:
244+
...
245+
246+
@overload
247+
@classmethod
248+
def worst_possible_cost(cls, metric: Sequence[Metric]) -> list[float]:
249+
...
250+
251+
@classmethod
252+
def worst_possible_cost(
253+
cls,
254+
metric: Metric | Sequence[Metric],
255+
) -> float | list[float]:
256+
"""Get the crash cost for a metric for SMAC."""
257+
match metric:
258+
case Metric(bounds=(lower, upper)): # Bounded metrics
259+
return abs(upper - lower)
260+
case Metric(): # Unbounded metric
261+
return np.inf
262+
case metrics:
263+
return [cls.worst_possible_cost(m) for m in metrics]
264+
265+
@classmethod
266+
def cost(cls, value: Metric.Value) -> float:
267+
"""Get the cost for a metric value for HEBO."""
268+
match value.distance_to_optimal:
269+
case None: # If we can't compute the distance, use the loss
270+
return value.loss
271+
case distance: # If we can compute the distance, use that
272+
return distance

0 commit comments

Comments
 (0)