1- from typing import Dict , Optional , List , Union , Type
1+ from typing import Dict , List , Optional , Type , Union
2+
3+ from nucleus .annotation import check_all_annotation_paths_remote
4+ from nucleus .job import AsyncJob
5+ from nucleus .utils import serialize_and_write_to_presigned_url
6+
27from .constants import (
38 ANNOTATIONS_KEY ,
4- DEFAULT_ANNOTATION_UPDATE_MODE ,
59 BOX_TYPE ,
10+ DEFAULT_ANNOTATION_UPDATE_MODE ,
11+ MASK_TYPE ,
612 POLYGON_TYPE ,
7- SEGMENTATION_TYPE ,
13+ REQUEST_ID_KEY ,
14+ UPDATE_KEY ,
815)
916from .prediction import (
1017 BoxPrediction ,
@@ -84,7 +91,9 @@ def predict(
8491 Union [BoxPrediction , PolygonPrediction , SegmentationPrediction ]
8592 ],
8693 update : Optional [bool ] = DEFAULT_ANNOTATION_UPDATE_MODE ,
87- ) -> dict :
94+ asynchronous : bool = False ,
95+ dataset_id : Optional [str ] = None ,
96+ ) -> Union [dict , AsyncJob ]:
8897 """
8998 Uploads model outputs as predictions for a model_run. Returns info about the upload.
9099 :param annotations: List[Union[BoxPrediction, PolygonPrediction]],
@@ -95,7 +104,24 @@ def predict(
95104 "predictions_ignored": int,
96105 }
97106 """
98- return self ._client .predict (self .model_run_id , annotations , update )
107+ if asynchronous :
108+ check_all_annotation_paths_remote (annotations )
109+
110+ assert (
111+ dataset_id is not None
112+ ), "For now, you must pass a dataset id to predict for asynchronous uploads."
113+
114+ request_id = serialize_and_write_to_presigned_url (
115+ annotations , dataset_id , self ._client
116+ )
117+ response = self ._client .make_request (
118+ payload = {REQUEST_ID_KEY : request_id , UPDATE_KEY : update },
119+ route = f"modelRun/{ self .model_run_id } /predict?async=1" ,
120+ )
121+
122+ return AsyncJob (response ["job_id" ], self ._client )
123+ else :
124+ return self ._client .predict (self .model_run_id , annotations , update )
99125
100126 def iloc (self , i : int ):
101127 """
@@ -153,7 +179,7 @@ def _format_prediction_response(
153179 ] = {
154180 BOX_TYPE : BoxPrediction ,
155181 POLYGON_TYPE : PolygonPrediction ,
156- SEGMENTATION_TYPE : SegmentationPrediction ,
182+ MASK_TYPE : SegmentationPrediction ,
157183 }
158184 for type_key in annotation_payload :
159185 type_class = type_key_to_class [type_key ]
0 commit comments