Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/idmodels/sarix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ def __init__(self, model_config):

def _get_extra_sarix_params(self, df):
"""Return extra parameters to pass to SARIX constructor. Returns empty dict by default."""
return {}
extra_params = {}

# Add innovation distribution parameters if specified
if hasattr(self.model_config, "innovation_dist"):
extra_params["innovation_dist"] = self.model_config.innovation_dist
if hasattr(self.model_config, "innovation_df_prior_scale"):
extra_params["innovation_df_prior_scale"] = self.model_config.innovation_df_prior_scale

return extra_params

def run(self, run_config):
fdl = DiseaseDataLoader()
Expand Down Expand Up @@ -118,15 +126,21 @@ class SARIXFourierModel(SARIXModel):
"""
def _get_extra_sarix_params(self, df):
"""Return Fourier-specific parameters for SARIX constructor."""
# Get base parameters (includes innovation_dist if specified)
extra_params = super()._get_extra_sarix_params(df)

# Extract day-of-year from dates for Fourier features
# Take the first location's dates (same for all locations after reshaping)
day_of_year = df.groupby("location")["wk_end_date"].apply(lambda x: x.dt.dayofyear.values).iloc[0]

return {
# Add Fourier-specific parameters
extra_params.update({
"day_of_year": day_of_year,
"fourier_K": self.model_config.fourier_K,
"fourier_pooling": self.model_config.fourier_pooling
}
})

return extra_params


def _np_percentile(predictions, q_levels, axis):
Expand Down
72 changes: 72 additions & 0 deletions tests/integration/test_sarix.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,78 @@ def test_sarix_fourier_missing_pooling_parameter():
f"Error should mention fourier_pooling, got: {str(e)}"


def test_sarix_tdist_innovations(tmp_path):
"""Test SARIX model with t-distributed innovations."""
model_config = SimpleNamespace(
model_class="sarix",
model_name="sarix_p6_4rt_thetashared_sigmanone_tdist",

# data sources and adjustments for reporting issues
sources=["nhsn"],

# fit locations separately or jointly
fit_locations_separately=False,

# SARI model parameters
p=6,
P=0,
d=0,
D=0,
season_period=1,

# power transform applied to surveillance signals
power_transform="4rt",

# sharing of information about parameters
theta_pooling="shared",
sigma_pooling="none",

# innovation distribution parameters
innovation_dist="t",
innovation_df_prior_scale=10.0,

# covariates
x=[]
)

run_config = SimpleNamespace(
disease="flu",
ref_date=datetime.date.fromisoformat("2024-01-06"),
output_root=tmp_path / "model-output",
artifact_store_root=tmp_path / "artifact-store",
save_feat_importance=False,
locations=["US", "01", "02", "04", "05"], # Reduced for faster testing
max_horizon=2, # Reduced for faster testing
q_levels=[0.025, 0.50, 0.975],
q_labels=["0.025", "0.5", "0.975"],
num_warmup=100, # Reduced for faster testing
num_samples=100,
num_chains=1
)

model = SARIXModel(model_config)
model.run(run_config)

# Verify output structure
actual_df = pd.read_csv(
run_config.output_root / "UMass-sarix_p6_4rt_thetashared_sigmanone_tdist" /
"2024-01-06-UMass-sarix_p6_4rt_thetashared_sigmanone_tdist.csv"
)

# Assertions
assert len(actual_df) > 0, "Output dataframe should not be empty"
assert set(actual_df["location"].unique()) == set(run_config.locations), \
"Output should contain predictions for all input locations"
assert all(actual_df["output_type"] == "quantile"), \
"All outputs should be quantiles"
assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \
"Output should contain all specified quantile levels"
assert actual_df["value"].notna().all(), \
"All predictions should be non-null"
assert (actual_df["value"] >= 0).all(), \
"All predictions should be non-negative"


def _np_percentile_val():
return numpy.array(
[[[2.22541624e-01, 1.82324940e-01, 1.27709944e-01],
Expand Down
Loading