Skip to content

Commit 173a1dc

Browse files
authored
Merge pull request #791 from Labelbox/lgluszek/ontology-media-type
[CCV-1905][CCV-1906][CCV-1907] Pass media type when creating an ontology
2 parents 2a0d146 + 67b83b8 commit 173a1dc

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

labelbox/client.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from labelbox.schema.slice import CatalogSlice
3737
from labelbox.schema.queue_mode import QueueMode
3838

39-
from labelbox.schema.media_type import MediaType
39+
from labelbox.schema.media_type import MediaType, get_media_type_validation_error
4040

4141
logger = logging.getLogger(__name__)
4242

@@ -853,15 +853,18 @@ def rootSchemaPayloadToFeatureSchema(client, payload):
853853
rootSchemaPayloadToFeatureSchema,
854854
['rootSchemaNodes', 'nextCursor'])
855855

856-
def create_ontology_from_feature_schemas(self, name,
857-
feature_schema_ids) -> Ontology:
856+
def create_ontology_from_feature_schemas(self,
857+
name,
858+
feature_schema_ids,
859+
media_type=None) -> Ontology:
858860
"""
859861
Creates an ontology from a list of feature schema ids
860862
861863
Args:
862864
name (str): Name of the ontology
863865
feature_schema_ids (List[str]): List of feature schema ids corresponding to
864866
top level tools and classifications to include in the ontology
867+
media_type (MediaType or None): Media type of a new ontology
865868
Returns:
866869
The created Ontology
867870
"""
@@ -891,9 +894,9 @@ def create_ontology_from_feature_schemas(self, name,
891894
"Neither `tool` or `classification` found in the normalized feature schema"
892895
)
893896
normalized = {'tools': tools, 'classifications': classifications}
894-
return self.create_ontology(name, normalized)
897+
return self.create_ontology(name, normalized, media_type)
895898

896-
def create_ontology(self, name, normalized) -> Ontology:
899+
def create_ontology(self, name, normalized, media_type=None) -> Ontology:
897900
"""
898901
Creates an ontology from normalized data
899902
>>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []}
@@ -910,13 +913,27 @@ def create_ontology(self, name, normalized) -> Ontology:
910913
Args:
911914
name (str): Name of the ontology
912915
normalized (dict): A normalized ontology payload. See above for details.
916+
media_type (MediaType or None): Media type of a new ontology
913917
Returns:
914918
The created Ontology
915919
"""
920+
921+
if media_type:
922+
if MediaType.is_supported(media_type):
923+
media_type = media_type.value
924+
else:
925+
raise get_media_type_validation_error(media_type)
926+
916927
query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertOntologyInput!){
917928
upsertOntology(data: $data){ %s }
918929
} """ % query.results_query_part(Entity.Ontology)
919-
params = {'data': {'name': name, 'normalized': json.dumps(normalized)}}
930+
params = {
931+
'data': {
932+
'name': name,
933+
'normalized': json.dumps(normalized),
934+
'mediaType': media_type
935+
}
936+
}
920937
res = self.execute(query_str, params)
921938
return Entity.Ontology(self, res['upsertOntology'])
922939

@@ -1035,9 +1052,9 @@ def _format_failed_rows(rows: Dict[str, str],
10351052
)
10361053

10371054
# Start assign global keys to data rows job
1038-
query_str = """mutation assignGlobalKeysToDataRowsPyApi($globalKeyDataRowLinks: [AssignGlobalKeyToDataRowInput!]!) {
1039-
assignGlobalKeysToDataRows(data: {assignInputs: $globalKeyDataRowLinks}) {
1040-
jobId
1055+
query_str = """mutation assignGlobalKeysToDataRowsPyApi($globalKeyDataRowLinks: [AssignGlobalKeyToDataRowInput!]!) {
1056+
assignGlobalKeysToDataRows(data: {assignInputs: $globalKeyDataRowLinks}) {
1057+
jobId
10411058
}
10421059
}
10431060
"""
@@ -1172,7 +1189,7 @@ def _format_failed_rows(rows: List[str],
11721189

11731190
# Query string for retrieving job status and result, if job is done
11741191
result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) {
1175-
dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data {
1192+
dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data {
11761193
fetchedDataRows { id }
11771194
notFoundGlobalKeys
11781195
accessDeniedGlobalKeys
@@ -1246,8 +1263,8 @@ def clear_global_keys(
12461263
12471264
'Results' contains a list global keys that were successfully cleared.
12481265
1249-
'Errors' contains a list of global_keys correspond to the data rows that could not be
1250-
modified, accessed by the user, or not found.
1266+
'Errors' contains a list of global_keys correspond to the data rows that could not be
1267+
modified, accessed by the user, or not found.
12511268
Examples:
12521269
>>> job_result = client.get_data_row_ids_for_global_keys(["key1","key2"])
12531270
>>> print(job_result['status'])
@@ -1271,7 +1288,7 @@ def _format_failed_rows(rows: List[str],
12711288

12721289
# Query string for retrieving job status and result, if job is done
12731290
result_query_str = """query clearGlobalKeysResultPyApi($jobId: ID!) {
1274-
clearGlobalKeysResult(jobId: {id: $jobId}) { data {
1291+
clearGlobalKeysResult(jobId: {id: $jobId}) { data {
12751292
clearedGlobalKeys
12761293
failedToClearGlobalKeys
12771294
notFoundGlobalKeys

labelbox/schema/media_type.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ def get_supported_members(cls):
4444
item for item in cls.__members__
4545
if item not in ["Unknown", "Unsupported"]
4646
]
47+
48+
49+
def get_media_type_validation_error(media_type):
50+
return TypeError(f"{media_type} is not a valid media type. Use"
51+
f" any of {MediaType.get_supported_members()}"
52+
" from MediaType. Example: MediaType.Image.")

0 commit comments

Comments
 (0)