-
Notifications
You must be signed in to change notification settings - Fork 0
replace SimpleNamespace configs with typed dataclasses and enums #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0d0698a
5b3ca1c
de6be08
d74fe38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,17 @@ | ||
| __version__ = "1.2.0" | ||
| from idmodels.config import ( | ||
| DataSource, | ||
| Disease, | ||
| GBQRModelConfig, | ||
| PoolingStrategy, | ||
| PowerTransform, | ||
| RunConfig, | ||
| SARIXFourierModelConfig, | ||
| SARIXModelConfig, | ||
| ) | ||
| from idmodels.gbqr import GBQRModel | ||
| from idmodels.sarix import SARIXFourierModel, SARIXModel | ||
|
|
||
| __all__ = ["DataSource", "Disease", "GBQRModel", "GBQRModelConfig", "PoolingStrategy", "PowerTransform", "RunConfig", | ||
| "SARIXFourierModel", "SARIXFourierModelConfig", "SARIXModel", "SARIXModelConfig"] | ||
|
|
||
| __version__ = "1.3.1" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import datetime | ||
| from abc import ABC | ||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| class DataSource(str, Enum): | ||
| NHSN = "nhsn" | ||
| NSSP = "nssp" | ||
| FLUSURVNET = "flusurvnet" | ||
| ILINET = "ilinet" | ||
|
|
||
|
|
||
| class Disease(str, Enum): | ||
| FLU = "flu" | ||
| COVID = "covid" | ||
| RSV = "rsv" | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RSV should be added |
||
|
|
||
| class PowerTransform(str, Enum): | ||
| FOURTH_ROOT = "4rt" | ||
| NONE = "none" | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add none
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, we should have a NONE option as well. |
||
|
|
||
| class PoolingStrategy(str, Enum): | ||
| NONE = "none" | ||
| SHARED = "shared" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelConfig(ABC): | ||
| """ | ||
| Abstract base for model configuration. | ||
|
|
||
| Holds settings that describe *what* model to run and how it processes data (sources, transforms, pooling). | ||
| Not instantiated directly - use :class:`SARIXModelConfig` or :class:`GBQRModelConfig`. | ||
| """ | ||
|
|
||
| model_name: str | ||
| sources: list[DataSource] | ||
| fit_locations_separately: bool | ||
| power_transform: PowerTransform | ||
|
|
||
| def __post_init__(self): | ||
| if type(self) is ModelConfig: | ||
| raise TypeError("ModelConfig is abstract - use SARIXModelConfig or GBQRModelConfig") | ||
|
|
||
|
|
||
| @dataclass | ||
| class RunConfig: | ||
| """ | ||
| Run configuration. | ||
|
|
||
| Holds settings that describe a single execution: which disease, which locations, output paths, quantile levels, etc. | ||
| """ | ||
|
|
||
| disease: Disease | ||
| ref_date: datetime.date | ||
| output_root: Path | ||
| artifact_store_root: Path | None | ||
| max_horizon: int | ||
| states: list[str] | ||
| hsas: list[str] | ||
| q_levels: list[float] | ||
| q_labels: list[str] | ||
|
|
||
|
|
||
| @dataclass | ||
| class SARIXModelConfig(ModelConfig): | ||
| p: int = 0 | ||
| P: int = 0 | ||
| d: int = 0 | ||
| D: int = 0 | ||
| season_period: int = 1 | ||
| theta_pooling: PoolingStrategy = PoolingStrategy.NONE | ||
| sigma_pooling: PoolingStrategy = PoolingStrategy.NONE | ||
| x: list = field(default_factory=list) | ||
| num_warmup: int = 2000 | ||
| num_samples: int = 2000 | ||
| num_chains: int = 1 | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should check with @nickreich if these are indeed the defaults we want for the SARIX model class parameters (and we should do the same for the GBQR model class)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel that strongly about what the defaults are/should be. Defaulting to NONE seems reasonable to me. |
||
|
|
||
| @dataclass | ||
| class SARIXFourierModelConfig(SARIXModelConfig): | ||
| fourier_K: int = 1 | ||
| fourier_pooling: PoolingStrategy = PoolingStrategy.NONE | ||
|
|
||
|
|
||
| @dataclass | ||
| class GBQRModelConfig(ModelConfig): | ||
| incl_level_feats: bool = True | ||
| num_bags: int = 100 | ||
| bag_frac_samples: float = 0.7 | ||
| reporting_adj: bool = False | ||
| save_feat_importance: bool = False | ||
|
|
||
| # directional wave features (disabled by default) | ||
| use_directional_waves: bool = False | ||
| wave_directions: list[str] = field(default_factory=lambda: ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) | ||
| wave_temporal_lags: list[int] = field(default_factory=lambda: [1, 2]) | ||
| wave_max_distance_km: float = 1000.0 | ||
| wave_include_velocity: bool = False | ||
| wave_include_aggregate: bool = True | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this! One thing to keep in mind is that if model development goes through many iterations, we may end up having a lot of enums in the configs. I think we should keep some kind of a doc for their purpose and intention. E.g. what does "Pooling" mean to a user of the model (even if it may be obvious).