Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.3.1]

### Changed
- **Breaking**: `SARIXRunConfig` and `GBQRRunConfig` removed; use `RunConfig` directly (was abstract but is now directly instantiable)
- **Breaking**: `num_warmup`, `num_samples`, `num_chains` moved from `SARIXRunConfig` to `SARIXModelConfig`
- **Breaking**: `save_feat_importance` moved from `GBQRRunConfig` to `GBQRModelConfig`

## [1.3.0]

### Added
- Concrete configuration dataclasses: `ModelConfig`, `RunConfig` (abstract bases), `SARIXModelConfig`, `SARIXRunConfig`, `GBQRModelConfig`, `GBQRRunConfig`
- `SARIXFourierModelConfig` dataclass with `fourier_K` and `fourier_pooling` fields
- Wave feature fields on `GBQRModelConfig` (`use_directional_waves`, `wave_directions`, etc.), disabled by default
- Enum types: `DataSource`, `Disease`, `PowerTransform`, `PoolingStrategy`
- Docstrings for `ModelConfig` and `RunConfig` base classes
- All config types exported from `idmodels.__init__`

### Changed
- **Breaking**: `model_config.sources` now expects `list[DataSource]` instead of `list[str]`
- **Breaking**: `model_config.power_transform` now expects `PowerTransform` instead of `str`
- **Breaking**: `model_config.disease` now expects `Disease` instead of `str`
- Source validation in `sarix.py` and `gbqr.py` uses `DataSource` enums and set operations instead of `np.isin` with string arrays
- All tests use concrete config dataclasses instead of `SimpleNamespace`
- Updated `directional_wave_features.md` examples to use config dataclasses

### Removed
- `SimpleNamespace` usage throughout tests and documentation
- `model_class` field from model configurations (implied by the dataclass type)
- `num_bags` from `GBQRRunConfig` test helper (it is a `GBQRModelConfig` field)
- `save_feat_importance` from `SARIXRunConfig` test helper (not a SARIX field)

## [1.1.0] - 2025-12-08

### Added
Expand Down Expand Up @@ -71,7 +102,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Updated to latest iddata API

