2222"""
2323Routines and classes for creating priors and timeslices for use in tsdate
2424"""
25+ import itertools
2526import logging
2627import os
2728from collections import defaultdict
@@ -1030,10 +1031,8 @@ def _truncate_priors(ts, priors, progress=False):
10301031 Truncate priors for all nonfixed nodes
10311032 so they conform to the age of fixed nodes in the tree sequence
10321033 """
1033- tables = ts .tables
1034-
10351034 fixed_nodes = priors .fixed_node_ids ()
1036- fixed_times = tables . nodes . time [fixed_nodes ]
1035+ fixed_times = ts . nodes_time [fixed_nodes ]
10371036
10381037 grid_data = np .copy (priors .grid_data [:])
10391038 timepoints = priors .timepoints
@@ -1043,24 +1042,25 @@ def _truncate_priors(ts, priors, progress=False):
10431042 zero_value = 0
10441043 elif priors .probability_space == "logarithmic" :
10451044 zero_value = - np .inf
1046- constrained_min_times = np .zeros_like (tables . nodes . time )
1045+ constrained_min_times = np .zeros_like (ts . nodes_time )
10471046 # Set the min times of fixed nodes to those in the tree sequence
10481047 constrained_min_times [fixed_nodes ] = fixed_times
10491048
10501049 # Traverse through the ARG, ensuring children come before parents.
10511050 # This can be done by iterating over groups of edges with the same parent
1052- new_parent_edge_idx = np .concatenate (
1053- (
1054- [0 ],
1055- np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1056- [tables .edges .num_rows ],
1057- )
1058- )
1059- for edges_start , edges_end in zip (
1060- new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
1051+ edges_parent = ts .edges_parent
1052+ edges_child = ts .edges_child
1053+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
1054+ for edges_start , edges_end in tqdm (
1055+ zip (
1056+ itertools .chain ([0 ], new_parent_edge_idx ),
1057+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
1058+ ),
1059+ desc = "Trunc priors" ,
1060+ disable = not progress ,
10611061 ):
1062- parent = tables . edges . parent [edges_start ]
1063- child_ids = tables . edges . child [edges_start :edges_end ] # May contain dups
1062+ parent = edges_parent [edges_start ]
1063+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
10641064 oldest_child_time = np .max (constrained_min_times [child_ids ])
10651065 if oldest_child_time > constrained_min_times [parent ]:
10661066 if priors .is_fixed (parent ):
@@ -1198,15 +1198,17 @@ def build_grid(
11981198 node_var_override = node_var_override ,
11991199 progress = progress ,
12001200 )
1201- tables = tree_sequence .tables
1202- if np .any (tables .nodes .time [tree_sequence .samples ()] > 0 ):
1201+ if np .any (tree_sequence .nodes_time [tree_sequence .samples ()] > 0 ):
12031202 if not allow_historical_samples :
12041203 raise ValueError (
12051204 "There are samples at non-zero times, invalidating the conditional "
12061205 "coalescent prior. You can set allow_historical_samples=True to carry "
12071206 "on regardless, calculating a prior as if all samples were "
12081207 "contemporaneous (reasonable if you only have a few ancient samples)"
12091208 )
1210- if np .any (tables .nodes .time [priors .fixed_node_ids ()] > 0 ) and truncate_priors :
1209+ if (
1210+ np .any (tree_sequence .nodes_time [priors .fixed_node_ids ()] > 0 )
1211+ and truncate_priors
1212+ ):
12111213 priors = _truncate_priors (tree_sequence , priors , progress = progress )
12121214 return priors
0 commit comments