11import itertools
2- from typing import Callable , Dict , List , Type , Union
2+ from typing import Callable , Dict , List , Optional , Union
33
44from nucleus .logger import logger
5- from nucleus .validate .eval_functions .base_eval_function import BaseEvalFunction
5+ from nucleus .validate .eval_functions .base_eval_function import (
6+ EvalFunctionConfig ,
7+ )
68
79from ..data_transfer_objects .eval_function import EvalFunctionEntry
810from ..errors import EvalFunctionNotAvailableError
911
1012MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes"
1113
1214
13- class BoundingBoxIOU (BaseEvalFunction ):
15+ class PolygonIOUConfig (EvalFunctionConfig ):
16+ def __call__ (
17+ self ,
18+ enforce_label_match : bool = False ,
19+ iou_threshold : float = 0.0 ,
20+ confidence_threshold : float = 0.0 ,
21+ ** kwargs ,
22+ ):
23+ """Configures a call to :class:`PolygonIOU` object.
24+ ::
25+
26+ import nucleus
27+
28+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
29+ bbox_iou: BoundingBoxIOU = client.validate.eval_functions.bbox_iou
30+ slice_id = "slc_<your_slice>"
31+ scenario_test = client.validate.create_scenario_test(
32+ "Example test",
33+ slice_id=slice_id,
34+ evaluation_criteria=[bbox_iou(confidence_threshold=0.8) > 0.5]
35+ )
36+
37+ Args:
38+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
39+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
40+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
41+ """
42+ return super ().__call__ (
43+ enforce_label_match = enforce_label_match ,
44+ iou_threshold = iou_threshold ,
45+ confidence_threshold = confidence_threshold ,
46+ ** kwargs ,
47+ )
48+
1449 @classmethod
1550 def expected_name (cls ) -> str :
1651 return "bbox_iou"
1752
1853
19- class BoundingBoxMeanAveragePrecision (BaseEvalFunction ):
54+ class PolygonMAPConfig (EvalFunctionConfig ):
55+ def __call__ (
56+ self ,
57+ iou_threshold : float = 0.5 ,
58+ ** kwargs ,
59+ ):
60+ """Configures a call to :class:`PolygonMAP` object.
61+ ::
62+
63+ import nucleus
64+
65+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
66+ bbox_map: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_map
67+ slice_id = "slc_<your_slice>"
68+ scenario_test = client.validate.create_scenario_test(
69+ "Example test",
70+ slice_id=slice_id,
71+ evaluation_criteria=[bbox_map(iou_threshold=0.6) > 0.8]
72+ )
73+
74+ Args:
75+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
76+ """
77+ return super ().__call__ (
78+ iou_threshold = iou_threshold ,
79+ ** kwargs ,
80+ )
81+
2082 @classmethod
2183 def expected_name (cls ) -> str :
2284 return "bbox_map"
2385
2486
25- class BoundingBoxRecall (BaseEvalFunction ):
87+ class PolygonRecallConfig (EvalFunctionConfig ):
88+ def __call__ (
89+ self ,
90+ enforce_label_match : bool = False ,
91+ iou_threshold : float = 0.5 ,
92+ confidence_threshold : float = 0.0 ,
93+ ** kwargs ,
94+ ):
95+ """Configures a call to :class:`PolygonRecall` object.
96+ ::
97+
98+ import nucleus
99+
100+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
101+ bbox_recall: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_recall
102+ slice_id = "slc_<your_slice>"
103+ scenario_test = client.validate.create_scenario_test(
104+ "Example test",
105+ slice_id=slice_id,
106+ evaluation_criteria=[bbox_recall(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
107+ )
108+
109+ Args:
110+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
111+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
112+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
113+ """
114+ return super ().__call__ (
115+ enforce_label_match = enforce_label_match ,
116+ iou_threshold = iou_threshold ,
117+ confidence_threshold = confidence_threshold ,
118+ ** kwargs ,
119+ )
120+
26121 @classmethod
27122 def expected_name (cls ) -> str :
28123 return "bbox_recall"
29124
30125
31- class BoundingBoxPrecision (BaseEvalFunction ):
126+ class PolygonPrecisionConfig (EvalFunctionConfig ):
127+ def __call__ (
128+ self ,
129+ enforce_label_match : bool = False ,
130+ iou_threshold : float = 0.5 ,
131+ confidence_threshold : float = 0.0 ,
132+ ** kwargs ,
133+ ):
134+ """Configures a call to :class:`PolygonPrecision` object.
135+ ::
136+
137+ import nucleus
138+
139+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
140+ bbox_precision: BoundingBoxMeanAveragePrecision= client.validate.eval_functions.bbox_precision
141+ slice_id = "slc_<your_slice>"
142+ scenario_test = client.validate.create_scenario_test(
143+ "Example test",
144+ slice_id=slice_id,
145+ evaluation_criteria=[bbox_precision(iou_threshold=0.6, confidence_threshold=0.4) > 0.9]
146+ )
147+
148+ Args:
149+ enforce_label_match: whether to enforce that annotation and prediction labels must match. Defaults to False
150+ iou_threshold: IOU threshold to consider detection as valid. Must be in [0, 1]. Default 0.0
151+ confidence_threshold: minimum confidence threshold for predictions. Must be in [0, 1]. Default 0.0
152+ """
153+ return super ().__call__ (
154+ enforce_label_match = enforce_label_match ,
155+ iou_threshold = iou_threshold ,
156+ confidence_threshold = confidence_threshold ,
157+ ** kwargs ,
158+ )
159+
32160 @classmethod
33161 def expected_name (cls ) -> str :
34162 return "bbox_precision"
35163
36164
37- class CategorizationF1 (BaseEvalFunction ):
165+ class CategorizationF1Config (EvalFunctionConfig ):
166+ def __call__ (
167+ self ,
168+ confidence_threshold : Optional [float ] = None ,
169+ f1_method : Optional [str ] = None ,
170+ ** kwargs ,
171+ ):
172+ """ Configure an evaluation of :class:`CategorizationF1`.
173+ ::
174+
175+ import nucleus
176+
177+ client = nucleus.NucleusClient(YOUR_SCALE_API_KEY)
178+ cat_f1: CategorizationF1 = client.validate.eval_functions.cat_f1
179+ slice_id = "slc_<your_slice>"
180+ scenario_test = client.validate.create_scenario_test(
181+ "Example test",
182+ slice_id=slice_id,
183+ evaluation_criteria=[cat_f1(confidence_threshold=0.6, f1_method="weighted") > 0.7]
184+ )
185+
186+ Args:
187+ confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation.
188+ Must be in [0, 1]. Default 0.0
189+ f1_method: {'micro', 'macro', 'samples','weighted', 'binary'}, \
190+ default='macro'
191+ This parameter is required for multiclass/multilabel targets.
192+ If ``None``, the scores for each class are returned. Otherwise, this
193+ determines the type of averaging performed on the data:
194+
195+ ``'binary'``:
196+ Only report results for the class specified by ``pos_label``.
197+ This is applicable only if targets (``y_{true,pred}``) are binary.
198+ ``'micro'``:
199+ Calculate metrics globally by counting the total true positives,
200+ false negatives and false positives.
201+ ``'macro'``:
202+ Calculate metrics for each label, and find their unweighted
203+ mean. This does not take label imbalance into account.
204+ ``'weighted'``:
205+ Calculate metrics for each label, and find their average weighted
206+ by support (the number of true instances for each label). This
207+ alters 'macro' to account for label imbalance; it can result in an
208+ F-score that is not between precision and recall.
209+ ``'samples'``:
210+ Calculate metrics for each instance, and find their average (only
211+ meaningful for multilabel classification where this differs from
212+ :func:`accuracy_score`).
213+ """
214+ return super ().__call__ (
215+ confidence_threshold = confidence_threshold , f1_method = f1_method
216+ )
217+
38218 @classmethod
39219 def expected_name (cls ) -> str :
40220 return "cat_f1"
41221
42222
43- class CustomEvalFunction (BaseEvalFunction ):
223+ class CustomEvalFunction (EvalFunctionConfig ):
44224 @classmethod
45225 def expected_name (cls ) -> str :
46226 raise NotImplementedError (
47227 "Custm evaluation functions are coming soon"
48228 ) # Placeholder: See super().eval_func_entry for actual name
49229
50230
51- class StandardEvalFunction (BaseEvalFunction ):
231+ class StandardEvalFunction (EvalFunctionConfig ):
52232 """Class for standard Model CI eval functions that have not been added as attributes on
53233 AvailableEvalFunctions yet.
54234 """
@@ -65,7 +245,7 @@ def expected_name(cls) -> str:
65245 return "public_function" # Placeholder: See super().eval_func_entry for actual name
66246
67247
68- class EvalFunctionNotAvailable (BaseEvalFunction ):
248+ class EvalFunctionNotAvailable (EvalFunctionConfig ):
69249 def __init__ (
70250 self , not_available_name : str
71251 ): # pylint: disable=super-init-not-called
@@ -89,13 +269,14 @@ def expected_name(cls) -> str:
89269
90270
91271EvalFunction = Union [
92- Type [BoundingBoxIOU ],
93- Type [BoundingBoxMeanAveragePrecision ],
94- Type [BoundingBoxPrecision ],
95- Type [BoundingBoxRecall ],
96- Type [CustomEvalFunction ],
97- Type [EvalFunctionNotAvailable ],
98- Type [StandardEvalFunction ],
272+ PolygonIOUConfig ,
273+ PolygonMAPConfig ,
274+ PolygonPrecisionConfig ,
275+ PolygonRecallConfig ,
276+ CategorizationF1Config ,
277+ CustomEvalFunction ,
278+ EvalFunctionNotAvailable ,
279+ StandardEvalFunction ,
99280]
100281
101282
@@ -124,24 +305,24 @@ def __init__(self, available_functions: List[EvalFunctionEntry]):
124305 f .name : f for f in available_functions if f .is_public
125306 }
126307 # NOTE: Public are assigned
127- self ._public_to_function : Dict [str , BaseEvalFunction ] = {}
308+ self ._public_to_function : Dict [str , EvalFunctionConfig ] = {}
128309 self ._custom_to_function : Dict [str , CustomEvalFunction ] = {
129310 f .name : CustomEvalFunction (f )
130311 for f in available_functions
131312 if not f .is_public
132313 }
133- self .bbox_iou = self ._assign_eval_function_if_defined (BoundingBoxIOU ) # type: ignore
134- self .bbox_precision = self ._assign_eval_function_if_defined (
135- BoundingBoxPrecision # type: ignore
314+ self .bbox_iou : PolygonIOUConfig = self ._assign_eval_function_if_defined (PolygonIOUConfig ) # type: ignore
315+ self .bbox_precision : PolygonPrecisionConfig = self ._assign_eval_function_if_defined (
316+ PolygonPrecisionConfig # type: ignore
136317 )
137- self .bbox_recall = self ._assign_eval_function_if_defined (
138- BoundingBoxRecall # type: ignore
318+ self .bbox_recall : PolygonRecallConfig = self ._assign_eval_function_if_defined (
319+ PolygonRecallConfig # type: ignore
139320 )
140- self .bbox_map = self ._assign_eval_function_if_defined (
141- BoundingBoxMeanAveragePrecision # type: ignore
321+ self .bbox_map : PolygonMAPConfig = self ._assign_eval_function_if_defined (
322+ PolygonMAPConfig # type: ignore
142323 )
143- self .cat_f1 = self ._assign_eval_function_if_defined (
144- CategorizationF1 # type: ignore
324+ self .cat_f1 : CategorizationF1Config = self ._assign_eval_function_if_defined (
325+ CategorizationF1Config # type: ignore
145326 )
146327
147328 # Add public entries that have not been implemented as an attribute on this class
@@ -163,7 +344,7 @@ def __repr__(self):
163344 )
164345
165346 @property
166- def public_functions (self ) -> Dict [str , BaseEvalFunction ]:
347+ def public_functions (self ) -> Dict [str , EvalFunctionConfig ]:
167348 """Standard functions provided by Model CI.
168349
169350 Notes:
0 commit comments