55and have confidence that they’re always shipping the best model.
66"""
77from dataclasses import dataclass , field
8- from typing import List
8+ from typing import List , Optional
99
1010from ..connection import Connection
1111from ..constants import NAME_KEY , SLICE_ID_KEY
1212from ..dataset_item import DatasetItem
13- from .data_transfer_objects .eval_function import EvaluationCriterion
13+ from .constants import (
14+ EVAL_FUNCTION_ID_KEY ,
15+ SCENARIO_TEST_ID_KEY ,
16+ SCENARIO_TEST_METRICS_KEY ,
17+ THRESHOLD_COMPARISON_KEY ,
18+ THRESHOLD_KEY ,
19+ ThresholdComparison ,
20+ )
1421from .data_transfer_objects .scenario_test_evaluations import GetEvalHistory
15- from .data_transfer_objects .scenario_test_metric import AddScenarioTestMetric
22+ from .data_transfer_objects .scenario_test_metric import AddScenarioTestFunction
23+ from .eval_functions .available_eval_functions import EvalFunction
1624from .scenario_test_evaluation import ScenarioTestEvaluation
1725from .scenario_test_metric import ScenarioTestMetric
1826
@@ -36,6 +44,7 @@ class ScenarioTest:
3644 connection : Connection = field (repr = False )
3745 name : str = field (init = False )
3846 slice_id : str = field (init = False )
47+ baseline_model_id : Optional [str ] = None
3948
4049 def __post_init__ (self ):
4150 # TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call
@@ -45,10 +54,10 @@ def __post_init__(self):
4554 self .name = response [NAME_KEY ]
4655 self .slice_id = response [SLICE_ID_KEY ]
4756
48- def add_criterion (
49- self , evaluation_criterion : EvaluationCriterion
57+ def add_eval_function (
58+ self , eval_function : EvalFunction
5059 ) -> ScenarioTestMetric :
51- """Creates and adds a new criteria to the :class:`ScenarioTest`. ::
60+ """Creates and adds a new evaluation metric to the :class:`ScenarioTest`. ::
5261
5362 import nucleus
5463 client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
@@ -58,49 +67,52 @@ def add_criterion(
5867
5968 e = client.validate.eval_functions
6069 # Assuming a user would like to add all available public evaluation functions as criteria
61- scenario_test.add_criterion (
62- e.bbox_iou() > 0.5
70+ scenario_test.add_eval_function (
71+ e.bbox_iou
6372 )
64- scenario_test.add_criterion (
65- e.bbox_map() > 0.85
73+ scenario_test.add_eval_function (
74+ e.bbox_map
6675 )
67- scenario_test.add_criterion (
68- e.bbox_precision() > 0.7
76+ scenario_test.add_eval_function (
77+ e.bbox_precision
6978 )
70- scenario_test.add_criterion (
71- e.bbox_recall() > 0.6
79+ scenario_test.add_eval_function (
80+ e.bbox_recall
7281 )
7382
7483 Args:
75- evaluation_criterion: :class:`EvaluationCriterion` created by comparison with an :class:`EvalFunction`
84+ eval_function: :class:`EvalFunction`
7685
7786 Returns:
7887 The created ScenarioTestMetric object.
7988 """
8089 response = self .connection .post (
81- AddScenarioTestMetric (
90+ AddScenarioTestFunction (
8291 scenario_test_name = self .name ,
83- eval_function_id = evaluation_criterion .eval_function_id ,
84- threshold = evaluation_criterion .threshold ,
85- threshold_comparison = evaluation_criterion .threshold_comparison ,
92+ eval_function_id = eval_function .id ,
8693 ).dict (),
87- "validate/scenario_test_metric " ,
94+ "validate/scenario_test_eval_function " ,
8895 )
96+ print (response )
8997 return ScenarioTestMetric (
90- scenario_test_id = response ["scenario_test_id" ],
91- eval_function_id = response ["eval_function_id" ],
92- threshold = evaluation_criterion .threshold ,
93- threshold_comparison = evaluation_criterion .threshold_comparison ,
98+ scenario_test_id = response [SCENARIO_TEST_ID_KEY ],
99+ eval_function_id = response [EVAL_FUNCTION_ID_KEY ],
100+ threshold = response .get (THRESHOLD_KEY , None ),
101+ threshold_comparison = response .get (
102+ THRESHOLD_COMPARISON_KEY ,
103+ ThresholdComparison .GREATER_THAN_EQUAL_TO ,
104+ ),
105+ connection = self .connection ,
94106 )
95107
96- def get_criteria (self ) -> List [ScenarioTestMetric ]:
108+ def get_eval_functions (self ) -> List [ScenarioTestMetric ]:
97109 """Retrieves all criteria of the :class:`ScenarioTest`. ::
98110
99111 import nucleus
100112 client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
101113 scenario_test = client.validate.scenario_tests[0]
102114
103- scenario_test.get_criteria ()
115+ scenario_test.get_eval_functions ()
104116
105117 Returns:
106118 A list of ScenarioTestMetric objects.
@@ -109,8 +121,8 @@ def get_criteria(self) -> List[ScenarioTestMetric]:
109121 f"validate/scenario_test/{ self .id } /metrics" ,
110122 )
111123 return [
112- ScenarioTestMetric (** metric )
113- for metric in response ["scenario_test_metrics" ]
124+ ScenarioTestMetric (** metric , connection = self . connection )
125+ for metric in response [SCENARIO_TEST_METRICS_KEY ]
114126 ]
115127
116128 def get_eval_history (self ) -> List [ScenarioTestEvaluation ]:
@@ -141,3 +153,24 @@ def get_items(self) -> List[DatasetItem]:
141153 return [
142154 DatasetItem .from_json (item ) for item in response [DATASET_ITEMS_KEY ]
143155 ]
156+
157+ def set_baseline_model (self , model_id : str ):
158+ """Set's a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
159+ this scenario test must have been evaluated using that model. The baseline model's performance
160+ is used as the threshold for all metrics against which other models are compared.
161+
162+ import nucleus
163+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
164+ scenario_test = client.validate.scenario_tests[0]
165+
166+ scenario_test.set_baseline_model('my_baseline_model_id')
167+
168+ Returns:
169+ A list of :class:`ScenarioTestEvaluation` objects.
170+ """
171+ response = self .connection .post (
172+ {},
173+ f"validate/scenario_test/{ self .id } /set_baseline_model/{ model_id } " ,
174+ )
175+ self .baseline_model_id = response .get ("baseline_model_id" )
176+ return self .baseline_model_id
0 commit comments