|
1 |
| -from typing import Generator, List, Union, Any, TYPE_CHECKING |
| 1 | +from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING |
2 | 2 | import os
|
3 | 3 | import json
|
4 | 4 | import logging
|
|
17 | 17 | from labelbox.orm.model import Entity, Field, Relationship
|
18 | 18 | from labelbox.orm import query
|
19 | 19 | from labelbox.exceptions import MalformedQueryException
|
20 |
| - |
21 |
| -if TYPE_CHECKING: |
22 |
| - from labelbox import Task, User, DataRow |
| 20 | +from labelbox.schema.data_row import DataRow |
| 21 | +from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters |
| 22 | +from labelbox.schema.export_params import CatalogExportParams |
| 23 | +from labelbox.schema.project import _validate_datetime |
| 24 | +from labelbox.schema.task import Task |
| 25 | +from labelbox.schema.user import User |
23 | 26 |
|
24 | 27 | logger = logging.getLogger(__name__)
|
25 | 28 |
|
@@ -534,3 +537,187 @@ def export_data_rows(self,
|
534 | 537 | logger.debug("Dataset '%s' data row export, waiting for server...",
|
535 | 538 | self.uid)
|
536 | 539 | time.sleep(sleep_time)
|
| 540 | + |
| 541 | + def export_v2(self, |
| 542 | + task_name: Optional[str] = None, |
| 543 | + filters: Optional[DatasetExportFilters] = None, |
| 544 | + params: Optional[CatalogExportParams] = None) -> Task: |
| 545 | + """ |
| 546 | + Creates a dataset export task with the given params and returns the task. |
| 547 | + |
| 548 | + >>> dataset = client.get_dataset(DATASET_ID) |
| 549 | + >>> task = dataset.export_v2( |
| 550 | + >>> filters={ |
| 551 | + >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], |
| 552 | + >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] |
| 553 | + >>> }, |
| 554 | + >>> params={ |
| 555 | + >>> "performance_details": False, |
| 556 | + >>> "label_details": True |
| 557 | + >>> }) |
| 558 | + >>> task.wait_till_done() |
| 559 | + >>> task.result |
| 560 | + """ |
| 561 | + |
| 562 | + _params = params or CatalogExportParams({ |
| 563 | + "attachments": False, |
| 564 | + "metadata_fields": False, |
| 565 | + "data_row_details": False, |
| 566 | + "project_details": False, |
| 567 | + "performance_details": False, |
| 568 | + "label_details": False, |
| 569 | + "media_type_override": None, |
| 570 | + "model_runs_ids": None, |
| 571 | + "projects_ids": None, |
| 572 | + }) |
| 573 | + |
| 574 | + _filters = filters or DatasetExportFilters({ |
| 575 | + "last_activity_at": None, |
| 576 | + "label_created_at": None |
| 577 | + }) |
| 578 | + |
| 579 | + def _get_timezone() -> str: |
| 580 | + timezone_query_str = """query CurrentUserPyApi { user { timezone } }""" |
| 581 | + tz_res = self.client.execute(timezone_query_str) |
| 582 | + return tz_res["user"]["timezone"] or "UTC" |
| 583 | + |
| 584 | + timezone: Optional[str] = None |
| 585 | + |
| 586 | + mutation_name = "exportDataRowsInCatalog" |
| 587 | + create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){ |
| 588 | + %s(input: $input) {taskId} } |
| 589 | + """ % (mutation_name) |
| 590 | + |
| 591 | + search_query: List[Dict[str, Collection[str]]] = [] |
| 592 | + search_query.append({ |
| 593 | + "ids": [self.uid], |
| 594 | + "operator": "is", |
| 595 | + "type": "dataset" |
| 596 | + }) |
| 597 | + media_type_override = _params.get('media_type_override', None) |
| 598 | + |
| 599 | + if task_name is None: |
| 600 | + task_name = f"Export v2: dataset - {self.name}" |
| 601 | + query_params = { |
| 602 | + "input": { |
| 603 | + "taskName": task_name, |
| 604 | + "filters": { |
| 605 | + "searchQuery": { |
| 606 | + "scope": None, |
| 607 | + "query": search_query |
| 608 | + } |
| 609 | + }, |
| 610 | + "params": { |
| 611 | + "mediaTypeOverride": |
| 612 | + media_type_override.value |
| 613 | + if media_type_override is not None else None, |
| 614 | + "includeAttachments": |
| 615 | + _params.get('attachments', False), |
| 616 | + "includeMetadata": |
| 617 | + _params.get('metadata_fields', False), |
| 618 | + "includeDataRowDetails": |
| 619 | + _params.get('data_row_details', False), |
| 620 | + "includeProjectDetails": |
| 621 | + _params.get('project_details', False), |
| 622 | + "includePerformanceDetails": |
| 623 | + _params.get('performance_details', False), |
| 624 | + "includeLabelDetails": |
| 625 | + _params.get('label_details', False) |
| 626 | + }, |
| 627 | + } |
| 628 | + } |
| 629 | + |
| 630 | + if "last_activity_at" in _filters and _filters[ |
| 631 | + 'last_activity_at'] is not None: |
| 632 | + if timezone is None: |
| 633 | + timezone = _get_timezone() |
| 634 | + values = _filters['last_activity_at'] |
| 635 | + start, end = values |
| 636 | + if (start is not None and end is not None): |
| 637 | + [_validate_datetime(date) for date in values] |
| 638 | + search_query.append({ |
| 639 | + "type": "data_row_last_activity_at", |
| 640 | + "value": { |
| 641 | + "operator": "BETWEEN", |
| 642 | + "timezone": timezone, |
| 643 | + "value": { |
| 644 | + "min": start, |
| 645 | + "max": end |
| 646 | + } |
| 647 | + } |
| 648 | + }) |
| 649 | + elif (start is not None): |
| 650 | + _validate_datetime(start) |
| 651 | + search_query.append({ |
| 652 | + "type": "data_row_last_activity_at", |
| 653 | + "value": { |
| 654 | + "operator": "GREATER_THAN_OR_EQUAL", |
| 655 | + "timezone": timezone, |
| 656 | + "value": start |
| 657 | + } |
| 658 | + }) |
| 659 | + elif (end is not None): |
| 660 | + _validate_datetime(end) |
| 661 | + search_query.append({ |
| 662 | + "type": "data_row_last_activity_at", |
| 663 | + "value": { |
| 664 | + "operator": "LESS_THAN_OR_EQUAL", |
| 665 | + "timezone": timezone, |
| 666 | + "value": end |
| 667 | + } |
| 668 | + }) |
| 669 | + |
| 670 | + if "label_created_at" in _filters and _filters[ |
| 671 | + "label_created_at"] is not None: |
| 672 | + if timezone is None: |
| 673 | + timezone = _get_timezone() |
| 674 | + values = _filters['label_created_at'] |
| 675 | + start, end = values |
| 676 | + if (start is not None and end is not None): |
| 677 | + [_validate_datetime(date) for date in values] |
| 678 | + search_query.append({ |
| 679 | + "type": "labeled_at", |
| 680 | + "value": { |
| 681 | + "operator": "BETWEEN", |
| 682 | + "value": { |
| 683 | + "min": start, |
| 684 | + "max": end |
| 685 | + } |
| 686 | + } |
| 687 | + }) |
| 688 | + elif (start is not None): |
| 689 | + _validate_datetime(start) |
| 690 | + search_query.append({ |
| 691 | + "type": "labeled_at", |
| 692 | + "value": { |
| 693 | + "operator": "GREATER_THAN_OR_EQUAL", |
| 694 | + "value": start |
| 695 | + } |
| 696 | + }) |
| 697 | + elif (end is not None): |
| 698 | + _validate_datetime(end) |
| 699 | + search_query.append({ |
| 700 | + "type": "labeled_at", |
| 701 | + "value": { |
| 702 | + "operator": "LESS_THAN_OR_EQUAL", |
| 703 | + "value": end |
| 704 | + } |
| 705 | + }) |
| 706 | + |
| 707 | + res = self.client.execute( |
| 708 | + create_task_query_str, |
| 709 | + query_params, |
| 710 | + ) |
| 711 | + res = res[mutation_name] |
| 712 | + task_id = res["taskId"] |
| 713 | + user: User = self.client.get_user() |
| 714 | + tasks: List[Task] = list( |
| 715 | + user.created_tasks(where=Entity.Task.uid == task_id)) |
| 716 | + # Cache user in a private variable as the relationship can't be |
| 717 | + # resolved due to server-side limitations (see Task.created_by) |
| 718 | + # for more info. |
| 719 | + if len(tasks) != 1: |
| 720 | + raise ResourceNotFoundError(Entity.Task, task_id) |
| 721 | + task: Task = tasks[0] |
| 722 | + task._user = user |
| 723 | + return task |
0 commit comments