From ed05ae5608eb30d26e27c120fa6929a5e5b7ebb8 Mon Sep 17 00:00:00 2001 From: "Claude Sonnet 4.5" Date: Sun, 29 Mar 2026 19:22:20 +0200 Subject: [PATCH 1/8] feat: add CAMS atmospheric composition data source and lexicon MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add DataSource (CAMS) and ForecastSource (CAMS_FX) for Copernicus Atmosphere Monitoring Service data via the CDS API. CAMS provides atmospheric composition / air quality data not currently available in earth2studio — complementing the existing weather-focused data sources (GFS, IFS, ERA5, etc.). Data sources: - CAMS: EU air quality analysis (0.1 deg, 9 pollutants, 10 height levels) - CAMS_FX: EU + Global forecasts (EU 0.1 deg up to 96h, Global 0.4 deg up to 120h) Variables include: dust, PM2.5, PM10, SO2, NO2, O3, CO, NH3, NO (EU surface and multi-level), plus AOD and total column products (Global). Lexicon: 101 entries covering all 9 pollutants at all 9 EU altitude levels (50-5000m), plus surface and 11 global column/AOD variables. Implementation follows upstream conventions: - Protocol-compliant __call__ and async fetch methods - Badges section for API doc filtering - Time validation, available() classmethod - Lazy CDS client initialization - pathlib-based caching with SHA256 keys - Tests with @pytest.mark.xfail for CI without CDS credentials Requires: cdsapi (already in the 'data' optional dependency group) Co-Authored-By: Claude Opus 4.6 (1M context) --- earth2studio/data/__init__.py | 1 + earth2studio/data/cams.py | 629 ++++++++++++++++++++++++++++++ earth2studio/lexicon/__init__.py | 1 + earth2studio/lexicon/cams.py | 95 +++++ test/data/test_cams.py | 136 +++++++ test/lexicon/test_cams_lexicon.py | 68 ++++ 6 files changed, 930 insertions(+) create mode 100644 earth2studio/data/cams.py create mode 100644 earth2studio/lexicon/cams.py create mode 100644 test/data/test_cams.py create mode 100644 test/lexicon/test_cams_lexicon.py diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index a27049741..cbd35443c 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -17,6 +17,7 @@ from .ace2 import ACE2ERA5Data from .arco import ARCO from .base import DataSource, ForecastSource +from .cams import CAMS, CAMS_FX from .cbottle import CBottle3D from .cds import CDS from .cmip6 import CMIP6, CMIP6MultiRealm diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py new file mode 100644 index 000000000..9bbac9a11 --- /dev/null +++ b/earth2studio/data/cams.py @@ -0,0 +1,629 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import hashlib +import pathlib +import shutil +from dataclasses import dataclass +from datetime import datetime, timedelta +from time import sleep + +import numpy as np +import xarray as xr +from loguru import logger + +from earth2studio.data.utils import ( + datasource_cache_root, + prep_data_inputs, + prep_forecast_inputs, +) +from earth2studio.lexicon import CAMSLexicon +from earth2studio.utils.imports import ( + OptionalDependencyFailure, + check_optional_dependencies, +) +from earth2studio.utils.type import LeadTimeArray, TimeArray, VariableArray + +try: + import cdsapi +except ImportError: + OptionalDependencyFailure("data") + cdsapi = None + +# CAMS EU analysis available from 2019-07-01 onward +_CAMS_EU_MIN_TIME = datetime(2019, 7, 1) +# CAMS Global forecast available from 2015-01-01 onward +_CAMS_GLOBAL_MIN_TIME = datetime(2015, 1, 1) + +EU_DATASET = "cams-europe-air-quality-forecasts" + + +@dataclass +class _CAMSVarInfo: + e2s_name: str + api_name: str + nc_key: str + dataset: str + level: str + index: int + + +def _resolve_variable(e2s_name: str, index: int) -> _CAMSVarInfo: + cams_key, _ = CAMSLexicon[e2s_name] + dataset, api_name, nc_key, level = cams_key.split("::") + return _CAMSVarInfo( + e2s_name=e2s_name, + api_name=api_name, + nc_key=nc_key, + dataset=dataset, + level=level, + index=index, + ) + + +def _download_cams_netcdf( + client: "cdsapi.Client", + dataset: str, + request_body: dict, + cache_path: pathlib.Path, + verbose: bool = True, +) -> pathlib.Path: + if cache_path.is_file(): + return cache_path + + r = client.retrieve(dataset, request_body) + while True: + r.update() + reply = r.reply + if verbose: + logger.debug( + f"Request ID:{reply['request_id']}, state: {reply['state']}" + ) + if reply["state"] == "completed": + break + elif reply["state"] in ("queued", "running"): + sleep(5.0) + elif reply["state"] in ("failed",): + raise RuntimeError( + f"CAMS request failed for {dataset}: " + + reply.get("error", {}).get("message", "unknown error") + ) + else: + sleep(2.0) + r.download(str(cache_path)) + return cache_path + + +def _extract_field( + ds: xr.Dataset, nc_key: str, level: str = "0", time_index: int = 0 +) -> np.ndarray: + if nc_key not in ds: + raise ValueError( + f"Variable '{nc_key}' not found in NetCDF. " + f"Available: {list(ds.data_vars)}" + ) + field = ds[nc_key] + non_spatial = [d for d in field.dims if d not in ("latitude", "longitude")] + sel: dict[str, int] = {} + for d in non_spatial: + if d == "level" and level: + level_val = float(level) if level else 0.0 + level_coords = field.coords["level"].values + nearest_idx = int(np.argmin(np.abs(level_coords - level_val))) + sel[d] = nearest_idx + elif d in ("time", "forecast_period", "forecast_reference_time"): + sel[d] = time_index + elif field.sizes[d] > 1: + sel[d] = 0 + else: + sel[d] = 0 + if sel: + field = field.isel(sel) + return field.values + + +def _validate_cams_time( + times: list[datetime], min_time: datetime, name: str +) -> None: + for t in times: + t_naive = t.replace(tzinfo=None) if t.tzinfo else t + if t_naive < min_time: + raise ValueError( + f"Requested time {t} is before {name} availability " + f"(earliest: {min_time})" + ) + if t_naive.minute != 0 or t_naive.second != 0: + raise ValueError( + f"Requested time {t} must be on the hour for {name}" + ) + + +def _validate_cams_leadtime(lead_times: list[timedelta], max_hours: int) -> None: + for lt in lead_times: + hours = int(lt.total_seconds() // 3600) + if lt.total_seconds() % 3600 != 0: + raise ValueError( + f"Lead time {lt} must be a whole number of hours" + ) + if hours < 0 or hours > max_hours: + raise ValueError( + f"Lead time {lt} ({hours}h) outside valid range [0, {max_hours}]h" + ) + + +@check_optional_dependencies() +class CAMS: + """CAMS European Air Quality analysis data source. + + Uses the ``cams-europe-air-quality-forecasts`` dataset with ``type=analysis``. + Grid is 0.1 deg over Europe, read dynamically from the downloaded NetCDF. + + Parameters + ---------- + cache : bool, optional + Cache data source on local memory, by default True + verbose : bool, optional + Print download progress, by default True + + Warning + ------- + This is a remote data source and can potentially download a large amount of data + to your local machine for large requests. + + Note + ---- + Additional information on the data repository, registration, and authentication can + be referenced here: + + - https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts + - https://cds.climate.copernicus.eu/how-to-api + + Badges + ------ + region:europe dataclass:analysis product:airquality + """ + + def __init__(self, cache: bool = True, verbose: bool = True): + self._cache = cache + self._verbose = verbose + self._cds_client: "cdsapi.Client | None" = None + + @property + def _client(self) -> "cdsapi.Client": + if self._cds_client is None: + if cdsapi is None: + raise ImportError( + "cdsapi is required for CAMS. " + "Install with: pip install 'earth2studio[data]'" + ) + self._cds_client = cdsapi.Client( + debug=False, quiet=True, wait_until_complete=False + ) + return self._cds_client + + def __call__( + self, + time: datetime | list[datetime] | TimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Retrieve CAMS EU analysis data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Timestamps to return data for (UTC). + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSLexicon with EU dataset. + + Returns + ------- + xr.DataArray + CAMS data array with dims [time, variable, lat, lon] + """ + time, variable = prep_data_inputs(time, variable) + self._validate_time(time) + self.cache.mkdir(parents=True, exist_ok=True) + + data_arrays = [] + for t0 in time: + da = self._fetch_analysis(t0, variable) + data_arrays.append(da) + + if not self._cache: + shutil.rmtree(self.cache) + + return xr.concat(data_arrays, dim="time") + + async def fetch( + self, + time: datetime | list[datetime] | TimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Async retrieval of CAMS EU analysis data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Timestamps to return data for (UTC). + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSLexicon with EU dataset. + + Returns + ------- + xr.DataArray + CAMS data array with dims [time, variable, lat, lon] + """ + return await asyncio.to_thread(self.__call__, time, variable) + + @classmethod + def available( + cls, + time: datetime | list[datetime], + ) -> bool: + """Check if CAMS EU analysis data is available for the requested times. + + Parameters + ---------- + time : datetime | list[datetime] + Timestamps to check availability for. + + Returns + ------- + bool + True if all requested times are within the valid range. + """ + if isinstance(time, datetime): + time = [time] + return all(t >= _CAMS_EU_MIN_TIME for t in time) + + @staticmethod + def _validate_time(times: list[datetime]) -> None: + _validate_cams_time(times, _CAMS_EU_MIN_TIME, "CAMS EU analysis") + + def _fetch_analysis(self, time: datetime, variables: list[str]) -> xr.DataArray: + var_infos = [] + for i, v in enumerate(variables): + info = _resolve_variable(v, i) + if info.dataset != EU_DATASET: + raise ValueError( + f"CAMS analysis only supports EU dataset, got '{info.dataset}' " + f"for variable '{v}'. Use CAMS_FX for global forecast variables." + ) + var_infos.append(info) + + api_vars = [vi.api_name for vi in var_infos] + levels = sorted(set(vi.level for vi in var_infos if vi.level)) + if not levels: + levels = ["0"] + nc_path = self._download_cached(time, api_vars, levels) + + ds = xr.open_dataset(nc_path, decode_timedelta=False) + lat = ds.latitude.values + lon = ds.longitude.values + + da = xr.DataArray( + data=np.empty((1, len(variables), len(lat), len(lon))), + dims=["time", "variable", "lat", "lon"], + coords={ + "time": [time], + "variable": variables, + "lat": lat, + "lon": lon, + }, + ) + + for info in var_infos: + _, modifier = CAMSLexicon[info.e2s_name] + da[0, info.index] = modifier( + _extract_field(ds, info.nc_key, level=info.level) + ) + + ds.close() + return da + + def _download_cached( + self, time: datetime, api_vars: list[str], levels: list[str] + ) -> pathlib.Path: + date_str = time.strftime("%Y-%m-%d") + sha = hashlib.sha256( + f"cams_eu_{'_'.join(sorted(api_vars))}" + f"_{'_'.join(sorted(levels))}_{date_str}_{time.hour:02d}".encode() + ) + cache_path = self.cache / (sha.hexdigest() + ".nc") + + request_body = { + "variable": api_vars, + "model": ["ensemble"], + "level": levels, + "date": [f"{date_str}/{date_str}"], + "type": ["analysis"], + "time": [f"{time.hour:02d}:00"], + "leadtime_hour": ["0"], + "data_format": "netcdf", + } + + if self._verbose: + logger.info( + f"Fetching CAMS EU analysis for {date_str} {time.hour:02d}:00 " + f"vars={api_vars}" + ) + return _download_cams_netcdf( + self._client, EU_DATASET, request_body, cache_path, self._verbose + ) + + @property + def cache(self) -> pathlib.Path: + """Cache location.""" + cache_location = pathlib.Path(datasource_cache_root()) / "cams" + if not self._cache: + cache_location = cache_location / "tmp_cams" + return cache_location + + +@check_optional_dependencies() +class CAMS_FX: + """CAMS forecast data source. + + Supports both EU (``cams-europe-air-quality-forecasts``) and Global + (``cams-global-atmospheric-composition-forecasts``) forecast datasets. + The dataset is determined automatically from the requested variables via CAMSLexicon. + + Parameters + ---------- + cache : bool, optional + Cache data source on local memory, by default True + verbose : bool, optional + Print download progress, by default True + + Warning + ------- + This is a remote data source and can potentially download a large amount of data + to your local machine for large requests. + + Note + ---- + Additional information on the data repository, registration, and authentication can + be referenced here: + + - https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts + - https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts + - https://cds.climate.copernicus.eu/how-to-api + + Badges + ------ + region:europe region:global dataclass:simulation product:airquality + """ + + # EU forecasts go up to 96h, Global up to 120h + MAX_EU_LEAD_HOURS = 96 + MAX_GLOBAL_LEAD_HOURS = 120 + + def __init__(self, cache: bool = True, verbose: bool = True): + self._cache = cache + self._verbose = verbose + self._cds_client: "cdsapi.Client | None" = None + + @property + def _client(self) -> "cdsapi.Client": + if self._cds_client is None: + if cdsapi is None: + raise ImportError( + "cdsapi is required for CAMS_FX. " + "Install with: pip install 'earth2studio[data]'" + ) + self._cds_client = cdsapi.Client( + debug=False, quiet=True, wait_until_complete=False + ) + return self._cds_client + + def __call__( + self, + time: datetime | list[datetime] | TimeArray, + lead_time: timedelta | list[timedelta] | LeadTimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Retrieve CAMS forecast data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Forecast initialization times (UTC). + lead_time : timedelta | list[timedelta] | LeadTimeArray + Forecast lead times. + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSLexicon. + + Returns + ------- + xr.DataArray + CAMS forecast data array with dims [time, lead_time, variable, lat, lon] + """ + time, lead_time, variable = prep_forecast_inputs(time, lead_time, variable) + self._validate_time(time) + self.cache.mkdir(parents=True, exist_ok=True) + + data_arrays = [] + for t0 in time: + da = self._fetch_forecast(t0, lead_time, variable) + data_arrays.append(da) + + if not self._cache: + shutil.rmtree(self.cache) + + return xr.concat(data_arrays, dim="time") + + async def fetch( + self, + time: datetime | list[datetime] | TimeArray, + lead_time: timedelta | list[timedelta] | LeadTimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Async retrieval of CAMS forecast data. + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Forecast initialization times (UTC). + lead_time : timedelta | list[timedelta] | LeadTimeArray + Forecast lead times. + variable : str | list[str] | VariableArray + Variables to return. Must be in CAMSLexicon. + + Returns + ------- + xr.DataArray + CAMS forecast data array with dims [time, lead_time, variable, lat, lon] + """ + return await asyncio.to_thread(self.__call__, time, lead_time, variable) + + @classmethod + def available( + cls, + time: datetime | list[datetime], + ) -> bool: + """Check if CAMS forecast data is available for the requested times. + + Parameters + ---------- + time : datetime | list[datetime] + Timestamps to check availability for. + + Returns + ------- + bool + True if all requested times are within the valid range. + """ + if isinstance(time, datetime): + time = [time] + return all(t >= _CAMS_GLOBAL_MIN_TIME for t in time) + + @staticmethod + def _validate_time(times: list[datetime]) -> None: + _validate_cams_time(times, _CAMS_GLOBAL_MIN_TIME, "CAMS forecast") + + def _fetch_forecast( + self, + time: datetime, + lead_times: np.ndarray, + variables: list[str], + ) -> xr.DataArray: + datasets: dict[str, list[_CAMSVarInfo]] = {} + for i, v in enumerate(variables): + info = _resolve_variable(v, i) + datasets.setdefault(info.dataset, []).append(info) + + if len(datasets) > 1: + raise ValueError( + "Cannot mix EU and Global CAMS variables in a single CAMS_FX call. " + f"Got datasets: {list(datasets.keys())}" + ) + + dataset_name = next(iter(datasets)) + var_infos = datasets[dataset_name] + api_vars = [vi.api_name for vi in var_infos] + levels = sorted(set(vi.level for vi in var_infos if vi.level)) + lead_hours = [ + str(int(np.timedelta64(lt, "h").astype(int))) for lt in lead_times + ] + + is_eu = dataset_name == EU_DATASET + max_hours = self.MAX_EU_LEAD_HOURS if is_eu else self.MAX_GLOBAL_LEAD_HOURS + _validate_cams_leadtime( + [timedelta(hours=int(h)) for h in lead_hours], max_hours + ) + + nc_path = self._download_cached( + time, dataset_name, api_vars, lead_hours, levels + ) + + ds = xr.open_dataset(nc_path, decode_timedelta=False) + lat = ds.latitude.values + lon = ds.longitude.values + + da = xr.DataArray( + data=np.empty( + (1, len(lead_times), len(variables), len(lat), len(lon)) + ), + dims=["time", "lead_time", "variable", "lat", "lon"], + coords={ + "time": [time], + "lead_time": lead_times, + "variable": variables, + "lat": lat, + "lon": lon, + }, + ) + + for lt_idx in range(len(lead_hours)): + for info in var_infos: + _, modifier = CAMSLexicon[info.e2s_name] + da[0, lt_idx, info.index] = modifier( + _extract_field( + ds, info.nc_key, level=info.level, time_index=lt_idx + ) + ) + + ds.close() + return da + + def _download_cached( + self, + time: datetime, + dataset: str, + api_vars: list[str], + lead_hours: list[str], + levels: list[str] | None = None, + ) -> pathlib.Path: + date_str = time.strftime("%Y-%m-%d") + level_str = "_".join(sorted(levels)) if levels else "none" + sha = hashlib.sha256( + f"cams_fx_{dataset}_{'_'.join(sorted(api_vars))}" + f"_{'_'.join(lead_hours)}_{level_str}" + f"_{date_str}_{time.hour:02d}".encode() + ) + cache_path = self.cache / (sha.hexdigest() + ".nc") + + is_eu = dataset == EU_DATASET + + request_body: dict = { + "variable": api_vars, + "date": [f"{date_str}/{date_str}"], + "type": ["forecast"], + "time": [f"{time.hour:02d}:00"], + "leadtime_hour": lead_hours, + "data_format": "netcdf", + } + if is_eu: + request_body["model"] = ["ensemble"] + request_body["level"] = levels if levels else ["0"] + + if self._verbose: + logger.info( + f"Fetching CAMS forecast ({dataset.split('-')[1]}) for {date_str} " + f"{time.hour:02d}:00 lead_hours={lead_hours} vars={api_vars}" + ) + return _download_cams_netcdf( + self._client, dataset, request_body, cache_path, self._verbose + ) + + @property + def cache(self) -> pathlib.Path: + """Cache location.""" + cache_location = pathlib.Path(datasource_cache_root()) / "cams" + if not self._cache: + cache_location = cache_location / "tmp_cams_fx" + return cache_location diff --git a/earth2studio/lexicon/__init__.py b/earth2studio/lexicon/__init__.py index 34b01a31d..e800e03d0 100644 --- a/earth2studio/lexicon/__init__.py +++ b/earth2studio/lexicon/__init__.py @@ -16,6 +16,7 @@ from .ace import ACELexicon from .arco import ARCOLexicon +from .cams import CAMSLexicon from .cbottle import CBottleLexicon from .cds import CDSLexicon from .cmip6 import CMIP6Lexicon diff --git a/earth2studio/lexicon/cams.py b/earth2studio/lexicon/cams.py new file mode 100644 index 000000000..46f3e458a --- /dev/null +++ b/earth2studio/lexicon/cams.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import numpy as np + +from .base import LexiconType + +_EU = "cams-europe-air-quality-forecasts" +_GLOBAL = "cams-global-atmospheric-composition-forecasts" + +# All EU levels available in the CAMS API (meters above ground) +_EU_LEVELS = [50, 100, 250, 500, 750, 1000, 2000, 3000, 5000] + +# short_name -> (api_request_name, netcdf_key, surface_e2s_name) +_EU_POLLUTANTS = { + "dust": ("dust", "dust", "dust"), + "pm2p5": ("particulate_matter_2.5um", "pm2p5_conc", "pm2p5"), + "pm10": ("particulate_matter_10um", "pm10_conc", "pm10"), + "so2": ("sulphur_dioxide", "so2_conc", "so2sfc"), + "no2": ("nitrogen_dioxide", "no2_conc", "no2sfc"), + "o3": ("ozone", "o3_conc", "o3sfc"), + "co": ("carbon_monoxide", "co_conc", "cosfc"), + "nh3": ("ammonia", "nh3_conc", "nh3sfc"), + "no": ("nitrogen_monoxide", "no_conc", "nosfc"), +} + + +def _build_eu_vocab() -> dict[str, str]: + vocab: dict[str, str] = {} + for short, (api, nc, sfc_name) in _EU_POLLUTANTS.items(): + vocab[sfc_name] = f"{_EU}::{api}::{nc}::0" + for level in _EU_LEVELS: + vocab[f"{short}_{level}m"] = f"{_EU}::{api}::{nc}::{level}" + return vocab + + +class CAMSLexicon(metaclass=LexiconType): + """Copernicus Atmosphere Monitoring Service Lexicon + + CAMS specified ``::::::`` + + The API variable name (used in the cdsapi request) differs from the NetCDF + key (used to index the downloaded file). Both are stored in the VOCAB. + + Note + ---- + EU multi-level variables are available at: 0 (surface), 50, 100, 250, 500, + 750, 1000, 2000, 3000, 5000 m. All pollutants are mapped at all available + levels. + + Additional resources: + https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts + https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts + """ + + VOCAB = { + **_build_eu_vocab(), + # ---- CAMS Global (column/AOD, 0.4 deg grid) ---- + "aod550": f"{_GLOBAL}::total_aerosol_optical_depth_550nm::aod550::", + "duaod550": f"{_GLOBAL}::dust_aerosol_optical_depth_550nm::duaod550::", + "omaod550": f"{_GLOBAL}::organic_matter_aerosol_optical_depth_550nm::omaod550::", + "bcaod550": f"{_GLOBAL}::black_carbon_aerosol_optical_depth_550nm::bcaod550::", + "ssaod550": f"{_GLOBAL}::sea_salt_aerosol_optical_depth_550nm::ssaod550::", + "suaod550": f"{_GLOBAL}::sulphate_aerosol_optical_depth_550nm::suaod550::", + "tcco": f"{_GLOBAL}::total_column_carbon_monoxide::tcco::", + "tcno2": f"{_GLOBAL}::total_column_nitrogen_dioxide::tcno2::", + "tco3": f"{_GLOBAL}::total_column_ozone::tco3::", + "tcso2": f"{_GLOBAL}::total_column_sulphur_dioxide::tcso2::", + "gtco3": f"{_GLOBAL}::gems_total_column_ozone::gtco3::", + } + + @classmethod + def get_item(cls, val: str) -> tuple[str, Callable]: + """Return name in CAMS vocabulary.""" + cams_key = cls.VOCAB[val] + + def mod(x: np.ndarray) -> np.ndarray: + return x + + return cams_key, mod diff --git a/test/data/test_cams.py b/test/data/test_cams.py new file mode 100644 index 000000000..1742ae7dd --- /dev/null +++ b/test/data/test_cams.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import pathlib +import shutil + +import numpy as np +import pytest + +from earth2studio.data import CAMS, CAMS_FX + +YESTERDAY = datetime.datetime.now(datetime.UTC).replace( + hour=0, minute=0, second=0, microsecond=0 +) - datetime.timedelta(days=1) + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "time", + [ + [YESTERDAY], + np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]), + ], +) +@pytest.mark.parametrize("variable", ["dust", ["dust", "pm2p5"]]) +def test_cams_fetch(time, variable): + ds = CAMS(cache=False) + data = ds(time, variable) + shape = data.shape + + if isinstance(variable, str): + variable = [variable] + + assert shape[0] == 1 + assert shape[1] == len(variable) + assert len(data.coords["lat"]) > 0 + assert len(data.coords["lon"]) > 0 + assert not np.isnan(data.values).all() + assert np.array_equal(data.coords["variable"].values, np.array(variable)) + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(120) +@pytest.mark.parametrize("variable", [["dust", "so2sfc"]]) +@pytest.mark.parametrize("cache", [True, False]) +def test_cams_cache(variable, cache): + time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) + ds = CAMS(cache=cache) + data = ds(time, variable) + shape = data.shape + + assert shape[0] == 1 + assert shape[1] == 2 + assert not np.isnan(data.values).all() + assert pathlib.Path(ds.cache).is_dir() == cache + + data = ds(time, variable[0]) + assert data.shape[1] == 1 + + try: + shutil.rmtree(ds.cache) + except FileNotFoundError: + pass + + +@pytest.mark.timeout(30) +def test_cams_invalid(): + with pytest.raises((ValueError, KeyError)): + ds = CAMS() + ds(YESTERDAY, "nonexistent_var") + + +def test_cams_time_validation(): + with pytest.raises(ValueError): + ds = CAMS() + ds(datetime.datetime(2018, 1, 1), "dust") + + +def test_cams_available(): + assert CAMS.available(datetime.datetime(2024, 1, 1)) + assert not CAMS.available(datetime.datetime(2015, 1, 1)) + + +# ---- CAMS_FX tests ---- + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(120) +@pytest.mark.parametrize("variable", ["dust", ["dust", "pm2p5"]]) +@pytest.mark.parametrize( + "lead_time", + [ + datetime.timedelta(hours=0), + [datetime.timedelta(hours=0), datetime.timedelta(hours=24)], + ], +) +def test_cams_fx_fetch(variable, lead_time): + time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) + ds = CAMS_FX(cache=False) + data = ds(time, lead_time, variable) + shape = data.shape + + if isinstance(variable, str): + variable = [variable] + if isinstance(lead_time, datetime.timedelta): + lead_time = [lead_time] + + assert shape[0] == 1 # time + assert shape[1] == len(lead_time) + assert shape[2] == len(variable) + assert len(data.coords["lat"]) > 0 + assert len(data.coords["lon"]) > 0 + assert not np.isnan(data.values).all() + + +def test_cams_fx_available(): + assert CAMS_FX.available(datetime.datetime(2024, 1, 1)) + assert not CAMS_FX.available(datetime.datetime(2010, 1, 1)) diff --git a/test/lexicon/test_cams_lexicon.py b/test/lexicon/test_cams_lexicon.py new file mode 100644 index 000000000..144c935ad --- /dev/null +++ b/test/lexicon/test_cams_lexicon.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from earth2studio.lexicon import CAMSLexicon + + +@pytest.mark.parametrize( + "variable", + [ + ["dust"], + ["so2sfc", "pm2p5"], + ["no2sfc", "o3sfc", "cosfc"], + ["aod550"], + ["duaod550", "tcno2"], + ["dust_500m", "pm2p5_1000m"], + ], +) +def test_cams_lexicon(variable): + data = np.random.randn(len(variable), 8) + for v in variable: + label, modifier = CAMSLexicon[v] + output = modifier(data) + assert isinstance(label, str) + assert data.shape == output.shape + + +def test_cams_lexicon_invalid(): + with pytest.raises(KeyError): + CAMSLexicon["nonexistent_variable"] + + +def test_cams_lexicon_vocab_format(): + for key, value in CAMSLexicon.VOCAB.items(): + parts = value.split("::") + assert len(parts) == 4, ( + f"VOCAB entry '{key}' must have format " + "'dataset::api_var::nc_key::level'" + ) + dataset = parts[0] + assert dataset in ( + "cams-europe-air-quality-forecasts", + "cams-global-atmospheric-composition-forecasts", + ), f"Unknown dataset in VOCAB entry '{key}': {dataset}" + + +def test_cams_lexicon_all_levels_covered(): + levels = [50, 100, 250, 500, 750, 1000, 2000, 3000, 5000] + pollutants = ["dust", "pm2p5", "pm10", "so2", "no2", "o3", "co", "nh3", "no"] + for p in pollutants: + for lev in levels: + key = f"{p}_{lev}m" + assert key in CAMSLexicon.VOCAB, f"Missing level entry: {key}" From 5ae645f1d0e2ebcfdd02edd2c9e4f24c0fa0b4e8 Mon Sep 17 00:00:00 2001 From: "Claude Sonnet 4.5" Date: Sun, 29 Mar 2026 19:35:57 +0200 Subject: [PATCH 2/8] =?UTF-8?q?fix:=20address=20review=20findings=20?= =?UTF-8?q?=E2=80=94=20atomic=20download,=20tz-aware=20available(),=20coor?= =?UTF-8?q?dinate-based=20lead-time=20selection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P1: Use atomic write-then-rename in _download_cams_netcdf to prevent corrupt partial files from being cached on interrupted downloads. P1: Fix TypeError in CAMS.available() and CAMS_FX.available() when called with timezone-aware datetimes (strip tzinfo before comparing against naive min-time constants, matching _validate_cams_time). P2: Replace positional lead-time indexing in _extract_field with coordinate-based selection via forecast_period dimension values, avoiding silent data misassignment if API reorders slices. Co-Authored-By: Claude Opus 4.6 (1M context) --- earth2studio/data/cams.py | 50 +++++++++++++++++++++++++++------------ test/data/test_cams.py | 4 ++++ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py index 9bbac9a11..abdd71572 100644 --- a/earth2studio/data/cams.py +++ b/earth2studio/data/cams.py @@ -16,8 +16,10 @@ import asyncio import hashlib +import os import pathlib import shutil +import tempfile from dataclasses import dataclass from datetime import datetime, timedelta from time import sleep @@ -104,12 +106,22 @@ def _download_cams_netcdf( ) else: sleep(2.0) - r.download(str(cache_path)) + tmp_fd, tmp_name = tempfile.mkstemp(dir=cache_path.parent, suffix=".nc.tmp") + try: + os.close(tmp_fd) + r.download(tmp_name) + os.replace(tmp_name, cache_path) + except Exception: + pathlib.Path(tmp_name).unlink(missing_ok=True) + raise return cache_path def _extract_field( - ds: xr.Dataset, nc_key: str, level: str = "0", time_index: int = 0 + ds: xr.Dataset, + nc_key: str, + level: str = "0", + lead_time_hours: int | None = None, ) -> np.ndarray: if nc_key not in ds: raise ValueError( @@ -118,21 +130,22 @@ def _extract_field( ) field = ds[nc_key] non_spatial = [d for d in field.dims if d not in ("latitude", "longitude")] - sel: dict[str, int] = {} + isel: dict[str, int] = {} for d in non_spatial: if d == "level" and level: level_val = float(level) if level else 0.0 level_coords = field.coords["level"].values nearest_idx = int(np.argmin(np.abs(level_coords - level_val))) - sel[d] = nearest_idx - elif d in ("time", "forecast_period", "forecast_reference_time"): - sel[d] = time_index - elif field.sizes[d] > 1: - sel[d] = 0 + isel[d] = nearest_idx + elif d == "forecast_period" and lead_time_hours is not None: + fp_vals = field.coords["forecast_period"].values.astype(float) + target = float(lead_time_hours) + nearest_idx = int(np.argmin(np.abs(fp_vals - target))) + isel[d] = nearest_idx else: - sel[d] = 0 - if sel: - field = field.isel(sel) + isel[d] = 0 + if isel: + field = field.isel(isel) return field.values @@ -288,7 +301,10 @@ def available( """ if isinstance(time, datetime): time = [time] - return all(t >= _CAMS_EU_MIN_TIME for t in time) + return all( + (t.replace(tzinfo=None) if t.tzinfo else t) >= _CAMS_EU_MIN_TIME + for t in time + ) @staticmethod def _validate_time(times: list[datetime]) -> None: @@ -509,7 +525,10 @@ def available( """ if isinstance(time, datetime): time = [time] - return all(t >= _CAMS_GLOBAL_MIN_TIME for t in time) + return all( + (t.replace(tzinfo=None) if t.tzinfo else t) >= _CAMS_GLOBAL_MIN_TIME + for t in time + ) @staticmethod def _validate_time(times: list[datetime]) -> None: @@ -568,12 +587,13 @@ def _fetch_forecast( }, ) - for lt_idx in range(len(lead_hours)): + for lt_idx, lt_h in enumerate(lead_hours): for info in var_infos: _, modifier = CAMSLexicon[info.e2s_name] da[0, lt_idx, info.index] = modifier( _extract_field( - ds, info.nc_key, level=info.level, time_index=lt_idx + ds, info.nc_key, level=info.level, + lead_time_hours=int(lt_h), ) ) diff --git a/test/data/test_cams.py b/test/data/test_cams.py index 1742ae7dd..b271d159c 100644 --- a/test/data/test_cams.py +++ b/test/data/test_cams.py @@ -96,6 +96,8 @@ def test_cams_time_validation(): def test_cams_available(): assert CAMS.available(datetime.datetime(2024, 1, 1)) assert not CAMS.available(datetime.datetime(2015, 1, 1)) + # timezone-aware datetimes must not raise TypeError + assert CAMS.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) # ---- CAMS_FX tests ---- @@ -134,3 +136,5 @@ def test_cams_fx_fetch(variable, lead_time): def test_cams_fx_available(): assert CAMS_FX.available(datetime.datetime(2024, 1, 1)) assert not CAMS_FX.available(datetime.datetime(2010, 1, 1)) + # timezone-aware datetimes must not raise TypeError + assert CAMS_FX.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) From e944d291c1e568faf8466de872abe67238dbffbd Mon Sep 17 00:00:00 2001 From: "Claude Sonnet 4.5" Date: Tue, 31 Mar 2026 00:53:20 +0200 Subject: [PATCH 3/8] docs: add CAMS and CAMS_FX to datasource documentation pages Add CAMS to analysis datasources and CAMS_FX to forecast datasources. Add region:europe and product:airquality to badge filters. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/modules/datasources_analysis.rst | 5 +++-- docs/modules/datasources_forecast.rst | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/modules/datasources_analysis.rst b/docs/modules/datasources_analysis.rst index 7fc88beb8..2c27002dc 100644 --- a/docs/modules/datasources_analysis.rst +++ b/docs/modules/datasources_analysis.rst @@ -19,9 +19,9 @@ Used for fetching initial conditions for inference and validation data for scori .. currentmodule:: earth2studio -.. badge-filter:: region:global region:na region:as +.. badge-filter:: region:global region:na region:as region:europe dataclass:analysis dataclass:reanalysis dataclass:observation dataclass:simulation - product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu + product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu product:airquality :filter-mode: or :badge-order-fixed: :group-visibility-toggle: @@ -33,6 +33,7 @@ Used for fetching initial conditions for inference and validation data for scori :template: datasource.rst data.ARCO + data.CAMS data.CDS data.CMIP6 data.CMIP6MultiRealm diff --git a/docs/modules/datasources_forecast.rst b/docs/modules/datasources_forecast.rst index c8d50b545..8570a7f97 100644 --- a/docs/modules/datasources_forecast.rst +++ b/docs/modules/datasources_forecast.rst @@ -9,9 +9,9 @@ Typically used in intercomparison workflows. .. currentmodule:: earth2studio -.. badge-filter:: region:global region:na region:as +.. badge-filter:: region:global region:na region:as region:europe dataclass:analysis dataclass:reanalysis dataclass:observation dataclass:simulation - product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu + product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu product:airquality :filter-mode: or :badge-order-fixed: :group-visibility-toggle: @@ -23,6 +23,7 @@ Typically used in intercomparison workflows. :template: datasource.rst data.AIFS_FX + data.CAMS_FX data.AIFS_ENS_FX data.GFS_FX data.GEFS_FX From 7f00dc501e06bd3d558fd7a7c46360896fe865b6 Mon Sep 17 00:00:00 2001 From: "Claude Sonnet 4.5" Date: Wed, 1 Apr 2026 13:17:12 +0200 Subject: [PATCH 4/8] fix: address P2 review findings in CAMS data source - Deduplicate api_vars via dict.fromkeys() to avoid duplicate variable names in CDS API requests (CAMS and CAMS_FX) - Use dataset-specific min-time validation in CAMS_FX (EU: 2019-07-01, Global: 2015-01-01) instead of global minimum for all datasets - Sort lead_hours in CAMS_FX cache key so identical lead times in different order produce the same cache hit Co-Authored-By: Claude Opus 4.6 (1M context) --- earth2studio/data/cams.py | 44 +++++++++++++++++++-------------------- test/data/test_cams.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py index abdd71572..0d9351c42 100644 --- a/earth2studio/data/cams.py +++ b/earth2studio/data/cams.py @@ -92,9 +92,7 @@ def _download_cams_netcdf( r.update() reply = r.reply if verbose: - logger.debug( - f"Request ID:{reply['request_id']}, state: {reply['state']}" - ) + logger.debug(f"Request ID:{reply['request_id']}, state: {reply['state']}") if reply["state"] == "completed": break elif reply["state"] in ("queued", "running"): @@ -125,8 +123,7 @@ def _extract_field( ) -> np.ndarray: if nc_key not in ds: raise ValueError( - f"Variable '{nc_key}' not found in NetCDF. " - f"Available: {list(ds.data_vars)}" + f"Variable '{nc_key}' not found in NetCDF. Available: {list(ds.data_vars)}" ) field = ds[nc_key] non_spatial = [d for d in field.dims if d not in ("latitude", "longitude")] @@ -149,9 +146,7 @@ def _extract_field( return field.values -def _validate_cams_time( - times: list[datetime], min_time: datetime, name: str -) -> None: +def _validate_cams_time(times: list[datetime], min_time: datetime, name: str) -> None: for t in times: t_naive = t.replace(tzinfo=None) if t.tzinfo else t if t_naive < min_time: @@ -160,18 +155,14 @@ def _validate_cams_time( f"(earliest: {min_time})" ) if t_naive.minute != 0 or t_naive.second != 0: - raise ValueError( - f"Requested time {t} must be on the hour for {name}" - ) + raise ValueError(f"Requested time {t} must be on the hour for {name}") def _validate_cams_leadtime(lead_times: list[timedelta], max_hours: int) -> None: for lt in lead_times: hours = int(lt.total_seconds() // 3600) if lt.total_seconds() % 3600 != 0: - raise ValueError( - f"Lead time {lt} must be a whole number of hours" - ) + raise ValueError(f"Lead time {lt} must be a whole number of hours") if hours < 0 or hours > max_hours: raise ValueError( f"Lead time {lt} ({hours}h) outside valid range [0, {max_hours}]h" @@ -321,7 +312,7 @@ def _fetch_analysis(self, time: datetime, variables: list[str]) -> xr.DataArray: ) var_infos.append(info) - api_vars = [vi.api_name for vi in var_infos] + api_vars = list(dict.fromkeys(vi.api_name for vi in var_infos)) levels = sorted(set(vi.level for vi in var_infos if vi.level)) if not levels: levels = ["0"] @@ -531,8 +522,14 @@ def available( ) @staticmethod - def _validate_time(times: list[datetime]) -> None: - _validate_cams_time(times, _CAMS_GLOBAL_MIN_TIME, "CAMS forecast") + def _validate_time(times: list[datetime], dataset: str | None = None) -> None: + if dataset == EU_DATASET: + min_time = _CAMS_EU_MIN_TIME + name = "CAMS EU forecast" + else: + min_time = _CAMS_GLOBAL_MIN_TIME + name = "CAMS global forecast" + _validate_cams_time(times, min_time, name) def _fetch_forecast( self, @@ -552,8 +549,9 @@ def _fetch_forecast( ) dataset_name = next(iter(datasets)) + self._validate_time([time], dataset=dataset_name) var_infos = datasets[dataset_name] - api_vars = [vi.api_name for vi in var_infos] + api_vars = list(dict.fromkeys(vi.api_name for vi in var_infos)) levels = sorted(set(vi.level for vi in var_infos if vi.level)) lead_hours = [ str(int(np.timedelta64(lt, "h").astype(int))) for lt in lead_times @@ -574,9 +572,7 @@ def _fetch_forecast( lon = ds.longitude.values da = xr.DataArray( - data=np.empty( - (1, len(lead_times), len(variables), len(lat), len(lon)) - ), + data=np.empty((1, len(lead_times), len(variables), len(lat), len(lon))), dims=["time", "lead_time", "variable", "lat", "lon"], coords={ "time": [time], @@ -592,7 +588,9 @@ def _fetch_forecast( _, modifier = CAMSLexicon[info.e2s_name] da[0, lt_idx, info.index] = modifier( _extract_field( - ds, info.nc_key, level=info.level, + ds, + info.nc_key, + level=info.level, lead_time_hours=int(lt_h), ) ) @@ -612,7 +610,7 @@ def _download_cached( level_str = "_".join(sorted(levels)) if levels else "none" sha = hashlib.sha256( f"cams_fx_{dataset}_{'_'.join(sorted(api_vars))}" - f"_{'_'.join(lead_hours)}_{level_str}" + f"_{'_'.join(sorted(lead_hours, key=int))}_{level_str}" f"_{date_str}_{time.hour:02d}".encode() ) cache_path = self.cache / (sha.hexdigest() + ".nc") diff --git a/test/data/test_cams.py b/test/data/test_cams.py index b271d159c..09af7a409 100644 --- a/test/data/test_cams.py +++ b/test/data/test_cams.py @@ -92,6 +92,18 @@ def test_cams_time_validation(): ds = CAMS() ds(datetime.datetime(2018, 1, 1), "dust") + # CAMS_FX must reject EU variables before 2019-07-01 + with pytest.raises(ValueError, match="CAMS EU forecast"): + CAMS_FX._validate_time( + [datetime.datetime(2018, 1, 1)], + dataset="cams-europe-air-quality-forecasts", + ) + # but accept global variables at the same date + CAMS_FX._validate_time( + [datetime.datetime(2018, 1, 1)], + dataset="cams-global-atmospheric-composition-forecasts", + ) + def test_cams_available(): assert CAMS.available(datetime.datetime(2024, 1, 1)) @@ -138,3 +150,29 @@ def test_cams_fx_available(): assert not CAMS_FX.available(datetime.datetime(2010, 1, 1)) # timezone-aware datetimes must not raise TypeError assert CAMS_FX.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) + + +def test_cams_fx_cache_key_lead_hours_order(): + """lead_hours in different order must produce the same cache key.""" + # Mirrors the cache key construction in CAMS_FX._download_cached + import hashlib + + def _make_cache_key(lead_hours): + return hashlib.sha256( + f"cams_fx_eu_dust_{'_'.join(sorted(lead_hours, key=int))}" + f"_0_2024-06-01_00".encode() + ).hexdigest() + + assert _make_cache_key(["0", "24", "48"]) == _make_cache_key(["48", "0", "24"]) + assert _make_cache_key(["12", "6"]) == _make_cache_key(["6", "12"]) + + +def test_cams_api_vars_dedup(): + """Variables sharing the same API name must not produce duplicate requests.""" + from earth2studio.data.cams import _resolve_variable + + # dust and dust_50m both resolve to api_name "dust" + info_a = _resolve_variable("dust", 0) + info_b = _resolve_variable("dust_50m", 1) + api_vars = list(dict.fromkeys([info_a.api_name, info_b.api_name])) + assert api_vars == ["dust"], f"Expected deduplicated list, got {api_vars}" From e59fe35c6393330eb355d2f07f81d999ee5d28ef Mon Sep 17 00:00:00 2001 From: "Claude Sonnet 4.5" Date: Wed, 1 Apr 2026 19:45:40 +0200 Subject: [PATCH 5/8] refactor: decouple CAMS to Global-only forecast source Per reviewer feedback (NickGeneva): - Remove CAMS analysis class (no ML models need it currently) - Remove EU dataset support from CAMS_FX (1:1 mapping with remote store) - Reduce CAMSLexicon to 11 Global variables (AOD, column products) - Update docs and tests accordingly Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/modules/datasources_analysis.rst | 1 - earth2studio/data/__init__.py | 2 +- earth2studio/data/cams.py | 299 +++----------------------- earth2studio/lexicon/cams.py | 33 --- test/data/test_cams.py | 120 +++-------- test/lexicon/test_cams_lexicon.py | 42 ++-- 6 files changed, 84 insertions(+), 413 deletions(-) diff --git a/docs/modules/datasources_analysis.rst b/docs/modules/datasources_analysis.rst index 2c27002dc..03e1c8c25 100644 --- a/docs/modules/datasources_analysis.rst +++ b/docs/modules/datasources_analysis.rst @@ -33,7 +33,6 @@ Used for fetching initial conditions for inference and validation data for scori :template: datasource.rst data.ARCO - data.CAMS data.CDS data.CMIP6 data.CMIP6MultiRealm diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index cbd35443c..2147e9a84 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -17,7 +17,7 @@ from .ace2 import ACE2ERA5Data from .arco import ARCO from .base import DataSource, ForecastSource -from .cams import CAMS, CAMS_FX +from .cams import CAMS_FX from .cbottle import CBottle3D from .cds import CDS from .cmip6 import CMIP6, CMIP6MultiRealm diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py index 0d9351c42..d9c70d00e 100644 --- a/earth2studio/data/cams.py +++ b/earth2studio/data/cams.py @@ -30,7 +30,6 @@ from earth2studio.data.utils import ( datasource_cache_root, - prep_data_inputs, prep_forecast_inputs, ) from earth2studio.lexicon import CAMSLexicon @@ -46,12 +45,10 @@ OptionalDependencyFailure("data") cdsapi = None -# CAMS EU analysis available from 2019-07-01 onward -_CAMS_EU_MIN_TIME = datetime(2019, 7, 1) # CAMS Global forecast available from 2015-01-01 onward _CAMS_GLOBAL_MIN_TIME = datetime(2015, 1, 1) -EU_DATASET = "cams-europe-air-quality-forecasts" +_GLOBAL_DATASET = "cams-global-atmospheric-composition-forecasts" @dataclass @@ -118,7 +115,6 @@ def _download_cams_netcdf( def _extract_field( ds: xr.Dataset, nc_key: str, - level: str = "0", lead_time_hours: int | None = None, ) -> np.ndarray: if nc_key not in ds: @@ -129,12 +125,7 @@ def _extract_field( non_spatial = [d for d in field.dims if d not in ("latitude", "longitude")] isel: dict[str, int] = {} for d in non_spatial: - if d == "level" and level: - level_val = float(level) if level else 0.0 - level_coords = field.coords["level"].values - nearest_idx = int(np.argmin(np.abs(level_coords - level_val))) - isel[d] = nearest_idx - elif d == "forecast_period" and lead_time_hours is not None: + if d == "forecast_period" and lead_time_hours is not None: fp_vals = field.coords["forecast_period"].values.astype(float) target = float(lead_time_hours) nearest_idx = int(np.argmin(np.abs(fp_vals - target))) @@ -169,225 +160,12 @@ def _validate_cams_leadtime(lead_times: list[timedelta], max_hours: int) -> None ) -@check_optional_dependencies() -class CAMS: - """CAMS European Air Quality analysis data source. - - Uses the ``cams-europe-air-quality-forecasts`` dataset with ``type=analysis``. - Grid is 0.1 deg over Europe, read dynamically from the downloaded NetCDF. - - Parameters - ---------- - cache : bool, optional - Cache data source on local memory, by default True - verbose : bool, optional - Print download progress, by default True - - Warning - ------- - This is a remote data source and can potentially download a large amount of data - to your local machine for large requests. - - Note - ---- - Additional information on the data repository, registration, and authentication can - be referenced here: - - - https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts - - https://cds.climate.copernicus.eu/how-to-api - - Badges - ------ - region:europe dataclass:analysis product:airquality - """ - - def __init__(self, cache: bool = True, verbose: bool = True): - self._cache = cache - self._verbose = verbose - self._cds_client: "cdsapi.Client | None" = None - - @property - def _client(self) -> "cdsapi.Client": - if self._cds_client is None: - if cdsapi is None: - raise ImportError( - "cdsapi is required for CAMS. " - "Install with: pip install 'earth2studio[data]'" - ) - self._cds_client = cdsapi.Client( - debug=False, quiet=True, wait_until_complete=False - ) - return self._cds_client - - def __call__( - self, - time: datetime | list[datetime] | TimeArray, - variable: str | list[str] | VariableArray, - ) -> xr.DataArray: - """Retrieve CAMS EU analysis data. - - Parameters - ---------- - time : datetime | list[datetime] | TimeArray - Timestamps to return data for (UTC). - variable : str | list[str] | VariableArray - Variables to return. Must be in CAMSLexicon with EU dataset. - - Returns - ------- - xr.DataArray - CAMS data array with dims [time, variable, lat, lon] - """ - time, variable = prep_data_inputs(time, variable) - self._validate_time(time) - self.cache.mkdir(parents=True, exist_ok=True) - - data_arrays = [] - for t0 in time: - da = self._fetch_analysis(t0, variable) - data_arrays.append(da) - - if not self._cache: - shutil.rmtree(self.cache) - - return xr.concat(data_arrays, dim="time") - - async def fetch( - self, - time: datetime | list[datetime] | TimeArray, - variable: str | list[str] | VariableArray, - ) -> xr.DataArray: - """Async retrieval of CAMS EU analysis data. - - Parameters - ---------- - time : datetime | list[datetime] | TimeArray - Timestamps to return data for (UTC). - variable : str | list[str] | VariableArray - Variables to return. Must be in CAMSLexicon with EU dataset. - - Returns - ------- - xr.DataArray - CAMS data array with dims [time, variable, lat, lon] - """ - return await asyncio.to_thread(self.__call__, time, variable) - - @classmethod - def available( - cls, - time: datetime | list[datetime], - ) -> bool: - """Check if CAMS EU analysis data is available for the requested times. - - Parameters - ---------- - time : datetime | list[datetime] - Timestamps to check availability for. - - Returns - ------- - bool - True if all requested times are within the valid range. - """ - if isinstance(time, datetime): - time = [time] - return all( - (t.replace(tzinfo=None) if t.tzinfo else t) >= _CAMS_EU_MIN_TIME - for t in time - ) - - @staticmethod - def _validate_time(times: list[datetime]) -> None: - _validate_cams_time(times, _CAMS_EU_MIN_TIME, "CAMS EU analysis") - - def _fetch_analysis(self, time: datetime, variables: list[str]) -> xr.DataArray: - var_infos = [] - for i, v in enumerate(variables): - info = _resolve_variable(v, i) - if info.dataset != EU_DATASET: - raise ValueError( - f"CAMS analysis only supports EU dataset, got '{info.dataset}' " - f"for variable '{v}'. Use CAMS_FX for global forecast variables." - ) - var_infos.append(info) - - api_vars = list(dict.fromkeys(vi.api_name for vi in var_infos)) - levels = sorted(set(vi.level for vi in var_infos if vi.level)) - if not levels: - levels = ["0"] - nc_path = self._download_cached(time, api_vars, levels) - - ds = xr.open_dataset(nc_path, decode_timedelta=False) - lat = ds.latitude.values - lon = ds.longitude.values - - da = xr.DataArray( - data=np.empty((1, len(variables), len(lat), len(lon))), - dims=["time", "variable", "lat", "lon"], - coords={ - "time": [time], - "variable": variables, - "lat": lat, - "lon": lon, - }, - ) - - for info in var_infos: - _, modifier = CAMSLexicon[info.e2s_name] - da[0, info.index] = modifier( - _extract_field(ds, info.nc_key, level=info.level) - ) - - ds.close() - return da - - def _download_cached( - self, time: datetime, api_vars: list[str], levels: list[str] - ) -> pathlib.Path: - date_str = time.strftime("%Y-%m-%d") - sha = hashlib.sha256( - f"cams_eu_{'_'.join(sorted(api_vars))}" - f"_{'_'.join(sorted(levels))}_{date_str}_{time.hour:02d}".encode() - ) - cache_path = self.cache / (sha.hexdigest() + ".nc") - - request_body = { - "variable": api_vars, - "model": ["ensemble"], - "level": levels, - "date": [f"{date_str}/{date_str}"], - "type": ["analysis"], - "time": [f"{time.hour:02d}:00"], - "leadtime_hour": ["0"], - "data_format": "netcdf", - } - - if self._verbose: - logger.info( - f"Fetching CAMS EU analysis for {date_str} {time.hour:02d}:00 " - f"vars={api_vars}" - ) - return _download_cams_netcdf( - self._client, EU_DATASET, request_body, cache_path, self._verbose - ) - - @property - def cache(self) -> pathlib.Path: - """Cache location.""" - cache_location = pathlib.Path(datasource_cache_root()) / "cams" - if not self._cache: - cache_location = cache_location / "tmp_cams" - return cache_location - - @check_optional_dependencies() class CAMS_FX: - """CAMS forecast data source. + """CAMS Global atmospheric composition forecast data source. - Supports both EU (``cams-europe-air-quality-forecasts``) and Global - (``cams-global-atmospheric-composition-forecasts``) forecast datasets. - The dataset is determined automatically from the requested variables via CAMSLexicon. + Uses the ``cams-global-atmospheric-composition-forecasts`` dataset. + Grid is 0.4 deg global, read dynamically from the downloaded NetCDF. Parameters ---------- @@ -406,18 +184,15 @@ class CAMS_FX: Additional information on the data repository, registration, and authentication can be referenced here: - - https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts - https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts - https://cds.climate.copernicus.eu/how-to-api Badges ------ - region:europe region:global dataclass:simulation product:airquality + region:global dataclass:simulation product:airquality """ - # EU forecasts go up to 96h, Global up to 120h - MAX_EU_LEAD_HOURS = 96 - MAX_GLOBAL_LEAD_HOURS = 120 + MAX_LEAD_HOURS = 120 def __init__(self, cache: bool = True, verbose: bool = True): self._cache = cache @@ -443,7 +218,7 @@ def __call__( lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Retrieve CAMS forecast data. + """Retrieve CAMS Global forecast data. Parameters ---------- @@ -479,7 +254,7 @@ async def fetch( lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async retrieval of CAMS forecast data. + """Async retrieval of CAMS Global forecast data. Parameters ---------- @@ -502,7 +277,7 @@ def available( cls, time: datetime | list[datetime], ) -> bool: - """Check if CAMS forecast data is available for the requested times. + """Check if CAMS Global forecast data is available for the requested times. Parameters ---------- @@ -522,14 +297,8 @@ def available( ) @staticmethod - def _validate_time(times: list[datetime], dataset: str | None = None) -> None: - if dataset == EU_DATASET: - min_time = _CAMS_EU_MIN_TIME - name = "CAMS EU forecast" - else: - min_time = _CAMS_GLOBAL_MIN_TIME - name = "CAMS global forecast" - _validate_cams_time(times, min_time, name) + def _validate_time(times: list[datetime]) -> None: + _validate_cams_time(times, _CAMS_GLOBAL_MIN_TIME, "CAMS Global forecast") def _fetch_forecast( self, @@ -537,35 +306,26 @@ def _fetch_forecast( lead_times: np.ndarray, variables: list[str], ) -> xr.DataArray: - datasets: dict[str, list[_CAMSVarInfo]] = {} + var_infos = [] for i, v in enumerate(variables): info = _resolve_variable(v, i) - datasets.setdefault(info.dataset, []).append(info) - - if len(datasets) > 1: - raise ValueError( - "Cannot mix EU and Global CAMS variables in a single CAMS_FX call. " - f"Got datasets: {list(datasets.keys())}" - ) + if info.dataset != _GLOBAL_DATASET: + raise ValueError( + f"CAMS_FX only supports Global dataset, got '{info.dataset}' " + f"for variable '{v}'." + ) + var_infos.append(info) - dataset_name = next(iter(datasets)) - self._validate_time([time], dataset=dataset_name) - var_infos = datasets[dataset_name] api_vars = list(dict.fromkeys(vi.api_name for vi in var_infos)) - levels = sorted(set(vi.level for vi in var_infos if vi.level)) lead_hours = [ str(int(np.timedelta64(lt, "h").astype(int))) for lt in lead_times ] - is_eu = dataset_name == EU_DATASET - max_hours = self.MAX_EU_LEAD_HOURS if is_eu else self.MAX_GLOBAL_LEAD_HOURS _validate_cams_leadtime( - [timedelta(hours=int(h)) for h in lead_hours], max_hours + [timedelta(hours=int(h)) for h in lead_hours], self.MAX_LEAD_HOURS ) - nc_path = self._download_cached( - time, dataset_name, api_vars, lead_hours, levels - ) + nc_path = self._download_cached(time, api_vars, lead_hours) ds = xr.open_dataset(nc_path, decode_timedelta=False) lat = ds.latitude.values @@ -590,7 +350,6 @@ def _fetch_forecast( _extract_field( ds, info.nc_key, - level=info.level, lead_time_hours=int(lt_h), ) ) @@ -601,22 +360,17 @@ def _fetch_forecast( def _download_cached( self, time: datetime, - dataset: str, api_vars: list[str], lead_hours: list[str], - levels: list[str] | None = None, ) -> pathlib.Path: date_str = time.strftime("%Y-%m-%d") - level_str = "_".join(sorted(levels)) if levels else "none" sha = hashlib.sha256( - f"cams_fx_{dataset}_{'_'.join(sorted(api_vars))}" - f"_{'_'.join(sorted(lead_hours, key=int))}_{level_str}" + f"cams_fx_{'_'.join(sorted(api_vars))}" + f"_{'_'.join(sorted(lead_hours, key=int))}" f"_{date_str}_{time.hour:02d}".encode() ) cache_path = self.cache / (sha.hexdigest() + ".nc") - is_eu = dataset == EU_DATASET - request_body: dict = { "variable": api_vars, "date": [f"{date_str}/{date_str}"], @@ -625,17 +379,14 @@ def _download_cached( "leadtime_hour": lead_hours, "data_format": "netcdf", } - if is_eu: - request_body["model"] = ["ensemble"] - request_body["level"] = levels if levels else ["0"] if self._verbose: logger.info( - f"Fetching CAMS forecast ({dataset.split('-')[1]}) for {date_str} " + f"Fetching CAMS Global forecast for {date_str} " f"{time.hour:02d}:00 lead_hours={lead_hours} vars={api_vars}" ) return _download_cams_netcdf( - self._client, dataset, request_body, cache_path, self._verbose + self._client, _GLOBAL_DATASET, request_body, cache_path, self._verbose ) @property diff --git a/earth2studio/lexicon/cams.py b/earth2studio/lexicon/cams.py index 46f3e458a..770d1252c 100644 --- a/earth2studio/lexicon/cams.py +++ b/earth2studio/lexicon/cams.py @@ -20,34 +20,8 @@ from .base import LexiconType -_EU = "cams-europe-air-quality-forecasts" _GLOBAL = "cams-global-atmospheric-composition-forecasts" -# All EU levels available in the CAMS API (meters above ground) -_EU_LEVELS = [50, 100, 250, 500, 750, 1000, 2000, 3000, 5000] - -# short_name -> (api_request_name, netcdf_key, surface_e2s_name) -_EU_POLLUTANTS = { - "dust": ("dust", "dust", "dust"), - "pm2p5": ("particulate_matter_2.5um", "pm2p5_conc", "pm2p5"), - "pm10": ("particulate_matter_10um", "pm10_conc", "pm10"), - "so2": ("sulphur_dioxide", "so2_conc", "so2sfc"), - "no2": ("nitrogen_dioxide", "no2_conc", "no2sfc"), - "o3": ("ozone", "o3_conc", "o3sfc"), - "co": ("carbon_monoxide", "co_conc", "cosfc"), - "nh3": ("ammonia", "nh3_conc", "nh3sfc"), - "no": ("nitrogen_monoxide", "no_conc", "nosfc"), -} - - -def _build_eu_vocab() -> dict[str, str]: - vocab: dict[str, str] = {} - for short, (api, nc, sfc_name) in _EU_POLLUTANTS.items(): - vocab[sfc_name] = f"{_EU}::{api}::{nc}::0" - for level in _EU_LEVELS: - vocab[f"{short}_{level}m"] = f"{_EU}::{api}::{nc}::{level}" - return vocab - class CAMSLexicon(metaclass=LexiconType): """Copernicus Atmosphere Monitoring Service Lexicon @@ -59,18 +33,11 @@ class CAMSLexicon(metaclass=LexiconType): Note ---- - EU multi-level variables are available at: 0 (surface), 50, 100, 250, 500, - 750, 1000, 2000, 3000, 5000 m. All pollutants are mapped at all available - levels. - Additional resources: - https://ads.atmosphere.copernicus.eu/datasets/cams-europe-air-quality-forecasts https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts """ VOCAB = { - **_build_eu_vocab(), - # ---- CAMS Global (column/AOD, 0.4 deg grid) ---- "aod550": f"{_GLOBAL}::total_aerosol_optical_depth_550nm::aod550::", "duaod550": f"{_GLOBAL}::dust_aerosol_optical_depth_550nm::duaod550::", "omaod550": f"{_GLOBAL}::organic_matter_aerosol_optical_depth_550nm::omaod550::", diff --git a/test/data/test_cams.py b/test/data/test_cams.py index 09af7a409..41590e6a2 100644 --- a/test/data/test_cams.py +++ b/test/data/test_cams.py @@ -15,13 +15,14 @@ # limitations under the License. import datetime +import hashlib import pathlib import shutil import numpy as np import pytest -from earth2studio.data import CAMS, CAMS_FX +from earth2studio.data import CAMS_FX YESTERDAY = datetime.datetime.now(datetime.UTC).replace( hour=0, minute=0, second=0, microsecond=0 @@ -31,48 +32,51 @@ @pytest.mark.slow @pytest.mark.xfail @pytest.mark.timeout(120) +@pytest.mark.parametrize("variable", ["aod550", ["aod550", "tcco"]]) @pytest.mark.parametrize( - "time", + "lead_time", [ - [YESTERDAY], - np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]), + datetime.timedelta(hours=0), + [datetime.timedelta(hours=0), datetime.timedelta(hours=24)], ], ) -@pytest.mark.parametrize("variable", ["dust", ["dust", "pm2p5"]]) -def test_cams_fetch(time, variable): - ds = CAMS(cache=False) - data = ds(time, variable) +def test_cams_fx_fetch(variable, lead_time): + time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) + ds = CAMS_FX(cache=False) + data = ds(time, lead_time, variable) shape = data.shape if isinstance(variable, str): variable = [variable] + if isinstance(lead_time, datetime.timedelta): + lead_time = [lead_time] - assert shape[0] == 1 - assert shape[1] == len(variable) + assert shape[0] == 1 # time + assert shape[1] == len(lead_time) + assert shape[2] == len(variable) assert len(data.coords["lat"]) > 0 assert len(data.coords["lon"]) > 0 assert not np.isnan(data.values).all() - assert np.array_equal(data.coords["variable"].values, np.array(variable)) @pytest.mark.slow @pytest.mark.xfail @pytest.mark.timeout(120) -@pytest.mark.parametrize("variable", [["dust", "so2sfc"]]) @pytest.mark.parametrize("cache", [True, False]) -def test_cams_cache(variable, cache): +def test_cams_fx_cache(cache): time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) - ds = CAMS(cache=cache) - data = ds(time, variable) + lead_time = datetime.timedelta(hours=0) + ds = CAMS_FX(cache=cache) + data = ds(time, lead_time, ["aod550", "tcco"]) shape = data.shape assert shape[0] == 1 - assert shape[1] == 2 + assert shape[2] == 2 assert not np.isnan(data.values).all() assert pathlib.Path(ds.cache).is_dir() == cache - data = ds(time, variable[0]) - assert data.shape[1] == 1 + data = ds(time, lead_time, "aod550") + assert data.shape[2] == 1 try: shutil.rmtree(ds.cache) @@ -81,68 +85,15 @@ def test_cams_cache(variable, cache): @pytest.mark.timeout(30) -def test_cams_invalid(): +def test_cams_fx_invalid(): with pytest.raises((ValueError, KeyError)): - ds = CAMS() - ds(YESTERDAY, "nonexistent_var") - - -def test_cams_time_validation(): - with pytest.raises(ValueError): - ds = CAMS() - ds(datetime.datetime(2018, 1, 1), "dust") - - # CAMS_FX must reject EU variables before 2019-07-01 - with pytest.raises(ValueError, match="CAMS EU forecast"): - CAMS_FX._validate_time( - [datetime.datetime(2018, 1, 1)], - dataset="cams-europe-air-quality-forecasts", - ) - # but accept global variables at the same date - CAMS_FX._validate_time( - [datetime.datetime(2018, 1, 1)], - dataset="cams-global-atmospheric-composition-forecasts", - ) - - -def test_cams_available(): - assert CAMS.available(datetime.datetime(2024, 1, 1)) - assert not CAMS.available(datetime.datetime(2015, 1, 1)) - # timezone-aware datetimes must not raise TypeError - assert CAMS.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) + ds = CAMS_FX() + ds(YESTERDAY, datetime.timedelta(hours=0), "nonexistent_var") -# ---- CAMS_FX tests ---- - - -@pytest.mark.slow -@pytest.mark.xfail -@pytest.mark.timeout(120) -@pytest.mark.parametrize("variable", ["dust", ["dust", "pm2p5"]]) -@pytest.mark.parametrize( - "lead_time", - [ - datetime.timedelta(hours=0), - [datetime.timedelta(hours=0), datetime.timedelta(hours=24)], - ], -) -def test_cams_fx_fetch(variable, lead_time): - time = np.array([np.datetime64(YESTERDAY.strftime("%Y-%m-%dT%H:%M"))]) - ds = CAMS_FX(cache=False) - data = ds(time, lead_time, variable) - shape = data.shape - - if isinstance(variable, str): - variable = [variable] - if isinstance(lead_time, datetime.timedelta): - lead_time = [lead_time] - - assert shape[0] == 1 # time - assert shape[1] == len(lead_time) - assert shape[2] == len(variable) - assert len(data.coords["lat"]) > 0 - assert len(data.coords["lon"]) > 0 - assert not np.isnan(data.values).all() +def test_cams_fx_time_validation(): + with pytest.raises(ValueError, match="CAMS Global forecast"): + CAMS_FX._validate_time([datetime.datetime(2014, 1, 1)]) def test_cams_fx_available(): @@ -154,25 +105,22 @@ def test_cams_fx_available(): def test_cams_fx_cache_key_lead_hours_order(): """lead_hours in different order must produce the same cache key.""" - # Mirrors the cache key construction in CAMS_FX._download_cached - import hashlib def _make_cache_key(lead_hours): return hashlib.sha256( - f"cams_fx_eu_dust_{'_'.join(sorted(lead_hours, key=int))}" - f"_0_2024-06-01_00".encode() + f"cams_fx_aod550_{'_'.join(sorted(lead_hours, key=int))}" + f"_2024-06-01_00".encode() ).hexdigest() assert _make_cache_key(["0", "24", "48"]) == _make_cache_key(["48", "0", "24"]) assert _make_cache_key(["12", "6"]) == _make_cache_key(["6", "12"]) -def test_cams_api_vars_dedup(): +def test_cams_fx_api_vars_dedup(): """Variables sharing the same API name must not produce duplicate requests.""" from earth2studio.data.cams import _resolve_variable - # dust and dust_50m both resolve to api_name "dust" - info_a = _resolve_variable("dust", 0) - info_b = _resolve_variable("dust_50m", 1) + info_a = _resolve_variable("aod550", 0) + info_b = _resolve_variable("aod550", 1) api_vars = list(dict.fromkeys([info_a.api_name, info_b.api_name])) - assert api_vars == ["dust"], f"Expected deduplicated list, got {api_vars}" + assert len(api_vars) == 1 diff --git a/test/lexicon/test_cams_lexicon.py b/test/lexicon/test_cams_lexicon.py index 144c935ad..d5e6b6a30 100644 --- a/test/lexicon/test_cams_lexicon.py +++ b/test/lexicon/test_cams_lexicon.py @@ -23,12 +23,11 @@ @pytest.mark.parametrize( "variable", [ - ["dust"], - ["so2sfc", "pm2p5"], - ["no2sfc", "o3sfc", "cosfc"], ["aod550"], ["duaod550", "tcno2"], - ["dust_500m", "pm2p5_1000m"], + ["bcaod550", "ssaod550", "suaod550"], + ["tcco", "tco3", "tcso2"], + ["gtco3", "omaod550"], ], ) def test_cams_lexicon(variable): @@ -49,20 +48,27 @@ def test_cams_lexicon_vocab_format(): for key, value in CAMSLexicon.VOCAB.items(): parts = value.split("::") assert len(parts) == 4, ( - f"VOCAB entry '{key}' must have format " - "'dataset::api_var::nc_key::level'" + f"VOCAB entry '{key}' must have format 'dataset::api_var::nc_key::level'" + ) + assert parts[0] == "cams-global-atmospheric-composition-forecasts", ( + f"Expected global dataset in VOCAB entry '{key}', got '{parts[0]}'" ) - dataset = parts[0] - assert dataset in ( - "cams-europe-air-quality-forecasts", - "cams-global-atmospheric-composition-forecasts", - ), f"Unknown dataset in VOCAB entry '{key}': {dataset}" -def test_cams_lexicon_all_levels_covered(): - levels = [50, 100, 250, 500, 750, 1000, 2000, 3000, 5000] - pollutants = ["dust", "pm2p5", "pm10", "so2", "no2", "o3", "co", "nh3", "no"] - for p in pollutants: - for lev in levels: - key = f"{p}_{lev}m" - assert key in CAMSLexicon.VOCAB, f"Missing level entry: {key}" +def test_cams_lexicon_all_global_vars(): + expected = [ + "aod550", + "duaod550", + "omaod550", + "bcaod550", + "ssaod550", + "suaod550", + "tcco", + "tcno2", + "tco3", + "tcso2", + "gtco3", + ] + for var in expected: + assert var in CAMSLexicon.VOCAB, f"Missing global variable: {var}" + assert len(CAMSLexicon.VOCAB) == len(expected) From 490044ac773a99a37c5c22abb7daeb742a9e11ea Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Thu, 2 Apr 2026 02:52:13 +0000 Subject: [PATCH 6/8] Various updates and fixes --- earth2studio/data/cams.py | 175 ++++++++++++++++++++---------- earth2studio/lexicon/__init__.py | 2 +- earth2studio/lexicon/base.py | 10 ++ earth2studio/lexicon/cams.py | 78 ++++++++++++- test/data/test_cams.py | 124 ++++++++++++++++++--- test/lexicon/test_cams_lexicon.py | 86 +++++++++++---- 6 files changed, 379 insertions(+), 96 deletions(-) diff --git a/earth2studio/data/cams.py b/earth2studio/data/cams.py index d9c70d00e..f2226f57c 100644 --- a/earth2studio/data/cams.py +++ b/earth2studio/data/cams.py @@ -32,7 +32,7 @@ datasource_cache_root, prep_forecast_inputs, ) -from earth2studio.lexicon import CAMSLexicon +from earth2studio.lexicon import CAMSGlobalLexicon from earth2studio.utils.imports import ( OptionalDependencyFailure, check_optional_dependencies, @@ -45,11 +45,6 @@ OptionalDependencyFailure("data") cdsapi = None -# CAMS Global forecast available from 2015-01-01 onward -_CAMS_GLOBAL_MIN_TIME = datetime(2015, 1, 1) - -_GLOBAL_DATASET = "cams-global-atmospheric-composition-forecasts" - @dataclass class _CAMSVarInfo: @@ -62,7 +57,7 @@ class _CAMSVarInfo: def _resolve_variable(e2s_name: str, index: int) -> _CAMSVarInfo: - cams_key, _ = CAMSLexicon[e2s_name] + cams_key, _ = CAMSGlobalLexicon[e2s_name] dataset, api_name, nc_key, level = cams_key.split("::") return _CAMSVarInfo( e2s_name=e2s_name, @@ -116,6 +111,7 @@ def _extract_field( ds: xr.Dataset, nc_key: str, lead_time_hours: int | None = None, + pressure_level: int | None = None, ) -> np.ndarray: if nc_key not in ds: raise ValueError( @@ -130,6 +126,11 @@ def _extract_field( target = float(lead_time_hours) nearest_idx = int(np.argmin(np.abs(fp_vals - target))) isel[d] = nearest_idx + elif d in ("pressure_level", "isobaricInhPa") and pressure_level is not None: + pl_vals = field.coords[d].values.astype(float) + target_pl = float(pressure_level) + nearest_idx = int(np.argmin(np.abs(pl_vals - target_pl))) + isel[d] = nearest_idx else: isel[d] = 0 if isel: @@ -137,35 +138,12 @@ def _extract_field( return field.values -def _validate_cams_time(times: list[datetime], min_time: datetime, name: str) -> None: - for t in times: - t_naive = t.replace(tzinfo=None) if t.tzinfo else t - if t_naive < min_time: - raise ValueError( - f"Requested time {t} is before {name} availability " - f"(earliest: {min_time})" - ) - if t_naive.minute != 0 or t_naive.second != 0: - raise ValueError(f"Requested time {t} must be on the hour for {name}") - - -def _validate_cams_leadtime(lead_times: list[timedelta], max_hours: int) -> None: - for lt in lead_times: - hours = int(lt.total_seconds() // 3600) - if lt.total_seconds() % 3600 != 0: - raise ValueError(f"Lead time {lt} must be a whole number of hours") - if hours < 0 or hours > max_hours: - raise ValueError( - f"Lead time {lt} ({hours}h) outside valid range [0, {max_hours}]h" - ) - - @check_optional_dependencies() class CAMS_FX: """CAMS Global atmospheric composition forecast data source. Uses the ``cams-global-atmospheric-composition-forecasts`` dataset. - Grid is 0.4 deg global, read dynamically from the downloaded NetCDF. + Grid is 0.4 deg global (451 x 900). Parameters ---------- @@ -185,14 +163,21 @@ class CAMS_FX: be referenced here: - https://ads.atmosphere.copernicus.eu/datasets/cams-global-atmospheric-composition-forecasts - - https://cds.climate.copernicus.eu/how-to-api + - https://ads.atmosphere.copernicus.eu/how-to-api + + The API endpoint for this data source varies from the Climate Data Store (CDS), be + sure your api config has the correct url. Badges ------ - region:global dataclass:simulation product:airquality + region:global dataclass:simulation product:wind product:temp product:atmos """ MAX_LEAD_HOURS = 120 + CAMS_MIN_TIME = datetime(2015, 1, 1) + CAMS_DATASET_URI = "cams-global-atmospheric-composition-forecasts" + CAMS_LAT = np.linspace(90, -90, 451) + CAMS_LON = np.linspace(0, 359.6, 900) def __init__(self, cache: bool = True, verbose: bool = True): self._cache = cache @@ -227,7 +212,7 @@ def __call__( lead_time : timedelta | list[timedelta] | LeadTimeArray Forecast lead times. variable : str | list[str] | VariableArray - Variables to return. Must be in CAMSLexicon. + Variables to return. Must be in CAMSGlobalLexicon. Returns ------- @@ -263,7 +248,7 @@ async def fetch( lead_time : timedelta | list[timedelta] | LeadTimeArray Forecast lead times. variable : str | list[str] | VariableArray - Variables to return. Must be in CAMSLexicon. + Variables to return. Must be in CAMSGlobalLexicon. Returns ------- @@ -292,13 +277,36 @@ def available( if isinstance(time, datetime): time = [time] return all( - (t.replace(tzinfo=None) if t.tzinfo else t) >= _CAMS_GLOBAL_MIN_TIME + (t.replace(tzinfo=None) if t.tzinfo else t) >= cls.CAMS_MIN_TIME for t in time ) + @classmethod + def _validate_time(cls, times: list[datetime]) -> None: + """Validate that requested times are valid for CAMS Global forecast.""" + for t in times: + t_naive = t.replace(tzinfo=None) if t.tzinfo else t + if t_naive < cls.CAMS_MIN_TIME: + raise ValueError( + f"Requested time {t} is before CAMS Global forecast availability " + f"(earliest: {cls.CAMS_MIN_TIME})" + ) + if t_naive.minute != 0 or t_naive.second != 0: + raise ValueError( + f"Requested time {t} must be on the hour for CAMS Global forecast" + ) + @staticmethod - def _validate_time(times: list[datetime]) -> None: - _validate_cams_time(times, _CAMS_GLOBAL_MIN_TIME, "CAMS Global forecast") + def _validate_leadtime(lead_times: list[timedelta], max_hours: int) -> None: + """Validate that requested lead times are valid.""" + for lt in lead_times: + hours = int(lt.total_seconds() // 3600) + if lt.total_seconds() % 3600 != 0: + raise ValueError(f"Lead time {lt} must be a whole number of hours") + if hours < 0 or hours > max_hours: + raise ValueError( + f"Lead time {lt} ({hours}h) outside valid range [0, {max_hours}]h" + ) def _fetch_forecast( self, @@ -309,52 +317,99 @@ def _fetch_forecast( var_infos = [] for i, v in enumerate(variables): info = _resolve_variable(v, i) - if info.dataset != _GLOBAL_DATASET: - raise ValueError( - f"CAMS_FX only supports Global dataset, got '{info.dataset}' " - f"for variable '{v}'." - ) var_infos.append(info) - api_vars = list(dict.fromkeys(vi.api_name for vi in var_infos)) lead_hours = [ str(int(np.timedelta64(lt, "h").astype(int))) for lt in lead_times ] - _validate_cams_leadtime( + self._validate_leadtime( [timedelta(hours=int(h)) for h in lead_hours], self.MAX_LEAD_HOURS ) - nc_path = self._download_cached(time, api_vars, lead_hours) + # Separate surface and pressure-level variables; they need different + # API requests because pressure-level vars require the pressure_level + # parameter and are only available at 3-hourly lead times. + surface_infos = [vi for vi in var_infos if not vi.level] + pressure_infos = [vi for vi in var_infos if vi.level] + + # Validate that pressure-level lead times are multiples of 3 hours + if pressure_infos: + for lt_h in lead_hours: + if int(lt_h) % 3 != 0: + raise ValueError( + f"Lead time {lt_h}h is not a multiple of 3 hours. " + "Pressure-level variables in CAMS Global forecasts " + "are only available at 3-hourly intervals (0, 3, 6, ...)." + ) - ds = xr.open_dataset(nc_path, decode_timedelta=False) - lat = ds.latitude.values - lon = ds.longitude.values + # Download surface variables + surface_ds: xr.Dataset | None = None + if surface_infos: + surface_api_vars = list(dict.fromkeys(vi.api_name for vi in surface_infos)) + nc_path = self._download_cached(time, surface_api_vars, lead_hours) + surface_ds = xr.open_dataset(nc_path, decode_timedelta=False) + + # Download pressure-level variables (grouped by unique levels) + pressure_ds: xr.Dataset | None = None + if pressure_infos: + pressure_api_vars = list( + dict.fromkeys(vi.api_name for vi in pressure_infos) + ) + pressure_levels = sorted({vi.level for vi in pressure_infos}, key=int) + nc_path = self._download_cached( + time, pressure_api_vars, lead_hours, pressure_levels + ) + pressure_ds = xr.open_dataset(nc_path, decode_timedelta=False) + + # Use whichever dataset is available to read grid coordinates + ref_ds = surface_ds if surface_ds is not None else pressure_ds + if ref_ds is None: + raise ValueError( + "No variables to fetch – both surface and pressure lists are empty." + ) da = xr.DataArray( - data=np.empty((1, len(lead_times), len(variables), len(lat), len(lon))), + data=np.empty( + ( + 1, + len(lead_times), + len(variables), + len(self.CAMS_LAT), + len(self.CAMS_LON), + ) + ), dims=["time", "lead_time", "variable", "lat", "lon"], coords={ "time": [time], "lead_time": lead_times, "variable": variables, - "lat": lat, - "lon": lon, + "lat": self.CAMS_LAT, + "lon": self.CAMS_LON, }, ) for lt_idx, lt_h in enumerate(lead_hours): for info in var_infos: - _, modifier = CAMSLexicon[info.e2s_name] + _, modifier = CAMSGlobalLexicon[info.e2s_name] + ds = pressure_ds if info.level else surface_ds + if ds is None: # pragma: no cover + raise RuntimeError( + f"Dataset for variable {info.e2s_name} is unexpectedly None" + ) da[0, lt_idx, info.index] = modifier( _extract_field( ds, info.nc_key, lead_time_hours=int(lt_h), + pressure_level=int(info.level) if info.level else None, ) ) - ds.close() + if surface_ds is not None: + surface_ds.close() + if pressure_ds is not None: + pressure_ds.close() return da def _download_cached( @@ -362,11 +417,18 @@ def _download_cached( time: datetime, api_vars: list[str], lead_hours: list[str], + pressure_levels: list[str] | None = None, ) -> pathlib.Path: date_str = time.strftime("%Y-%m-%d") + pl_part = ( + f"_pl{'_'.join(sorted(pressure_levels, key=int))}" + if pressure_levels + else "" + ) sha = hashlib.sha256( f"cams_fx_{'_'.join(sorted(api_vars))}" f"_{'_'.join(sorted(lead_hours, key=int))}" + f"{pl_part}" f"_{date_str}_{time.hour:02d}".encode() ) cache_path = self.cache / (sha.hexdigest() + ".nc") @@ -379,14 +441,17 @@ def _download_cached( "leadtime_hour": lead_hours, "data_format": "netcdf", } + if pressure_levels: + request_body["pressure_level"] = pressure_levels if self._verbose: logger.info( f"Fetching CAMS Global forecast for {date_str} " f"{time.hour:02d}:00 lead_hours={lead_hours} vars={api_vars}" + + (f" pressure_levels={pressure_levels}" if pressure_levels else "") ) return _download_cams_netcdf( - self._client, _GLOBAL_DATASET, request_body, cache_path, self._verbose + self._client, self.CAMS_DATASET_URI, request_body, cache_path, self._verbose ) @property diff --git a/earth2studio/lexicon/__init__.py b/earth2studio/lexicon/__init__.py index e800e03d0..7eadc1c8e 100644 --- a/earth2studio/lexicon/__init__.py +++ b/earth2studio/lexicon/__init__.py @@ -16,7 +16,7 @@ from .ace import ACELexicon from .arco import ARCOLexicon -from .cams import CAMSLexicon +from .cams import CAMSGlobalLexicon from .cbottle import CBottleLexicon from .cds import CDSLexicon from .cmip6 import CMIP6Lexicon diff --git a/earth2studio/lexicon/base.py b/earth2studio/lexicon/base.py index 88f5fdac2..a85fc989f 100644 --- a/earth2studio/lexicon/base.py +++ b/earth2studio/lexicon/base.py @@ -279,6 +279,16 @@ def __contains__(cls, val: object) -> bool: "strd06": "surface long-wave (thermal) radiation downwards (J m-2) past 6 hours", "sf": "snowfall water equivalent (kg m-2)", "ro": "runoff water equivalent (surface plus subsurface) (kg m-2)", + "aod550": "total aerosol optical depth at 550 nm (dimensionless)", + "duaod550": "dust aerosol optical depth at 550 nm (dimensionless)", + "omaod550": "organic matter aerosol optical depth at 550 nm (dimensionless)", + "bcaod550": "black carbon aerosol optical depth at 550 nm (dimensionless)", + "ssaod550": "sea salt aerosol optical depth at 550 nm (dimensionless)", + "suaod550": "sulphate aerosol optical depth at 550 nm (dimensionless)", + "tcco": "total column carbon monoxide (kg m-2)", + "tcno2": "total column nitrogen dioxide (kg m-2)", + "tco3": "total column ozone (kg m-2)", + "tcso2": "total column sulphur dioxide (kg m-2)", } diff --git a/earth2studio/lexicon/cams.py b/earth2studio/lexicon/cams.py index 770d1252c..bd7455d28 100644 --- a/earth2studio/lexicon/cams.py +++ b/earth2studio/lexicon/cams.py @@ -23,8 +23,8 @@ _GLOBAL = "cams-global-atmospheric-composition-forecasts" -class CAMSLexicon(metaclass=LexiconType): - """Copernicus Atmosphere Monitoring Service Lexicon +class CAMSGlobalLexicon(metaclass=LexiconType): + """Copernicus Atmosphere Monitoring Service Global Forecast Lexicon CAMS specified ``::::::`` @@ -38,17 +38,87 @@ class CAMSLexicon(metaclass=LexiconType): """ VOCAB = { + # Surface meteorological variables + "u10m": f"{_GLOBAL}::10m_u_component_of_wind::u10::", + "v10m": f"{_GLOBAL}::10m_v_component_of_wind::v10::", + "t2m": f"{_GLOBAL}::2m_temperature::t2m::", + "d2m": f"{_GLOBAL}::2m_dewpoint_temperature::d2m::", + "sp": f"{_GLOBAL}::surface_pressure::sp::", + "msl": f"{_GLOBAL}::mean_sea_level_pressure::msl::", + "tcwv": f"{_GLOBAL}::total_column_water_vapour::tcwv::", + "tp": f"{_GLOBAL}::total_precipitation::tp::", + "skt": f"{_GLOBAL}::skin_temperature::skt::", + "tcc": f"{_GLOBAL}::total_cloud_cover::tcc::", + "z": f"{_GLOBAL}::surface_geopotential::z::", + "lsm": f"{_GLOBAL}::land_sea_mask::lsm::", + # Aerosol optical depth variables (550nm) "aod550": f"{_GLOBAL}::total_aerosol_optical_depth_550nm::aod550::", "duaod550": f"{_GLOBAL}::dust_aerosol_optical_depth_550nm::duaod550::", "omaod550": f"{_GLOBAL}::organic_matter_aerosol_optical_depth_550nm::omaod550::", "bcaod550": f"{_GLOBAL}::black_carbon_aerosol_optical_depth_550nm::bcaod550::", "ssaod550": f"{_GLOBAL}::sea_salt_aerosol_optical_depth_550nm::ssaod550::", "suaod550": f"{_GLOBAL}::sulphate_aerosol_optical_depth_550nm::suaod550::", + # Total column trace gases "tcco": f"{_GLOBAL}::total_column_carbon_monoxide::tcco::", "tcno2": f"{_GLOBAL}::total_column_nitrogen_dioxide::tcno2::", - "tco3": f"{_GLOBAL}::total_column_ozone::tco3::", + "tco3": f"{_GLOBAL}::total_column_ozone::gtco3::", "tcso2": f"{_GLOBAL}::total_column_sulphur_dioxide::tcso2::", - "gtco3": f"{_GLOBAL}::gems_total_column_ozone::gtco3::", + # Pressure-level variables (multi-level, available every 3h lead time) + # u-component of wind + "u200": f"{_GLOBAL}::u_component_of_wind::u::200", + "u250": f"{_GLOBAL}::u_component_of_wind::u::250", + "u300": f"{_GLOBAL}::u_component_of_wind::u::300", + "u400": f"{_GLOBAL}::u_component_of_wind::u::400", + "u500": f"{_GLOBAL}::u_component_of_wind::u::500", + "u600": f"{_GLOBAL}::u_component_of_wind::u::600", + "u700": f"{_GLOBAL}::u_component_of_wind::u::700", + "u850": f"{_GLOBAL}::u_component_of_wind::u::850", + "u925": f"{_GLOBAL}::u_component_of_wind::u::925", + "u1000": f"{_GLOBAL}::u_component_of_wind::u::1000", + # v-component of wind + "v200": f"{_GLOBAL}::v_component_of_wind::v::200", + "v250": f"{_GLOBAL}::v_component_of_wind::v::250", + "v300": f"{_GLOBAL}::v_component_of_wind::v::300", + "v400": f"{_GLOBAL}::v_component_of_wind::v::400", + "v500": f"{_GLOBAL}::v_component_of_wind::v::500", + "v600": f"{_GLOBAL}::v_component_of_wind::v::600", + "v700": f"{_GLOBAL}::v_component_of_wind::v::700", + "v850": f"{_GLOBAL}::v_component_of_wind::v::850", + "v925": f"{_GLOBAL}::v_component_of_wind::v::925", + "v1000": f"{_GLOBAL}::v_component_of_wind::v::1000", + # Temperature + "t200": f"{_GLOBAL}::temperature::t::200", + "t250": f"{_GLOBAL}::temperature::t::250", + "t300": f"{_GLOBAL}::temperature::t::300", + "t400": f"{_GLOBAL}::temperature::t::400", + "t500": f"{_GLOBAL}::temperature::t::500", + "t600": f"{_GLOBAL}::temperature::t::600", + "t700": f"{_GLOBAL}::temperature::t::700", + "t850": f"{_GLOBAL}::temperature::t::850", + "t925": f"{_GLOBAL}::temperature::t::925", + "t1000": f"{_GLOBAL}::temperature::t::1000", + # Geopotential + "z200": f"{_GLOBAL}::geopotential::z::200", + "z250": f"{_GLOBAL}::geopotential::z::250", + "z300": f"{_GLOBAL}::geopotential::z::300", + "z400": f"{_GLOBAL}::geopotential::z::400", + "z500": f"{_GLOBAL}::geopotential::z::500", + "z600": f"{_GLOBAL}::geopotential::z::600", + "z700": f"{_GLOBAL}::geopotential::z::700", + "z850": f"{_GLOBAL}::geopotential::z::850", + "z925": f"{_GLOBAL}::geopotential::z::925", + "z1000": f"{_GLOBAL}::geopotential::z::1000", + # Specific humidity + "q200": f"{_GLOBAL}::specific_humidity::q::200", + "q250": f"{_GLOBAL}::specific_humidity::q::250", + "q300": f"{_GLOBAL}::specific_humidity::q::300", + "q400": f"{_GLOBAL}::specific_humidity::q::400", + "q500": f"{_GLOBAL}::specific_humidity::q::500", + "q600": f"{_GLOBAL}::specific_humidity::q::600", + "q700": f"{_GLOBAL}::specific_humidity::q::700", + "q850": f"{_GLOBAL}::specific_humidity::q::850", + "q925": f"{_GLOBAL}::specific_humidity::q::925", + "q1000": f"{_GLOBAL}::specific_humidity::q::1000", } @classmethod diff --git a/test/data/test_cams.py b/test/data/test_cams.py index 41590e6a2..a9efce556 100644 --- a/test/data/test_cams.py +++ b/test/data/test_cams.py @@ -15,12 +15,13 @@ # limitations under the License. import datetime -import hashlib import pathlib import shutil +from unittest.mock import MagicMock, patch import numpy as np import pytest +import xarray as xr from earth2studio.data import CAMS_FX @@ -99,23 +100,9 @@ def test_cams_fx_time_validation(): def test_cams_fx_available(): assert CAMS_FX.available(datetime.datetime(2024, 1, 1)) assert not CAMS_FX.available(datetime.datetime(2010, 1, 1)) - # timezone-aware datetimes must not raise TypeError assert CAMS_FX.available(datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC)) -def test_cams_fx_cache_key_lead_hours_order(): - """lead_hours in different order must produce the same cache key.""" - - def _make_cache_key(lead_hours): - return hashlib.sha256( - f"cams_fx_aod550_{'_'.join(sorted(lead_hours, key=int))}" - f"_2024-06-01_00".encode() - ).hexdigest() - - assert _make_cache_key(["0", "24", "48"]) == _make_cache_key(["48", "0", "24"]) - assert _make_cache_key(["12", "6"]) == _make_cache_key(["6", "12"]) - - def test_cams_fx_api_vars_dedup(): """Variables sharing the same API name must not produce duplicate requests.""" from earth2studio.data.cams import _resolve_variable @@ -124,3 +111,110 @@ def test_cams_fx_api_vars_dedup(): info_b = _resolve_variable("aod550", 1) api_vars = list(dict.fromkeys([info_a.api_name, info_b.api_name])) assert len(api_vars) == 1 + + +def test_cams_fx_call_mock(tmp_path: pathlib.Path): + """Test CAMS_FX __call__ with surface and pressure-level variables (mocked).""" + lat = CAMS_FX.CAMS_LAT + lon = CAMS_FX.CAMS_LON + forecast_period = np.array([0, 3, 6], dtype=np.float64) + pressure_level = np.array([500.0, 850.0]) + + # Surface NetCDF + mock_surface_ds = xr.Dataset( + { + "aod550": ( + ["forecast_period", "latitude", "longitude"], + np.random.rand(len(forecast_period), len(lat), len(lon)), + ), + "tcco": ( + ["forecast_period", "latitude", "longitude"], + np.random.rand(len(forecast_period), len(lat), len(lon)), + ), + }, + coords={ + "forecast_period": forecast_period, + "latitude": lat, + "longitude": lon, + }, + ) + surface_path = tmp_path / "mock_surface.nc" + mock_surface_ds.to_netcdf(surface_path) + + # Pressure-level NetCDF + mock_pressure_ds = xr.Dataset( + { + "u": ( + ["forecast_period", "pressure_level", "latitude", "longitude"], + np.random.rand( + len(forecast_period), len(pressure_level), len(lat), len(lon) + ), + ), + "t": ( + ["forecast_period", "pressure_level", "latitude", "longitude"], + np.random.rand( + len(forecast_period), len(pressure_level), len(lat), len(lon) + ), + ), + }, + coords={ + "forecast_period": forecast_period, + "pressure_level": pressure_level, + "latitude": lat, + "longitude": lon, + }, + ) + pressure_path = tmp_path / "mock_pressure.nc" + mock_pressure_ds.to_netcdf(pressure_path) + + with patch("earth2studio.data.cams.cdsapi") as mock_cdsapi: + mock_cdsapi.Client = MagicMock() + time = datetime.datetime(2024, 6, 1, 0, 0) + lead_time = [datetime.timedelta(hours=0), datetime.timedelta(hours=6)] + + # --- surface-only fetch: single download --- + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.return_value = surface_path + data = CAMS_FX(cache=False)(time, lead_time, ["aod550", "tcco"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["aod550", "tcco"] + mock_dl.assert_called_once() + + # --- pressure-level-only fetch: single download --- + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.return_value = pressure_path + data = CAMS_FX(cache=False)(time, lead_time, ["u500", "t850"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["u500", "t850"] + mock_dl.assert_called_once() + + # --- mixed fetch: two separate downloads --- + call_count = 0 + + def _side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + return surface_path if call_count == 1 else pressure_path + + with patch("earth2studio.data.cams._download_cams_netcdf") as mock_dl: + mock_dl.side_effect = _side_effect + data = CAMS_FX(cache=False)(time, lead_time, ["aod550", "u500"]) + + assert data.shape == (1, 2, 2, len(lat), len(lon)) + assert list(data.coords["variable"].values) == ["aod550", "u500"] + assert mock_dl.call_count == 2 + + +def test_cams_fx_pressure_level_leadtime_validation(): + """Pressure-level vars at non-3h lead times must raise ValueError.""" + with patch("earth2studio.data.cams.cdsapi") as mock_cdsapi: + mock_cdsapi.Client = MagicMock() + ds = CAMS_FX(cache=False) + with pytest.raises(ValueError, match="multiple of 3 hours"): + ds( + datetime.datetime(2024, 6, 1, 0, 0), + datetime.timedelta(hours=1), + "u500", + ) diff --git a/test/lexicon/test_cams_lexicon.py b/test/lexicon/test_cams_lexicon.py index d5e6b6a30..8431f80e5 100644 --- a/test/lexicon/test_cams_lexicon.py +++ b/test/lexicon/test_cams_lexicon.py @@ -17,23 +17,32 @@ import numpy as np import pytest -from earth2studio.lexicon import CAMSLexicon +from earth2studio.lexicon import CAMSGlobalLexicon @pytest.mark.parametrize( "variable", [ + # Surface variables + ["u10m", "v10m", "t2m"], + ["d2m", "sp", "msl"], + ["tcwv", "tp", "tcc"], + # Aerosol optical depth ["aod550"], ["duaod550", "tcno2"], ["bcaod550", "ssaod550", "suaod550"], + # Trace gases ["tcco", "tco3", "tcso2"], - ["gtco3", "omaod550"], + ["omaod550"], + # Pressure level variables + ["u500", "v500", "t500"], + ["z850", "q700"], ], ) def test_cams_lexicon(variable): data = np.random.randn(len(variable), 8) for v in variable: - label, modifier = CAMSLexicon[v] + label, modifier = CAMSGlobalLexicon[v] output = modifier(data) assert isinstance(label, str) assert data.shape == output.shape @@ -41,34 +50,69 @@ def test_cams_lexicon(variable): def test_cams_lexicon_invalid(): with pytest.raises(KeyError): - CAMSLexicon["nonexistent_variable"] + CAMSGlobalLexicon["nonexistent_variable"] def test_cams_lexicon_vocab_format(): - for key, value in CAMSLexicon.VOCAB.items(): + for key, value in CAMSGlobalLexicon.VOCAB.items(): parts = value.split("::") - assert len(parts) == 4, ( - f"VOCAB entry '{key}' must have format 'dataset::api_var::nc_key::level'" - ) - assert parts[0] == "cams-global-atmospheric-composition-forecasts", ( - f"Expected global dataset in VOCAB entry '{key}', got '{parts[0]}'" - ) + assert ( + len(parts) == 4 + ), f"VOCAB entry '{key}' must have format 'dataset::api_var::nc_key::level'" + assert ( + parts[0] == "cams-global-atmospheric-composition-forecasts" + ), f"Expected global dataset in VOCAB entry '{key}', got '{parts[0]}'" -def test_cams_lexicon_all_global_vars(): - expected = [ +def test_cams_lexicon_surface_vars(): + """Test that all expected surface variables are present.""" + expected_surface = [ + "u10m", + "v10m", + "t2m", + "d2m", + "sp", + "msl", + "tcwv", + "tp", + "skt", + "z", + "lsm", + "tcc", + ] + for var in expected_surface: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing surface variable: {var}" + + +def test_cams_lexicon_aod_vars(): + """Test that all aerosol optical depth variables are present.""" + expected_aod = [ "aod550", "duaod550", "omaod550", "bcaod550", "ssaod550", "suaod550", - "tcco", - "tcno2", - "tco3", - "tcso2", - "gtco3", ] - for var in expected: - assert var in CAMSLexicon.VOCAB, f"Missing global variable: {var}" - assert len(CAMSLexicon.VOCAB) == len(expected) + for var in expected_aod: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing AOD variable: {var}" + + +def test_cams_lexicon_trace_gas_vars(): + """Test that all trace gas variables are present.""" + expected_gases = ["tcco", "tcno2", "tco3", "tcso2"] + for var in expected_gases: + assert var in CAMSGlobalLexicon.VOCAB, f"Missing trace gas variable: {var}" + + +def test_cams_lexicon_pressure_level_vars(): + """Test that pressure level variables are present for common levels.""" + pressure_levels = [200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] + variables = ["u", "v", "t", "z", "q"] + + for var in variables: + for level in pressure_levels: + key = f"{var}{level}" + assert ( + key in CAMSGlobalLexicon.VOCAB + ), f"Missing pressure level variable: {key}" From 3ae66eb914df0d6d59dfcf1a9193da26372e503b Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Thu, 2 Apr 2026 02:57:36 +0000 Subject: [PATCH 7/8] Test Fixes --- test/data/test_cams.py | 9 +++++++++ test/data/test_cds.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/test/data/test_cams.py b/test/data/test_cams.py index a9efce556..1aa11fed1 100644 --- a/test/data/test_cams.py +++ b/test/data/test_cams.py @@ -25,6 +25,15 @@ from earth2studio.data import CAMS_FX +CAMS_ADS_URL = "https://ads.atmosphere.copernicus.eu/api" + + +@pytest.fixture(autouse=True) +def _set_cdsapi_url(monkeypatch): + """Point cdsapi at the ADS endpoint for all tests in this module.""" + monkeypatch.setenv("CDSAPI_URL", CAMS_ADS_URL) + + YESTERDAY = datetime.datetime.now(datetime.UTC).replace( hour=0, minute=0, second=0, microsecond=0 ) - datetime.timedelta(days=1) diff --git a/test/data/test_cds.py b/test/data/test_cds.py index 150e5c6df..9f8718be8 100644 --- a/test/data/test_cds.py +++ b/test/data/test_cds.py @@ -23,6 +23,14 @@ from earth2studio.data import CDS +CDS_API_URL = "https://cds.climate.copernicus.eu/api" + + +@pytest.fixture(autouse=True) +def _set_cdsapi_url(monkeypatch): + """Point cdsapi at the CDS endpoint for all tests in this module.""" + monkeypatch.setenv("CDSAPI_URL", CDS_API_URL) + @pytest.mark.slow @pytest.mark.xfail From b0c5f8f519be334e9ce4a57c30d2ca5c859afc75 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva Date: Thu, 2 Apr 2026 03:00:17 +0000 Subject: [PATCH 8/8] Test Fixes --- docs/modules/datasources_analysis.rst | 2 +- docs/modules/datasources_forecast.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/modules/datasources_analysis.rst b/docs/modules/datasources_analysis.rst index 03e1c8c25..0f541604b 100644 --- a/docs/modules/datasources_analysis.rst +++ b/docs/modules/datasources_analysis.rst @@ -21,7 +21,7 @@ Used for fetching initial conditions for inference and validation data for scori .. badge-filter:: region:global region:na region:as region:europe dataclass:analysis dataclass:reanalysis dataclass:observation dataclass:simulation - product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu product:airquality + product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu :filter-mode: or :badge-order-fixed: :group-visibility-toggle: diff --git a/docs/modules/datasources_forecast.rst b/docs/modules/datasources_forecast.rst index 8570a7f97..416ea980d 100644 --- a/docs/modules/datasources_forecast.rst +++ b/docs/modules/datasources_forecast.rst @@ -9,9 +9,9 @@ Typically used in intercomparison workflows. .. currentmodule:: earth2studio -.. badge-filter:: region:global region:na region:as region:europe +.. badge-filter:: region:global region:na region:as dataclass:analysis dataclass:reanalysis dataclass:observation dataclass:simulation - product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu product:airquality + product:wind product:precip product:temp product:atmos product:ocean product:land product:veg product:solar product:radar product:sat product:insitu :filter-mode: or :badge-order-fixed: :group-visibility-toggle: