@@ -947,29 +947,55 @@ def gamma_cdf(t_set, alpha, beta):
947947 return np .insert (t_set , 0 , 0 )
948948
949949
950- def fill_priors (node_parameters , timepoints , ts , Ne , * , prior_distr , progress = False ):
950+ def fill_priors (
951+ node_parameters ,
952+ timepoints ,
953+ ts ,
954+ Ne ,
955+ * ,
956+ prior_distr ,
957+ node_var_override = None ,
958+ progress = False ,
959+ ):
951960 """
952961 Take the alpha and beta values from the node_parameters array, which contains
953- one row for each node in the TS (including fixed nodes)
954- and fill out a NodeGridValues object with the prior values from the
955- gamma or lognormal distribution with those parameters.
962+ one row for each node in the TS (including fixed nodes, although alpha and beta
963+ are ignored for these nodes) and fill out a NodeGridValues object with the prior
964+ values from the gamma or lognormal distribution with those parameters.
965+
966+ For a description of `node_var_override`, see the parameter description in
967+ the `build_grid` function.
956968
957969 TODO - what if there is an internal fixed node? Should we truncate
958970 """
959971 if prior_distr == "lognorm" :
960972 cdf_func = scipy .stats .lognorm .cdf
961- main_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
973+ shape_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
962974 scale_param = np .exp (node_parameters [:, PriorParams .field_index ("alpha" )])
975+
976+ def shape_scale_from_mean_var (mean , var ):
977+ a , b = lognorm_approx (mean , var )
978+ return np .sqrt (b ), np .exp (a )
979+
963980 elif prior_distr == "gamma" :
964981 cdf_func = scipy .stats .gamma .cdf
965- main_param = node_parameters [:, PriorParams .field_index ("alpha" )]
966- scale_param = 1 / node_parameters [:, PriorParams .field_index ("beta" )]
982+ shape_param = node_parameters [:, PriorParams .field_index ("alpha" )]
983+ scale_param = 1.0 / node_parameters [:, PriorParams .field_index ("beta" )]
984+
985+ def shape_scale_from_mean_var (mean , var ):
986+ a , b = gamma_approx (mean , var )
987+ return a , 1.0 / b
988+
967989 else :
968990 raise ValueError ("prior distribution must be lognorm or gamma" )
969-
991+ if node_var_override is None :
992+ node_var_override = {}
970993 datable_nodes = np .ones (ts .num_nodes , dtype = bool )
971994 datable_nodes [ts .samples ()] = False
995+ # Mark all nodes in node_var_override as datable
996+ datable_nodes [list (node_var_override .keys ())] = True
972997 datable_nodes = np .where (datable_nodes )[0 ]
998+
973999 prior_times = base .NodeGridValues (
9741000 ts .num_nodes ,
9751001 datable_nodes [np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
@@ -980,8 +1006,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9801006 for node in tqdm (
9811007 datable_nodes , desc = "Assign Prior to Each Node" , disable = not progress
9821008 ):
1009+ if node in node_var_override :
1010+ shape , scale = shape_scale_from_mean_var (
1011+ mean = ts .node (node ).time ,
1012+ var = node_var_override [node ],
1013+ )
1014+ else :
1015+ shape = shape_param [node ]
1016+ scale = scale_param [node ]
9831017 with np .errstate (divide = "ignore" , invalid = "ignore" ):
984- prior_node = cdf_func (timepoints , main_param [ node ] , scale = scale_param [ node ] )
1018+ prior_node = cdf_func (timepoints , shape , scale = scale )
9851019 # force age to be less than max value
9861020 prior_node = np .divide (prior_node , np .max (prior_node ))
9871021 # prior in each epoch
@@ -994,7 +1028,7 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
9941028def _truncate_priors (ts , priors , progress = False ):
9951029 """
9961030 Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
997- if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1031+ if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
9981032 sequence
9991033 """
10001034 tables = ts .tables
@@ -1060,6 +1094,7 @@ def build_grid(
10601094 prior_distribution = "lognorm" ,
10611095 allow_historical_samples = None ,
10621096 truncate_priors = None ,
1097+ node_var_override = None ,
10631098 eps = 1e-6 ,
10641099 progress = False ,
10651100):
@@ -1094,6 +1129,13 @@ def build_grid(
10941129 priors of their direct ancestor nodes so that the probability of being younger
10951130 than the oldest descendant sample is zero. If the tree sequence is trustworthy
10961131 this should give better restults. Default: `True`
1132+ :param dict node_var_override: is a dict mapping node IDs to a variance value.
1133+ Any nodes listed here will be treated as non-fixed nodes whose prior is not
1134+ calculated from the conditional coalescent but instead are allocated a prior
1135+ whose mean is thenode time in the tree sequence and whose variance is the
1136+ value in this dictionary. This allows sample nodes to be treated as nonfixed
1137+ nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1138+ treated as occurring ata fixed time (as if this were an empty dict).
10971139 :param float eps: Specify minimum distance separating points in the time grid. Also
10981140 specifies the error factor in time difference calculations. Default: 1e-6
10991141 :return: A prior object to pass to tsdate.date() containing prior values for
@@ -1154,16 +1196,18 @@ def build_grid(
11541196 tree_sequence ,
11551197 Ne ,
11561198 prior_distr = prior_distribution ,
1199+ node_var_override = node_var_override ,
11571200 progress = progress ,
11581201 )
1159- if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1202+ tables = tree_sequence .tables
1203+ if np .any (tables .nodes .time [tree_sequence .samples ()] > 0 ):
11601204 if not allow_historical_samples :
11611205 raise ValueError (
11621206 "There are samples at non-zero times, invalidating the conditional "
11631207 "coalescent prior. You can set allow_historical_samples=True to carry "
11641208 "on regardless, calculating a prior as if all samples were "
11651209 "contemporaneous (reasonable if you only have a few ancient samples)"
11661210 )
1167- if truncate_priors :
1211+ if np . any ( tables . nodes . time [ priors . fixed_node_ids ()] > 0 ) and truncate_priors :
11681212 priors = _truncate_priors (tree_sequence , priors , progress = progress )
11691213 return priors
0 commit comments