Skip to content

Commit 327800b

Browse files
chore: clean up and track test files
1 parent 9675c73 commit 327800b

File tree

5 files changed

+389
-17
lines changed

5 files changed

+389
-17
lines changed

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def serialize_model(self, handler):
6060
return res
6161

6262

63+
class FrameLocation(BaseModel):
64+
end: int
65+
start: int
66+
67+
68+
class VideoSupported(BaseModel):
69+
# Note that frames are only allowed as top level inferences for video
70+
frames: Optional[List[FrameLocation]] = None
71+
72+
@model_serializer(mode="wrap")
73+
def serialize_model(self, handler):
74+
res = handler(self)
75+
# This means these are no video frames ..
76+
if self.frames is None:
77+
res.pop("frames")
78+
return res
6379

6480

6581
class NDTextSubclass(NDAnswer):
@@ -226,14 +242,13 @@ def from_common(
226242
name=name,
227243
schema_id=feature_schema_id,
228244
uuid=uuid,
229-
frames=extra.get("frames"),
230245
message_id=message_id,
231246
confidence=text.confidence,
232247
custom_metrics=text.custom_metrics,
233248
)
234249

235250

236-
class NDChecklist(NDAnnotation, NDChecklistSubclass):
251+
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
237252
@model_serializer(mode="wrap")
238253
def serialize_model(self, handler):
239254
res = handler(self)
@@ -280,7 +295,7 @@ def from_common(
280295
)
281296

282297

283-
class NDRadio(NDAnnotation, NDRadioSubclass):
298+
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
284299
@classmethod
285300
def from_common(
286301
cls,
@@ -410,8 +425,7 @@ def to_common(
410425
def from_common(
411426
cls,
412427
annotation: Union[
413-
ClassificationAnnotation,
414-
VideoClassificationAnnotation,
428+
ClassificationAnnotation, VideoClassificationAnnotation
415429
],
416430
data: GenericDataRowData,
417431
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
@@ -434,8 +448,7 @@ def from_common(
434448
@staticmethod
435449
def lookup_classification(
436450
annotation: Union[
437-
ClassificationAnnotation,
438-
VideoClassificationAnnotation,
451+
ClassificationAnnotation, VideoClassificationAnnotation
439452
],
440453
) -> Union[NDText, NDChecklist, NDRadio]:
441454
return {Text: NDText, Checklist: NDChecklist, Radio: NDRadio}.get(

libs/labelbox/tests/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -688,12 +688,12 @@ def create_label():
688688
predictions,
689689
)
690690
upload_task.wait_until_done(sleep_time_seconds=5)
691-
assert upload_task.state == AnnotationImportState.FINISHED, (
692-
"Label Import did not finish"
693-
)
694-
assert len(upload_task.errors) == 0, (
695-
f"Label Import {upload_task.name} failed with errors {upload_task.errors}"
696-
)
691+
assert (
692+
upload_task.state == AnnotationImportState.FINISHED
693+
), "Label Import did not finish"
694+
assert (
695+
len(upload_task.errors) == 0
696+
), f"Label Import {upload_task.name} failed with errors {upload_task.errors}"
697697

698698
project.create_label = create_label
699699
project.create_label()

0 commit comments

Comments
 (0)