@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151151 """
152152 ll = scipy .stats .poisson .pmf (muts , dt * mutation_rate * span )
153153 if normalize :
154- return ll / np .max (ll )
154+ return ll / np .nanmax (ll )
155155 else :
156156 return ll
157157
@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258258
259259 mutations_on_edge = self .mut_edges [edge .id ]
260260 child_time = self .ts .node (edge .child ).time
261- #assert child_time == 0
262- # Temporary hack - we should really take a more precise likelihood
263- return self ._lik (
264- mutations_on_edge ,
265- edge .span ,
266- self .timediff ,
267- self .mut_rate ,
268- normalize = self .normalize ,
269- )
261+ if child_time == 0 :
262+ return self ._lik (
263+ mutations_on_edge ,
264+ edge .span ,
265+ self .timediff ,
266+ self .mut_rate ,
267+ normalize = self .normalize ,
268+ )
269+ else :
270+ timediff = self .timepoints - child_time + 1e-8
271+ # Temporary hack - we should really take a more precise likelihood
272+ likelihood = self ._lik (
273+ mutations_on_edge ,
274+ edge .span ,
275+ timediff ,
276+ self .mut_rate ,
277+ normalize = self .normalize ,
278+ )
279+ # Prevent child from being older than parent
280+ likelihood [timediff < 0 ] = 0
281+
282+ return likelihood
270283
271284 def get_mut_lik_lower_tri (self , edge ):
272285 """
@@ -389,7 +402,7 @@ def get_fixed(self, arr, edge):
389402 return arr * liks
390403
391404 def scale_geometric (self , fraction , value ):
392- return value ** fraction
405+ return value ** fraction
393406
394407
395408class LogLikelihoods (Likelihoods ):
@@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
429442 """
430443 ll = scipy .stats .poisson .logpmf (muts , dt * mutation_rate * span )
431444 if normalize :
432- return ll - np .max (ll )
445+ return ll - np .nanmax (ll )
433446 else :
434447 return ll
435448
@@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
634647 inside = self .priors .clone_with_new_data ( # store inside matrix values
635648 grid_data = np .nan , fixed_data = self .lik .identity_constant
636649 )
650+ # It is possible that a simple node is non-fixed, in which case we want to
651+ # provide an inside array that reflects the prior distribution
652+ nonfixed_samples = np .intersect1d (inside .nonfixed_node_ids (), self .ts .samples ())
653+ for u in nonfixed_samples :
654+ # this is in the same probability space as the prior, so we should be
655+ # OK just to copy the prior values straight in. It's unclear to me (Yan)
656+ # how/if they should be normalised, however
657+ inside [u ][:] = self .priors [u ]
658+
637659 if cache_inside :
638660 g_i = np .full (
639661 (self .ts .num_edges , self .lik .grid_size ), self .lik .identity_constant
640662 )
641663 norm = np .full (self .ts .num_nodes , np .nan )
664+ to_visit = np .zeros (self .ts .num_nodes , dtype = bool )
665+ to_visit [inside .nonfixed_node_ids ()] = True
642666 # Iterate through the nodes via groupby on parent node
643667 for parent , edges in tqdm (
644668 self .edges_by_parent_asc (),
@@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
673697 "dangling nodes: please simplify it"
674698 )
675699 daughter_val = self .lik .scale_geometric (
676- spanfrac , self .lik .make_lower_tri (inside [ edge . child ] )
700+ spanfrac , self .lik .make_lower_tri (inside_values )
677701 )
678702 edge_lik = self .lik .get_inside (daughter_val , edge )
679703 val = self .lik .combine (val , edge_lik )
704+ if np .all (val == 0 ):
705+ raise ValueError
680706 if cache_inside :
681707 g_i [edge .id ] = edge_lik
682- norm [parent ] = np .max (val ) if normalize else 1
708+ norm [parent ] = np .max (val ) if normalize else self . lik . identity_constant
683709 inside [parent ] = self .lik .reduce (val , norm [parent ])
710+ to_visit [parent ] = False
711+
712+ # There may be nodes that are not parents but are also not fixed (e.g.
713+ # undated sample nodes). These need an identity normalization constant
714+ for unfixed_unvisited in np .where (to_visit )[0 ]:
715+ norm [unfixed_unvisited ] = self .lik .identity_constant
716+
684717 if cache_inside :
685718 self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
686719 # Keep the results in this object
@@ -732,10 +765,10 @@ def outside_pass(
732765 if ignore_oldest_root :
733766 if edge .parent == self .ts .num_nodes - 1 :
734767 continue
735- # if edge.parent in self.fixednodes:
736- # raise RuntimeError(
737- # "Fixed nodes cannot currently be parents in the TS"
738- # )
768+ if edge .parent in self .fixednodes :
769+ raise RuntimeError (
770+ "Fixed nodes cannot currently be parents in the TS"
771+ )
739772 # Geometric scaling works exactly for all nodes fixed in graph
740773 # but is an approximation when times are unknown.
741774 spanfrac = edge .span / self .spans [child ]
@@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
897930 return ts , mn_post , vr_post
898931
899932
900- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
933+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
901934 """
902- If predicted node times violate topology, restrict node ages so that they
903- must be older than all their children.
935+ If node_times violate topology, return increased node_times so that each node is
936+ guaranteed to be older than any of its their children.
904937 """
905- new_mn_post = np .copy (post_mn )
906- if nodes_to_date is None :
907- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
908- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
909-
910- tables = ts .tables
911- parents = tables .edges .parent
912- nd_children = tables .edges .child [np .argsort (parents )]
913- parents = sorted (parents )
914- parents_unique = np .unique (parents , return_index = True )
915- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
916- for index , nd in tqdm (
917- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
938+ edges_parent = ts .edges_parent
939+ edges_child = ts .edges_child
940+
941+ new_node_times = np .copy (node_times )
942+ # Traverse through the ARG, ensuring children come before parents.
943+ # This can be done by iterating over groups of edges with the same parent
944+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
945+ for edges_start , edges_end in tqdm (
946+ zip (
947+ itertools .chain ([0 ], new_parent_edge_idx ),
948+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
949+ ),
950+ desc = "Constrain Ages" ,
951+ disable = not progress ,
918952 ):
919- if index + 1 != len (nodes_to_date ):
920- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
921- else :
922- children_index = np .arange (parent_indices [index ], ts .num_edges )
923- children = nd_children [children_index ]
924- time = np .max (new_mn_post [children ])
925- if new_mn_post [nd ] <= time :
926- new_mn_post [nd ] = time + eps
927- return new_mn_post
953+ parent = edges_parent [edges_start ]
954+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
955+ oldest_child_time = np .max (new_node_times [child_ids ])
956+ if oldest_child_time >= new_node_times [parent ]:
957+ new_node_times [parent ] = oldest_child_time + eps
958+ return new_node_times
928959
929960
930961def date (
@@ -1015,7 +1046,7 @@ def date(
10151046 progress = progress ,
10161047 ** kwargs
10171048 )
1018- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1049+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
10191050 tables = tree_sequence .dump_tables ()
10201051 tables .time_units = time_units
10211052 tables .nodes .time = constrained
@@ -1064,12 +1095,6 @@ def get_dates(
10641095
10651096 :return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10661097 """
1067- # Stuff yet to be implemented. These can be deleted once fixed
1068- #for sample in tree_sequence.samples():
1069- # if tree_sequence.node(sample).time != 0:
1070- # raise NotImplementedError("Samples must all be at time 0")
1071- fixed_nodes = set (tree_sequence .samples ())
1072-
10731098 # Default to not creating approximate priors unless ts has > 1000 samples
10741099 approx_priors = False
10751100 if tree_sequence .num_samples > 1000 :
@@ -1097,6 +1122,8 @@ def get_dates(
10971122 )
10981123 priors = priors
10991124
1125+ fixed_nodes = set (priors .fixed_node_ids ())
1126+
11001127 if probability_space != base .LOG :
11011128 liklhd = Likelihoods (
11021129 tree_sequence ,
0 commit comments