Skip to content

Commit aee0f8c

Browse files
committed
add deltalake query support
1 parent 505ddfe commit aee0f8c

File tree

5 files changed

+258
-26
lines changed

5 files changed

+258
-26
lines changed

mp_api/client/core/client.py

Lines changed: 166 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import itertools
1010
import os
1111
import platform
12+
import shutil
1213
import sys
1314
import warnings
1415
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
@@ -18,15 +19,13 @@
1819
from importlib.metadata import PackageNotFoundError, version
1920
from json import JSONDecodeError
2021
from math import ceil
21-
from typing import (
22-
TYPE_CHECKING,
23-
ForwardRef,
24-
Optional,
25-
get_args,
26-
)
22+
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
2723
from urllib.parse import quote, urljoin
2824

25+
import pyarrow as pa
26+
import pyarrow.dataset as ds
2927
import requests
28+
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
3029
from emmet.core.utils import jsanitize
3130
from pydantic import BaseModel, create_model
3231
from requests.adapters import HTTPAdapter
@@ -36,7 +35,7 @@
3635
from urllib3.util.retry import Retry
3736

3837
from mp_api.client.core.settings import MAPIClientSettings
39-
from mp_api.client.core.utils import load_json, validate_ids
38+
from mp_api.client.core.utils import MPDataset, load_json, validate_ids
4039

4140
try:
4241
import boto3
@@ -71,6 +70,7 @@ class BaseRester:
7170
document_model: type[BaseModel] | None = None
7271
supports_versions: bool = False
7372
primary_key: str = "material_id"
73+
delta_backed: bool = False
7474

