@@ -99,15 +99,16 @@ def tell(self, report: Trial.Report[HEBOTrial]) -> None:
9999 # than HEBO should be fine with these reported.
100100 # Either way, we don't actually have to look at the status of the trial to give
101101 # the info to hebo.
102-
103102 # Make sure we have a value for each
104103 _lookup : dict [str , Metric .Value ] = {
105104 v .metric .name : v for v in report .metric_values
106105 }
107106 metric_values = [
108107 _lookup .get (metric .name , metric .worst ) for metric in self .metrics
109108 ]
110- raw_y = np .array ([[v .value for v in metric_values ]])
109+
110+ costs = [self .cost (v ) for v in metric_values ]
111+ raw_y = np .array ([costs ]) # Yep, it needs 2d, for single report tells
111112 self .optimizer .observe (raw_x , raw_y )
112113
113114 @override
@@ -221,13 +222,12 @@ def create(
221222 )
222223 case Sequence ():
223224 assert len (metrics ) > 1
224- # TODO: Not really sure if I should give a ref point or not, especially
225- # if there are unbounded metrics.
226- ref_point = np .array ([metric .worst .value for metric in metrics ])
227225 optimizer = GeneralBO (
228226 space = space ,
229227 num_obj = len (metrics ),
230- ref_point = ref_point ,
228+ # TODO: Not really sure if I should give a ref point or not,
229+ # especially if there are unbounded metrics.
230+ ref_point = np .array (cls .worst_possible_cost (metrics )),
231231 ** optimizer_kwargs ,
232232 )
233233
@@ -237,3 +237,36 @@ def create(
237237 @classmethod
238238 def preferred_parser (cls ) -> HEBOParser :
239239 return parser
240+
241+ @overload
242+ @classmethod
243+ def worst_possible_cost (cls , metric : Metric ) -> float :
244+ ...
245+
246+ @overload
247+ @classmethod
248+ def worst_possible_cost (cls , metric : Sequence [Metric ]) -> list [float ]:
249+ ...
250+
251+ @classmethod
252+ def worst_possible_cost (
253+ cls ,
254+ metric : Metric | Sequence [Metric ],
255+ ) -> float | list [float ]:
256+ """Get the crash cost for a metric for SMAC."""
257+ match metric :
258+ case Metric (bounds = (lower , upper )): # Bounded metrics
259+ return abs (upper - lower )
260+ case Metric (): # Unbounded metric
261+ return np .inf
262+ case metrics :
263+ return [cls .worst_possible_cost (m ) for m in metrics ]
264+
265+ @classmethod
266+ def cost (cls , value : Metric .Value ) -> float :
267+ """Get the cost for a metric value for HEBO."""
268+ match value .distance_to_optimal :
269+ case None : # If we can't compute the distance, use the loss
270+ return value .loss
271+ case distance : # If we can compute the distance, use that
272+ return distance
0 commit comments