diff --git a/dev/generate_mcp_tools.py b/dev/generate_mcp_tools.py index 4bd57700..57ca9511 100644 --- a/dev/generate_mcp_tools.py +++ b/dev/generate_mcp_tools.py @@ -53,13 +53,13 @@ def regenerate_tools( from datetime import datetime from typing import Literal +from emmet.core.band_theory import BSPathType from emmet.core.chemenv import ( COORDINATION_GEOMETRIES, COORDINATION_GEOMETRIES_IUCR, COORDINATION_GEOMETRIES_IUPAC, COORDINATION_GEOMETRIES_NAMES, ) -from emmet.core.band_theory import BSPathType from emmet.core.electronic_structure import DOSProjectionType from emmet.core.grain_boundary import GBTypeEnum from emmet.core.mpid import MPID diff --git a/mp_api/_test_utils.py b/mp_api/_test_utils.py index 5d4044c9..c11b0ea0 100644 --- a/mp_api/_test_utils.py +++ b/mp_api/_test_utils.py @@ -4,6 +4,8 @@ from __future__ import annotations +from enum import Enum + try: import pytest except ImportError as exc: @@ -86,19 +88,64 @@ def client_search_testing( assert doc[alt_name_dict.get(param, param)] is not None -def client_pagination(search_method: Callable, id_name: str): - page_1 = search_method(_page=1, chunk_size=NUM_DOCS, fields=[id_name]) - page_2 = search_method(_page=2, chunk_size=NUM_DOCS, fields=[id_name]) +def client_pagination( + search_method: Callable, id_name: str, additional_fields: list[str] | None = None +) -> None: + """Test pagination on an endpoint. + + Args: + search_method (Callable) : Client search method to use + id_name (str) : the name of a field which uniquely indexes a series of documents + additional_fields (list of str) : Optional other fields to retrieve. + + Raises: + AssertionError if pagination does not result in unique sets of documents + """ + fields = [id_name, *(additional_fields or [])] + page_1 = search_method(_page=1, chunk_size=NUM_DOCS, fields=fields) + page_2 = search_method(_page=2, chunk_size=NUM_DOCS, fields=fields) assert all(len(results) == NUM_DOCS for results in (page_1, page_2)) assert {str(getattr(doc, id_name)) for doc in page_1}.intersection( {str(getattr(doc, id_name)) for doc in page_2} ) == set() -def client_sort(search_method: Callable, sort_fields: str | Sequence[str]): +def client_sort( + search_method: Callable, + sort_fields: str | Sequence[str], + aux_query: dict[str, Any] | None = None, + default_fields: tuple[str, ...] = ("deprecated", "material_id"), +): + """Test sorting on an endpoint. + + Args: + search_method (Callable) : Client search method to use + sort_fields (str or Sequence of str) : fields to sort on + aux_query (dict) : auxiliary query needed to filter documents + default_fields (list): default fields to return + + Raises: + AssertionError if sorting in ascending or descending order does not work. + """ + + def _normalize(doc, field: str): + v = getattr(doc, field) + # serialize enums + return v.value if isinstance(v, Enum) else v + + user_query = { + k: v + for k, v in (aux_query or {}).items() + if k not in ("_page", "_sort_fields", "chunk_size", "fields") + } for sort_field in [sort_fields] if isinstance(sort_fields, str) else sort_fields: + asc = search_method( - _page=1, _sort_fields=sort_field, chunk_size=NUM_DOCS, fields=[sort_field] + _page=1, + _sort_fields=sort_field, + chunk_size=NUM_DOCS, + fields=[sort_field, *default_fields], + **user_query, ) desc = search_method( _page=1, @@ -108,12 +155,12 @@ def client_sort(search_method: Callable, sort_fields: str | Sequence[str]): ) idxs = list(range(NUM_DOCS)) - assert sorted(idxs, key=lambda idx: getattr(asc[idx], sort_field)) == idxs + assert sorted(idxs, key=lambda idx: _normalize(asc[idx], sort_field)) == idxs assert ( sorted( idxs, - key=lambda idx: getattr(desc[idx], sort_field), + key=lambda idx: _normalize(desc[idx], sort_field), reverse=True, ) == idxs diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 5816c473..f5a9c5ab 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -42,6 +42,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry +from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -52,7 +53,6 @@ from mp_api.client.core.utils import ( MPDataset, load_json, - validate_api_key, validate_endpoint, validate_ids, ) @@ -68,6 +68,17 @@ except PackageNotFoundError: # pragma: no cover __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") +STATIC_COLLECTIONS = [ + "eos", + "grain_boundaries", + "jcesr", + "molecules", + "phonon", + "snls", + "surface-properties", + "synth-descriptions", + "xas", +] hdlr = logging.StreamHandler() fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") @@ -86,33 +97,52 @@ def _batched(iterable: Iterable, n: int) -> Iterator: yield batch -class BaseRester: - """Base client class with core stubs.""" +class QueryBuilderWithCache(QueryBuilder): - suffix: str = "" - document_model: type[BaseModel] = _DictLikeAccess - primary_key: str = "material_id" - delta_backed: bool = False + def __init__(self) -> None: + """Extend deltalake.QueryBuilder with stored DeltaTables. + + The deltalake.QueryBuilder class does not permit introspection + of registered DeltaTables through the python API. + + Re-registering a DeltaTable + (1) wastes time by reading its metadata + (2) raises an exception because a table is already registered + + This class simply allows for caching the DeltaTable instances + and table names on the QueryBuilder class. + """ + # Dict of table names (labels) to DeltaTable instances + self._delta_tables: dict[str, DeltaTable] = {} + super().__init__() + + def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: + """Register and cache a DeltaTable.""" + self._delta_tables[table_name] = delta_table + return super().register(table_name, delta_table) + + +class _Rester: + """Define base attributes of a REST client.""" def __init__( self, api_key: str | None = None, endpoint: str | None = None, include_user_agent: bool = True, - session: requests.Session | None = None, - s3_client: Any | None = None, - debug: bool = False, use_document_model: bool = True, - timeout: int = 20, + session: requests.Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + db_version: str | None = None, local_dataset_cache: ( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilderWithCache | None = None, **kwargs, - ): - """Initialize the REST API helper class. + ) -> None: + """Initialize a RESTer. Arguments: api_key: A String API key for accessing the MaterialsProject @@ -131,49 +161,56 @@ def __init__( making the API request. This helps MP support pymatgen users, and is similar to what most web browsers send with each page request. Set to False to disable the user agent. - session: requests Session object with which to connect to the API, for - advanced usage only. - s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. - debug: if True, print the URL for every request use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. - timeout: Time in seconds to wait until a request timeout error is thrown + session: requests Session object with which to connect to the API, for + advanced usage only. headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database + than what is currently deployed. The Materials Project cannot guarantee that all + features will still work. local_dataset_cache: Target directory for downloading full datasets. Defaults to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset + query_builder : Instance of QueryBuilderWithCache to use in querying delta tables + NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored. **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = validate_api_key(api_key) - self.base_endpoint = validate_endpoint(endpoint) - self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + self.api_key = get_user_api_key(api_key=api_key) + self.endpoint = validate_endpoint(endpoint) - self.debug = debug self.include_user_agent = include_user_agent self.use_document_model = use_document_model - self.timeout = timeout - self.headers = headers or {} - self.mute_progress_bars = mute_progress_bars - ( - self.db_version, - self.access_controlled_batch_ids, - ) = BaseRester._get_heartbeat_info(self.base_endpoint) + self.headers = headers or get_consumer() + self._session = session or _Rester._create_session( + api_key=self.api_key, + include_user_agent=self.include_user_agent, + headers=self.headers, + ) - self.local_dataset_cache: Path = Path(local_dataset_cache) - self.force_renew = force_renew + if is_dev_env(): + self._session.headers["x-api-key"] = self.api_key or "" - self._session = session - self._s3_client = s3_client + self.use_document_model = use_document_model + self.mute_progress_bars = mute_progress_bars + self.db_version: str = db_version or "" + self.local_dataset_cache = Path(local_dataset_cache) + self.force_renew = force_renew + self._query_builder = ( + query_builder if isinstance(query_builder, QueryBuilderWithCache) else None + ) if "monty_decode" in kwargs: + # Pop to not repeatedly trigger warning to the user + kwargs.pop("monty_decode", None) warnings.warn( "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." "The client by default returns results consistent with `monty_decode=True`.", - category=MPRestWarning, stacklevel=2, + category=MPRestWarning, ) @property @@ -185,13 +222,10 @@ def session(self) -> requests.Session: return self._session @property - def s3_client(self): - if not self._s3_client: - self._s3_client = boto3.client( - "s3", - config=Config(signature_version=UNSIGNED), # type: ignore - ) - return self._s3_client + def query_builder(self): + if not self._query_builder: + self._query_builder = QueryBuilderWithCache() + return self._query_builder @staticmethod def _create_session(api_key, include_user_agent, headers): @@ -270,6 +304,112 @@ def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]: response = get_resp.json() return response["db_version"], response["access_controlled_batch_ids"] + +class BaseRester(_Rester): + """Base client class with core stubs.""" + + suffix: str = "" + document_model: type[BaseModel] = _DictLikeAccess + primary_key: str = "material_id" + delta_backed: bool = True + + def __init__( + self, + api_key: str | None = None, + endpoint: str | None = None, + include_user_agent: bool = True, + use_document_model: bool = True, + session: requests.Session | None = None, + headers: dict | None = None, + mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + db_version: str | None = None, + local_dataset_cache: ( + str | os.PathLike + ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, + query_builder: QueryBuilderWithCache | None = None, + s3_client: Any | None = None, + timeout: int = 20, + **kwargs, + ): + """Initialize the REST API helper class. + + s3_client: boto3 S3 client object with which to connect to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + + Arguments: + api_key: A String API key for accessing the MaterialsProject + REST interface. Please obtain your API key at + https://www.materialsproject.org/dashboard. If this is None, + the code will check if there is a "PMG_MAPI_KEY" setting. + If so, it will use that environment variable. This makes + easier for heavy users to simply add this environment variable to + their setups and MPRester can then be called without any arguments. + endpoint: Url of endpoint to access the MaterialsProject REST + interface. Defaults to the standard Materials Project REST + address at "https://api.materialsproject.org", but + can be changed to other urls implementing a similar interface. + include_user_agent: If True, will include a user agent with the + HTTP request including information on pymatgen and system version + making the API request. This helps MP support pymatgen users, and + is similar to what most web browsers send with each page request. + Set to False to disable the user agent. + session: requests Session object with which to connect to the API, for + advanced usage only. + use_document_model: If False, skip the creating the document model and return data + as a dictionary. This can be simpler to work with but bypasses data validation + and will not give auto-complete for available fields. + headers: Custom headers for localhost connections. + mute_progress_bars: Whether to disable progress bars. + db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database + than what is currently deployed. The Materials Project cannot guarantee that all + features will still work. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'mp_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset + query_builder : Instance of QueryBuilderWithCache to use in querying delta tables + NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored. + s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + **kwargs: access to legacy kwargs that may be in the process of being deprecated + """ + super().__init__( + api_key=api_key, + endpoint=endpoint, + include_user_agent=include_user_agent, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + db_version=db_version, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, + **kwargs, + ) + + self.base_endpoint = validate_endpoint(endpoint) + self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + + ( + hb_db_version, + self.access_controlled_batch_ids, + ) = self._get_heartbeat_info(self.base_endpoint) + if not self.db_version: + self.db_version = hb_db_version + + self.timeout = timeout + self._s3_client = s3_client + + @property + def s3_client(self): + if not self._s3_client: + self._s3_client = boto3.client( + "s3", + config=Config(signature_version=UNSIGNED), # type: ignore + ) + return self._s3_client + def _post_resource( self, body: dict | None = None, @@ -440,18 +580,120 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore + def _get_delta_table( + self, + bucket: str, + prefix: str, + connector: str = "s3a", + label: str | None = None, + ) -> tuple[str, DeltaTable]: + """Either create a new DeltaTable, or retrieve a cached one. + + If creating a new DeltaTable, will also register in self.query_builder + + Args: + bucket (str) : name of the bucket in S3 + prefix (str) : name of the prefix in S3 + connector (str) : s3, s3n, s3a (default), or other + valid Hadoop connector string. + label (str or None) : optional label for the table in the + cached query builder + If `None`, will be gleaned from the URI + + Returns: + str : the table name in the stored query builder + DeltaTable : If one exists at the specified bucket / prefix, + will retrieve the cached instance. + """ + delta_timeout = f"{self.timeout}s" + full_key = f"{bucket}/{prefix}" + qb_label = label or full_key.replace("/", "_").replace("-", "_") + + uri = f"{connector}://{full_key}" + if not uri.endswith("/"): + uri += "/" + + try: + stored_label, delta_table = next( + (_label, _table) + for _label, _table in self.query_builder._delta_tables.items() + if _table.table_uri == uri + ) + except StopIteration: + stored_label = None + + if stored_label is None: + delta_table = DeltaTable( + uri, + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + "timeout": delta_timeout, + "connect_timeout": delta_timeout, + "retry_delay": "3", + "max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}", + }, + ) + self.query_builder.register(qb_label, delta_table) + + elif stored_label != qb_label: + warnings.warn( + f"DeltaTable with URI {uri} already found with different label: " + f"Stored label = {stored_label}; submitted label {qb_label}. " + "Using stored DeltaTable.", + category=MPRestWarning, + stacklevel=2, + ) + return stored_label, delta_table + + return qb_label, delta_table + + def _query_delta_single(self, query: str) -> pa.Table: + """Execute a SQL query against a registered Delta table. + + Wraps the query execution in a try/except to provide a more + actionable error message when the underlying Delta query engine + fails (e.g., due to network timeouts, missing tables, or + malformed queries). + + Args: + query (str): A SQL query string compatible with the + QueryBuilder engine. + + Returns: + pa.Table: The query result as a PyArrow Table. + + Raises: + MPRestError: If query execution fails for any reason, + including network timeouts, connectivity issues, or + invalid queries. Inspect the chained exception for + the underlying cause. + """ + try: + return pa.table(self.query_builder.execute(query).read_all()) + except Exception as e: + raise MPRestError( + f"Failed to retrieve object due to: {e}. " + f"If this is a timeout error, try increasing the 'timeout' " + f"parameter on MPRester (current value: {self.timeout}s)." + ) from e + def _query_delta_backed( self, bucket: str, prefix: str, + access_controlled: bool = True, timeout: int | None = None, + label: str | None = None, ) -> dict[str, Any]: """Retrieve data from S3 backed by a DeltaTable. Args: bucket (str) : S3 OpenData bucket prefix (str) : S3 object prefix + access_controlled (bool): whether or not table has access controlled data timeout (int or None) : timeout on getting access-controlled groups + label (str or None) : label of the table in QueryBuilder Returns: dict of str to Any @@ -508,13 +750,7 @@ def _query_delta_backed( ) } - tbl = DeltaTable( - f"s3a://{bucket}/{prefix}", - storage_options={ - "AWS_SKIP_SIGNATURE": "true", - "AWS_REGION": "us-east-1", - }, - ) + tbl_lbl, tbl = self._get_delta_table(bucket, prefix, label=label) controlled_batch_str = ",".join( [f"'{tag}'" for tag in self.access_controlled_batch_ids] @@ -522,19 +758,23 @@ def _query_delta_backed( predicate = ( f"WHERE batch_id NOT IN ({controlled_batch_str})" - if not has_gnome_access + if not has_gnome_access and controlled_batch_str and access_controlled else "" ) - - builder = QueryBuilder().register("tbl", tbl) + # TODO: do we need something like this? + # predicate += f"{' AND ' if predicate else 'WHERE '}version='{self.db_version}'" # Setup progress bar num_docs_needed: int = tbl.count() if not has_gnome_access: - num_docs_needed = self.count( - {"batch_id_neq_any": self.access_controlled_batch_ids} - ) + try: + num_docs_needed = self.count( + {"batch_id_neq_any": self.access_controlled_batch_ids} + ) + except MPRestError: + # batch_id isn't a valid field + num_docs_needed = self.count() pbar = ( tqdm( @@ -549,7 +789,7 @@ def _query_delta_backed( else None ) - iterator = builder.execute(f"SELECT * FROM tbl {predicate}") + iterator = self.query_builder.execute(f"SELECT * FROM {tbl_lbl} {predicate}") file_options = ds.ParquetFileFormat().make_write_options(compression="zstd") @@ -695,14 +935,21 @@ def _query_resource( if "tasks" in suffix: bucket_suffix, prefix = ("parsed", "core/tasks/") + elif suffix in STATIC_COLLECTIONS: + bucket_suffix = "build" + prefix = f"static-collections/{suffix}" else: + # TODO: remove once all collections are migrated to delta-backed format bucket_suffix = "build" - prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}" + prefix = f"collections/{suffix}" bucket = f"materialsproject-{bucket_suffix}" if self.delta_backed: - return self._query_delta_backed(bucket, prefix, timeout=timeout) + access_controlled = suffix not in STATIC_COLLECTIONS + return self._query_delta_backed( + bucket, prefix, access_controlled, timeout=timeout + ) # Paginate over all entries in the bucket. # TODO: change when a subset of entries needed from DB @@ -1448,8 +1695,10 @@ def __getattr__(self, v: str): use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + db_version=self.db_version, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ) return self.sub_resters[v] raise AttributeError(f"{self.__class__} has no attribute {v}") diff --git a/mp_api/client/core/schemas.py b/mp_api/client/core/schemas.py index a8a00181..d721ef69 100644 --- a/mp_api/client/core/schemas.py +++ b/mp_api/client/core/schemas.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import cached_property from importlib import import_module from itertools import chain from typing import TYPE_CHECKING, ForwardRef, get_args @@ -166,6 +167,12 @@ def new_dict(self, *args, **kwargs): data_model.__getattr__ = new_getattr data_model.dict = new_dict + for attr in dir(document_model): + if isinstance( + prop_method := getattr(document_model, attr), property | cached_property + ): + setattr(data_model, attr, prop_method) + return data_model, set_fields, fields_not_requested diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 5c15f8b4..116b7d1e 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -13,7 +13,7 @@ import pyarrow.dataset as ds from deltalake import DeltaTable from emmet.core import __version__ as _EMMET_CORE_VER -from emmet.core.mpid_ext import validate_identifier +from emmet.core.mpid import validate_identifier from monty.json import MontyDecoder from packaging.version import parse as parse_version @@ -27,6 +27,7 @@ from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS if TYPE_CHECKING: + from collections.abc import Iterable from typing import Any, Literal from pydantic._internal._model_construction import ModelMetaclass @@ -116,10 +117,8 @@ def validate_ids(id_list: list[str]) -> list[str]: " data for all IDs and filter locally." ) - # TODO: after the transition to AlphaID in the document models, - # The following line should be changed to - # return [validate_identifier(idx,serialize=True) for idx in id_list] - return [str(validate_identifier(idx)) for idx in id_list] + validated = [validate_identifier(idx, serialize=False) for idx in id_list] + return [getattr(idx, "string", str(idx)) for idx in validated] def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str: @@ -243,6 +242,14 @@ def __getattr__(self, v: str) -> Any: if hasattr(self._imported, v): return getattr(self._imported, v) + raise AttributeError( + f"{self._module_name}{'.' + self._class_name if self._class_name else ''} " + f"has no attribute {v}" + ) + + def __dir__(self) -> Iterable[str]: + return self._obj.__dir__() if hasattr(self._obj, "__dir__") else [] + class MPDataset: """Convenience wrapper for pyarrow datasets stored on disk.""" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 33166c45..6b6ebc13 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -10,8 +10,12 @@ from emmet.core.band_theory import BSPathType from emmet.core.mpid import MPID, AlphaID from emmet.core.types.enums import ThermoType +from emmet.core.types.pymatgen_types.computed_entries_adapter import ( + ComputedStructureEntryType, +) from emmet.core.vasp.calc_types import CalcType from packaging import version +from pydantic import BaseModel, TypeAdapter from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry from pymatgen.core import Composition, Element, Structure @@ -21,21 +25,15 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env -from mp_api.client.core import BaseRester from mp_api.client.core._oxygen_evolution import OxygenEvolution +from mp_api.client.core.client import _Rester from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, _emit_status_warning, ) from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS -from mp_api.client.core.utils import ( - LazyImport, - load_json, - validate_endpoint, - validate_ids, -) +from mp_api.client.core.utils import LazyImport, load_json, validate_ids from mp_api.client.routes import GENERIC_RESTERS from mp_api.client.routes.materials import MATERIALS_RESTERS from mp_api.client.routes.molecules import MOLECULES_RESTERS @@ -49,6 +47,7 @@ from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.pourbaix_diagram import PourbaixEntry + from pymatgen.electronic_structure.dos import Dos from pymatgen.entries.compatibility import Compatibility from pymatgen.entries.computed_entries import ( ComputedEntry, @@ -56,6 +55,7 @@ ) from pymatgen.util.typing import SpeciesLike + from mp_api.client.core.client import QueryBuilderWithCache from mp_api.client.core.schemas import _DictLikeAccess DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U_R2SCAN"]} @@ -85,23 +85,25 @@ ] -class MPRester: +class MPRester(_Rester): """Access the new Materials Project API.""" def __init__( self, api_key: str | None = None, endpoint: str | None = None, - notify_db_version: bool = False, include_user_agent: bool = True, use_document_model: bool = True, session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + db_version: str | None = None, local_dataset_cache: ( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilderWithCache | None = None, + notify_db_version: bool = False, **kwargs, ): """Initialize the MPRester. @@ -118,13 +120,6 @@ def __init__( interface. Defaults to the standard Materials Project REST address at "https://api.materialsproject.org", but can be changed to other URLs implementing a similar interface. - notify_db_version (bool): If True, the current MP database version will - be retrieved and logged locally in the ~/.mprester.log.yaml. If the database - version changes, you will be notified. The current database version is - also printed on instantiation. These local logs are not sent to - materialsproject.org and are not associated with your API key, so be - aware that a notification may not be presented if you run MPRester - from multiple computing environments. include_user_agent (bool): If True, will include a user agent with the HTTP request including information on pymatgen and system version making the API request. This helps MP support pymatgen users, and @@ -136,28 +131,38 @@ def __init__( session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. + db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database + than what is currently deployed. The Materials Project cannot guarantee that all + features will still work. local_dataset_cache: Target directory for downloading full datasets. Defaults to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset - **kwargs: access to ContribsClient kwargs or (possibly-deprecated) legacy kwargs + query_builder : Instance of QueryBuilderWithCache to use in querying delta tables + NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored. + notify_db_version (bool): If True, the current MP database version will + be retrieved and logged locally in the ~/.mprester.log.yaml. If the database + version changes, you will be notified. The current database version is + also printed on instantiation. These local logs are not sent to + materialsproject.org and are not associated with your API key, so be + aware that a notification may not be presented if you run MPRester + from multiple computing environments. + **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = get_user_api_key(api_key=api_key) - - self.endpoint = validate_endpoint(endpoint) - - self.headers = headers or get_consumer() - self.session = session or BaseRester._create_session( - api_key=self.api_key, + super().__init__( + api_key=api_key, + endpoint=endpoint, include_user_agent=include_user_agent, - headers=self.headers, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + db_version=db_version, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, + **kwargs, ) - if is_dev_env(): - self.session.headers["x-api-key"] = self.api_key or "" - self._include_user_agent = include_user_agent - self.use_document_model = use_document_model - self.mute_progress_bars = mute_progress_bars - self.local_dataset_cache = local_dataset_cache - self.force_renew = force_renew + self._contribs = None self._contribs_kwargs = { k: kwargs[k] @@ -198,14 +203,6 @@ def __init__( "chemenv", ] - if "monty_decode" in kwargs: - warnings.warn( - "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." - "The client by default returns results consistent with `monty_decode=True`.", - stacklevel=2, - category=MPRestWarning, - ) - # Check if emmet version of server is compatible if (emmet_version := MPRester.get_emmet_version(self.endpoint)) and ( version.parse(emmet_version.base_version) @@ -218,6 +215,17 @@ def __init__( stacklevel=2, ) + if self.db_version: + warnings.warn( + "Specifying an explicit database version is an experimental " + "feature. The Materials Project cannot guarantee " + "functionality at this time, use at your own risk!", + stacklevel=2, + category=MPRestWarning, + ) + else: + self.db_version = self._get_heartbeat_info(self.endpoint)[0] + if notify_db_version: self._db_version_check() @@ -236,13 +244,15 @@ def __init__( lazy_rester( api_key=self.api_key, endpoint=self.endpoint, - include_user_agent=self._include_user_agent, + include_user_agent=self.include_user_agent, session=self.session, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + db_version=self.db_version, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ), ) @@ -285,14 +295,6 @@ def contribs(self): ) return self._contribs - def __enter__(self): - """Support for "with" context.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Support for "with" context.""" - self.session.close() - def __getattr__(self, attr): if attr in self._deprecated_attributes: warnings.warn( @@ -315,8 +317,7 @@ def __dir__(self): ) def __repr__(self) -> str: - db_version = self.get_database_version() - return f"MPRester({'v' + db_version if db_version else 'unknown version'})" + return f"MPRester({'v' + self.db_version if self.db_version else 'unknown version'})" def get_task_ids_associated_with_material_id( self, material_id: str, calc_types: list[CalcType] | None = None @@ -379,7 +380,9 @@ def get_structure_by_material_id( return structure_data def get_database_version(self) -> str | None: - """The Materials Project database is periodically updated and has a + """DEPRECATED: use `self.db_version` instead. + + The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does change, while calculation data about a specific calculation task @@ -391,10 +394,13 @@ def get_database_version(self) -> str | None: Returns: database version as a string if accessible, None otherwise """ - if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: - _emit_status_warning() - return None - return get_resp.json()["db_version"] + warnings.warn( + "`get_database_version` has been deprecated in favor of " + "MPRester().db_version.", + stacklevel=2, + category=MPRestWarning, + ) + return self.db_version @staticmethod @cache @@ -639,9 +645,11 @@ def get_entries( ) for doc in docs: - entry_list = doc["entries"].values() - for entry in entry_list: - entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore + entry_list = (doc.model_dump() if isinstance(doc, BaseModel) else doc)[ + "entries" + ].values() + + for entry_dict in entry_list: if not compatible_only: entry_dict["correction"] = 0.0 entry_dict["energy_adjustments"] = [] @@ -671,7 +679,9 @@ def get_entries( correction["n_atoms"] *= site_ratio # Need to store object to permit de-duplication - entries.add(ComputedStructureEntry.from_dict(entry_dict)) + entries.add( + TypeAdapter(ComputedStructureEntryType).validate_python(entry_dict) + ) return list(entries) @@ -1157,18 +1167,34 @@ def get_bandstructure_by_material_id( material_id=material_id, path_type=path_type, line_mode=line_mode ) - def get_dos_by_material_id(self, material_id: str): - """Get the complete density of states pymatgen object associated with a Materials Project ID. + def get_dos_by_material_id(self, material_id: str) -> Dos: + """Get the density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ - return self.materials.electronic_structure_dos.get_dos_from_material_id( - material_id=material_id - ) # type: ignore + if ( + not ( + es_doc := self.materials.electronic_structure.search( + material_ids=material_id, fields=["dos"] + ) + ) + or not es_doc[0]["dos"] + ): + raise MPRestError(f"No DOS found for {material_id}") + + dos_data = es_doc[0]["dos"] + task_id = dos_data.task_id if self.use_document_model else dos_data["task_id"] + run_type = self.materials.tasks.search(task_ids=[task_id], fields=["run_type"])[ + 0 + ]["run_type"] + return self.materials.electronic_structure_dos.get_dos_from_task_id( + task_id, + run_type=run_type, + ) def get_phonon_dos_by_material_id(self, material_id: str): """Get phonon density of states data corresponding to a material_id. @@ -1575,24 +1601,26 @@ def get_stability( } chemsys_str = "-".join(sorted(str(ele) for ele in chemsys)) - thermo_type = ( - ThermoType(thermo_type) if isinstance(thermo_type, str) else thermo_type + thermo_type_valid_str: str = ( + ThermoType(thermo_type).value + if (isinstance(thermo_type, str) and thermo_type != "r2SCAN") + else str(thermo_type) ) corrector: Compatibility | None = None - if thermo_type == ThermoType.GGA_GGA_U: + if thermo_type_valid_str == ThermoType.GGA_GGA_U.value: from pymatgen.entries.compatibility import MaterialsProject2020Compatibility corrector = MaterialsProject2020Compatibility() - elif thermo_type == ThermoType.GGA_GGA_U_R2SCAN: + elif thermo_type_valid_str == ThermoType.GGA_GGA_U_R2SCAN.value: from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme corrector = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN") try: pd = self.materials.thermo.get_phase_diagram_from_chemsys( - chemsys_str, thermo_type=thermo_type + chemsys_str, thermo_type=thermo_type_valid_str ) except MPRestError: pd = None @@ -1600,7 +1628,7 @@ def get_stability( if not pd: warnings.warn( f"No phase diagram data available for chemical system {chemsys_str} " - f"and thermo type {thermo_type}.", + f"and thermo type {thermo_type_valid_str}.", category=MPRestWarning, stacklevel=2, ) @@ -1674,7 +1702,6 @@ def _db_version_check(self) -> None: """Check if the database version has drifted.""" import yaml # type: ignore[import-untyped] - db_version = self.get_database_version() old_db_version = None if MAPI_CLIENT_SETTINGS.LOG_FILE.exists(): old_db_version = ( @@ -1685,15 +1712,15 @@ def _db_version_check(self) -> None: if not isinstance(old_db_version, str): old_db_version = None - if old_db_version != db_version: + if old_db_version != self.db_version: MAPI_CLIENT_SETTINGS.LOG_FILE.write_text( - yaml.safe_dump({"MAPI_DB_VERSION": db_version}) + yaml.safe_dump({"MAPI_DB_VERSION": self.db_version}) ) if old_db_version: warnings.warn( "Materials Project database version has changed " - f"from v{old_db_version} to v{db_version}.", + f"from v{old_db_version} to v{self.db_version}.", category=MPRestWarning, stacklevel=2, ) diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index c55e3758..26b268ca 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -12,6 +12,7 @@ class DOIRester(BaseRester): suffix = "doi" document_model = DOIDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 0a74fc1c..209f0a4d 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -3,12 +3,15 @@ import warnings from collections import defaultdict -from emmet.core.electrode import ConversionElectrodeDoc, InsertionElectrodeDoc +from emmet.core.electrode import ( + ConversionElectrodeDoc, + InsertionElectrodeDoc, + validate_battery_id, +) from pymatgen.core.periodic_table import Element from mp_api.client.core import BaseRester -from mp_api.client.core.exceptions import MPRestWarning -from mp_api.client.core.utils import validate_ids +from mp_api.client.core.exceptions import MPRestError, MPRestWarning class BaseElectrodeRester(BaseRester): @@ -104,9 +107,19 @@ def search( {f"{param}_min": value[0], f"{param}_max": value[1]} ) elif param == "battery_ids": - query_params[param] = ",".join( - validate_ids([value] if isinstance(value, str) else value) - ) + _battery_ids = [value] if isinstance(value, str) else value + try: + for battery_id in _battery_ids: + validate_battery_id(battery_id) + query_params[param] = ",".join(_battery_ids) + + except Exception: + raise MPRestError( + f"At least one battery_id in: {value} is invalid." + " Try using the validate_battery_id function from emmet.core.electrode" + " to test your inputs." + ) + elif param == "working_ion": query_params["working_ion"] = ",".join( str(ele) diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index b0bee09e..c207e5f4 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -4,20 +4,23 @@ from collections import defaultdict from typing import TYPE_CHECKING -from emmet.core.band_theory import BSPathType -from emmet.core.electronic_structure import ( - DOSProjectionType, - ElectronicStructureDoc, -) +from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos +from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc +from emmet.core.mpid import AlphaID +from emmet.core.vasp.calc_types.enums import RunType from pymatgen.analysis.magnetism.analyzer import Ordering from pymatgen.core.periodic_table import Element +from pymatgen.electronic_structure.bandstructure import ( + BandStructure, + BandStructureSymmLine, +) from pymatgen.electronic_structure.core import OrbitalType, Spin from mp_api.client.core import BaseRester, MPRestError -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: - from pymatgen.electronic_structure.dos import CompleteDos + from pymatgen.electronic_structure.dos import Dos class ElectronicStructureRester(BaseRester): @@ -167,6 +170,7 @@ def es_rester(self) -> ElectronicStructureRester: class BandStructureRester(BaseESPropertyRester): suffix = "materials/electronic_structure/bandstructure" + delta_backed = False def search_bandstructure_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" @@ -255,20 +259,51 @@ def search( **query_params, ) - def get_bandstructure_from_task_id(self, task_id: str): + def get_bandstructure_from_task_id( + self, + task_id: str, + run_type: str | RunType | None = None, + path_type: str | BSPathType | None = None, + ) -> BandStructure: """Get the band structure pymatgen object associated with a given task ID. Arguments: task_id (str): Task ID for the band structure calculation - + run_type (str, RunType, or None): Optional run type, + will speed up query due to delta table partitioning. + path_type (str, BSPathType, or None) : Optional path type to + speed up query Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] + bs_lbl, _ = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/bandstructures/", + label="bandstructure", + ) + + query = f""" + SELECT * + FROM {bs_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + query += f"\nAND run_type='{rt.value}'" + if path_type: + query += f"\nAND path_convention='{path_type}'" + + table = self._query_delta_single(query) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + emmet_bs = ElectronicBS(**deser[0]) + return emmet_bs.to_pmg( + pmg_cls=BandStructureSymmLine if emmet_bs.labels_dict else BandStructure + ) + raise MPRestError( + f"No bandstructure data found for {task_id=}" + + (f"run_type={rt}" if run_type else "") + ) def get_bandstructure_from_material_id( self, @@ -291,7 +326,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["bandstructure"] ) if not bs_doc: - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( @@ -314,7 +351,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["dos"] ) ): - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( @@ -327,9 +366,12 @@ def get_bandstructure_from_material_id( raise MPRestError( f"No uniform band structure data found for {material_id}" ) - bs_task_id = bs_data["total"]["1"]["task_id"] + bs_task_id = bs_data["task_id"] - bs_obj = self.get_bandstructure_from_task_id(bs_task_id) + bs_obj = self.get_bandstructure_from_task_id( + bs_task_id, + path_type=path_type if line_mode else BSPathType.unknown, + ) if bs_obj: return bs_obj @@ -338,6 +380,7 @@ def get_bandstructure_from_material_id( class DosRester(BaseESPropertyRester): suffix = "materials/electronic_structure/dos" + delta_backed = False def search_dos_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" @@ -451,42 +494,62 @@ def search( **query_params, ) - def get_dos_from_task_id(self, task_id: str) -> CompleteDos: + def get_dos_from_task_id( + self, task_id: str, run_type: str | RunType | None = None + ) -> Dos: """Get the density of states pymatgen object associated with a given calculation ID. Arguments: task_id (str): Task ID for the density of states calculation + run_type (str, RunType, or None): Optional run type to query by. + Will speed up query due to delta table partitioning. Returns: - bandstructure (CompleteDos): CompleteDos object + pymatgen Dos + """ + dos_lbl, _ = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/total-dos/", + label="total_dos", + ) + + query = f""" + SELECT * + FROM {dos_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"dos/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] - def get_dos_from_material_id(self, material_id: str): + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + query += f"\nAND run_type='{rt.value}'" + + table = self._query_delta_single(query) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + return ElectronicDos(**deser[0]).to_pmg() + raise MPRestError( + f"No DOS data found for {task_id=}" + (f"run_type={rt}" if run_type else "") + ) + + def get_dos_from_material_id(self, material_id: str) -> Dos: """Get the complete density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ if not ( dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"]) ): - return None + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if not (dos_data := dos_doc[0].get("dos")): raise MPRestError(f"No density of states data found for {material_id}") dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[ - "total" - ]["1"]["task_id"] - if dos_obj := self.get_dos_from_task_id(dos_task_id): - return dos_obj - - raise MPRestError("No density of states object found.") + "task_id" + ] + return self.get_dos_from_task_id(dos_task_id) diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index 0182eb6f..aad22f5f 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -1,32 +1,34 @@ from __future__ import annotations +import warnings from collections import defaultdict from emmet.core.eos import EOSDoc -from mp_api.client.core import BaseRester +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning from mp_api.client.core.utils import validate_ids class EOSRester(BaseRester): suffix = "materials/eos" document_model = EOSDoc # type: ignore - primary_key = "material_id" + primary_key = "task_id" def search( self, - material_ids: str | list[str] | None = None, + task_ids: str | list[str] | None = None, energies: tuple[float, float] | None = None, volumes: tuple[float, float] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + **kwargs, ) -> list[EOSDoc] | list[dict]: """Query equations of state docs using a variety of search criteria. Arguments: - material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs + task_ids (str, List[str]): Search for equation of states associated with the specified task IDs energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range. volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. @@ -34,17 +36,31 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in EOSDoc to return data for. Default is material_id only if all_fields is False. + **kwargs : used for handling deprecated kwargs Returns: ([EOSDoc], [dict]) List of equations of state docs or dictionaries. """ query_params: dict = defaultdict(dict) - if material_ids: - if isinstance(material_ids, str): - material_ids = [material_ids] + if "material_ids" in kwargs: + if task_ids: + raise MPRestError( + "You have specified both `task_ids` and the deprecated `material_ids` tag. " + "Please specify only `task_ids`." + ) + task_ids = kwargs.pop("material_ids") + warnings.warn( + "`material_id` has been replaced by `task_id` in the EOS endpoint. " + "Please migrate to using the newer field name.", + stacklevel=2, + category=MPRestWarning, + ) - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if task_ids: + query_params["task_ids"] = ",".join( + validate_ids([task_ids] if isinstance(task_ids, str) else task_ids) + ) if volumes: query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) diff --git a/mp_api/client/routes/materials/grain_boundaries.py b/mp_api/client/routes/materials/grain_boundaries.py index 6949b9de..d9ac75c3 100644 --- a/mp_api/client/routes/materials/grain_boundaries.py +++ b/mp_api/client/routes/materials/grain_boundaries.py @@ -12,6 +12,7 @@ class GrainBoundaryRester(BaseRester): suffix = "materials/grain_boundaries" document_model = GrainBoundaryDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 0373cd0d..42b92fb0 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -1,11 +1,12 @@ from __future__ import annotations +import warnings from collections import defaultdict from typing import TYPE_CHECKING from emmet.core.phonon import PhononBS, PhononBSDOSDoc, PhononDOS -from mp_api.client.core import BaseRester, MPRestError +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: @@ -17,21 +18,22 @@ class PhononRester(BaseRester): suffix = "materials/phonon" document_model = PhononBSDOSDoc # type: ignore - primary_key = "material_id" + primary_key = "identifier" def search( self, - material_ids: str | list[str] | None = None, + identifiers: str | list[str] | None = None, phonon_method: str | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + **kwargs, ) -> list[PhononBSDOSDoc] | list[dict]: """Query phonon docs using a variety of search criteria. Arguments: - material_ids (str, List[str]): A single Material ID string or list of strings + identifiers (str, List[str]): A single Phonon Task ID string or list of strings (e.g., mp-149, [mp-149, mp-13]). phonon_method (str): phonon method to search (dfpt, phonopy, pheasy) num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. @@ -39,17 +41,34 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in PhononBSDOSDoc to return data for. Default is material_id, last_updated, and formula_pretty if all_fields is False. + **kwargs : used for handling deprecated kwargs Returns: ([PhononBSDOSDoc], [dict]) List of phonon documents or dictionaries. """ query_params: dict = defaultdict(dict) - if material_ids: - if isinstance(material_ids, str): - material_ids = [material_ids] - - query_params["material_ids"] = ",".join(validate_ids(material_ids)) + if "material_ids" in kwargs: + if identifiers: + raise MPRestError( + "You have specified both `identifiers` and the deprecated `material_ids` tag. " + "Please specify only `identifiers`." + ) + identifiers = kwargs.pop("material_ids") + warnings.warn( + "`material_id` has been replaced by `identifier` in the phonon endpoint. " + "Please migrate to using the newer field name and the more generic `identifiers` kwarg " + "for searching.", + stacklevel=2, + category=MPRestWarning, + ) + + if identifiers: + query_params["identifiers"] = ",".join( + validate_ids( + [identifiers] if isinstance(identifiers, str) else identifiers + ) + ) if phonon_method and phonon_method in {"dfpt", "phonopy", "pheasy"}: query_params["phonon_method"] = phonon_method @@ -68,13 +87,13 @@ def search( **query_params, ) - def get_bandstructure_from_material_id( - self, material_id: str, phonon_method: str + def get_bandstructure_from_phonon_id( + self, identifier: str, phonon_method: str ) -> PhononBS | dict[str, Any]: - """Get the phonon band structure pymatgen object associated with a given material ID and phonon method. + """Get the phonon band structure pymatgen object associated with a given phonon ID and phonon method. Arguments: - material_id (str): Material ID for the phonon band structure calculation + identifier (str): Phonon ID for the phonon band structure calculation phonon_method (str): phonon method, i.e. pheasy or dfpt Returns: @@ -82,7 +101,7 @@ def get_bandstructure_from_material_id( """ result = self._query_open_data( bucket="materialsproject-parsed", - key=f"ph-bandstructures/{phonon_method}/{material_id}.json.gz", + key=f"ph-bandstructures/{phonon_method}/{identifier}.json.gz", )[0][0] return ( @@ -91,13 +110,25 @@ def get_bandstructure_from_material_id( else result # type: ignore[return-value] ) - def get_dos_from_material_id( + def get_bandstructure_from_material_id( self, material_id: str, phonon_method: str + ) -> PhononBS | dict[str, Any]: + """Deprecated: use `get_bandstructure_from_phonon_id` instead.""" + warnings.warn( + "`material_id` has been replaced by `identifier` in the phonon endpoint. " + "Please migrate to using `get_bandstructure_from_phonon_id` with the `identifier` kwarg.", + stacklevel=2, + category=MPRestWarning, + ) + return self.get_bandstructure_from_phonon_id(material_id, phonon_method) + + def get_dos_from_phonon_id( + self, identifier: str, phonon_method: str ) -> PhononDOS | dict[str, Any]: - """Get the phonon dos pymatgen object associated with a given material ID and phonon method. + """Get the phonon dos pymatgen object associated with a given phonon ID and phonon method. Arguments: - material_id (str): Material ID for the phonon dos calculation + identifier (str): Phonon ID for the phonon dos calculation phonon_method (str): phonon method, i.e. pheasy or dfpt Returns: @@ -105,7 +136,7 @@ def get_dos_from_material_id( """ result = self._query_open_data( bucket="materialsproject-parsed", - key=f"ph-dos/{phonon_method}/{material_id}.json.gz", + key=f"ph-dos/{phonon_method}/{identifier}.json.gz", )[0][0] return ( @@ -114,41 +145,88 @@ def get_dos_from_material_id( else result # type: ignore[return-value] ) - def get_forceconstants_from_material_id( - self, material_id: str + def get_dos_from_material_id( + self, material_id: str, phonon_method: str + ) -> PhononDOS | dict[str, Any]: + """Deprecated: use `get_dos_from_phonon_id` instead.""" + warnings.warn( + "`material_id` has been replaced by `identifier` in the phonon endpoint. " + "Please migrate to using `get_dos_from_phonon_id` with the `identifier` kwarg.", + stacklevel=2, + category=MPRestWarning, + ) + return self.get_dos_from_phonon_id(material_id, phonon_method) + + def get_forceconstants_from_phonon_id( + self, identifier: str ) -> list[list[Matrix3D]]: - """Get the force constants associated with a given material ID. + """Get the force constants associated with a given phonon ID. Arguments: - material_id (str): Material ID for the force constants calculation + identifier (str): Phonon ID for the force constants calculation Returns: - force constants (list[list[Matrix3D]]): PhononDOS object + force constants (list[list[Matrix3D]]): force constants """ return self._query_open_data( # type: ignore[return-value] bucket="materialsproject-parsed", - key=f"ph-force-constants/{material_id}.json.gz", + key=f"ph-force-constants/{identifier}.json.gz", )[0][0] - def compute_thermo_quantities(self, material_id: str, phonon_method: str): - """Compute thermodynamical quantities for given material ID and phonon_method. + def get_forceconstants_from_material_id( + self, material_id: str + ) -> list[list[Matrix3D]]: + """Deprecated: use `get_forceconstants_from_phonon_id` instead.""" + warnings.warn( + "`material_id` has been replaced by `identifier` in the phonon endpoint. " + "Please migrate to using `get_forceconstants_from_phonon_id` with the `identifier` kwarg.", + stacklevel=2, + category=MPRestWarning, + ) + return self.get_forceconstants_from_phonon_id(material_id) + + def compute_thermo_quantities( + self, + identifier: str | None = None, + phonon_method: str | None = None, + **kwargs, + ): + """Compute thermodynamical quantities for given phonon ID and phonon_method. Arguments: - material_id (str): Material ID to calculate quantities for + identifier (str): Phonon ID to calculate quantities for phonon_method (str): phonon method, i.e. pheasy or dfpt + **kwargs : used for handling deprecated kwargs Returns: quantities (dict): thermodynamical quantities """ + if "material_id" in kwargs: + if identifier: + raise MPRestError( + "You have specified both `identifier` and the deprecated `material_id` tag. " + "Please specify only `identifier`." + ) + identifier = kwargs.pop("material_id") + warnings.warn( + "`material_id` has been replaced by `identifier` in the phonon endpoint. " + "Please migrate to using the newer field name and the more generic `identifier` kwarg.", + stacklevel=2, + category=MPRestWarning, + ) + + if identifier is None: + raise MPRestError("`identifier` must be specified.") + use_document_model = self.use_document_model self.use_document_model = False - docs = self.search(material_ids=material_id, phonon_method=phonon_method) + docs = self.search(identifiers=identifier, phonon_method=phonon_method) if not docs or not docs[0]: raise MPRestError("No phonon document found") self.use_document_model = True - docs[0]["phonon_dos"] = self.get_dos_from_material_id( # type: ignore[index] - material_id, phonon_method + docs[0]["phonon_dos"] = self.get_dos_from_phonon_id( # type: ignore[index] + identifier, phonon_method # type: ignore[arg-type] ) doc = PhononBSDOSDoc(**docs[0]) # type: ignore[arg-type] self.use_document_model = use_document_model diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index aa6cab71..0ba8c5b7 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -26,6 +26,7 @@ class SimilarityRester(BaseRester): suffix = "materials/similarity" document_model = SimilarityDoc # type: ignore primary_key = "material_id" + delta_backed = False _fingerprinter: SimilarityScorer | None = None diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 62eaa676..6f1096b1 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -11,6 +11,7 @@ class SubstratesRester(BaseRester): suffix = "materials/substrates" document_model = SubstratesDoc # type: ignore primary_key = "film_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 76d9e60c..3a36d5f9 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -12,6 +12,7 @@ class SurfacePropertiesRester(BaseRester): suffix = "materials/surface_properties" document_model = SurfacePropDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/synthesis.py b/mp_api/client/routes/materials/synthesis.py index 6788814c..4567c51f 100644 --- a/mp_api/client/routes/materials/synthesis.py +++ b/mp_api/client/routes/materials/synthesis.py @@ -12,6 +12,7 @@ class SynthesisRester(BaseRester): suffix = "materials/synthesis" document_model = SynthesisSearchResultModel # type: ignore + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 66a03758..6feefbf9 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,8 +3,6 @@ from datetime import datetime from typing import TYPE_CHECKING -import pyarrow as pa -from deltalake import DeltaTable, QueryBuilder from emmet.core.mpid import MPID, AlphaID from emmet.core.tasks import CoreTaskDoc from emmet.core.trajectory import RelaxTrajectory @@ -40,23 +38,23 @@ def get_trajectory( dict representing emmet.core.trajectory.RelaxTrajectory """ as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - predicate = ( - f"WHERE run_type='{str(run_type)}' AND identifier='{as_alpha}'" - if run_type - else f"WHERE identifier='{as_alpha}'" - ) + f"WHERE run_type='{str(run_type)}' AND " if run_type else "" + ) + f"WHERE identifier='{as_alpha}'" - traj_tbl = DeltaTable( - "s3a://materialsproject-parsed/core/trajectories/", - storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + traj_lbl, _ = self._get_delta_table( + "materialsproject-parsed", + "core/trajectories/", + label="traj", ) - traj_data = pa.table(QueryBuilder().register("traj", traj_tbl).execute(f""" - SELECT * - FROM traj - {predicate}; - """).read_all()).to_pylist(maps_as_pydicts="strict") + query = f""" + SELECT * + FROM {traj_lbl} + {predicate}; + """ + + traj_data = self._query_delta_single(query).to_pylist(maps_as_pydicts="strict") if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index a2088a7f..409fadb2 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -1,22 +1,58 @@ from __future__ import annotations from collections import defaultdict +from typing import TYPE_CHECKING import numpy as np -from emmet.core.thermo import ThermoDoc +from emmet.core.thermo import ThermoDoc, validate_thermo_id from emmet.core.types.enums import ThermoType +from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType +from pydantic import TypeAdapter from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element -from pymatgen.core import __version__ as __pmg_version__ from mp_api.client.core import BaseRester -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.exceptions import MPRestError +from mp_api.client.core.utils import validate_ids + +if TYPE_CHECKING: + from collections.abc import Sequence + + from enums import Enum class ThermoRester(BaseRester): suffix = "materials/thermo" document_model = ThermoDoc # type: ignore - primary_key = "thermo_id" + primary_key = "material_id" + + @staticmethod + def _check_thermo_types(thermo_types: Sequence[str | Enum]) -> set[str]: + """Check if a user has input any invalid thermo types. + + Args: + thermo_types (Sequence of str or Enum) : list of thermo types + the user has queried for + + phase-diagram tbl has "r2SCAN", not "R2SCAN" + mixing of ThermoType/RunType in emmet -_- + TODO: coerce upstream? allow case-insensitivity in emmet? + + Returns: + set of str: validated thermo types + + Raises: + ValueError if any invalid thermo types are input + """ + t_types: set[str] = {t if isinstance(t, str) else t.value for t in thermo_types} + t_types = {"r2SCAN" if t == "R2SCAN" else t for t in t_types} + valid_types = {"r2SCAN", *map(str, ThermoType.__members__.values())} + + if invalid_types := t_types - valid_types: + raise ValueError( + f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}" + ) + return t_types def search( self, @@ -55,7 +91,7 @@ def search( material_ids (List[str]): List of Materials Project IDs to return data for. thermo_ids (List[str]): List of thermo IDs to return data for. This is a combination of the Materials Project ID and thermo type (e.g. mp-149_GGA_GGA+U). - thermo_types (List[ThermoType]): List of thermo types to return data for (e.g. ThermoType.GGA_GGA_U). + thermo_types (List[ThermoType or str]): List of thermo/run types to return data for (e.g. ThermoType.GGA_GGA_U). num_elements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider. total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider. uncorrected_energy (Tuple[float,float]): Minimum and maximum uncorrected total @@ -90,16 +126,22 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if thermo_ids: - query_params.update({"thermo_ids": ",".join(validate_ids(thermo_ids))}) + try: + for thermo_id in thermo_ids: + validate_thermo_id(thermo_id) + query_params.update({"thermo_ids": ",".join(thermo_ids)}) + + except Exception: + raise MPRestError( + f"At least one thermo_id in: {thermo_ids} is invalid." + " Try using the validate_thermo_id function from emmet.core.thermo" + " to test your inputs." + ) if thermo_types: - t_types = {t if isinstance(t, str) else t.value for t in thermo_types} - valid_types = {*map(str, ThermoType.__members__.values())} - if invalid_types := t_types - valid_types: - raise ValueError( - f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}" - ) - query_params.update({"thermo_types": ",".join(t_types)}) + query_params.update( + {"thermo_types": ",".join(self._check_thermo_types(thermo_types))} + ) if num_elements: if isinstance(num_elements, int): @@ -156,32 +198,29 @@ def get_phase_diagram_from_chemsys( Returns: (PhaseDiagram): Pymatgen phase diagram object. """ - t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value - valid_types = {*map(str, ThermoType.__members__.values())} - if invalid_types := {t_type} - valid_types: - raise ValueError( - f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}" - ) + validated_thermo_type = self._check_thermo_types([thermo_type]).pop() sorted_chemsys = "-".join(sorted(chemsys.split("-"))) - phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" version = self.db_version.replace(".", "-") - obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd_dct = self._query_open_data( # type: ignore[union-attr] - bucket="materialsproject-build", - key=obj_key, - decoder=lambda x: load_json(x, deser=False), - )[0][0].get("phase_diagram") - - pd = PhaseDiagram.from_dict( - { # type: ignore[arg-type] - k: v if k != "elements" else [e.get("element", e) for e in v] - for k, v in pd_dct.items() # type: ignore[union-attr] - } # post pymatgen/-core split, different serialization behavior - if int(__pmg_version__.split(".", 1)[0]) >= 2026 - else pd_dct # pymatgen<=2025.10.7 + + pd_lbl, _ = self._get_delta_table( + "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" ) + query = f""" + SELECT phase_diagram + FROM {pd_lbl} + WHERE chemsys='{sorted_chemsys}' + AND version='{version}' + AND thermo_type='{validated_thermo_type}' + """ + table = self._query_delta_single(query) + as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict") + + pd: PhaseDiagram | None = None + if len(pds := TypeAdapter(list[PhaseDiagramType]).validate_python(as_py)) > 0: + pd = pds[0] + # Ensure el_ref keys are Element objects for PDPlotter. # Ensure qhull_data is a numpy array # This should be fixed in pymatgen diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index a4f164f8..01dfed41 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING -from emmet.core.xas import XASDoc +from emmet.core.xas import XASDoc, validate_xas_spectrum_id from pymatgen.core.periodic_table import Element -from mp_api.client.core import BaseRester +from mp_api.client.core import BaseRester, MPRestWarning +from mp_api.client.core.exceptions import MPRestError if TYPE_CHECKING: from typing import Any @@ -16,7 +18,8 @@ class XASRester(BaseRester): suffix = "materials/xas" document_model = XASDoc # type: ignore - primary_key = "spectrum_id" + primary_key = "task_id" + delta_backed = False def search( self, @@ -25,7 +28,7 @@ def search( formula: str | None = None, chemsys: str | list[str] | None = None, elements: list[str] | None = None, - material_ids: list[str] | None = None, + task_ids: list[str] | None = None, spectrum_type: XasType | None = None, spectrum_ids: str | list[str] | None = None, num_chunks: int | None = None, @@ -34,6 +37,7 @@ def search( fields: list[str] | None = None, _page: int | None = None, _sort_fields: str | None = None, + **kwargs, ): """Query core XAS docs using a variety of search criteria. @@ -45,7 +49,7 @@ def search( chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). elements (List[str]): A list of elements. - material_ids (str, List[str]): A single Material ID string or list of strings + task_ids (str, List[str]): A single Task ID string or list of strings (e.g., mp-149, [mp-149, mp-13]). spectrum_type (XasType): Spectrum type (e.g. EXAFS, XAFS, or XANES). spectrum_ids (str, List[str]): A single Spectrum ID string or list of strings @@ -57,10 +61,26 @@ def search( Default is material_id, last_updated, and formula_pretty if all_fields is False. _page (int or None) : Page of the results to skip to. _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. + **kwargs : used for handling deprecated kwargs Returns: ([MaterialsDoc]) List of material documents """ + if "material_ids" in kwargs: + if task_ids: + raise MPRestError( + "You have specified both `task_ids` and the deprecated `material_ids` tag. " + "Please specify only `task_ids`." + ) + task_ids = kwargs.pop("material_ids") + warnings.warn( + "`material_id` has been replaced by `task_id` in the xas endpoint. " + "Please migrate to using the newer field name and the `task_ids` kwarg " + "for searching.", + stacklevel=2, + category=MPRestWarning, + ) + _locals = locals() query_params: dict[str, Any] = { k: _locals[k] @@ -78,9 +98,21 @@ def search( ) } ) - for k in ("chemsys", "elements", "material_ids", "spectrum_ids"): + for k in ("chemsys", "elements", "task_ids", "spectrum_ids"): if (v := _locals.get(k)) is not None: - query_params[k] = ",".join([v] if isinstance(v, str) else v) + _v = [v] if isinstance(v, str) else v + if k == "spectrum_ids": + try: + for spectrum_id in k: + validate_xas_spectrum_id(spectrum_id) + except Exception: + raise MPRestError( + f"At least one spectrum_id in: {_v} is invalid." + " Try using the validate_xas_spectrum_id function from emmet.core.xas" + " to test your inputs." + ) + + query_params[k] = ",".join(_v) query_params = { entry: query_params[entry] diff --git a/mp_api/client/routes/molecules/jcesr.py b/mp_api/client/routes/molecules/jcesr.py index 2d462c19..24d3f5e6 100644 --- a/mp_api/client/routes/molecules/jcesr.py +++ b/mp_api/client/routes/molecules/jcesr.py @@ -15,6 +15,7 @@ class JcesrMoleculesRester(BaseRester): suffix = "molecules/jcesr" document_model = MoleculesDoc # type: ignore primary_key = "task_id" + delta_backed = False def __init__(self, **kwargs): """Throw deprecation warning when JCESR client is initialized.""" diff --git a/mp_api/client/routes/molecules/molecules.py b/mp_api/client/routes/molecules/molecules.py index b7600328..3171b55c 100644 --- a/mp_api/client/routes/molecules/molecules.py +++ b/mp_api/client/routes/molecules/molecules.py @@ -20,3 +20,4 @@ class MoleculeRester(CoreRester): primary_key = "molecule_id" suffix = "molecules/core" _sub_resters = MOLECULES_RESTERS + delta_backed = False diff --git a/mp_api/client/routes/molecules/summary.py b/mp_api/client/routes/molecules/summary.py index 4be3aab5..2f91677e 100644 --- a/mp_api/client/routes/molecules/summary.py +++ b/mp_api/client/routes/molecules/summary.py @@ -12,6 +12,7 @@ class MoleculesSummaryRester(BaseRester): suffix = "molecules/summary" document_model = MoleculeSummaryDoc # type: ignore primary_key = "molecule_id" + delta_backed = False def search( self, diff --git a/pyproject.toml b/pyproject.toml index da7257ff..dcff023c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.86.4rc1,<0.86.5", + "emmet-core>=0.87.0rc1,<0.87.2", "boto3", "orjson >= 3.10,<4", "pyarrow >= 20.0.0", @@ -49,7 +49,7 @@ contribs = [ ] all = [ "custodian", - "emmet-core[all]>=0.86.4rc1,<0.86.5", + "emmet-core[all]>=0.87.0rc1,<0.87.2", "fastmcp", "flask", ] diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index c9da5cf8..f871bb4e 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -12,9 +12,9 @@ bibtexparser==1.4.4 # via pymatgen-core blake3==1.0.8 # via emmet-core -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer @@ -30,11 +30,11 @@ deltalake==1.5.1 # via mp-api (pyproject.toml) deprecated==1.3.1 # via deltalake -emmet-core==0.86.4 +emmet-core==0.87.0rc1 # via mp-api (pyproject.toml) fonttools==4.62.1 # via matplotlib -idna==3.13 +idna==3.15 # via requests inflect==7.5.0 # via emmet-core @@ -61,7 +61,7 @@ more-itertools==11.0.2 # via inflect mpmath==1.3.0 # via sympy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via pymatgen-core @@ -85,7 +85,7 @@ packaging==26.2 # plotly palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via pymatgen-core pillow==12.2.0 # via matplotlib @@ -104,7 +104,7 @@ pydantic==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # pymatgen-io-validation @@ -130,7 +130,7 @@ python-dotenv==1.2.2 # via pydantic-settings pyyaml==6.0.3 # via pybtex -requests==2.33.1 +requests==2.34.0 # via # mp-api (pyproject.toml) # pymatgen-core @@ -170,7 +170,7 @@ typing-inspection==0.4.2 # pydantic-settings uncertainties==3.2.3 # via pymatgen-core -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 72b37027..8f40e9cd 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -24,6 +24,8 @@ arrow==1.4.0 # via isoduration ase==3.28.0 # via pymatgen-analysis-diffusion +ast-serialize==0.3.0 + # via mypy asttokens==3.0.1 # via stack-data attrs==26.1.0 @@ -47,15 +49,15 @@ blinker==1.9.0 # via flask boltons==25.0.0 # via mp-api (pyproject.toml) -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer bravado==12.0.1 # via mp-api (pyproject.toml) -bravado-core==6.3.1 +bravado-core==6.4.1 # via bravado cachetools==7.1.1 # via @@ -80,7 +82,7 @@ click==8.3.3 # uvicorn contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.5 +coverage[toml]==7.14.0 # via pytest-cov cryptography==48.0.0 # via @@ -92,7 +94,7 @@ custodian==2025.12.14 # via mp-api (pyproject.toml) cycler==0.12.1 # via matplotlib -cyclopts==4.11.2 +cyclopts==4.12.0 # via fastmcp decorator==5.2.1 # via ipython @@ -115,7 +117,7 @@ docutils==0.22.4 # sphinx email-validator==2.3.0 # via pydantic -emmet-core[all]==0.86.4 +emmet-core[all]==0.87.0rc1 # via mp-api (pyproject.toml) exceptiongroup==1.3.1 # via fastmcp @@ -159,7 +161,7 @@ httpx-sse==0.4.3 # via mcp identify==2.6.19 # via pre-commit -idna==3.13 +idna==3.15 # via # anyio # email-validator @@ -241,11 +243,13 @@ latexcodec==3.0.1 # via pybtex lazy-loader==0.5 # via scikit-image -librt==0.10.0 +librt==0.11.0 # via mypy +lobsterpy==0.6.1 + # via emmet-core lxml==6.1.0 # via pymatgen-core -markdown-it-py==4.0.0 +markdown-it-py==4.2.0 # via rich markupsafe==3.0.3 # via @@ -257,11 +261,11 @@ matplotlib==3.10.9 # ase # pymatgen-core # seaborn -matplotlib-inline==0.2.1 +matplotlib-inline==0.2.2 # via ipython mccabe==0.7.0 # via flake8 -mcp==1.27.0 +mcp==1.27.1 # via fastmcp mdurl==0.1.2 # via markdown-it-py @@ -286,13 +290,13 @@ msgpack==1.1.2 # via # bravado # bravado-core -mypy==1.20.2 +mypy==2.1.0 # via mp-api (pyproject.toml) mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via @@ -305,6 +309,7 @@ numpy==2.4.4 # ase # contourpy # imageio + # lobsterpy # matplotlib # monty # pandas @@ -338,7 +343,7 @@ packaging==26.2 # sphinx palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via # pymatgen-core # seaborn @@ -391,6 +396,8 @@ pyarrow==24.0.0 # mp-api (pyproject.toml) pybtex==0.26.1 # via emmet-core +pycodcif==3.0.1 + # via emmet-core pycodestyle==2.14.0 # via # flake8 @@ -407,7 +414,7 @@ pydantic[email]==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # mcp @@ -428,6 +435,7 @@ pyjwt[crypto]==2.12.1 pymatgen==2026.5.4 # via # emmet-core + # lobsterpy # mp-api (pyproject.toml) # mp-pyrho # pymatgen-analysis-alloys @@ -475,13 +483,13 @@ python-dateutil==2.9.0.post0 # bravado-core # matplotlib # pandas -python-discovery==1.3.0 +python-discovery==1.3.1 # via virtualenv python-dotenv==1.2.2 # via # fastmcp # pydantic-settings -python-multipart==0.0.27 +python-multipart==0.0.28 # via mcp pytz==2026.2 # via bravado-core @@ -499,7 +507,7 @@ referencing==0.37.0 # jsonschema # jsonschema-path # jsonschema-specifications -requests==2.33.1 +requests==2.34.0 # via # bravado # bravado-core @@ -580,7 +588,7 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sse-starlette==3.4.1 +sse-starlette==3.4.4 # via mcp stack-data==0.6.3 # via ipython @@ -606,13 +614,13 @@ traitlets==5.15.0 # matplotlib-inline typeguard==4.5.1 # via inflect -types-requests==2.33.0.20260503 +types-requests==2.33.0.20260513 # via # mp-api (pyproject.toml) # types-tqdm -types-setuptools==82.0.0.20260408 +types-setuptools==82.0.0.20260508 # via mp-api (pyproject.toml) -types-tqdm==4.67.3.20260408 +types-tqdm==4.67.3.20260508 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via @@ -653,7 +661,7 @@ uncertainties==3.2.3 # via pymatgen-core uri-template==1.3.0 # via jsonschema -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests @@ -662,7 +670,7 @@ uvicorn==0.46.0 # via # fastmcp # mcp -virtualenv==21.3.1 +virtualenv==21.3.2 # via pre-commit watchfiles==1.1.1 # via fastmcp diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index f46aae42..2f793027 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -12,9 +12,9 @@ bibtexparser==1.4.4 # via pymatgen-core blake3==1.0.8 # via emmet-core -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer @@ -30,11 +30,11 @@ deltalake==1.5.1 # via mp-api (pyproject.toml) deprecated==1.3.1 # via deltalake -emmet-core==0.86.4 +emmet-core==0.87.0rc1 # via mp-api (pyproject.toml) fonttools==4.62.1 # via matplotlib -idna==3.13 +idna==3.15 # via requests inflect==7.5.0 # via emmet-core @@ -61,7 +61,7 @@ more-itertools==11.0.2 # via inflect mpmath==1.3.0 # via sympy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via pymatgen-core @@ -85,7 +85,7 @@ packaging==26.2 # plotly palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via pymatgen-core pillow==12.2.0 # via matplotlib @@ -104,7 +104,7 @@ pydantic==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # pymatgen-io-validation @@ -130,7 +130,7 @@ python-dotenv==1.2.2 # via pydantic-settings pyyaml==6.0.3 # via pybtex -requests==2.33.1 +requests==2.34.0 # via # mp-api (pyproject.toml) # pymatgen-core @@ -168,7 +168,7 @@ typing-inspection==0.4.2 # pydantic-settings uncertainties==3.2.3 # via pymatgen-core -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 09f9ae78..69e23e09 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -24,6 +24,8 @@ arrow==1.4.0 # via isoduration ase==3.28.0 # via pymatgen-analysis-diffusion +ast-serialize==0.3.0 + # via mypy asttokens==3.0.1 # via stack-data attrs==26.1.0 @@ -45,15 +47,15 @@ blinker==1.9.0 # via flask boltons==25.0.0 # via mp-api (pyproject.toml) -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer bravado==12.0.1 # via mp-api (pyproject.toml) -bravado-core==6.3.1 +bravado-core==6.4.1 # via bravado cachetools==7.1.1 # via @@ -78,7 +80,7 @@ click==8.3.3 # uvicorn contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.5 +coverage[toml]==7.14.0 # via pytest-cov cryptography==48.0.0 # via @@ -90,7 +92,7 @@ custodian==2025.12.14 # via mp-api (pyproject.toml) cycler==0.12.1 # via matplotlib -cyclopts==4.11.2 +cyclopts==4.12.0 # via fastmcp decorator==5.2.1 # via ipython @@ -113,7 +115,7 @@ docutils==0.22.4 # sphinx email-validator==2.3.0 # via pydantic -emmet-core[all]==0.86.4 +emmet-core[all]==0.87.0rc1 # via mp-api (pyproject.toml) exceptiongroup==1.3.1 # via fastmcp @@ -157,7 +159,7 @@ httpx-sse==0.4.3 # via mcp identify==2.6.19 # via pre-commit -idna==3.13 +idna==3.15 # via # anyio # email-validator @@ -237,11 +239,13 @@ latexcodec==3.0.1 # via pybtex lazy-loader==0.5 # via scikit-image -librt==0.10.0 +librt==0.11.0 # via mypy +lobsterpy==0.6.1 + # via emmet-core lxml==6.1.0 # via pymatgen-core -markdown-it-py==4.0.0 +markdown-it-py==4.2.0 # via rich markupsafe==3.0.3 # via @@ -253,11 +257,11 @@ matplotlib==3.10.9 # ase # pymatgen-core # seaborn -matplotlib-inline==0.2.1 +matplotlib-inline==0.2.2 # via ipython mccabe==0.7.0 # via flake8 -mcp==1.27.0 +mcp==1.27.1 # via fastmcp mdurl==0.1.2 # via markdown-it-py @@ -282,13 +286,13 @@ msgpack==1.1.2 # via # bravado # bravado-core -mypy==1.20.2 +mypy==2.1.0 # via mp-api (pyproject.toml) mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via @@ -301,6 +305,7 @@ numpy==2.4.4 # ase # contourpy # imageio + # lobsterpy # matplotlib # monty # pandas @@ -334,7 +339,7 @@ packaging==26.2 # sphinx palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via # pymatgen-core # seaborn @@ -387,6 +392,8 @@ pyarrow==24.0.0 # mp-api (pyproject.toml) pybtex==0.26.1 # via emmet-core +pycodcif==3.0.1 + # via emmet-core pycodestyle==2.14.0 # via # flake8 @@ -403,7 +410,7 @@ pydantic[email]==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # mcp @@ -424,6 +431,7 @@ pyjwt[crypto]==2.12.1 pymatgen==2026.5.4 # via # emmet-core + # lobsterpy # mp-api (pyproject.toml) # mp-pyrho # pymatgen-analysis-alloys @@ -471,13 +479,13 @@ python-dateutil==2.9.0.post0 # bravado-core # matplotlib # pandas -python-discovery==1.3.0 +python-discovery==1.3.1 # via virtualenv python-dotenv==1.2.2 # via # fastmcp # pydantic-settings -python-multipart==0.0.27 +python-multipart==0.0.28 # via mcp pytz==2026.2 # via bravado-core @@ -495,7 +503,7 @@ referencing==0.37.0 # jsonschema # jsonschema-path # jsonschema-specifications -requests==2.33.1 +requests==2.34.0 # via # bravado # bravado-core @@ -576,7 +584,7 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sse-starlette==3.4.1 +sse-starlette==3.4.4 # via mcp stack-data==0.6.3 # via ipython @@ -602,13 +610,13 @@ traitlets==5.15.0 # matplotlib-inline typeguard==4.5.1 # via inflect -types-requests==2.33.0.20260503 +types-requests==2.33.0.20260513 # via # mp-api (pyproject.toml) # types-tqdm -types-setuptools==82.0.0.20260408 +types-setuptools==82.0.0.20260508 # via mp-api (pyproject.toml) -types-tqdm==4.67.3.20260408 +types-tqdm==4.67.3.20260508 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via @@ -646,7 +654,7 @@ uncertainties==3.2.3 # via pymatgen-core uri-template==1.3.0 # via jsonschema -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests @@ -655,7 +663,7 @@ uvicorn==0.46.0 # via # fastmcp # mcp -virtualenv==21.3.1 +virtualenv==21.3.2 # via pre-commit watchfiles==1.1.1 # via fastmcp diff --git a/requirements/requirements-ubuntu-latest_py3.13.txt b/requirements/requirements-ubuntu-latest_py3.13.txt index be79e368..b64f7e02 100644 --- a/requirements/requirements-ubuntu-latest_py3.13.txt +++ b/requirements/requirements-ubuntu-latest_py3.13.txt @@ -12,9 +12,9 @@ bibtexparser==1.4.4 # via pymatgen-core blake3==1.0.8 # via emmet-core -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer @@ -30,11 +30,11 @@ deltalake==1.5.1 # via mp-api (pyproject.toml) deprecated==1.3.1 # via deltalake -emmet-core==0.86.4 +emmet-core==0.87.0rc1 # via mp-api (pyproject.toml) fonttools==4.62.1 # via matplotlib -idna==3.13 +idna==3.15 # via requests inflect==7.5.0 # via emmet-core @@ -61,7 +61,7 @@ more-itertools==11.0.2 # via inflect mpmath==1.3.0 # via sympy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via pymatgen-core @@ -85,7 +85,7 @@ packaging==26.2 # plotly palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via pymatgen-core pillow==12.2.0 # via matplotlib @@ -104,7 +104,7 @@ pydantic==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # pymatgen-io-validation @@ -130,7 +130,7 @@ python-dotenv==1.2.2 # via pydantic-settings pyyaml==6.0.3 # via pybtex -requests==2.33.1 +requests==2.34.0 # via # mp-api (pyproject.toml) # pymatgen-core @@ -167,7 +167,7 @@ typing-inspection==0.4.2 # pydantic-settings uncertainties==3.2.3 # via pymatgen-core -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests diff --git a/requirements/requirements-ubuntu-latest_py3.13_extras.txt b/requirements/requirements-ubuntu-latest_py3.13_extras.txt index 24c3a2c1..1006ae32 100644 --- a/requirements/requirements-ubuntu-latest_py3.13_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.13_extras.txt @@ -24,6 +24,8 @@ arrow==1.4.0 # via isoduration ase==3.28.0 # via pymatgen-analysis-diffusion +ast-serialize==0.3.0 + # via mypy asttokens==3.0.1 # via stack-data attrs==26.1.0 @@ -45,15 +47,15 @@ blinker==1.9.0 # via flask boltons==25.0.0 # via mp-api (pyproject.toml) -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer bravado==12.0.1 # via mp-api (pyproject.toml) -bravado-core==6.3.1 +bravado-core==6.4.1 # via bravado cachetools==7.1.1 # via @@ -78,7 +80,7 @@ click==8.3.3 # uvicorn contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.5 +coverage[toml]==7.14.0 # via pytest-cov cryptography==48.0.0 # via @@ -90,7 +92,7 @@ custodian==2025.12.14 # via mp-api (pyproject.toml) cycler==0.12.1 # via matplotlib -cyclopts==4.11.2 +cyclopts==4.12.0 # via fastmcp decorator==5.2.1 # via ipython @@ -113,7 +115,7 @@ docutils==0.22.4 # sphinx email-validator==2.3.0 # via pydantic -emmet-core[all]==0.86.4 +emmet-core[all]==0.87.0rc1 # via mp-api (pyproject.toml) exceptiongroup==1.3.1 # via fastmcp @@ -157,7 +159,7 @@ httpx-sse==0.4.3 # via mcp identify==2.6.19 # via pre-commit -idna==3.13 +idna==3.15 # via # anyio # email-validator @@ -237,11 +239,13 @@ latexcodec==3.0.1 # via pybtex lazy-loader==0.5 # via scikit-image -librt==0.10.0 +librt==0.11.0 # via mypy +lobsterpy==0.6.1 + # via emmet-core lxml==6.1.0 # via pymatgen-core -markdown-it-py==4.0.0 +markdown-it-py==4.2.0 # via rich markupsafe==3.0.3 # via @@ -253,11 +257,11 @@ matplotlib==3.10.9 # ase # pymatgen-core # seaborn -matplotlib-inline==0.2.1 +matplotlib-inline==0.2.2 # via ipython mccabe==0.7.0 # via flake8 -mcp==1.27.0 +mcp==1.27.1 # via fastmcp mdurl==0.1.2 # via markdown-it-py @@ -282,13 +286,13 @@ msgpack==1.1.2 # via # bravado # bravado-core -mypy==1.20.2 +mypy==2.1.0 # via mp-api (pyproject.toml) mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via @@ -301,6 +305,7 @@ numpy==2.4.4 # ase # contourpy # imageio + # lobsterpy # matplotlib # monty # pandas @@ -334,7 +339,7 @@ packaging==26.2 # sphinx palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via # pymatgen-core # seaborn @@ -387,6 +392,8 @@ pyarrow==24.0.0 # mp-api (pyproject.toml) pybtex==0.26.1 # via emmet-core +pycodcif==3.0.1 + # via emmet-core pycodestyle==2.14.0 # via # flake8 @@ -403,7 +410,7 @@ pydantic[email]==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # mcp @@ -424,6 +431,7 @@ pyjwt[crypto]==2.12.1 pymatgen==2026.5.4 # via # emmet-core + # lobsterpy # mp-api (pyproject.toml) # mp-pyrho # pymatgen-analysis-alloys @@ -471,13 +479,13 @@ python-dateutil==2.9.0.post0 # bravado-core # matplotlib # pandas -python-discovery==1.3.0 +python-discovery==1.3.1 # via virtualenv python-dotenv==1.2.2 # via # fastmcp # pydantic-settings -python-multipart==0.0.27 +python-multipart==0.0.28 # via mcp pytz==2026.2 # via bravado-core @@ -495,7 +503,7 @@ referencing==0.37.0 # jsonschema # jsonschema-path # jsonschema-specifications -requests==2.33.1 +requests==2.34.0 # via # bravado # bravado-core @@ -576,7 +584,7 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sse-starlette==3.4.1 +sse-starlette==3.4.4 # via mcp stack-data==0.6.3 # via ipython @@ -602,13 +610,13 @@ traitlets==5.15.0 # matplotlib-inline typeguard==4.5.1 # via inflect -types-requests==2.33.0.20260503 +types-requests==2.33.0.20260513 # via # mp-api (pyproject.toml) # types-tqdm -types-setuptools==82.0.0.20260408 +types-setuptools==82.0.0.20260508 # via mp-api (pyproject.toml) -types-tqdm==4.67.3.20260408 +types-tqdm==4.67.3.20260508 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via @@ -640,7 +648,7 @@ uncertainties==3.2.3 # via pymatgen-core uri-template==1.3.0 # via jsonschema -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests @@ -649,7 +657,7 @@ uvicorn==0.46.0 # via # fastmcp # mcp -virtualenv==21.3.1 +virtualenv==21.3.2 # via pre-commit watchfiles==1.1.1 # via fastmcp diff --git a/requirements/requirements-ubuntu-latest_py3.14.txt b/requirements/requirements-ubuntu-latest_py3.14.txt index f806d209..1e5b11bc 100644 --- a/requirements/requirements-ubuntu-latest_py3.14.txt +++ b/requirements/requirements-ubuntu-latest_py3.14.txt @@ -12,9 +12,9 @@ bibtexparser==1.4.4 # via pymatgen-core blake3==1.0.8 # via emmet-core -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer @@ -30,11 +30,11 @@ deltalake==1.5.1 # via mp-api (pyproject.toml) deprecated==1.3.1 # via deltalake -emmet-core==0.86.4 +emmet-core==0.87.0rc1 # via mp-api (pyproject.toml) fonttools==4.62.1 # via matplotlib -idna==3.13 +idna==3.15 # via requests inflect==7.5.0 # via emmet-core @@ -61,7 +61,7 @@ more-itertools==11.0.2 # via inflect mpmath==1.3.0 # via sympy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via pymatgen-core @@ -85,7 +85,7 @@ packaging==26.2 # plotly palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via pymatgen-core pillow==12.2.0 # via matplotlib @@ -104,7 +104,7 @@ pydantic==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # pymatgen-io-validation @@ -130,7 +130,7 @@ python-dotenv==1.2.2 # via pydantic-settings pyyaml==6.0.3 # via pybtex -requests==2.33.1 +requests==2.34.0 # via # mp-api (pyproject.toml) # pymatgen-core @@ -167,7 +167,7 @@ typing-inspection==0.4.2 # pydantic-settings uncertainties==3.2.3 # via pymatgen-core -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests diff --git a/requirements/requirements-ubuntu-latest_py3.14_extras.txt b/requirements/requirements-ubuntu-latest_py3.14_extras.txt index 9af80371..b6739c5b 100644 --- a/requirements/requirements-ubuntu-latest_py3.14_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.14_extras.txt @@ -24,6 +24,8 @@ arrow==1.4.0 # via isoduration ase==3.28.0 # via pymatgen-analysis-diffusion +ast-serialize==0.3.0 + # via mypy asttokens==3.0.1 # via stack-data attrs==26.1.0 @@ -45,15 +47,15 @@ blinker==1.9.0 # via flask boltons==25.0.0 # via mp-api (pyproject.toml) -boto3==1.43.4 +boto3==1.43.6 # via mp-api (pyproject.toml) -botocore==1.43.4 +botocore==1.43.6 # via # boto3 # s3transfer bravado==12.0.1 # via mp-api (pyproject.toml) -bravado-core==6.3.1 +bravado-core==6.4.1 # via bravado cachetools==7.1.1 # via @@ -78,7 +80,7 @@ click==8.3.3 # uvicorn contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.5 +coverage[toml]==7.14.0 # via pytest-cov cryptography==48.0.0 # via @@ -90,7 +92,7 @@ custodian==2025.12.14 # via mp-api (pyproject.toml) cycler==0.12.1 # via matplotlib -cyclopts==4.11.2 +cyclopts==4.12.0 # via fastmcp decorator==5.2.1 # via ipython @@ -113,7 +115,7 @@ docutils==0.22.4 # sphinx email-validator==2.3.0 # via pydantic -emmet-core[all]==0.86.4 +emmet-core[all]==0.87.0rc1 # via mp-api (pyproject.toml) exceptiongroup==1.3.1 # via fastmcp @@ -157,7 +159,7 @@ httpx-sse==0.4.3 # via mcp identify==2.6.19 # via pre-commit -idna==3.13 +idna==3.15 # via # anyio # email-validator @@ -237,11 +239,13 @@ latexcodec==3.0.1 # via pybtex lazy-loader==0.5 # via scikit-image -librt==0.10.0 +librt==0.11.0 # via mypy +lobsterpy==0.6.1 + # via emmet-core lxml==6.1.0 # via pymatgen-core -markdown-it-py==4.0.0 +markdown-it-py==4.2.0 # via rich markupsafe==3.0.3 # via @@ -253,11 +257,11 @@ matplotlib==3.10.9 # ase # pymatgen-core # seaborn -matplotlib-inline==0.2.1 +matplotlib-inline==0.2.2 # via ipython mccabe==0.7.0 # via flake8 -mcp==1.27.0 +mcp==1.27.1 # via fastmcp mdurl==0.1.2 # via markdown-it-py @@ -282,13 +286,13 @@ msgpack==1.1.2 # via # bravado # bravado-core -mypy==1.20.2 +mypy==2.1.0 # via mp-api (pyproject.toml) mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.20.0 +narwhals==2.21.0 # via plotly networkx==3.6.1 # via @@ -301,6 +305,7 @@ numpy==2.4.4 # ase # contourpy # imageio + # lobsterpy # matplotlib # monty # pandas @@ -334,7 +339,7 @@ packaging==26.2 # sphinx palettable==3.3.3 # via pymatgen-core -pandas==3.0.2 +pandas==3.0.3 # via # pymatgen-core # seaborn @@ -387,6 +392,8 @@ pyarrow==24.0.0 # mp-api (pyproject.toml) pybtex==0.26.1 # via emmet-core +pycodcif==3.0.1 + # via emmet-core pycodestyle==2.14.0 # via # flake8 @@ -403,7 +410,7 @@ pydantic[email]==2.13.4 # pymatgen-io-validation pydantic-core==2.46.4 # via pydantic -pydantic-settings==2.14.0 +pydantic-settings==2.14.1 # via # emmet-core # mcp @@ -424,6 +431,7 @@ pyjwt[crypto]==2.12.1 pymatgen==2026.5.4 # via # emmet-core + # lobsterpy # mp-api (pyproject.toml) # mp-pyrho # pymatgen-analysis-alloys @@ -471,13 +479,13 @@ python-dateutil==2.9.0.post0 # bravado-core # matplotlib # pandas -python-discovery==1.3.0 +python-discovery==1.3.1 # via virtualenv python-dotenv==1.2.2 # via # fastmcp # pydantic-settings -python-multipart==0.0.27 +python-multipart==0.0.28 # via mcp pytz==2026.2 # via bravado-core @@ -495,7 +503,7 @@ referencing==0.37.0 # jsonschema # jsonschema-path # jsonschema-specifications -requests==2.33.1 +requests==2.34.0 # via # bravado # bravado-core @@ -576,7 +584,7 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sse-starlette==3.4.1 +sse-starlette==3.4.4 # via mcp stack-data==0.6.3 # via ipython @@ -602,13 +610,13 @@ traitlets==5.15.0 # matplotlib-inline typeguard==4.5.1 # via inflect -types-requests==2.33.0.20260503 +types-requests==2.33.0.20260513 # via # mp-api (pyproject.toml) # types-tqdm -types-setuptools==82.0.0.20260408 +types-setuptools==82.0.0.20260508 # via mp-api (pyproject.toml) -types-tqdm==4.67.3.20260408 +types-tqdm==4.67.3.20260508 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via @@ -640,7 +648,7 @@ uncertainties==3.2.3 # via pymatgen-core uri-template==1.3.0 # via jsonschema -urllib3==2.6.3 +urllib3==2.7.0 # via # botocore # requests @@ -649,7 +657,7 @@ uvicorn==0.46.0 # via # fastmcp # mcp -virtualenv==21.3.1 +virtualenv==21.3.2 # via pre-commit watchfiles==1.1.1 # via fastmcp diff --git a/tests/client/materials/test_chemenv.py b/tests/client/materials/test_chemenv.py index 77e693c0..8fd04f98 100644 --- a/tests/client/materials/test_chemenv.py +++ b/tests/client/materials/test_chemenv.py @@ -3,7 +3,6 @@ import pytest from mp_api._test_utils import client_search_testing, requires_api_key - from mp_api.client.routes.materials.chemenv import ChemenvRester @@ -34,11 +33,11 @@ def rester(): } custom_field_tests: dict = { - "material_ids": ["mp-22526"], + "material_ids": ["mp-149"], "elements": ["Si", "O"], "exclude_elements": ["Si", "O"], "chemenv_symbol": ["S:1"], - "chemenv_iupac": ["IC-12"], + "chemenv_iupac": ["A-2"], "chemenv_iucr": ["[2l]"], "chemenv_name": ["Octahedron"], "species": ["Cu2+"], diff --git a/tests/client/materials/test_electrodes.py b/tests/client/materials/test_electrodes.py index e6a53008..cb37d486 100644 --- a/tests/client/materials/test_electrodes.py +++ b/tests/client/materials/test_electrodes.py @@ -4,15 +4,14 @@ from pymatgen.core.periodic_table import Element from mp_api._test_utils import ( - client_search_testing, client_pagination, + client_search_testing, client_sort, requires_api_key, ) - from mp_api.client.routes.materials.electrodes import ( - ElectrodeRester, ConversionElectrodeRester, + ElectrodeRester, ) @@ -43,9 +42,9 @@ def conversion_rester(): sub_doc_fields: list = [] alt_name_dict: dict = { - "battery_ids": "battery_id", - "formula": "battery_id", - "exclude_elements": "battery_id", + "battery_ids": "battery_type", + "formula": "battery_type", + "exclude_elements": "battery_type", "num_elements": "nelements", "num_sites": "nsites", } @@ -93,14 +92,11 @@ def test_conversion_client(conversion_rester): @requires_api_key def test_pagination(): with ElectrodeRester() as rester: - client_pagination(rester.search, "battery_id") + client_pagination(rester.search, "material_ids") -@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) @requires_api_key -@pytest.mark.parametrize( - "sort_field", ["battery_id", "stability_charge", "average_voltage"] -) +@pytest.mark.parametrize("sort_field", ["stability_charge", "average_voltage"]) def test_sort(sort_field): with ElectrodeRester() as rester: - client_sort(rester.search, sort_field) + client_sort(rester.search, sort_field, default_fields=()) diff --git a/tests/client/materials/test_electronic_structure.py b/tests/client/materials/test_electronic_structure.py index a89cc730..f17378c5 100644 --- a/tests/client/materials/test_electronic_structure.py +++ b/tests/client/materials/test_electronic_structure.py @@ -1,9 +1,9 @@ +from typing import Any + import pytest from pymatgen.analysis.magnetism import Ordering -from typing import Any from mp_api._test_utils import client_search_testing, requires_api_key - from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.electronic_structure import ( BandStructureRester, @@ -104,7 +104,7 @@ def test_bs_client(): with pytest.raises(MPRestError, match="No electronic structure data found."): _ = bs_rester.get_bandstructure_from_material_id("mp-0") - with pytest.raises(MPRestError, match="No object found"): + with pytest.raises(MPRestError, match="No bandstructure data found"): _ = bs_rester.get_bandstructure_from_task_id("mp-0") @@ -157,6 +157,8 @@ def test_dos_client(): with pytest.raises(MPRestError, match="To query orbital-projected DOS"): _ = dos_rester.search(projection_type="orbital") - assert dos_rester.get_dos_from_material_id("mp-0") is None - with pytest.raises(MPRestError, match="No object found"): + with pytest.raises(MPRestError, match="No electronic structure data found"): + _ = dos_rester.get_dos_from_material_id("mp-0") + + with pytest.raises(MPRestError, match="No DOS data found for task_id"): _ = dos_rester.get_dos_from_task_id("mp-0") diff --git a/tests/client/materials/test_eos.py b/tests/client/materials/test_eos.py index 3e633e49..dbce1bc2 100644 --- a/tests/client/materials/test_eos.py +++ b/tests/client/materials/test_eos.py @@ -3,7 +3,7 @@ import pytest from mp_api._test_utils import client_search_testing, requires_api_key - +from mp_api.client.core.exceptions import MPRestError, MPRestWarning from mp_api.client.routes.materials.eos import EOSRester @@ -26,9 +26,9 @@ def rester(): sub_doc_fields: list = [] -alt_name_dict: dict = {"material_ids": "material_id"} +alt_name_dict: dict = {"task_ids": "eos"} -custom_field_tests: dict = {"material_ids": ["mp-149"]} +custom_field_tests: dict = {"task_ids": ["mp-149"]} @requires_api_key @@ -42,3 +42,15 @@ def test_client(rester): custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, ) + + +@requires_api_key +def test_warnings_errors(rester): + + with pytest.warns( + MPRestWarning, match="`material_id` has been replaced by `task_id`" + ): + rester.search(material_ids=["mp-149"], num_chunks=1, chunk_size=1) + + with pytest.raises(MPRestError, match="You have specified both"): + rester.search(material_ids=["mp-149"], task_ids=["mp-1"]) diff --git a/tests/client/materials/test_phonon.py b/tests/client/materials/test_phonon.py index 1beb1978..58050d5f 100644 --- a/tests/client/materials/test_phonon.py +++ b/tests/client/materials/test_phonon.py @@ -2,11 +2,9 @@ import numpy as np import pytest - from emmet.core.phonon import PhononBS, PhononDOS from mp_api._test_utils import client_search_testing, requires_api_key - from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.phonon import PhononRester @@ -23,17 +21,22 @@ def test_phonon_search(): "fields", ], alt_name_dict={ - "material_ids": "material_id", + "phonon_ids": "identifier", }, custom_field_tests={ + # test search kwarg backwards compat "material_ids": ["mp-149", "mp-13"], - "material_ids": "mp-149", + "phonon_ids": ["ft", "mp-13"], + "phonon_ids": "mp-149", "phonon_method": "dfpt", }, sub_doc_fields=[], ) +# NOTE: below funcs still query legacy jsonl s3 objects +# they are key by 'mp-123' -> don't change id search strings +# to Alpha version @requires_api_key @pytest.mark.parametrize("use_document_model", [True, False]) def test_phonon_get_methods(use_document_model): diff --git a/tests/client/materials/test_provenance.py b/tests/client/materials/test_provenance.py index 9a460c7e..da855ff8 100644 --- a/tests/client/materials/test_provenance.py +++ b/tests/client/materials/test_provenance.py @@ -25,7 +25,7 @@ def rester(): alt_name_dict: dict = {"material_ids": "material_id"} -custom_field_tests: dict = {"material_ids": ["mp-149"]} +custom_field_tests: dict = {"material_ids": ["mp-13"]} @requires_api_key diff --git a/tests/client/materials/test_summary.py b/tests/client/materials/test_summary.py index 1d908342..7869cb2b 100644 --- a/tests/client/materials/test_summary.py +++ b/tests/client/materials/test_summary.py @@ -161,4 +161,4 @@ def test_pagination(): @pytest.mark.parametrize("sort_field", summary_sort_fields) def test_sort(sort_field: str): with SummaryRester() as rester: - client_sort(rester.search, sort_field) + client_sort(rester.search, sort_field, aux_query={sort_field: (0, 10)}) diff --git a/tests/client/materials/test_thermo.py b/tests/client/materials/test_thermo.py index 83885440..51dfd66c 100644 --- a/tests/client/materials/test_thermo.py +++ b/tests/client/materials/test_thermo.py @@ -5,7 +5,6 @@ from pymatgen.analysis.phase_diagram import PhaseDiagram from mp_api._test_utils import client_search_testing, requires_api_key - from mp_api.client.routes.materials.thermo import ThermoRester @@ -30,7 +29,7 @@ def rester(): alt_name_dict: dict = { "formula": "formula_pretty", "material_ids": "material_id", - "thermo_ids": "thermo_id", + "thermo_ids": "material_id", "thermo_types": "thermo_type", "total_energy": "energy_per_atom", "formation_energy": "formation_energy_per_atom", @@ -65,7 +64,8 @@ def test_client(rester): def test_get_phase_diagram_from_chemsys(): # Test that a phase diagram is returned + pd = ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm", thermo_type="GGA_GGA+U") assert isinstance( - ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm", thermo_type="GGA_GGA+U"), + pd, PhaseDiagram, ) diff --git a/tests/client/materials/test_xas.py b/tests/client/materials/test_xas.py index 850d6b8c..aea83274 100644 --- a/tests/client/materials/test_xas.py +++ b/tests/client/materials/test_xas.py @@ -1,16 +1,15 @@ -import pytest from typing import Any +import pytest from emmet.core.types.enums import XasEdge, XasType from pymatgen.core.periodic_table import Element from mp_api._test_utils import ( - client_search_testing, client_pagination, + client_search_testing, client_sort, requires_api_key, ) - from mp_api.client.routes.materials.xas import XASRester @@ -35,8 +34,8 @@ def rester(): alt_name_dict: dict[str, str] = { "required_elements": "elements", "formula": "formula_pretty", - "exclude_elements": "material_id", - "spectrum_ids": "spectrum_id", + "exclude_elements": "absorbing_element", + "spectrum_ids": "absorbing_element", } custom_field_tests: dict[str, Any] = { @@ -47,7 +46,6 @@ def rester(): "formula": "Ce(WO4)2", "chemsys": "Ce-O-W", "elements": ["Ce"], - "spectrum_ids": ["mp-1194531-XANES-Fe-L2", "mp-1194531-XANES-Fe-K"], } @@ -69,14 +67,19 @@ def test_client(rester): @requires_api_key def test_pagination(): with XASRester() as rester: - client_pagination(rester.search, "spectrum_id") + client_pagination( + rester.search, "task_id", additional_fields=["spectrum_type", "edge"] + ) -@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) @requires_api_key @pytest.mark.parametrize( - "sort_field", ["material_id", "absorbing_element", "spectrum_id"] + "sort_field", + [ + "spectrum_type", + "absorbing_element", + ], ) def test_sort(sort_field): with XASRester() as rester: - client_sort(rester.search, sort_field) + client_sort(rester.search, sort_field, default_fields=()) diff --git a/tests/client/molecules/test_jcesr.py b/tests/client/molecules/test_jcesr.py index deb8bf21..cd7dd97b 100644 --- a/tests/client/molecules/test_jcesr.py +++ b/tests/client/molecules/test_jcesr.py @@ -1,14 +1,14 @@ import os + +import pytest +from pymatgen.core.periodic_table import Element + from mp_api._test_utils import ( - client_search_testing, client_pagination, + client_search_testing, client_sort, requires_api_key, ) - -import pytest -from pymatgen.core.periodic_table import Element - from mp_api.client.core.exceptions import MPRestWarning from mp_api.client.routes.molecules.jcesr import JcesrMoleculesRester @@ -68,9 +68,13 @@ def test_pagination(): client_pagination(rester.search, "task_id") -@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) @requires_api_key -@pytest.mark.parametrize("sort_field", ["task_id", "IE", "EA"]) +@pytest.mark.parametrize( + "sort_field", + [ + "task_id", + ], +) def test_sort(sort_field): with JcesrMoleculesRester() as rester: - client_sort(rester.search, sort_field) + client_sort(rester.search, sort_field, default_fields=()) diff --git a/tests/client/molecules/test_summary.py b/tests/client/molecules/test_summary.py index 910adf47..cc96b3b7 100644 --- a/tests/client/molecules/test_summary.py +++ b/tests/client/molecules/test_summary.py @@ -5,8 +5,8 @@ from emmet.core.mpid import MPculeID from mp_api._test_utils import ( - client_search_testing, client_pagination, + client_search_testing, client_sort, requires_api_key, ) @@ -66,9 +66,8 @@ def test_pagination(): client_pagination(rester.search, "molecule_id") -@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) @requires_api_key -@pytest.mark.parametrize("sort_field", ["molecule_id", "charge", "spin_multiplicity"]) +@pytest.mark.parametrize("sort_field", ["charge", "spin_multiplicity"]) def test_sort(sort_field): with MoleculesSummaryRester() as rester: - client_sort(rester.search, sort_field) + client_sort(rester.search, sort_field, default_fields=()) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 7d91f6e5..caf786f8 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -48,12 +48,16 @@ ] # temp # Temporarily ignore molecules resters while molecules query operators are changed +# Temp ignore eos rester for material_id -> task_id schema change with MPRester() as mpr: resters_to_test = [ rester for rester in mpr._all_resters if ( - "molecule" not in rester._class_name.lower() + not any( + substr in rester._class_name.lower() + for substr in ("molecule", "electrode") + ) and not (pmg_alloys is None and "alloys" in str(rester).lower()) ) ] diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index 521518fc..06fc3799 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -1,18 +1,17 @@ +import importlib import itertools import os import random -import importlib -import requests from tempfile import NamedTemporaryFile import numpy as np import pytest +import requests from emmet.core.mpid import MPID, AlphaID +from emmet.core.phonon import PhononBS, PhononDOS from emmet.core.tasks import TaskDoc -from emmet.core.vasp.calc_types import CalcType -from emmet.core.phonon import PhononDOS, PhononBS from emmet.core.types.enums import ThermoType - +from emmet.core.vasp.calc_types import CalcType from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry, PourbaixDiagram, PourbaixEntry from pymatgen.analysis.wulff import WulffShape @@ -26,16 +25,15 @@ ) from pymatgen.electronic_structure.dos import CompleteDos from pymatgen.entries.compatibility import ( - MaterialsProjectAqueousCompatibility, MaterialsProject2020Compatibility, + MaterialsProjectAqueousCompatibility, ) -from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme from pymatgen.entries.computed_entries import ComputedEntry, GibbsComputedStructureEntry +from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme from pymatgen.io.cif import CifParser from pymatgen.io.vasp import Chgcar from mp_api._test_utils import requires_api_key - from mp_api.client import MPRester from mp_api.client.core import MPRestError, MPRestWarning from mp_api.client.core.settings import _DEFAULT_ENDPOINT @@ -56,7 +54,9 @@ def mpr(): @requires_api_key class TestMPRester: fake_mp_api_key = "12345678901234567890123456789012" # 32 chars - default_endpoint = _DEFAULT_ENDPOINT + default_endpoint = _DEFAULT_ENDPOINT + ( + "/" if not _DEFAULT_ENDPOINT.endswith("/") else "" + ) def test_get_structure_by_material_id(self, mpr): s0 = mpr.get_structure_by_material_id("mp-149") @@ -69,9 +69,14 @@ def test_get_structure_by_material_id(self, mpr): assert {s.formula for s in s2} == {"Si2"} def test_get_database_version(self, mpr): - db_version = mpr.get_database_version() + db_version = mpr.db_version assert db_version is not None + with pytest.warns( + MPRestWarning, match="`get_database_version` has been deprecated" + ): + assert db_version == mpr.get_database_version() + def test_get_material_id_from_task_id(self, mpr): assert mpr.get_material_id_from_task_id("mp-540081") == "mp-19017" @@ -82,8 +87,7 @@ def test_get_task_ids_associated_with_material_id(self, mpr): assert len(results) > 0 def test_get_material_id_references(self, mpr): - data = mpr.get_material_id_references("mp-123") - assert len(data) > 5 + assert len(mpr.get_material_id_references("mp-123")) > 0 def test_get_material_id_doc(self, mpr): mp_ids = mpr.get_material_ids("Al2O3") @@ -168,6 +172,7 @@ def test_find_structure( <= 2 ) + @pytest.mark.skip(reason="Re-enable after deployment of db version 2026.04.13") def test_get_bandstructure_by_material_id(self, mpr): bs = mpr.get_bandstructure_by_material_id("mp-149") assert isinstance(bs, BandStructureSymmLine) @@ -175,6 +180,7 @@ def test_get_bandstructure_by_material_id(self, mpr): assert isinstance(bs_uniform, BandStructure) assert not isinstance(bs_uniform, BandStructureSymmLine) + @pytest.mark.skip(reason="Re-enable after deployment of db version 2026.04.13") def test_get_dos_by_id(self, mpr): dos = mpr.get_dos_by_material_id("mp-149") assert isinstance(dos, CompleteDos) @@ -185,74 +191,77 @@ def test_get_entry_by_material_id(self, mpr): assert e[0].composition.reduced_formula == "LiFePO4" def test_get_entries(self, mpr): + + # Avoiding "golden test data": freshly retrieve 5 thermo docs and + # perform entry querying based off those entries + thermo_docs = mpr.materials.thermo.search( + num_chunks=1, chunk_size=5, num_elements=(2, 3) + ) + syms = ["Li", "Fe", "O"] chemsys = "Li-Fe-O" with pytest.warns( DeprecationWarning, match="The `inc_structure` argument is deprecated" ): - entries = mpr.get_entries(chemsys, inc_structure=False) + entries = mpr.get_entries(thermo_docs[0].chemsys, inc_structure=False) - elements = {Element(sym) for sym in syms} - for e in entries: - assert isinstance(e, ComputedEntry) - assert set(e.composition.elements).issubset(elements) + assert all(isinstance(e, ComputedEntry) for e in entries) + assert all( + set(e.composition.elements).issubset(thermo_docs[0].elements) + for e in entries + ) # Formula - formula = "SiO2" - entries = mpr.get_entries(formula) - - for e in entries: - assert isinstance(e, ComputedEntry) + entries = mpr.get_entries(thermo_docs[1].formula_pretty) + assert all(isinstance(e, ComputedEntry) for e in entries) # Property data - formula = "BiFeO3" - entries = mpr.get_entries(formula, property_data=["energy_above_hull"]) + entries = mpr.get_entries( + thermo_docs[2].formula_pretty, property_data=["energy_above_hull"] + ) - for e in entries: - assert e.data.get("energy_above_hull", None) is not None + assert all(e.data.get("energy_above_hull", None) is not None for e in entries) # Conventional structure - entry = next( - e - for e in mpr.get_entry_by_material_id( - "mp-22526", conventional_unit_cell=True - ) - if e.entry_id == "mp-22526-r2SCAN" + as_conv = mpr.get_entry_by_material_id( + thermo_docs[3].material_id, conventional_unit_cell=True ) - - s = entry.structure - assert pytest.approx(s.lattice.a) == s.lattice.b - assert pytest.approx(s.lattice.a) != s.lattice.c - assert pytest.approx(s.lattice.alpha) == 90 - assert pytest.approx(s.lattice.beta) == 90 - assert pytest.approx(s.lattice.gamma) == 120 + assert all(e.structure == e.structure.to_conventional() for e in as_conv) # Ensure energy per atom is same - entry = next( - e - for e in mpr.get_entry_by_material_id( - "mp-22526", conventional_unit_cell=False + non_standardized = mpr.get_entry_by_material_id( + thermo_docs[3].material_id, conventional_unit_cell=False + ) + assert all( + e.uncorrected_energy_per_atom + == pytest.approx( + next( + f for f in non_standardized if f.entry_id == e.entry_id + ).uncorrected_energy_per_atom ) - if e.entry_id == "mp-22526-r2SCAN" + for e in as_conv ) - s = entry.structure - assert pytest.approx(s.lattice.a) == s.lattice.b - assert pytest.approx(s.lattice.a, abs=1e-3) == s.lattice.c - assert pytest.approx(s.lattice.alpha, abs=1e-3) == s.lattice.beta - assert pytest.approx(s.lattice.alpha, abs=1e-3) == s.lattice.gamma # Additional criteria entry = mpr.get_entries( - "mp-149", - additional_criteria={"energy_above_hull": (0.0, 10)}, + thermo_docs[4].material_id, + additional_criteria={ + "energy_above_hull": (0.0, 2 * thermo_docs[4].energy_above_hull) + }, property_data=["energy_above_hull"], )[0] assert "energy_above_hull" in entry.data + # Test out of range entries = mpr.get_entries( - "mp-149", - additional_criteria={"energy_above_hull": (1, 10)}, + thermo_docs[4].material_id, + additional_criteria={ + "energy_above_hull": ( + 1.5 * thermo_docs[4].energy_above_hull, + 2 * thermo_docs[4].energy_above_hull, + ) + }, property_data=["energy_above_hull"], ) @@ -378,7 +387,7 @@ def test_get_ion_entries(self, mpr): # the rf factor correction is necessary to make sure the composition # of the reference solid is normalized to a single formula unit ref_solid_entry = next( - e for e in ion_ref_entries if e.entry_id.startswith("mp-4770") + e for e in ion_ref_entries if str(e.entry_id).startswith("mp-4770") ) rf = ref_solid_entry.composition.get_reduced_composition_and_factor()[1] solid_energy = ion_ref_pd.get_form_energy(ref_solid_entry) / rf @@ -419,14 +428,16 @@ def test_get_wulff_shape(self, mpr): assert isinstance(ws, WulffShape) def test_large_list(self, mpr): + num_chunks = 10 + chunk_size = 500 mpids = [ str(doc.material_id) for doc in mpr.materials.summary.search( - chunk_size=1000, num_chunks=10, fields=["material_id"] + chunk_size=chunk_size, num_chunks=num_chunks, fields=["material_id"] ) ] docs = mpr.materials.summary.search(material_ids=mpids, fields=["material_id"]) - assert len(docs) == 10000 + assert len(docs) == chunk_size * num_chunks def test_get_api_key_endpoint_from_env_var(self, monkeypatch: pytest.MonkeyPatch): """Ensure the MP_API_KEY and MP_API_ENDPOINT from environment variable @@ -553,81 +564,83 @@ def test_get_cohesive_energy(self): mpr.get_cohesive_energy("mp-1") @pytest.mark.parametrize( - "chemsys, thermo_type", - [ - [("Fe", "P"), "GGA_GGA+U"], - [("Li", "S"), ThermoType.GGA_GGA_U_R2SCAN], - [("Ni", "Se"), ThermoType.R2SCAN], - [("Ni", "Kr"), "R2SCAN"], - ], + "thermo_type", ["GGA_GGA+U", ThermoType.GGA_GGA_U_R2SCAN, "r2SCAN"] ) - def test_get_stability(self, chemsys, thermo_type): + def test_get_stability(self, thermo_type): """ This test is adapted from the pymatgen one - the scope is broadened to include more diverse chemical environments and thermo types which reflect the scope of the current MP database. """ with MPRester() as mpr: - entries = mpr.get_entries_in_chemsys( - chemsys, additional_criteria={"thermo_types": [thermo_type]} - ) - no_compound_entries = all( - len(entry.composition.elements) == 1 for entry in entries - ) + # No golden test data. Always test on fetched thermo data + chemsys_to_test: set[str] = { + doc.chemsys + for doc in mpr.materials.thermo.search( + thermo_types=[thermo_type], + num_elements=2, + num_chunks=1, + chunk_size=4, + fields=["chemsys"], + ) + } - modified_entries = [ - ComputedEntry( - entry.composition, - entry.uncorrected_energy + 0.01, - parameters=entry.parameters, - entry_id=f"mod_{entry.entry_id}", + for chemsys in chemsys_to_test: + + entries = mpr.get_entries_in_chemsys( + chemsys, additional_criteria={"thermo_types": [thermo_type]} ) - for entry in entries - if entry.composition.reduced_formula in ["Fe2P", "".join(chemsys)] - ] - if len(modified_entries) == 0: - # create fake entry to get PD retrieval to fail modified_entries = [ ComputedEntry( - "".join(chemsys), - np.average([entry.energy for entry in entries]), - entry_id=f"hypothetical", + entry.composition, + entry.uncorrected_energy + 0.01, + parameters=entry.parameters, + entry_id=f"mod_{entry.entry_id}", ) + for entry in entries + if entry.entry_id == entries[0].entry_id ] - if no_compound_entries: - with pytest.warns( - MPRestWarning, match="No phase diagram data available" + if ( + all(len(entry.composition.elements) == 1 for entry in entries) + and chemsys.count("-") > 0 ): - mpr.get_stability(modified_entries, thermo_type=thermo_type) - return - - else: - rester_ehulls = mpr.get_stability( - modified_entries, thermo_type=thermo_type - ) - - all_entries = entries + modified_entries - - compat = None - if thermo_type == "GGA_GGA+U": - compat = MaterialsProject2020Compatibility() - elif thermo_type == "GGA_GGA+U_R2SCAN": - compat = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN") - - if compat: - all_entries = compat.process_entries(all_entries) + # For a multi-element chemsys with no multinaries, only elementals, + # there should be no phase diagram data available. + with pytest.warns( + MPRestWarning, match="No phase diagram data available" + ): + mpr.get_stability(modified_entries, thermo_type=thermo_type) + return + + else: + rester_ehulls = mpr.get_stability( + modified_entries, thermo_type=thermo_type + ) - pd = PhaseDiagram(all_entries) - for entry in all_entries: - if str(entry.entry_id).startswith("mod"): - for dct in rester_ehulls: - if dct["entry_id"] == entry.entry_id: - data = dct - break - assert pd.get_e_above_hull(entry) == pytest.approx(data["e_above_hull"]) + all_entries = entries + modified_entries + + compat = None + if thermo_type == "GGA_GGA+U": + compat = MaterialsProject2020Compatibility() + elif thermo_type == "GGA_GGA+U_R2SCAN": + compat = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN") + + if compat: + all_entries = compat.process_entries(all_entries) + + pd = PhaseDiagram(all_entries) + for entry in all_entries: + if str(entry.entry_id).startswith("mod"): + for dct in rester_ehulls: + if dct["entry_id"] == entry.entry_id: + data = dct + break + assert pd.get_e_above_hull(entry) == pytest.approx( + data["e_above_hull"] + ) @pytest.mark.parametrize( "mpid, working_ion, thermo_type", @@ -673,7 +686,7 @@ def test_nomad_integration(self, mpr): target_mpid, file_patterns=["some_pattern"] ) assert all( - isinstance(entry["task_id"], MPID) + isinstance(entry["task_id"], AlphaID) and isinstance(entry["calc_type"], CalcType) for entry in calc_type_map[target_mpid] ) @@ -687,7 +700,7 @@ def test_nomad_integration(self, mpr): [MPID(target_mpid)], calc_types=["GGA Deformation"] ) assert all( - isinstance(entry["task_id"], MPID) + isinstance(entry["task_id"], AlphaID) and entry["calc_type"].value == "GGA Deformation" for entry in calc_type_map[target_mpid] ) @@ -700,14 +713,16 @@ def test_nomad_integration(self, mpr): def test_db_warning(self, monkeypatch: pytest.MonkeyPatch): from pathlib import Path + import yaml + from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS with NamedTemporaryFile(suffix=".yaml") as tmp_log: monkeypatch.setattr(MAPI_CLIENT_SETTINGS, "LOG_FILE", Path(tmp_log.name)) with MPRester(notify_db_version=True) as mpr: - db_version = mpr.get_database_version() + db_version = mpr.db_version parsed_db_ver = yaml.safe_load(Path(tmp_log.name).read_text()).get( "MAPI_DB_VERSION" diff --git a/tests/mcp/test_tools.py b/tests/mcp/test_tools.py index 519abf96..22005ea0 100644 --- a/tests/mcp/test_tools.py +++ b/tests/mcp/test_tools.py @@ -33,7 +33,7 @@ def test_core_search_tools(): and doc.metadata is None and doc.title.startswith("mp-") and doc.text == robo_descs[doc.id] - and doc.url == f"https://next-gen.materialsproject.org/materials/{doc.id}" + and doc.url == f"https://next-gen.materialsproject.org/materials/{doc.title}" for doc in search_results.results )