Skip to content

Commit 0e883d6

Browse files
committed
package updates; AR process and RV PR taken into account
1 parent ece3f47 commit 0e883d6

File tree

2 files changed

+30
-34
lines changed

2 files changed

+30
-34
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ polars = "^1.5.0"
1414
numpyro = "^0.15.2"
1515
arviz = "^0.19.0"
1616
pyrenew = {git = "https://github.com/CDCgov/PyRenew", rev = "main"}
17+
toml = "^0.10.2"
1718

1819
[tool.poetry.group.dev.dependencies]
1920
rpy2 = "^3.5.16"

pyrenew_flu_light/tut_epim_port_msr.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpyro
2525
import numpyro.distributions as dist
2626
import polars as pl
27+
import pyrenew.regression as r
2728
import pyrenew.transformation as t
2829
import toml
2930
from jax.typing import ArrayLike
@@ -36,16 +37,10 @@
3637
InitializeInfectionsFromVec,
3738
logistic_susceptibility_adjustment,
3839
)
39-
from pyrenew.metaclass import (
40-
DistributionalRV,
41-
Model,
42-
RandomVariable,
43-
SampledValue,
44-
TransformedRandomVariable,
45-
)
40+
from pyrenew.metaclass import Model, RandomVariable, SampledValue
4641
from pyrenew.observation import NegativeBinomialObservation
47-
from pyrenew.process import SimpleRandomWalkProcess
48-
from pyrenew.regression import GLMPrediction
42+
from pyrenew.process import RandomWalk
43+
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
4944

5045
FONT_PATH = "texgyreschola-regular.otf"
5146
if os.path.exists(FONT_PATH):
@@ -705,12 +700,12 @@ def __init__(
705700
): # numpydoc ignore=GL08
706701
logging.info("Initializing CFAEPIM_Observation")
707702

708-
CFAEPIM_Observation.validate(
709-
predictors,
710-
alpha_prior_dist,
711-
coefficient_priors,
712-
nb_concentration_prior,
713-
)
703+
# CFAEPIM_Observation.validate(
704+
# predictors,
705+
# alpha_prior_dist,
706+
# coefficient_priors,
707+
# nb_concentration_prior,
708+
# )
714709

715710
self.predictors = predictors
716711
self.alpha_prior_dist = alpha_prior_dist
@@ -728,9 +723,8 @@ def _init_alpha_t(self):
728723
transformation.
729724
"""
730725
logging.info("Initializing alpha process")
731-
self.alpha_process = GLMPrediction(
726+
self.alpha_process = r.GLMPrediction(
732727
name="alpha_t",
733-
fixed_predictor_values=self.predictors,
734728
intercept_prior=self.alpha_prior_dist,
735729
coefficient_priors=self.coefficient_priors,
736730
transform=t.SigmoidTransform().inv,
@@ -745,9 +739,9 @@ def _init_negative_binomial(self):
745739
logging.info("Initializing negative binomial process")
746740
self.nb_observation = NegativeBinomialObservation(
747741
name="negbinom_rv",
748-
concentration_rv=DistributionalRV(
742+
concentration_rv=DistributionalVariable(
749743
name="nb_concentration",
750-
dist=self.nb_concentration_prior,
744+
distribution=self.nb_concentration_prior,
751745
),
752746
)
753747

@@ -815,8 +809,9 @@ def sample(
815809
ascertainment values and the expected
816810
hospitalizations.
817811
"""
818-
alpha_samples = self.alpha_process.sample()["prediction"]
819-
alpha_samples = alpha_samples[: infections.shape[0]]
812+
813+
alpha_samples = self.alpha_process.sample(self.predictors)
814+
alpha_samples = alpha_samples[0].value[: infections.shape[0]]
820815
expected_hosp = (
821816
alpha_samples
822817
* jnp.convolve(infections, inf_to_hosp_dist, mode="full")[
@@ -917,24 +912,24 @@ def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
917912
"Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale)
918913
)
919914
# Rt random walk process
920-
wt_rv = SimpleRandomWalkProcess(
915+
init_rv = DistributionalVariable(
916+
name="init_Wt_rv",
917+
distribution=self.intercept_RW_prior,
918+
)
919+
wt_rv = RandomWalk(
921920
name="Wt",
922-
step_rv=DistributionalRV(
921+
step_rv=DistributionalVariable(
923922
name="rw_step_rv",
924-
dist=dist.Normal(0, sd_wt),
923+
distribution=dist.Normal(0, sd_wt),
925924
reparam=LocScaleReparam(0),
926925
),
927-
init_rv=DistributionalRV(
928-
name="init_Wt_rv",
929-
dist=self.intercept_RW_prior,
930-
),
931926
)
932927
# transform Rt random walk w/ scaled logit
933-
transformed_rt_samples = TransformedRandomVariable(
928+
transformed_rt_samples = TransformedVariable(
934929
name="transformed_rt_rw",
935930
base_rv=wt_rv,
936931
transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv,
937-
).sample(n_steps=n_steps, **kwargs)
932+
).sample(n=n_steps, init_vals=init_rv()[0].value)
938933
# broadcast the Rt samples to daily values
939934
broadcasted_rt_samples = transformed_rt_samples[0].value[
940935
self.week_indices
@@ -1027,11 +1022,11 @@ def __init__(
10271022
# infections: initial infections
10281023
self.I0 = InfectionInitializationProcess(
10291024
name="I0_initialization",
1030-
I_pre_init_rv=DistributionalRV(
1025+
I_pre_init_rv=DistributionalVariable(
10311026
name="I0",
1032-
dist=dist.Exponential(rate=1 / self.mean_inf_val).expand(
1033-
[self.inf_model_seed_days]
1034-
),
1027+
distribution=dist.Exponential(
1028+
rate=1 / self.mean_inf_val
1029+
).expand([self.inf_model_seed_days]),
10351030
),
10361031
infection_init_method=InitializeInfectionsFromVec(
10371032
n_timepoints=self.inf_model_seed_days

0 commit comments

Comments
 (0)