Skip to content

Commit 223b790

Browse files
authored
Merge pull request #1012 from Labelbox/imuhammad/AL-5258-message-based-classification-annotation-types
[AL-5258] Annotation types for message based conversation classifications
2 parents 278b711 + e627082 commit 223b790

File tree

6 files changed

+111
-8
lines changed

6 files changed

+111
-8
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Any, Dict, List, Union
2+
from typing import Any, Dict, List, Optional, Union
33

44
from labelbox.data.mixins import ConfidenceMixin
55

@@ -27,10 +27,12 @@ class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin):
2727
name (Optional[str])
2828
feature_schema_id (Optional[Cuid])
2929
value (Union[Text, Checklist, Radio, Dropdown])
30+
message_id (Optional[str]) Message id for conversational text
3031
extra (Dict[str, Any])
3132
"""
3233

3334
value: Union[Text, Checklist, Radio, Dropdown]
35+
message_id: Optional[str] = None
3436

3537

3638
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin):

labelbox/data/serialization/ndjson/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def dict(self, *args, **kwargs):
3838
class NDAnnotation(NDJsonBase):
3939
name: Optional[str] = None
4040
schema_id: Optional[Cuid] = None
41+
message_id: Optional[str] = None
4142
page: Optional[int] = None
4243
unit: Optional[str] = None
4344

labelbox/data/serialization/ndjson/classification.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,15 @@ def from_common(cls,
127127
feature_schema_id: Cuid,
128128
extra: Dict[str, Any],
129129
data: Union[TextData, ImageData],
130+
message_id: str,
130131
confidence: Optional[float] = None) -> "NDText":
131132
return cls(
132133
answer=text.answer,
133134
data_row=DataRow(id=data.uid, global_key=data.global_key),
134135
name=name,
135136
schema_id=feature_schema_id,
136137
uuid=extra.get('uuid'),
138+
message_id=message_id,
137139
confidence=confidence,
138140
)
139141

@@ -147,6 +149,7 @@ def from_common(cls,
147149
feature_schema_id: Cuid,
148150
extra: Dict[str, Any],
149151
data: Union[VideoData, TextData, ImageData],
152+
message_id: str,
150153
confidence: Optional[float] = None) -> "NDChecklist":
151154
return cls(answer=[
152155
NDFeature(name=answer.name,
@@ -159,6 +162,7 @@ def from_common(cls,
159162
schema_id=feature_schema_id,
160163
uuid=extra.get('uuid'),
161164
frames=extra.get('frames'),
165+
message_id=message_id,
162166
confidence=confidence)
163167

164168

@@ -171,6 +175,7 @@ def from_common(cls,
171175
feature_schema_id: Cuid,
172176
extra: Dict[str, Any],
173177
data: Union[VideoData, TextData, ImageData],
178+
message_id: str,
174179
confidence: Optional[float] = None) -> "NDRadio":
175180
return cls(answer=NDFeature(name=radio.answer.name,
176181
schema_id=radio.answer.feature_schema_id,
@@ -180,6 +185,7 @@ def from_common(cls,
180185
schema_id=feature_schema_id,
181186
uuid=extra.get('uuid'),
182187
frames=extra.get('frames'),
188+
message_id=message_id,
183189
confidence=confidence)
184190

185191

@@ -228,6 +234,7 @@ def to_common(
228234
name=annotation.name,
229235
feature_schema_id=annotation.schema_id,
230236
extra={'uuid': annotation.uuid},
237+
message_id=annotation.message_id,
231238
confidence=annotation.confidence)
232239
if getattr(annotation, 'frames', None) is None:
233240
return [common]
@@ -252,6 +259,7 @@ def from_common(
252259
return classify_obj.from_common(annotation.value, annotation.name,
253260
annotation.feature_schema_id,
254261
annotation.extra, data,
262+
annotation.message_id,
255263
annotation.confidence)
256264

257265
@staticmethod

labelbox/data/serialization/ndjson/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
logger = logging.getLogger(__name__)
88

9-
IGNORE_IF_NONE = ["page", "unit"]
9+
IGNORE_IF_NONE = ["page", "unit", "messageId"]
1010

1111

1212
class NDJsonConverter:

tests/data/annotation_types/classification/test_classification.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def test_subclass():
5151
'extra': {},
5252
'value': {
5353
'answer': answer
54-
}
54+
},
55+
'message_id': None,
5556
}
5657
classification = ClassificationAnnotation(
5758
value=Text(answer=answer),
@@ -64,7 +65,8 @@ def test_subclass():
6465
'value': {
6566
'answer': answer
6667
},
67-
'name': name
68+
'name': name,
69+
'message_id': None,
6870
}
6971
classification = ClassificationAnnotation(
7072
value=Text(answer=answer),
@@ -76,7 +78,8 @@ def test_subclass():
7678
'extra': {},
7779
'value': {
7880
'answer': answer
79-
}
81+
},
82+
'message_id': None,
8083
}
8184

8285

@@ -115,7 +118,8 @@ def test_radio():
115118
'extra': {},
116119
'confidence': 0.81
117120
}
118-
}
121+
},
122+
'message_id': None,
119123
}
120124

121125

@@ -156,6 +160,7 @@ def test_checklist():
156160
'confidence': 0.99
157161
}]
158162
},
163+
'message_id': None,
159164
}
160165

161166

@@ -194,5 +199,6 @@ def test_dropdown():
194199
'confidence': 1,
195200
'extra': {}
196201
}]
197-
}
202+
},
203+
'message_id': None,
198204
}

tests/data/serialization/ndjson/test_conversation.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,95 @@
11
import json
22

33
import pytest
4-
4+
import labelbox.types as lb_types
55
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
66

7+
radio_ndjson = [{
8+
'dataRow': {
9+
'globalKey': 'my_global_key'
10+
},
11+
'name': 'radio',
12+
'answer': {
13+
'name': 'first_radio_answer'
14+
},
15+
'messageId': '0'
16+
}]
17+
18+
radio_label = [
19+
lb_types.Label(
20+
data=lb_types.ConversationData(global_key='my_global_key'),
21+
annotations=[
22+
lb_types.ClassificationAnnotation(
23+
name='radio',
24+
value=lb_types.Radio(answer=lb_types.ClassificationAnswer(
25+
name="first_radio_answer")),
26+
message_id="0")
27+
])
28+
]
29+
30+
checklist_ndjson = [{
31+
'dataRow': {
32+
'globalKey': 'my_global_key'
33+
},
34+
'name': 'checklist',
35+
'answer': [
36+
{
37+
'name': 'first_checklist_answer'
38+
},
39+
{
40+
'name': 'second_checklist_answer'
41+
},
42+
],
43+
'messageId': '2'
44+
}]
45+
46+
checklist_label = [
47+
lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'),
48+
annotations=[
49+
lb_types.ClassificationAnnotation(
50+
name='checklist',
51+
message_id="2",
52+
value=lb_types.Checklist(answer=[
53+
lb_types.ClassificationAnswer(
54+
name="first_checklist_answer"),
55+
lb_types.ClassificationAnswer(
56+
name="second_checklist_answer")
57+
]))
58+
])
59+
]
60+
61+
free_text_ndjson = [{
62+
'dataRow': {
63+
'globalKey': 'my_global_key'
64+
},
65+
'name': 'free_text',
66+
'answer': 'sample text',
67+
'messageId': '0'
68+
}]
69+
free_text_label = [
70+
lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'),
71+
annotations=[
72+
lb_types.ClassificationAnnotation(
73+
name='free_text',
74+
message_id="0",
75+
value=lb_types.Text(answer="sample text"))
76+
])
77+
]
78+
79+
80+
@pytest.mark.parametrize(
81+
"label, ndjson",
82+
[[radio_label, radio_ndjson], [checklist_label, checklist_ndjson],
83+
[free_text_label, free_text_ndjson]])
84+
def test_message_based_radio_classification(label, ndjson):
85+
serialized_label = list(NDJsonConverter().serialize(label))
86+
serialized_label[0].pop('uuid')
87+
assert serialized_label == ndjson
88+
89+
deserialized_label = list(NDJsonConverter().deserialize(ndjson))
90+
deserialized_label[0].annotations[0].extra.pop('uuid')
91+
assert deserialized_label[0].annotations == label[0].annotations
92+
793

894
@pytest.mark.parametrize("filename", [
995
"tests/data/assets/ndjson/conversation_entity_import.json",

0 commit comments

Comments
 (0)