11from typing import Dict , Optional , List
22from .annotation import (
33 BoxAnnotation ,
4+ CategoryAnnotation ,
45 Point ,
56 PolygonAnnotation ,
67 Segment ,
1314 BOX_TYPE ,
1415 CUBOID_TYPE ,
1516 POLYGON_TYPE ,
17+ CATEGORY_TYPE ,
1618 REFERENCE_ID_KEY ,
1719 METADATA_KEY ,
1820 GEOMETRY_KEY ,
1921 LABEL_KEY ,
22+ TAXONOMY_NAME_KEY ,
2023 TYPE_KEY ,
2124 X_KEY ,
2225 Y_KEY ,
@@ -40,6 +43,8 @@ def from_json(payload: dict):
4043 return PolygonPrediction .from_json (payload )
4144 elif payload .get (TYPE_KEY , None ) == CUBOID_TYPE :
4245 return CuboidPrediction .from_json (payload )
46+ elif payload .get (TYPE_KEY , None ) == CATEGORY_TYPE :
47+ return CategoryPrediction .from_json (payload )
4348 else :
4449 return SegmentationPrediction .from_json (payload )
4550
@@ -207,3 +212,43 @@ def from_json(cls, payload: dict):
207212 metadata = payload .get (METADATA_KEY , {}),
208213 class_pdf = payload .get (CLASS_PDF_KEY , None ),
209214 )
215+
216+
217+ class CategoryPrediction (CategoryAnnotation ):
218+ def __init__ (
219+ self ,
220+ label : str ,
221+ taxonomy_name : str ,
222+ reference_id : str ,
223+ confidence : Optional [float ] = None ,
224+ metadata : Optional [Dict ] = None ,
225+ class_pdf : Optional [Dict ] = None ,
226+ ):
227+ super ().__init__ (
228+ label = label ,
229+ taxonomy_name = taxonomy_name ,
230+ reference_id = reference_id ,
231+ metadata = metadata ,
232+ )
233+ self .confidence = confidence
234+ self .class_pdf = class_pdf
235+
236+ def to_payload (self ) -> dict :
237+ payload = super ().to_payload ()
238+ if self .confidence is not None :
239+ payload [CONFIDENCE_KEY ] = self .confidence
240+ if self .class_pdf is not None :
241+ payload [CLASS_PDF_KEY ] = self .class_pdf
242+
243+ return payload
244+
245+ @classmethod
246+ def from_json (cls , payload : dict ):
247+ return cls (
248+ label = payload .get (LABEL_KEY , 0 ),
249+ taxonomy_name = payload .get (TAXONOMY_NAME_KEY , None ),
250+ reference_id = payload [REFERENCE_ID_KEY ],
251+ confidence = payload .get (CONFIDENCE_KEY , None ),
252+ metadata = payload .get (METADATA_KEY , {}),
253+ class_pdf = payload .get (CLASS_PDF_KEY , None ),
254+ )
0 commit comments