diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index c682153..e9f5b38 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -14,7 +14,15 @@ def __init__(self, model_config): def _get_extra_sarix_params(self, df): """Return extra parameters to pass to SARIX constructor. Returns empty dict by default.""" - return {} + extra_params = {} + + # Add innovation distribution parameters if specified + if hasattr(self.model_config, "innovation_dist"): + extra_params["innovation_dist"] = self.model_config.innovation_dist + if hasattr(self.model_config, "innovation_df_prior_scale"): + extra_params["innovation_df_prior_scale"] = self.model_config.innovation_df_prior_scale + + return extra_params def run(self, run_config): fdl = DiseaseDataLoader() @@ -118,15 +126,21 @@ class SARIXFourierModel(SARIXModel): """ def _get_extra_sarix_params(self, df): """Return Fourier-specific parameters for SARIX constructor.""" + # Get base parameters (includes innovation_dist if specified) + extra_params = super()._get_extra_sarix_params(df) + # Extract day-of-year from dates for Fourier features # Take the first location's dates (same for all locations after reshaping) day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0] - return { + # Add Fourier-specific parameters + extra_params.update({ "day_of_year": day_of_year, "fourier_K": self.model_config.fourier_K, "fourier_pooling": self.model_config.fourier_pooling - } + }) + + return extra_params def _np_percentile(predictions, q_levels, axis): diff --git a/tests/integration/test_sarix.py b/tests/integration/test_sarix.py index 7619443..9449bbc 100644 --- a/tests/integration/test_sarix.py +++ b/tests/integration/test_sarix.py @@ -333,6 +333,78 @@ def test_sarix_fourier_missing_pooling_parameter(): f"Error should mention fourier_pooling, got: {str(e)}" +def test_sarix_tdist_innovations(tmp_path): + """Test SARIX model with t-distributed innovations.""" + model_config = SimpleNamespace( + model_class="sarix", + model_name="sarix_p6_4rt_thetashared_sigmanone_tdist", + + # data sources and adjustments for reporting issues + sources=["nhsn"], + + # fit locations separately or jointly + fit_locations_separately=False, + + # SARI model parameters + p=6, + P=0, + d=0, + D=0, + season_period=1, + + # power transform applied to surveillance signals + power_transform="4rt", + + # sharing of information about parameters + theta_pooling="shared", + sigma_pooling="none", + + # innovation distribution parameters + innovation_dist="t", + innovation_df_prior_scale=10.0, + + # covariates + x=[] + ) + + run_config = SimpleNamespace( + disease="flu", + ref_date=datetime.date.fromisoformat("2024-01-06"), + output_root=tmp_path / "model-output", + artifact_store_root=tmp_path / "artifact-store", + save_feat_importance=False, + locations=["US", "01", "02", "04", "05"], # Reduced for faster testing + max_horizon=2, # Reduced for faster testing + q_levels=[0.025, 0.50, 0.975], + q_labels=["0.025", "0.5", "0.975"], + num_warmup=100, # Reduced for faster testing + num_samples=100, + num_chains=1 + ) + + model = SARIXModel(model_config) + model.run(run_config) + + # Verify output structure + actual_df = pd.read_csv( + run_config.output_root / "UMass-sarix_p6_4rt_thetashared_sigmanone_tdist" / + "2024-01-06-UMass-sarix_p6_4rt_thetashared_sigmanone_tdist.csv" + ) + + # Assertions + assert len(actual_df) > 0, "Output dataframe should not be empty" + assert set(actual_df["location"].unique()) == set(run_config.locations), \ + "Output should contain predictions for all input locations" + assert all(actual_df["output_type"] == "quantile"), \ + "All outputs should be quantiles" + assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \ + "Output should contain all specified quantile levels" + assert actual_df["value"].notna().all(), \ + "All predictions should be non-null" + assert (actual_df["value"] >= 0).all(), \ + "All predictions should be non-negative" + + def _np_percentile_val(): return numpy.array( [[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01],