From 9d2048e7041bac113e17bb148211fece1ac2fe41 Mon Sep 17 00:00:00 2001 From: Patrick Huck Date: Wed, 5 Mar 2025 11:37:09 -0800 Subject: [PATCH 1/7] exclude gnome for full downloads if needed --- mp_api/client/core/client.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 4cd98958..50d86292 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -473,6 +473,23 @@ def _query_resource( suffix = infix if suffix == "core" else suffix suffix = suffix.replace("_", "-") + # Check if user has access to GNoMe + has_gnome_access = bool( + self._submit_requests( + url=urljoin(self.endpoint, "materials/summary/"), + criteria={ + "batch_id": "gnome_r2scan_statics", + "_fields": "material_id", + }, + use_document_model=False, + num_chunks=1, + chunk_size=1, + timeout=timeout, + ) + .get("meta", {}) + .get("total_doc", 0) + ) + # Paginate over all entries in the bucket. # TODO: change when a subset of entries needed from DB if "tasks" in suffix: @@ -481,6 +498,11 @@ def _query_resource( bucket_suffix = "build" prefix = f"collections/{db_version}/{suffix}" + # only include prefixes accessible to user + # i.e. append `batch_id=others/core` to `prefix` + if not has_gnome_access: + prefix += "/batch_id=others" + bucket = f"materialsproject-{bucket_suffix}" paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) From 505ddfe0c311ef4de64d6dd9c19ea78c41b75754 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 22 Oct 2025 17:20:29 -0700 Subject: [PATCH 2/7] query s3 for trajectories --- mp_api/client/routes/materials/tasks.py | 33 +++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c7865078..b0854d42 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,7 +3,11 @@ from datetime import datetime from typing import TYPE_CHECKING +import pyarrow as pa +from deltalake import DeltaTable, QueryBuilder +from emmet.core.mpid import AlphaID from emmet.core.tasks import CoreTaskDoc +from emmet.core.trajectory import RelaxTrajectory from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.utils import validate_ids @@ -16,6 +20,7 @@ class TaskRester(BaseRester): suffix: str = "materials/tasks" document_model: type[BaseModel] = CoreTaskDoc # type: ignore primary_key: str = "task_id" + delta_backed = True def get_trajectory(self, task_id): """Returns a Trajectory object containing the geometry of the @@ -26,16 +31,30 @@ def get_trajectory(self, task_id): task_id (str): Task ID """ - traj_data = self._query_resource_data( - {"task_ids": [task_id]}, suburl="trajectory/", use_document_model=False - )[0].get( - "trajectories", None - ) # type: ignore + as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - if traj_data is None: + traj_tbl = DeltaTable( + "s3a://materialsproject-parsed/core/trajectories/", + storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + ) + + traj_data = pa.table( + QueryBuilder() + .register("traj", traj_tbl) + .execute( + f""" + SELECT * + FROM traj + WHERE identifier='{as_alpha}' + """ + ) + .read_all() + ).to_pylist(maps_as_pydicts="strict") + + if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return traj_data + return RelaxTrajectory(**traj_data[0]).to_pmg() def search( self, From aee0f8c117e01e514604b4c5996f144a4c3b560d Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 08:39:17 -0700 Subject: [PATCH 3/7] add deltalake query support --- mp_api/client/core/client.py | 192 ++++++++++++++++++++++++++++----- mp_api/client/core/settings.py | 14 +++ mp_api/client/core/utils.py | 63 +++++++++++ mp_api/client/mprester.py | 13 +++ pyproject.toml | 2 + 5 files changed, 258 insertions(+), 26 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 50d86292..5234ef34 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -9,6 +9,7 @@ import itertools import os import platform +import shutil import sys import warnings from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait @@ -18,15 +19,13 @@ from importlib.metadata import PackageNotFoundError, version from json import JSONDecodeError from math import ceil -from typing import ( - TYPE_CHECKING, - ForwardRef, - Optional, - get_args, -) +from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import quote, urljoin +import pyarrow as pa +import pyarrow.dataset as ds import requests +from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake from emmet.core.utils import jsanitize from pydantic import BaseModel, create_model from requests.adapters import HTTPAdapter @@ -36,7 +35,7 @@ from urllib3.util.retry import Retry from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import MPDataset, load_json, validate_ids try: import boto3 @@ -71,6 +70,7 @@ class BaseRester: document_model: type[BaseModel] | None = None supports_versions: bool = False primary_key: str = "material_id" + delta_backed: bool = False def __init__( self, @@ -85,6 +85,8 @@ def __init__( timeout: int = 20, headers: dict | None = None, mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the REST API helper class. @@ -116,6 +118,9 @@ def __init__( timeout: Time in seconds to wait until a request timeout error is thrown headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'materialsproject_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset """ # TODO: think about how to migrate from PMG_MAPI_KEY self.api_key = api_key or os.getenv("MP_API_KEY") @@ -129,6 +134,8 @@ def __init__( self.timeout = timeout self.headers = headers or {} self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self.db_version = BaseRester._get_database_version(self.endpoint) if self.suffix: @@ -212,7 +219,7 @@ def _get_database_version(endpoint): remains unchanged and available for querying via its task_id. The database version is set as a date in the format YYYY_MM_DD, - where "_DD" may be optional. An additional numerical or `postN` suffix + predicate "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. Returns: database version as a string @@ -356,10 +363,7 @@ def _patch_resource( raise MPRestError(str(ex)) def _query_open_data( - self, - bucket: str, - key: str, - decoder: Callable | None = None, + self, bucket: str, key: str, decoder: Callable | None = None ) -> tuple[list[dict] | list[bytes], int]: """Query and deserialize Materials Project AWS open data s3 buckets. @@ -463,6 +467,12 @@ def _query_resource( url += "/" if query_s3: + pbar_message = ( # type: ignore + f"Retrieving {self.document_model.__name__} documents" # type: ignore + if self.document_model is not None + else "Retrieving documents" + ) + db_version = self.db_version.replace(".", "-") if "/" not in self.suffix: suffix = self.suffix @@ -474,9 +484,14 @@ def _query_resource( suffix = suffix.replace("_", "-") # Check if user has access to GNoMe + # temp suppress tqdm + re_enable = not self.mute_progress_bars + self.mute_progress_bars = True has_gnome_access = bool( self._submit_requests( - url=urljoin(self.endpoint, "materials/summary/"), + url=urljoin( + "https://api.materialsproject.org/", "materials/summary/" + ), criteria={ "batch_id": "gnome_r2scan_statics", "_fields": "material_id", @@ -489,21 +504,147 @@ def _query_resource( .get("meta", {}) .get("total_doc", 0) ) + self.mute_progress_bars = not re_enable - # Paginate over all entries in the bucket. - # TODO: change when a subset of entries needed from DB if "tasks" in suffix: - bucket_suffix, prefix = "parsed", "tasks_atomate2" + bucket_suffix, prefix = ("parsed", "core/tasks/") else: bucket_suffix = "build" prefix = f"collections/{db_version}/{suffix}" - # only include prefixes accessible to user - # i.e. append `batch_id=others/core` to `prefix` - if not has_gnome_access: - prefix += "/batch_id=others" - bucket = f"materialsproject-{bucket_suffix}" + + if self.delta_backed: + target_path = ( + self.local_dataset_cache + f"/{bucket_suffix}/{prefix}" + ) + os.makedirs(target_path, exist_ok=True) + + if DeltaTable.is_deltatable(target_path): + if self.force_renew: + shutil.rmtree(target_path) + warnings.warn( + f"Regenerating {suffix} dataset at {target_path}...", + MPLocalDatasetWarning, + ) + os.makedirs(target_path, exist_ok=True) + else: + warnings.warn( + f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset " + "or re-run search query with MPRester(force_renew=True)", + MPLocalDatasetWarning, + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + tbl = DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + + controlled_batch_str = ",".join( + [f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS] + ) + + predicate = ( + " WHERE batch_id NOT IN (" # don't delete leading space + + controlled_batch_str + + ")" + if not has_gnome_access + else "" + ) + + builder = QueryBuilder().register("tbl", tbl) + + # Setup progress bar + num_docs_needed = pa.table( + builder.execute("SELECT COUNT(*) FROM tbl").read_all() + )[0][0].as_py() + + # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator + # -> need to modify BatchIdQuery operator to handle root level + # batch_id, not only builder_meta.batch_id + # if not has_gnome_access: + # num_docs_needed = self.count( + # {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS} + # ) + + pbar = ( + tqdm( + desc=pbar_message, + total=num_docs_needed, + ) + if not self.mute_progress_bars + else None + ) + + iterator = builder.execute("SELECT * FROM tbl" + predicate) + + file_options = ds.ParquetFileFormat().make_write_options( + compression="zstd" + ) + + def _flush(accumulator, group): + ds.write_dataset( + accumulator, + base_dir=target_path, + format="parquet", + basename_template=f"group-{group}-" + + "part-{i}.zstd.parquet", + existing_data_behavior="overwrite_or_ignore", + max_rows_per_group=1024, + file_options=file_options, + ) + + group = 1 + size = 0 + accumulator = [] + for page in iterator: + # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer + accumulator.append(pa.record_batch(page)) + page_size = page.num_rows + size += page_size + + if pbar is not None: + pbar.update(page_size) + + if size >= SETTINGS.DATASET_FLUSH_THRESHOLD: + _flush(accumulator, group) + group += 1 + size = 0 + accumulator = [] + + if accumulator: + _flush(accumulator, group + 1) + + convert_to_deltalake(target_path) + + warnings.warn( + f"Dataset for {suffix} written to {target_path}. It is recommended to optimize " + "the table according to your usage patterns prior to running intensive workloads, " + "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout", + MPLocalDatasetWarning, + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + # Paginate over all entries in the bucket. + # TODO: change when a subset of entries needed from DB paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) @@ -540,11 +681,6 @@ def _query_resource( } # Setup progress bar - pbar_message = ( # type: ignore - f"Retrieving {self.document_model.__name__} documents" # type: ignore - if self.document_model is not None - else "Retrieving documents" - ) num_docs_needed = int(self.count()) pbar = ( tqdm( @@ -1372,3 +1508,7 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + + +class MPLocalDatasetWarning(Warning): + """Raised when unrecoverable actions are performed on a local dataset.""" diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 200b6778..9dbc6a38 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings): _MAX_LIST_LENGTH, description="Maximum length of query parameter list" ) + LOCAL_DATASET_CACHE: str = Field( + os.path.expanduser("~") + "/mp_datasets", + description="Target directory for downloading full datasets", + ) + + DATASET_FLUSH_THRESHOLD: int = Field( + 100000, + description="Threshold number of rows to accumulate in memory before flushing dataset to disk", + ) + + ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field( + ["gnome_r2scan_statics"], description="Batch ids with access restrictions" + ) + model_config = SettingsConfigDict(env_prefix="MPRESTER_") diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index c2d03fec..8fb48c14 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,12 +1,17 @@ from __future__ import annotations import re +from functools import cached_property +from itertools import chain from typing import TYPE_CHECKING, Literal import orjson +import pyarrow.dataset as ds +from deltalake import DeltaTable from emmet.core import __version__ as _EMMET_CORE_VER from monty.json import MontyDecoder from packaging.version import parse as parse_version +from pydantic._internal._model_construction import ModelMetaclass from mp_api.client.core.settings import MAPIClientSettings @@ -124,3 +129,61 @@ def validate_monty(cls, v, _): monty_cls.validate_monty_v2 = classmethod(validate_monty) return monty_cls + + +class MPDataset: + def __init__(self, path, document_model, use_document_model): + self._start = 0 + self._path = path + self._document_model = document_model + self._dataset = ds.dataset(path) + self._row_groups = list( + chain.from_iterable( + [ + fragment.split_by_row_group() + for fragment in self._dataset.get_fragments() + ] + ) + ) + self._use_document_model = use_document_model + + @property + def pyarrow_dataset(self) -> ds.Dataset: + return self._dataset + + @property + def pydantic_model(self) -> ModelMetaclass: + return self._document_model + + @property + def use_document_model(self) -> bool: + return self._use_document_model + + @use_document_model.setter + def use_document_model(self, value: bool): + self._use_document_model = value + + @cached_property + def delta_table(self) -> DeltaTable: + return DeltaTable(self._path) + + @cached_property + def num_chunks(self) -> int: + return len(self._row_groups) + + def __getitem__(self, idx): + return list( + map( + lambda x: self._document_model(**x) if self._use_document_model else x, + self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"), + ) + ) + + def __len__(self) -> int: + return self.num_chunks + + def __iter__(self): + current = self._start + while current < self.num_chunks: + yield self[current] + current += 1 diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 3fdc07f9..5537736a 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -133,6 +133,8 @@ def __init__( session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = _MAPI_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the MPRester. @@ -167,6 +169,9 @@ def __init__( session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to "materialsproject_datasets" in the user's home directory + force_renew: Option to overwrite existing local dataset """ # SETTINGS tries to read API key from ~/.config/.pmgrc.yaml @@ -192,6 +197,8 @@ def __init__( self.use_document_model = use_document_model self.monty_decode = monty_decode self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._contribs = None self._deprecated_attributes = [ @@ -267,6 +274,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) for cls in self._all_resters if cls.suffix in core_suffix @@ -293,6 +302,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( self, @@ -323,6 +334,8 @@ def __core_custom_getattr(_self, _attr, _rester_map): use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( diff --git a/pyproject.toml b/pyproject.toml index f202666c..063e8c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dependencies = [ "smart_open", "boto3", "orjson >= 3.10,<4", + "pyarrow >= 20.0.0", + "deltalake >= 1.2.0", ] dynamic = ["version"] From d5a25b19ca037771010f7c743ce3bae266aba0e6 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 09:01:57 -0700 Subject: [PATCH 4/7] linting + mistaken sed replace on 'where' --- mp_api/client/core/client.py | 2 +- mp_api/client/core/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 5234ef34..c8a49233 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -219,7 +219,7 @@ def _get_database_version(endpoint): remains unchanged and available for querying via its task_id. The database version is set as a date in the format YYYY_MM_DD, - predicate "_DD" may be optional. An additional numerical or `postN` suffix + where "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. Returns: database version as a string diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 8fb48c14..9e7003ed 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -133,6 +133,7 @@ def validate_monty(cls, v, _): class MPDataset: def __init__(self, path, document_model, use_document_model): + """Convenience wrapper for pyarrow datasets stored on disk.""" self._start = 0 self._path = path self._document_model = document_model From 2de051df8fde2dd2d10e611598b0ea2efdf984a0 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:40:13 -0700 Subject: [PATCH 5/7] return trajectory as pmg dict --- mp_api/client/routes/materials/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index b0854d42..a879a93c 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -54,7 +54,7 @@ def get_trajectory(self, task_id): if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return RelaxTrajectory(**traj_data[0]).to_pmg() + return RelaxTrajectory(**traj_data[0]).to_pmg().as_dict() def search( self, From 7d0b8b749b3f163133a5028b6ee169f9bb39cc05 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:46:29 -0700 Subject: [PATCH 6/7] update trajectory test --- tests/materials/test_tasks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/materials/test_tasks.py b/tests/materials/test_tasks.py index b35dfd93..1ddf12c5 100644 --- a/tests/materials/test_tasks.py +++ b/tests/materials/test_tasks.py @@ -1,8 +1,9 @@ import os -from core_function import client_search_testing -import pytest +import pytest +from core_function import client_search_testing from emmet.core.utils import utcnow + from mp_api.client.routes.materials.tasks import TaskRester @@ -53,7 +54,6 @@ def test_client(rester): def test_get_trajectories(rester): - trajectories = [traj for traj in rester.get_trajectory("mp-149")] + trajectory = rester.get_trajectory("mp-149") - for traj in trajectories: - assert ("@module", "pymatgen.core.trajectory") in traj.items() + assert trajectory["@module"] == "pymatgen.core.trajectory" From 7195adf9b11394898dae78b502e3235b74a18f75 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:48:39 -0700 Subject: [PATCH 7/7] correct docstrs --- mp_api/client/core/client.py | 2 +- mp_api/client/mprester.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index c8a49233..5b74e5dc 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -119,7 +119,7 @@ def __init__( headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults - to 'materialsproject_datasets' in the user's home directory + to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset """ # TODO: think about how to migrate from PMG_MAPI_KEY diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 5537736a..a60de0f3 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -170,7 +170,7 @@ def __init__( headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults - to "materialsproject_datasets" in the user's home directory + to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset """