Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/idmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "1.1.0"
__version__ = "1.2.0"

35 changes: 21 additions & 14 deletions src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ def run(self, run_config):
if (run_config.states == []) & (run_config.hsas == []):
raise ValueError("User must request a non-empty set of locations to forecast for.")

if (run_config.states != []) & (run_config.hsas != []):
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")

df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)
df["unique_id"] = df["agg_level"] + df["location"]

# augment data with features and target values
if run_config.disease == "flu":
Expand Down Expand Up @@ -101,12 +99,12 @@ def run(self, run_config):

# train model and obtain test set predictinos
if self.model_config.fit_locations_separately:
locations = df_test["location"].unique()
unique_ids = df_test["unique_id"].unique()
preds_df = [
self._train_gbq_and_predict(
run_config,
df_train, df_test, feat_names, location
) for location in locations
) for location in unique_ids
]
preds_df = pd.concat(preds_df, axis=0)
else:
Expand Down Expand Up @@ -165,7 +163,7 @@ def _train_gbq_and_predict(self, run_config,

# melt to get columns into rows, keeping only the things we need to invert data
# transforms later on
cols_to_keep = ["source", "location", "wk_end_date", "pop",
cols_to_keep = ["source", "agg_level", "location", "wk_end_date", "pop",
"inc_trans_cs", "horizon",
"inc_trans_center_factor", "inc_trans_scale_factor"]
preds_df = df_test_w_preds[cols_to_keep + run_config.q_labels]
Expand Down Expand Up @@ -197,15 +195,19 @@ def _train_gbq_and_predict(self, run_config,
preds_df["value"] = np.minimum(preds_df["value"], 1.0)
target_name = "wk inc " + run_config.disease + " prop ed visits"

keep_agg_levels = False
gcols = ["location", "reference_date", "horizon", "target_end_date", "target", "output_type"]
# we count national as state since it is coded using the same 2-digit fips code
preds_df["geo_level"] = np.where(preds_df["agg_level"] == "national", "state", preds_df["agg_level"])
if len(preds_df["geo_level"].unique()) > 1:
keep_agg_levels = True
gcols.insert(0, "agg_level")

preds_df["value"] = np.maximum(preds_df["value"], 0.0)
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, target_name)
preds_df = self._format_as_flusight_output(preds_df, run_config.ref_date, target_name, keep_agg_levels)

# sort quantiles to avoid quantile crossing
preds_df = self._quantile_noncrossing(
preds_df,
gcols = ["location", "reference_date", "horizon", "target_end_date",
"target", "output_type"]
)
preds_df = self._quantile_noncrossing(preds_df, gcols = gcols)

return preds_df

Expand Down Expand Up @@ -293,9 +295,14 @@ def _get_test_quantile_predictions(self, run_config,
return test_pred_qs_df


def _format_as_flusight_output(self, preds_df, ref_date, target_name):
def _format_as_flusight_output(self, preds_df, ref_date, target_name, keep_agg_levels = False):
# keep just required columns and rename to match hub format
preds_df = preds_df[["location", "wk_end_date", "horizon", "quantile", "value"]] \
req_cols = ["location", "wk_end_date", "horizon", "quantile", "value"]

if keep_agg_levels:
req_cols.insert(0, "agg_level")

preds_df = preds_df[req_cols] \
.rename(columns={"quantile": "output_type_id"})

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
Expand Down
25 changes: 14 additions & 11 deletions src/idmodels/sarix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy as np
import pandas as pd
from iddata.loader import DiseaseDataLoader
Expand Down Expand Up @@ -40,12 +39,10 @@ def run(self, run_config):
if (run_config.states == []) & (run_config.hsas == []):
raise ValueError("User must request a non-empty set of locations to forecast for.")

if (run_config.states != []) & (run_config.hsas != []):
raise NotImplementedError("Functionality for simultaneously forecasting state- and hsa-level locations is not yet implemented.")

df_states = df.loc[(df["location"].isin(run_config.states)) & (df["agg_level"] != "hsa")]
df_hsas = df.loc[(df["location"].isin(run_config.hsas)) & (df["agg_level"] == "hsa")]
df = pd.concat([df_states, df_hsas], join = "inner", axis = 0)
df["unique_id"] = df["agg_level"] + df["location"]

# season week relative to christmas
df = df.merge(
Expand All @@ -61,8 +58,7 @@ def run(self, run_config):
# missing values are interpolated when possible
xy_colnames = self.model_config.x + ["inc_trans_cs"]
df = df.query("wk_end_date >= '2022-10-01'").interpolate()
unique_locations = len(df_states["location"].unique()) + len(df_hsas["location"].unique())
batched_xy = df[xy_colnames].values.reshape(unique_locations, -1, len(xy_colnames))
batched_xy = df[xy_colnames].values.reshape(len(df["unique_id"].unique()), -1, len(xy_colnames))

# Get any extra parameters for the SARIX constructor
extra_params = self._get_extra_sarix_params(df)
Expand All @@ -87,18 +83,18 @@ def run(self, run_config):
pred_qs = _np_percentile(sarix_fit_all_locs_theta_pooled.predictions[..., :, :, 0],
np.array(run_config.q_levels) * 100, axis=0)

df_data_last_obs = df.groupby(["location", "agg_level"]).tail(1)
df_data_last_obs = df.groupby(["unique_id", "agg_level"]).tail(1)

preds_df = pd.concat([
pd.DataFrame(pred_qs[i, :, :]) \
.set_axis(df_data_last_obs["location"], axis="index") \
.set_axis(df_data_last_obs["unique_id"], axis="index") \
.set_axis(np.arange(1, run_config.max_horizon+1), axis="columns") \
.assign(output_type_id = q_label) \
for i, q_label in enumerate(run_config.q_labels)
]) \
.reset_index() \
.melt(["location", "output_type_id"], var_name="horizon") \
.merge(df_data_last_obs, on="location", how="left")
.melt(["unique_id", "output_type_id"], var_name="horizon") \
.merge(df_data_last_obs, on="unique_id", how="left")

# build data frame with predictions on the original scale
preds_df["value"] = (preds_df["value"] + preds_df["inc_trans_center_factor"]) * preds_df["inc_trans_scale_factor"]
Expand All @@ -119,7 +115,14 @@ def run(self, run_config):
preds_df["value"] = np.minimum(preds_df["value"], 1.0)

# keep just required columns and rename to match hub format
preds_df = preds_df[["location", "wk_end_date", "horizon", "output_type_id", "value"]]
req_cols = ["location", "wk_end_date", "horizon", "output_type_id", "value"]

# we count national as state since it is coded using the same 2-digit fips code
preds_df["geo_level"] = np.where(preds_df["agg_level"] == "national", "state", preds_df["agg_level"])
if len(preds_df["geo_level"].unique()) > 1:
req_cols.insert(0, "agg_level")

preds_df = preds_df[req_cols]

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
preds_df["reference_date"] = run_config.ref_date
Expand Down

This file was deleted.

Loading