24
24
import numpyro
25
25
import numpyro .distributions as dist
26
26
import polars as pl
27
+ import pyrenew .regression as r
27
28
import pyrenew .transformation as t
28
29
import toml
29
30
from jax .typing import ArrayLike
36
37
InitializeInfectionsFromVec ,
37
38
logistic_susceptibility_adjustment ,
38
39
)
39
- from pyrenew .metaclass import (
40
- DistributionalRV ,
41
- Model ,
42
- RandomVariable ,
43
- SampledValue ,
44
- TransformedRandomVariable ,
45
- )
40
+ from pyrenew .metaclass import Model , RandomVariable , SampledValue
46
41
from pyrenew .observation import NegativeBinomialObservation
47
- from pyrenew .process import SimpleRandomWalkProcess
48
- from pyrenew .regression import GLMPrediction
42
+ from pyrenew .process import RandomWalk
43
+ from pyrenew .randomvariable import DistributionalVariable , TransformedVariable
49
44
50
45
FONT_PATH = "texgyreschola-regular.otf"
51
46
if os .path .exists (FONT_PATH ):
@@ -705,12 +700,12 @@ def __init__(
705
700
): # numpydoc ignore=GL08
706
701
logging .info ("Initializing CFAEPIM_Observation" )
707
702
708
- CFAEPIM_Observation .validate (
709
- predictors ,
710
- alpha_prior_dist ,
711
- coefficient_priors ,
712
- nb_concentration_prior ,
713
- )
703
+ # CFAEPIM_Observation.validate(
704
+ # predictors,
705
+ # alpha_prior_dist,
706
+ # coefficient_priors,
707
+ # nb_concentration_prior,
708
+ # )
714
709
715
710
self .predictors = predictors
716
711
self .alpha_prior_dist = alpha_prior_dist
@@ -728,9 +723,8 @@ def _init_alpha_t(self):
728
723
transformation.
729
724
"""
730
725
logging .info ("Initializing alpha process" )
731
- self .alpha_process = GLMPrediction (
726
+ self .alpha_process = r . GLMPrediction (
732
727
name = "alpha_t" ,
733
- fixed_predictor_values = self .predictors ,
734
728
intercept_prior = self .alpha_prior_dist ,
735
729
coefficient_priors = self .coefficient_priors ,
736
730
transform = t .SigmoidTransform ().inv ,
@@ -745,9 +739,9 @@ def _init_negative_binomial(self):
745
739
logging .info ("Initializing negative binomial process" )
746
740
self .nb_observation = NegativeBinomialObservation (
747
741
name = "negbinom_rv" ,
748
- concentration_rv = DistributionalRV (
742
+ concentration_rv = DistributionalVariable (
749
743
name = "nb_concentration" ,
750
- dist = self .nb_concentration_prior ,
744
+ distribution = self .nb_concentration_prior ,
751
745
),
752
746
)
753
747
@@ -815,8 +809,9 @@ def sample(
815
809
ascertainment values and the expected
816
810
hospitalizations.
817
811
"""
818
- alpha_samples = self .alpha_process .sample ()["prediction" ]
819
- alpha_samples = alpha_samples [: infections .shape [0 ]]
812
+
813
+ alpha_samples = self .alpha_process .sample (self .predictors )
814
+ alpha_samples = alpha_samples [0 ].value [: infections .shape [0 ]]
820
815
expected_hosp = (
821
816
alpha_samples
822
817
* jnp .convolve (infections , inf_to_hosp_dist , mode = "full" )[
@@ -917,24 +912,24 @@ def sample(self, n_steps: int, **kwargs) -> tuple: # numpydoc ignore=GL08
917
912
"Wt_rw_sd" , dist .HalfNormal (self .gamma_RW_prior_scale )
918
913
)
919
914
# Rt random walk process
920
- wt_rv = SimpleRandomWalkProcess (
915
+ init_rv = DistributionalVariable (
916
+ name = "init_Wt_rv" ,
917
+ distribution = self .intercept_RW_prior ,
918
+ )
919
+ wt_rv = RandomWalk (
921
920
name = "Wt" ,
922
- step_rv = DistributionalRV (
921
+ step_rv = DistributionalVariable (
923
922
name = "rw_step_rv" ,
924
- dist = dist .Normal (0 , sd_wt ),
923
+ distribution = dist .Normal (0 , sd_wt ),
925
924
reparam = LocScaleReparam (0 ),
926
925
),
927
- init_rv = DistributionalRV (
928
- name = "init_Wt_rv" ,
929
- dist = self .intercept_RW_prior ,
930
- ),
931
926
)
932
927
# transform Rt random walk w/ scaled logit
933
- transformed_rt_samples = TransformedRandomVariable (
928
+ transformed_rt_samples = TransformedVariable (
934
929
name = "transformed_rt_rw" ,
935
930
base_rv = wt_rv ,
936
931
transforms = t .ScaledLogitTransform (x_max = self .max_rt ).inv ,
937
- ).sample (n_steps = n_steps , ** kwargs )
932
+ ).sample (n = n_steps , init_vals = init_rv ()[ 0 ]. value )
938
933
# broadcast the Rt samples to daily values
939
934
broadcasted_rt_samples = transformed_rt_samples [0 ].value [
940
935
self .week_indices
@@ -1027,11 +1022,11 @@ def __init__(
1027
1022
# infections: initial infections
1028
1023
self .I0 = InfectionInitializationProcess (
1029
1024
name = "I0_initialization" ,
1030
- I_pre_init_rv = DistributionalRV (
1025
+ I_pre_init_rv = DistributionalVariable (
1031
1026
name = "I0" ,
1032
- dist = dist .Exponential ( rate = 1 / self . mean_inf_val ). expand (
1033
- [ self .inf_model_seed_days ]
1034
- ),
1027
+ distribution = dist .Exponential (
1028
+ rate = 1 / self .mean_inf_val
1029
+ ). expand ([ self . inf_model_seed_days ]) ,
1035
1030
),
1036
1031
infection_init_method = InitializeInfectionsFromVec (
1037
1032
n_timepoints = self .inf_model_seed_days
0 commit comments