55and have confidence that they’re always shipping the best model.
66"""
77from dataclasses import dataclass , field
8- from typing import List , Optional
8+ from typing import List , Optional , Union
99
1010from ..connection import Connection
11- from ..constants import DATASET_ITEMS_KEY , NAME_KEY , SLICE_ID_KEY
11+ from ..constants import DATASET_ITEMS_KEY , NAME_KEY , SCENES_KEY , SLICE_ID_KEY
1212from ..dataset_item import DatasetItem
13+ from ..scene import Scene
1314from .constants import (
1415 EVAL_FUNCTION_ID_KEY ,
1516 SCENARIO_TEST_ID_KEY ,
1617 SCENARIO_TEST_METRICS_KEY ,
1718 THRESHOLD_COMPARISON_KEY ,
1819 THRESHOLD_KEY ,
20+ EntityLevel ,
1921 ThresholdComparison ,
2022)
2123from .data_transfer_objects .scenario_test_evaluations import EvaluationResult
@@ -162,16 +164,31 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
162164 ]
163165 return evaluations
164166
165- def get_items (self ) -> List [DatasetItem ]:
167+ def get_items (
168+ self , level : EntityLevel = EntityLevel .ITEM
169+ ) -> Union [List [DatasetItem ], List [Scene ]]:
170+ """Gets items within a scenario test at a given level, returning a list of DatasetItem or Scene objects.
171+
172+ Args:
173+ level: :class:`EntityLevel`
174+
175+ Returns:
176+ A list of :class:`ScenarioTestEvaluation` objects.
177+ """
166178 response = self .connection .get (
167179 f"validate/scenario_test/{ self .id } /items" ,
168180 )
181+ if level == EntityLevel .SCENE :
182+ return [
183+ Scene .from_json (scene , skip_validate = True )
184+ for scene in response [SCENES_KEY ]
185+ ]
169186 return [
170187 DatasetItem .from_json (item ) for item in response [DATASET_ITEMS_KEY ]
171188 ]
172189
173190 def set_baseline_model (self , model_id : str ):
174- """Set's a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
191+ """Sets a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
175192 this scenario test must have been evaluated using that model. The baseline model's performance
176193 is used as the threshold for all metrics against which other models are compared.
177194
@@ -205,14 +222,28 @@ def upload_external_evaluation_results(
205222 len (results ) > 0
206223 ), "Submitting evaluation requires at least one result."
207224
225+ level = EntityLevel .ITEM
208226 metric_per_ref_id = {}
209227 weight_per_ref_id = {}
210228 aggregate_weighted_sum = 0.0
211229 aggregate_weight = 0.0
230+
212231 # aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
213232 for r in results :
214- metric_per_ref_id [r .item_ref_id ] = r .score
215- weight_per_ref_id [r .item_ref_id ] = r .weight
233+ # Ensure results are uploaded ONLY for items or ONLY for scenes
234+ if r .scene_ref_id is not None :
235+ level = EntityLevel .SCENE
236+ if r .item_ref_id is not None and level == EntityLevel .SCENE :
237+ raise ValueError (
238+ "All evaluation results must either pertain to a scene_ref_id or an item_ref_id, not both."
239+ )
240+ ref_id = (
241+ r .item_ref_id if level == EntityLevel .ITEM else r .scene_ref_id
242+ )
243+
244+ # Aggregate scores and weights
245+ metric_per_ref_id [ref_id ] = r .score
246+ weight_per_ref_id [ref_id ] = r .weight
216247 aggregate_weighted_sum += r .score * r .weight
217248 aggregate_weight += r .weight
218249
@@ -224,6 +255,7 @@ def upload_external_evaluation_results(
224255 "overall_metric" : aggregate_weighted_sum / aggregate_weight ,
225256 "model_id" : model_id ,
226257 "slice_id" : self .slice_id ,
258+ "level" : level .value ,
227259 }
228260 response = self .connection .post (
229261 payload ,
0 commit comments