@@ -541,48 +541,79 @@ class InterventionTimeEstimator(PyMCModel):
541541 ... t,
542542 ... y,
543543 ... coords,
544- ... effect=[ "impulse"]
544+ ... priors={ "impulse":[]}
545545 ... )
546546 Inference data...
547547 """
548548
549- def build_model (self , t , y , coords , effect , span , grain_season ):
549+ def build_model (self , t , y , coords , time_range , grain_season , priors ):
550550 """
551551 Defines the PyMC model
552552
553553 :param t: An array of values representing the time over which y is spread
554554 :param y: An array of values representing our outcome y
555- :param coords: A dictionary with the coordinate names for our instruments
555+ :param coords: An optional dictionary with the coordinate names for our instruments.
556+ In particular, used to determine the number of seasons.
557+ :param time_range: An optional tuple providing a specific time_range where the
558+ intervention effect should have taken place.
559+ :param priors: An optional dictionary of priors for the parameters of the
560+ different distributions.
561+ :code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}`
556562 """
557563
558564 with self :
559565 self .add_coords (coords )
560566
561- if span is None :
562- span = (t .min (), t .max ())
567+ if time_range is None :
568+ time_range = (t .min (), t .max ())
563569
564570 # --- Priors ---
565- switchpoint = pm .Uniform ("switchpoint" , lower = span [0 ], upper = span [1 ])
566- alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 10 )
567- beta = pm .Normal (name = "beta" , mu = 0 , sigma = 10 )
571+ switchpoint = pm .Uniform (
572+ "switchpoint" , lower = time_range [0 ], upper = time_range [1 ]
573+ )
574+ alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 50 )
575+ beta = pm .Normal (name = "beta" , mu = 0 , sigma = 50 )
568576 seasons = 0
569577 if "seasons" in coords and len (coords ["seasons" ]) > 0 :
570578 season_idx = np .arange (len (y )) // grain_season % len (coords ["seasons" ])
571- season_effect = pm .Normal ("season" , mu = 0 , sigma = 1 , dims = "seasons" )
572- seasons = season_effect [season_idx ]
579+ seasons_effect = pm .Normal (
580+ "seasons_effect" , mu = 0 , sigma = 50 , dims = "seasons"
581+ )
582+ seasons = seasons_effect [season_idx ]
573583
574584 # --- Intervention effect ---
575585 level = trend = impulse = 0
576586
577- if "level" in effect :
578- level = pm .Normal ("level" , mu = 0 , sigma = 10 )
579-
580- if "trend" in effect :
581- trend = pm .Normal ("trend" , mu = 0 , sigma = 10 )
582-
583- if "impulse" in effect :
584- impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = 0 , sigma = 1 )
585- decay_rate = pm .HalfNormal ("decay_rate" , sigma = 1 )
587+ if "level" in priors :
588+ mu , sigma = (
589+ (0 , 50 )
590+ if len (priors ["level" ]) != 2
591+ else (priors ["level" ][0 ], priors ["level" ][1 ])
592+ )
593+ level = pm .Normal (
594+ "level" ,
595+ mu = mu ,
596+ sigma = sigma ,
597+ )
598+ if "trend" in priors :
599+ mu , sigma = (
600+ (0 , 50 )
601+ if len (priors ["trend" ]) != 2
602+ else (priors ["trend" ][0 ], priors ["trend" ][1 ])
603+ )
604+ trend = pm .Normal ("trend" , mu = mu , sigma = sigma )
605+ if "impulse" in priors :
606+ mu , sigma1 , sigma2 = (
607+ (0 , 50 , 50 )
608+ if len (priors ["impulse" ]) != 3
609+ else (
610+ priors ["impulse" ][0 ],
611+ priors ["impulse" ][1 ],
612+ priors ["impulse" ][2 ],
613+ )
614+ )
615+ impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = mu , sigma = sigma1 )
616+ decay_rate = pm .HalfNormal ("decay_rate" , sigma = sigma2 )
586617 impulse = impulse_amplitude * pm .math .exp (
587618 - decay_rate * abs (t - switchpoint )
588619 )
@@ -597,16 +628,16 @@ def build_model(self, t, y, coords, effect, span, grain_season):
597628 )
598629 # Compute and store the the sum of the intervention and the time series
599630 mu = pm .Deterministic ("mu" , mu_ts + weight * mu_in )
631+ sigma = pm .HalfNormal ("sigma" , 1 )
600632
601633 # --- Likelihood ---
602- pm .Normal ("y_hat" , mu = mu , sigma = 2 , observed = y )
634+ pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = y )
603635
604- def fit (self , t , y , coords , effect = [], span = None , grain_season = 1 , n = 1000 ):
636+ def fit (self , t , y , coords , time_range = None , grain_season = 1 , priors = {} , n = 1000 ):
605637 """
606638 Draw samples from posterior distribution
607639 """
608- self .sample_kwargs ["progressbar" ] = False
609- self .build_model (t , y , coords , effect , span , grain_season )
640+ self .build_model (t , y , coords , time_range , grain_season , priors )
610641 with self :
611- self .idata = pm .sample (n , ** self .sample_kwargs )
642+ self .idata = pm .sample (n , progressbar = False , ** self .sample_kwargs )
612643 return self .idata
0 commit comments