[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.1.0...HEAD
[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.3.0...HEAD
[1.3.0]: https://github.com/reichlab/idmodels/compare/v1.1.0...v1.3.0
[1.1.0]: https://github.com/reichlab/idmodels/compare/v1.0.0...v1.1.0
[1.0.0]: https://github.com/reichlab/idmodels/compare/v0.1.0...v1.0.0
[0.1.0]: https://github.com/reichlab/idmodels/compare/v0.0.1...v0.1.0
Expand Down
55 changes: 29 additions & 26 deletions docs/directional_wave_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ For each location and time point, the following features are generated:
Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`:

```python
from types import SimpleNamespace
from idmodels.config import DataSource, GBQRModelConfig, PowerTransform

model_config = SimpleNamespace(
# ... existing parameters ...
model_config = GBQRModelConfig(
model_name = "gbqr_with_waves",
sources = [DataSource.NHSN],
fit_locations_separately = False,
power_transform = PowerTransform.FOURTH_ROOT,

# Directional wave features (disabled by default)
use_directional_waves = True, # Set to True to enable
Expand Down Expand Up @@ -111,8 +114,8 @@ model_config = SimpleNamespace(

### Minimal Configuration (4 cardinal directions)
```python
model_config = SimpleNamespace(
# ... other params ...
model_config = GBQRModelConfig(
# ... required base params ...
use_directional_waves = True,
wave_directions = ['N', 'S', 'E', 'W']
)
Expand All @@ -121,8 +124,8 @@ Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features**

### Standard Configuration (8 directions)
```python
model_config = SimpleNamespace(
# ... other params ...
model_config = GBQRModelConfig(
# ... required base params ...
use_directional_waves = True,
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
wave_temporal_lags = [1, 2]
Expand All @@ -132,8 +135,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features**

### Maximum Information (all options)
```python
model_config = SimpleNamespace(
# ... other params ...
model_config = GBQRModelConfig(
# ... required base params ...
use_directional_waves = True,
wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
wave_temporal_lags = [1, 2],
Expand All @@ -147,8 +150,8 @@ Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features
### Hypothesis-Driven (specific directions)
```python
# If you suspect disease spreads along NE-SW axis
model_config = SimpleNamespace(
# ... other params ...
model_config = GBQRModelConfig(
# ... required base params ...
use_directional_waves = True,
wave_directions = ['NE', 'SW'],
wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread
Expand Down Expand Up @@ -240,22 +243,22 @@ The implementation includes validation that warns about:
## Example: Complete GBQR Configuration

```python
from types import SimpleNamespace
import datetime
from pathlib import Path
from idmodels.config import DataSource, Disease, GBQRModelConfig, PowerTransform, RunConfig
from idmodels.gbqr import GBQRModel

# Model configuration with directional wave features
model_config = SimpleNamespace(
model_class = "gbqr",
model_config = GBQRModelConfig(
model_name = "gbqr_with_waves",

# Standard GBQR parameters
sources = [DataSource.NHSN],
fit_locations_separately = False,
power_transform = PowerTransform.FOURTH_ROOT,
incl_level_feats = True,
num_bags = 10,
bag_frac_samples = 0.7,
reporting_adj = False,
sources = ["nhsn"],
fit_locations_separately = False,
power_transform = "4rt",
save_feat_importance = True,

# Directional wave features
use_directional_waves = True,
Expand All @@ -267,16 +270,16 @@ model_config = SimpleNamespace(
)

# Run configuration
run_config = SimpleNamespace(
disease = "flu",
run_config = RunConfig(
disease = Disease.FLU,
ref_date = datetime.date(2024, 1, 6),
output_root = "output/",
artifact_store_root = "artifacts/",
save_feat_importance = True,
locations = None, # All locations
output_root = Path("output/"),
artifact_store_root = Path("artifacts/"),
max_horizon = 4,
states = ["US", "01", "06", "13", "36", "48"],
hsas = [],
q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975],
q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"]
q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"],
)

# Run model
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ dependencies = [
"iddata @ git+https://github.com/reichlab/iddata",
"lightgbm",
"numpy",
"pandas",
"sarix @ git+https://github.com/reichlab/sarix",
"pandas~=2.0", # pandas 3.0 breaks compatibility; remove cap once validated
"sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f",
"scikit-learn",
"tqdm",
"timeseriesutils @ git+https://github.com/reichlab/timeseriesutils"
Expand Down
17 changes: 16 additions & 1 deletion src/idmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
__version__ = "1.2.0"
from idmodels.config import (
DataSource,
Disease,
GBQRModelConfig,
PoolingStrategy,
PowerTransform,
RunConfig,
SARIXFourierModelConfig,
SARIXModelConfig,
)
from idmodels.gbqr import GBQRModel
from idmodels.sarix import SARIXFourierModel, SARIXModel

__all__ = ["DataSource", "Disease", "GBQRModel", "GBQRModelConfig", "PoolingStrategy", "PowerTransform", "RunConfig",
"SARIXFourierModel", "SARIXFourierModelConfig", "SARIXModel", "SARIXModelConfig"]

__version__ = "1.3.1"
104 changes: 104 additions & 0 deletions src/idmodels/config.py
Copy link
Contributor

@trobacker trobacker Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! One thing to keep in mind is that if model development goes through many iterations, we may end up having a lot of enums in the configs. I think we should keep some kind of a doc for their purpose and intention. E.g. what does "Pooling" mean to a user of the model (even if it may be obvious).

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import datetime
from abc import ABC
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path


class DataSource(str, Enum):
NHSN = "nhsn"
NSSP = "nssp"
FLUSURVNET = "flusurvnet"
ILINET = "ilinet"


class Disease(str, Enum):
FLU = "flu"
COVID = "covid"
RSV = "rsv"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RSV should be added


class PowerTransform(str, Enum):
FOURTH_ROOT = "4rt"
NONE = "none"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add none

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should have a NONE option as well.


class PoolingStrategy(str, Enum):
NONE = "none"
SHARED = "shared"


@dataclass
class ModelConfig(ABC):
"""
Abstract base for model configuration.

Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling).
Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`.
"""

model_name: str
sources: list[DataSource]
fit_locations_separately: bool
power_transform: PowerTransform

def __post_init__(self):
if type(self) is ModelConfig:
raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig")


@dataclass
class RunConfig:
"""
Run configuration.

Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc.
"""

disease: Disease
ref_date: datetime.date
output_root: Path
artifact_store_root: Path | None
max_horizon: int
states: list[str]
hsas: list[str]
q_levels: list[float]
q_labels: list[str]


@dataclass
class SARIXModelConfig(ModelConfig):
p: int = 0
P: int = 0
d: int = 0
D: int = 0
season_period: int = 1
theta_pooling: PoolingStrategy = PoolingStrategy.NONE
sigma_pooling: PoolingStrategy = PoolingStrategy.NONE
x: list = field(default_factory=list)
num_warmup: int = 2000
num_samples: int = 2000
num_chains: int = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check with @nickreich if these are indeed the defaults we want for the SARIX model class parameters (and we should do the same for the GBQR model class)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel that strongly about what the defaults are/should be. Defaulting to NONE seems reasonable to me.


@dataclass
class SARIXFourierModelConfig(SARIXModelConfig):
fourier_K: int = 1
fourier_pooling: PoolingStrategy = PoolingStrategy.NONE


@dataclass
class GBQRModelConfig(ModelConfig):
incl_level_feats: bool = True
num_bags: int = 100
bag_frac_samples: float = 0.7
reporting_adj: bool = False
save_feat_importance: bool = False

# directional wave features (disabled by default)
use_directional_waves: bool = False
wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"])
wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2])
wave_max_distance_km: float = 1000.0
wave_include_velocity: bool = False
wave_include_aggregate: bool = True
41 changes: 21 additions & 20 deletions src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from iddata.loader import DiseaseDataLoader
from tqdm.autonotebook import tqdm

from idmodels.config import DataSource, Disease, PowerTransform
from idmodels.preprocess import create_directional_wave_features, create_features_and_targets
from idmodels.utils import build_save_path

Expand All @@ -31,22 +32,22 @@ def run(self, run_config):
ilinet_kwargs = {"scale_to_positive": False}
flusurvnet_kwargs = {"burden_adj": False}

valid_sources = ["flusurvnet", "nhsn", "ilinet", "nssp"]
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
valid_sources = {DataSource.FLUSURVNET, DataSource.NHSN, DataSource.ILINET, DataSource.NSSP}
if not set(self.model_config.sources) <= valid_sources:
raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.")

# Check if both nhsn and nssp data are included as sources
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
if (DataSource.NHSN in self.model_config.sources) and (DataSource.NSSP in self.model_config.sources):
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")

fdl = DiseaseDataLoader()
if "nhsn" in self.model_config.sources:
if DataSource.NHSN in self.model_config.sources:
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
elif "nssp" in self.model_config.sources:
elif DataSource.NSSP in self.model_config.sources:
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
Expand All @@ -62,20 +63,20 @@ def run(self, run_config):
df["unique_id"] = df["agg_level"] + df["location"]

# augment data with features and target values
if run_config.disease == "flu":
if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV):
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
elif run_config.disease == "covid":
elif run_config.disease == Disease.COVID:
init_feats = ["inc_trans_cs", "log_pop"]

# Create directional wave features if enabled
if hasattr(self.model_config, "use_directional_waves") and self.model_config.use_directional_waves:
if self.model_config.use_directional_waves:
wave_config = {
"enabled": True,
"directions": getattr(self.model_config, "wave_directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]),
"temporal_lags": getattr(self.model_config, "wave_temporal_lags", [1, 2]),
"max_distance_km": getattr(self.model_config, "wave_max_distance_km", 1000),
"include_velocity": getattr(self.model_config, "wave_include_velocity", False),
"include_aggregate": getattr(self.model_config, "wave_include_aggregate", True)
"directions": self.model_config.wave_directions,
"temporal_lags": self.model_config.wave_temporal_lags,
"max_distance_km": self.model_config.wave_max_distance_km,
"include_velocity": self.model_config.wave_include_velocity,
"include_aggregate": self.model_config.wave_include_aggregate,
}
df, wave_feat_names = create_directional_wave_features(df, wave_config)
init_feats = init_feats + wave_feat_names
Expand All @@ -87,7 +88,7 @@ def run(self, run_config):
curr_feat_names=init_feats)

# keep only rows that are in-season
if run_config.disease == "flu":
if (run_config.disease == Disease.FLU) or (run_config.disease == Disease.RSV):
df = df.query("season_week >= 5 and season_week <= 45")

# "test set" df used to generate look-ahead predictions
Expand Down Expand Up @@ -176,12 +177,12 @@ def _train_gbq_and_predict(self, run_config,
# build data frame with predictions on the original scale
preds_df["inc_trans_cs_target_hat"] = preds_df["inc_trans_cs"] + preds_df["delta_hat"]
preds_df["inc_trans_target_hat"] = (preds_df["inc_trans_cs_target_hat"] + preds_df["inc_trans_center_factor"]) * (preds_df["inc_trans_scale_factor"] + 0.01)
if self.model_config.power_transform == "4rt":
if self.model_config.power_transform == PowerTransform.FOURTH_ROOT:
inv_power = 4
elif self.model_config.power_transform is None:
elif self.model_config.power_transform == PowerTransform.NONE:
inv_power = 1
else:
raise ValueError('unsupported power_transform: must be "4rt" or None')
raise ValueError(f"unsupported power_transform: {self.model_config.power_transform!r}")

preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4)

Expand Down Expand Up @@ -276,7 +277,7 @@ def _get_test_quantile_predictions(self, run_config,
test_preds_by_bag[:, b, q_ind] = model.predict(X=x_test)

# combine and save feature importance scores
if run_config.save_feat_importance:
if self.model_config.save_feat_importance:
feat_importance = pd.concat(feat_importance, axis=0)
save_path = build_save_path(
root=run_config.artifact_store_root,
Expand Down
Loading
Loading