7575
def __init__(
7676
self,
@@ -85,6 +85,8 @@ def __init__(
8585
timeout: int = 20,
8686
headers: dict | None = None,
8787
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
88+
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
89+
force_renew: bool = False,
8890
):
8991
"""Initialize the REST API helper class.
9092
@@ -116,6 +118,9 @@ def __init__(
116118
timeout: Time in seconds to wait until a request timeout error is thrown
117119
headers: Custom headers for localhost connections.
118120
mute_progress_bars: Whether to disable progress bars.
121+
local_dataset_cache: Target directory for downloading full datasets. Defaults
122+
to 'materialsproject_datasets' in the user's home directory
123+
force_renew: Option to overwrite existing local dataset
119124
"""
120125
# TODO: think about how to migrate from PMG_MAPI_KEY
121126
self.api_key = api_key or os.getenv("MP_API_KEY")
@@ -129,6 +134,8 @@ def __init__(
129134
self.timeout = timeout
130135
self.headers = headers or {}
131136
self.mute_progress_bars = mute_progress_bars
137+
self.local_dataset_cache = local_dataset_cache
138+
self.force_renew = force_renew
132139
self.db_version = BaseRester._get_database_version(self.endpoint)
133140

134141
if self.suffix:
@@ -212,7 +219,7 @@ def _get_database_version(endpoint):
212219
remains unchanged and available for querying via its task_id.
213220
214221
The database version is set as a date in the format YYYY_MM_DD,
215-
where "_DD" may be optional. An additional numerical or `postN` suffix
222+
predicate "_DD" may be optional. An additional numerical or `postN` suffix
216223
might be added if multiple releases happen on the same day.
217224
218225
Returns: database version as a string
@@ -356,10 +363,7 @@ def _patch_resource(
356363
raise MPRestError(str(ex))
357364

358365
def _query_open_data(
359-
self,
360-
bucket: str,
361-
key: str,
362-
decoder: Callable | None = None,
366+
self, bucket: str, key: str, decoder: Callable | None = None
363367
) -> tuple[list[dict] | list[bytes], int]:
364368
"""Query and deserialize Materials Project AWS open data s3 buckets.
365369
@@ -463,6 +467,12 @@ def _query_resource(
463467
url += "/"
464468

465469
if query_s3:
470+
pbar_message = ( # type: ignore
471+
f"Retrieving {self.document_model.__name__} documents" # type: ignore
472+
if self.document_model is not None
473+
else "Retrieving documents"
474+
)
475+
466476
db_version = self.db_version.replace(".", "-")
467477
if "/" not in self.suffix:
468478
suffix = self.suffix
@@ -474,9 +484,14 @@ def _query_resource(
474484
suffix = suffix.replace("_", "-")
475485

476486
# Check if user has access to GNoMe
487+
# temp suppress tqdm
488+
re_enable = not self.mute_progress_bars
489+
self.mute_progress_bars = True
477490
has_gnome_access = bool(
478491
self._submit_requests(
479-
url=urljoin(self.endpoint, "materials/summary/"),
492+
url=urljoin(
493+
"https://api.materialsproject.org/", "materials/summary/"
494+
),
480495
criteria={
481496
"batch_id": "gnome_r2scan_statics",
482497
"_fields": "material_id",
@@ -489,21 +504,147 @@ def _query_resource(
489504
.get("meta", {})
490505
.get("total_doc", 0)
491506
)
507+
self.mute_progress_bars = not re_enable
492508

493-
# Paginate over all entries in the bucket.
494-
# TODO: change when a subset of entries needed from DB
495509
if "tasks" in suffix:
496-
bucket_suffix, prefix = "parsed", "tasks_atomate2"
510+
bucket_suffix, prefix = ("parsed", "core/tasks/")
497511
else:
498512
bucket_suffix = "build"
499513
prefix = f"collections/{db_version}/{suffix}"
500514

501-
# only include prefixes accessible to user
502-
# i.e. append `batch_id=others/core` to `prefix`
503-
if not has_gnome_access:
504-
prefix += "/batch_id=others"
505-
506515
bucket = f"materialsproject-{bucket_suffix}"
516+
517+
if self.delta_backed:
518+
target_path = (
519+
self.local_dataset_cache + f"/{bucket_suffix}/{prefix}"
520+
)
521+
os.makedirs(target_path, exist_ok=True)
522+
523+
if DeltaTable.is_deltatable(target_path):
524+
if self.force_renew:
525+
shutil.rmtree(target_path)
526+
warnings.warn(
527+
f"Regenerating {suffix} dataset at {target_path}...",
528+
MPLocalDatasetWarning,
529+
)
530+
os.makedirs(target_path, exist_ok=True)
531+
else:
532+
warnings.warn(
533+
f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset "
534+
"or re-run search query with MPRester(force_renew=True)",
535+
MPLocalDatasetWarning,
536+
)
537+
538+
return {
539+
"data": MPDataset(
540+
path=target_path,
541+
document_model=self.document_model,
542+
use_document_model=self.use_document_model,
543+
)
544+
}
545+
546+
tbl = DeltaTable(
547+
f"s3a://{bucket}/{prefix}",
548+
storage_options={
549+
"AWS_SKIP_SIGNATURE": "true",
550+
"AWS_REGION": "us-east-1",
551+
},
552+
)
553+
554+
controlled_batch_str = ",".join(
555+
[f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS]
556+
)
557+
558+
predicate = (
559+
" WHERE batch_id NOT IN (" # don't delete leading space
560+
+ controlled_batch_str
561+
+ ")"
562+
if not has_gnome_access
563+
else ""
564+
)
565+
566+
builder = QueryBuilder().register("tbl", tbl)
567+
568+
# Setup progress bar
569+
num_docs_needed = pa.table(
570+
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
571+
)[0][0].as_py()
572+
573+
# TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator
574+
# -> need to modify BatchIdQuery operator to handle root level
575+
# batch_id, not only builder_meta.batch_id
576+
# if not has_gnome_access:
577+
# num_docs_needed = self.count(
578+
# {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS}
579+
# )
580+
581+
pbar = (
582+
tqdm(
583+
desc=pbar_message,
584+
total=num_docs_needed,
585+
)
586+
if not self.mute_progress_bars
587+
else None
588+
)
589+
590+
iterator = builder.execute("SELECT * FROM tbl" + predicate)
591+
592+
file_options = ds.ParquetFileFormat().make_write_options(
593+
compression="zstd"
594+
)
595+
596+
def _flush(accumulator, group):
597+
ds.write_dataset(
598+
accumulator,
599+
base_dir=target_path,
600+
format="parquet",
601+
basename_template=f"group-{group}-"
602+
+ "part-{i}.zstd.parquet",
603+
existing_data_behavior="overwrite_or_ignore",
604+
max_rows_per_group=1024,
605+
file_options=file_options,
606+
)
607+
608+
group = 1
609+
size = 0
610+
accumulator = []
611+
for page in iterator:
612+
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
613+
accumulator.append(pa.record_batch(page))
614+
page_size = page.num_rows
615+
size += page_size
616+
617+
if pbar is not None:
618+
pbar.update(page_size)
619+
620+
if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
621+
_flush(accumulator, group)
622+
group += 1
623+
size = 0
624+
accumulator = []
625+
626+
if accumulator:
627+
_flush(accumulator, group + 1)
628+
629+
convert_to_deltalake(target_path)
630+
631+
warnings.warn(
632+
f"Dataset for {suffix} written to {target_path}. It is recommended to optimize "
633+
"the table according to your usage patterns prior to running intensive workloads, "
634+
"see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout",
635+
MPLocalDatasetWarning,
636+
)
637+
638+
return {
639+
"data": MPDataset(
640+
path=target_path,
641+
document_model=self.document_model,
642+
use_document_model=self.use_document_model,
643+
)
644+
}
645+
646+
# Paginate over all entries in the bucket.
647+
# TODO: change when a subset of entries needed from DB
507648
paginator = self.s3_client.get_paginator("list_objects_v2")
508649
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
509650

@@ -540,11 +681,6 @@ def _query_resource(
540681
}
541682

542683
# Setup progress bar
543-
pbar_message = ( # type: ignore
544-
f"Retrieving {self.document_model.__name__} documents" # type: ignore
545-
if self.document_model is not None
546-
else "Retrieving documents"
547-
)
548684
num_docs_needed = int(self.count())
549685
pbar = (
550686
tqdm(
@@ -1372,3 +1508,7 @@ class MPRestError(Exception):
13721508

13731509
class MPRestWarning(Warning):
13741510
"""Raised when a query is malformed but interpretable."""
1511+
1512+
1513+
class MPLocalDatasetWarning(Warning):
1514+
"""Raised when unrecoverable actions are performed on a local dataset."""

mp_api/client/core/settings.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings):
8787
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
8888
)
8989

90+
LOCAL_DATASET_CACHE: str = Field(
91+
os.path.expanduser("~") + "/mp_datasets",
92+
description="Target directory for downloading full datasets",
93+
)
94+
95+
DATASET_FLUSH_THRESHOLD: int = Field(
96+
100000,
97+
description="Threshold number of rows to accumulate in memory before flushing dataset to disk",
98+
)
99+
100+
ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field(
101+
["gnome_r2scan_statics"], description="Batch ids with access restrictions"
102+
)
103+
90104
model_config = SettingsConfigDict(env_prefix="MPRESTER_")

mp_api/client/core/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from __future__ import annotations
22

33
import re
4+
from functools import cached_property
5+
from itertools import chain
46
from typing import TYPE_CHECKING, Literal
57

68
import orjson
9+
import pyarrow.dataset as ds
10+
from deltalake import DeltaTable
711
from emmet.core import __version__ as _EMMET_CORE_VER
812
from monty.json import MontyDecoder
913
from packaging.version import parse as parse_version
14+
from pydantic._internal._model_construction import ModelMetaclass
1015

1116
from mp_api.client.core.settings import MAPIClientSettings
1217

@@ -124,3 +129,61 @@ def validate_monty(cls, v, _):
124129
monty_cls.validate_monty_v2 = classmethod(validate_monty)
125130

126131
return monty_cls
132+
133+
134+
class MPDataset:
135+
def __init__(self, path, document_model, use_document_model):
136+
self._start = 0
137+
self._path = path
138+
self._document_model = document_model
139+
self._dataset = ds.dataset(path)
140+
self._row_groups = list(
141+
chain.from_iterable(
142+
[
143+
fragment.split_by_row_group()
144+
for fragment in self._dataset.get_fragments()
145+
]
146+
)
147+
)
148+
self._use_document_model = use_document_model
149+
150+
@property
151+
def pyarrow_dataset(self) -> ds.Dataset:
152+
return self._dataset
153+
154+
@property
155+
def pydantic_model(self) -> ModelMetaclass:
156+
return self._document_model
157+
158+
@property
159+
def use_document_model(self) -> bool:
160+
return self._use_document_model
161+
162+
@use_document_model.setter
163+
def use_document_model(self, value: bool):
164+
self._use_document_model = value
165+
166+
@cached_property
167+
def delta_table(self) -> DeltaTable:
168+
return DeltaTable(self._path)
169+
170+
@cached_property
171+
def num_chunks(self) -> int:
172+
return len(self._row_groups)
173+
174+
def __getitem__(self, idx):
175+
return list(
176+
map(
177+
lambda x: self._document_model(**x) if self._use_document_model else x,
178+
self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"),
179+
)
180+
)
181+
182+
def __len__(self) -> int:
183+
return self.num_chunks
184+
185+
def __iter__(self):
186+
current = self._start
187+
while current < self.num_chunks:
188+
yield self[current]
189+
current += 1

0 commit comments

Comments
 (0)