diff --git a/docs/modules/datasources_analysis.rst b/docs/modules/datasources_analysis.rst index 7fc88beb8..0f541604b 100644 --- a/docs/modules/datasources_analysis.rst +++ b/docs/modules/datasources_analysis.rst @@ -19,7 +19,7 @@ Used for fetching initial conditions for inference and validation data for scori .. currentmodule:: earth2studio -.. badge-filter:: region:global region:na region:as +.. badge-filter:: region:global region:na region:as region:europe dataclass:analysis dataclass:reanalysis dataclass:observation dataclass:simulation product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu :filter-mode: or diff --git a/docs/modules/datasources_forecast.rst b/docs/modules/datasources_forecast.rst index c8d50b545..416ea980d 100644 --- a/docs/modules/datasources_forecast.rst +++ b/docs/modules/datasources_forecast.rst @@ -23,6 +23,7 @@ Typically used in intercomparison workflows. :template: datasource.rst data.AIFS_FX + data.CAMS_FX data.AIFS_ENS_FX data.GFS_FX data.GEFS_FX diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index a27049741..2147e9a84 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -17,6 +17,7 @@ from .ace2 import ACE2ERA5Data from .arco import ARCO from .base import DataSource, ForecastSource +from .cams import CAMS_FX from .cbottle import CBottle3D from .cds import CDS from .cmip6 import CMIP6, CMIP6MultiRealm diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py new file mode 100644 index 000000000..f2226f57c --- /dev/null +++ b/earth2studio/data/cams.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import hashlib +import os +import pathlib +import shutil +import tempfile +from dataclasses import dataclass +from datetime import datetime, timedelta +from time import sleep + +import numpy as np +import xarray as xr +from loguru import logger + +from earth2studio.data.utils import ( + datasource_cache_root, + prep_forecast_inputs, +) +from earth2studio.lexicon import CAMSGlobalLexicon +from earth2studio.utils.imports import ( + OptionalDependencyFailure, + check_optional_dependencies, +) +from earth2studio.utils.type import LeadTimeArray, TimeArray, VariableArray + +try: + import cdsapi +except ImportError: + OptionalDependencyFailure("data") + cdsapi = None + + +@dataclass +class _CAMSVarInfo: + e2s_name: str + api_name: str + nc_key: str + dataset: str + level: str + index: int + + +def _resolve_variable(e2s_name: str, index: int) -> _CAMSVarInfo: + cams_key, _ = CAMSGlobalLexicon[e2s_name] + dataset, api_name, nc_key, level = cams_key.split("::") + return _CAMSVarInfo( + e2s_name=e2s_name, + api_name=api_name, + nc_key=nc_key, + dataset=dataset, + level=level, + index=index, + ) + + +def _download_cams_netcdf( + client: "cdsapi.Client", + dataset: str, + request_body: dict, + cache_path: pathlib.Path, + verbose: bool = True, +) -> pathlib.Path: + if cache_path.is_file(): + return cache_path + + r = client.retrieve(dataset, request_body) + while True: + r.update() + reply = r.reply + if verbose: + logger.debug(f"Request ID:{reply['request_id']}, state: {reply['state']}") + if reply["state"] == "completed": + break + elif reply["state"] in ("queued", "running"): + sleep(5.0) + elif reply["state"] in ("failed",): + raise RuntimeError( + f"CAMS request failed for {dataset}: " + + reply.get("error", {}).get("message", "unknown error") + ) + else: + sleep(2.0) + tmp_fd, tmp_name = tempfile.mkstemp(dir=cache_path.parent, suffix=".nc.tmp") + try: + os.close(tmp_fd) + r.download(tmp_name) + os.replace(tmp_name, cache_path) + except Exception: + pathlib.Path(tmp_name).unlink(missing_ok=True) + raise + return cache_path + + +def _extract_field( + ds: xr.Dataset, + nc_key: str, + lead_time_hours: int | None = None, + pressure_level: int | None = None, +) -> np.ndarray: + if nc_key not in ds: + raise ValueError( + f"Variable '{nc_key}' not found in NetCDF. Available: {list(ds.data_vars)}" + ) + field = ds[nc_key] + non_spatial = [d for d in field.dims if d not in ("latitude", "longitude")] + isel: dict[str, int] = {} + for d in non_spatial: + if d == "forecast_period" and lead_time_hours is not None: + fp_vals = field.coords["forecast_period"].values.astype(float) + target = float(lead_time_hours) + nearest_idx = int(np.argmin(np.abs(fp_vals - target))) + isel[d] = nearest_idx + elif d in ("pressure_level", "isobaricInhPa") and pressure_level is not None: + pl_vals = field.coords[d].values.astype(float) + target_pl = float(pressure_level) + nearest_idx = int(np.argmin(np.abs(pl_vals - target_pl))) + isel[d] = nearest_idx + else: + isel[d] = 0 + if isel: + field = field.isel(isel) + return field.values + + +@check_optional_dependencies() +class CAMS_FX: + """CAMS Global atmospheric composition forecast data source. + + Uses the ``cams-global-atmospheric-composition-forecasts`` dataset. + Grid is 0.4 deg global (451 x 900). + + Parameters + ---------- + cache : bool, optional + Cache data source on local memory, by default True + verbose : bool, optional + Print download progress, by default True + + Warning + ------- + This is a remote data source and can potentially download a large amount of data + to your local machine for large requests. + + Note + ---- + Additional information on the data repository, registration, and authentication can + be referenced here: + + - https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts + - https://ads.atmosphere.copernicus.eu/how-to-api + + The API endpoint for this data source varies from the Climate Data Store (CDS), be + sure your api config has the correct url. + + Badges + ------ + region:global dataclass:simulation product:wind product:temp product:atmos + """ + + MAX_LEAD_HOURS = 120 + CAMS_MIN_TIME = datetime(2015, 1, 1) + CAMS_DATASET_URI = "cams-global-atmospheric-composition-forecasts" + CAMS_LAT = np.linspace(90, -90, 451) + CAMS_LON = np.linspace(0, 359.6, 900) + + def __init__(self, cache: bool = True, verbose: bool = True): + self._cache = cache + self._verbose = verbose + self._cds_client: "cdsapi.Client | None" = None + + @property + def _client(self) -> "cdsapi.Client": + if self._cds_client is None: + if cdsapi is None: + raise ImportError( + "cdsapi is required for CAMS_FX. " + "Install with: pip install 'earth2studio[data]'" + ) + self._cds_client = cdsapi.Client( + debug=False, quiet=True, wait_until_complete=False + ) + return self._cds_client + + def __call__( + self, + time: datetime | list[datetime] | TimeArray, + lead_time: timedelta | list[timedelta] | LeadTimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Retrieve CAMS Global forecast data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Forecast initialization times (UTC). + lead_time : timedelta | list[timedelta] | LeadTimeArray + Forecast lead times. + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSGlobalLexicon. + + Returns + ------- + xr.DataArray + CAMS forecast data array with dims [time, lead_time, variable, lat, lon] + """ + time, lead_time, variable = prep_forecast_inputs(time, lead_time, variable) + self._validate_time(time) + self.cache.mkdir(parents=True, exist_ok=True) + + data_arrays = [] + for t0 in time: + da = self._fetch_forecast(t0, lead_time, variable) + data_arrays.append(da) + + if not self._cache: + shutil.rmtree(self.cache) + + return xr.concat(data_arrays, dim="time") + + async def fetch( + self, + time: datetime | list[datetime] | TimeArray, + lead_time: timedelta | list[timedelta] | LeadTimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Async retrieval of CAMS Global forecast data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Forecast initialization times (UTC). + lead_time : timedelta | list[timedelta] | LeadTimeArray + Forecast lead times. + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSGlobalLexicon. + + Returns + ------- + xr.DataArray + CAMS forecast data array with dims [time, lead_time, variable, lat, lon] + """ + return await asyncio.to_thread(self.__call__, time, lead_time, variable) + + @classmethod + def available( + cls, + time: datetime | list[datetime], + ) -> bool: + """Check if CAMS Global forecast data is available for the requested times. + + Parameters + ---------- + time : datetime | list[datetime] + Timestamps to check availability for. + + Returns + ------- + bool + True if all requested times are within the valid range. + """ + if isinstance(time, datetime): + time = [time] + return all( + (t.replace(tzinfo=None) if t.tzinfo else t) >= cls.CAMS_MIN_TIME + for t in time + ) + + @classmethod + def _validate_time(cls, times: list[datetime]) -> None: + """Validate that requested times are valid for CAMS Global forecast.""" + for t in times: + t_naive = t.replace(tzinfo=None) if t.tzinfo else t + if t_naive < cls.CAMS_MIN_TIME: + raise ValueError( + f"Requested time {t} is before CAMS Global forecast availability " + f"(earliest: {cls.CAMS_MIN_TIME})" + ) + if t_naive.minute != 0 or t_naive.second != 0: + raise ValueError( + f"Requested time {t} must be on the hour for CAMS Global forecast" + ) + + @staticmethod + def _validate_leadtime(lead_times: list[timedelta], max_hours: int) -> None: + """Validate that requested lead times are valid.""" + for lt in lead_times: + hours = int(lt.total_seconds() // 3600) + if lt.total_seconds() % 3600 != 0: + raise ValueError(f"Lead time {lt} must be a whole number of hours") + if hours < 0 or hours > max_hours: + raise ValueError( + f"Lead time {lt} ({hours}h) outside valid range [0, {max_hours}]h" + ) + + def _fetch_forecast( + self, + time: datetime, + lead_times: np.ndarray, + variables: list[str], + ) -> xr.DataArray: + var_infos = [] + for i, v in enumerate(variables): + info = _resolve_variable(v, i) + var_infos.append(info) + + lead_hours = [ + str(int(np.timedelta64(lt, "h").astype(int))) for lt in lead_times + ] + + self._validate_leadtime( + [timedelta(hours=int(h)) for h in lead_hours], self.MAX_LEAD_HOURS + ) + + # Separate surface and pressure-level variables; they need different + # API requests because pressure-level vars require the pressure_level + # parameter and are only available at 3-hourly lead times. + surface_infos = [vi for vi in var_infos if not vi.level] + pressure_infos = [vi for vi in var_infos if vi.level] + + # Validate that pressure-level lead times are multiples of 3 hours + if pressure_infos: + for lt_h in lead_hours: + if int(lt_h) % 3 != 0: + raise ValueError( + f"Lead time {lt_h}h is not a multiple of 3 hours. " + "Pressure-level variables in CAMS Global forecasts " + "are only available at 3-hourly intervals (0, 3, 6, ...)." + ) + + # Download surface variables + surface_ds: xr.Dataset | None = None + if surface_infos: + surface_api_vars = list(dict.fromkeys(vi.api_name for vi in surface_infos)) + nc_path = self._download_cached(time, surface_api_vars, lead_hours) + surface_ds = xr.open_dataset(nc_path, decode_timedelta=False) + + # Download pressure-level variables (grouped by unique levels) + pressure_ds: xr.Dataset | None = None + if pressure_infos: + pressure_api_vars = list( + dict.fromkeys(vi.api_name for vi in pressure_infos) + ) + pressure_levels = sorted({vi.level for vi in pressure_infos}, key=int) + nc_path = self._download_cached( + time, pressure_api_vars, lead_hours, pressure_levels + ) + pressure_ds = xr.open_dataset(nc_path, decode_timedelta=False) + + # Use whichever dataset is available to read grid coordinates + ref_ds = surface_ds if surface_ds is not None else pressure_ds + if ref_ds is None: + raise ValueError( + "No variables to fetch – both surface and pressure lists are empty." + ) + + da = xr.DataArray( + data=np.empty( + ( + 1, + len(lead_times), + len(variables), + len(self.CAMS_LAT), + len(self.CAMS_LON), + ) + ), + dims=["time", "lead_time", "variable", "lat", "lon"], + coords={ + "time": [time], + "lead_time": lead_times, + "variable": variables, + "lat": self.CAMS_LAT, + "lon": self.CAMS_LON, + }, + ) + + for lt_idx, lt_h in enumerate(lead_hours): + for info in var_infos: + _, modifier = CAMSGlobalLexicon[info.e2s_name] + ds = pressure_ds if info.level else surface_ds + if ds is None: # pragma: no cover + raise RuntimeError( + f"Dataset for variable {info.e2s_name} is unexpectedly None" + ) + da[0, lt_idx, info.index] = modifier( + _extract_field( + ds, + info.nc_key, + lead_time_hours=int(lt_h), + pressure_level=int(info.level) if info.level else None, + ) + ) + + if surface_ds is not None: + surface_ds.close() + if pressure_ds is not None: + pressure_ds.close() + return da + + def _download_cached( + self, + time: datetime, + api_vars: list[str], + lead_hours: list[str], + pressure_levels: list[str] | None = None, + ) -> pathlib.Path: + date_str = time.strftime("%Y-%m-%d") + pl_part = ( + f"_pl{'_'.join(sorted(pressure_levels, key=int))}" + if pressure_levels + else "" + ) + sha = hashlib.sha256( + f"cams_fx_{'_'.join(sorted(api_vars))}" + f"_{'_'.join(sorted(lead_hours, key=int))}" + f"{pl_part}" + f"_{date_str}_{time.hour:02d}".encode() + ) + cache_path = self.cache / (sha.hexdigest() + ".nc") + + request_body: dict = { + "variable": api_vars, + "date": [f"{date_str}/{date_str}"], + "type": ["forecast"], + "time": [f"{time.hour:02d}:00"], + "leadtime_hour": lead_hours, + "data_format": "netcdf", + } + if pressure_levels: + request_body["pressure_level"] = pressure_levels + + if self._verbose: + logger.info( + f"Fetching CAMS Global forecast for {date_str} " + f"{time.hour:02d}:00 lead_hours={lead_hours} vars={api_vars}" + + (f" pressure_levels={pressure_levels}" if pressure_levels else "") + ) + return _download_cams_netcdf( + self._client, self.CAMS_DATASET_URI, request_body, cache_path, self._verbose + ) + + @property + def cache(self) -> pathlib.Path: + """Cache location.""" + cache_location = pathlib.Path(datasource_cache_root()) / "cams" + if not self._cache: + cache_location = cache_location / "tmp_cams_fx" + return cache_location diff --git a/earth2studio/lexicon/__init__.py b/earth2studio/lexicon/__init__.py index 34b01a31d..7eadc1c8e 100644 --- a/earth2studio/lexicon/__init__.py +++ b/earth2studio/lexicon/__init__.py @@ -16,6 +16,7 @@ from .ace import ACELexicon from .arco import ARCOLexicon +from .cams import CAMSGlobalLexicon from .cbottle import CBottleLexicon from .cds import CDSLexicon from .cmip6 import CMIP6Lexicon diff --git a/earth2studio/lexicon/base.py b/earth2studio/lexicon/base.py index 88f5fdac2..a85fc989f 100644 --- a/earth2studio/lexicon/base.py +++ b/earth2studio/lexicon/base.py @@ -279,6 +279,16 @@ def __contains__(cls, val: object) -> bool: "strd06": "surface long-wave (thermal) radiation downwards (J m-2) past 6 hours", "sf": "snowfall water equivalent (kg m-2)", "ro": "runoff water equivalent (surface plus subsurface) (kg m-2)", + "aod550": "total aerosol optical depth at 550 nm (dimensionless)", + "duaod550": "dust aerosol optical depth at 550 nm (dimensionless)", + "omaod550": "organic matter aerosol optical depth at 550 nm (dimensionless)", + "bcaod550": "black carbon aerosol optical depth at 550 nm (dimensionless)", + "ssaod550": "sea salt aerosol optical depth at 550 nm (dimensionless)", + "suaod550": "sulphate aerosol optical depth at 550 nm (dimensionless)", + "tcco": "total column carbon monoxide (kg m-2)", + "tcno2": "total column nitrogen dioxide (kg m-2)", + "tco3": "total column ozone (kg m-2)", + "tcso2": "total column sulphur dioxide (kg m-2)", } diff --git a/earth2studio/lexicon/cams.py b/earth2studio/lexicon/cams.py new file mode 100644 index 000000000..bd7455d28 --- /dev/null +++ b/earth2studio/lexicon/cams.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import numpy as np + +from .base import LexiconType + +_GLOBAL = "cams-global-atmospheric-composition-forecasts" + + +class CAMSGlobalLexicon(metaclass=LexiconType): + """Copernicus Atmosphere Monitoring Service Global Forecast Lexicon + + CAMS specified ``::::::`` + + The API variable name (used in the cdsapi request) differs from the NetCDF + key (used to index the downloaded file). Both are stored in the VOCAB. + + Note + ---- + Additional resources: + https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts + """ + + VOCAB = { + # Surface meteorological variables + "u10m": f"{_GLOBAL}::10m_u_component_of_wind::u10::", + "v10m": f"{_GLOBAL}::10m_v_component_of_wind::v10::", + "t2m": f"{_GLOBAL}::2m_temperature::t2m::", + "d2m": f"{_GLOBAL}::2m_dewpoint_temperature::d2m::", + "sp": f"{_GLOBAL}::surface_pressure::sp::", + "msl": f"{_GLOBAL}::mean_sea_level_pressure::msl::", + "tcwv": f"{_GLOBAL}::total_column_water_vapour::tcwv::", + "tp": f"{_GLOBAL}::total_precipitation::tp::", + "skt": f"{_GLOBAL}::skin_temperature::skt::", + "tcc": f"{_GLOBAL}::total_cloud_cover::tcc::", + "z": f"{_GLOBAL}::surface_geopotential::z::", + "lsm": f"{_GLOBAL}::land_sea_mask::lsm::", + # Aerosol optical depth variables (550nm) + "aod550": f"{_GLOBAL}::total_aerosol_optical_depth_550nm::aod550::", + "duaod550": f"{_GLOBAL}::dust_aerosol_optical_depth_550nm::duaod550::", + "omaod550": f"{_GLOBAL}::organic_matter_aerosol_optical_depth_550nm::omaod550::", + "bcaod550": f"{_GLOBAL}::black_carbon_aerosol_optical_depth_550nm::bcaod550::", + "ssaod550": f"{_GLOBAL}::sea_salt_aerosol_optical_depth_550nm::ssaod550::", + "suaod550": f"{_GLOBAL}::sulphate_aerosol_optical_depth_550nm::suaod550::", + # Total column trace gases + "tcco": f"{_GLOBAL}::total_column_carbon_monoxide::tcco::", + "tcno2": f"{_GLOBAL}::total_column_nitrogen_dioxide::tcno2::", + "tco3": f"{_GLOBAL}::total_column_ozone::gtco3::", + "tcso2": f"{_GLOBAL}::total_column_sulphur_dioxide::tcso2::", + # Pressure-level variables (multi-level, available every 3h lead time) + # u-component of wind + "u200": f"{_GLOBAL}::u_component_of_wind::u::200", + "u250": f"{_GLOBAL}::u_component_of_wind::u::250", + "u300": f"{_GLOBAL}::u_component_of_wind::u::300", + "u400": f"{_GLOBAL}::u_component_of_wind::u::400", + "u500": f"{_GLOBAL}::u_component_of_wind::u::500", + "u600": f"{_GLOBAL}::u_component_of_wind::u::600", + "u700": f"{_GLOBAL}::u_component_of_wind::u::700", + "u850": f"{_GLOBAL}::u_component_of_wind::u::850", + "u925": f"{_GLOBAL}::u_component_of_wind::u::925", + "u1000": f"{_GLOBAL}::u_component_of_wind::u::1000", + # v-component of wind + "v200": f"{_GLOBAL}::v_component_of_wind::v::200", + "v250": f"{_GLOBAL}::v_component_of_wind::v::250", + "v300": f"{_GLOBAL}::v_component_of_wind::v::300", + "v400": f"{_GLOBAL}::v_component_of_wind::v::400", + "v500": f"{_GLOBAL}::v_component_of_wind::v::500", + "v600": f"{_GLOBAL}::v_component_of_wind::v::600", + "v700": f"{_GLOBAL}::v_component_of_wind::v::700", + "v850": f"{_GLOBAL}::v_component_of_wind::v::850", + "v925": f"{_GLOBAL}::v_component_of_wind::v::925", + "v1000": f"{_GLOBAL}::v_component_of_wind::v::1000", + # Temperature + "t200": f"{_GLOBAL}::temperature::t::200", + "t250": f"{_GLOBAL}::temperature::t::250", + "t300": f"{_GLOBAL}::temperature::t::300", + "t400": f"{_GLOBAL}::temperature::t::400", + "t500": f"{_GLOBAL}::temperature::t::500", + "t600": f"{_GLOBAL}::temperature::t::600", + "t700": f"{_GLOBAL}::temperature::t::700", + "t850": f"{_GLOBAL}::temperature::t::850", + "t925": f"{_GLOBAL}::temperature::t::925", + "t1000": f"{_GLOBAL}::temperature::t::1000", + # Geopotential + "z200": f"{_GLOBAL}::geopotential::z::200", + "z250": f"{_GLOBAL}::geopotential::z::250", + "z300": f"{_GLOBAL}::geopotential::z::300", + "z400": f"{_GLOBAL}::geopotential::z::400", + "z500": f"{_GLOBAL}::geopotential::z::500", + "z600": f"{_GLOBAL}::geopotential::z::600", + "z700": f"{_GLOBAL}::geopotential::z::700", + "z850": f"{_GLOBAL}::geopotential::z::850", + "z925": f"{_GLOBAL}::geopotential::z::925", + "z1000": f"{_GLOBAL}::geopotential::z::1000", + # Specific humidity + "q200": f"{_GLOBAL}::specific_humidity::q::200", + "q250": f"{_GLOBAL}::specific_humidity::q::250", + "q300": f"{_GLOBAL}::specific_humidity::q::300", + "q400": f"{_GLOBAL}::specific_humidity::q::400", + "q500": f"{_GLOBAL}::specific_humidity::q::500", + "q600": f"{_GLOBAL}::specific_humidity::q::600", + "q700": f"{_GLOBAL}::specific_humidity::q::700", + "q850": f"{_GLOBAL}::specific_humidity::q::850", + "q925": f"{_GLOBAL}::specific_humidity::q::925", + "q1000": f"{_GLOBAL}::specific_humidity::q::1000", + } + + @classmethod + def get_item(cls, val: str) -> tuple[str, Callable]: + """Return name in CAMS vocabulary.""" + cams_key = cls.VOCAB[val] + + def mod(x: np.ndarray) -> np.ndarray: + return x + + return cams_key, mod diff --git a/test/data/test_cams.py b/test/data/test_cams.py new file mode 100644 index 000000000..1aa11fed1 --- /dev/null +++ b/test/data/test_cams.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import pathlib +import shutil +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import xarray as xr + +from earth2studio.data import CAMS_FX + +CAMS_ADS_URL = "https://ads.atmosphere.copernicus.eu/api" + + +@pytest.fixture(autouse=True) +def _set_cdsapi_url(monkeypatch): + """Point cdsapi at the ADS endpoint for all tests in this module.""" + monkeypatch.setenv("CDSAPI_URL", CAMS_ADS_URL) + + +YESTERDAY = datetime.datetime.now(datetime.UTC).replace( + hour=0, minute=0, second=0, microsecond=0 +) - datetime.timedelta(days=1) + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(120) +@pytest.mark.parametrize("variable", ["aod550", ["aod550", "tcco"]]) +@pytest.mark.parametrize( + "lead_time", + [ + datetime.timedelta(hours=0), + [datetime.timedelta(hours=0), datetime.timedelta(hours=24)], + ], +) +def test_cams_fx_fetch(variable, lead_time): + time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) + ds = CAMS_FX(cache=False) + data = ds(time, lead_time, variable) + shape = data.shape + + if isinstance(variable, str): + variable = [variable] + if isinstance(lead_time, datetime.timedelta): + lead_time = [lead_time] + + assert shape[0] == 1 # time + assert shape[1] == len(lead_time) + assert shape[2] == len(variable) + assert len(data.coords["lat"]) > 0 + assert len(data.coords["lon"]) > 0 + assert not np.isnan(data.values).all() + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(120) +@pytest.mark.parametrize("cache", [True, False]) +def test_cams_fx_cache(cache): + time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) + lead_time = datetime.timedelta(hours=0) + ds = CAMS_FX(cache=cache) + data = ds(time, lead_time, ["aod550", "tcco"]) + shape = data.shape + + assert shape[0] == 1 + assert shape[2] == 2 + assert not np.isnan(data.values).all() + assert pathlib.Path(ds.cache).is_dir() == cache + + data = ds(time, lead_time, "aod550") + assert data.shape[2] == 1 + + try: + shutil.rmtree(ds.cache) + except FileNotFoundError: + pass + + +@pytest.mark.timeout(30) +def test_cams_fx_invalid(): + with pytest.raises((ValueError, KeyError)): + ds = CAMS_FX() + ds(YESTERDAY, datetime.timedelta(hours=0), "nonexistent_var") + + +def test_cams_fx_time_validation(): + with pytest.raises(ValueError, match="CAMS Global forecast"): + CAMS_FX._validate_time([datetime.datetime(2014, 1, 1)]) + + +def test_cams_fx_available(): + assert CAMS_FX.available(datetime.datetime(2024, 1, 1)) + assert not CAMS_FX.available(datetime.datetime(2010, 1, 1)) + assert CAMS_FX.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) + + +def test_cams_fx_api_vars_dedup(): + """Variables sharing the same API name must not produce duplicate requests.""" + from earth2studio.data.cams import _resolve_variable + + info_a = _resolve_variable("aod550", 0) + info_b = _resolve_variable("aod550", 1) + api_vars = list(dict.fromkeys([info_a.api_name, info_b.api_name])) + assert len(api_vars) == 1 + + +def test_cams_fx_call_mock(tmp_path: pathlib.Path): + """Test CAMS_FX __call__ with surface and pressure-level variables (mocked).""" + lat = CAMS_FX.CAMS_LAT + lon = CAMS_FX.CAMS_LON + forecast_period = np.array([0, 3, 6], dtype=np.float64) + pressure_level = np.array([500.0, 850.0]) + + # Surface NetCDF + mock_surface_ds = xr.Dataset( + { + "aod550": ( + ["forecast_period", "latitude", "longitude"], + np.random.rand(len(forecast_period), len(lat), len(lon)), + ), + "tcco": ( + ["forecast_period", "latitude", "longitude"], + np.random.rand(len(forecast_period), len(lat), len(lon)), + ), + }, + coords={ + "forecast_period": forecast_period, + "latitude": lat, + "longitude": lon, + }, + ) + surface_path = tmp_path / "mock_surface.nc" + mock_surface_ds.to_netcdf(surface_path) + + # Pressure-level NetCDF + mock_pressure_ds = xr.Dataset( + { + "u": ( + ["forecast_period", "pressure_level", "latitude", "longitude"], + np.random.rand( + len(forecast_period), len(pressure_level), len(lat), len(lon) + ), + ), + "t": ( + ["forecast_period", "pressure_level", "latitude", "longitude"], + np.random.rand( + len(forecast_period), len(pressure_level), len(lat), len(lon) + ), + ), + }, + coords={ + "forecast_period": forecast_period, + "pressure_level": pressure_level, + "latitude": lat, + "longitude": lon, + }, + ) + pressure_path = tmp_path / "mock_pressure.nc" + mock_pressure_ds.to_netcdf(pressure_path) + + with patch("earth2studio.data.cams.cdsapi") as mock_cdsapi: + mock_cdsapi.Client = MagicMock() + time = datetime.datetime(2024, 6, 1, 0, 0) + lead_time = [datetime.timedelta(hours=0), datetime.timedelta(hours=6)] + + # --- surface-only fetch: single download --- + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.return_value = surface_path + data = CAMS_FX(cache=False)(time, lead_time, ["aod550", "tcco"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["aod550", "tcco"] + mock_dl.assert_called_once() + + # --- pressure-level-only fetch: single download --- + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.return_value = pressure_path + data = CAMS_FX(cache=False)(time, lead_time, ["u500", "t850"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["u500", "t850"] + mock_dl.assert_called_once() + + # --- mixed fetch: two separate downloads --- + call_count = 0 + + def _side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + return surface_path if call_count == 1 else pressure_path + + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.side_effect = _side_effect + data = CAMS_FX(cache=False)(time, lead_time, ["aod550", "u500"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["aod550", "u500"] + assert mock_dl.call_count == 2 + + +def test_cams_fx_pressure_level_leadtime_validation(): + """Pressure-level vars at non-3h lead times must raise ValueError.""" + with patch("earth2studio.data.cams.cdsapi") as mock_cdsapi: + mock_cdsapi.Client = MagicMock() + ds = CAMS_FX(cache=False) + with pytest.raises(ValueError, match="multiple of 3 hours"): + ds( + datetime.datetime(2024, 6, 1, 0, 0), + datetime.timedelta(hours=1), + "u500", + ) diff --git a/test/data/test_cds.py b/test/data/test_cds.py index 150e5c6df..9f8718be8 100644 --- a/test/data/test_cds.py +++ b/test/data/test_cds.py @@ -23,6 +23,14 @@ from earth2studio.data import CDS +CDS_API_URL = "https://cds.climate.copernicus.eu/api" + + +@pytest.fixture(autouse=True) +def _set_cdsapi_url(monkeypatch): + """Point cdsapi at the CDS endpoint for all tests in this module.""" + monkeypatch.setenv("CDSAPI_URL", CDS_API_URL) + @pytest.mark.slow @pytest.mark.xfail diff --git a/test/lexicon/test_cams_lexicon.py b/test/lexicon/test_cams_lexicon.py new file mode 100644 index 000000000..8431f80e5 --- /dev/null +++ b/test/lexicon/test_cams_lexicon.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from earth2studio.lexicon import CAMSGlobalLexicon + + +@pytest.mark.parametrize( + "variable", + [ + # Surface variables + ["u10m", "v10m", "t2m"], + ["d2m", "sp", "msl"], + ["tcwv", "tp", "tcc"], + # Aerosol optical depth + ["aod550"], + ["duaod550", "tcno2"], + ["bcaod550", "ssaod550", "suaod550"], + # Trace gases + ["tcco", "tco3", "tcso2"], + ["omaod550"], + # Pressure level variables + ["u500", "v500", "t500"], + ["z850", "q700"], + ], +) +def test_cams_lexicon(variable): + data = np.random.randn(len(variable), 8) + for v in variable: + label, modifier = CAMSGlobalLexicon[v] + output = modifier(data) + assert isinstance(label, str) + assert data.shape == output.shape + + +def test_cams_lexicon_invalid(): + with pytest.raises(KeyError): + CAMSGlobalLexicon["nonexistent_variable"] + + +def test_cams_lexicon_vocab_format(): + for key, value in CAMSGlobalLexicon.VOCAB.items(): + parts = value.split("::") + assert ( + len(parts) == 4 + ), f"VOCAB entry '{key}' must have format 'dataset::api_var::nc_key::level'" + assert ( + parts[0] == "cams-global-atmospheric-composition-forecasts" + ), f"Expected global dataset in VOCAB entry '{key}', got '{parts[0]}'" + + +def test_cams_lexicon_surface_vars(): + """Test that all expected surface variables are present.""" + expected_surface = [ + "u10m", + "v10m", + "t2m", + "d2m", + "sp", + "msl", + "tcwv", + "tp", + "skt", + "z", + "lsm", + "tcc", + ] + for var in expected_surface: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing surface variable: {var}" + + +def test_cams_lexicon_aod_vars(): + """Test that all aerosol optical depth variables are present.""" + expected_aod = [ + "aod550", + "duaod550", + "omaod550", + "bcaod550", + "ssaod550", + "suaod550", + ] + for var in expected_aod: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing AOD variable: {var}" + + +def test_cams_lexicon_trace_gas_vars(): + """Test that all trace gas variables are present.""" + expected_gases = ["tcco", "tcno2", "tco3", "tcso2"] + for var in expected_gases: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing trace gas variable: {var}" + + +def test_cams_lexicon_pressure_level_vars(): + """Test that pressure level variables are present for common levels.""" + pressure_levels = [200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + variables = ["u", "v", "t", "z", "q"] + + for var in variables: + for level in pressure_levels: + key = f"{var}{level}" + assert ( + key in CAMSGlobalLexicon.VOCAB + ), f"Missing pressure level variable: {key}"