From 706ec0559059e8ef2e5898ea49be4e898787882b Mon Sep 17 00:00:00 2001 From: Nicholas Reich Date: Sun, 16 Nov 2025 23:10:44 -0500 Subject: [PATCH 1/2] Add support for t-distributed innovations in SARIXModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extend _get_extra_sarix_params() to pass innovation_dist and innovation_df_prior_scale - Update SARIXFourierModel to properly merge base and Fourier parameters - Add test_sarix_tdist_innovations() integration test - Validates model runs successfully with t-distributed errors - Verifies output structure and prediction quality Depends on sarix library support for t-distributed innovations. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/idmodels/sarix.py | 20 +++++++-- tests/integration/test_sarix.py | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index c682153..b0139f4 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -14,7 +14,15 @@ def __init__(self, model_config): def _get_extra_sarix_params(self, df): """Return extra parameters to pass to SARIX constructor. Returns empty dict by default.""" - return {} + extra_params = {} + + # Add innovation distribution parameters if specified + if hasattr(self.model_config, 'innovation_dist'): + extra_params['innovation_dist'] = self.model_config.innovation_dist + if hasattr(self.model_config, 'innovation_df_prior_scale'): + extra_params['innovation_df_prior_scale'] = self.model_config.innovation_df_prior_scale + + return extra_params def run(self, run_config): fdl = DiseaseDataLoader() @@ -118,15 +126,21 @@ class SARIXFourierModel(SARIXModel): """ def _get_extra_sarix_params(self, df): """Return Fourier-specific parameters for SARIX constructor.""" + # Get base parameters (includes innovation_dist if specified) + extra_params = super()._get_extra_sarix_params(df) + # Extract day-of-year from dates for Fourier features # Take the first location's dates (same for all locations after reshaping) day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0] - return { + # Add Fourier-specific parameters + extra_params.update({ "day_of_year": day_of_year, "fourier_K": self.model_config.fourier_K, "fourier_pooling": self.model_config.fourier_pooling - } + }) + + return extra_params def _np_percentile(predictions, q_levels, axis): diff --git a/tests/integration/test_sarix.py b/tests/integration/test_sarix.py index 7619443..9449bbc 100644 --- a/tests/integration/test_sarix.py +++ b/tests/integration/test_sarix.py @@ -333,6 +333,78 @@ def test_sarix_fourier_missing_pooling_parameter(): f"Error should mention fourier_pooling, got: {str(e)}" +def test_sarix_tdist_innovations(tmp_path): + """Test SARIX model with t-distributed innovations.""" + model_config = SimpleNamespace( + model_class="sarix", + model_name="sarix_p6_4rt_thetashared_sigmanone_tdist", + + # data sources and adjustments for reporting issues + sources=["nhsn"], + + # fit locations separately or jointly + fit_locations_separately=False, + + # SARI model parameters + p=6, + P=0, + d=0, + D=0, + season_period=1, + + # power transform applied to surveillance signals + power_transform="4rt", + + # sharing of information about parameters + theta_pooling="shared", + sigma_pooling="none", + + # innovation distribution parameters + innovation_dist="t", + innovation_df_prior_scale=10.0, + + # covariates + x=[] + ) + + run_config = SimpleNamespace( + disease="flu", + ref_date=datetime.date.fromisoformat("2024-01-06"), + output_root=tmp_path / "model-output", + artifact_store_root=tmp_path / "artifact-store", + save_feat_importance=False, + locations=["US", "01", "02", "04", "05"], # Reduced for faster testing + max_horizon=2, # Reduced for faster testing + q_levels=[0.025, 0.50, 0.975], + q_labels=["0.025", "0.5", "0.975"], + num_warmup=100, # Reduced for faster testing + num_samples=100, + num_chains=1 + ) + + model = SARIXModel(model_config) + model.run(run_config) + + # Verify output structure + actual_df = pd.read_csv( + run_config.output_root / "UMass-sarix_p6_4rt_thetashared_sigmanone_tdist" / + "2024-01-06-UMass-sarix_p6_4rt_thetashared_sigmanone_tdist.csv" + ) + + # Assertions + assert len(actual_df) > 0, "Output dataframe should not be empty" + assert set(actual_df["location"].unique()) == set(run_config.locations), \ + "Output should contain predictions for all input locations" + assert all(actual_df["output_type"] == "quantile"), \ + "All outputs should be quantiles" + assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \ + "Output should contain all specified quantile levels" + assert actual_df["value"].notna().all(), \ + "All predictions should be non-null" + assert (actual_df["value"] >= 0).all(), \ + "All predictions should be non-negative" + + def _np_percentile_val(): return numpy.array( [[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01], From 8aaa5d5015a06e36cce2829991694f377b3f70f2 Mon Sep 17 00:00:00 2001 From: Nicholas Reich Date: Sun, 16 Nov 2025 23:32:20 -0500 Subject: [PATCH 2/2] Fix linting errors: Replace single quotes with double quotes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace single quotes with double quotes in sarix.py to comply with ruff Q000 rule. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/idmodels/sarix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index b0139f4..e9f5b38 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -17,10 +17,10 @@ def _get_extra_sarix_params(self, df): extra_params = {} # Add innovation distribution parameters if specified - if hasattr(self.model_config, 'innovation_dist'): - extra_params['innovation_dist'] = self.model_config.innovation_dist - if hasattr(self.model_config, 'innovation_df_prior_scale'): - extra_params['innovation_df_prior_scale'] = self.model_config.innovation_df_prior_scale + if hasattr(self.model_config, "innovation_dist"): + extra_params["innovation_dist"] = self.model_config.innovation_dist + if hasattr(self.model_config, "innovation_df_prior_scale"): + extra_params["innovation_df_prior_scale"] = self.model_config.innovation_df_prior_scale return extra_params