-
Notifications
You must be signed in to change notification settings - Fork 51
Add Deltalake query support #1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9d2048e
505ddfe
aee0f8c
d5a25b1
2de051d
7d0b8b7
7195adf
33b787f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| import itertools | ||
| import os | ||
| import platform | ||
| import shutil | ||
| import sys | ||
| import warnings | ||
| from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait | ||
|
|
@@ -21,7 +22,10 @@ | |
| from typing import TYPE_CHECKING, ForwardRef, Optional, get_args | ||
| from urllib.parse import quote, urljoin | ||
|
|
||
| import pyarrow as pa | ||
| import pyarrow.dataset as ds | ||
| import requests | ||
| from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake | ||
| from emmet.core.utils import jsanitize | ||
| from pydantic import BaseModel, create_model | ||
| from requests.adapters import HTTPAdapter | ||
|
|
@@ -31,7 +35,7 @@ | |
| from urllib3.util.retry import Retry | ||
|
|
||
| from mp_api.client.core.settings import MAPIClientSettings | ||
| from mp_api.client.core.utils import load_json, validate_ids | ||
| from mp_api.client.core.utils import MPDataset, load_json, validate_ids | ||
|
|
||
| try: | ||
| import boto3 | ||
|
|
@@ -83,6 +87,7 @@ class BaseRester: | |
| document_model: type[BaseModel] | None = None | ||
| supports_versions: bool = False | ||
| primary_key: str = "material_id" | ||
| delta_backed: bool = False | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -97,6 +102,8 @@ def __init__( | |
| timeout: int = 20, | ||
| headers: dict | None = None, | ||
| mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, | ||
| local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE, | ||
| force_renew: bool = False, | ||
| ): | ||
| """Initialize the REST API helper class. | ||
|
|
||
|
|
@@ -128,6 +135,9 @@ def __init__( | |
| timeout: Time in seconds to wait until a request timeout error is thrown | ||
| headers: Custom headers for localhost connections. | ||
| mute_progress_bars: Whether to disable progress bars. | ||
| local_dataset_cache: Target directory for downloading full datasets. Defaults | ||
| to 'mp_datasets' in the user's home directory | ||
| force_renew: Option to overwrite existing local dataset | ||
| """ | ||
| # TODO: think about how to migrate from PMG_MAPI_KEY | ||
| self.api_key = api_key or os.getenv("MP_API_KEY") | ||
|
|
@@ -141,6 +151,8 @@ def __init__( | |
| self.timeout = timeout | ||
| self.headers = headers or {} | ||
| self.mute_progress_bars = mute_progress_bars | ||
| self.local_dataset_cache = local_dataset_cache | ||
| self.force_renew = force_renew | ||
| self.db_version = BaseRester._get_database_version(self.endpoint) | ||
|
|
||
| if self.suffix: | ||
|
|
@@ -368,10 +380,7 @@ def _patch_resource( | |
| raise MPRestError(str(ex)) | ||
|
|
||
| def _query_open_data( | ||
| self, | ||
| bucket: str, | ||
| key: str, | ||
| decoder: Callable | None = None, | ||
| self, bucket: str, key: str, decoder: Callable | None = None | ||
| ) -> tuple[list[dict] | list[bytes], int]: | ||
| """Query and deserialize Materials Project AWS open data s3 buckets. | ||
|
|
||
|
|
@@ -471,6 +480,12 @@ def _query_resource( | |
| url += "/" | ||
|
|
||
| if query_s3: | ||
| pbar_message = ( # type: ignore | ||
| f"Retrieving {self.document_model.__name__} documents" # type: ignore | ||
| if self.document_model is not None | ||
| else "Retrieving documents" | ||
| ) | ||
|
|
||
| db_version = self.db_version.replace(".", "-") | ||
| if "/" not in self.suffix: | ||
| suffix = self.suffix | ||
|
|
@@ -481,15 +496,168 @@ def _query_resource( | |
| suffix = infix if suffix == "core" else suffix | ||
| suffix = suffix.replace("_", "-") | ||
|
|
||
| # Paginate over all entries in the bucket. | ||
| # TODO: change when a subset of entries needed from DB | ||
| # Check if user has access to GNoMe | ||
| # temp suppress tqdm | ||
| re_enable = not self.mute_progress_bars | ||
| self.mute_progress_bars = True | ||
| has_gnome_access = bool( | ||
| self._submit_requests( | ||
| url=urljoin( | ||
| "https://api.materialsproject.org/", "materials/summary/" | ||
| ), | ||
| criteria={ | ||
| "batch_id": "gnome_r2scan_statics", | ||
| "_fields": "material_id", | ||
| }, | ||
| use_document_model=False, | ||
| num_chunks=1, | ||
| chunk_size=1, | ||
| timeout=timeout, | ||
| ) | ||
| .get("meta", {}) | ||
| .get("total_doc", 0) | ||
| ) | ||
| self.mute_progress_bars = not re_enable | ||
|
|
||
| if "tasks" in suffix: | ||
| bucket_suffix, prefix = "parsed", "tasks_atomate2" | ||
| bucket_suffix, prefix = ("parsed", "core/tasks/") | ||
| else: | ||
| bucket_suffix = "build" | ||
| prefix = f"collections/{db_version}/{suffix}" | ||
|
|
||
| bucket = f"materialsproject-{bucket_suffix}" | ||
|
|
||
| if self.delta_backed: | ||
| target_path = ( | ||
| self.local_dataset_cache + f"/{bucket_suffix}/{prefix}" | ||
| ) | ||
| os.makedirs(target_path, exist_ok=True) | ||
|
|
||
| if DeltaTable.is_deltatable(target_path): | ||
| if self.force_renew: | ||
| shutil.rmtree(target_path) | ||
| warnings.warn( | ||
| f"Regenerating {suffix} dataset at {target_path}...", | ||
| MPLocalDatasetWarning, | ||
| ) | ||
| os.makedirs(target_path, exist_ok=True) | ||
| else: | ||
| warnings.warn( | ||
| f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset " | ||
| "or re-run search query with MPRester(force_renew=True)", | ||
| MPLocalDatasetWarning, | ||
| ) | ||
|
|
||
| return { | ||
| "data": MPDataset( | ||
| path=target_path, | ||
| document_model=self.document_model, | ||
| use_document_model=self.use_document_model, | ||
| ) | ||
| } | ||
|
|
||
| tbl = DeltaTable( | ||
| f"s3a://{bucket}/{prefix}", | ||
| storage_options={ | ||
| "AWS_SKIP_SIGNATURE": "true", | ||
| "AWS_REGION": "us-east-1", | ||
| }, | ||
| ) | ||
|
|
||
| controlled_batch_str = ",".join( | ||
| [f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS] | ||
| ) | ||
|
|
||
| predicate = ( | ||
| " WHERE batch_id NOT IN (" # don't delete leading space | ||
| + controlled_batch_str | ||
| + ")" | ||
| if not has_gnome_access | ||
| else "" | ||
| ) | ||
|
|
||
| builder = QueryBuilder().register("tbl", tbl) | ||
|
|
||
| # Setup progress bar | ||
| num_docs_needed = pa.table( | ||
| builder.execute("SELECT COUNT(*) FROM tbl").read_all() | ||
| )[0][0].as_py() | ||
|
|
||
| # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator | ||
| # -> need to modify BatchIdQuery operator to handle root level | ||
| # batch_id, not only builder_meta.batch_id | ||
| # if not has_gnome_access: | ||
| # num_docs_needed = self.count( | ||
| # {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS} | ||
| # ) | ||
|
|
||
| pbar = ( | ||
| tqdm( | ||
| desc=pbar_message, | ||
| total=num_docs_needed, | ||
| ) | ||
| if not self.mute_progress_bars | ||
| else None | ||
| ) | ||
|
|
||
| iterator = builder.execute("SELECT * FROM tbl" + predicate) | ||
|
|
||
| file_options = ds.ParquetFileFormat().make_write_options( | ||
| compression="zstd" | ||
| ) | ||
|
|
||
| def _flush(accumulator, group): | ||
| ds.write_dataset( | ||
| accumulator, | ||
| base_dir=target_path, | ||
| format="parquet", | ||
| basename_template=f"group-{group}-" | ||
| + "part-{i}.zstd.parquet", | ||
| existing_data_behavior="overwrite_or_ignore", | ||
| max_rows_per_group=1024, | ||
| file_options=file_options, | ||
| ) | ||
|
|
||
| group = 1 | ||
| size = 0 | ||
| accumulator = [] | ||
| for page in iterator: | ||
| # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer | ||
| accumulator.append(pa.record_batch(page)) | ||
| page_size = page.num_rows | ||
| size += page_size | ||
|
|
||
| if pbar is not None: | ||
| pbar.update(page_size) | ||
|
|
||
| if size >= SETTINGS.DATASET_FLUSH_THRESHOLD: | ||
| _flush(accumulator, group) | ||
| group += 1 | ||
| size = 0 | ||
| accumulator = [] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| if accumulator: | ||
| _flush(accumulator, group + 1) | ||
|
|
||
| convert_to_deltalake(target_path) | ||
|
|
||
| warnings.warn( | ||
| f"Dataset for {suffix} written to {target_path}. It is recommended to optimize " | ||
| "the table according to your usage patterns prior to running intensive workloads, " | ||
| "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout", | ||
| MPLocalDatasetWarning, | ||
| ) | ||
|
|
||
| return { | ||
| "data": MPDataset( | ||
| path=target_path, | ||
| document_model=self.document_model, | ||
| use_document_model=self.use_document_model, | ||
| ) | ||
| } | ||
|
|
||
| # Paginate over all entries in the bucket. | ||
| # TODO: change when a subset of entries needed from DB | ||
| paginator = self.s3_client.get_paginator("list_objects_v2") | ||
| pages = paginator.paginate(Bucket=bucket, Prefix=prefix) | ||
|
|
||
|
|
@@ -526,11 +694,6 @@ def _query_resource( | |
| } | ||
|
|
||
| # Setup progress bar | ||
| pbar_message = ( # type: ignore | ||
| f"Retrieving {self.document_model.__name__} documents" # type: ignore | ||
| if self.document_model is not None | ||
| else "Retrieving documents" | ||
| ) | ||
| num_docs_needed = int(self.count()) | ||
| pbar = ( | ||
| tqdm( | ||
|
|
@@ -1359,3 +1522,7 @@ class MPRestError(Exception): | |
|
|
||
| class MPRestWarning(Warning): | ||
| """Raised when a query is malformed but interpretable.""" | ||
|
|
||
|
|
||
| class MPLocalDatasetWarning(Warning): | ||
| """Raised when unrecoverable actions are performed on a local dataset.""" | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings): | |||||||||||||||||||||||||||||
| _MAX_LIST_LENGTH, description="Maximum length of query parameter list" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| LOCAL_DATASET_CACHE: str = Field( | ||||||||||||||||||||||||||||||
| os.path.expanduser("~") + "/mp_datasets", | ||||||||||||||||||||||||||||||
| description="Target directory for downloading full datasets", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| DATASET_FLUSH_THRESHOLD: int = Field( | ||||||||||||||||||||||||||||||
| 100000, | ||||||||||||||||||||||||||||||
| description="Threshold number of rows to accumulate in memory before flushing dataset to disk", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field( | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. settings is the right place for the first two. Access-controlled batch ids should probably be hardcoded and change with client releases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. akin to one of these? Lines 475 to 488 in 254c7d0
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re: the use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah no, I mean for the access controlled batch ids. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's a good idea. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to start a PR to add it to the |
||||||||||||||||||||||||||||||
| ["gnome_r2scan_statics"], description="Batch ids with access restrictions" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| model_config = SettingsConfigDict(env_prefix="MPRESTER_") | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
self.endpointUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be
self.base_endpoint? I think this tripped me up when I tried usingself.endpointoriginallyapi/mp_api/client/core/client.py
Lines 134 to 135 in 254c7d0
for the
tasksrester ->self.endpointcaused theurljoinhere to yield something like{base_url}/materials/tasks/materials/summaryThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right,
self.base_endpointshould work.