Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
992eb4a
Update `iddata` requirement
lshandross Oct 7, 2025
443ef52
Update sarix to handle nssp data
lshandross Oct 17, 2025
299386c
Update gbqr to handle nssp data
lshandross Oct 17, 2025
5290921
Restrict NSSP GBQR preds to [0, 1]
lshandross Oct 21, 2025
e77dfdd
Update tests and test data for NSSP support
lshandross Oct 21, 2025
0f4e462
Update requirements text files
lshandross Oct 21, 2025
06f3b4d
Use reichlab/sarix as dependency
lshandross Oct 22, 2025
465cf6e
Remove debugging code
lshandross Oct 23, 2025
7a45c3e
Use if-elif instead of two ifs
lshandross Oct 23, 2025
3872814
Use not instead of ~
lshandross Oct 23, 2025
1da182f
Update sarix package dependency version
lshandross Oct 23, 2025
c29f044
Update sarix dependency commit hash
lshandross Oct 27, 2025
25f6238
Convert NSSP observed percentage to proportion
lshandross Oct 31, 2025
c9c7b11
Update SARIX NSSP integration test data
lshandross Oct 31, 2025
efa0ff5
Update sarix for hsa and state-level forecasting
lshandross Nov 13, 2025
2db039f
Refactor sarix tests to create configs w helpers
lshandross Nov 13, 2025
d46122e
Add more sarix tests + update test data
lshandross Nov 14, 2025
3dec0e4
Update + refactor gbqr tests
lshandross Nov 14, 2025
08c8817
Update gbqr for hsa and state-level forecasting
lshandross Nov 14, 2025
c2c27da
Restrict rates to counts conversion to nhsn; update tests + test data
lshandross Nov 14, 2025
07e32f2
Update `iddata` dependency in requirements files
lshandross Nov 17, 2025
80fdec8
Merge branch 'main' into ls/sarix-gbqr-takes-nssp-data/9
lshandross Nov 18, 2025
583f9e5
Fix linter I001 error
lshandross Nov 18, 2025
3c392b2
Update sarix fourier tests + config creation methods
lshandross Nov 18, 2025
d20f7a0
Update nssp tests to use `ref_date=2025-09-20` + update test data
lshandross Nov 18, 2025
5e145aa
Internal nssp loading uses provided ref date
lshandross Nov 18, 2025
724f8ed
Remove redundant code
lshandross Nov 24, 2025
b35fcc4
Update unique location calculation
lshandross Nov 24, 2025
a63d488
Add in-line comment about interpolation
lshandross Nov 24, 2025
397827b
Increment to v1.0.0
lshandross Nov 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"lightgbm",
"numpy",
"pandas",
"sarix @ git+https://github.com/reichlab/sarix@35eea2379a9790e0457b1aed41d13509e5d5056f",
"sarix @ git+https://github.com/reichlab/sarix",
"scikit-learn",
"tqdm",
"timeseriesutils @ git+https://github.com/reichlab/timeseriesutils"
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
# via idmodels (pyproject.toml)
identify==2.6.1
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ frozenlist==1.5.0
# aiosignal
fsspec==2024.10.0
# via s3fs
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
iddata @ git+https://github.com/reichlab/iddata@5a7e74d7823d39b8a8ef6334c5191e440bc669d8
# via idmodels (pyproject.toml)
idna==3.10
# via yarl
Expand Down
2 changes: 1 addition & 1 deletion src/idmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "1.0.0"
58 changes: 45 additions & 13 deletions src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,38 @@ def run(self, run_config):
ilinet_kwargs = {"scale_to_positive": False}
flusurvnet_kwargs = {"burden_adj": False}

valid_sources = ["flusurvnet", "nhsn", "ilinet", "nssp"]
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
raise ValueError("For GBQR, the only supported data sources are 'nhsn', 'flusurvnet', 'ilinet', or 'nssp'.")

# Check if both nhsn and nssp data are included as sources
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")

fdl = DiseaseDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
if run_config.locations is not None:
df = df.loc[df["location"].isin(run_config.locations)]
if "nhsn" in self.model_config.sources:
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
elif "nssp" in self.model_config.sources:
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)

