88from typing import List , Optional , Union
99
1010from ..connection import Connection
11- from ..constants import DATASET_ITEMS_KEY , NAME_KEY , SCENES_KEY , SLICE_ID_KEY
11+ from ..constants import (
12+ DATASET_ITEMS_KEY ,
13+ NAME_KEY ,
14+ SCENES_KEY ,
15+ SLICE_ID_KEY ,
16+ TRACKS_KEY ,
17+ )
1218from ..dataset_item import DatasetItem
1319from ..scene import Scene
20+ from ..track import Track
1421from .constants import (
1522 EVAL_FUNCTION_ID_KEY ,
1623 SCENARIO_TEST_ID_KEY ,
@@ -166,8 +173,8 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
166173
167174 def get_items (
168175 self , level : EntityLevel = EntityLevel .ITEM
169- ) -> Union [List [DatasetItem ], List [Scene ]]:
170- """Gets items within a scenario test at a given level, returning a list of DatasetItem or Scene objects.
176+ ) -> Union [List [Track ], List [ DatasetItem ], List [Scene ]]:
177+ """Gets items within a scenario test at a given level, returning a list of Track, DatasetItem, or Scene objects.
171178
172179 Args:
173180 level: :class:`EntityLevel`
@@ -178,14 +185,22 @@ def get_items(
178185 response = self .connection .get (
179186 f"validate/scenario_test/{ self .id } /items" ,
180187 )
188+ if level == EntityLevel .TRACK :
189+ return [
190+ Track .from_json (track , connection = self .connection )
191+ for track in response .get (TRACKS_KEY , [])
192+ ]
181193 if level == EntityLevel .SCENE :
182194 return [
183195 Scene .from_json (scene , skip_validate = True )
184- for scene in response [ SCENES_KEY ]
196+ for scene in response . get ( SCENES_KEY , [])
185197 ]
186- return [
187- DatasetItem .from_json (item ) for item in response [DATASET_ITEMS_KEY ]
188- ]
198+ if level == EntityLevel .ITEM :
199+ return [
200+ DatasetItem .from_json (item )
201+ for item in response .get (DATASET_ITEMS_KEY , [])
202+ ]
203+ raise ValueError (f"Invalid entity level: { level } " )
189204
190205 def set_baseline_model (self , model_id : str ):
191206 """Sets a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
@@ -222,23 +237,41 @@ def upload_external_evaluation_results(
222237 len (results ) > 0
223238 ), "Submitting evaluation requires at least one result."
224239
225- level = EntityLevel . ITEM
240+ level : Optional [ EntityLevel ] = None
226241 metric_per_ref_id = {}
227242 weight_per_ref_id = {}
228243 aggregate_weighted_sum = 0.0
229244 aggregate_weight = 0.0
230245
246+ # Ensures reults at only one EntityLevel are provided, otherwise throwing a ValueError
247+ def ensure_level_consistency_or_raise (
248+ cur_level : Optional [EntityLevel ], new_level : EntityLevel
249+ ):
250+ if level is not None and level != new_level :
251+ raise ValueError (
252+ f"All evaluation results must only pertain to one level. Received { cur_level } then { new_level } "
253+ )
254+
231255 # aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
232256 for r in results :
233- # Ensure results are uploaded ONLY for items or ONLY for scenes
257+ # Ensure results are uploaded ONLY for ONE OF tracks, items, and scenes
258+ if r .track_ref_id is not None :
259+ ensure_level_consistency_or_raise (level , EntityLevel .TRACK )
260+ level = EntityLevel .TRACK
261+ if r .item_ref_id is not None :
262+ ensure_level_consistency_or_raise (level , EntityLevel .ITEM )
263+ level = EntityLevel .ITEM
234264 if r .scene_ref_id is not None :
265+ ensure_level_consistency_or_raise (level , EntityLevel .SCENE )
235266 level = EntityLevel .SCENE
236- if r .item_ref_id is not None and level == EntityLevel .SCENE :
237- raise ValueError (
238- "All evaluation results must either pertain to a scene_ref_id or an item_ref_id, not both."
239- )
240267 ref_id = (
241- r .item_ref_id if level == EntityLevel .ITEM else r .scene_ref_id
268+ r .track_ref_id
269+ if level == EntityLevel .TRACK
270+ else (
271+ r .item_ref_id
272+ if level == EntityLevel .ITEM
273+ else r .scene_ref_id
274+ )
242275 )
243276
244277 # Aggregate scores and weights
@@ -255,7 +288,7 @@ def upload_external_evaluation_results(
255288 "overall_metric" : aggregate_weighted_sum / aggregate_weight ,
256289 "model_id" : model_id ,
257290 "slice_id" : self .slice_id ,
258- "level" : level .value ,
291+ "level" : level .value if level else None ,
259292 }
260293 response = self .connection .post (
261294 payload ,
0 commit comments