Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import (
TYPE_CHECKING,
ForwardRef,
Optional,
get_args,
)
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin

import requests
Expand Down Expand Up @@ -64,6 +59,23 @@
SETTINGS = MAPIClientSettings() # type: ignore


class _DictLikeAccess(BaseModel):
"""Define a pydantic mix-in which permits dict-like access to model fields."""

def __getitem__(self, item: str) -> Any:
"""Return `item` if a valid model field, otherwise raise an exception."""
if item in self.__class__.model_fields:
return getattr(self, item)
raise AttributeError(f"{self.__class__.__name__} has no model field `{item}`.")

def get(self, item: str, default: Any = None) -> Any:
"""Return a model field `item`, or `default` if it doesn't exist."""
try:
return self.__getitem__(item)
except AttributeError:
return default


class BaseRester:
"""Base client class with core stubs."""

Expand Down Expand Up @@ -427,13 +439,9 @@ def _query_resource(
if use_document_model is None:
use_document_model = self.use_document_model

if timeout is None:
timeout = self.timeout
timeout = self.timeout if timeout is None else timeout

if criteria:
criteria = {k: v for k, v in criteria.items() if v is not None}
else:
criteria = {}
criteria = {k: v for k, v in (criteria or {}).items() if v is not None}

# Query s3 if no query is passed and all documents are asked for
# TODO also skip fields set to same as their default
Expand Down Expand Up @@ -1080,6 +1088,7 @@ def _generate_returned_model(
# TODO fields_not_requested is not the same as unset_fields
# i.e. field could be requested but not available in the raw doc
fields_not_requested=(list[str], unset_fields),
__base__=_DictLikeAccess,
__doc__=".".join(
[
getattr(self.document_model, k, "")
Expand Down
96 changes: 27 additions & 69 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import os
import warnings
from collections import defaultdict
from functools import cache, lru_cache
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -424,20 +425,15 @@ def get_task_ids_associated_with_material_id(
if not tasks:
return []

calculations = (
tasks[0].calc_types # type: ignore
if self.use_document_model
else tasks[0]["calc_types"] # type: ignore
)
calculations = tasks[0]["calc_types"]

if calc_types:
return [
task
for task, calc_type in calculations.items()
if calc_type in calc_types
]
else:
return list(calculations.keys())
return list(calculations.keys())

def get_structure_by_material_id(
self, material_id: str, final: bool = True, conventional_unit_cell: bool = False
Expand Down Expand Up @@ -539,11 +535,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
List of BibTeX references ([str])
"""
docs = self.materials.provenance.search(material_ids=material_id)

if not docs:
return []

return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore
return docs[0]["references"] if docs else []

def get_material_ids(
self,
Expand All @@ -558,17 +550,16 @@ def get_material_ids(
Returns:
List of all materials ids ([MPID])
"""
inp_k = "formula"
if isinstance(chemsys_formula, list) or (
isinstance(chemsys_formula, str) and "-" in chemsys_formula
):
input_params = {"chemsys": chemsys_formula}
else:
input_params = {"formula": chemsys_formula}
inp_k = "chemsys"

return sorted(
doc.material_id if self.use_document_model else doc["material_id"] # type: ignore
doc["material_id"]
for doc in self.materials.search(
**input_params, # type: ignore
**{inp_k: chemsys_formula},
all_fields=False,
fields=["material_id"],
)
Expand Down Expand Up @@ -601,10 +592,8 @@ def get_structures(
all_fields=False,
fields=["structure"],
)
if not self.use_document_model:
return [doc["structure"] for doc in docs] # type: ignore

return [doc.structure for doc in docs] # type: ignore
return [doc["structure"] for doc in docs]
else:
structures = []

Expand All @@ -613,12 +602,7 @@ def get_structures(
all_fields=False,
fields=["initial_structures"],
):
initial_structures = (
doc.initial_structures # type: ignore
if self.use_document_model
else doc["initial_structures"] # type: ignore
)
structures.extend(initial_structures)
structures.extend(doc["initial_structures"])

return structures

Expand Down Expand Up @@ -723,7 +707,7 @@ def get_entries(
if additional_criteria:
input_params = {**input_params, **additional_criteria}

entries = []
entries: set[ComputedStructureEntry] = set()

fields = (
["entries", "thermo_type"]
Expand All @@ -738,24 +722,17 @@ def get_entries(
)

for doc in docs:
entry_list = (
doc.entries.values() # type: ignore
if self.use_document_model
else doc["entries"].values() # type: ignore
)
entry_list = doc["entries"].values()
for entry in entry_list:
entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore
entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore
if not compatible_only:
entry_dict["correction"] = 0.0
entry_dict["energy_adjustments"] = []

if property_data:
for property in property_data:
entry_dict["data"][property] = (
doc.model_dump()[property] # type: ignore
if self.use_document_model
else doc[property] # type: ignore
)
entry_dict["data"] = {
property: doc[property] for property in property_data
}

if conventional_unit_cell:
entry_struct = Structure.from_dict(entry_dict["structure"])
Expand All @@ -776,15 +753,10 @@ def get_entries(
if "n_atoms" in correction:
correction["n_atoms"] *= site_ratio

entry = (
ComputedStructureEntry.from_dict(entry_dict)
if self.monty_decode
else entry_dict
)
# Need to store object to permit de-duplication
entries.add(ComputedStructureEntry.from_dict(entry_dict))

entries.append(entry)

return entries
return [e if self.monty_decode else e.as_dict() for e in entries]

def get_pourbaix_entries(
self,
Expand Down Expand Up @@ -1315,9 +1287,7 @@ def get_wulff_shape(self, material_id: str):
if not doc:
return None

surfaces: list = (
doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore
)
surfaces: list = doc[0]["surfaces"]

lattice = (
SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
Expand Down Expand Up @@ -1387,17 +1357,8 @@ def get_charge_density_from_material_id(
if len(results) == 0:
return None

latest_doc = max( # type: ignore
results,
key=lambda x: (
x.last_updated # type: ignore
if self.use_document_model
else x["last_updated"]
), # type: ignore
)
task_id = (
latest_doc.task_id if self.use_document_model else latest_doc["task_id"]
)
latest_doc = max(results, key=lambda x: x["last_updated"])
task_id = latest_doc["task_id"]
return self.get_charge_density_from_task_id(task_id, inc_task_doc)

def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
Expand All @@ -1419,20 +1380,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
else []
)

meta = {}
meta = defaultdict(list)
for doc in self.materials.search( # type: ignore
task_ids=material_ids,
fields=["calc_types", "deprecated_tasks", "material_id"],
):
doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore
for task_id, calc_type in doc_dict["calc_types"].items():
for task_id, calc_type in doc["calc_types"].items():
if calc_types and calc_type not in calc_types:
continue
mp_id = doc_dict["material_id"]
if meta.get(mp_id) is None:
meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}]
else:
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
mp_id = doc["material_id"]
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})

if not meta:
raise ValueError(f"No tasks found for material id {material_ids}.")

Expand Down
74 changes: 27 additions & 47 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,61 +276,47 @@ def get_bandstructure_from_material_id(
if not bs_doc:
raise MPRestError("No electronic structure data found.")

bs_data = (
bs_doc[0].bandstructure # type: ignore
if self.use_document_model
else bs_doc[0]["bandstructure"] # type: ignore
)

if bs_data is None:
if (bs_data := bs_doc[0]["bandstructure"]) is None:
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_data: dict = (
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
)

if bs_data.get(path_type.value, None):
bs_task_id = bs_data[path_type.value]["task_id"]
else:
bs_data: dict = (
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
)

if bs_data.get(path_type.value, None) is None:
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_doc = es_rester.search(material_ids=material_id, fields=["dos"])
bs_task_id = bs_data[path_type.value]["task_id"]

if not bs_doc:
else:
if not (
bs_doc := es_rester.search(material_ids=material_id, fields=["dos"])
):
raise MPRestError("No electronic structure data found.")

bs_data = (
bs_doc[0].dos # type: ignore
if self.use_document_model
else bs_doc[0]["dos"] # type: ignore
)

if bs_data is None:
if (bs_data := bs_doc[0]["dos"]) is None:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
else:
bs_data: dict = (
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
)

if bs_data.get("total", None):
bs_task_id = bs_data["total"]["1"]["task_id"]
else:
bs_data: dict = (
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
)

if bs_data.get("total", None) is None:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
bs_task_id = bs_data["total"]["1"]["task_id"]

bs_obj = self.get_bandstructure_from_task_id(bs_task_id)

if bs_obj:
return bs_obj
else:
raise MPRestError("No band structure object found.")
raise MPRestError("No band structure object found.")


class DosRester(BaseRester):
Expand Down Expand Up @@ -456,22 +442,16 @@ def get_dos_from_material_id(self, material_id: str):
mute_progress_bars=self.mute_progress_bars,
)

dos_doc = es_rester.search(material_ids=material_id, fields=["dos"])
if not dos_doc:
if not (dos_doc := es_rester.search(material_ids=material_id, fields=["dos"])):
return None

dos_data: dict = (
dos_doc[0].model_dump() if self.use_document_model else dos_doc[0] # type: ignore
)

if dos_data["dos"]:
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
else:
if not (dos_data := dos_doc[0].get("dos")):
raise MPRestError(f"No density of states data found for {material_id}")

dos_obj = self.get_dos_from_task_id(dos_task_id)

if dos_obj:
dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[
"total"
]["1"]["task_id"]
if dos_obj := self.get_dos_from_task_id(dos_task_id):
return dos_obj
else:
raise MPRestError("No density of states object found.")

raise MPRestError("No density of states object found.")
Loading