Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 180 additions & 13 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import os
import platform
import shutil
import sys
import warnings
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
Expand All @@ -21,7 +22,10 @@
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import pyarrow as pa
import pyarrow.dataset as ds
import requests
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
Expand All @@ -31,7 +35,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import MPDataset, load_json, validate_ids

try:
import boto3
Expand Down Expand Up @@ -83,6 +87,7 @@ class BaseRester:
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"
delta_backed: bool = False

def __init__(
self,
Expand All @@ -97,6 +102,8 @@ def __init__(
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE,
force_renew: bool = False,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -128,6 +135,9 @@ def __init__(
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
local_dataset_cache: Target directory for downloading full datasets. Defaults
to 'mp_datasets' in the user's home directory
force_renew: Option to overwrite existing local dataset
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -141,6 +151,8 @@ def __init__(
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.local_dataset_cache = local_dataset_cache
self.force_renew = force_renew
self.db_version = BaseRester._get_database_version(self.endpoint)

if self.suffix:
Expand Down Expand Up @@ -368,10 +380,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable | None = None,
self, bucket: str, key: str, decoder: Callable | None = None
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Expand Down Expand Up @@ -471,6 +480,12 @@ def _query_resource(
url += "/"

if query_s3:
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)

db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
Expand All @@ -481,15 +496,168 @@ def _query_resource(
suffix = infix if suffix == "core" else suffix
suffix = suffix.replace("_", "-")

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
# Check if user has access to GNoMe
# temp suppress tqdm
re_enable = not self.mute_progress_bars
self.mute_progress_bars = True
has_gnome_access = bool(
self._submit_requests(
url=urljoin(
"https://api.materialsproject.org/", "materials/summary/"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use self.endpoint

Copy link
Collaborator Author

@tsmathis tsmathis Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be self.base_endpoint? I think this tripped me up when I tried using self.endpoint originally

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)

for the tasks rester -> self.endpoint caused the urljoin here to yield something like {base_url}/materials/tasks/materials/summary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, self.base_endpoint should work.

),
criteria={
"batch_id": "gnome_r2scan_statics",
"_fields": "material_id",
},
use_document_model=False,
num_chunks=1,
chunk_size=1,
timeout=timeout,
)
.get("meta", {})
.get("total_doc", 0)
)
self.mute_progress_bars = not re_enable

if "tasks" in suffix:
bucket_suffix, prefix = "parsed", "tasks_atomate2"
bucket_suffix, prefix = ("parsed", "core/tasks/")
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"

if self.delta_backed:
target_path = (
self.local_dataset_cache + f"/{bucket_suffix}/{prefix}"
)
os.makedirs(target_path, exist_ok=True)

if DeltaTable.is_deltatable(target_path):
if self.force_renew:
shutil.rmtree(target_path)
warnings.warn(
f"Regenerating {suffix} dataset at {target_path}...",
MPLocalDatasetWarning,
)
os.makedirs(target_path, exist_ok=True)
else:
warnings.warn(
f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset "
"or re-run search query with MPRester(force_renew=True)",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

tbl = DeltaTable(
f"s3a://{bucket}/{prefix}",
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
},
)

controlled_batch_str = ",".join(
[f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS]
)

predicate = (
" WHERE batch_id NOT IN (" # don't delete leading space
+ controlled_batch_str
+ ")"
if not has_gnome_access
else ""
)

builder = QueryBuilder().register("tbl", tbl)

# Setup progress bar
num_docs_needed = pa.table(
builder.execute("SELECT COUNT(*) FROM tbl").read_all()
)[0][0].as_py()

# TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator
# -> need to modify BatchIdQuery operator to handle root level
# batch_id, not only builder_meta.batch_id
# if not has_gnome_access:
# num_docs_needed = self.count(
# {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS}
# )

pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
else None
)

iterator = builder.execute("SELECT * FROM tbl" + predicate)

file_options = ds.ParquetFileFormat().make_write_options(
compression="zstd"
)

def _flush(accumulator, group):
ds.write_dataset(
accumulator,
base_dir=target_path,
format="parquet",
basename_template=f"group-{group}-"
+ "part-{i}.zstd.parquet",
existing_data_behavior="overwrite_or_ignore",
max_rows_per_group=1024,
file_options=file_options,
)

group = 1
size = 0
accumulator = []
for page in iterator:
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
accumulator.append(pa.record_batch(page))
page_size = page.num_rows
size += page_size

if pbar is not None:
pbar.update(page_size)

if size >= SETTINGS.DATASET_FLUSH_THRESHOLD:
_flush(accumulator, group)
group += 1
size = 0
accumulator = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulator.clear() for better memory management?


if accumulator:
_flush(accumulator, group + 1)

convert_to_deltalake(target_path)

warnings.warn(
f"Dataset for {suffix} written to {target_path}. It is recommended to optimize "
"the table according to your usage patterns prior to running intensive workloads, "
"see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout",
MPLocalDatasetWarning,
)

return {
"data": MPDataset(
path=target_path,
document_model=self.document_model,
use_document_model=self.use_document_model,
)
}

# Paginate over all entries in the bucket.
# TODO: change when a subset of entries needed from DB
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

Expand Down Expand Up @@ -526,11 +694,6 @@ def _query_resource(
}

# Setup progress bar
pbar_message = ( # type: ignore
f"Retrieving {self.document_model.__name__} documents" # type: ignore
if self.document_model is not None
else "Retrieving documents"
)
num_docs_needed = int(self.count())
pbar = (
tqdm(
Expand Down Expand Up @@ -1359,3 +1522,7 @@ class MPRestError(Exception):

class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


class MPLocalDatasetWarning(Warning):
"""Raised when unrecoverable actions are performed on a local dataset."""
14 changes: 14 additions & 0 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings):
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

LOCAL_DATASET_CACHE: str = Field(
os.path.expanduser("~") + "/mp_datasets",
description="Target directory for downloading full datasets",
)

DATASET_FLUSH_THRESHOLD: int = Field(
100000,
description="Threshold number of rows to accumulate in memory before flushing dataset to disk",
)

ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

settings is the right place for the first two. Access-controlled batch ids should probably be hardcoded and change with client releases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

akin to one of these?

def get_database_version(self):
"""The Materials Project database is periodically updated and has a
database version associated with it. When the database is updated,
consolidated data (information about "a material") may and does
change, while calculation data about a specific calculation task
remains unchanged and available for querying via its task_id.
The database version is set as a date in the format YYYY_MM_DD,
where "_DD" may be optional. An additional numerical suffix
might be added if multiple releases happen on the same day.
Returns: database version as a string
"""
return get(url=self.endpoint + "heartbeat").json()["db_version"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: the use of self.endpoint? If so, then yes :)

Copy link
Collaborator Author

@tsmathis tsmathis Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, I mean for the access controlled batch ids.
Should those be added to the heartbeat so they aren't defined in the client code/settings? And then the client can just call get_access_controlled_batch_ids()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a good idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to start a PR to add it to the heartbeat_meta here.

["gnome_r2scan_statics"], description="Batch ids with access restrictions"
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")
64 changes: 64 additions & 0 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import re
from functools import cached_property
from itertools import chain
from typing import TYPE_CHECKING, Literal

import orjson
import pyarrow.dataset as ds
from deltalake import DeltaTable
from emmet.core import __version__ as _EMMET_CORE_VER
from monty.json import MontyDecoder
from packaging.version import parse as parse_version
from pydantic._internal._model_construction import ModelMetaclass

from mp_api.client.core.settings import MAPIClientSettings

Expand Down Expand Up @@ -124,3 +129,62 @@ def validate_monty(cls, v, _):
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls


class MPDataset:
def __init__(self, path, document_model, use_document_model):
"""Convenience wrapper for pyarrow datasets stored on disk."""
self._start = 0
self._path = path
self._document_model = document_model
self._dataset = ds.dataset(path)
self._row_groups = list(
chain.from_iterable(
[
fragment.split_by_row_group()
for fragment in self._dataset.get_fragments()
]
)
)
self._use_document_model = use_document_model

@property
def pyarrow_dataset(self) -> ds.Dataset:
return self._dataset

@property
def pydantic_model(self) -> ModelMetaclass:
return self._document_model

@property
def use_document_model(self) -> bool:
return self._use_document_model

@use_document_model.setter
def use_document_model(self, value: bool):
self._use_document_model = value

@cached_property
def delta_table(self) -> DeltaTable:
return DeltaTable(self._path)

@cached_property
def num_chunks(self) -> int:
return len(self._row_groups)

def __getitem__(self, idx):
return list(
map(
lambda x: self._document_model(**x) if self._use_document_model else x,
self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"),
)
)

def __len__(self) -> int:
return self.num_chunks

def __iter__(self):
current = self._start
while current < self.num_chunks:
yield self[current]
current += 1
Loading