@@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):
10271027
10281028def _truncate_priors (ts , priors , progress = False ):
10291029 """
1030- Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1031- if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1032- sequence
1030+ Truncate priors for all nonfixed nodes
1031+ so they conform to the age of fixed nodes in the tree sequence
10331032 """
10341033 tables = ts .tables
1035- truncate_nodes = priors .nonfixed_node_ids ()
1036- # ensure truncate_nodes is ordered by node time
1037- truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
10381034
10391035 fixed_nodes = priors .fixed_node_ids ()
10401036 fixed_times = tables .nodes .time [fixed_nodes ]
@@ -1050,29 +1046,29 @@ def _truncate_priors(ts, priors, progress=False):
10501046 constrained_min_times = np .zeros_like (tables .nodes .time )
10511047 # Set the min times of fixed nodes to those in the tree sequence
10521048 constrained_min_times [fixed_nodes ] = fixed_times
1053- constrained_max_times = np .full_like (constrained_min_times , np .inf )
1054-
1055- parents = tables .edges .parent
1056- nd_children = tables .edges .child [np .argsort (parents )]
1057- parents = sorted (parents )
1058- parents_unique = np .unique (parents , return_index = True )
1059- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1060- for index , nd in tqdm (
1061- enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1049+
1050+ # Traverse through the ARG, ensuring children come before parents.
1051+ # This can be done by iterating over groups of edges with the same parent
1052+ new_parent_edge_idx = np .concatenate (
1053+ (
1054+ [0 ],
1055+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1056+ [tables .edges .num_rows ],
1057+ )
1058+ )
1059+ for edges_start , edges_end in zip (
1060+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
10621061 ):
1063- if index + 1 != len (truncate_nodes ):
1064- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1065- else :
1066- children_index = np .arange (parent_indices [index ], ts .num_edges )
1067- children = nd_children [children_index ]
1068- time = np .max (constrained_min_times [children ])
1069- # The constrained time of the node should be the age of the oldest child
1070- if constrained_min_times [nd ] <= time :
1071- constrained_min_times [nd ] = time
1072- nearest_time = np .argmin (np .abs (timepoints - time ))
1073- lookup_index = priors .row_lookup [int (nd )]
1074- grid_data [lookup_index ][:nearest_time ] = zero_value
1075- assert np .all (constrained_min_times < constrained_max_times )
1062+ parent = tables .edges .parent [edges_start ]
1063+ child_ids = tables .edges .child [edges_start :edges_end ] # May contain dups
1064+ oldest_child_time = np .max (constrained_min_times [child_ids ])
1065+ if oldest_child_time > constrained_min_times [parent ]:
1066+ constrained_min_times [parent ] = oldest_child_time
1067+ if constrained_min_times [parent ] > 0 :
1068+ # What if the parent here is a fixed node?
1069+ nearest_time = np .argmin (np .abs (timepoints - constrained_min_times [parent ]))
1070+ lookup_index = priors .row_lookup [parent ]
1071+ grid_data [lookup_index ][:nearest_time ] = zero_value
10761072
10771073 rowmax = grid_data [:, 1 :].max (axis = 1 )
10781074 if priors .probability_space == "linear" :
0 commit comments