Skip to content

Commit f9e4457

Browse files
authored
Merge pull request #1010 from Labelbox/mno/al-5278
2 parents d1ea29c + 750cf3d commit f9e4457

File tree

10 files changed

+453
-14
lines changed

10 files changed

+453
-14
lines changed

labelbox/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,8 +1559,7 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice:
15591559
Returns:
15601560
CatalogSlice
15611561
"""
1562-
query_str = """
1563-
query getSavedQueryPyApi($id: ID!) {
1562+
query_str = """query getSavedQueryPyApi($id: ID!) {
15641563
getSavedQuery(id: $id) {
15651564
id
15661565
name

labelbox/schema/data_row.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import logging
2-
from typing import TYPE_CHECKING, Optional
2+
from typing import TYPE_CHECKING, Collection, Dict, List, Optional
33
import json
4+
from labelbox.exceptions import ResourceNotFoundError
45

56
from labelbox.orm import query
67
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
78
from labelbox.orm.model import Entity, Field, Relationship
89
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
10+
from labelbox.schema.export_params import CatalogExportParams
11+
from labelbox.schema.task import Task
12+
from labelbox.schema.user import User # type: ignore
913

1014
if TYPE_CHECKING:
11-
from labelbox import AssetAttachment
15+
from labelbox import AssetAttachment, Client
1216

1317
logger = logging.getLogger(__name__)
1418

@@ -150,3 +154,105 @@ def create_attachment(self,
150154
})
151155
return Entity.AssetAttachment(self.client,
152156
res["createDataRowAttachment"])
157+
158+
@staticmethod
159+
def export_v2(client: 'Client',
160+
data_rows: List['DataRow'],
161+
task_name: Optional[str] = None,
162+
params: Optional[CatalogExportParams] = None) -> Task:
163+
"""
164+
Creates a data rows export task with the given list, params and returns the task.
165+
166+
>>> dataset = client.get_dataset(DATASET_ID)
167+
>>> task = DataRow.export_v2(
168+
>>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()],
169+
>>> filters={
170+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
171+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
172+
>>> },
173+
>>> params={
174+
>>> "performance_details": False,
175+
>>> "label_details": True
176+
>>> })
177+
>>> task.wait_till_done()
178+
>>> task.result
179+
"""
180+
print('export start')
181+
182+
_params = params or CatalogExportParams({
183+
"attachments": False,
184+
"metadata_fields": False,
185+
"data_row_details": False,
186+
"project_details": False,
187+
"performance_details": False,
188+
"label_details": False,
189+
"media_type_override": None,
190+
"model_runs_ids": None,
191+
"projects_ids": None,
192+
})
193+
194+
mutation_name = "exportDataRowsInCatalog"
195+
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
196+
%s(input: $input) {taskId} }
197+
""" % (mutation_name)
198+
199+
data_rows_ids = [data_row.uid for data_row in data_rows]
200+
search_query: List[Dict[str, Collection[str]]] = []
201+
search_query.append({
202+
"ids": data_rows_ids,
203+
"operator": "is",
204+
"type": "data_row_id"
205+
})
206+
207+
print(search_query)
208+
media_type_override = _params.get('media_type_override', None)
209+
210+
if task_name is None:
211+
task_name = f"Export v2: data rows (%s)" % len(data_rows_ids)
212+
query_params = {
213+
"input": {
214+
"taskName": task_name,
215+
"filters": {
216+
"searchQuery": {
217+
"scope": None,
218+
"query": search_query
219+
}
220+
},
221+
"params": {
222+
"mediaTypeOverride":
223+
media_type_override.value
224+
if media_type_override is not None else None,
225+
"includeAttachments":
226+
_params.get('attachments', False),
227+
"includeMetadata":
228+
_params.get('metadata_fields', False),
229+
"includeDataRowDetails":
230+
_params.get('data_row_details', False),
231+
"includeProjectDetails":
232+
_params.get('project_details', False),
233+
"includePerformanceDetails":
234+
_params.get('performance_details', False),
235+
"includeLabelDetails":
236+
_params.get('label_details', False)
237+
},
238+
}
239+
}
240+
241+
res = client.execute(
242+
create_task_query_str,
243+
query_params,
244+
)
245+
print(res)
246+
res = res[mutation_name]
247+
task_id = res["taskId"]
248+
user: User = client.get_user()
249+
tasks: List[Task] = list(
250+
user.created_tasks(where=Entity.Task.uid == task_id))
251+
# Cache user in a private variable as the relationship can't be
252+
# resolved due to server-side limitations (see Task.created_by)
253+
# for more info.
254+
if len(tasks) != 1:
255+
raise ResourceNotFoundError(Entity.Task, task_id)
256+
task: Task = tasks[0]
257+
task._user = user
258+
return task

