diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 4cd98958..793d0902 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -18,12 +18,7 @@ from importlib.metadata import PackageNotFoundError, version from json import JSONDecodeError from math import ceil -from typing import ( - TYPE_CHECKING, - ForwardRef, - Optional, - get_args, -) +from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import quote, urljoin import requests @@ -64,6 +59,23 @@ SETTINGS = MAPIClientSettings() # type: ignore +class _DictLikeAccess(BaseModel): + """Define a pydantic mix-in which permits dict-like access to model fields.""" + + def __getitem__(self, item: str) -> Any: + """Return `item` if a valid model field, otherwise raise an exception.""" + if item in self.__class__.model_fields: + return getattr(self, item) + raise AttributeError(f"{self.__class__.__name__} has no model field `{item}`.") + + def get(self, item: str, default: Any = None) -> Any: + """Return a model field `item`, or `default` if it doesn't exist.""" + try: + return self.__getitem__(item) + except AttributeError: + return default + + class BaseRester: """Base client class with core stubs.""" @@ -427,13 +439,9 @@ def _query_resource( if use_document_model is None: use_document_model = self.use_document_model - if timeout is None: - timeout = self.timeout + timeout = self.timeout if timeout is None else timeout - if criteria: - criteria = {k: v for k, v in criteria.items() if v is not None} - else: - criteria = {} + criteria = {k: v for k, v in (criteria or {}).items() if v is not None} # Query s3 if no query is passed and all documents are asked for # TODO also skip fields set to same as their default @@ -1080,6 +1088,7 @@ def _generate_returned_model( # TODO fields_not_requested is not the same as unset_fields # i.e. field could be requested but not available in the raw doc fields_not_requested=(list[str], unset_fields), + __base__=_DictLikeAccess, __doc__=".".join( [ getattr(self.document_model, k, "") diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 3fdc07f9..e6f65c96 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -3,6 +3,7 @@ import itertools import os import warnings +from collections import defaultdict from functools import cache, lru_cache from typing import TYPE_CHECKING @@ -424,11 +425,7 @@ def get_task_ids_associated_with_material_id( if not tasks: return [] - calculations = ( - tasks[0].calc_types # type: ignore - if self.use_document_model - else tasks[0]["calc_types"] # type: ignore - ) + calculations = tasks[0]["calc_types"] if calc_types: return [ @@ -436,8 +433,7 @@ def get_task_ids_associated_with_material_id( for task, calc_type in calculations.items() if calc_type in calc_types ] - else: - return list(calculations.keys()) + return list(calculations.keys()) def get_structure_by_material_id( self, material_id: str, final: bool = True, conventional_unit_cell: bool = False @@ -539,11 +535,7 @@ def get_material_id_references(self, material_id: str) -> list[str]: List of BibTeX references ([str]) """ docs = self.materials.provenance.search(material_ids=material_id) - - if not docs: - return [] - - return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore + return docs[0]["references"] if docs else [] def get_material_ids( self, @@ -558,17 +550,16 @@ def get_material_ids( Returns: List of all materials ids ([MPID]) """ + inp_k = "formula" if isinstance(chemsys_formula, list) or ( isinstance(chemsys_formula, str) and "-" in chemsys_formula ): - input_params = {"chemsys": chemsys_formula} - else: - input_params = {"formula": chemsys_formula} + inp_k = "chemsys" return sorted( - doc.material_id if self.use_document_model else doc["material_id"] # type: ignore + doc["material_id"] for doc in self.materials.search( - **input_params, # type: ignore + **{inp_k: chemsys_formula}, all_fields=False, fields=["material_id"], ) @@ -601,10 +592,8 @@ def get_structures( all_fields=False, fields=["structure"], ) - if not self.use_document_model: - return [doc["structure"] for doc in docs] # type: ignore - return [doc.structure for doc in docs] # type: ignore + return [doc["structure"] for doc in docs] else: structures = [] @@ -613,12 +602,7 @@ def get_structures( all_fields=False, fields=["initial_structures"], ): - initial_structures = ( - doc.initial_structures # type: ignore - if self.use_document_model - else doc["initial_structures"] # type: ignore - ) - structures.extend(initial_structures) + structures.extend(doc["initial_structures"]) return structures @@ -723,7 +707,7 @@ def get_entries( if additional_criteria: input_params = {**input_params, **additional_criteria} - entries = [] + entries: set[ComputedStructureEntry] = set() fields = ( ["entries", "thermo_type"] @@ -738,24 +722,17 @@ def get_entries( ) for doc in docs: - entry_list = ( - doc.entries.values() # type: ignore - if self.use_document_model - else doc["entries"].values() # type: ignore - ) + entry_list = doc["entries"].values() for entry in entry_list: - entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore + entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore if not compatible_only: entry_dict["correction"] = 0.0 entry_dict["energy_adjustments"] = [] if property_data: - for property in property_data: - entry_dict["data"][property] = ( - doc.model_dump()[property] # type: ignore - if self.use_document_model - else doc[property] # type: ignore - ) + entry_dict["data"] = { + property: doc[property] for property in property_data + } if conventional_unit_cell: entry_struct = Structure.from_dict(entry_dict["structure"]) @@ -776,15 +753,10 @@ def get_entries( if "n_atoms" in correction: correction["n_atoms"] *= site_ratio - entry = ( - ComputedStructureEntry.from_dict(entry_dict) - if self.monty_decode - else entry_dict - ) + # Need to store object to permit de-duplication + entries.add(ComputedStructureEntry.from_dict(entry_dict)) - entries.append(entry) - - return entries + return [e if self.monty_decode else e.as_dict() for e in entries] def get_pourbaix_entries( self, @@ -1315,9 +1287,7 @@ def get_wulff_shape(self, material_id: str): if not doc: return None - surfaces: list = ( - doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore - ) + surfaces: list = doc[0]["surfaces"] lattice = ( SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice @@ -1387,17 +1357,8 @@ def get_charge_density_from_material_id( if len(results) == 0: return None - latest_doc = max( # type: ignore - results, - key=lambda x: ( - x.last_updated # type: ignore - if self.use_document_model - else x["last_updated"] - ), # type: ignore - ) - task_id = ( - latest_doc.task_id if self.use_document_model else latest_doc["task_id"] - ) + latest_doc = max(results, key=lambda x: x["last_updated"]) + task_id = latest_doc["task_id"] return self.get_charge_density_from_task_id(task_id, inc_task_doc) def get_download_info(self, material_ids, calc_types=None, file_patterns=None): @@ -1419,20 +1380,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None): else [] ) - meta = {} + meta = defaultdict(list) for doc in self.materials.search( # type: ignore task_ids=material_ids, fields=["calc_types", "deprecated_tasks", "material_id"], ): - doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore - for task_id, calc_type in doc_dict["calc_types"].items(): + for task_id, calc_type in doc["calc_types"].items(): if calc_types and calc_type not in calc_types: continue - mp_id = doc_dict["material_id"] - if meta.get(mp_id) is None: - meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}] - else: - meta[mp_id].append({"task_id": task_id, "calc_type": calc_type}) + mp_id = doc["material_id"] + meta[mp_id].append({"task_id": task_id, "calc_type": calc_type}) + if not meta: raise ValueError(f"No tasks found for material id {material_ids}.") diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index ccaaaad8..94298206 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -276,61 +276,47 @@ def get_bandstructure_from_material_id( if not bs_doc: raise MPRestError("No electronic structure data found.") - bs_data = ( - bs_doc[0].bandstructure # type: ignore - if self.use_document_model - else bs_doc[0]["bandstructure"] # type: ignore - ) - - if bs_data is None: + if (bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( f"No {path_type.value} band structure data found for {material_id}" ) - else: - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore - ) - if bs_data.get(path_type.value, None): - bs_task_id = bs_data[path_type.value]["task_id"] - else: + bs_data: dict = ( + bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + ) + + if bs_data.get(path_type.value, None) is None: raise MPRestError( f"No {path_type.value} band structure data found for {material_id}" ) - else: - bs_doc = es_rester.search(material_ids=material_id, fields=["dos"]) + bs_task_id = bs_data[path_type.value]["task_id"] - if not bs_doc: + else: + if not ( + bs_doc := es_rester.search(material_ids=material_id, fields=["dos"]) + ): raise MPRestError("No electronic structure data found.") - bs_data = ( - bs_doc[0].dos # type: ignore - if self.use_document_model - else bs_doc[0]["dos"] # type: ignore - ) - - if bs_data is None: + if (bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( f"No uniform band structure data found for {material_id}" ) - else: - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore - ) - if bs_data.get("total", None): - bs_task_id = bs_data["total"]["1"]["task_id"] - else: + bs_data: dict = ( + bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + ) + + if bs_data.get("total", None) is None: raise MPRestError( f"No uniform band structure data found for {material_id}" ) + bs_task_id = bs_data["total"]["1"]["task_id"] bs_obj = self.get_bandstructure_from_task_id(bs_task_id) if bs_obj: return bs_obj - else: - raise MPRestError("No band structure object found.") + raise MPRestError("No band structure object found.") class DosRester(BaseRester): @@ -456,22 +442,16 @@ def get_dos_from_material_id(self, material_id: str): mute_progress_bars=self.mute_progress_bars, ) - dos_doc = es_rester.search(material_ids=material_id, fields=["dos"]) - if not dos_doc: + if not (dos_doc := es_rester.search(material_ids=material_id, fields=["dos"])): return None - dos_data: dict = ( - dos_doc[0].model_dump() if self.use_document_model else dos_doc[0] # type: ignore - ) - - if dos_data["dos"]: - dos_task_id = dos_data["dos"]["total"]["1"]["task_id"] - else: + if not (dos_data := dos_doc[0].get("dos")): raise MPRestError(f"No density of states data found for {material_id}") - dos_obj = self.get_dos_from_task_id(dos_task_id) - - if dos_obj: + dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[ + "total" + ]["1"]["task_id"] + if dos_obj := self.get_dos_from_task_id(dos_task_id): return dos_obj - else: - raise MPRestError("No density of states object found.") + + raise MPRestError("No density of states object found.") diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index df24e93b..6b35fabd 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -126,10 +126,17 @@ def get_structure_by_material_id( response = self.search(material_ids=material_id, fields=[field]) - if response: - response = ( - response[0].model_dump() if self.use_document_model else response[0] # type: ignore - ) + if response and response[0]: + response = response[0] + # Ensure that return type is a Structure regardless of `monty_decode` or `model_dump` output + if isinstance(response[field], dict): + response[field] = Structure.from_dict(response[field]) + elif isinstance(response[field], list) and any( + isinstance(struct, dict) for struct in response[field] + ): + response[field] = [ + Structure.from_dict(struct) for struct in response[field] + ] return response[field] if response else response # type: ignore @@ -305,7 +312,4 @@ def find_structure( ) return results # type: ignore - if results: - return results[0]["material_id"] - else: - return [] + return results[0]["material_id"] if (results and results[0]) else [] diff --git a/mp_api/client/routes/molecules/molecules.py b/mp_api/client/routes/molecules/molecules.py index 922f136e..4493932e 100644 --- a/mp_api/client/routes/molecules/molecules.py +++ b/mp_api/client/routes/molecules/molecules.py @@ -34,13 +34,7 @@ def get_molecule_by_mpculeid( field = "molecule" if final else "initial_molecules" response = self.search(molecule_ids=[mpcule_id], fields=[field]) # type: ignore - - if response: - response = ( - response[0].model_dump() if self.use_document_model else response[0] # type: ignore - ) - - return response[field] if response else response # type: ignore + return response[0][field] if (response and response[0]) else response # type: ignore def find_molecule( self, @@ -96,10 +90,7 @@ def find_molecule( ) return results # type: ignore - if results: - return results[0]["molecule_id"] - else: - return [] + return results[0]["molecule_id"] if (results and results[0]) else [] def search( self, diff --git a/pyproject.toml b/pyproject.toml index f202666c..afefc1e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.85.1rc0,<0.86", + "emmet-core>=0.85.1rc0", "smart_open", "boto3", "orjson >= 3.10,<4", diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index 780c5ceb..c993b2f5 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.85.1 +emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index 67ad81fc..1b70c23f 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -24,7 +24,7 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib -emmet-core==0.85.1 +emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 # via matplotlib diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 0cc9d271..dbb5e395 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -148,7 +148,13 @@ def test_get_entries(self, mpr): assert e.data.get("energy_above_hull", None) is not None # Conventional structure - entry = mpr.get_entry_by_material_id("mp-22526", conventional_unit_cell=True)[1] + entry = next( + e + for e in mpr.get_entry_by_material_id( + "mp-22526", conventional_unit_cell=True + ) + if e.entry_id == "mp-22526-r2SCAN" + ) s = entry.structure assert pytest.approx(s.lattice.a) == s.lattice.b @@ -158,9 +164,14 @@ def test_get_entries(self, mpr): assert pytest.approx(s.lattice.gamma) == 120 # Ensure energy per atom is same - prim = mpr.get_entry_by_material_id("mp-22526", conventional_unit_cell=False)[1] - - s = prim.structure + entry = next( + e + for e in mpr.get_entry_by_material_id( + "mp-22526", conventional_unit_cell=False + ) + if e.entry_id == "mp-22526-r2SCAN" + ) + s = entry.structure assert pytest.approx(s.lattice.a) == s.lattice.b assert pytest.approx(s.lattice.a, abs=1e-3) == s.lattice.c assert pytest.approx(s.lattice.alpha, abs=1e-3) == s.lattice.beta