if (run_config.states == []) & (run_config.hsas == []):
raise ValueError("User must request a non-empty set of locations to forecast for.")

if (run_config.states != []) & (run_config.hsas != []):
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")

df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)

# augment data with features and target values
if run_config.disease == "flu":
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
Expand Down Expand Up @@ -133,7 +156,7 @@ def _train_gbq_and_predict(self, run_config,
"inc_trans_cs", "horizon",
"inc_trans_center_factor", "inc_trans_scale_factor"]
preds_df = df_test_w_preds[cols_to_keep + run_config.q_labels]
preds_df = preds_df.loc[(preds_df["source"] == "nhsn")]
preds_df = preds_df.loc[preds_df["source"].isin(["nhsn", "nssp"])]
preds_df = pd.melt(preds_df,
id_vars=cols_to_keep,
var_name="quantile",
Expand All @@ -149,11 +172,20 @@ def _train_gbq_and_predict(self, run_config,
else:
raise ValueError('unsupported power_transform: must be "4rt" or None')

preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4) * preds_df["pop"] / 100000
preds_df["value"] = np.maximum(preds_df["value"], 0.0)
preds_df["value"] = (np.maximum(preds_df["inc_trans_target_hat"], 0.0) ** inv_power - 0.01 - 0.75**4)

# get predictions into the format needed for FluSight hub submission
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, run_config.disease)
if "nhsn" in preds_df["source"].unique():
# turn nhsn rates back into counts
preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000
target_name = "wk inc " + run_config.disease + " hosp"
elif "nssp" in preds_df["source"].unique():
preds_df["value"] = preds_df["value"] / 100 # percentage to proportion
preds_df["value"] = np.minimum(preds_df["value"], 1.0)
target_name = "wk inc " + run_config.disease + " prop ed visits"

preds_df["value"] = np.maximum(preds_df["value"], 0.0)
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, target_name)

# sort quantiles to avoid quantile crossing
preds_df = self._quantile_noncrossing(
Expand Down Expand Up @@ -248,15 +280,15 @@ def _get_test_quantile_predictions(self, run_config,
return test_pred_qs_df


def _format_as_flusight_output(self, preds_df, ref_date, disease):
def _format_as_flusight_output(self, preds_df, ref_date, target_name):
# keep just required columns and rename to match hub format
preds_df = preds_df[["location", "wk_end_date", "horizon", "quantile", "value"]] \
.rename(columns={"quantile": "output_type_id"})

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
preds_df["reference_date"] = ref_date
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - ref_date).dt.days / 7).astype(int)
preds_df["target"] = "wk inc " + disease + " hosp"
preds_df["target"] = target_name

preds_df["output_type"] = "quantile"
preds_df.drop(columns="wk_end_date", inplace=True)
Expand Down
59 changes: 46 additions & 13 deletions src/idmodels/sarix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,35 @@ def _get_extra_sarix_params(self, df):
return {}

def run(self, run_config):
valid_sources = np.array(["nhsn", "nssp"])
if not np.isin(np.array(self.model_config.sources), valid_sources).all():
raise ValueError("For SARIX, the only supported data sources are 'nhsn' or 'nssp'.")

# Check if both nhsn and nssp data are included as sources
if all(src in self.model_config.sources for src in ["nhsn", "nssp"]):
raise ValueError("Only one of 'nhsn' or 'nssp' may be selected as a data source.")

fdl = DiseaseDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
if run_config.locations is not None:
df = df.loc[df["location"].isin(run_config.locations)]
if "nhsn" in self.model_config.sources:
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
target_name = "wk inc " + run_config.disease + " hosp"
elif "nssp" in self.model_config.sources:
df = fdl.load_data(nssp_kwargs={"as_of": run_config.ref_date, "disease": run_config.disease},
sources=self.model_config.sources,
power_transform=self.model_config.power_transform)
target_name = "wk inc " + run_config.disease + " prop ed visits"

if (run_config.states == []) & (run_config.hsas == []):
raise ValueError("User must request a non-empty set of locations to forecast for.")

