@@ -185,6 +185,19 @@ def _calculate_confusion_matrix(
185185 ) = convert_to_instance_seg_confusion (
186186 confusion , annotation , prediction
187187 )
188+ else :
189+ ann_labels = list (
190+ dict .fromkeys (s .label for s in annotation .annotations )
191+ )
192+ pred_labels = list (
193+ dict .fromkeys (s .label for s in prediction .annotations )
194+ )
195+ missing_or_filtered_labels = set (ann_labels ) - set (pred_labels )
196+ non_taxonomy_classes = {
197+ segment .index
198+ for segment in annotation .annotations
199+ if segment .label in missing_or_filtered_labels
200+ }
188201
189202 return confusion , non_taxonomy_classes
190203
@@ -644,9 +657,13 @@ def _metric_impl(
644657 + confusion .sum (axis = 0 )
645658 - np .diag (confusion )
646659 )
647- freq = confusion .sum (axis = 0 ) / confusion .sum ()
648- fwavacc = (freq [freq > 0 ] * iu [freq > 0 ]).sum ()
649- fwavacc .put (list (non_taxonomy_classes ), np .nan )
660+ predicted_counts = confusion .sum (axis = 0 ).astype (np .float_ )
661+ predicted_counts .put (list (non_taxonomy_classes ), np .nan )
662+ freq = predicted_counts / np .nansum (predicted_counts )
663+ iu .put (list (non_taxonomy_classes ), np .nan )
664+ fwavacc = (
665+ np .nan_to_num (freq [freq > 0 ]) * np .nan_to_num (iu [freq > 0 ])
666+ ).sum ()
650667 mean_fwavacc = np .nanmean (fwavacc )
651668 return ScalarResult (value = np .nan_to_num (mean_fwavacc ), weight = confusion .sum ()) # type: ignore
652669
0 commit comments