11from abc import abstractmethod
22from dataclasses import dataclass
3- from typing import List , Set , Tuple , Union
3+ from typing import List , Optional , Set , Tuple , Union
44
55from sklearn .metrics import f1_score
66
77from nucleus .annotation import AnnotationList , CategoryAnnotation
88from nucleus .metrics .base import Metric , MetricResult , ScalarResult
9+ from nucleus .metrics .filtering import ListOfAndFilters , ListOfOrAndFilters
910from nucleus .metrics .filters import confidence_filter
1011from nucleus .prediction import CategoryPrediction , PredictionList
1112
@@ -56,12 +57,37 @@ class CategorizationMetric(Metric):
5657 def __init__ (
5758 self ,
5859 confidence_threshold : float = 0.0 ,
60+ annotation_filters : Optional [
61+ Union [ListOfOrAndFilters , ListOfAndFilters ]
62+ ] = None ,
63+ prediction_filters : Optional [
64+ Union [ListOfOrAndFilters , ListOfAndFilters ]
65+ ] = None ,
5966 ):
6067 """Initializes CategorizationMetric abstract object.
6168
6269 Args:
6370 confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0
71+ annotation_filters: Filter predicates. Allowed formats are:
72+ ListOfAndFilters where each Filter forms a chain of AND predicates.
73+ or
74+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
75+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
76+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
77+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
78+ (AND), forming a more selective `and` multiple field predicate.
79+ Finally, the most outer list combines these filters as a disjunction (OR).
80+ prediction_filters: Filter predicates. Allowed formats are:
81+ ListOfAndFilters where each Filter forms a chain of AND predicates.
82+ or
83+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
84+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
85+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
86+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
87+ (AND), forming a more selective `and` multiple field predicate.
88+ Finally, the most outer list combines these filters as a disjunction (OR).
6489 """
90+ super ().__init__ (annotation_filters , prediction_filters )
6591 assert 0 <= confidence_threshold <= 1
6692 self .confidence_threshold = confidence_threshold
6793
@@ -83,7 +109,7 @@ def eval(
83109 def aggregate_score (self , results : List [CategorizationResult ]) -> ScalarResult : # type: ignore[override]
84110 pass
85111
86- def __call__ (
112+ def call_metric (
87113 self , annotations : AnnotationList , predictions : PredictionList
88114 ) -> CategorizationResult :
89115 if self .confidence_threshold > 0 :
@@ -139,7 +165,15 @@ class CategorizationF1(CategorizationMetric):
139165 """Evaluation method that matches categories and returns a CategorizationF1Result that aggregates to the F1 score"""
140166
141167 def __init__ (
142- self , confidence_threshold : float = 0.0 , f1_method : str = "macro"
168+ self ,
169+ confidence_threshold : float = 0.0 ,
170+ f1_method : str = "macro" ,
171+ annotation_filters : Optional [
172+ Union [ListOfOrAndFilters , ListOfAndFilters ]
173+ ] = None ,
174+ prediction_filters : Optional [
175+ Union [ListOfOrAndFilters , ListOfAndFilters ]
176+ ] = None ,
143177 ):
144178 """
145179 Args:
@@ -169,8 +203,28 @@ def __init__(
169203 Calculate metrics for each instance, and find their average (only
170204 meaningful for multilabel classification where this differs from
171205 :func:`accuracy_score`).
206+ annotation_filters: Filter predicates. Allowed formats are:
207+ ListOfAndFilters where each Filter forms a chain of AND predicates.
208+ or
209+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
210+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
211+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
212+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
213+ (AND), forming a more selective `and` multiple field predicate.
214+ Finally, the most outer list combines these filters as a disjunction (OR).
215+ prediction_filters: Filter predicates. Allowed formats are:
216+ ListOfAndFilters where each Filter forms a chain of AND predicates.
217+ or
218+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
219+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
220+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
221+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
222+ (AND), forming a more selective `and` multiple field predicate.
223+ Finally, the most outer list combines these filters as a disjunction (OR).
172224 """
173- super ().__init__ (confidence_threshold )
225+ super ().__init__ (
226+ confidence_threshold , annotation_filters , prediction_filters
227+ )
174228 assert (
175229 f1_method in F1_METHODS
176230 ), f"Invalid f1_method { f1_method } , expected one of { F1_METHODS } "
0 commit comments