if (run_config.states != []) & (run_config.hsas != []):
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")

df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)

# season week relative to christmas
df = df.merge(
Expand All @@ -34,10 +57,12 @@ def run(self, run_config):
on="season") \
.assign(delta_xmas = lambda x: x["season_week"] - x["xmas_week"])
df["xmas_spike"] = np.maximum(3 - np.abs(df["delta_xmas"]), 0)


# missing values are interpolated when possible
xy_colnames = self.model_config.x + ["inc_trans_cs"]
df = df.query("wk_end_date >= '2022-10-01'").interpolate()
batched_xy = df[xy_colnames].values.reshape(len(df["location"].unique()), -1, len(xy_colnames))
unique_locations = len(df_states["location"].unique()) + len(df_hsas["location"].unique())
batched_xy = df[xy_colnames].values.reshape(unique_locations, -1, len(xy_colnames))

# Get any extra parameters for the SARIX constructor
extra_params = self._get_extra_sarix_params(df)
Expand All @@ -62,18 +87,18 @@ def run(self, run_config):
pred_qs = _np_percentile(sarix_fit_all_locs_theta_pooled.predictions[..., :, :, 0],
np.array(run_config.q_levels) * 100, axis=0)

df_nhsn_last_obs = df.groupby(["location"]).tail(1)
df_data_last_obs = df.groupby(["location", "agg_level"]).tail(1)

preds_df = pd.concat([
pd.DataFrame(pred_qs[i, :, :]) \
.set_axis(df_nhsn_last_obs["location"], axis="index") \
.set_axis(df_data_last_obs["location"], axis="index") \
.set_axis(np.arange(1, run_config.max_horizon+1), axis="columns") \
.assign(output_type_id = q_label) \
for i, q_label in enumerate(run_config.q_labels)
]) \
.reset_index() \
.melt(["location", "output_type_id"], var_name="horizon") \
.merge(df_nhsn_last_obs, on="location", how="left")
.merge(df_data_last_obs, on="location", how="left")

# build data frame with predictions on the original scale
preds_df["value"] = (preds_df["value"] + preds_df["inc_trans_center_factor"]) * preds_df["inc_trans_scale_factor"]
Expand All @@ -82,19 +107,27 @@ def run(self, run_config):
else:
preds_df["value"] = np.maximum(preds_df["value"], 0.0) ** 2

preds_df["value"] = (preds_df["value"] - 0.01 - 0.75**4) * preds_df["pop"] / 100000
preds_df["value"] = (preds_df["value"] - 0.01 - 0.75**4)
preds_df["value"] = np.maximum(preds_df["value"], 0.0)

if "nhsn" in preds_df["source"].unique():
# turn nhsn rates back into counts
preds_df["value"] = preds_df["value"] * preds_df["pop"] / 100000

if target_name == "wk inc " + run_config.disease + " prop ed visits":
preds_df["value"] = preds_df["value"] / 100 # percentage to proportion
preds_df["value"] = np.minimum(preds_df["value"], 1.0)

# keep just required columns and rename to match hub format
preds_df = preds_df[["location", "wk_end_date", "horizon", "output_type_id", "value"]]

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
preds_df["reference_date"] = run_config.ref_date
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - run_config.ref_date).dt.days / 7).astype(int)
preds_df["output_type"] = "quantile"
preds_df["target"] = "wk inc " + run_config.disease + " hosp"
preds_df["target"] = target_name
preds_df.drop(columns="wk_end_date", inplace=True)

# save
save_path = build_save_path(
root=run_config.output_root,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
location,reference_date,horizon,target_end_date,target,output_type,output_type_id,value
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.025,0.0
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.025,0.0
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.025,0.0
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.5,0.0
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.5,0.0
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.5,0.0
1,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
25,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
99,2025-09-20,-1,2025-09-13,wk inc flu prop ed visits,quantile,0.975,0.0
1,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
25,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
99,2025-09-20,0,2025-09-20,wk inc flu prop ed visits,quantile,0.975,0.0
1,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0
25,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0
99,2025-09-20,1,2025-09-27,wk inc flu prop ed visits,quantile,0.975,0.0
Loading