labelbox/schema/dataset.py

Lines changed: 191 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generator, List, Union, Any, TYPE_CHECKING
1+
from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING
22
import os
33
import json
44
import logging
@@ -17,9 +17,12 @@
1717
from labelbox.orm.model import Entity, Field, Relationship
1818
from labelbox.orm import query
1919
from labelbox.exceptions import MalformedQueryException
20-
21-
if TYPE_CHECKING:
22-
from labelbox import Task, User, DataRow
20+
from labelbox.schema.data_row import DataRow
21+
from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters
22+
from labelbox.schema.export_params import CatalogExportParams
23+
from labelbox.schema.project import _validate_datetime
24+
from labelbox.schema.task import Task
25+
from labelbox.schema.user import User
2326

2427
logger = logging.getLogger(__name__)
2528

@@ -534,3 +537,187 @@ def export_data_rows(self,
534537
logger.debug("Dataset '%s' data row export, waiting for server...",
535538
self.uid)
536539
time.sleep(sleep_time)
540+
541+
def export_v2(self,
542+
task_name: Optional[str] = None,
543+
filters: Optional[DatasetExportFilters] = None,
544+
params: Optional[CatalogExportParams] = None) -> Task:
545+
"""
546+
Creates a dataset export task with the given params and returns the task.
547+
548+
>>> dataset = client.get_dataset(DATASET_ID)
549+
>>> task = dataset.export_v2(
550+
>>> filters={
551+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
552+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
553+
>>> },
554+
>>> params={
555+
>>> "performance_details": False,
556+
>>> "label_details": True
557+
>>> })
558+
>>> task.wait_till_done()
559+
>>> task.result
560+
"""
561+
562+
_params = params or CatalogExportParams({
563+
"attachments": False,
564+
"metadata_fields": False,
565+
"data_row_details": False,
566+
"project_details": False,
567+
"performance_details": False,
568+
"label_details": False,
569+
"media_type_override": None,
570+
"model_runs_ids": None,
571+
"projects_ids": None,
572+
})
573+
574+
_filters = filters or DatasetExportFilters({
575+
"last_activity_at": None,
576+
"label_created_at": None
577+
})
578+
579+
def _get_timezone() -> str:
580+
timezone_query_str = """query CurrentUserPyApi { user { timezone } }"""
581+
tz_res = self.client.execute(timezone_query_str)
582+
return tz_res["user"]["timezone"] or "UTC"
583+
584+
timezone: Optional[str] = None
585+
586+
mutation_name = "exportDataRowsInCatalog"
587+
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
588+
%s(input: $input) {taskId} }
589+
""" % (mutation_name)
590+
591+
search_query: List[Dict[str, Collection[str]]] = []
592+
search_query.append({
593+
"ids": [self.uid],
594+
"operator": "is",
595+
"type": "dataset"
596+
})
597+
media_type_override = _params.get('media_type_override', None)
598+
599+
if task_name is None:
600+
task_name = f"Export v2: dataset - {self.name}"
601+
query_params = {
602+
"input": {
603+
"taskName": task_name,
604+
"filters": {
605+
"searchQuery": {
606+
"scope": None,
607+
"query": search_query
608+
}
609+
},
610+
"params": {
611+
"mediaTypeOverride":
612+
media_type_override.value
613+
if media_type_override is not None else None,
614+
"includeAttachments":
615+
_params.get('attachments', False),
616+
"includeMetadata":
617+
_params.get('metadata_fields', False),
618+
"includeDataRowDetails":
619+
_params.get('data_row_details', False),
620+
"includeProjectDetails":
621+
_params.get('project_details', False),
622+
"includePerformanceDetails":
623+
_params.get('performance_details', False),
624+
"includeLabelDetails":
625+
_params.get('label_details', False)
626+
},
627+
}
628+
}
629+
630+
if "last_activity_at" in _filters and _filters[
631+
'last_activity_at'] is not None:
632+
if timezone is None:
633+
timezone = _get_timezone()
634+
values = _filters['last_activity_at']
635+
start, end = values
636+
if (start is not None and end is not None):
637+
[_validate_datetime(date) for date in values]
638+
search_query.append({
639+
"type": "data_row_last_activity_at",
640+
"value": {
641+
"operator": "BETWEEN",
642+
"timezone": timezone,
643+
"value": {
644+
"min": start,
645+
"max": end
646+
}
647+
}
648+
})
649+
elif (start is not None):
650+
_validate_datetime(start)
651+
search_query.append({
652+
"type": "data_row_last_activity_at",
653+
"value": {
654+
"operator": "GREATER_THAN_OR_EQUAL",
655+
"timezone": timezone,
656+
"value": start
657+
}
658+
})
659+
elif (end is not None):
660+
_validate_datetime(end)
661+
search_query.append({
662+
"type": "data_row_last_activity_at",
663+
"value": {
664+
"operator": "LESS_THAN_OR_EQUAL",
665+
"timezone": timezone,
666+
"value": end
667+
}
668+
})
669+
670+
if "label_created_at" in _filters and _filters[
671+
"label_created_at"] is not None:
672+
if timezone is None:
673+
timezone = _get_timezone()
674+
values = _filters['label_created_at']
675+
start, end = values
676+
if (start is not None and end is not None):
677+
[_validate_datetime(date) for date in values]
678+
search_query.append({
679+
"type": "labeled_at",
680+
"value": {
681+
"operator": "BETWEEN",
682+
"value": {
683+
"min": start,
684+
"max": end
685+
}
686+
}
687+
})
688+
elif (start is not None):
689+
_validate_datetime(start)
690+
search_query.append({
691+
"type": "labeled_at",
692+
"value": {
693+
"operator": "GREATER_THAN_OR_EQUAL",
694+
"value": start
695+
}
696+
})
697+
elif (end is not None):
698+
_validate_datetime(end)
699+
search_query.append({
700+
"type": "labeled_at",
701+
"value": {
702+
"operator": "LESS_THAN_OR_EQUAL",
703+
"value": end
704+
}
705+
})
706+
707+
res = self.client.execute(
708+
create_task_query_str,
709+
query_params,
710+
)
711+
res = res[mutation_name]
712+
task_id = res["taskId"]
713+
user: User = self.client.get_user()
714+
tasks: List[Task] = list(
715+
user.created_tasks(where=Entity.Task.uid == task_id))
716+
# Cache user in a private variable as the relationship can't be
717+
# resolved due to server-side limitations (see Task.created_by)
718+
# for more info.
719+
if len(tasks) != 1:
720+
raise ResourceNotFoundError(Entity.Task, task_id)
721+
task: Task = tasks[0]
722+
task._user = user
723+
return task

labelbox/schema/export_filters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Tuple
1010

1111

12-
class ProjectExportFilters(TypedDict):
12+
class SharedExportFilters(TypedDict):
1313
label_created_at: Optional[Tuple[str, str]]
1414
""" Date range for labels created at
1515
Formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss"
@@ -26,3 +26,11 @@ class ProjectExportFilters(TypedDict):
2626
>>> [None, "2050-01-01 00:00:00"]
2727
>>> ["2000-01-01 00:00:00", None]
2828
"""
29+
30+
31+
class ProjectExportFilters(SharedExportFilters):
32+
pass
33+
34+
35+
class DatasetExportFilters(SharedExportFilters):
36+
pass

labelbox/schema/export_params.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22

3-
from typing import Optional
3+
from typing import Optional, List
44

55
from labelbox.schema.media_type import MediaType
66
if sys.version_info >= (3, 8):
@@ -22,6 +22,15 @@ class ProjectExportParams(DataRowParams):
2222
performance_details: Optional[bool]
2323

2424

25+
class CatalogExportParams(DataRowParams):
26+
project_details: Optional[bool]
27+
label_details: Optional[bool]
28+
performance_details: Optional[bool]
29+
model_runs_ids: Optional[List[str]]
30+
projects_ids: Optional[List[str]]
31+
pass
32+
33+
2534
class ModelRunExportParams(DataRowParams):
2635
# TODO: Add model run fields
2736
pass

0 commit comments

Comments
 (0)