7
7
from ..group import get_feature_pairs , get_identifying_key , has_no_annotations , has_no_matching_annotations
8
8
from ...annotation_types import (ObjectAnnotation , ClassificationAnnotation ,
9
9
Mask , Geometry , Point , Line , Checklist , Text ,
10
- Radio , ScalarMetricValue )
10
+ TextEntity , Radio , ScalarMetricValue )
11
11
12
12
13
13
def miou (ground_truths : List [Union [ObjectAnnotation , ClassificationAnnotation ]],
@@ -61,6 +61,8 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation,
61
61
return vector_miou (ground_truths , predictions , include_subclasses )
62
62
elif isinstance (predictions [0 ], ClassificationAnnotation ):
63
63
return classification_miou (ground_truths , predictions )
64
+ elif isinstance (predictions [0 ].value , TextEntity ):
65
+ return ner_miou (ground_truths , predictions , include_subclasses )
64
66
else :
65
67
raise ValueError (
66
68
f"Unexpected annotation found. Found { type (predictions [0 ].value )} " )
@@ -269,3 +271,51 @@ def _ensure_valid_poly(poly):
269
271
def _mask_iou (mask1 : np .ndarray , mask2 : np .ndarray ) -> ScalarMetricValue :
270
272
"""Computes iou between two binary segmentation masks."""
271
273
return np .sum (mask1 & mask2 ) / np .sum (mask1 | mask2 )
274
+
275
+
276
+ def _get_ner_pairs (
277
+ ground_truths : List [ObjectAnnotation ], predictions : List [ObjectAnnotation ]
278
+ ) -> List [Tuple [ObjectAnnotation , ObjectAnnotation , ScalarMetricValue ]]:
279
+ """Get iou score for all possible pairs of ground truths and predictions"""
280
+ pairs = []
281
+ for ground_truth , prediction in product (ground_truths , predictions ):
282
+ score = _ner_iou (ground_truth .value , prediction .value )
283
+ pairs .append ((ground_truth , prediction , score ))
284
+ return pairs
285
+
286
+
287
+ def _ner_iou (ner1 : TextEntity , ner2 : TextEntity ):
288
+ """Computes iou between two text entity annotations"""
289
+ intersection_start , intersection_end = max (ner1 .start , ner2 .start ), min (
290
+ ner1 .end , ner2 .end )
291
+ union_start , union_end = min (ner1 .start ,
292
+ ner2 .start ), max (ner1 .end , ner2 .end )
293
+ #edge case of only one character in text
294
+ if union_start == union_end :
295
+ return 1
296
+ #if there is no intersection
297
+ if intersection_start > intersection_end :
298
+ return 0
299
+ return (intersection_end - intersection_start ) / (union_end - union_start )
300
+
301
+
302
+ def ner_miou (ground_truths : List [ObjectAnnotation ],
303
+ predictions : List [ObjectAnnotation ],
304
+ include_subclasses : bool ) -> Optional [ScalarMetricValue ]:
305
+ """
306
+ Computes iou score for all features with the same feature schema id.
307
+ Calculation includes subclassifications.
308
+
309
+ Args:
310
+ ground_truths: List of ground truth ner annotations
311
+ predictions: List of prediction ner annotations
312
+ Returns:
313
+ float representing the iou score for the feature type.
314
+ If there are no matches then this returns none
315
+ """
316
+ if has_no_matching_annotations (ground_truths , predictions ):
317
+ return 0.
318
+ elif has_no_annotations (ground_truths , predictions ):
319
+ return None
320
+ pairs = _get_ner_pairs (ground_truths , predictions )
321
+ return object_pair_miou (pairs , include_subclasses )
0 commit comments