@@ -653,7 +653,10 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
653653 self .priors .nonfixed_node_ids (), self .ts .samples ()
654654 )
655655 for u in nonfixed_samples :
656- inside [u ][:] = self .priors [u ]
656+ # this is in the same probability space as the prior, so we should be
657+ # OK just to copy the prior values straight in (but we should check they
658+ # are normalised so that they sum to unity)
659+ inside [u ][:] = self .priors .sum_to_unity (self .priors [u ])
657660
658661 if cache_inside :
659662 g_i = np .full (
@@ -922,34 +925,31 @@ def posterior_mean_var(ts, timepoints, posterior, *, fixed_node_set=None):
922925 return ts , mn_post , vr_post
923926
924927
925- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
928+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
926929 """
927- If predicted node times violate topology, restrict node ages so that they
928- must be older than all their children.
930+ If node_times violate topology, return increased node_times so that each node is
931+ guaranteed to be older than any of its their children.
929932 """
930- new_mn_post = np .copy (post_mn )
931- if nodes_to_date is None :
932- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
933- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
934-
935933 tables = ts .tables
936- parents = tables .edges .parent
937- nd_children = tables .edges .child [np .argsort (parents )]
938- parents = sorted (parents )
939- parents_unique = np .unique (parents , return_index = True )
940- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
941- for index , nd in tqdm (
942- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
934+ new_node_times = np .copy (node_times )
935+ # Traverse through the ARG, ensuring children come before parents.
936+ # This can be done by iterating over groups of edges with the same parent
937+ new_parent_edge_idx = np .concatenate (
938+ (
939+ [0 ],
940+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
941+ [tables .edges .num_rows ],
942+ )
943+ )
944+ for edges_start , edges_end in zip (
945+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
943946 ):
944- if index + 1 != len (nodes_to_date ):
945- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
946- else :
947- children_index = np .arange (parent_indices [index ], ts .num_edges )
948- children = nd_children [children_index ]
949- time = np .max (new_mn_post [children ])
950- if new_mn_post [nd ] <= time :
951- new_mn_post [nd ] = time + eps
952- return new_mn_post
947+ parent = tables .edges .parent [edges_start ]
948+ child_ids = tables .edges .child [edges_start :edges_end ] # May contain dups
949+ oldest_child_time = np .max (new_node_times [child_ids ])
950+ if oldest_child_time >= new_node_times [parent ]:
951+ new_node_times [parent ] = oldest_child_time + eps
952+ return new_node_times
953953
954954
955955def date (
@@ -1040,7 +1040,7 @@ def date(
10401040 progress = progress ,
10411041 ** kwargs
10421042 )
1043- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1043+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
10441044 tables = tree_sequence .dump_tables ()
10451045 tables .time_units = time_units
10461046 tables .nodes .time = constrained
0 commit comments