diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d0a8af..11e9bdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.3.1] + +### Changed +- **Breaking**: `SARIXRunConfig` and `GBQRRunConfig` removed; use `RunConfig` directly (was abstract but is now directly instantiable) +- **Breaking**: `num_warmup`, `num_samples`, `num_chains` moved from `SARIXRunConfig` to `SARIXModelConfig` +- **Breaking**: `save_feat_importance` moved from `GBQRRunConfig` to `GBQRModelConfig` + +## [1.3.0] + +### Added +- Concrete configuration dataclasses: `ModelConfig`, `RunConfig` (abstract bases), `SARIXModelConfig`, `SARIXRunConfig`, `GBQRModelConfig`, `GBQRRunConfig` +- `SARIXFourierModelConfig` dataclass with `fourier_K` and `fourier_pooling` fields +- Wave feature fields on `GBQRModelConfig` (`use_directional_waves`, `wave_directions`, etc.), disabled by default +- Enum types: `DataSource`, `Disease`, `PowerTransform`, `PoolingStrategy` +- Docstrings for `ModelConfig` and `RunConfig` base classes +- All config types exported from `idmodels.__init__` + +### Changed +- **Breaking**: `model_config.sources` now expects `list[DataSource]` instead of `list[str]` +- **Breaking**: `model_config.power_transform` now expects `PowerTransform` instead of `str` +- **Breaking**: `model_config.disease` now expects `Disease` instead of `str` +- Source validation in `sarix.py` and `gbqr.py` uses `DataSource` enums and set operations instead of `np.isin` with string arrays +- All tests use concrete config dataclasses instead of `SimpleNamespace` +- Updated `directional_wave_features.md` examples to use config dataclasses + +### Removed +- `SimpleNamespace` usage throughout tests and documentation +- `model_class` field from model configurations (implied by the dataclass type) +- `num_bags` from `GBQRRunConfig` test helper (it is a `GBQRModelConfig` field) +- `save_feat_importance` from `SARIXRunConfig` test helper (not a SARIX field) + ## [1.1.0] - 2025-12-08 ### Added @@ -71,7 +102,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Updated to latest iddata API -[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.1.0...HEAD +[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.3.0...HEAD +[1.3.0]: https://github.com/reichlab/idmodels/compare/v1.1.0...v1.3.0 [1.1.0]: https://github.com/reichlab/idmodels/compare/v1.0.0...v1.1.0 [1.0.0]: https://github.com/reichlab/idmodels/compare/v0.1.0...v1.0.0 [0.1.0]: https://github.com/reichlab/idmodels/compare/v0.0.1...v0.1.0 diff --git a/docs/directional_wave_features.md b/docs/directional_wave_features.md index c565bfd..5f6258b 100644 --- a/docs/directional_wave_features.md +++ b/docs/directional_wave_features.md @@ -43,10 +43,13 @@ For each location and time point, the following features are generated: Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`: ```python -from types import SimpleNamespace +from idmodels.config import DataSource, GBQRModelConfig, PowerTransform -model_config = SimpleNamespace( - # ... existing parameters ... +model_config = GBQRModelConfig( + model_name = "gbqr_with_waves", + sources = [DataSource.NHSN], + fit_locations_separately = False, + power_transform = PowerTransform.FOURTH_ROOT, # Directional wave features (disabled by default) use_directional_waves = True, # Set to True to enable @@ -111,8 +114,8 @@ model_config = SimpleNamespace( ### Minimal Configuration (4 cardinal directions) ```python -model_config = SimpleNamespace( - # ... other params ... +model_config = GBQRModelConfig( + # ... required base params ... use_directional_waves = True, wave_directions = ['N', 'S', 'E', 'W'] ) @@ -121,8 +124,8 @@ Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features** ### Standard Configuration (8 directions) ```python -model_config = SimpleNamespace( - # ... other params ... +model_config = GBQRModelConfig( + # ... required base params ... use_directional_waves = True, wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], wave_temporal_lags = [1, 2] @@ -132,8 +135,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features** ### Maximum Information (all options) ```python -model_config = SimpleNamespace( - # ... other params ... +model_config = GBQRModelConfig( + # ... required base params ... use_directional_waves = True, wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], wave_temporal_lags = [1, 2], @@ -147,8 +150,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features ### Hypothesis-Driven (specific directions) ```python # If you suspect disease spreads along NE-SW axis -model_config = SimpleNamespace( - # ... other params ... +model_config = GBQRModelConfig( + # ... required base params ... use_directional_waves = True, wave_directions = ['NE', 'SW'], wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread @@ -240,22 +243,22 @@ The implementation includes validation that warns about: ## Example: Complete GBQR Configuration ```python -from types import SimpleNamespace +import datetime +from pathlib import Path +from idmodels.config import DataSource, Disease, GBQRModelConfig, PowerTransform, RunConfig from idmodels.gbqr import GBQRModel # Model configuration with directional wave features -model_config = SimpleNamespace( - model_class = "gbqr", +model_config = GBQRModelConfig( model_name = "gbqr_with_waves", - - # Standard GBQR parameters + sources = [DataSource.NHSN], + fit_locations_separately = False, + power_transform = PowerTransform.FOURTH_ROOT, incl_level_feats = True, num_bags = 10, bag_frac_samples = 0.7, reporting_adj = False, - sources = ["nhsn"], - fit_locations_separately = False, - power_transform = "4rt", + save_feat_importance = True, # Directional wave features use_directional_waves = True, @@ -267,16 +270,16 @@ model_config = SimpleNamespace( ) # Run configuration -run_config = SimpleNamespace( - disease = "flu", +run_config = RunConfig( + disease = Disease.FLU, ref_date = datetime.date(2024, 1, 6), - output_root = "output/", - artifact_store_root = "artifacts/", - save_feat_importance = True, - locations = None, # All locations + output_root = Path("output/"), + artifact_store_root = Path("artifacts/"), max_horizon = 4, + states = ["US", "01", "06", "13", "36", "48"], + hsas = [], q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975], - q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"] + q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"], ) # Run model diff --git a/pyproject.toml b/pyproject.toml index 2280971..5e4ec54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,8 @@ dependencies = [ "iddata @ git+https://github.com/reichlab/iddata", "lightgbm", "numpy", - "pandas", - "sarix @ git+https://github.com/reichlab/sarix", + "pandas~=2.0", # pandas 3.0 breaks compatibility; remove cap once validated + "sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f", "scikit-learn", "tqdm", "timeseriesutils @ git+https://github.com/reichlab/timeseriesutils" diff --git a/src/idmodels/__init__.py b/src/idmodels/__init__.py index c1afe52..3fb0078 100644 --- a/src/idmodels/__init__.py +++ b/src/idmodels/__init__.py @@ -1,2 +1,17 @@ -__version__ = "1.2.0" +from idmodels.config import ( + DataSource, + Disease, + GBQRModelConfig, + PoolingStrategy, + PowerTransform, + RunConfig, + SARIXFourierModelConfig, + SARIXModelConfig, +) +from idmodels.gbqr import GBQRModel +from idmodels.sarix import SARIXFourierModel, SARIXModel +__all__ = ["DataSource", "Disease", "GBQRModel", "GBQRModelConfig", "PoolingStrategy", "PowerTransform", "RunConfig", + "SARIXFourierModel", "SARIXFourierModelConfig", "SARIXModel", "SARIXModelConfig"] + +__version__ = "1.3.1" diff --git a/src/idmodels/config.py b/src/idmodels/config.py new file mode 100644 index 0000000..f4574b6 --- /dev/null +++ b/src/idmodels/config.py @@ -0,0 +1,104 @@ +import datetime +from abc import ABC +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + + +class DataSource(str, Enum): + NHSN = "nhsn" + NSSP = "nssp" + FLUSURVNET = "flusurvnet" + ILINET = "ilinet" + + +class Disease(str, Enum): + FLU = "flu" + COVID = "covid" + RSV = "rsv" + + +class PowerTransform(str, Enum): + FOURTH_ROOT = "4rt" + NONE = "none" + + +class PoolingStrategy(str, Enum): + NONE = "none" + SHARED = "shared" + + +@dataclass +class ModelConfig(ABC): + """ + Abstract base for model configuration. + + Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling). + Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`. + """ + + model_name: str + sources: list[DataSource] + fit_locations_separately: bool + power_transform: PowerTransform + + def __post_init__(self): + if type(self) is ModelConfig: + raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig") + + +@dataclass +class RunConfig: + """ + Run configuration. + + Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc. + """ + + disease: Disease + ref_date: datetime.date + output_root: Path + artifact_store_root: Path | None + max_horizon: int + states: list[str] + hsas: list[str] + q_levels: list[float] + q_labels: list[str] + + +@dataclass +class SARIXModelConfig(ModelConfig): + p: int = 0 + P: int = 0 + d: int = 0 + D: int = 0 + season_period: int = 1 + theta_pooling: PoolingStrategy = PoolingStrategy.NONE + sigma_pooling: PoolingStrategy = PoolingStrategy.NONE + x: list = field(default_factory=list) + num_warmup: int = 2000 + num_samples: int = 2000 + num_chains: int = 1 + + +@dataclass +class SARIXFourierModelConfig(SARIXModelConfig): + fourier_K: int = 1 + fourier_pooling: PoolingStrategy = PoolingStrategy.NONE + + +@dataclass +class GBQRModelConfig(ModelConfig): + incl_level_feats: bool = True + num_bags: int = 100 + bag_frac_samples: float = 0.7 + reporting_adj: bool = False + save_feat_importance: bool = False + + # directional wave features (disabled by default) + use_directional_waves: bool = False + wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) + wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2]) + wave_max_distance_km: float = 1000.0 + wave_include_velocity: bool = False + wave_include_aggregate: bool = True diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index 92f9158..374addb 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -6,6 +6,7 @@ from iddata.loader import DiseaseDataLoader from tqdm.autonotebook import tqdm +from idmodels.config import DataSource, Disease, PowerTransform from idmodels.preprocess import create_directional_wave_features, create_features_and_targets from idmodels.utils import build_save_path @@ -31,22 +32,22 @@ def run(self, run_config): ilinet_kwargs = {"scale_to_positive": False} flusurvnet_kwargs = {"burden_adj": False} - valid_sources = ["flusurvnet", "nhsn", "ilinet", "nssp"] - if not np.isin(np.array(self.model_config.sources), valid_sources).all(): + valid_sources = {DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET, DataSource.NSSP} + if not set(self.model_config.sources) <= valid_sources: raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.") - + # Check if both nhsn and nssp data are included as sources - if all(src in self.model_config.sources for src in ["nhsn", "nssp"]): + if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources): raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.") - + fdl = DiseaseDataLoader() - if "nhsn" in self.model_config.sources: + if DataSource.NHSN in self.model_config.sources: df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, ilinet_kwargs=ilinet_kwargs, flusurvnet_kwargs=flusurvnet_kwargs, sources=self.model_config.sources, power_transform=self.model_config.power_transform) - elif "nssp" in self.model_config.sources: + elif DataSource.NSSP in self.model_config.sources: df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, ilinet_kwargs=ilinet_kwargs, flusurvnet_kwargs=flusurvnet_kwargs, @@ -62,20 +63,20 @@ def run(self, run_config): df["unique_id"] = df["agg_level"] + df["location"] # augment data with features and target values - if run_config.disease == "flu": + if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV): init_feats = ["inc_trans_cs", "season_week", "log_pop"] - elif run_config.disease == "covid": + elif run_config.disease == Disease.COVID: init_feats = ["inc_trans_cs", "log_pop"] # Create directional wave features if enabled - if hasattr(self.model_config, "use_directional_waves") and self.model_config.use_directional_waves: + if self.model_config.use_directional_waves: wave_config = { "enabled": True, - "directions": getattr(self.model_config, "wave_directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]), - "temporal_lags": getattr(self.model_config, "wave_temporal_lags", [1, 2]), - "max_distance_km": getattr(self.model_config, "wave_max_distance_km", 1000), - "include_velocity": getattr(self.model_config, "wave_include_velocity", False), - "include_aggregate": getattr(self.model_config, "wave_include_aggregate", True) + "directions": self.model_config.wave_directions, + "temporal_lags": self.model_config.wave_temporal_lags, + "max_distance_km": self.model_config.wave_max_distance_km, + "include_velocity": self.model_config.wave_include_velocity, + "include_aggregate": self.model_config.wave_include_aggregate, } df, wave_feat_names = create_directional_wave_features(df, wave_config) init_feats = init_feats + wave_feat_names @@ -87,7 +88,7 @@ def run(self, run_config): curr_feat_names=init_feats) # keep only rows that are in-season - if run_config.disease == "flu": + if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV): df = df.query("season_week >= 5 and season_week <= 45") # "test set" df used to generate look-ahead predictions @@ -176,12 +177,12 @@ def _train_gbq_and_predict(self, run_config, # build data frame with predictions on the original scale preds_df["inc_trans_cs_target_hat"] = preds_df["inc_trans_cs"] + preds_df["delta_hat"] preds_df["inc_trans_target_hat"] = (preds_df["inc_trans_cs_target_hat"] + preds_df["inc_trans_center_factor"]) * (preds_df["inc_trans_scale_factor"] + 0.01) - if self.model_config.power_transform == "4rt": + if self.model_config.power_transform == PowerTransform.FOURTH_ROOT: inv_power = 4 - elif self.model_config.power_transform is None: + elif self.model_config.power_transform == PowerTransform.NONE: inv_power = 1 else: - raise ValueError('unsupported power_transform: must be "4rt" or None') + raise ValueError(f"unsupported power_transform: {self.model_config.power_transform!r}") preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4) @@ -276,7 +277,7 @@ def _get_test_quantile_predictions(self, run_config, test_preds_by_bag[:, b, q_ind] = model.predict(X=x_test) # combine and save feature importance scores - if run_config.save_feat_importance: + if self.model_config.save_feat_importance: feat_importance = pd.concat(feat_importance, axis=0) save_path = build_save_path( root=run_config.artifact_store_root, diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index b89ba2a..2dcda12 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -4,6 +4,7 @@ from iddata.utils import get_holidays from sarix import sarix +from idmodels.config import DataSource, PowerTransform, SARIXFourierModelConfig from idmodels.utils import build_save_path @@ -16,21 +17,21 @@ def _get_extra_sarix_params(self, df): return {} def run(self, run_config): - valid_sources = np.array(["nhsn", "nssp"]) - if not np.isin(np.array(self.model_config.sources), valid_sources).all(): + valid_sources = {DataSource.NHSN, DataSource.NSSP} + if not set(self.model_config.sources) <= valid_sources: raise ValueError("For SARIX, the only supported data sources are 'nhsn' or 'nssp'.") - + # Check if both nhsn and nssp data are included as sources - if all(src in self.model_config.sources for src in ["nhsn", "nssp"]): + if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources): raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.") fdl = DiseaseDataLoader() - if "nhsn" in self.model_config.sources: + if DataSource.NHSN in self.model_config.sources: df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, sources=self.model_config.sources, power_transform=self.model_config.power_transform) target_name = "wk inc " + run_config.disease + " hosp" - elif "nssp" in self.model_config.sources: + elif DataSource.NSSP in self.model_config.sources: df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease}, sources=self.model_config.sources, power_transform=self.model_config.power_transform) @@ -57,6 +58,7 @@ def run(self, run_config): # missing values are interpolated when possible xy_colnames = self.model_config.x + ["inc_trans_cs"] + df = df.query("wk_end_date >= '2022-10-01'").interpolate() batched_xy = df[xy_colnames].values.reshape(len(df["unique_id"].unique()), -1, len(xy_colnames)) @@ -73,10 +75,10 @@ def run(self, run_config): transform="none", # transformations are handled outside of SARIX theta_pooling=self.model_config.theta_pooling, sigma_pooling=self.model_config.sigma_pooling, - forecast_horizon = run_config.max_horizon, - num_warmup = run_config.num_warmup, - num_samples = run_config.num_samples, - num_chains = run_config.num_chains, + forecast_horizon=run_config.max_horizon, + num_warmup=self.model_config.num_warmup, + num_samples=self.model_config.num_samples, + num_chains=self.model_config.num_chains, **extra_params ) @@ -98,7 +100,7 @@ def run(self, run_config): # build data frame with predictions on the original scale preds_df["value"] = (preds_df["value"] + preds_df["inc_trans_center_factor"]) * preds_df["inc_trans_scale_factor"] - if self.model_config.power_transform == "4rt": + if self.model_config.power_transform == PowerTransform.FOURTH_ROOT: preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 4 else: preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 2 @@ -152,6 +154,13 @@ class SARIXFourierModel(SARIXModel): - fourier_K: Number of Fourier harmonic pairs (int) - fourier_pooling: How to share Fourier coefficients across locations ('none' or 'shared') """ + def __init__(self, model_config): + if not isinstance(model_config, SARIXFourierModelConfig): + raise TypeError( + f"SARIXFourierModel requires a SARIXFourierModelConfig, got {type(model_config).__name__}" + ) + super().__init__(model_config) + def _get_extra_sarix_params(self, df): """Return Fourier-specific parameters for SARIX constructor.""" # Extract day-of-year from dates for Fourier features diff --git a/tests/integration/test_gbqr.py b/tests/integration/test_gbqr.py index 22ac8e9..9ce97c2 100644 --- a/tests/integration/test_gbqr.py +++ b/tests/integration/test_gbqr.py @@ -1,6 +1,5 @@ import datetime from pathlib import Path -from types import SimpleNamespace from unittest.mock import patch import lightgbm @@ -9,6 +8,7 @@ import pytest from pandas.testing import assert_frame_equal +from idmodels.config import DataSource, Disease, GBQRModelConfig, PowerTransform, RunConfig from idmodels.gbqr import GBQRModel @@ -20,7 +20,7 @@ def test_gbqr_nhsn(tmp_path): "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", "56", "72"] - model_config = create_test_gbqr_model_config(sources = ["flusurvnet", "nhsn", "ilinet"]) + model_config = create_test_gbqr_model_config(sources = [DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET]) run_config = create_test_gbqr_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) # patch lgb.LGBMRegressor's `predict()` to return the same values to make the tests reproducible across OSs @@ -46,7 +46,7 @@ def test_gbqr_nhsn(tmp_path): ]) def test_gbqr_nssp(tmp_path, fips_codes, nci_ids): date = datetime.date.fromisoformat("2025-11-22") - model_config = create_test_gbqr_model_config(sources=["nssp"]) + model_config = create_test_gbqr_model_config(sources=[DataSource.NSSP]) run_config = create_test_gbqr_run_config(ref_date=date, states=fips_codes, hsas=nci_ids, tmp_path=tmp_path) # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs @@ -77,17 +77,16 @@ def test_gbqr_nssp(tmp_path, fips_codes, nci_ids): def create_test_gbqr_model_config(sources): - if "nhsn" in sources: - main_source = "nhsn" - elif "nssp" in sources: - main_source = "nssp" + if DataSource.NHSN in sources: + main_source = DataSource.NHSN + elif DataSource.NSSP in sources: + main_source = DataSource.NSSP else: main_source = None - - model_config = SimpleNamespace( - model_class = "gbqr", - model_name = "gbqr_" + main_source + "_no_reporting_adj", - + + model_config = GBQRModelConfig( + model_name = "gbqr_" + main_source.value + "_no_reporting_adj", + incl_level_feats = True, # bagging setup @@ -96,31 +95,29 @@ def create_test_gbqr_model_config(sources): # adjustments to reporting reporting_adj = False, - + # data sources and adjustments for reporting issues sources = sources, - + # fit locations separately or jointly fit_locations_separately = False, - + # power transform applied to surveillance signals - power_transform = "4rt", + power_transform = PowerTransform.FOURTH_ROOT, ) return model_config def create_test_gbqr_run_config(ref_date, states, hsas, tmp_path): - run_config = SimpleNamespace( - disease="flu", + run_config = RunConfig( + disease=Disease.FLU, ref_date=ref_date, output_root=tmp_path / "model-output", artifact_store_root=tmp_path / "artifact-store", - save_feat_importance=False, states=states, - hsas = hsas, + hsas=hsas, max_horizon=3, - q_levels = [0.025, 0.50, 0.975], - q_labels = ["0.025", "0.5", "0.975"], - num_bags = 10 + q_levels=[0.025, 0.50, 0.975], + q_labels=["0.025", "0.5", "0.975"], ) return run_config diff --git a/tests/integration/test_gbqr_wave_features.py b/tests/integration/test_gbqr_wave_features.py index c81419c..e85e8b9 100644 --- a/tests/integration/test_gbqr_wave_features.py +++ b/tests/integration/test_gbqr_wave_features.py @@ -1,11 +1,10 @@ """Integration test for GBQR model with directional wave features.""" -from types import SimpleNamespace - import numpy as np import pandas as pd +from idmodels.config import DataSource, GBQRModelConfig, PowerTransform from idmodels.preprocess import create_directional_wave_features, create_features_and_targets @@ -185,7 +184,11 @@ def test_gbqr_wave_features_with_model_config_pattern(): df = create_realistic_test_data() # Simulate model_config with wave feature settings - model_config = SimpleNamespace( + model_config = GBQRModelConfig( + model_name="gbqr_wave_test", + sources=[DataSource.NHSN], + fit_locations_separately=False, + power_transform=PowerTransform.FOURTH_ROOT, use_directional_waves=True, wave_directions=["N", "S", "E", "W"], wave_temporal_lags=[1, 2], @@ -229,9 +232,12 @@ def test_gbqr_wave_features_backwards_compatibility(): df = create_realistic_test_data() # Model config WITHOUT wave feature settings (backwards compatibility) - model_config = SimpleNamespace( - incl_level_feats=True, - # No wave feature attributes + model_config = GBQRModelConfig( + model_name="gbqr_no_waves", + sources=[DataSource.NHSN], + fit_locations_separately=False, + power_transform=PowerTransform.FOURTH_ROOT, + # use_directional_waves defaults to False ) init_feats = ["inc_trans_cs", "log_pop"] diff --git a/tests/integration/test_sarix.py b/tests/integration/test_sarix.py index bc716be..4d8f2c0 100644 --- a/tests/integration/test_sarix.py +++ b/tests/integration/test_sarix.py @@ -1,6 +1,5 @@ import datetime from pathlib import Path -from types import SimpleNamespace from unittest.mock import patch import numpy @@ -8,6 +7,15 @@ import pytest from pandas.testing import assert_frame_equal +from idmodels.config import ( + DataSource, + Disease, + PoolingStrategy, + PowerTransform, + RunConfig, + SARIXFourierModelConfig, + SARIXModelConfig, +) from idmodels.sarix import SARIXFourierModel, SARIXModel @@ -19,21 +27,22 @@ def test_sarix_nhsn(tmp_path): "33", "34", "35", "36", "37", "38", "39", "40", "41", "42", "44", "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", "56", "72"] - model_config = create_test_sarix_model_config(main_source=["nhsn"], theta_pooling="shared", sigma_pooling="none") - run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], num=200, tmp_path=tmp_path) - + model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling=PoolingStrategy.SHARED, + sigma_pooling=PoolingStrategy.NONE, num=200) + run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) + # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs with patch("idmodels.sarix._np_percentile", return_value=_np_percentile_val()): model = SARIXModel(model_config) model.run(run_config) actual_df = pd.read_csv( - run_config.output_root / f"UMass-{model_config.model_name}" / + run_config.output_root / f"UMass-{model_config.model_name}" / f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" ) expected_df = pd.read_csv( Path("tests") / "integration" / "data" / - f"UMass-{model_config.model_name}" / + f"UMass-{model_config.model_name}" / f"{str(run_config.ref_date)}-UMass-{model_config.model_name}.csv" ) assert_frame_equal(actual_df, expected_df) @@ -46,8 +55,9 @@ def test_sarix_nhsn(tmp_path): ]) def test_sarix_nssp(tmp_path, fips_codes, nci_ids): date = datetime.date.fromisoformat("2025-11-22") - model_config = create_test_sarix_model_config(main_source=["nssp"], theta_pooling="shared", sigma_pooling="none") - run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=nci_ids, num=200, tmp_path=tmp_path) + model_config = create_test_sarix_model_config(main_source=[DataSource.NSSP], theta_pooling=PoolingStrategy.SHARED, + sigma_pooling=PoolingStrategy.NONE, num=200) + run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=nci_ids, tmp_path=tmp_path) # patch the `_np_percentile()` helper function return the same values to make the tests reproducible across OSs if (fips_codes != []) & (nci_ids == []): @@ -81,8 +91,9 @@ def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path): # Use multiple locations to ensure we have multiple batches date = datetime.date.fromisoformat("2024-01-06") fips_codes = ["US", "01", "02", "04", "05"] # Multiple locs = multiple batches - model_config = create_test_sarix_model_config(main_source=["nhsn"], theta_pooling="none", sigma_pooling="shared") - run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], num=200, tmp_path=tmp_path) + model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling=PoolingStrategy.NONE, + sigma_pooling=PoolingStrategy.SHARED, num=200) + run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) model = SARIXModel(model_config) model.run(run_config) @@ -109,42 +120,28 @@ def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path): def test_sarix_fourier_none_pooling(tmp_path): """Test SARIXFourierModel with fourier_pooling='none' (unpooled).""" - model_config = SimpleNamespace( - model_class="sarix_fourier", + model_config = SARIXFourierModelConfig( model_name="sarix_p2_fourier_K2_none", - - # data sources - sources=["nhsn"], - - # fit locations separately or jointly + sources=[DataSource.NHSN], fit_locations_separately=False, - - # SARIX model parameters p=2, P=0, d=0, D=0, season_period=1, - - # power transform - power_transform="4rt", - - # parameter pooling - theta_pooling="shared", - sigma_pooling="shared", - - # Fourier parameters + power_transform=PowerTransform.FOURTH_ROOT, + theta_pooling=PoolingStrategy.SHARED, + sigma_pooling=PoolingStrategy.SHARED, fourier_K=2, - fourier_pooling="none", # Unpooled Fourier coefficients - - # covariates - x=[] - ) + fourier_pooling=PoolingStrategy.NONE, + x=[], + num_warmup=50, + num_samples=50) date = datetime.date.fromisoformat("2024-01-06") fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing - # model_config = create_test_sarix_model_config(main_source=["nhsn"], theta_pooling="shared", sigma_pooling="none") - run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], num=50, tmp_path=tmp_path) + # model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling="shared", sigma_pooling="none") + run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) model = SARIXFourierModel(model_config) model.run(run_config) @@ -171,42 +168,28 @@ def test_sarix_fourier_none_pooling(tmp_path): def test_sarix_fourier_shared_pooling(tmp_path): """Test SARIXFourierModel with fourier_pooling='shared' (pooled across locations).""" - model_config = SimpleNamespace( - model_class="sarix_fourier", + model_config = SARIXFourierModelConfig( model_name="sarix_p2_fourier_K2_shared", - - # data sources - sources=["nhsn"], - - # fit locations separately or jointly + sources=[DataSource.NHSN], fit_locations_separately=False, - - # SARIX model parameters p=2, P=0, d=0, D=0, season_period=1, - - # power transform - power_transform="4rt", - - # parameter pooling - theta_pooling="shared", - sigma_pooling="shared", - - # Fourier parameters + power_transform=PowerTransform.FOURTH_ROOT, + theta_pooling=PoolingStrategy.SHARED, + sigma_pooling=PoolingStrategy.SHARED, fourier_K=2, - fourier_pooling="shared", # Shared Fourier coefficients - - # covariates - x=[] - ) + fourier_pooling=PoolingStrategy.SHARED, + x=[], + num_warmup=50, + num_samples=50) date = datetime.date.fromisoformat("2024-01-06") fips_codes = ["US", "01", "02", "04", "05"] # fewer locs for faster testing - # model_config = create_test_sarix_model_config(main_source=["nhsn"], theta_pooling="shared", sigma_pooling="none") - run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], num=50, tmp_path=tmp_path) + # model_config = create_test_sarix_model_config(main_source=[DataSource.NHSN], theta_pooling="shared", sigma_pooling="none") + run_config = create_test_sarix_run_config(ref_date=date, states=fips_codes, hsas=[], tmp_path=tmp_path) model = SARIXFourierModel(model_config) model.run(run_config) @@ -231,60 +214,33 @@ def test_sarix_fourier_shared_pooling(tmp_path): "All predictions should be non-negative" -def test_sarix_fourier_missing_pooling_parameter(): - """Test that SARIXFourierModel raises error when fourier_pooling is missing.""" - model_config = SimpleNamespace( - model_class="sarix_fourier", - model_name="sarix_p2_fourier_K2_nopooling", - sources=["nhsn"], +def test_sarix_fourier_wrong_config_type(): + """Test that SARIXFourierModel raises TypeError when given a SARIXModelConfig instead of SARIXFourierModelConfig.""" + model_config = SARIXModelConfig( + model_name="sarix_p2", + sources=[DataSource.NHSN], fit_locations_separately=False, p=2, P=0, d=0, D=0, season_period=1, - power_transform="4rt", - theta_pooling="shared", - sigma_pooling="shared", - fourier_K=2, - # fourier_pooling is MISSING - should cause error + power_transform=PowerTransform.FOURTH_ROOT, + theta_pooling=PoolingStrategy.SHARED, + sigma_pooling=PoolingStrategy.SHARED, x=[] ) - run_config = SimpleNamespace( - disease="flu", - ref_date=datetime.date.fromisoformat("2024-01-06"), - output_root=Path("/tmp") / "model-output", - artifact_store_root=Path("/tmp") / "artifact-store", - save_feat_importance=False, - states=["US"], - hsas=[], - max_horizon=1, - q_levels=[0.5], - q_labels=["0.5"], - num_warmup=10, - num_samples=10, - num_chains=1 - ) - - model = SARIXFourierModel(model_config) + with pytest.raises(TypeError, match="SARIXFourierModel requires a SARIXFourierModelConfig"): + SARIXFourierModel(model_config) - # Should raise AttributeError when trying to access missing fourier_pooling - try: - model.run(run_config) - assert False, "Should have raised AttributeError for missing fourier_pooling" - except AttributeError as e: - assert "fourier_pooling" in str(e), \ - f"Error should mention fourier_pooling, got: {str(e)}" +def create_test_sarix_model_config(main_source, theta_pooling: PoolingStrategy, sigma_pooling: PoolingStrategy, num: int = 200): + model_config = SARIXModelConfig( + model_name = "sarix_" + main_source[0].value + "_p6_4rt_theta" + theta_pooling.value + "_sigma" + sigma_pooling.value, -def create_test_sarix_model_config(main_source, theta_pooling, sigma_pooling): - model_config = SimpleNamespace( - model_class = "sarix", - model_name = "sarix_" + main_source[0] + "_p6_4rt_theta" + theta_pooling + "_sigma" + sigma_pooling, - # data sources and adjustments for reporting issues sources = main_source, - + # fit locations separately or jointly fit_locations_separately = False, - + # SARI model parameters p = 6, P = 0, @@ -293,32 +249,32 @@ def create_test_sarix_model_config(main_source, theta_pooling, sigma_pooling): season_period = 1, # power transform applied to surveillance signals - power_transform = "4rt", + power_transform = PowerTransform.FOURTH_ROOT, # sharing of information about parameters theta_pooling=theta_pooling, sigma_pooling=sigma_pooling, - + # covariates - x = [] + x=[], + + num_warmup=num, + num_samples=num, + num_chains=1 ) return model_config -def create_test_sarix_run_config(ref_date, states, hsas, num, tmp_path): - run_config = SimpleNamespace( - disease="flu", +def create_test_sarix_run_config(ref_date, states, hsas, tmp_path): + run_config = RunConfig( + disease=Disease.FLU, ref_date=ref_date, output_root=tmp_path / "model-output", artifact_store_root=tmp_path / "artifact-store", - save_feat_importance=False, states=states, hsas = hsas, max_horizon=3, - q_levels = [0.025, 0.50, 0.975], - q_labels = ["0.025", "0.5", "0.975"], - num_warmup = num, - num_samples = num, - num_chains = 1 + q_levels=[0.025, 0.50, 0.975], + q_labels=["0.025", "0.5", "0.975"], ) return run_config diff --git a/uv.lock b/uv.lock index 08be397..c028cce 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14'", @@ -562,8 +562,8 @@ wheels = [ [[package]] name = "iddata" -version = "0.0.1" -source = { git = "https://github.com/reichlab/iddata#b804f63d85b8f4f1c48cf44e3cb9f6764167bab2" } +version = "0.1.0" +source = { git = "https://github.com/reichlab/iddata#d049cd28ebe65ac1a0cba82e84691adc34e6ff09" } dependencies = [ { name = "numpy" }, { name = "pandas" }, @@ -610,11 +610,11 @@ requires-dist = [ { name = "iddata", git = "https://github.com/reichlab/iddata" }, { name = "lightgbm" }, { name = "numpy" }, - { name = "pandas" }, + { name = "pandas", specifier = "~=2.0" }, { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'" }, - { name = "sarix", git = "https://github.com/reichlab/sarix?rev=35eea2379a9790e0457b1aed41d13509e5d5056f" }, + { name = "sarix", git = "https://github.com/reichlab/sarix" }, { name = "scikit-learn" }, { name = "timeseriesutils", git = "https://github.com/reichlab/timeseriesutils" }, { name = "tqdm" }, @@ -1635,7 +1635,7 @@ wheels = [ [[package]] name = "sarix" version = "0.2.0" -source = { git = "https://github.com/reichlab/sarix?rev=35eea2379a9790e0457b1aed41d13509e5d5056f#35eea2379a9790e0457b1aed41d13509e5d5056f" } +source = { git = "https://github.com/reichlab/sarix#35eea2379a9790e0457b1aed41d13509e5d5056f" } dependencies = [ { name = "jax" }, { name = "matplotlib" },