Skip to content

Commit 817335e

Browse files
authored
Merge pull request #1009 from Labelbox/imuhammad/fix-data-type-coersion
Fix annotation data type coersion by Pydantic
2 parents 3f4e423 + 5ffc6dc commit 817335e

File tree

10 files changed

+73
-12
lines changed

10 files changed

+73
-12
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class AudioData(BaseData):
5-
...
6+
class AudioData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["AudioData"] = "AudioData"
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

46
class ConversationData(BaseData):
5-
...
7+
class_name: Literal["ConversationData"] = "ConversationData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class DicomData(BaseData):
5-
...
6+
class DicomData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["DicomData"] = "DicomData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class DocumentData(BaseData):
5-
...
6+
class DocumentData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["DocumentData"] = "DocumentData"
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from labelbox.typing_imports import Literal
2+
from labelbox.utils import _NoCoercionMixin
13
from .base_data import BaseData
24

35

4-
class HTMLData(BaseData):
5-
...
6+
class HTMLData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["HTMLData"] = "HTMLData"

labelbox/data/annotation_types/data/text.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from pydantic import root_validator
77

88
from labelbox.exceptions import InternalServerError
9+
from labelbox.typing_imports import Literal
10+
from labelbox.utils import _NoCoercionMixin
911
from .base_data import BaseData
1012

1113

12-
class TextData(BaseData):
14+
class TextData(BaseData, _NoCoercionMixin):
1315
"""
1416
Represents text data. Requires arg file_path, text, or url
1517
@@ -20,6 +22,7 @@ class TextData(BaseData):
2022
text (str)
2123
url (str)
2224
"""
25+
class_name: Literal["TextData"] = "TextData"
2326
file_path: Optional[str] = None
2427
text: Optional[str] = None
2528
url: Optional[str] = None

labelbox/data/annotation_types/label.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
VideoClassificationAnnotation, VideoObjectAnnotation,
1212
DICOMObjectAnnotation)
1313
from .classification import ClassificationAnswer
14-
from .data import DicomData, VideoData, TextData, ImageData
14+
from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData
1515
from .geometry import Mask
1616
from .metrics import ScalarMetric, ConfusionMatrixMetric
1717
from .types import Cuid
1818
from ..ontology import get_feature_schema_lookup
1919

20+
DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData,
21+
ConversationData, DicomData, DocumentData, HTMLData]
22+
2023

2124
class Label(BaseModel):
2225
"""Container for holding data and annotations
@@ -38,7 +41,7 @@ class Label(BaseModel):
3841
extra: additional context
3942
"""
4043
uid: Optional[Cuid] = None
41-
data: Union[VideoData, ImageData, TextData, TiledImageData]
44+
data: DataType
4245
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
4346
ScalarMetric, ConfusionMatrixMetric]] = []
4447
extra: Dict[str, Any] = {}

labelbox/typing_imports.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
This module imports types that differ across python versions, so other modules
3+
don't have to worry about where they should be imported from.
4+
"""
5+
6+
import sys
7+
if sys.version_info >= (3, 8):
8+
from typing import Literal
9+
else:
10+
from typing_extensions import Literal

labelbox/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,26 @@ class _CamelCaseMixin(BaseModel):
3535
class Config:
3636
allow_population_by_field_name = True
3737
alias_generator = camel_case
38+
39+
40+
class _NoCoercionMixin:
41+
"""
42+
When using Unions in type annotations, pydantic will try to coerce the type
43+
of the object to the type of the first Union member. Which results in
44+
uninteded behavior.
45+
46+
This mixin uses a class_name discriminator field to prevent pydantic from
47+
corecing the type of the object. Add a class_name field to the class you
48+
want to discrimniate and use this mixin class to remove the discriminator
49+
when serializing the object.
50+
51+
Example:
52+
class ConversationData(BaseData, _NoCoercionMixin):
53+
class_name: Literal["ConversationData"] = "ConversationData"
54+
55+
"""
56+
57+
def dict(self, *args, **kwargs):
58+
res = super().dict(*args, **kwargs)
59+
res.pop('class_name')
60+
return res

tests/data/annotation_types/test_label.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
import labelbox.types as lb_types
34
from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option
45
from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text,
56
ClassificationAnnotation,
@@ -181,3 +182,14 @@ def test_schema_assignment_confidence():
181182
])
182183

183184
assert label.annotations[0].confidence == 0.914
185+
186+
187+
def test_initialize_label_no_coercion():
188+
global_key = 'global-key'
189+
ner_annotation = lb_types.ObjectAnnotation(
190+
name="ner",
191+
value=lb_types.ConversationEntity(start=0, end=8, message_id="4"))
192+
label = Label(data=lb_types.ConversationData(global_key=global_key),
193+
annotations=[ner_annotation])
194+
assert isinstance(label.data, lb_types.ConversationData)
195+
assert label.data.global_key == global_key

0 commit comments

Comments
 (0)