diff --git a/ALE_python/__init__.py b/ALE_python/__init__.py new file mode 100644 index 0000000..4802e90 --- /dev/null +++ b/ALE_python/__init__.py @@ -0,0 +1 @@ +__version__ = "1.0" diff --git a/ALE_python/__main__.py b/ALE_python/__main__.py new file mode 100644 index 0000000..cf4fd4c --- /dev/null +++ b/ALE_python/__main__.py @@ -0,0 +1,6 @@ +"""Allow running ALE_python as: python -m ALE_python [args...]""" + +from .cli import main + +if __name__ == "__main__": + main() diff --git a/ALE_python/ale.py b/ALE_python/ale.py new file mode 100644 index 0000000..8f53ed7 --- /dev/null +++ b/ALE_python/ale.py @@ -0,0 +1,971 @@ +""" +Core approx_posterior class for the ALE Python port. + +Direct port of ALE.h / ALE.cpp. Uses Python integers as bitsets +(bit i set means species i is in the set). Leaf IDs are 1-based +(bit 0 is unused). +""" + +import math +import random + +from .newick import parse_newick, get_leaves, get_all_nodes, is_leaf, to_newick + + +def double_factorial(n): + """Compute n!! = n * (n-2) * (n-4) * ... * 1 (or 2). + + Returns 1 for n <= 0. + """ + if n <= 0: + return 1.0 + result = 1.0 + k = n + while k > 0: + result *= k + k -= 2 + return result + + +def _popcount(x): + """Count set bits in integer *x*.""" + return bin(x).count("1") + + +def _bits_to_list(x, size): + """Return sorted list of set bit positions in *x* (up to bit *size*).""" + out = [] + for i in range(1, size + 1): + if x & (1 << i): + out.append(i) + return out + + +# --------------------------------------------------------------------------- +# Unrooting helper +# --------------------------------------------------------------------------- + +def _build_unrooted_adjacency(root): + """Build an unrooted adjacency representation from a parsed tree. + + The C++ code calls ``G->unroot()`` which, for a rooted binary tree, + removes the root node and connects both subtrees into a trifurcation + at one of the root's children. + + Returns + ------- + neighbor : dict[int, list[int]] + node-id -> list of adjacent node-ids + node_map : dict[int, Node] + node-id -> Node object + """ + all_nodes = get_all_nodes(root) + node_map = {n.id: n for n in all_nodes} + neighbor = {} + + # Determine if the tree is rooted (root has exactly 2 children) + if len(root.children) == 2: + # Unroot: merge the root's two children. The root is removed and + # the two children become neighbours of each other in addition to + # their existing children. + left, right = root.children + # Build normal adjacency for every node except root + for n in all_nodes: + if n is root: + continue + adj = [] + for c in n.children: + adj.append(c.id) + if n.parent is not None and n.parent is not root: + adj.append(n.parent.id) + neighbor[n.id] = adj + + # Connect left <-> right through the removed root + neighbor[left.id].append(right.id) + neighbor[right.id].append(left.id) + else: + # Already unrooted (or multifurcation at root) + for n in all_nodes: + adj = [] + for c in n.children: + adj.append(c.id) + if n.parent is not None: + adj.append(n.parent.id) + neighbor[n.id] = adj + + return neighbor, node_map + + +# --------------------------------------------------------------------------- +# approx_posterior +# --------------------------------------------------------------------------- + +class approx_posterior: + """Approximate posterior on phylogenetic trees via clade and conditional + clade probabilities. Direct port of the C++ ``approx_posterior`` class. + """ + + def __init__(self, tree_string=None): + self.observations = 0.0 + self.constructor_string = "" + self.alpha = 0.0 + self.beta = 0.0 + self.leaf_ids = {} # name -> id (1-based) + self.id_leaves = {} # id -> name + self.Bip_counts = {} # bipartition_id -> count + self.Bip_bls = {} # bipartition_id -> sum of branch lengths + self.Dip_counts = [{}] # indexed by g_id; keys are (gp_id, gpp_id) + self.set_ids = {} # bitset -> bipartition_id + self.id_sets = {} # bipartition_id -> bitset + self.Gamma = 0 # bitset with all species + self.Gamma_size = 0 + self.K_Gamma = 0.0 + self.N_Gamma = 0.0 + self.last_leafset_id = 0 + self.set_sizes = {} # bipartition_id -> set size + self.size_ordered_bips = {} # size -> [bipartition_ids] + self.name_separator = "+" + + if tree_string is not None: + self.constructor_string = tree_string + self.construct(tree_string) + + # ------------------------------------------------------------------ + # construct + # ------------------------------------------------------------------ + + def construct(self, tree_string): + """Parse a Newick tree to initialise leaf mappings and Gamma.""" + self.last_leafset_id = 0 + self.observations = 0.0 + + tree_string = tree_string.strip() + + # Get leaf names -- from a tree or comma-separated list + if tree_string.startswith("("): + root = parse_newick(tree_string) + leaves = get_leaves(root) + leaf_names = sorted(n.name for n in leaves) + else: + leaf_names = sorted( + s.strip() for s in tree_string.split(",") if s.strip() + ) + + # Assign 1-based IDs in alphabetical order (matching C++ std::map) + self.leaf_ids = {} + self.id_leaves = {} + for idx, name in enumerate(leaf_names, start=1): + self.leaf_ids[name] = idx + self.id_leaves[idx] = name + + self.alpha = 0.0 + self.beta = 0.0 + self.Gamma_size = len(leaf_names) + + # Gamma bitset: bits 1..Gamma_size set, bit 0 unused + self.Gamma = 0 + for i in range(1, self.Gamma_size + 1): + self.Gamma |= (1 << i) + + # Number of bipartitions of Gamma + self.K_Gamma = 2.0 ** (self.Gamma_size - 1) - 1 + + # Number of unrooted trees on Gamma_size leaves + if self.Gamma_size < 3: + self.N_Gamma = 1.0 + else: + self.N_Gamma = double_factorial(2 * self.Gamma_size - 5) + + # Pre-allocate Dip_counts entries (matching C++ push_back(temp)) + self.Dip_counts = [{}] + while len(leaf_names) + 1 > len(self.Dip_counts): + self.Dip_counts.append({}) + + # ------------------------------------------------------------------ + # set2id + # ------------------------------------------------------------------ + + def set2id(self, leaf_set): + """Return existing ID or create a new one for *leaf_set* (int bitset).""" + if leaf_set in self.set_ids: + sid = self.set_ids[leaf_set] + if sid != 0: + return sid + + self.last_leafset_id += 1 + self.set_ids[leaf_set] = self.last_leafset_id + self.id_sets[self.last_leafset_id] = leaf_set + # Ensure Dip_counts has enough entries + while self.last_leafset_id >= len(self.Dip_counts): + self.Dip_counts.append({}) + self.Bip_bls[self.last_leafset_id] = 0.0 + return self.last_leafset_id + + # ------------------------------------------------------------------ + # decompose + # ------------------------------------------------------------------ + + def decompose(self, G_string, weight=1.0): + """Parse a tree and extract bipartitions, updating counts.""" + # Push an empty dict at the start (matching C++ behaviour) + self.Dip_counts.append({}) + + root = parse_newick(G_string) + neighbor, node_map = _build_unrooted_adjacency(root) + + # Build directed edges (from_id, to_id) -> count + dedges = {} + for from_id, adj in neighbor.items(): + for to_id in adj: + dedges[(from_id, to_id)] = 0 + + flat_names = {} # (from_id, to_id) -> bitset + + # Name all leaf edges + for dedge in list(dedges.keys()): + from_id, to_id = dedge + node = node_map[from_id] + if is_leaf(node): + leaf_bit = 1 << self.leaf_ids[node.name] + flat_names[dedge] = leaf_bit + + # Register leaf set and accumulate branch length + g_id = self.set2id(leaf_bit) + bl = node.branch_length if node.branch_length is not None else 0.0 + self.Bip_bls[g_id] += bl + + # Mark done + dedges[dedge] = -1 + + # Increment counts for outgoing edges from 'to' + for adj_id in neighbor[to_id]: + if adj_id != from_id: + dedge_out = (to_id, adj_id) + if dedge_out in dedges: + dedges[dedge_out] += 1 + + # Special case: only 2 leaves + leaf_count = sum(1 for n in node_map.values() if is_leaf(n)) + if leaf_count == 2: + self.Bip_counts[1] = self.Bip_counts.get(1, 0.0) + weight + return + + # Iteratively resolve internal edges + edges_left = any(v != -1 for v in dedges.values()) + while edges_left: + for dedge in list(dedges.keys()): + if dedges[dedge] != 2: + continue + + from_id, to_id = dedge + # Find the two incoming edges + dedges_in = [] + for adj_id in neighbor[from_id]: + if adj_id != to_id: + dedges_in.append((adj_id, from_id)) + + leaf_set_in_1 = flat_names[dedges_in[0]] + leaf_set_in_2 = flat_names[dedges_in[1]] + combined = leaf_set_in_1 | leaf_set_in_2 + flat_names[dedge] = combined + + g_id = self.set2id(combined) + + tmp_id1 = self.set2id(leaf_set_in_1) + tmp_id2 = self.set2id(leaf_set_in_2) + parts = (min(tmp_id1, tmp_id2), max(tmp_id1, tmp_id2)) + + # Branch length + from_node = node_map[from_id] + to_node = node_map[to_id] + if (from_node.parent is not None + and from_node.parent.id == to_id): + bl = (from_node.branch_length + if from_node.branch_length is not None else 0.0) + elif (to_node.parent is not None + and to_node.parent.id == from_id): + bl = (to_node.branch_length + if to_node.branch_length is not None else 0.0) + else: + # After unrooting, the edge between the two former + # children of the root may not have a direct + # parent-child relationship. Use the sum of their + # branch lengths to the (removed) root. + bl_from = (from_node.branch_length + if from_node.branch_length is not None else 0.0) + bl_to = (to_node.branch_length + if to_node.branch_length is not None else 0.0) + bl = bl_from + bl_to + + self.Bip_bls[g_id] = self.Bip_bls.get(g_id, 0.0) + bl + + # Update counts + if g_id >= len(self.Dip_counts): + while g_id >= len(self.Dip_counts): + self.Dip_counts.append({}) + self.Dip_counts[g_id][parts] = ( + self.Dip_counts[g_id].get(parts, 0.0) + weight + ) + self.Bip_counts[g_id] = ( + self.Bip_counts.get(g_id, 0.0) + weight + ) + + # Mark done + dedges[dedge] = -1 + for adj_id in neighbor[to_id]: + if adj_id != from_id: + dedge_out = (to_id, adj_id) + if dedge_out in dedges: + dedges[dedge_out] += 1 + + edges_left = any(v != -1 for v in dedges.values()) + + # ------------------------------------------------------------------ + # observation + # ------------------------------------------------------------------ + + def observation(self, trees, weight=1.0): + """Observe a list of tree strings, updating all counts.""" + for tree_str in trees: + self.decompose(tree_str, weight=weight) + self.observations += weight + + # Rebuild set_sizes and size_ordered_bips + self.set_sizes = {} + self.size_ordered_bips = {} + for sid, bitset in self.id_sets.items(): + size = _popcount(bitset) + self.set_sizes[sid] = size + self.size_ordered_bips.setdefault(size, []).append(sid) + + # ------------------------------------------------------------------ + # Probability functions + # ------------------------------------------------------------------ + + def Bi(self, n2): + """Number of tree topologies for a bipartition of sizes n1, n2.""" + n1 = self.Gamma_size - n2 + if n1 == 1 or n2 == 1: + return double_factorial(2 * self.Gamma_size - 5) + n1 = max(2, n1) + n2 = max(2, n2) + return (double_factorial(2 * n1 - 3) + * double_factorial(2 * n2 - 3)) + + def Tri(self, n2, n3): + """Number of tree topologies for a trifurcation of sizes n1, n2, n3.""" + n1 = self.Gamma_size - n2 - n3 + n1 = max(2, n1) + n2 = max(2, n2) + n3 = max(2, n3) + return (double_factorial(2 * n1 - 3) + * double_factorial(2 * n2 - 3) + * double_factorial(2 * n3 - 3)) + + def p_bip(self, g_id): + """Bipartition probability by ID.""" + if self.Gamma_size < 4: + return 1.0 + + Bip_count = 0.0 + if g_id not in self.Bip_counts or g_id == -10 or g_id == 0: + Bip_count = 0.0 + else: + Bip_count = self.Bip_counts[g_id] + + if g_id not in self.set_sizes or g_id == -10: + Bip_count = 0.0 + elif (self.set_sizes[g_id] == 1 + or self.set_sizes[g_id] == self.Gamma_size - 1): + Bip_count = self.observations + + if self.alpha > 0: + size = self.set_sizes.get(g_id, 0) + return (Bip_count + self.alpha / self.N_Gamma + * self.Bi(size)) / ( + self.observations + self.alpha) + else: + return Bip_count / self.observations + + def p_bip_by_set(self, gamma): + """Bipartition probability by bitset.""" + if self.Gamma_size < 4: + return 1.0 + if gamma in self.set_ids: + g_id = self.set_ids[gamma] + else: + g_id = -10 + return self.p_bip(g_id) + + def p_dip(self, g_id, gp_id, gpp_id): + """Conditional clade probability (by IDs, with gp_id < gpp_id).""" + if self.Gamma_size < 4: + return 1.0 + + beta_switch = 1.0 + Dip_count = 0.0 + Bip_count = 0.0 + + if g_id not in self.Bip_counts or g_id == -10 or g_id == 0: + beta_switch = 0.0 + Bip_count = 0.0 + Dip_count = 0.0 + else: + parts = (min(gp_id, gpp_id), max(gp_id, gpp_id)) + Bip_count = self.Bip_counts[g_id] + + if (gp_id == -10 or gpp_id == -10 + or gp_id == 0 or gpp_id == 0 + or g_id >= len(self.Dip_counts) + or parts not in self.Dip_counts[g_id]): + Dip_count = 0.0 + else: + Dip_count = self.Dip_counts[g_id][parts] + + if (g_id not in self.set_sizes + or self.set_sizes[g_id] == 1 + or self.set_sizes[g_id] == self.Gamma_size - 1): + Bip_count = self.observations + + if self.alpha > 0 or self.beta > 0: + g_size = self.set_sizes.get(g_id, 1) + gp_size = self.set_sizes.get(gp_id, 1) + gpp_size = self.set_sizes.get(gpp_id, 1) + numerator = (Dip_count + + self.alpha / self.N_Gamma + * self.Tri(gp_size, gpp_size) + + beta_switch * self.beta + / (2.0 ** (g_size - 1) - 1)) + denominator = (Bip_count + + self.alpha / self.N_Gamma + * self.Bi(g_size) + + beta_switch * self.beta) + return numerator / denominator + else: + if Bip_count == 0: + return 0.0 + return Dip_count / Bip_count + + def p_dip_by_set(self, gamma, gammap, gammapp): + """Conditional clade probability by bitsets.""" + if self.Gamma_size < 4: + return 1.0 + + g_id = self.set_ids[gamma] if gamma in self.set_ids else -10 + gp_id = self.set_ids[gammap] if gammap in self.set_ids else -10 + gpp_id = self.set_ids[gammapp] if gammapp in self.set_ids else -10 + + if gpp_id > gp_id: + return self.p_dip(g_id, gp_id, gpp_id) + else: + return self.p_dip(g_id, gpp_id, gp_id) + + # ------------------------------------------------------------------ + # recompose + # ------------------------------------------------------------------ + + def recompose(self, G_string): + """Compute conditional clade probabilities for a given tree. + + Returns a dict mapping bitset -> probability. + """ + return_map = {} + + root = parse_newick(G_string) + neighbor, node_map = _build_unrooted_adjacency(root) + + dedges = {} + for from_id, adj in neighbor.items(): + for to_id in adj: + dedges[(from_id, to_id)] = 0 + + flat_names = {} + q = {} + + # Name leaves + for dedge in list(dedges.keys()): + from_id, to_id = dedge + node = node_map[from_id] + if is_leaf(node): + leaf_bit = 1 << self.leaf_ids[node.name] + flat_names[dedge] = leaf_bit + q[dedge] = 1.0 + return_map[leaf_bit] = 1.0 + dedges[dedge] = -1 + for adj_id in neighbor[to_id]: + if adj_id != from_id: + dedge_out = (to_id, adj_id) + if dedge_out in dedges: + dedges[dedge_out] += 1 + + edges_left = any(v != -1 for v in dedges.values()) + while edges_left: + for dedge in list(dedges.keys()): + if dedges[dedge] != 2: + continue + + from_id, to_id = dedge + dedges_in = [] + for adj_id in neighbor[from_id]: + if adj_id != to_id: + dedges_in.append((adj_id, from_id)) + + leaf_set_in_1 = flat_names[dedges_in[0]] + leaf_set_in_2 = flat_names[dedges_in[1]] + combined = leaf_set_in_1 | leaf_set_in_2 + flat_names[dedge] = combined + + q[dedge] = (q[dedges_in[0]] * q[dedges_in[1]] + * self.p_dip_by_set(combined, + leaf_set_in_1, + leaf_set_in_2)) + return_map[combined] = q[dedge] + + dedges[dedge] = -1 + for adj_id in neighbor[to_id]: + if adj_id != from_id: + dedge_out = (to_id, adj_id) + if dedge_out in dedges: + dedges[dedge_out] += 1 + + edges_left = any(v != -1 for v in dedges.values()) + + return return_map + + # ------------------------------------------------------------------ + # p (tree probability) + # ------------------------------------------------------------------ + + def p(self, tree_string): + """Probability of a tree given the approximate posterior.""" + rec_map = self.recompose(tree_string) + p_val = 0.0 + for gamma, val in rec_map.items(): + p_val = val + # not_gamma: flip all bits except bit 0 + not_gamma = gamma ^ self.Gamma + p_val *= rec_map.get(not_gamma, 0.0) * self.p_bip_by_set(gamma) + if math.isnan(p_val): + p_val = 0.0 + break + return p_val + + # ------------------------------------------------------------------ + # mpp_tree (maximum posterior probability tree) + # ------------------------------------------------------------------ + + def mpp_tree(self): + """Return (newick_string, max_posterior_probability).""" + qmpp = {} + + # DP from small to large + for size in sorted(self.size_ordered_bips.keys()): + for g_id in self.size_ordered_bips[size]: + if size == 1: + qmpp[g_id] = 1.0 + else: + max_cp = 0.0 + if g_id < len(self.Dip_counts): + for (gp_id, gpp_id), _ in self.Dip_counts[g_id].items(): + cp = (self.p_dip(g_id, gp_id, gpp_id) + * qmpp.get(gp_id, 0.0) + * qmpp.get(gpp_id, 0.0)) + if cp > max_cp: + max_cp = cp + qmpp[g_id] = max_cp + + # Find best root bipartition + max_pp = 0.0 + sum_pp = 0.0 + max_bip = -1 + max_not_bip = -1 + + for g_id in self.Bip_counts: + gamma = self.id_sets[g_id] + not_gamma = gamma ^ self.Gamma + if not_gamma not in self.set_ids: + continue + not_g_id = self.set_ids[not_gamma] + pp = (qmpp.get(g_id, 0.0) + * qmpp.get(not_g_id, 0.0) + * self.p_bip(g_id)) + sum_pp += pp + if pp > max_pp: + max_pp = pp + max_bip = g_id + max_not_bip = not_g_id + + if max_bip == -1 or sum_pp == 0: + return ("", 0.0) + + # Root support and branch length + support = max_pp / sum_pp * 2 # we looked at everything twice + bl = min( + self.Bip_bls.get(max_bip, 0.0) + / max(self.Bip_counts.get(max_bip, 1.0), 1e-100), + 0.99, + ) + bs_str = f"{support}:{bl}" + + left = self.mpp_backtrack(max_bip, qmpp) + right = self.mpp_backtrack(max_not_bip, qmpp) + max_tree = f"({left},{right}){bs_str};\n" + + return (max_tree, max_pp) + + def mpp_backtrack(self, g_id, qmpp): + """Recursive backtrack for mpp_tree.""" + # Leaf + if self.set_sizes.get(g_id, 0) == 1: + bl = self.Bip_bls.get(g_id, 0.0) / max(self.observations, 1e-100) + bitset = self.id_sets[g_id] + leaf_id = 0 + for i in range(1, self.Gamma_size + 1): + if bitset & (1 << i): + leaf_id = i + break + return f"{self.id_leaves[leaf_id]}:{bl}" + + # Internal node: find best split + max_cp = 0.0 + sum_cp = 0.0 + max_gp_id = -1 + max_gpp_id = -1 + + if g_id < len(self.Dip_counts): + for (gp_id, gpp_id), _ in self.Dip_counts[g_id].items(): + cp = (self.p_dip(g_id, gp_id, gpp_id) + * qmpp.get(gp_id, 0.0) + * qmpp.get(gpp_id, 0.0)) + sum_cp += cp + if cp > max_cp: + max_cp = cp + max_gp_id = gp_id + max_gpp_id = gpp_id + + if sum_cp == 0: + support_str = "0" + else: + support_str = str(max_cp / sum_cp) + bl = (self.Bip_bls.get(g_id, 0.0) + / max(self.Bip_counts.get(g_id, 1.0), 1e-100)) + + left = self.mpp_backtrack(max_gp_id, qmpp) + right = self.mpp_backtrack(max_gpp_id, qmpp) + return f"({left},{right}){support_str}:{bl}" + + # ------------------------------------------------------------------ + # count_trees + # ------------------------------------------------------------------ + + def count_trees(self): + """Count amalgamated trees with the complete leaf set.""" + g_id_count = {} + + for size in sorted(self.size_ordered_bips.keys()): + for g_id in self.size_ordered_bips[size]: + if size == 1: + g_id_count[g_id] = 1.0 + else: + g_id_count[g_id] = 0.0 + if g_id < len(self.Dip_counts): + for (gp_id, gpp_id), _ in self.Dip_counts[g_id].items(): + g_id_count[g_id] += ( + g_id_count.get(gp_id, 0.0) + * g_id_count.get(gpp_id, 0.0) + ) + + count = 0.0 + for g_id in self.Bip_counts: + gamma = self.id_sets[g_id] + not_gamma = gamma ^ self.Gamma + if not_gamma not in self.set_ids: + continue + gamma_size = _popcount(gamma) + not_gamma_size = _popcount(not_gamma) + val = (g_id_count.get(self.set_ids[gamma], 0.0) + * g_id_count.get(self.set_ids[not_gamma], 0.0)) + if gamma_size > not_gamma_size: + count += val + elif gamma_size == not_gamma_size: + count += val / 2.0 + + return count + + # ------------------------------------------------------------------ + # random_tree / random_split + # ------------------------------------------------------------------ + + def random_tree(self): + """Sample a random tree from the approximate posterior.""" + total = sum(self.Bip_counts.values()) + if total == 0: + return "" + rnd = random.random() + cumsum = 0.0 + g_id = None + for gid, cnt in self.Bip_counts.items(): + cumsum += cnt + g_id = gid + if cumsum > total * rnd: + break + + gamma = self.id_sets[g_id] + not_gamma = gamma ^ self.Gamma + return (f"({self.random_split(gamma)}:1," + f"{self.random_split(not_gamma)}:1);\n") + + def random_split(self, gamma): + """Recursive random split of a clade.""" + gamma_v = _bits_to_list(gamma, self.Gamma_size) + gamma_size = len(gamma_v) + + if gamma_size == 1: + return self.id_leaves[gamma_v[0]] + + rnd = random.random() + p_sum = 0.0 + g_id = self.set_ids.get(gamma, 0) + beta_switch = 1.0 + Bip_count = 0.0 + + if not g_id: + beta_switch = 0.0 + Bip_count = 0.0 + + gammap = 0 + gammapp = 0 + + for gp_size in range(1, gamma_size // 2 + 1): + saw = 0 + found = False + + if g_id and g_id < len(self.Dip_counts): + for (gp_id_k, gpp_id_k), _ in self.Dip_counts[g_id].items(): + gp_bits = self.id_sets.get(gp_id_k, 0) + gpp_bits = self.id_sets.get(gpp_id_k, 0) + gp_id_size = _popcount(gp_bits) + gpp_id_size = _popcount(gpp_bits) + this_size = min(gp_id_size, gpp_id_size) + + if this_size == gp_size: + p_sum += self.p_dip_by_set( + gamma, + self.id_sets[gp_id_k], + self.id_sets[gpp_id_k], + ) + if rnd < p_sum: + p_sum = -1 + saw += 1 + + if p_sum < 0: + gammap = self.id_sets[gp_id_k] + gammapp = self.id_sets[gpp_id_k] + found = True + break + + if found: + break + + # Unobserved partitions + if g_id: + Bip_count = self.Bip_counts.get(g_id, 0.0) + if gamma_size == 1 or gamma_size == self.Gamma_size - 1: + Bip_count = self.observations + + nbip = math.comb(gamma_size, gp_size) + if gamma_size - gp_size == gp_size: + nbip //= 2 + + denom = (Bip_count + + self.alpha / self.N_Gamma * self.Bi(gamma_size) + + beta_switch * self.beta) + if denom != 0: + p_sum += ((0 + + self.alpha / self.N_Gamma + * self.Tri(gp_size, gamma_size - gp_size) + + beta_switch * self.beta + / (2.0 ** (gamma_size - 1) - 1)) + / denom * (nbip - saw)) + + if rnd < p_sum: + p_sum = -1 + + if p_sum < 0: + # Pick random unsampled partition + while True: + chosen = random.sample(gamma_v, gp_size) + gammap = 0 + for v in chosen: + gammap |= (1 << v) + gammapp = gamma ^ gammap + + gp_id_t = self.set_ids.get(gammap, 0) + gpp_id_t = self.set_ids.get(gammapp, 0) + parts = (min(gp_id_t, gpp_id_t), max(gp_id_t, gpp_id_t)) + if (g_id < 0 or g_id >= len(self.Dip_counts) + or parts not in self.Dip_counts[g_id] + or self.Dip_counts[g_id][parts] == 0): + break + break + + return (f"({self.random_split(gammap)}:1," + f"{self.random_split(gammapp)}:1)") + + # ------------------------------------------------------------------ + # save_state / load_state + # ------------------------------------------------------------------ + + def save_state(self, fname): + """Write ALE file in the exact C++ format.""" + with open(fname, "w") as fout: + fout.write("#constructor_string\n") + fout.write(self.constructor_string.strip() + "\n") + + fout.write("#observations\n") + fout.write(f"{self.observations}\n") + + fout.write("#Bip_counts\n") + for gid in sorted(self.Bip_counts.keys()): + fout.write(f"{gid}\t{self.Bip_counts[gid]}\n") + + fout.write("#Bip_bls\n") + for gid in sorted(self.Bip_bls.keys()): + fout.write(f"{gid}\t{self.Bip_bls[gid]}\n") + + fout.write("#Dip_counts\n") + for index in range(len(self.Dip_counts)): + for (gp_id, gpp_id), count in self.Dip_counts[index].items(): + fout.write(f"{index}\t{gp_id}\t{gpp_id}\t{count}\n") + + fout.write("#last_leafset_id\n") + fout.write(f"{self.last_leafset_id}\n") + + fout.write("#leaf-id\n") + for name in sorted(self.leaf_ids.keys()): + fout.write(f"{name}\t{self.leaf_ids[name]}\n") + + fout.write("#set-id\n") + for bitset in sorted(self.set_ids.keys()): + sid = self.set_ids[bitset] + fout.write(f"{sid}\t:") + for i in range(self.Gamma_size + 1): + if bitset & (1 << i): + fout.write(f"\t{i}") + fout.write("\n") + + fout.write("#END\n") + + def load_state(self, fname): + """Read ALE file in the exact C++ format.""" + reading = "#nothing" + + with open(fname, "r") as fin: + for line in fin: + line = line.rstrip("\n") + if "#" in line: + reading = line.strip() + elif reading == "#constructor_string": + tree_string = line.strip() + self.constructor_string = tree_string + self.construct(tree_string) + reading = "#nothing" + elif reading == "#observations": + self.observations = float(line.strip()) + elif reading == "#Bip_counts": + tokens = line.strip().split() + if len(tokens) >= 2: + self.Bip_counts[int(tokens[0])] = float(tokens[1]) + elif reading == "#Bip_bls": + tokens = line.strip().split() + if len(tokens) >= 2: + self.Bip_bls[int(tokens[0])] = float(tokens[1]) + elif reading == "#Dip_counts": + tokens = line.strip().split() + if len(tokens) >= 4: + idx = int(tokens[0]) + parts = (int(tokens[1]), int(tokens[2])) + count = float(tokens[3]) + while idx >= len(self.Dip_counts): + self.Dip_counts.append({}) + self.Dip_counts[idx][parts] = count + elif reading == "#last_leafset_id": + self.last_leafset_id = int(line.strip()) + elif reading == "#leaf-id": + tokens = line.strip().split() + if len(tokens) >= 2: + name = tokens[0] + lid = int(tokens[1]) + self.leaf_ids[name] = lid + self.id_leaves[lid] = name + elif reading == "#set-id": + fields = line.strip().split(":") + if len(fields) >= 2: + set_id = int(fields[0].strip()) + tokens = fields[1].strip().split() + bitset = 0 + for t in tokens: + bitset |= (1 << int(t)) + self.set_ids[bitset] = set_id + self.id_sets[set_id] = bitset + + # Pad Dip_counts to match last_leafset_id + 1 so that indices + # corresponding to empty dicts (which save_state does not write) + # are present after loading. + while len(self.Dip_counts) < self.last_leafset_id + 1: + self.Dip_counts.append({}) + + # Root bipartition (bits 1..Gamma_size set) + root_bits = 0 + for i in range(1, self.Gamma_size + 1): + root_bits |= (1 << i) + self.id_sets[-1] = root_bits + self.set_ids[root_bits] = -1 + + # Rebuild set_sizes and size_ordered_bips + self.set_sizes = {} + self.size_ordered_bips = {} + for sid, bitset in self.id_sets.items(): + size = _popcount(bitset) + self.set_sizes[sid] = size + self.size_ordered_bips.setdefault(size, []).append(sid) + + # ------------------------------------------------------------------ + # Utility methods + # ------------------------------------------------------------------ + + def get_leaf_names(self): + """Return all leaf names.""" + return list(self.leaf_ids.keys()) + + def set_alpha(self, a): + """Set the alpha correction parameter.""" + self.alpha = a + + def set_beta(self, b): + """Set the beta correction parameter.""" + self.beta = b + + def set2name(self, leaf_set): + """Return human-readable name for a bitset.""" + parts = [] + for i in range(1, self.Gamma_size + 1): + if leaf_set & (1 << i): + parts.append(self.id_leaves[i]) + return self.name_separator.join(parts) + + def compute_ordered_vector_of_clades(self): + """Return (ids, sizes) ordered by size, with root (-1) appended.""" + ids = [] + id_sizes = [] + for size in sorted(self.size_ordered_bips.keys()): + for gid in self.size_ordered_bips[size]: + ids.append(gid) + id_sizes.append(size) + ids.append(-1) + id_sizes.append(self.Gamma_size) + return ids, id_sizes + + +# Alias used by ale_util and cli modules +ApproxPosterior = approx_posterior diff --git a/ALE_python/ale_util.py b/ALE_python/ale_util.py new file mode 100644 index 0000000..82a23eb --- /dev/null +++ b/ALE_python/ale_util.py @@ -0,0 +1,181 @@ +"""I/O utilities for ALE (Amalgamated Likelihood Estimation). + +Ported from C++ ALE_util.cpp. Provides functions for observing, loading, +and saving ApproxPosterior objects from/to files and strings. +""" + +from __future__ import annotations + +import os +from typing import Union + +from .ale import ApproxPosterior + + +def _build_singleton_ale(name: str) -> ApproxPosterior: + """Build a one-leaf ApproxPosterior, matching C++ singleton handling.""" + import tempfile + import os + + ale = ApproxPosterior() + # Write a minimal .ale state file and load it, matching C++ behavior + with tempfile.NamedTemporaryFile(mode="w", suffix=".ale", delete=False) as f: + tmp = f.name + f.write("#constructor_string\n") + f.write(f"{name}\n") + f.write("#observations\n1\n") + f.write("#Bip_counts\n") + f.write("#Bip_bls\n") + f.write("1\t1\n") + f.write("#Dip_counts\n") + f.write("#last_leafset_id\n1\n") + f.write("#leaf-id\n") + f.write(f"{name}\t1\n") + f.write("#set-id\n") + f.write("1\t:\t1\n") + f.write("#END\n") + try: + ale.load_state(tmp) + finally: + os.unlink(tmp) + return ale + + +def observe_ALE_from_file( + fname_or_fnames: Union[str, list[str]], + burnin: int = 0, + every: int = 1, + until: int = -1, +) -> ApproxPosterior: + """Read tree(s) from one or more files and build an ApproxPosterior. + + For each file, lines containing '(' are treated as tree strings. + The first *burnin* trees are skipped, then every *every*-th tree is + kept. If *until* > 0 the collection is truncated to that many trees. + + Parameters + ---------- + fname_or_fnames: + A single filename or a list of filenames. + burnin: + Number of leading trees to discard per file. + every: + Sub-sampling stride (keep every *every*-th tree after burn-in). + until: + Maximum number of trees to use (all if <= 0). + + Returns + ------- + ApproxPosterior + The observed approximate posterior object. + """ + if isinstance(fname_or_fnames, str): + fnames = [fname_or_fnames] + else: + fnames = list(fname_or_fnames) + + trees: list[str] = [] + + for fname in fnames: + if not os.path.isfile(fname): + raise FileNotFoundError( + f"Error, file {fname} does not seem accessible." + ) + + with open(fname) as fh: + tree_i = 0 + for line in fh: + line = line.rstrip("\n").strip() + if not line: + continue + if "(" in line: + tree_i += 1 + if tree_i > burnin and tree_i % every == 0: + trees.append(line) + elif ";" in line: + # Singleton tree like "A;" — C++ immediately builds + # a one-leaf ALE and returns, so we do the same. + name = line.rstrip(";").split(",")[0].split(":")[0].split()[0] + if name: + return _build_singleton_ale(name) + + if not trees: + raise ValueError("No trees found in the provided file(s).") + + if until > 0: + trees = trees[:until] + + ale = ApproxPosterior(trees[0]) + ale.observation(trees) + return ale + + +def observe_ALE_from_strings(trees: list[str]) -> ApproxPosterior: + """Build an ApproxPosterior from a list of tree strings. + + Parameters + ---------- + trees: + Newick tree strings. + + Returns + ------- + ApproxPosterior + """ + if not trees: + raise ValueError("Tree list must not be empty.") + + ale = ApproxPosterior(trees[0]) + ale.observation(trees) + return ale + + +def observe_ALE_from_string(tree: str) -> ApproxPosterior: + """Build an ApproxPosterior from a single tree string. + + Parameters + ---------- + tree: + A Newick tree string. + + Returns + ------- + ApproxPosterior + """ + return observe_ALE_from_strings([tree]) + + +def load_ALE_from_file(fname: str) -> ApproxPosterior: + """Load a previously saved ApproxPosterior from a ``.ale`` file. + + Parameters + ---------- + fname: + Path to the saved state file. + + Returns + ------- + ApproxPosterior + """ + ale = ApproxPosterior() + ale.load_state(fname) + return ale + + +def save_ALE_to_file(ale: ApproxPosterior, fname: str) -> str: + """Save an ApproxPosterior to a file. + + Parameters + ---------- + ale: + The ApproxPosterior instance to save. + fname: + Destination file path. + + Returns + ------- + str + The filename that was written. + """ + ale.save_state(fname) + return fname diff --git a/ALE_python/cli.py b/ALE_python/cli.py new file mode 100644 index 0000000..0cb144d --- /dev/null +++ b/ALE_python/cli.py @@ -0,0 +1,1351 @@ +"""Main command-line interface that dispatches to individual ALE programs. + +Usage: python -m ALE_python [args...] + +Programs: ALEobserve, ALEml_undated, ALEmcmc_undated, ALEcount, + ls_leaves, CCPscore, ALEadd, ALEevaluate_undated +""" + +import math +import os +import random +import sys + +from .ale import approx_posterior as ApproxPosterior +from .ale_util import ( + load_ALE_from_file, + observe_ALE_from_file, + observe_ALE_from_string, + save_ALE_to_file, +) +from .exodt import ExODTModel +from .newick import get_leaves, parse_newick, to_newick + +ALE_VERSION = "1.0" + + +# --------------------------------------------------------------------------- +# Argument parsing helpers +# --------------------------------------------------------------------------- + +def _print_subcommand_help(name): + """Print the argparse help for a specific subcommand and exit.""" + parser = _build_argparse() + parser.parse_args([name, "--help"]) + + +def _normalize_argv(argv): + """Convert ``--key value`` long options to ``key=value`` form. + + This allows callers to use either the original C++ style + (``delta=0.01``, ``rate_multiplier:tau_to:43:0.0``) or standard + long options (``--delta 0.01``, ``--rate-multiplier tau_to:43:0.0``). + Positional arguments (not starting with ``--``) are passed through + unchanged. Bare flags (``--MLOR``, ``--reldate``) become ``MLOR`` + and ``reldate``. + """ + result = [] + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith("--"): + key = arg[2:].replace("-", "_") + # Peek: is the next token a value (doesn't start with --)? + if i + 1 < len(argv) and not argv[i + 1].startswith("--"): + result.append(f"{key}={argv[i + 1]}") + i += 2 + else: + # Boolean flag like --MLOR or --reldate + result.append(key) + i += 1 + else: + result.append(arg) + i += 1 + return result + + +def _parse_kwargs(argv): + """Parse key=value and key:value arguments from an argv list. + + Returns (positional_args, keyword_dict). + """ + positional = [] + kwargs = {} + for arg in argv: + if "=" in arg: + key, _, val = arg.partition("=") + kwargs[key] = val + elif ":" in arg: + # Some arguments use colon separators (e.g. S_branch_lengths:0.2) + key, _, val = arg.partition(":") + kwargs[key] = val + else: + positional.append(arg) + return positional, kwargs + + +# --------------------------------------------------------------------------- +# ALEobserve +# --------------------------------------------------------------------------- + +def ALEobserve(argv): + """Observe gene trees and build an ALE file. + + argv[0] = gene tree file, additional positional args are more files. + burnin=N keyword arg (default 0). + """ + print(f"ALEobserve using ALE v{ALE_VERSION}") + + if not argv: + _print_subcommand_help("ALEobserve") + return 1 + + positional, kwargs = _parse_kwargs(argv) + burnin = int(kwargs.get("burnin", 0)) + + first_file = positional[0].strip() + ale_files = [f.strip() for f in positional] + ale_name = first_file + ".ale" + + ale = observe_ALE_from_file(ale_files, burnin=burnin) + + print(f"# observe {ale.observations} tree(s) from: {' '.join(ale_files)}") + print(f"{burnin} burn in per file discarded.") + ale.save_state(ale_name) + print(f"# saved in {ale_name}") + print("# mpp tree from sample: ") + mpp_tree_str, _mpp_pp = ale.mpp_tree() + print(mpp_tree_str) + return 0 + + +# --------------------------------------------------------------------------- +# ALEml_undated +# --------------------------------------------------------------------------- + +def ALEml_undated(argv): + """Maximum-likelihood reconciliation under the undated DTL model.""" + import numpy as np + + print(f"ALEml_undated using ALE v{ALE_VERSION}") + + if len(argv) < 2: + _print_subcommand_help("ALEml_undated") + return 1 + + S_treefile = argv[0] + ale_file_arg = argv[1] + + if not os.path.isfile(S_treefile): + print(f"Error, file {S_treefile} does not seem accessible.") + sys.exit(1) + + with open(S_treefile) as f: + Sstring = f.readline().strip() + print(f"Read species tree from: {S_treefile}..") + + ale = load_ALE_from_file(ale_file_arg) + print(f"Read summary of tree sample for {ale.observations} trees from: {ale_file_arg}..") + + # Output file radical: basename of S_treefile + "_" + basename of ale_file + ale_file = os.path.basename(S_treefile) + "_" + os.path.basename(ale_file_arg) + + model = ExODTModel() + + samples = 100 + O_R = 1.0 + beta = 1.0 + delta_fixed = False + tau_fixed = False + lambda_fixed = False + DT_fixed = False + MLOR = False + no_T = False + + delta = 1e-2 + tau = 1e-2 + lambda_ = 1e-1 + DT_ratio = 0.05 + fraction_missing_file = "" + output_species_tree = "" + rate_multipliers = {} + ml_branch_multipliers = [] # [(branch_id, rate_name)] for per-branch optimization + + model.set_model_parameter("undatedBL", 0) + model.set_model_parameter("reldate", 0) + + # Parse optional arguments + for arg in argv[2:]: + if "=" in arg or ":" in arg: + sep = "=" if "=" in arg else ":" + tokens = arg.replace("=", ":").split(":") + key = tokens[0] + else: + continue + + if key == "sample": + samples = int(tokens[1]) + elif key == "separators": + model.set_model_parameter("gene_name_separators", tokens[1]) + elif key == "delta": + delta = float(tokens[1]) + delta_fixed = True + print(f"# delta fixed to {delta}") + elif key == "tau": + tau = float(tokens[1]) + tau_fixed = True + if tau < 1e-10: + no_T = True + print("# tau fixed to no transfer!") + tau = 1e-19 + else: + print(f"# tau fixed to {tau}") + elif key == "lambda": + lambda_ = float(tokens[1]) + lambda_fixed = True + print(f"# lambda fixed to {lambda_}") + elif key == "DT": + DT_ratio = float(tokens[1]) + DT_fixed = True + model.set_model_parameter("DT_ratio", DT_ratio) + print(f"# D/T ratio fixed to {model.scalar_parameter['DT_ratio']}") + elif key == "O_R": + O_R = float(tokens[1]) + print(f"# O_R set to {O_R}") + elif key == "beta": + beta = float(tokens[1]) + print(f"# beta set to {beta}") + elif key == "fraction_missing": + fraction_missing_file = tokens[1] + print(f"# File containing fractions of missing genes set to {fraction_missing_file}") + elif key == "S_branch_lengths": + model.set_model_parameter("undatedBL", 1) + if len(tokens) == 1 or tokens[1] == "": + model.set_model_parameter("root_BL", 1) + print("# unsing branch lengths of input S tree as rate multipliers with 1 at root! ") + else: + root_rm = float(tokens[1]) + model.set_model_parameter("root_BL", root_rm) + print(f"# unsing branch lengths of input S tree as rate multipliers with {root_rm} at root! ") + elif key == "reldate": + print("Respecting realtive ages from input S tree, please make sure input S tree is ultrametric!") + model.set_model_parameter("reldate", 1) + elif key == "MLOR": + print("Optimizing root origination multiplier.") + MLOR = True + elif key == "rate_multiplier": + rate_name = tokens[1] + e = int(tokens[2]) + rm = float(tokens[3]) + if rm >= -1: + print(f"# rate multiplier for rate {rate_name} on branch with ID {e} set to {rm}") + rate_multipliers.setdefault("rate_multiplier_" + rate_name, {})[e] = rm + else: + print(f"# rate multiplier for rate {rate_name} on branch with ID {e} to be optimized") + ml_branch_multipliers.append((e, "rate_multiplier_" + rate_name)) + elif key == "output_species_tree": + val = tokens[1].lower() + if val in ("y", "ye", "yes"): + output_species_tree = ale_file + ".spTree" + print(f"# outputting the annotated species tree to {output_species_tree}") + elif key == "seed": + seed_val = int(tokens[1]) + print(f"Set random seed to {seed_val}") + random.seed(seed_val) + np.random.seed(seed_val) + + model.set_model_parameter("BOOTSTRAP_LABELS", "yes") + model.construct_undated(Sstring, fraction_missing_file) + + # Apply rate multipliers + for rm_name, rm_dict in rate_multipliers.items(): + for e, rm_val in rm_dict.items(): + model.vector_parameter[rm_name][e] = rm_val + + model.set_model_parameter("seq_beta", beta) + model.set_model_parameter("O_R", O_R) + model.set_model_parameter("delta", delta) + model.set_model_parameter("tau", tau) + model.set_model_parameter("lambda", lambda_) + + model.calculate_undatedEs() + print("Reconciliation model initialised, starting DTL rate optimisation..") + + # Optimize if not all rates are fixed, or if branch multipliers need optimization + if not (delta_fixed and tau_fixed and lambda_fixed and not MLOR and not ml_branch_multipliers): + print("#optimizing rates") + + from scipy.optimize import minimize + + def neg_log_lk(params): + idx = 0 + d_val = delta + t_val = tau + l_val = lambda_ + o_val = O_R + + if not delta_fixed and not DT_fixed: + d_val = params[idx] + idx += 1 + if not tau_fixed and not DT_fixed: + t_val = params[idx] + idx += 1 + if not lambda_fixed: + l_val = params[idx] + idx += 1 + if DT_fixed: + t_val = params[idx] + idx += 1 + d_val = t_val * model.scalar_parameter.get("DT_ratio", DT_ratio) + if MLOR: + o_val = params[idx] + idx += 1 + + for branch_e, rm_name in ml_branch_multipliers: + multiplier = params[idx] + idx += 1 + model.vector_parameter[rm_name][branch_e] = max(multiplier, 1e-7) + + model.set_model_parameter("delta", max(d_val, 1e-10)) + model.set_model_parameter("tau", max(t_val, 1e-10)) + model.set_model_parameter("lambda", max(l_val, 1e-10)) + if MLOR: + model.set_model_parameter("O_R", max(o_val, 1e-10)) + model.calculate_undatedEs() + lk = model.pun(ale, False, no_T) + if lk <= 0: + return 1e50 + return -math.log(lk) + + # Build initial parameter vector + x0 = [] + param_names = [] + if not delta_fixed and not DT_fixed: + x0.append(delta) + param_names.append("delta") + print("#optimizing delta rate") + if not tau_fixed and not DT_fixed: + x0.append(tau) + param_names.append("tau") + print("#optimizing tau rate") + if not lambda_fixed: + x0.append(lambda_) + param_names.append("lambda") + print("#optimizing lambda rate") + if DT_fixed: + x0.append(tau) + param_names.append("tau_DT") + print("#optimizing delta and tau rates with fixed D/T ratio") + if MLOR: + x0.append(1.0) + param_names.append("O_R") + print("#optimizing O_R") + + for branch_e, rm_name in ml_branch_multipliers: + x0.append(1.0) + param_names.append(f"rm_{rm_name}_{branch_e}") + print(f"#optimizing for branch {branch_e} ratemultiplier {rm_name}") + + x0 = np.array(x0) + + # Build per-parameter bounds matching C++ constraints + rate_lo = 1e-6 if no_T else 1e-10 + bounds = [] + for pname in param_names: + if pname == "O_R": + bounds.append((1e-10, 1000.0)) + elif pname.startswith("rm_"): + bounds.append((1e-7, 10000.0)) + else: + bounds.append((rate_lo, 100.0)) + result = minimize( + neg_log_lk, + x0, + method="Nelder-Mead", + bounds=bounds, + options={"maxiter": 10000, "xatol": 1e-6, "fatol": 1e-6, "adaptive": True}, + ) + + # Extract optimized parameters + idx = 0 + if not delta_fixed and not DT_fixed: + delta = result.x[idx] + idx += 1 + if not tau_fixed and not DT_fixed: + tau = result.x[idx] + idx += 1 + if not lambda_fixed: + lambda_ = result.x[idx] + idx += 1 + if DT_fixed: + tau = result.x[idx] + idx += 1 + delta = tau * model.scalar_parameter.get("DT_ratio", DT_ratio) + if MLOR: + O_R = result.x[idx] + idx += 1 + + ml_rm_strings = [] + for branch_e, rm_name in ml_branch_multipliers: + multiplier = result.x[idx] + idx += 1 + model.vector_parameter[rm_name][branch_e] = multiplier + ml_rm_strings.append(f"{rm_name}\t{branch_e}\t{multiplier};") + + mlll = -result.fun + else: + mlll = math.log(model.pun(ale, False, no_T)) + + print() + print(f"ML rates: delta={delta}; tau={tau}; lambda={lambda_}; O_R={O_R}.") + if ml_branch_multipliers: + print("ML rate multipliers:") + for s in ml_rm_strings: + print(s) + print(f"LL={mlll}") + + # Set final rates for sampling + model.set_model_parameter("delta", delta) + model.set_model_parameter("tau", tau) + model.set_model_parameter("lambda", lambda_) + model.set_model_parameter("O_R", O_R) + model.calculate_undatedEs() + model.pun(ale, False, no_T) + + # Sample reconciled gene trees + print("Sampling reconciled gene trees..") + sample_strings = [] + total_events = {"D": 0.0, "T": 0.0, "L": 0.0, "S": 0.0} + for i in range(int(samples)): + # Reset event counts per sample + model.MLRec_events.clear() + model.Ttokens = [] + sample_tree = model.sample_undated(no_T) + sample_strings.append(sample_tree) + for key in total_events: + total_events[key] += model.MLRec_events.get(key, 0.0) + + # Write .uml_rec output + outname = ale_file + ".uml_rec" + with open(outname, "w") as fout: + fout.write(f"#ALEml_undated using ALE v{ALE_VERSION} by Szollosi GJ et al.; ssolo@elte.hu; CC BY-SA 3.0;\n\n") + s_tree_str = model.string_parameter.get("S_with_ranks", model.string_parameter.get("S_un", Sstring)) + fout.write(f"S:\t{s_tree_str}\n") + fout.write("\n") + fout.write(f"Input ale from:\t{ale_file}\n") + fout.write(f">logl: {mlll}\n") + fout.write("rate of\t Duplications\tTransfers\tLosses\n") + fout.write(f"ML \t{delta}\t{tau}\t{lambda_}\n") + fout.write(f"{int(samples)} reconciled G-s:\n\n") + for s in sample_strings: + fout.write(s + "\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tSpeciations\n") + div = samples if samples > 0 else 1 + fout.write( + f"Total \t{total_events['D'] / div}\t" + f"{total_events['T'] / div}\t" + f"{total_events['L'] / div}\t" + f"{total_events['S'] / div}\n" + ) + fout.write("\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tOriginations\tcopies\tsingletons\textinction_prob\tpresence\tLL\n") + fout.write(model.counts_string_undated(samples)) + + # Output species tree + if output_species_tree: + with open(output_species_tree, "w") as fout: + s_tree_str = model.string_parameter.get("S_with_ranks", model.string_parameter.get("S_un", Sstring)) + fout.write(s_tree_str + "\n") + + print(f"Results in: {outname}") + + # Write transfer file + t_name = ale_file + ".uTs" + with open(t_name, "w") as tout: + tout.write("#from\tto\tfreq.\n") + for e in range(model.last_branch): + for f in range(model.last_branch): + if model.T_to_from[e][f] > 0: + if e < model.last_leaf: + e_name = f"{model._node_name[model._id_nodes[e].id]}({e})" + else: + e_name = str(e) + if f < model.last_leaf: + f_name = f"{model._node_name[model._id_nodes[f].id]}({f})" + else: + f_name = str(f) + tout.write(f"\t{e_name}\t{f_name}\t{model.T_to_from[e][f] / div}\n") + print(f"Transfers in: {t_name}") + return 0 + + +# --------------------------------------------------------------------------- +# ALEmcmc_undated +# --------------------------------------------------------------------------- + +def _scale_double_constrained(value, maxi, lam): + """Scaling move on a positive real. Returns (new_value, hastings_ratio).""" + u = random.random() + scaling_factor = math.exp(lam * (u - 0.5)) + new_value = value * scaling_factor + if new_value < 0.00001: + new_value = 0.00001 + if new_value > maxi: + new_value = maxi + return new_value, scaling_factor + + +def _compute_exponential_log_probability(param, value): + """Log probability under Exponential(param).""" + if param <= 0: + return 0.0 + return math.log(param) - param * value + + +def _compute_log_prior(o, d, t, l, prior_o, prior_d, prior_t, prior_l): + pp = 0.0 + pp += _compute_exponential_log_probability(prior_o, o) + pp += _compute_exponential_log_probability(prior_d, d) + pp += _compute_exponential_log_probability(prior_t, t) + pp += _compute_exponential_log_probability(prior_l, l) + return pp + + +def _compute_log_lk(model, ale, o, d, t, l): + model.set_model_parameter("O_R", o) + model.set_model_parameter("delta", d) + model.set_model_parameter("tau", t) + model.set_model_parameter("lambda", l) + model.calculate_undatedEs() + lk = model.pun(ale) + if lk <= 0: + return -1e100 + return math.log(lk) + + +def ALEmcmc_undated(argv): + """MCMC sampling of reconciliation under the undated DTL model.""" + print(f"ALEmcmc using ALE v{ALE_VERSION}") + + if len(argv) < 2: + _print_subcommand_help("ALEmcmc_undated") + return 1 + + S_treefile = argv[0] + ale_file_arg = argv[1] + + if not os.path.isfile(S_treefile): + print(f"Error, file {S_treefile} does not seem accessible.") + sys.exit(1) + + with open(S_treefile) as f: + Sstring = f.readline().strip() + print(f"Read species tree from: {S_treefile}..") + + ale = load_ALE_from_file(ale_file_arg) + print(f"Read summary of tree sample for {ale.observations} trees from: {ale_file_arg}..") + + ale_file = os.path.basename(ale_file_arg) + + model = ExODTModel() + + samples = 100 + prior_origination = 1.0 + prior_delta = 0.01 + prior_tau = 0.01 + prior_lambda = 0.1 + sampling_rate = 1 + beta = 1.0 + fraction_missing_file = "" + output_species_tree = "" + rate_multipliers = {} + + model.set_model_parameter("undatedBL", 0) + model.set_model_parameter("reldate", 0) + + for arg in argv[2:]: + if "=" in arg or ":" in arg: + tokens = arg.replace("=", ":").split(":") + key = tokens[0] + else: + continue + + if key == "sample": + samples = int(tokens[1]) + elif key == "separators": + model.set_model_parameter("gene_name_separators", tokens[1]) + elif key == "delta": + prior_delta = float(tokens[1]) + print(f"# priorDelta fixed to {prior_delta}") + elif key == "tau": + prior_tau = float(tokens[1]) + print(f"# priorTau fixed to {prior_tau}") + elif key == "lambda": + prior_lambda = float(tokens[1]) + print(f"# priorLambda fixed to {prior_lambda}") + elif key == "O_R": + prior_origination = float(tokens[1]) + print(f"# priorOrigination set to {prior_origination}") + elif key == "beta": + beta = float(tokens[1]) + print(f"# beta set to {beta}") + elif key == "sampling_rate": + sampling_rate = int(tokens[1]) + print(f"# sampling_rate set to {sampling_rate}") + elif key == "fraction_missing": + fraction_missing_file = tokens[1] + print(f"# File containing fractions of missing genes set to {fraction_missing_file}") + elif key == "output_species_tree": + val = tokens[1].lower() + if val in ("y", "ye", "yes"): + output_species_tree = ale_file + ".spTree" + print(f"# outputting the annotated species tree to {output_species_tree}") + elif key == "S_branch_lengths": + model.set_model_parameter("undatedBL", 1) + if len(tokens) == 1 or tokens[1] == "": + model.set_model_parameter("root_BL", 1) + print("# unsing branch lengths of input S tree as rate multipliers with 1 at root! ") + else: + root_rm = float(tokens[1]) + model.set_model_parameter("root_BL", root_rm) + print(f"# unsing branch lengths of input S tree as rate multipliers with {root_rm} at root! ") + elif key == "rate_multiplier": + rate_name = tokens[1] + e = int(tokens[2]) + rm = float(tokens[3]) + print(f"# rate multiplier for rate {rate_name} on branch with ID {e} set to {rm}") + rate_multipliers.setdefault("rate_multiplier_" + rate_name, {})[e] = rm + elif key == "reldate": + print("Respecting realtive ages from input S tree, please make sure input S tree is ultrametric!") + model.set_model_parameter("reldate", 1) + + model.set_model_parameter("BOOTSTRAP_LABELS", "yes") + model.set_model_parameter("seq_beta", beta) + + model.construct_undated(Sstring, fraction_missing_file) + + # Apply rate multipliers after construct_undated so vectors are initialized + for rm_name, rm_dict in rate_multipliers.items(): + for e, rm_val in rm_dict.items(): + model.vector_parameter[rm_name][e] = rm_val + + # Draw initial values from exponential priors + current_origination = random.expovariate(prior_origination) if prior_origination > 0 else 1.0 + current_delta = random.expovariate(prior_delta) if prior_delta > 0 else 0.01 + current_tau = random.expovariate(prior_tau) if prior_tau > 0 else 0.01 + current_lambda = random.expovariate(prior_lambda) if prior_lambda > 0 else 0.1 + + new_origination = current_origination + new_delta = current_delta + new_tau = current_tau + new_lambda = current_lambda + + current_log_lk = _compute_log_lk(model, ale, current_origination, current_delta, current_tau, current_lambda) + current_log_prior = _compute_log_prior( + current_origination, current_delta, current_tau, current_lambda, + prior_origination, prior_delta, prior_tau, prior_lambda, + ) + + print(f"Initial logLK: {current_log_lk} and logPrior: {current_log_prior}") + print("Reconciliation model initialised, starting DTL rate sampling..") + + # Move setup + ORIGINATION_ID = 0 + DELTA_ID = 1 + LAMBDA_ID = 2 + TAU_ID = 3 + move_weights = [1.0, 1.0, 1.0, 1.0] + max_sum_dtl = 10.0 + max_origination = 1000000.0 + scale_move_params = [0.1, 1.0, 10.0] + + # CSV trace file + mcmc_outname = ale_file + "_umcmc.csv" + mcmc_fh = open(mcmc_outname, "w") + mcmc_fh.write("Iteration\tLogLk\tLogPrior\tOrigination\tDelta\tTau\tLambda\n") + print("Iteration\tLogLk\tLogPrior\tOrigination\tDelta\tTau\tLambda") + + def _pick_weighted(weights): + total = sum(weights) + r = random.random() * total + cumsum = 0.0 + for i, w in enumerate(weights): + cumsum += w + if r < cumsum: + return i + return len(weights) - 1 + + def _do_mcmc_step(): + nonlocal current_origination, current_delta, current_tau, current_lambda + nonlocal new_origination, new_delta, new_tau, new_lambda + nonlocal current_log_lk, current_log_prior + + move = _pick_weighted(move_weights) + scale = scale_move_params[_pick_weighted([1, 1, 1])] + hastings_ratio = 1.0 + + if move == ORIGINATION_ID: + new_origination, hastings_ratio = _scale_double_constrained(current_origination, max_origination, scale) + elif move == DELTA_ID: + new_delta, hastings_ratio = _scale_double_constrained(current_delta, max_sum_dtl - current_lambda - current_tau, scale) + elif move == LAMBDA_ID: + new_lambda, hastings_ratio = _scale_double_constrained(current_lambda, max_sum_dtl - current_delta - current_tau, scale) + elif move == TAU_ID: + new_tau, hastings_ratio = _scale_double_constrained(current_tau, max_sum_dtl - current_lambda - current_delta, scale) + + new_log_lk = _compute_log_lk(model, ale, new_origination, new_delta, new_tau, new_lambda) + new_log_prior = _compute_log_prior( + new_origination, new_delta, new_tau, new_lambda, + prior_origination, prior_delta, prior_tau, prior_lambda, + ) + + log_ratio = (new_log_lk + new_log_prior) - (current_log_lk + current_log_prior) + # Clamp exponent to prevent OverflowError; large positive values + # mean the proposal is overwhelmingly better, so auto-accept. + acceptance_prob = math.exp(min(log_ratio, 700.0)) * hastings_ratio + + if random.random() < acceptance_prob: + current_origination = new_origination + current_delta = new_delta + current_tau = new_tau + current_lambda = new_lambda + current_log_lk = new_log_lk + current_log_prior = new_log_prior + else: + new_origination = current_origination + new_delta = current_delta + new_tau = current_tau + new_lambda = current_lambda + + # BURNIN + burnin_length = 100 + print(f"BURNIN during {burnin_length} iterations.") + print("LogLk\tLogPrior\tOrigination\tDelta\tTau\tLambda") + for i in range(burnin_length): + _do_mcmc_step() + print(f"{i}\t{current_log_lk}\t{current_log_prior}\t{current_origination}\t{current_delta}\t{current_tau}\t{current_lambda}") + + # MCMC + total_iterations = int(samples * sampling_rate) + print(f"MCMC during {total_iterations} iterations.") + print("LogLk\tLogPrior\tOrigination\tDelta\tTau\tLambda") + + sample_strings = [] + num_speciations = 0.0 + num_duplications = 0.0 + num_transfers = 0.0 + num_losses = 0.0 + t_to_from_accum = {} # {("from_name", "to_name"): count} + + for i in range(total_iterations): + _do_mcmc_step() + + if i % sampling_rate == 0: + model.MLRec_events.clear() + model.reset_T_to_from() + model.Ttokens = [] + # Recompute with current params for sampling + model.set_model_parameter("O_R", current_origination) + model.set_model_parameter("delta", current_delta) + model.set_model_parameter("tau", current_tau) + model.set_model_parameter("lambda", current_lambda) + model.calculate_undatedEs() + model.pun(ale) + sample_tree = model.sample_undated() + sample_strings.append(sample_tree) + num_speciations += model.MLRec_events.get("S", 0) + num_duplications += model.MLRec_events.get("D", 0) + num_transfers += model.MLRec_events.get("T", 0) + num_losses += model.MLRec_events.get("L", 0) + # Accumulate T_to_from into running aggregate + for e in range(model.last_branch): + for f in range(model.last_branch): + if model.T_to_from[e][f] > 0: + e_name = model._node_name[model._id_nodes[e].id] if e < model.last_leaf else str(e) + f_name = model._node_name[model._id_nodes[f].id] if f < model.last_leaf else str(f) + key = (e_name, f_name) + t_to_from_accum[key] = t_to_from_accum.get(key, 0.0) + model.T_to_from[e][f] + print(f"{i}\t{current_log_lk}\t{current_log_prior}\t{current_origination}\t{current_delta}\t{current_tau}\t{current_lambda}") + + mcmc_fh.write(f"{i}\t{current_log_lk}\t{current_log_prior}\t{current_origination}\t{current_delta}\t{current_tau}\t{current_lambda}\n") + + mcmc_fh.close() + + # Write .umcmc_rec output + outname = ale_file + ".umcmc_rec" + with open(outname, "w") as fout: + fout.write(f"#ALEmcmc_undated using ALE v{ALE_VERSION} by Szollosi GJ et al.; ssolo@elte.hu; CC BY-SA 3.0;\n\n") + s_tree_str = model.string_parameter.get("S_with_ranks", model.string_parameter.get("S_un", Sstring)) + fout.write(f"S:\t{s_tree_str}\n") + fout.write("\n") + fout.write(f"Input ale from:\t{ale_file}\n") + fout.write("\n") + fout.write(f"{int(samples)} reconciled G-s:\n\n") + for s in sample_strings: + fout.write(s + "\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tSpeciations\n") + div = samples if samples > 0 else 1 + fout.write( + f"Total \t{num_duplications / div}\t" + f"{num_transfers / div}\t" + f"{num_losses / div}\t" + f"{num_speciations / div}\n" + ) + fout.write("\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tOriginations\tcopies\tsingletons\textinction_prob\tpresence\tLL\n") + fout.write(model.counts_string_undated(samples)) + + if output_species_tree: + with open(output_species_tree, "w") as fout: + s_tree_str = model.string_parameter.get("S_with_ranks", model.string_parameter.get("S_un", Sstring)) + fout.write(s_tree_str + "\n") + + print(f"Results in: {outname}") + + # Transfer file (from accumulated counts across all samples) + t_name = ale_file + "_mcmc.uTs" + with open(t_name, "w") as tout: + tout.write("#from\tto\tfreq.\n") + div = samples if samples > 0 else 1 + for (e_name, f_name), count in sorted(t_to_from_accum.items()): + tout.write(f"\t{e_name}\t{f_name}\t{count / div}\n") + print(f"Transfers in: {t_name}") + return 0 + + +# --------------------------------------------------------------------------- +# ALEcount +# --------------------------------------------------------------------------- + +def ALEcount(argv): + """Load an ALE file and print the number of amalgamated trees.""" + if not argv: + _print_subcommand_help("ALEcount") + return 1 + + ale_file = argv[0] + ale = load_ALE_from_file(ale_file) + print(ale.count_trees()) + return 0 + + +# --------------------------------------------------------------------------- +# ls_leaves +# --------------------------------------------------------------------------- + +def ls_leaves(argv): + """For each file: parse tree, list leaves with counts.""" + if not argv: + _print_subcommand_help("ls_leaves") + return 1 + + names = {} + for fname in argv: + with open(fname) as f: + tree_str = f.readline().strip() + root = parse_newick(tree_str) + leaves = get_leaves(root) + for leaf in leaves: + name = leaf.name + names[name] = names.get(name, 0) + 1 + + for name in sorted(names.keys()): + print(f"{name} {names[name]}") + return 0 + + +# --------------------------------------------------------------------------- +# CCPscore +# --------------------------------------------------------------------------- + +def CCPscore(argv): + """Load an ALE file, read a tree, and print log(p(tree)).""" + if len(argv) < 2: + _print_subcommand_help("CCPscore") + return 1 + + ale_file = argv[0] + tree_file = argv[1] + + ale = load_ALE_from_file(ale_file) + with open(tree_file) as f: + tree_str = f.readline().strip() + p_val = ale.p(tree_str) + if p_val > 0: + print(math.log(p_val)) + else: + print("-inf") + return 0 + + +# --------------------------------------------------------------------------- +# ALEadd +# --------------------------------------------------------------------------- + +def ALEadd(argv): + """Load an existing ALE file and add new trees.""" + print(f"ALEadd using ALE v{ALE_VERSION}") + + if len(argv) < 2: + _print_subcommand_help("ALEadd") + return 1 + + ale_file = argv[0].strip() + trees_file = argv[1].strip() + ale_name = ale_file + + burnin = 0 + every = 1 + until = -1 + weight = 1.0 + + for arg in argv[2:]: + if "=" in arg: + tokens = arg.split("=") + key = tokens[0] + val = tokens[1] + if key == "burnin": + burnin = int(val) + elif key == "every": + every = int(val) + elif key == "until": + until = int(val) + elif key == "weight": + weight = float(val) + elif key == "outfile": + ale_name = val + + ale = load_ALE_from_file(ale_file) + print(".") + + # Read trees from file + trees = [] + tree_i = 0 + with open(trees_file) as f: + for line in f: + line = line.rstrip("\n") + if "(" in line: + if tree_i >= burnin and tree_i % every == 0: + trees.append(line) + tree_i += 1 + + print("..") + + observe_trees = trees + if until > 0: + observe_trees = trees[:until] + + ale.observation(observe_trees, weight=weight) + + print(f"# {len(observe_trees)} new tree(s) observed with weight {weight} from: {trees_file}") + print(f"; {burnin} trees burnin discarded.") + print(f"# .ale with {ale.observations} tree(s) from: {ale_file} and {trees_file}") + ale.save_state(ale_name) + print(f"# saved in {ale_name}") + return 0 + + +# --------------------------------------------------------------------------- +# ALEevaluate_undated +# --------------------------------------------------------------------------- + +def ALEevaluate_undated(argv): + """Evaluate a single gene tree under the undated DTL model.""" + print(f"ALEestimate using ALE v{ALE_VERSION}") + + if len(argv) < 2: + _print_subcommand_help("ALEevaluate_undated") + return 1 + + species_tree_file = argv[0].strip() + gene_tree_file = argv[1].strip() + + # Read species tree + species_tree_str = "" + with open(species_tree_file) as f: + for line in f: + line = line.strip() + if "(" in line: + species_tree_str = line + print(f"\n\tRead species tree from: {species_tree_file}") + + # Read gene tree + gene_tree_str = "" + with open(gene_tree_file) as f: + for line in f: + line = line.strip() + if "(" in line: + gene_tree_str = line + + ale = ApproxPosterior(gene_tree_str) + ale.observation([gene_tree_str]) + print(f"\n\tObserved {ale.observations} gene tree(s) from: {gene_tree_file}") + + model = ExODTModel() + + samples = 100 + O_R = 1.0 + beta = 1.0 + delta = 0.01 + tau = 0.01 + lambda_ = 0.1 + fraction_missing_file = "" + output_files = False + + for arg in argv[2:]: + if "=" in arg: + tokens = arg.split("=") + key = tokens[0] + val = tokens[1] + if key == "sample": + samples = int(val) + elif key == "separators": + model.set_model_parameter("gene_name_separators", val) + elif key == "delta": + delta = float(val) + print(f"\n\tDelta fixed to {delta}") + elif key == "tau": + tau = float(val) + print(f"\n\tTau fixed to {tau}") + elif key == "lambda": + lambda_ = float(val) + print(f"Lambda fixed to {lambda_}") + elif key == "O_R": + O_R = float(val) + print(f"\n\tO_R set to {O_R}") + elif key == "beta": + beta = float(val) + print(f"\n\tBeta set to {beta}") + elif key == "fraction_missing": + fraction_missing_file = val + print(f"\n\tFile containing fractions of missing genes set to {fraction_missing_file}") + elif key == "outputFiles": + if val.lower() in ("y", "yes"): + output_files = True + + model.set_model_parameter("BOOTSTRAP_LABELS", "yes") + model.construct_undated(species_tree_str, fraction_missing_file) + + model.set_model_parameter("seq_beta", beta) + model.set_model_parameter("O_R", O_R) + model.set_model_parameter("delta", delta) + model.set_model_parameter("tau", tau) + model.set_model_parameter("lambda", lambda_) + + model.calculate_undatedEs() + loglk = math.log(max(model.pun(ale, True), sys.float_info.min)) + print(f"\n\tReconciliation model likelihood computed, logLk: {loglk}") + + if output_files: + print("\n\tSampling reconciled gene trees..") + sample_strings = [] + total_events = {"D": 0.0, "T": 0.0, "L": 0.0, "S": 0.0} + for i in range(int(samples)): + model.MLRec_events.clear() + model.Ttokens = [] + sample_tree = model.sample_undated() + sample_strings.append(sample_tree) + for key in total_events: + total_events[key] += model.MLRec_events.get(key, 0.0) + + ale_name = os.path.basename(gene_tree_file) + outname = ale_name + ".uml_rec" + with open(outname, "w") as fout: + fout.write(f"#ALEevaluate using ALE v{ALE_VERSION}; CC BY-SA 3.0;\n\n") + s_tree_str = model.string_parameter.get("S_with_ranks", model.string_parameter.get("S_un", species_tree_str)) + fout.write(f"S:\t{s_tree_str}\n") + fout.write("\n") + fout.write(f"Gene tree from:\t{gene_tree_file}\n") + fout.write(f">logl: {loglk}\n") + fout.write("rate of\t Duplications\tTransfers\tLosses\n") + fout.write(f"\t{delta}\t{tau}\t{lambda_}\n") + fout.write("\n") + fout.write(f"{int(samples)} reconciled G-s:\n\n") + for s in sample_strings: + fout.write(s + "\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tSpeciations\n") + div = samples if samples > 0 else 1 + fout.write( + f"Total \t{total_events['D'] / div}\t" + f"{total_events['T'] / div}\t" + f"{total_events['L'] / div}\t" + f"{total_events['S'] / div}\n" + ) + fout.write("\n") + fout.write("# of\t Duplications\tTransfers\tLosses\tOriginations\tcopies\n") + fout.write(model.counts_string_undated(samples)) + + print(f"Results in: {outname}") + + # Transfer file + t_name = ale_name + ".uTs" + with open(t_name, "w") as tout: + tout.write("#from\tto\tfreq.\n") + for e in range(model.last_branch): + for f in range(model.last_branch): + if model.T_to_from[e][f] > 0: + if e < model.last_leaf: + e_name = model._node_name[model._id_nodes[e].id] + else: + e_name = str(e) + if f < model.last_leaf: + f_name = model._node_name[model._id_nodes[f].id] + else: + f_name = str(f) + tout.write(f"\t{e_name}\t{f_name}\t{model.T_to_from[e][f] / div}\n") + print(f"Transfers in: {t_name}") + + return 0 + + +# --------------------------------------------------------------------------- +# Dispatch table and main entry point +# --------------------------------------------------------------------------- + +PROGRAMS = { + "ALEobserve": ALEobserve, + "ALEml_undated": ALEml_undated, + "ALEmcmc_undated": ALEmcmc_undated, + "ALEcount": ALEcount, + "ls_leaves": ls_leaves, + "CCPscore": CCPscore, + "ALEadd": ALEadd, + "ALEevaluate_undated": ALEevaluate_undated, +} + + +def _build_argparse(): + """Build an argparse parser with subcommands for help/documentation. + + The actual argument parsing is done by each subcommand's bespoke + parser for backward compatibility with the C++ key=value syntax. + argparse is used only for --help and subcommand dispatch. + """ + import argparse + + parser = argparse.ArgumentParser( + prog="python -m ALE_python", + description=( + "ALE (Amalgamated Likelihood Estimation) — " + "Python port of the phylogenetic reconciliation toolkit. " + "Options accept both C++ style (delta=0.01) and long-option " + "style (--delta 0.01). " + "Use 'python -m ALE_python --help' for per-command help." + ), + ) + parser.add_argument( + "--version", action="version", version=f"ALE {ALE_VERSION}", + ) + subparsers = parser.add_subparsers(dest="command", title="commands") + + # --- ALEobserve --- + p = subparsers.add_parser( + "ALEobserve", + help="Build an .ale file from a sample of gene trees", + description=( + "Read a sample of gene trees in Newick format and construct an " + "approximate posterior (ALE) file summarising clade frequencies. " + "The output .ale file is used as input by ALEml_undated, " + "ALEmcmc_undated, and other tools." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Both key=value and --key value styles are accepted for options.\n\n" + "examples:\n" + " python -m ALE_python ALEobserve gene_trees.newicks\n" + " python -m ALE_python ALEobserve gene_trees.newicks burnin=1000\n" + " python -m ALE_python ALEobserve gene_trees.newicks --burnin 1000" + ), + ) + p.add_argument("gene_trees", nargs="+", help="one or more Newick gene tree files") + p.add_argument("burnin", nargs="?", default="burnin=0", help="burnin=N — discard the first N trees per file (default: 0)") + + # --- ALEml_undated --- + p = subparsers.add_parser( + "ALEml_undated", + help="Maximum-likelihood reconciliation under the undated DTL model", + description=( + "Optimise duplication, transfer, and loss (DTL) rates by maximum " + "likelihood under the undated reconciliation model, then sample " + "reconciled gene trees. Reads a species tree and an .ale file." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Both key=value and --key value styles are accepted for options.\n\n" + "options:\n" + " sample=N number of reconciled trees to sample (default: 100)\n" + " seed=INT random seed for reproducibility\n" + " separators=STR gene-name separator characters\n" + " delta=FLOAT fix duplication rate (skip optimisation)\n" + " tau=FLOAT fix transfer rate (tau<1e-10 disables transfers)\n" + " lambda=FLOAT fix loss rate\n" + " O_R=FLOAT origination-at-root multiplier (default: 1.0)\n" + " MLOR optimise the root origination multiplier\n" + " DT=FLOAT fix duplication/transfer ratio\n" + " beta=FLOAT weight of sequence evidence (default: 1.0)\n" + " fraction_missing=FILE file with per-species missing-gene fractions\n" + " S_branch_lengths:ROOT_BL use species-tree branch lengths as rate multipliers\n" + " reldate respect relative dates from an ultrametric species tree\n" + " output_species_tree=y write annotated species tree to a .spTree file\n" + " rate_multiplier:RATE:BRANCH:VALUE\n" + " set or optimise a per-branch rate multiplier\n" + " (value >= -1 → fixed; value < -1 → optimised)\n\n" + "examples:\n" + " python -m ALE_python ALEml_undated species.nwk genes.ale\n" + " python -m ALE_python ALEml_undated species.nwk genes.ale sample=1000 seed=42\n" + " python -m ALE_python ALEml_undated species.nwk genes.ale --delta 0.01 --tau 0.01 --lambda 0.1" + ), + ) + p.add_argument("species_tree", help="species tree in Newick format") + p.add_argument("ale_file", help=".ale file from ALEobserve") + p.add_argument("options", nargs="*", help="key=value options (see below)") + + # --- ALEmcmc_undated --- + p = subparsers.add_parser( + "ALEmcmc_undated", + help="MCMC sampling of reconciliation under the undated DTL model", + description=( + "Run a Metropolis-Hastings MCMC chain to sample DTL rates and " + "reconciled gene trees under the undated model. Priors are " + "exponential distributions parametrised by the given rate values." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Both key=value and --key value styles are accepted for options.\n\n" + "options:\n" + " sample=N total MCMC iterations (default: 100)\n" + " sampling_rate=N record a sample every N iterations (default: 1)\n" + " separators=STR gene-name separator characters\n" + " delta=FLOAT prior rate for duplications (default: 0.01)\n" + " tau=FLOAT prior rate for transfers (default: 0.01)\n" + " lambda=FLOAT prior rate for losses (default: 0.1)\n" + " O_R=FLOAT prior for origination at root (default: 1.0)\n" + " beta=FLOAT weight of sequence evidence (default: 1.0)\n" + " fraction_missing=FILE file with per-species missing-gene fractions\n" + " S_branch_lengths:ROOT_BL use species-tree branch lengths as rate multipliers\n" + " rate_multiplier:RATE:BRANCH:VALUE\n" + " fix a per-branch rate multiplier\n" + " output_species_tree=y write annotated species tree\n" + " reldate respect relative dates from an ultrametric species tree\n\n" + "examples:\n" + " python -m ALE_python ALEmcmc_undated species.nwk genes.ale sample=10000\n" + " python -m ALE_python ALEmcmc_undated species.nwk genes.ale --sample 5000 --sampling-rate 10" + ), + ) + p.add_argument("species_tree", help="species tree in Newick format") + p.add_argument("ale_file", help=".ale file from ALEobserve") + p.add_argument("options", nargs="*", help="key=value options (see below)") + + # --- ALEcount --- + p = subparsers.add_parser( + "ALEcount", + help="Print the number of amalgamated trees in an .ale file", + description="Load an .ale file and print the number of distinct amalgamated gene tree topologies.", + ) + p.add_argument("ale_file", help=".ale file to count trees from") + + # --- ls_leaves --- + p = subparsers.add_parser( + "ls_leaves", + help="List leaf names from one or more tree files", + description="Parse each Newick tree file and print sorted leaf names with occurrence counts.", + ) + p.add_argument("tree_files", nargs="+", help="one or more Newick tree files") + + # --- CCPscore --- + p = subparsers.add_parser( + "CCPscore", + help="Score a tree under the conditional clade probability model", + description="Load an .ale file and a tree, then print log(P(tree)) under the CCP model.", + ) + p.add_argument("ale_file", help=".ale file with clade posteriors") + p.add_argument("tree_file", help="Newick tree file to score") + + # --- ALEadd --- + p = subparsers.add_parser( + "ALEadd", + help="Add new tree observations to an existing .ale file", + description=( + "Load an existing .ale file and observe additional gene trees from " + "a Newick file, updating the clade frequencies." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Both key=value and --key value styles are accepted for options.\n\n" + "options:\n" + " weight=FLOAT observation weight for new trees (default: 1.0)\n" + " burnin=N discard first N trees (default: 0)\n" + " every=N keep every Nth tree (default: 1)\n" + " until=N use at most N trees (default: all)\n" + " outfile=PATH output .ale file name (default: overwrite input)\n\n" + "examples:\n" + " python -m ALE_python ALEadd existing.ale new_trees.newicks\n" + " python -m ALE_python ALEadd existing.ale new_trees.newicks --burnin 100 --weight 0.5" + ), + ) + p.add_argument("ale_file", help="existing .ale file") + p.add_argument("gene_trees", help="Newick gene tree file to add") + p.add_argument("options", nargs="*", help="key=value options (see below)") + + # --- ALEevaluate_undated --- + p = subparsers.add_parser( + "ALEevaluate_undated", + help="Evaluate a single gene tree under the undated DTL model", + description=( + "Compute the reconciliation likelihood of a single gene tree " + "against a species tree under the undated DTL model with fixed rates." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Both key=value and --key value styles are accepted for options.\n\n" + "options:\n" + " sample=N number of reconciled trees to sample (default: 100)\n" + " separators=STR gene-name separator characters\n" + " delta=FLOAT duplication rate (default: 0.01)\n" + " tau=FLOAT transfer rate (default: 0.01)\n" + " lambda=FLOAT loss rate (default: 0.1)\n" + " O_R=FLOAT origination-at-root multiplier (default: 1.0)\n" + " beta=FLOAT weight of sequence evidence (default: 1.0)\n" + " fraction_missing=FILE per-species missing-gene fractions\n" + " outputFiles=y write .uml_rec and .uTs output files\n\n" + "examples:\n" + " python -m ALE_python ALEevaluate_undated species.nwk gene.nwk\n" + " python -m ALE_python ALEevaluate_undated species.nwk gene.nwk --delta 0.02 --outputFiles y" + ), + ) + p.add_argument("species_tree", help="species tree in Newick format") + p.add_argument("gene_tree", help="gene tree in Newick format") + p.add_argument("options", nargs="*", help="key=value options (see below)") + + return parser + + +def main(): + # If invoked with no args, print help; for --help/--version let argparse handle it + if len(sys.argv) < 2: + parser = _build_argparse() + parser.print_help() + return + if sys.argv[1] in ("-h", "--help", "--version"): + parser = _build_argparse() + parser.parse_args() + return + + program_name = sys.argv[1] + + # If the subcommand itself is followed by --help, use argparse + if "--help" in sys.argv[2:] or "-h" in sys.argv[2:]: + parser = _build_argparse() + parser.parse_args() + return + + if program_name not in PROGRAMS: + print(f"Error: unknown program '{program_name}'") + print(f"Available programs: {', '.join(sorted(PROGRAMS.keys()))}") + sys.exit(1) + + # Normalize --key value args to key=value, then dispatch to the + # original bespoke parser for backward compatibility + remaining_args = _normalize_argv(sys.argv[2:]) + ret = PROGRAMS[program_name](remaining_args) + sys.exit(ret or 0) diff --git a/ALE_python/exodt.py b/ALE_python/exodt.py new file mode 100644 index 0000000..97599ce --- /dev/null +++ b/ALE_python/exodt.py @@ -0,0 +1,1273 @@ +"""exODT model -- core reconciliation model for ALE (undated). + +Ported from the C++ files: exODT.h, exODT.cpp, undated.cpp. +Focuses on the UNDATED model (most commonly used). + +All code by Szollosi GJ et al.; ssolo@elte.hu; GNU GPL 3.0; +Python port. +""" + +import math +import re +import sys +from collections import defaultdict + +from .newick import parse_newick, get_leaves, get_all_nodes, is_leaf +from .fraction_missing import read_fraction_missing_file + +# Smallest positive float -- used to avoid log(0) and division by zero +EPSILON = sys.float_info.min + + +class ExODTModel: + """Species-tree / gene-tree reconciliation model (undated variant). + + This class contains the description of a species tree with parameters of + duplication, transfer and loss. Given an ``ApproxPosterior`` object it can + compute the probability of the approximate posterior under the DTL model + and sample reconciled gene trees. + """ + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + def __init__(self): + # Parameters -------------------------------------------------- + self.scalar_parameter: dict[str, float] = {} + self.vector_parameter: dict[str, list[float]] = {} + self.string_parameter: dict[str, str] = {} + + # Species tree topology --------------------------------------- + self.father: dict[int, int] = {} # node_id -> parent_id + self.daughter: dict[int, int] = {} # node_id -> left child + self.son: dict[int, int] = {} # node_id -> right child + self.extant_species: dict[int, str] = {} # leaf_id -> species name + + self.last_branch: int = 0 + self.last_leaf: int = 0 + self.last_rank: int = 0 + self.root_i: int = -1 + + # Node objects keyed by name / branch id ---------------------- + self._name_node: dict[str, object] = {} + self._node_name: dict[int, str] = {} # node.id -> name + self._node_ids: dict[int, int] = {} # newick-node.id -> branch id + self._id_nodes: dict[int, object] = {} # branch id -> newick Node + + # Undated model arrays ---------------------------------------- + self.fm: list[float] = [] + self.uE: list[float] = [] + self.mPTE: float = 0.0 + self.mPTE_ancestral_correction: list[float] = [] + self.uq: list[list[float]] = [] + self.mPTuq: list[float] = [] + self.mPTuq_ancestral_correction: list[list[float]] = [] + + self.PD: list[float] = [] + self.wT: list[float] = [] + self.PL: list[float] = [] + self.PS: list[float] = [] + self.rmD: list[float] = [] + self.rmT: list[float] = [] + self.rmL: list[float] = [] + self.tau_norm: list[float] = [] + + self.ancestors: list[list[int]] = [] + self.below: dict[int, dict[int, int]] = {} + self.ancestral: dict[int, dict[int, int]] = {} + + # ALE pointer and gene-clade bookkeeping --------------------- + self.ale_pointer = None + self.g_ids: list[int] = [] + self.g_id_sizes: list[int] = [] + self.g_id2i: dict[int, int] = {} + self.gid_sps: dict[int, str] = {} + + # Event tracking ---------------------------------------------- + self.MLRec_events: dict[str, float] = defaultdict(float) + self.branch_counts: dict[str, list[float]] = {} + self.T_to_from: list[list[float]] = [] + self.Ttokens: list[str] = [] + + # Rank-to-label map (bootstrap values) ------------------------ + self.rank2label: dict[int, int] = {} + + # Set default parameters (mirrors C++ constructor) ------------ + self.string_parameter["BOOTSTRAP_LABELS"] = "no" + self.string_parameter["gene_name_separators"] = "_@" + self.scalar_parameter["species_field"] = 0 + self.scalar_parameter["event_node"] = 0 + self.scalar_parameter["min_bip_count"] = -1 + self.scalar_parameter["min_branch_lenghts"] = 0 + self.scalar_parameter["stem_length"] = 1 + self.scalar_parameter["D"] = 3 + self.scalar_parameter["grid_delta_t"] = 0.005 + self.scalar_parameter["min_D"] = 3 + self.scalar_parameter["DD"] = 10 + self.scalar_parameter["O_R"] = 1.0 + self.scalar_parameter["seq_beta"] = 1.0 + self.scalar_parameter["undatedBL"] = 0 + self.scalar_parameter["reldate"] = 0 + self.scalar_parameter["root_BL"] = 0 + + # ------------------------------------------------------------------ + # Tree height helper (average of the two children, like the C++) + # ------------------------------------------------------------------ + + def _height(self, node) -> float: + """Compute height of a node: 0 for leaves, otherwise average of + (child branch length + child height) over both children.""" + if is_leaf(node): + return 0.0 + sons = node.children + h0 = (sons[0].branch_length or 0.0) + self._height(sons[0]) + h1 = (sons[1].branch_length or 0.0) + self._height(sons[1]) + return 0.5 * (h0 + h1) + + # ------------------------------------------------------------------ + # construct_undated + # ------------------------------------------------------------------ + + def construct_undated(self, s_string: str, fraction_missing_file: str = ""): + """Construct the undated reconciliation model from a Newick species + tree string and (optionally) a fraction-missing file. + + Mirrors ``exODT_model::construct_undated`` in *undated.cpp*. + """ + # Reset maps + self.daughter.clear() + self.son.clear() + self._name_node.clear() + self._node_name.clear() + self._node_ids.clear() + self._id_nodes.clear() + + self.string_parameter["S_un"] = s_string + + # Parse the species tree + root = parse_newick(s_string) + all_nodes = get_all_nodes(root) # post-order + + # Build name_node / node_name maps + for nd in all_nodes: + if is_leaf(nd): + self._name_node[nd.name] = nd + self._node_name[nd.id] = nd.name + else: + leaf_names = sorted(lf.name for lf in get_leaves(nd)) + name = ".".join(leaf_names) + "." + self._name_node[name] = nd + self._node_name[nd.id] = name + + # Register leaves (alphabetical order, as C++ iterates std::map) + self.last_branch = 0 + self.last_leaf = 0 + + self.vector_parameter["BL_rate_multiplier"] = [] + self.vector_parameter["rate_multiplier_tau_to"] = [] + self.vector_parameter["rate_multiplier_tau_from"] = [] + self.vector_parameter["rate_multiplier_delta"] = [] + self.vector_parameter["rate_multiplier_lambda"] = [] + self.vector_parameter["rate_multiplier_O"] = [] + self.wT = [] + self.rmD = [] + self.rmT = [] + self.rmL = [] + + saw = set() # set of newick-node ids already registered + + # Sort leaf names alphabetically (std::map ordering) + leaf_names_sorted = sorted( + name for name, nd in self._name_node.items() if is_leaf(nd) + ) + + for lname in leaf_names_sorted: + node = self._name_node[lname] + self.extant_species[self.last_branch] = node.name + self._node_ids[node.id] = self.last_branch + self._id_nodes[self.last_branch] = node + self.last_branch += 1 + self.last_leaf += 1 + saw.add(node.id) + # Leaves have no children + self.daughter[self.last_branch] = -1 + self.son[self.last_branch] = -1 + self.vector_parameter["BL_rate_multiplier"].append( + node.branch_length if node.branch_length is not None else 0.0 + ) + self.vector_parameter["rate_multiplier_tau_to"].append(1.0) + self.vector_parameter["rate_multiplier_tau_from"].append(1.0) + self.wT.append(1.0) + self.rmD.append(1.0) + self.rmT.append(1.0) + self.rmL.append(1.0) + self.vector_parameter["rate_multiplier_delta"].append(1.0) + self.vector_parameter["rate_multiplier_lambda"].append(1.0) + self.vector_parameter["rate_multiplier_O"].append(1.0) + + # Ad-hoc post-order: propagate from leaves upward + next_generation = [ + self._name_node[lname] for lname in leaf_names_sorted + ] + + while next_generation: + new_generation = [] + for node in next_generation: + if node.parent is not None: + parent = node.parent + sons = parent.children + sister = sons[1] if sons[0] is node else sons[0] + if parent.id not in self._node_ids and sister.id in saw: + self._node_ids[parent.id] = self.last_branch + self._id_nodes[self.last_branch] = parent + # Note: C++ stores node->getDistanceToFather() which + # is the branch length of the *child* that triggered + # this registration. Replicating that behaviour. + self.vector_parameter["BL_rate_multiplier"].append( + node.branch_length + if node.branch_length is not None + else 0.0 + ) + self.vector_parameter["rate_multiplier_tau_to"].append(1.0) + self.vector_parameter["rate_multiplier_tau_from"].append(1.0) + self.wT.append(1.0) + self.rmD.append(1.0) + self.rmT.append(1.0) + self.rmL.append(1.0) + self.vector_parameter["rate_multiplier_delta"].append(1.0) + self.vector_parameter["rate_multiplier_lambda"].append(1.0) + self.vector_parameter["rate_multiplier_O"].append(1.0) + self.last_branch += 1 + saw.add(parent.id) + new_generation.append(parent) + next_generation = new_generation + + # Build ``below`` matrix: below[e][f] = 1 iff + # height(father(e)) < height(f) + self.below.clear() + for e in range(self.last_branch - 1): + self.below[e] = {} + node_e = self._id_nodes[e] + h_father_e = self._height(node_e.parent) if node_e.parent is not None else 0.0 + for f in range(self.last_branch - 1): + node_f = self._id_nodes[f] + h_f = self._height(node_f) + self.below[e][f] = 1 if h_father_e < h_f else 0 + + # Extra rate-multiplier entries for the stem above root + # (index == last_branch, which is one past the last registered) + # In C++ this writes to vector_parameter["BL_rate_multiplier"][last_branch] + # which extends the vector by one. + bl_vec = self.vector_parameter["BL_rate_multiplier"] + if len(bl_vec) <= self.last_branch: + bl_vec.append(self.scalar_parameter.get("root_BL", 0.0)) + else: + bl_vec[self.last_branch] = self.scalar_parameter.get("root_BL", 0.0) + self.vector_parameter["rate_multiplier_tau_to"].append(1.0) + self.vector_parameter["rate_multiplier_tau_from"].append(1.0) + self.wT.append(1.0) + self.vector_parameter["rate_multiplier_delta"].append(1.0) + self.vector_parameter["rate_multiplier_lambda"].append(1.0) + self.vector_parameter["rate_multiplier_O"].append(1.0) + + # Build ``ancestral`` and ``ancestors`` ---------------------------- + self.ancestors = [[] for _ in range(self.last_branch)] + self.ancestral = {} + for e in range(self.last_branch): + self.ancestral[e] = {} + for f in range(self.last_branch): + self.ancestral[e][f] = 0 + + for nd in all_nodes: + if nd.id not in self._node_ids: + continue + e = self._node_ids[nd.id] + walker = nd + while True: + f = self._node_ids[walker.id] + if not self.ancestral[e][f]: + self.ancestors[e].append(f) + self.ancestral[e][f] = 1 + if walker.parent is None: + break + walker = walker.parent + + # If reldate mode, also mark ``below`` entries as ancestral + if self.scalar_parameter.get("reldate", 0): + for e in range(self.last_branch): + for f in range(self.last_branch): + if self.below.get(e, {}).get(f, 0) == 1: + if not self.ancestral[e][f]: + self.ancestors[e].append(f) + self.ancestral[e][f] = 1 + + # Set daughter / son for internal nodes + for name, nd in self._name_node.items(): + if not is_leaf(nd): + sons = nd.children + self.daughter[self._node_ids[nd.id]] = self._node_ids[sons[0].id] + self.son[self._node_ids[nd.id]] = self._node_ids[sons[1].id] + + # Generate S_with_ranks: species tree with branch IDs as labels + # Mirrors C++: node->setBranchProperty("ID", rank) then treeToParenthesis + def _to_newick_with_ranks(nd): + if is_leaf(nd): + return nd.name + child_strs = [_to_newick_with_ranks(c) for c in nd.children] + rank = self._node_ids[nd.id] + return "(" + ",".join(child_strs) + ")" + str(rank) + self.string_parameter["S_with_ranks"] = _to_newick_with_ranks(root) + ";" + + # Initialise branch_counts + count_keys = [ + "Os", "Ds", "Ts", "Tfroms", "Ls", "count", + "presence", "saw", "O_LL", "copies", "singleton", + ] + for key in count_keys: + self.branch_counts[key] = [0.0] * self.last_branch + + # Initialise T_to_from matrix + self.T_to_from = [[0.0] * self.last_branch for _ in range(self.last_branch)] + + # last_rank mirrors C++ + self.last_rank = self.last_branch + + # Set N=1 (undated does not use population size) + self.set_model_parameter("N", 1.0) + + # Fraction missing + self.vector_parameter["fraction_missing"] = [0.0] * self.last_leaf + if fraction_missing_file: + frac_map = read_fraction_missing_file(fraction_missing_file) + idx = 0 + for lname in leaf_names_sorted: + node = self._name_node[lname] + species = node.name + if species in frac_map: + self.vector_parameter["fraction_missing"][idx] = frac_map[species] + idx += 1 + + # ------------------------------------------------------------------ + # set_model_parameter (three overloads merged into one) + # ------------------------------------------------------------------ + + def set_model_parameter(self, name: str, value): + """Set a model parameter. + + *value* can be a ``str``, a ``float``/``int``, or a ``list[float]``. + Mirrors the three C++ overloads. + """ + if isinstance(value, str): + self.string_parameter[name] = value + return + + if isinstance(value, (list, tuple)): + self._set_model_parameter_vector(name, list(value)) + return + + # Scalar + value = float(value) + if name in ("delta", "tau", "lambda"): + N = self.vector_parameter.get("N", [1.0])[0] + self.vector_parameter[name] = [] + for _branch in range(self.last_branch): + if name == "tau": + self.vector_parameter[name].append(value / N) + else: + self.vector_parameter[name].append(value) + if name == "tau": + self.scalar_parameter[name + "_avg"] = value / N + else: + self.scalar_parameter[name + "_avg"] = value + elif name in ("N", "Delta_bar", "Lambda_bar"): + self.vector_parameter[name] = [value] * self.last_rank + else: + self.scalar_parameter[name] = value + + def _set_model_parameter_vector(self, name: str, value_vector: list[float]): + if name in ("delta", "tau", "lambda"): + N = self.vector_parameter.get("N", [1.0])[0] + self.vector_parameter[name] = [] + avg = 0.0 + for branch in range(self.last_branch): + if name == "tau": + v = value_vector[branch] / N + else: + v = value_vector[branch] + self.vector_parameter[name].append(v) + avg += v + self.scalar_parameter[name + "_avg"] = avg / max(self.last_branch, 1) + else: + self.vector_parameter[name] = [] + for rank in range(self.last_rank): + self.vector_parameter[name].append(value_vector[rank]) + + # ------------------------------------------------------------------ + # calculate_undatedEs + # ------------------------------------------------------------------ + + def calculate_undatedEs(self): + """Compute extinction probabilities for each branch (undated model). + + Mirrors ``exODT_model::calculate_undatedEs`` in *undated.cpp*. + """ + self.uE = [] + self.fm = [] + self.mPTE_ancestral_correction = [] + self.PD = [] + self.tau_norm = [] + self.PL = [] + self.PS = [] + + use_bl = bool(self.scalar_parameter.get("undatedBL", 0)) + + # Compute raw rates per branch + for e in range(self.last_branch): + if use_bl: + self.wT[e] = ( + self.vector_parameter["rate_multiplier_tau_to"][e] + * self.vector_parameter["BL_rate_multiplier"][e] + ) + self.rmT[e] = ( + self.vector_parameter["tau"][e] + * self.vector_parameter["rate_multiplier_tau_from"][e] + * self.vector_parameter["BL_rate_multiplier"][e] + ) + self.rmD[e] = ( + self.vector_parameter["delta"][e] + * self.vector_parameter["rate_multiplier_delta"][e] + * self.vector_parameter["BL_rate_multiplier"][e] + ) + self.rmL[e] = ( + self.vector_parameter["lambda"][e] + * self.vector_parameter["rate_multiplier_lambda"][e] + * self.vector_parameter["BL_rate_multiplier"][e] + ) + else: + self.wT[e] = self.vector_parameter["rate_multiplier_tau_to"][e] + self.rmT[e] = ( + self.vector_parameter["tau"][e] + * self.vector_parameter["rate_multiplier_tau_from"][e] + ) + self.rmD[e] = ( + self.vector_parameter["delta"][e] + * self.vector_parameter["rate_multiplier_delta"][e] + ) + self.rmL[e] = ( + self.vector_parameter["lambda"][e] + * self.vector_parameter["rate_multiplier_lambda"][e] + ) + + tau_sum = sum(self.wT[f] for f in range(self.last_branch)) + + for e in range(self.last_branch): + P_D = self.rmD[e] + P_T = self.rmT[e] + P_L = self.rmL[e] + P_S = 1.0 + + total = P_D + P_T + P_L + P_S + P_D /= total + P_T /= total + P_L /= total + P_S /= total + + self.PD.append(P_D) + + # tau_norm[e] = (tau_sum - sum_wT_ancestors) / P_T + tn = tau_sum + for f in self.ancestors[e]: + tn -= self.wT[f] + if P_T > 0: + tn /= P_T + else: + tn = 1.0 # avoid division by zero + self.tau_norm.append(tn) + + self.PL.append(P_L) + self.PS.append(P_S) + self.uE.append(0.0) + + if e < self.last_leaf: + self.fm.append(self.vector_parameter["fraction_missing"][e]) + else: + self.fm.append(0.0) + + self.mPTE_ancestral_correction.append(0.0) + + # Iterative computation of extinction probabilities (4 iterations) + self.mPTE = 0.0 + for iteration in range(4): + new_mPTE = 0.0 + + if iteration > 0: + for e in range(self.last_branch): + self.mPTE_ancestral_correction[e] = 0.0 + for f in self.ancestors[e]: + self.mPTE_ancestral_correction[e] += self.wT[f] * self.uE[f] + + for e in range(self.last_branch): + if e < self.last_leaf: + # Leaf: no speciation + self.uE[e] = ( + self.PL[e] + + self.PD[e] * self.uE[e] * self.uE[e] + + self.uE[e] + * (self.mPTE - self.mPTE_ancestral_correction[e]) + / self.tau_norm[e] + ) + else: + f = self.daughter[e] + g = self.son[e] + self.uE[e] = ( + self.PL[e] + + self.PS[e] * self.uE[f] * self.uE[g] + + self.PD[e] * self.uE[e] * self.uE[e] + + self.uE[e] + * (self.mPTE - self.mPTE_ancestral_correction[e]) + / self.tau_norm[e] + ) + new_mPTE += self.wT[e] * self.uE[e] + self.mPTE = new_mPTE + + # One more update incorporating fraction missing + new_mPTE = 0.0 + for e in range(self.last_branch): + if e < self.last_leaf: + self.uE[e] = (1.0 - self.fm[e]) * self.uE[e] + self.fm[e] + else: + f = self.daughter[e] + g = self.son[e] + self.uE[e] = ( + self.PL[e] + + self.PS[e] * self.uE[f] * self.uE[g] + + self.PD[e] * self.uE[e] * self.uE[e] + + self.uE[e] + * (self.mPTE - self.mPTE_ancestral_correction[e]) + / self.tau_norm[e] + ) + new_mPTE += self.wT[e] * self.uE[e] + self.mPTE = new_mPTE + + # ------------------------------------------------------------------ + # Gene-name to species-name mapping helper + # ------------------------------------------------------------------ + + def _gene_to_species(self, gene_name: str) -> str: + """Extract species name from a gene name using the configured + separators and species field.""" + seps = self.string_parameter.get("gene_name_separators", "_@") + # Build regex pattern that splits on any separator character + pattern = "[" + re.escape(seps) + "]+" + tokens = re.split(pattern, gene_name) + field = int(self.scalar_parameter.get("species_field", 0)) + if field == -1: + return "_".join(tokens[1:]) + return tokens[field] + + # ------------------------------------------------------------------ + # pun -- compute log-likelihood under undated model + # ------------------------------------------------------------------ + + def pun(self, ale, verbose: bool = False, no_T: bool = False) -> float: + """Compute the (non-log) likelihood of *ale* under the undated model. + + Returns the likelihood value (not log-transformed). + Mirrors ``exODT_model::pun`` in *undated.cpp*. + + Parameters + ---------- + ale : ApproxPosterior + The approximate posterior object. + verbose : bool + If True, print gene-to-species mapping. + no_T : bool + If True, disable transfers. + """ + survive = 0.0 + root_sum = 0.0 + O_norm = 0.0 + self.mPTuq_ancestral_correction = [] + self.uq = [] + self.mPTuq = [] + self.ale_pointer = ale + + # Build ordered clade list (small to large) + self.g_ids = [] + self.g_id_sizes = [] + for size in sorted(ale.size_ordered_bips.keys()): + for g_id in ale.size_ordered_bips[size]: + self.g_ids.append(g_id) + self.g_id_sizes.append(size) + # Root bipartition (-1) handled separately + self.g_ids.append(-1) + self.g_id_sizes.append(ale.Gamma_size) + self.root_i = len(self.g_ids) - 1 + + # Gene <-> species mapping + self.gid_sps.clear() + species_set = set(self.extant_species.values()) + if verbose: + print("\nGene\t:\tSpecies") + for idx in range(len(self.g_ids)): + g_id = self.g_ids[idx] + if self.g_id_sizes[idx] == 1: + # Find the single leaf id in the bitset + leaf_id = None + for bit_i in range(ale.Gamma_size + 1): + if (ale.id_sets[g_id] >> bit_i) & 1: + leaf_id = bit_i + break + gene_name = ale.id_leaves[leaf_id] + species_name = self._gene_to_species(gene_name) + self.gid_sps[g_id] = species_name + if species_name not in species_set: + print( + f"Error: gene name {gene_name} is associated to " + f"species name {species_name} that cannot be found " + f"in the species tree.", + file=sys.stderr, + ) + sys.exit(-1) + if verbose: + print(f"{gene_name}\t:\t{species_name}") + + # Build g_id2i and initialise uq / mPTuq arrays + self.g_id2i = {} + for i in range(len(self.g_ids)): + g_id = self.g_ids[i] + self.g_id2i[g_id] = i + + if i >= len(self.uq): + self.uq.append([0.0] * self.last_branch) + self.mPTuq_ancestral_correction.append([0.0] * self.last_branch) + self.mPTuq.append(0.0) + else: + self.mPTuq[i] = 0.0 + for e in range(self.last_branch): + self.uq[i][e] = 0.0 + self.mPTuq_ancestral_correction[i][e] = 0.0 + + # Main iteration (4 rounds for convergence) + for _iteration in range(4): + for i in range(len(self.g_ids)): + new_mPTuq = 0.0 + + g_id = self.g_ids[i] + is_a_leaf = self.g_id_sizes[i] == 1 + + # Build partition lists + gp_is = [] + gpp_is = [] + p_part = [] + + if g_id != -1: + # Normal clade: iterate over Dip_counts + for (gp_id, gpp_id), count in ale.Dip_counts[g_id].items(): + gp_is.append(self.g_id2i[gp_id]) + gpp_is.append(self.g_id2i[gpp_id]) + if ale.Bip_counts.get(g_id, 0) <= self.scalar_parameter["min_bip_count"]: + p_part.append(0.0) + else: + p_part.append( + ale.p_dip(g_id, gp_id, gpp_id) + ** self.scalar_parameter["seq_beta"] + ) + else: + # Root bipartition: enumerate all bipartitions + bip_parts = {} + for gp_id in ale.Bip_counts: + gamma = ale.id_sets[gp_id] + not_gamma = gamma ^ ale.Gamma + if not_gamma not in ale.set_ids: + continue + gpp_id = ale.set_ids[not_gamma] + key = frozenset([gp_id, gpp_id]) + bip_parts[key] = 1 + for key in bip_parts: + parts = sorted(key) + gp_id = parts[0] + gpp_id = parts[1] if len(parts) > 1 else parts[0] + gp_is.append(self.g_id2i[parts[0]]) + gpp_is.append(self.g_id2i[parts[1]] if len(parts) > 1 else self.g_id2i[parts[0]]) + bip_count = ale.Bip_counts.get(gp_id, 0) + if bip_count <= self.scalar_parameter.get("min_bip_count", -1) and not ale.Gamma_size < 4: + p_part.append(0.0) + else: + p_part.append( + ale.p_bip(gp_id) + ** self.scalar_parameter["seq_beta"] + ) + + # Inner loop over branches + for e in range(self.last_branch): + uq_sum = 0.0 + + # S-leaf and G-leaf match + if ( + e < self.last_leaf + and is_a_leaf + and self.extant_species[e] == self.gid_sps.get(g_id, "") + ): + uq_sum += self.PS[e] * 1.0 + + # G internal: enumerate partitions + if not is_a_leaf: + n_parts = len(gp_is) + for pi in range(n_parts): + gp_i = gp_is[pi] + gpp_i = gpp_is[pi] + pp = p_part[pi] + + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + # Speciation event + uq_sum += self.PS[e] * ( + self.uq[gp_i][f] * self.uq[gpp_i][g] + + self.uq[gp_i][g] * self.uq[gpp_i][f] + ) * pp + + # Duplication event + uq_sum += self.PD[e] * ( + self.uq[gp_i][e] * self.uq[gpp_i][e] + ) * pp + + # Transfer event + if not no_T: + uq_sum += ( + self.uq[gp_i][e] + * ( + self.mPTuq[gpp_i] + - self.mPTuq_ancestral_correction[gpp_i][e] + ) + / self.tau_norm[e] + + self.uq[gpp_i][e] + * ( + self.mPTuq[gp_i] + - self.mPTuq_ancestral_correction[gp_i][e] + ) + / self.tau_norm[e] + ) * pp + + # SL (speciation-loss) event + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + uq_sum += self.PS[e] * ( + self.uq[i][f] * self.uE[g] + + self.uq[i][g] * self.uE[f] + ) + + # DL (duplication-loss) event + uq_sum += self.PD[e] * (self.uq[i][e] * self.uE[e] * 2) + + # TL (transfer-loss) event + if not no_T: + uq_sum += ( + ( + self.mPTuq[i] + - self.mPTuq_ancestral_correction[i][e] + ) + / self.tau_norm[e] + * self.uE[e] + + self.uq[i][e] + * (self.mPTE - self.mPTE_ancestral_correction[e]) + / self.tau_norm[e] + ) + + if uq_sum < EPSILON: + uq_sum = EPSILON + self.uq[i][e] = uq_sum + new_mPTuq += self.wT[e] * uq_sum + + # Update ancestral correction for this clade + self.mPTuq_ancestral_correction[i][e] = 0.0 + for f in self.ancestors[e]: + self.mPTuq_ancestral_correction[i][e] += self.wT[f] * uq_sum + + self.mPTuq[i] = new_mPTuq + + # Root summation + survive = 0.0 + root_sum = 0.0 + O_norm = 0.0 + + # Check for single-origination mode + single_O = any( + self.vector_parameter["rate_multiplier_O"][e] < 0 + for e in range(self.last_branch) + ) + if single_O: + for e in range(self.last_branch): + if self.vector_parameter["rate_multiplier_O"][e] > 0: + self.vector_parameter["rate_multiplier_O"][e] = 0.0 + else: + self.vector_parameter["rate_multiplier_O"][e] = 1.0 + + for e in range(self.last_branch): + O_p = self.vector_parameter["rate_multiplier_O"][e] + if e == (self.last_branch - 1) and O_p == 1.0: + O_p = self.scalar_parameter["O_R"] + O_norm += O_p + root_sum += self.uq[self.root_i][e] * O_p + survive += 1.0 - self.uE[e] + + for e in range(self.last_branch): + O_p = self.vector_parameter["rate_multiplier_O"][e] + if e == (self.last_branch - 1) and O_p == 1.0: + O_p = self.scalar_parameter["O_R"] + if O_p > 0 and O_norm > 0: + self.branch_counts["O_LL"][e] = ( + math.log(max(self.uq[self.root_i][e], EPSILON)) + + math.log(max(O_p, EPSILON)) + - math.log(O_norm) + ) + + # Return non-log likelihood (C++ returns this directly) + if survive <= 0 or O_norm <= 0: + return EPSILON + return root_sum / survive / O_norm * self.last_branch + + # ------------------------------------------------------------------ + # sample_undated -- sample a reconciled gene tree (entry point) + # ------------------------------------------------------------------ + + def sample_undated(self, no_T: bool = False) -> str: + """Sample a reconciled gene tree under the undated model. + + Returns a Newick-like string describing the reconciled tree. + Mirrors ``exODT_model::sample_undated()`` (no args) in *undated.cpp*. + """ + import random + + r = random.random() + + root_sum = 0.0 + O_norm = 0.0 + for e in range(self.last_branch): + self.branch_counts["saw"][e] = 0 + O_p = self.vector_parameter["rate_multiplier_O"][e] + if e == (self.last_branch - 1) and O_p == 1.0: + O_p = self.scalar_parameter["O_R"] + O_norm += O_p + root_sum += self.uq[len(self.g_ids) - 1][e] * O_p + EPSILON + + root_resum = 0.0 + for e in range(self.last_branch): + O_p = self.vector_parameter["rate_multiplier_O"][e] + if e == (self.last_branch - 1) and O_p == 1.0: + O_p = self.scalar_parameter["O_R"] + root_resum += self.uq[self.root_i][e] * O_p + EPSILON + if r * root_sum < root_resum: + self.register_O(e) + return self._sample_undated_recursive(e, self.root_i, "O", "", no_T) + ";" + + return "-!=-" + + # ------------------------------------------------------------------ + # _sample_undated_recursive -- recursive sampling + # ------------------------------------------------------------------ + + def _sample_undated_recursive( + self, + e: int, + i: int, + last_event: str, + branch_string: str = "", + no_T: bool = False, + ) -> str: + """Recursively sample events for a reconciled gene tree. + + Mirrors ``exODT_model::sample_undated(int e, int i, ...)`` in + *undated.cpp*. + """ + import random + + r = random.random() + ale = self.ale_pointer + + is_a_leaf = False + g_id = self.g_ids[i] + if self.g_id_sizes[i] == 1: + is_a_leaf = True + + # Branch length + if g_id in ale.Bip_counts and ale.Bip_counts[g_id] > 0: + bl = max( + ale.Bip_bls[g_id] / ale.Bip_counts[g_id], + self.scalar_parameter["min_branch_lenghts"], + ) + else: + bl = max( + ale.Bip_bls.get(g_id, 0.0) / ale.observations, + self.scalar_parameter["min_branch_lenghts"], + ) + branch_length = str(bl) + + # Build partition lists (same logic as pun) + gp_is = [] + gpp_is = [] + p_part = [] + + if g_id != -1: + for (gp_id, gpp_id), count in ale.Dip_counts[g_id].items(): + gp_is.append(self.g_id2i[gp_id]) + gpp_is.append(self.g_id2i[gpp_id]) + if ale.Bip_counts.get(g_id, 0) <= self.scalar_parameter["min_bip_count"]: + p_part.append(0.0) + else: + p_part.append( + ale.p_dip(g_id, gp_id, gpp_id) + ** self.scalar_parameter["seq_beta"] + ) + else: + bip_parts = {} + for gp_id in ale.Bip_counts: + gamma = ale.id_sets[gp_id] + not_gamma = gamma ^ ale.Gamma + if not_gamma not in ale.set_ids: + continue + gpp_id = ale.set_ids[not_gamma] + key = frozenset([gp_id, gpp_id]) + bip_parts[key] = 1 + for key in bip_parts: + parts = sorted(key) + gp_is.append(self.g_id2i[parts[0]]) + gpp_is.append(self.g_id2i[parts[1]] if len(parts) > 1 else self.g_id2i[parts[0]]) + gp_id = parts[0] + bip_count = ale.Bip_counts.get(gp_id, 0) + if bip_count <= self.scalar_parameter.get("min_bip_count", -1) and not ale.Gamma_size < 4: + p_part.append(0.0) + else: + p_part.append( + ale.p_bip(gp_id) ** self.scalar_parameter["seq_beta"] + ) + + # Compute total uq_sum for sampling + uq_sum = 0.0 + + # S-leaf and G-leaf + if e < self.last_leaf and is_a_leaf and self.extant_species[e] == self.gid_sps.get(g_id, ""): + uq_sum += self.PS[e] * 1.0 + EPSILON + + # G internal + if not is_a_leaf: + n_parts = len(gp_is) + for pi in range(n_parts): + gp_i = gp_is[pi] + gpp_i = gpp_is[pi] + pp = p_part[pi] + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + uq_sum += self.PS[e] * self.uq[gp_i][f] * self.uq[gpp_i][g] * pp + EPSILON + uq_sum += self.PS[e] * self.uq[gp_i][g] * self.uq[gpp_i][f] * pp + EPSILON + # D event + uq_sum += self.PD[e] * (self.uq[gp_i][e] * self.uq[gpp_i][e] * 2) * pp + EPSILON + # T event + for f in range(self.last_branch): + if not self.ancestral[e][f] and not no_T: + uq_sum += self.uq[gp_i][e] * (self.wT[f] / self.tau_norm[e]) * self.uq[gpp_i][f] * pp + EPSILON + uq_sum += self.uq[gpp_i][e] * (self.wT[f] / self.tau_norm[e]) * self.uq[gp_i][f] * pp + EPSILON + + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + uq_sum += self.PS[e] * self.uq[i][f] * self.uE[g] + EPSILON + uq_sum += self.PS[e] * self.uq[i][g] * self.uE[f] + EPSILON + + uq_sum += self.PD[e] * (self.uq[i][e] * self.uE[e] * 2) + EPSILON + + for f in range(self.last_branch): + if not self.ancestral[e][f] and not no_T: + uq_sum += (self.wT[f] / self.tau_norm[e]) * self.uq[i][f] * self.uE[e] + EPSILON + uq_sum += (self.wT[f] / self.tau_norm[e]) * self.uE[f] * self.uq[i][e] + EPSILON + + # Branch label + if not (e < self.last_leaf): + estr = str(e) + else: + estr = self.extant_species[e] + + # Now resample: walk through events in same order, accumulating + uq_resum = 0.0 + + # S-leaf and G-leaf + if e < self.last_leaf and is_a_leaf and self.extant_species[e] == self.gid_sps.get(g_id, ""): + uq_resum += self.PS[e] * 1.0 + EPSILON + if r * uq_sum < uq_resum: + self.register_leafu(e, last_event) + return ale.set2name(ale.id_sets[g_id]) + branch_string + ":" + branch_length + + # G internal + if not is_a_leaf: + n_parts = len(gp_is) + for pi in range(n_parts): + gp_i = gp_is[pi] + gpp_i = gpp_is[pi] + pp = p_part[pi] + + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + + # S event (gp->f, gpp->g) + uq_resum += self.PS[e] * self.uq[gp_i][f] * self.uq[gpp_i][g] * pp + EPSILON + if r * uq_sum < uq_resum: + self.register_Su(e, last_event) + return ( + "(" + + self._sample_undated_recursive(f, gp_i, "S", "", no_T) + + "," + + self._sample_undated_recursive(g, gpp_i, "S", "", no_T) + + ")." + + estr + + branch_string + + ":" + + branch_length + ) + + # S event (gp->g, gpp->f) + uq_resum += self.PS[e] * self.uq[gp_i][g] * self.uq[gpp_i][f] * pp + EPSILON + if r * uq_sum < uq_resum: + self.register_Su(e, last_event) + return ( + "(" + + self._sample_undated_recursive(g, gp_i, "S", "", no_T) + + "," + + self._sample_undated_recursive(f, gpp_i, "S", "", no_T) + + ")." + + estr + + branch_string + + ":" + + branch_length + ) + + # D event + uq_resum += self.PD[e] * (self.uq[gp_i][e] * self.uq[gpp_i][e] * 2) * pp + EPSILON + if r * uq_sum < uq_resum: + self.register_D(e) + return ( + "(" + + self._sample_undated_recursive(e, gp_i, "D", "", no_T) + + "," + + self._sample_undated_recursive(e, gpp_i, "D", "", no_T) + + ").D@" + + estr + + branch_string + + ":" + + branch_length + ) + + # T event + for f in range(self.last_branch): + if not self.ancestral[e][f] and not no_T: + fstr = self.extant_species[f] if f < self.last_leaf else str(f) + + uq_resum += self.uq[gp_i][e] * (self.wT[f] / self.tau_norm[e]) * self.uq[gpp_i][f] * pp + EPSILON + if r * uq_sum < uq_resum: + self.register_Tfrom(e) + self.register_Tto(f) + self.register_T_to_from(e, f) + t_token = f"{estr}>{fstr}|{ale.set2name(ale.id_sets[self.g_ids[gpp_i]])}" + self.Ttokens.append(t_token) + return ( + "(" + + self._sample_undated_recursive(e, gp_i, "S", "", no_T) + + "," + + self._sample_undated_recursive(f, gpp_i, "T", "", no_T) + + ").T@" + + estr + + "->" + + fstr + + branch_string + + ":" + + branch_length + ) + + uq_resum += self.uq[gpp_i][e] * (self.wT[f] / self.tau_norm[e]) * self.uq[gp_i][f] * pp + EPSILON + if r * uq_sum < uq_resum: + self.register_Tfrom(e) + self.register_Tto(f) + self.register_T_to_from(e, f) + t_token = f"{estr}>{fstr}|{ale.set2name(ale.id_sets[self.g_ids[gp_i]])}" + self.Ttokens.append(t_token) + return ( + "(" + + self._sample_undated_recursive(e, gpp_i, "S", "", no_T) + + "," + + self._sample_undated_recursive(f, gp_i, "T", "", no_T) + + ").T@" + + estr + + "->" + + fstr + + branch_string + + ":" + + branch_length + ) + + # SL event + if not (e < self.last_leaf): + f = self.daughter[e] + g = self.son[e] + uq_resum += self.PS[e] * self.uq[i][f] * self.uE[g] + EPSILON + if r * uq_sum < uq_resum: + self.register_Su(e, last_event) + self.register_L(g) + return self._sample_undated_recursive(f, i, "S", "." + estr + branch_string, no_T) + + uq_resum += self.PS[e] * self.uq[i][g] * self.uE[f] + EPSILON + if r * uq_sum < uq_resum: + self.register_Su(e, last_event) + self.register_L(f) + return self._sample_undated_recursive(g, i, "S", "." + estr + branch_string, no_T) + + # DL event + uq_resum += self.PD[e] * (self.uq[i][e] * self.uE[e] * 2) + EPSILON + if r * uq_sum < uq_resum: + return self._sample_undated_recursive(e, i, "S", branch_string, no_T) + + # TL event + for f in range(self.last_branch): + if not self.ancestral[e][f] and not no_T: + fstr = self.extant_species[f] if f < self.last_leaf else str(f) + + uq_resum += (self.wT[f] / self.tau_norm[e]) * self.uq[i][f] * self.uE[e] + EPSILON + if r * uq_sum < uq_resum: + self.register_Tfrom(e) + self.register_Tto(f) + self.register_T_to_from(e, f) + self.register_L(e) + return self._sample_undated_recursive( + f, i, "T", ".T@" + estr + "->" + fstr + branch_string, no_T + ) + + uq_resum += (self.wT[f] / self.tau_norm[e]) * self.uE[f] * self.uq[i][e] + EPSILON + if r * uq_sum < uq_resum: + return self._sample_undated_recursive(e, i, "S", "", no_T) + + print("sum error!", file=sys.stderr) + return "-!=-" + + # ------------------------------------------------------------------ + # Event registration helpers + # ------------------------------------------------------------------ + + def register_O(self, e: int): + """Register an origination event on branch *e*.""" + if e > -1: + self.branch_counts["count"][e] += 1 + self.branch_counts["Os"][e] += 1 + + def register_D(self, e: int): + """Register a duplication event on branch *e*.""" + self.MLRec_events["D"] += 1 + if e > -1: + self.branch_counts["Ds"][e] += 1 + + def register_Tto(self, e: int): + """Register a transfer-to event on branch *e*.""" + self.MLRec_events["T"] += 1 + if e > -1: + self.branch_counts["Ts"][e] += 1 + + def register_Tfrom(self, e: int): + """Register a transfer-from event on branch *e*.""" + if e > -1: + self.branch_counts["Tfroms"][e] += 1 + + def register_L(self, e: int): + """Register a loss event on branch *e*.""" + self.MLRec_events["L"] += 1 + if e > -1: + self.branch_counts["Ls"][e] += 1 + + def register_S(self, e: int): + """Register a speciation event on branch *e* (dated model).""" + self.MLRec_events["S"] += 1 + if e > -1: + f = self.daughter[e] + g = self.son[e] + self.branch_counts["copies"][e] += 1 + self.branch_counts["count"][f] += 1 + self.branch_counts["count"][g] += 1 + + def register_Su(self, e: int, last_event: str): + """Register a speciation event on branch *e* (undated model).""" + self.MLRec_events["S"] += 1 + if e > -1: + f = self.daughter[e] + g = self.son[e] + if last_event in ("S", "O"): + self.branch_counts["singleton"][e] += 1 + self.branch_counts["copies"][e] += 1 + if self.branch_counts["saw"][e] == 0: + self.branch_counts["presence"][e] += 1 + self.branch_counts["saw"][e] = 1 + self.branch_counts["count"][f] += 1 + self.branch_counts["count"][g] += 1 + + def register_leafu(self, e: int, last_event: str): + """Register reaching a leaf (undated model).""" + if e > -1: + self.branch_counts["copies"][e] += 1 + if self.branch_counts["saw"][e] == 0: + self.branch_counts["presence"][e] += 1 + self.branch_counts["saw"][e] = 1 + if last_event in ("S", "O"): + self.branch_counts["singleton"][e] += 1 + + def register_leaf(self, e: int): + """Register reaching a leaf (dated model).""" + if e > -1: + self.branch_counts["copies"][e] += 1 + + def register_T_to_from(self, e: int, f: int): + """Record a transfer from branch *e* to branch *f*.""" + self.T_to_from[e][f] += 1 + + def reset_T_to_from(self): + """Reset the transfer-direction matrix to zeros.""" + for e in range(self.last_branch): + for f in range(self.last_branch): + self.T_to_from[e][f] = 0 + + def register_Ttoken(self, token: str): + """Append a transfer token string.""" + self.Ttokens.append(token) + + # ------------------------------------------------------------------ + # counts_string_undated + # ------------------------------------------------------------------ + + def counts_string_undated(self, samples: float = 1.0) -> str: + """Format event counts as a tab-separated string. + + Mirrors ``exODT_model::counts_string_undated`` in *undated.cpp*. + """ + lines = [] + for e in range(self.last_branch): + is_leaf_branch = e < self.last_leaf + if is_leaf_branch: + named = f"{self.extant_species[e]}({e})" + prefix = "S_terminal_branch" + else: + named = str(e) + prefix = "S_internal_branch" + + line = ( + f"{prefix}\t{named}\t" + f"{self.branch_counts['Ds'][e] / samples}\t" + f"{self.branch_counts['Ts'][e] / samples}\t" + f"{self.branch_counts['Ls'][e] / samples}\t" + f"{self.branch_counts['Os'][e] / samples}\t" + f"{self.branch_counts['copies'][e] / samples}\t" + f"{self.branch_counts['singleton'][e] / samples}\t" + f"{self.uE[e]}\t" + f"{self.branch_counts['presence'][e] / samples}\t" + f"{self.branch_counts['O_LL'][e]}" + ) + lines.append(line) + return "\n".join(lines) + "\n" diff --git a/ALE_python/fraction_missing.py b/ALE_python/fraction_missing.py new file mode 100644 index 0000000..eac8454 --- /dev/null +++ b/ALE_python/fraction_missing.py @@ -0,0 +1,28 @@ +"""Utility for reading ALE fraction missing files.""" + + +def read_fraction_missing_file(filepath: str) -> dict[str, float]: + """Read a fraction missing file and return a dict mapping species to fraction values. + + The file format is one line per species: + species_name:fraction_value + + Args: + filepath: Path to the fraction missing file. If empty string, returns empty dict. + + Returns: + A dict mapping species name (str) to fraction value (float between 0 and 1). + """ + if filepath == "": + return {} + + result = {} + with open(filepath, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + species_name, fraction_value = line.split(":") + result[species_name] = float(fraction_value) + + return result diff --git a/ALE_python/newick.py b/ALE_python/newick.py new file mode 100644 index 0000000..1969752 --- /dev/null +++ b/ALE_python/newick.py @@ -0,0 +1,159 @@ +class Node: + _next_id = 0 + + def __init__(self, name="", branch_length=None, parent=None): + self.name = name + self.branch_length = branch_length + self.children = [] + self.parent = parent + self.id = Node._next_id + Node._next_id += 1 + + def add_child(self, child): + child.parent = self + self.children.append(child) + return child + + def __repr__(self): + if self.name: + return f"Node({self.name!r})" + return f"Node(id={self.id}, children={len(self.children)})" + + +def is_leaf(node): + return len(node.children) == 0 + + +def get_leaves(node): + leaves = [] + _collect_leaves(node, leaves) + return leaves + + +def _collect_leaves(node, leaves): + if is_leaf(node): + leaves.append(node) + else: + for child in node.children: + _collect_leaves(child, leaves) + + +def get_all_nodes(node): + nodes = [] + _collect_postorder(node, nodes) + return nodes + + +def _collect_postorder(node, nodes): + for child in node.children: + _collect_postorder(child, nodes) + nodes.append(node) + + +def copy_tree(node, parent=None): + new_node = Node( + name=node.name, + branch_length=node.branch_length, + parent=parent, + ) + new_node.id = node.id + for child in node.children: + new_child = copy_tree(child, parent=new_node) + new_node.children.append(new_child) + return new_node + + +def parse_newick(string): + string = string.strip() + if string.endswith(";"): + string = string[:-1] + tokens = _tokenize(string) + root, _ = _parse_tokens(tokens, 0) + _assign_ids(root) + return root + + +def _tokenize(s): + tokens = [] + i = 0 + n = len(s) + while i < n: + c = s[i] + if c in "(),": + tokens.append(c) + i += 1 + elif c == ":": + tokens.append(":") + i += 1 + start = i + while i < n and s[i] not in "(),;:": + i += 1 + tokens.append(s[start:i]) + else: + start = i + while i < n and s[i] not in "(),;:": + i += 1 + tokens.append(s[start:i]) + return tokens + + +def _parse_tokens(tokens, pos): + node = Node() + + if pos < len(tokens) and tokens[pos] == "(": + pos += 1 # consume '(' + child, pos = _parse_tokens(tokens, pos) + node.add_child(child) + + while pos < len(tokens) and tokens[pos] == ",": + pos += 1 # consume ',' + child, pos = _parse_tokens(tokens, pos) + node.add_child(child) + + if pos < len(tokens) and tokens[pos] == ")": + pos += 1 # consume ')' + + # Parse label (name or bootstrap) after ')' or as a leaf name + if pos < len(tokens) and tokens[pos] not in "(),:": + node.name = tokens[pos] + pos += 1 + + # Parse branch length + if pos < len(tokens) and tokens[pos] == ":": + pos += 1 # consume ':' + if pos < len(tokens): + try: + node.branch_length = float(tokens[pos]) + except ValueError: + node.branch_length = None + pos += 1 + + return node, pos + + +def _assign_ids(node): + counter = [0] + _assign_ids_postorder(node, counter) + + +def _assign_ids_postorder(node, counter): + for child in node.children: + _assign_ids_postorder(child, counter) + node.id = counter[0] + counter[0] += 1 + + +def to_newick(node): + return _to_newick_recursive(node) + ";" + + +def _to_newick_recursive(node): + result = "" + if node.children: + child_strs = [_to_newick_recursive(c) for c in node.children] + result = "(" + ",".join(child_strs) + ")" + if node.name: + result += node.name + if node.branch_length is not None: + result += ":" + str(node.branch_length) + return result diff --git a/ALE_python/tests/__init__.py b/ALE_python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ALE_python/tests/conftest.py b/ALE_python/tests/conftest.py new file mode 100644 index 0000000..daec29e --- /dev/null +++ b/ALE_python/tests/conftest.py @@ -0,0 +1,31 @@ +import pytest +import os +import subprocess +import tempfile +import shutil + +@pytest.fixture +def project_root(): + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +@pytest.fixture +def example_data(project_root): + return os.path.join(project_root, "example_data") + +@pytest.fixture +def cpp_bin(project_root): + return os.path.join(project_root, "build", "bin") + +@pytest.fixture +def species_tree(example_data): + return os.path.join(example_data, "S.tree") + +@pytest.fixture +def gene_trees(example_data): + return os.path.join(example_data, "HBG745965_real.1.treelist") + +@pytest.fixture +def tmp_dir(): + d = tempfile.mkdtemp() + yield d + shutil.rmtree(d) diff --git a/ALE_python/tests/test_ale_core.py b/ALE_python/tests/test_ale_core.py new file mode 100644 index 0000000..97db705 --- /dev/null +++ b/ALE_python/tests/test_ale_core.py @@ -0,0 +1,218 @@ +"""Tests for core ALE data structures (ApproxPosterior, newick, etc.).""" + +import os +import tempfile + +import pytest + +from ALE_python.ale import ApproxPosterior +from ALE_python.ale_util import ( + observe_ALE_from_strings, + load_ALE_from_file, + save_ALE_to_file, +) +from ALE_python.newick import parse_newick, get_leaves + + +class TestConstructSimpleTree: + + def test_construct_simple_tree(self): + """Construct from '(A:1,B:1,C:1);', verify Gamma_size=3 and leaf_ids.""" + ale = ApproxPosterior("(A:1,B:1,C:1);") + + assert ale.Gamma_size == 3, f"Expected Gamma_size=3, got {ale.Gamma_size}" + assert len(ale.leaf_ids) == 3, ( + f"Expected 3 leaf_ids, got {len(ale.leaf_ids)}" + ) + assert set(ale.leaf_ids.keys()) == {"A", "B", "C"} + + # IDs should be 1-based and assigned alphabetically + assert ale.leaf_ids["A"] == 1 + assert ale.leaf_ids["B"] == 2 + assert ale.leaf_ids["C"] == 3 + + # Reverse mapping + assert ale.id_leaves[1] == "A" + assert ale.id_leaves[2] == "B" + assert ale.id_leaves[3] == "C" + + +class TestDecomposeAndObserve: + + def test_decompose_and_observe(self): + """Observe multiple trees, verify Bip_counts are populated.""" + trees = [ + "((A:1,B:1):1,C:1);", + "((A:1,C:1):1,B:1);", + "((A:1,B:1):1,C:1);", + "((B:1,C:1):1,A:1);", + ] + + ale = observe_ALE_from_strings(trees) + + assert ale.observations == pytest.approx(4.0), ( + f"Expected 4 observations, got {ale.observations}" + ) + assert ale.Gamma_size == 3 + + # With 3 leaves, all bipartitions are trivial (size 1 or 2), + # so Bip_counts should be non-empty + assert len(ale.Bip_counts) > 0, "Bip_counts should not be empty" + + +class TestSaveLoadRoundtrip: + + def test_save_load_roundtrip(self): + """Save ALE to file, load it back, compare all fields.""" + trees = [ + "((A:0.5,B:0.3):0.2,(C:0.4,D:0.6):0.1);", + "((A:0.5,C:0.3):0.2,(B:0.4,D:0.6):0.1);", + "((A:0.5,B:0.3):0.2,(C:0.4,D:0.6):0.1);", + "(((A:0.5,B:0.3):0.1,C:0.4):0.2,D:0.6);", + ] + original = observe_ALE_from_strings(trees) + + with tempfile.NamedTemporaryFile(suffix=".ale", delete=False) as tmp: + tmp_path = tmp.name + + try: + save_ALE_to_file(original, tmp_path) + loaded = load_ALE_from_file(tmp_path) + + # Compare observations + assert loaded.observations == pytest.approx(original.observations), ( + f"observations mismatch: {loaded.observations} vs {original.observations}" + ) + + # Compare Gamma_size + assert loaded.Gamma_size == original.Gamma_size + + # Compare leaf_ids + assert loaded.leaf_ids == original.leaf_ids + + # Compare Bip_counts + assert set(loaded.Bip_counts.keys()) == set(original.Bip_counts.keys()) + for key in original.Bip_counts: + assert loaded.Bip_counts[key] == pytest.approx( + original.Bip_counts[key], rel=1e-6 + ), f"Bip_counts[{key}] mismatch" + + # Compare non-empty Dip_counts entries (empty trailing entries + # may differ because decompose pushes empty dicts that are not + # serialised, so load_state pads to last_leafset_id+1 only). + orig_nonempty = {i: d for i, d in enumerate(original.Dip_counts) if d} + load_nonempty = {i: d for i, d in enumerate(loaded.Dip_counts) if d} + assert set(orig_nonempty.keys()) == set(load_nonempty.keys()), ( + f"Dip_counts non-empty index mismatch" + ) + for idx in orig_nonempty: + for key in orig_nonempty[idx]: + assert load_nonempty[idx][key] == pytest.approx( + orig_nonempty[idx][key], rel=1e-6 + ) + + # Compare set_ids mapping + for bitset, sid in original.set_ids.items(): + assert bitset in loaded.set_ids, ( + f"Bitset {bitset} missing from loaded set_ids" + ) + assert loaded.set_ids[bitset] == sid + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +class TestPBip: + + def test_p_bip_simple(self): + """Test bipartition probability calculation on a small example.""" + # With fewer than 4 leaves, p_bip should always return 1.0 + trees_3 = ["((A:1,B:1):1,C:1);"] * 5 + ale_3 = observe_ALE_from_strings(trees_3) + + for g_id in ale_3.Bip_counts: + assert ale_3.p_bip(g_id) == pytest.approx(1.0), ( + f"p_bip({g_id}) should be 1.0 for 3 leaves" + ) + + # With 4+ leaves, probabilities depend on observation frequencies + trees_4 = [ + "((A:1,B:1):1,(C:1,D:1):1);", + "((A:1,B:1):1,(C:1,D:1):1);", + "((A:1,C:1):1,(B:1,D:1):1);", + ] + ale_4 = observe_ALE_from_strings(trees_4) + + # Probabilities should be between 0 and 1 + for g_id in ale_4.Bip_counts: + p = ale_4.p_bip(g_id) + assert 0.0 <= p <= 1.0, f"p_bip({g_id})={p} out of range" + + # Leaf-level bipartitions (size 1) should have probability 1.0 + # because they appear in every tree + for g_id, bitset in ale_4.id_sets.items(): + size = ale_4.set_sizes.get(g_id, 0) + if size == 1: + assert ale_4.p_bip(g_id) == pytest.approx(1.0), ( + f"Leaf bipartition p_bip({g_id}) should be 1.0" + ) + + +class TestMppTree: + + def test_mpp_tree_simple(self): + """Test mpp tree on a small example with a clear majority topology.""" + # Give one topology a strong majority + trees = ( + ["((A:1,B:1):1,(C:1,D:1):1);"] * 8 + + ["((A:1,C:1):1,(B:1,D:1):1);"] * 2 + ) + ale = observe_ALE_from_strings(trees) + mpp_str, mpp_pp = ale.mpp_tree() + + assert len(mpp_str) > 0, "mpp_tree returned empty string" + assert mpp_pp > 0.0, f"mpp posterior probability should be > 0" + + # Parse and verify leaf count + root = parse_newick(mpp_str) + leaves = get_leaves(root) + leaf_names = sorted(n.name for n in leaves) + + assert len(leaf_names) == 4 + assert set(leaf_names) == {"A", "B", "C", "D"} + + +class TestCountTrees: + + def test_count_trees(self): + """Test count_trees function returns reasonable values.""" + # With 3 leaves, there are exactly 3 unrooted topologies, + # but amalgamated count depends on observed bipartitions + trees_3 = [ + "((A:1,B:1):1,C:1);", + "((A:1,C:1):1,B:1);", + "((B:1,C:1):1,A:1);", + ] + ale_3 = observe_ALE_from_strings(trees_3) + count_3 = ale_3.count_trees() + + # count_trees should return a positive number + assert count_3 > 0, f"count_trees returned {count_3}, expected > 0" + + # For 4 leaves with all topologies observed + trees_4 = [ + "((A:1,B:1):1,(C:1,D:1):1);", + "((A:1,C:1):1,(B:1,D:1):1);", + "((A:1,D:1):1,(B:1,C:1):1);", + ] + ale_4 = observe_ALE_from_strings(trees_4) + count_4 = ale_4.count_trees() + + assert count_4 > 0, f"count_trees returned {count_4}, expected > 0" + # With 4 leaves, count_trees counts amalgamable rooted trees (not + # unrooted topologies). Each of the 3 unrooted topologies has 5 + # rootings, giving 15 rooted trees total. + assert count_4 == pytest.approx(15.0), ( + f"count_trees for 4 leaves (all topologies observed) should be 15, got {count_4}" + ) diff --git a/ALE_python/tests/test_ml_undated.py b/ALE_python/tests/test_ml_undated.py new file mode 100644 index 0000000..c174cc2 --- /dev/null +++ b/ALE_python/tests/test_ml_undated.py @@ -0,0 +1,247 @@ +"""Tests for ALEml_undated -- comparing Python output with C++ output.""" + +import math +import os +import re +import shutil +import subprocess + +import pytest + +from ALE_python.ale import ApproxPosterior +from ALE_python.ale_util import load_ALE_from_file, observe_ALE_from_file +from ALE_python.exodt import ExODTModel +from ALE_python.newick import parse_newick, get_leaves + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _prepare_inputs(species_tree, gene_trees, tmp_dir, burnin=1000): + """Copy input files to tmp_dir, run Python observe, return paths.""" + s_copy = os.path.join(tmp_dir, os.path.basename(species_tree)) + g_copy = os.path.join(tmp_dir, os.path.basename(gene_trees)) + shutil.copy(species_tree, s_copy) + shutil.copy(gene_trees, g_copy) + + ale = observe_ALE_from_file(g_copy, burnin=burnin) + ale_path = g_copy + ".ale" + ale.save_state(ale_path) + return s_copy, ale_path + + +def _run_python_ml_undated_fixed(species_tree_path, ale_path, delta=0.01, + tau=0.01, lambda_=0.1): + """Run Python ALEml_undated with fixed rates (no optimization) and return + the log-likelihood.""" + with open(species_tree_path) as f: + Sstring = f.readline().strip() + + ale = load_ALE_from_file(ale_path) + + model = ExODTModel() + model.set_model_parameter("BOOTSTRAP_LABELS", "yes") + model.set_model_parameter("undatedBL", 0) + model.set_model_parameter("reldate", 0) + model.construct_undated(Sstring) + + model.set_model_parameter("seq_beta", 1.0) + model.set_model_parameter("O_R", 1.0) + model.set_model_parameter("delta", delta) + model.set_model_parameter("tau", tau) + model.set_model_parameter("lambda", lambda_) + + model.calculate_undatedEs() + lk = model.pun(ale, False, False) + if lk <= 0: + return -1e50 + return math.log(lk) + + +def _run_cpp_ml_undated_fixed(cpp_bin, species_tree_path, ale_path, tmp_dir, + delta=0.01, tau=0.01, lambda_=0.1): + """Run C++ ALEml_undated with fixed rates and extract the LL from output.""" + exe = os.path.join(cpp_bin, "ALEml_undated") + result = subprocess.run( + [exe, species_tree_path, ale_path, + f"delta={delta}", f"tau={tau}", f"lambda={lambda_}", + "sample=1"], + capture_output=True, + text=True, + cwd=tmp_dir, + timeout=600, + ) + assert result.returncode == 0, f"C++ ALEml_undated failed:\n{result.stderr}" + + # Extract LL from stdout + for line in result.stdout.splitlines(): + if line.startswith("LL="): + return float(line.split("=")[1].strip()) + # Also check in the .uml_rec file + rec_files = [f for f in os.listdir(tmp_dir) if f.endswith(".uml_rec")] + for rec_file in rec_files: + with open(os.path.join(tmp_dir, rec_file)) as fh: + for line in fh: + m = re.search(r">logl:\s*([-\d.eE+]+)", line) + if m: + return float(m.group(1)) + + pytest.fail("Could not extract LL from C++ ALEml_undated output") + + +def _run_python_ml_undated_optimized(species_tree_path, ale_path): + """Run Python ALEml_undated with optimization and return the LL.""" + from ALE_python.cli import ALEml_undated as cli_ml_undated + + # The CLI function takes argv and writes output files. + # We call it from the directory containing the ale file. + cwd_save = os.getcwd() + out_dir = os.path.dirname(ale_path) + os.chdir(out_dir) + try: + cli_ml_undated([species_tree_path, ale_path, "sample=10", "seed=42"]) + finally: + os.chdir(cwd_save) + + # Extract LL from output .uml_rec file + basename_s = os.path.basename(species_tree_path) + basename_ale = os.path.basename(ale_path) + rec_name = basename_s + "_" + basename_ale + ".uml_rec" + rec_path = os.path.join(out_dir, rec_name) + assert os.path.isfile(rec_path), f"Expected output file not found: {rec_path}" + + with open(rec_path) as fh: + for line in fh: + m = re.search(r">logl:\s*([-\d.eE+]+)", line) + if m: + return float(m.group(1)) + + pytest.fail("Could not extract LL from Python ALEml_undated output") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestMLUndated: + + def test_ml_undated_log_likelihood(self, species_tree, gene_trees, + cpp_bin, tmp_dir): + """Run Python ALEml_undated with fixed rates delta=0.01 tau=0.01 + lambda=0.1, verify LL is close to C++ LL.""" + s_copy, ale_path = _prepare_inputs(species_tree, gene_trees, tmp_dir) + + py_ll = _run_python_ml_undated_fixed(s_copy, ale_path) + cpp_ll = _run_cpp_ml_undated_fixed(cpp_bin, s_copy, ale_path, tmp_dir) + + # Allow some tolerance for floating-point differences between + # the Python and C++ implementations + assert py_ll == pytest.approx(cpp_ll, abs=1.0), ( + f"Python LL={py_ll}, C++ LL={cpp_ll}" + ) + + @pytest.mark.slow + def test_ml_undated_optimized_rates(self, species_tree, gene_trees, tmp_dir): + """Run optimization, check LL is around -34.6 (within 1.0 tolerance).""" + s_copy, ale_path = _prepare_inputs(species_tree, gene_trees, tmp_dir) + ll = _run_python_ml_undated_optimized(s_copy, ale_path) + + assert ll == pytest.approx(-34.6, abs=1.0), ( + f"Optimized LL={ll}, expected around -34.6" + ) + + @pytest.mark.slow + def test_ml_undated_output_files(self, species_tree, gene_trees, tmp_dir): + """Check that .uml_rec and .uTs files are created with correct format.""" + s_copy, ale_path = _prepare_inputs(species_tree, gene_trees, tmp_dir) + + from ALE_python.cli import ALEml_undated as cli_ml_undated + + cwd_save = os.getcwd() + os.chdir(tmp_dir) + try: + cli_ml_undated([s_copy, ale_path, "sample=10", "seed=42"]) + finally: + os.chdir(cwd_save) + + basename_s = os.path.basename(s_copy) + basename_ale = os.path.basename(ale_path) + radical = basename_s + "_" + basename_ale + + rec_path = os.path.join(tmp_dir, radical + ".uml_rec") + ts_path = os.path.join(tmp_dir, radical + ".uTs") + + assert os.path.isfile(rec_path), f".uml_rec file not found: {rec_path}" + assert os.path.isfile(ts_path), f".uTs file not found: {ts_path}" + + # Verify .uml_rec has expected structure + with open(rec_path) as fh: + content = fh.read() + assert "#ALEml_undated" in content, ".uml_rec missing header" + assert "S:\t" in content, ".uml_rec missing species tree line" + assert ">logl:" in content, ".uml_rec missing log-likelihood line" + assert "rate of" in content, ".uml_rec missing rate header" + assert "reconciled G-s:" in content, ".uml_rec missing reconciled trees section" + + # Verify .uTs has expected structure + with open(ts_path) as fh: + ts_content = fh.read() + assert "#from\tto\tfreq." in ts_content, ".uTs missing header" + + @pytest.mark.slow + def test_ml_undated_event_counts(self, species_tree, gene_trees, tmp_dir): + """Run with sample=100, verify total events are reasonable. + S should be 35 = n_leaves - 1 per sample.""" + s_copy, ale_path = _prepare_inputs(species_tree, gene_trees, tmp_dir) + + from ALE_python.cli import ALEml_undated as cli_ml_undated + + cwd_save = os.getcwd() + os.chdir(tmp_dir) + try: + cli_ml_undated([s_copy, ale_path, "sample=100", "seed=42"]) + finally: + os.chdir(cwd_save) + + basename_s = os.path.basename(s_copy) + basename_ale = os.path.basename(ale_path) + radical = basename_s + "_" + basename_ale + rec_path = os.path.join(tmp_dir, radical + ".uml_rec") + + # Parse event counts from the .uml_rec file + with open(rec_path) as fh: + lines = fh.readlines() + + # Find the "Total" line after "# of\tDuplications\tTransfers\tLosses\tSpeciations" + total_d, total_t, total_l, total_s = None, None, None, None + for i, line in enumerate(lines): + if line.startswith("Total "): + parts = line.strip().split("\t") + # Format: "Total \tD\tT\tL\tS" + total_d = float(parts[1]) + total_t = float(parts[2]) + total_l = float(parts[3]) + total_s = float(parts[4]) + break + + assert total_s is not None, "Could not find Total event counts in .uml_rec" + + # Speciations should be exactly n_leaves - 1 = 35 per sample (averaged) + assert total_s == pytest.approx(35.0, abs=1.0), ( + f"Expected ~35 speciations per sample, got {total_s}" + ) + + # D, T, L should all be non-negative + assert total_d >= 0, f"Negative duplication count: {total_d}" + assert total_t >= 0, f"Negative transfer count: {total_t}" + assert total_l >= 0, f"Negative loss count: {total_l}" + + # Total events should be reasonable (not zero, not astronomical) + total_events = total_d + total_t + total_l + total_s + assert total_events > 35.0, ( + f"Total events ({total_events}) suspiciously low" + ) + assert total_events < 1000.0, ( + f"Total events ({total_events}) suspiciously high" + ) diff --git a/ALE_python/tests/test_observe.py b/ALE_python/tests/test_observe.py new file mode 100644 index 0000000..84b9207 --- /dev/null +++ b/ALE_python/tests/test_observe.py @@ -0,0 +1,197 @@ +"""Tests for ALEobserve -- comparing Python output with C++ output.""" + +import os +import shutil +import subprocess + +import pytest + +from ALE_python.ale import ApproxPosterior +from ALE_python.ale_util import observe_ALE_from_file, load_ALE_from_file +from ALE_python.newick import parse_newick, get_leaves + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _run_cpp_observe(cpp_bin, gene_trees_path, tmp_dir, burnin=1000): + """Run the C++ ALEobserve and return the path to the .ale file.""" + gene_trees_copy = os.path.join(tmp_dir, os.path.basename(gene_trees_path)) + shutil.copy(gene_trees_path, gene_trees_copy) + exe = os.path.join(cpp_bin, "ALEobserve") + result = subprocess.run( + [exe, gene_trees_copy, f"burnin={burnin}"], + capture_output=True, + text=True, + cwd=tmp_dir, + timeout=300, + ) + assert result.returncode == 0, f"C++ ALEobserve failed:\n{result.stderr}" + ale_path = gene_trees_copy + ".ale" + assert os.path.isfile(ale_path), f"Expected .ale file not created: {ale_path}" + return ale_path + + +def _load_cpp_ale(cpp_bin, gene_trees_path, tmp_dir, burnin=1000): + """Run C++ ALEobserve and load the resulting .ale file.""" + ale_path = _run_cpp_observe(cpp_bin, gene_trees_path, tmp_dir, burnin=burnin) + return load_ALE_from_file(ale_path) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestObserve: + + def test_observe_creates_ale_file(self, gene_trees, tmp_dir): + """Run Python ALEobserve on example data with burnin=1000 and verify + that an .ale file is created.""" + gene_trees_copy = os.path.join(tmp_dir, os.path.basename(gene_trees)) + shutil.copy(gene_trees, gene_trees_copy) + + ale = observe_ALE_from_file(gene_trees_copy, burnin=1000) + out_path = gene_trees_copy + ".ale" + ale.save_state(out_path) + + assert os.path.isfile(out_path), "Python ALEobserve did not create .ale file" + assert os.path.getsize(out_path) > 0, ".ale file is empty" + + def test_observe_observation_count(self, gene_trees): + """Load the .ale and verify observations == 8504.""" + ale = observe_ALE_from_file(gene_trees, burnin=1000) + assert ale.observations == pytest.approx(8504.0), ( + f"Expected 8504 observations, got {ale.observations}" + ) + + def test_observe_matches_cpp_leaf_count(self, gene_trees, cpp_bin, tmp_dir): + """Compare number of leaves (36) in both Python and C++ .ale files.""" + py_ale = observe_ALE_from_file(gene_trees, burnin=1000) + cpp_ale = _load_cpp_ale(cpp_bin, gene_trees, tmp_dir, burnin=1000) + + assert py_ale.Gamma_size == 36, ( + f"Python Gamma_size={py_ale.Gamma_size}, expected 36" + ) + assert cpp_ale.Gamma_size == 36, ( + f"C++ Gamma_size={cpp_ale.Gamma_size}, expected 36" + ) + assert py_ale.Gamma_size == cpp_ale.Gamma_size + + def test_observe_matches_cpp_bip_counts(self, gene_trees, cpp_bin, tmp_dir): + """Load both .ale files and compare that Bip_counts are similar. + + Python and C++ may assign different integer IDs to the same + bipartitions, so we compare by bitset content rather than by ID. + + Leaf bipartitions (size 1 or Gamma_size-1) are excluded because + the C++ code does not store their counts in Bip_counts during + decompose; instead it special-cases them in p_bip(). + """ + py_ale = observe_ALE_from_file(gene_trees, burnin=1000) + cpp_ale = _load_cpp_ale(cpp_bin, gene_trees, tmp_dir, burnin=1000) + + def _is_leaf_size(size, gamma_size): + return size <= 1 or size >= gamma_size - 1 + + def _bitset_to_leafset(ale, bitset): + """Convert a bitset integer to a frozenset of leaf names.""" + names = set() + for name, lid in ale.leaf_ids.items(): + if (bitset >> lid) & 1: + names.add(name) + return frozenset(names) + + # Build leaf-name-set -> count maps for non-leaf bipartitions + py_by_leaves = {} + for kid, count in py_ale.Bip_counts.items(): + size = py_ale.set_sizes.get(kid, 0) + if _is_leaf_size(size, py_ale.Gamma_size): + continue + leafset = _bitset_to_leafset(py_ale, py_ale.id_sets[kid]) + py_by_leaves[leafset] = count + + cpp_by_leaves = {} + for kid, count in cpp_ale.Bip_counts.items(): + size = cpp_ale.set_sizes.get(kid, 0) + if _is_leaf_size(size, cpp_ale.Gamma_size): + continue + leafset = _bitset_to_leafset(cpp_ale, cpp_ale.id_sets[kid]) + cpp_by_leaves[leafset] = count + + assert len(py_by_leaves) == len(cpp_by_leaves), ( + f"Non-leaf bipartition count mismatch: {len(py_by_leaves)} Python vs {len(cpp_by_leaves)} C++" + ) + + py_only = set(py_by_leaves.keys()) - set(cpp_by_leaves.keys()) + cpp_only = set(cpp_by_leaves.keys()) - set(py_by_leaves.keys()) + assert not py_only and not cpp_only, ( + f"Bipartition mismatch: {len(py_only)} only in Python, {len(cpp_only)} only in C++" + ) + + for leafset in py_by_leaves: + py_val = py_by_leaves[leafset] + cpp_val = cpp_by_leaves[leafset] + assert py_val == pytest.approx(cpp_val, rel=1e-6), ( + f"Bip_counts for {leafset}: Python={py_val}, C++={cpp_val}" + ) + + def test_observe_mpp_tree_has_all_leaves(self, gene_trees): + """Check the mpp tree string contains all 36 leaf names.""" + ale = observe_ALE_from_file(gene_trees, burnin=1000) + mpp_str, mpp_pp = ale.mpp_tree() + + assert len(mpp_str) > 0, "mpp_tree returned empty string" + assert mpp_pp > 0.0, f"mpp posterior probability should be > 0, got {mpp_pp}" + + # Parse the mpp tree and check leaf count + root = parse_newick(mpp_str) + leaves = get_leaves(root) + leaf_names = sorted(n.name for n in leaves) + + assert len(leaf_names) == 36, ( + f"Expected 36 leaves in mpp tree, got {len(leaf_names)}" + ) + + # Verify all expected leaf names are present + expected_leaves = sorted(ale.leaf_ids.keys()) + assert leaf_names == expected_leaves, ( + f"Leaf name mismatch between mpp tree and ale object" + ) + + def test_observe_singleton_tree_file(self, tmp_dir): + """Singleton input should produce a stable one-leaf ALE.""" + source_dir = os.path.join(tmp_dir, "src") + os.makedirs(source_dir, exist_ok=True) + singleton_path = os.path.join(source_dir, "singleton.trees") + with open(singleton_path, "w") as fh: + fh.write("A;\n") + + py_ale = observe_ALE_from_file(singleton_path) + + assert py_ale.observations == pytest.approx(1.0) + assert py_ale.Gamma_size == 1 + assert py_ale.get_leaf_names() == ["A"] + assert py_ale.Bip_bls[1] == pytest.approx(1.0) + + roundtrip_path = os.path.join(tmp_dir, "singleton.ale") + py_ale.save_state(roundtrip_path) + loaded = load_ALE_from_file(roundtrip_path) + assert loaded.observations == pytest.approx(py_ale.observations) + assert loaded.Gamma_size == py_ale.Gamma_size + assert loaded.get_leaf_names() == py_ale.get_leaf_names() + + def test_observe_singleton_then_tree_returns_singleton(self, tmp_dir): + """If a singleton appears first, observe should return that one-leaf ALE.""" + source_dir = os.path.join(tmp_dir, "src") + os.makedirs(source_dir, exist_ok=True) + mixed_path = os.path.join(source_dir, "singleton_then_tree.trees") + with open(mixed_path, "w") as fh: + fh.write("A;\n") + fh.write("(A,B);\n") + + py_ale = observe_ALE_from_file(mixed_path) + + assert py_ale.observations == pytest.approx(1.0) + assert py_ale.Gamma_size == 1 + assert py_ale.get_leaf_names() == ["A"] diff --git a/pixi.lock b/pixi.lock new file mode 100644 index 0000000..a4e572e --- /dev/null +++ b/pixi.lock @@ -0,0 +1,1134 @@ +version: 6 +environments: + default: + channels: + - url: https://conda.anaconda.org/conda-forge/ + options: + pypi-prerelease-mode: if-necessary-or-explicit + packages: + linux-64: + - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.2-hb03c661_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.45.1-default_hfdba357_101.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.34.6-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.2.25-hbd8a1cb_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/cmake-4.3.0-hc85cc9f_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/conda-gcc-specs-14.3.0-he8ccf15_18.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gcc-14.3.0-h0dff253_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-14.3.0-hbdf3cc3_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/git-2.53.0-pl5321h6d3cee1_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gxx-14.3.0-h76987e4_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-14.3.0-h2185e75_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.3-h33c6efd_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-4.18.0-he073ed8_9.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.3-hb9d3cd8_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/krb5-1.22.2-ha1258a1_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.45.1-default_hbd61a6d_101.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libattr-2.5.2-hb03c661_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-5_h4a7cf45_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libcap-2.77-h3ff7636_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-5_h0358290_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libcurl-8.19.0-hcf29cc6_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20250104-pl5321h7949ede_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libev-4.33-hd590300_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.7.4-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfabric-2.4.0-ha770c72_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libfabric1-2.4.0-h8f87c3e_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.2.0-he0feb66_18.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/libgcc-devel_linux-64-14.3.0-hf649bbc_118.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-15.2.0-h69a702a_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-15.2.0-h69a702a_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-15.2.0-h68bc16d_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgomp-15.2.0-he0feb66_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.13.0-default_he001693_1000.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.18-h3b78370_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.11.0-5_h47877c9_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/liblzma-5.8.2-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.68.1-h877daf1_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libnl-3.11.0-hb9d3cd8_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.30-pthreads_h94d23a6_4.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libpmix-5.0.8-h31fc519_4.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsanitizer-14.3.0-h8f1669f_18.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.52.0-hf4e2dac_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.1-hcf80075_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_18.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/libstdcxx-devel_linux-64-14.3.0-h9f08a49_118.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-257.13-hd0affe5_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libudev1-257.13-hd0affe5_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.41.3-h5347b49_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libuv-1.51.0-hb03c661_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.2-hca6bf5a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.2-he237659_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/make-4.4.1-hb9d3cd8_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mpi-1.0.1-openmpi.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.3-py314h2b28147_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openmpi-5.0.10-h67ed482_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.1-h35e630c_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-26.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/perl-5.32.1-7_hd590300_perl5.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.14.3-h32b2ec7_101_cp314.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.14-8_cp314.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/rdma-core-61.0-h192683f_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/rhash-1.4.6-hb9d3cd8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.17.1-py314hf07bd8e_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.28-h4ee821c_9.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.4.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ucc-1.7.0-hcedbda0_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ucx-1.20.0-hf72d326_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda +packages: +- conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda + build_number: 20 + sha256: 1dd3fffd892081df9726d7eb7e0dea6198962ba775bd88842135a4ddb4deb3c9 + md5: a9f577daf3de00bca7c3c76c0ecbd1de + depends: + - __glibc >=2.17,<3.0.a0 + - libgomp >=7.5.0 + constrains: + - openmp_impl <0.0a0 + license: BSD-3-Clause + license_family: BSD + size: 28948 + timestamp: 1770939786096 +- conda: https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.2-hb03c661_1.conda + sha256: 78c516af87437f52d883193cf167378f592ad445294c69f7c69f56059087c40d + md5: 9bb149f49de3f322fca007283eaa2725 + depends: + - __glibc >=2.17,<3.0.a0 + - libattr 2.5.2 hb03c661_1 + - libgcc >=14 + license: GPL-2.0-or-later + license_family: GPL + size: 31386 + timestamp: 1773595914754 +- conda: https://conda.anaconda.org/conda-forge/linux-64/binutils_impl_linux-64-2.45.1-default_hfdba357_101.conda + sha256: 74341b26a2b9475dc14ba3cf12432fcd10a23af285101883e720216d81d44676 + md5: 83aa53cb3f5fc849851a84d777a60551 + depends: + - ld_impl_linux-64 2.45.1 default_hbd61a6d_101 + - sysroot_linux-64 + - zstd >=1.5.7,<1.6.0a0 + license: GPL-3.0-only + license_family: GPL + size: 3744895 + timestamp: 1770267152681 +- conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda + sha256: 0b75d45f0bba3e95dc693336fa51f40ea28c980131fec438afb7ce6118ed05f6 + md5: d2ffd7602c02f2b316fd921d39876885 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: bzip2-1.0.6 + license_family: BSD + size: 260182 + timestamp: 1771350215188 +- conda: https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.34.6-hb03c661_0.conda + sha256: cc9accf72fa028d31c2a038460787751127317dcfa991f8d1f1babf216bb454e + md5: 920bb03579f15389b9e512095ad995b7 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: MIT + license_family: MIT + size: 207882 + timestamp: 1765214722852 +- conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.2.25-hbd8a1cb_0.conda + sha256: 67cc7101b36421c5913a1687ef1b99f85b5d6868da3abbf6ec1a4181e79782fc + md5: 4492fd26db29495f0ba23f146cd5638d + depends: + - __unix + license: ISC + size: 147413 + timestamp: 1772006283803 +- conda: https://conda.anaconda.org/conda-forge/linux-64/cmake-4.3.0-hc85cc9f_0.conda + sha256: 77ac1115fb713dfc058b9d0eefd889610db62184e74dbe7eabd2a8c3c3a31ab1 + md5: 59c51e8455a594f1a78a307d9d94fde7 + depends: + - __glibc >=2.17,<3.0.a0 + - bzip2 >=1.0.8,<2.0a0 + - libcurl >=8.19.0,<9.0a0 + - libexpat >=2.7.4,<3.0a0 + - libgcc >=14 + - liblzma >=5.8.2,<6.0a0 + - libstdcxx >=14 + - libuv >=1.51.0,<2.0a0 + - libzlib >=1.3.1,<2.0a0 + - ncurses >=6.5,<7.0a0 + - rhash >=1.4.6,<2.0a0 + - zstd >=1.5.7,<1.6.0a0 + license: BSD-3-Clause + license_family: BSD + size: 23023925 + timestamp: 1773792006849 +- conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + sha256: ab29d57dc70786c1269633ba3dff20288b81664d3ff8d21af995742e2bb03287 + md5: 962b9857ee8e7018c22f2776ffa0b2d7 + depends: + - python >=3.9 + license: BSD-3-Clause + license_family: BSD + size: 27011 + timestamp: 1733218222191 +- conda: https://conda.anaconda.org/conda-forge/linux-64/conda-gcc-specs-14.3.0-he8ccf15_18.conda + sha256: b90ec0e6a9eb22f7240b3584fe785457cff961fec68d40e6aece5d596f9bbd9a + md5: 0e3e144115c43c9150d18fa20db5f31c + depends: + - gcc_impl_linux-64 >=14.3.0,<14.3.1.0a0 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 31705 + timestamp: 1771378159534 +- conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda + sha256: ee6cf346d017d954255bbcbdb424cddea4d14e4ed7e9813e429db1d795d01144 + md5: 8e662bd460bda79b1ea39194e3c4c9ab + depends: + - python >=3.10 + - typing_extensions >=4.6.0 + license: MIT and PSF-2.0 + size: 21333 + timestamp: 1763918099466 +- conda: https://conda.anaconda.org/conda-forge/linux-64/gcc-14.3.0-h0dff253_18.conda + sha256: 9b34b57b06b485e33a40d430f71ac88c8f381673592507cf7161c50ff0832772 + md5: 52d6457abc42e320787ada5f9033fa99 + depends: + - conda-gcc-specs + - gcc_impl_linux-64 14.3.0 hbdf3cc3_18 + license: BSD-3-Clause + license_family: BSD + size: 29506 + timestamp: 1771378321585 +- conda: https://conda.anaconda.org/conda-forge/linux-64/gcc_impl_linux-64-14.3.0-hbdf3cc3_18.conda + sha256: 3b31a273b806c6851e16e9cf63ef87cae28d19be0df148433f3948e7da795592 + md5: 30bb690150536f622873758b0e8d6712 + depends: + - binutils_impl_linux-64 >=2.45 + - libgcc >=14.3.0 + - libgcc-devel_linux-64 14.3.0 hf649bbc_118 + - libgomp >=14.3.0 + - libsanitizer 14.3.0 h8f1669f_18 + - libstdcxx >=14.3.0 + - libstdcxx-devel_linux-64 14.3.0 h9f08a49_118 + - sysroot_linux-64 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 76302378 + timestamp: 1771378056505 +- conda: https://conda.anaconda.org/conda-forge/linux-64/git-2.53.0-pl5321h6d3cee1_0.conda + sha256: 33b20cf09ff1c6ca960e6c5f7fad1f08ffd3112a87d79e42ed56f4e1b4cdefe3 + md5: ad8d4260a6dae5f55960b26b237d576b + depends: + - __glibc >=2.28,<3.0.a0 + - libcurl >=8.18.0,<9.0a0 + - libexpat >=2.7.3,<3.0a0 + - libgcc >=14 + - libiconv >=1.18,<2.0a0 + - libzlib >=1.3.1,<2.0a0 + - openssl >=3.5.5,<4.0a0 + - pcre2 >=10.47,<10.48.0a0 + - perl 5.* + license: GPL-2.0-or-later and LGPL-2.1-or-later + size: 11447951 + timestamp: 1770982660115 +- conda: https://conda.anaconda.org/conda-forge/linux-64/gxx-14.3.0-h76987e4_18.conda + sha256: 1b490c9be9669f9c559db7b2a1f7d8b973c58ca0c6f21a5d2ba3f0ab2da63362 + md5: 19189121d644d4ef75fed05383bc75f5 + depends: + - gcc 14.3.0 h0dff253_18 + - gxx_impl_linux-64 14.3.0 h2185e75_18 + license: BSD-3-Clause + license_family: BSD + size: 28883 + timestamp: 1771378355605 +- conda: https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-14.3.0-h2185e75_18.conda + sha256: 38ffca57cc9c264d461ac2ce9464a9d605e0f606d92d831de9075cb0d95fc68a + md5: 6514b3a10e84b6a849e1b15d3753eb22 + depends: + - gcc_impl_linux-64 14.3.0 hbdf3cc3_18 + - libstdcxx-devel_linux-64 14.3.0 h9f08a49_118 + - sysroot_linux-64 + - tzdata + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 14566100 + timestamp: 1771378271421 +- conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.3-h33c6efd_0.conda + sha256: fbf86c4a59c2ed05bbffb2ba25c7ed94f6185ec30ecb691615d42342baa1a16a + md5: c80d8a3b84358cb967fa81e7075fbc8a + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libstdcxx >=14 + license: MIT + license_family: MIT + size: 12723451 + timestamp: 1773822285671 +- conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + sha256: e1a9e3b1c8fe62dc3932a616c284b5d8cbe3124bbfbedcf4ce5c828cb166ee19 + md5: 9614359868482abba1bd15ce465e3c42 + depends: + - python >=3.10 + license: MIT + license_family: MIT + size: 13387 + timestamp: 1760831448842 +- conda: https://conda.anaconda.org/conda-forge/noarch/kernel-headers_linux-64-4.18.0-he073ed8_9.conda + sha256: 41557eeadf641de6aeae49486cef30d02a6912d8da98585d687894afd65b356a + md5: 86d9cba083cd041bfbf242a01a7a1999 + constrains: + - sysroot_linux-64 ==2.28 + license: LGPL-2.0-or-later AND LGPL-2.0-or-later WITH exceptions AND GPL-2.0-or-later + license_family: GPL + size: 1278712 + timestamp: 1765578681495 +- conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.3-hb9d3cd8_0.conda + sha256: 0960d06048a7185d3542d850986d807c6e37ca2e644342dd0c72feefcf26c2a4 + md5: b38117a3c920364aff79f870c984b4a3 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: LGPL-2.1-or-later + size: 134088 + timestamp: 1754905959823 +- conda: https://conda.anaconda.org/conda-forge/linux-64/krb5-1.22.2-ha1258a1_0.conda + sha256: 3e307628ca3527448dd1cb14ad7bb9d04d1d28c7d4c5f97ba196ae984571dd25 + md5: fb53fb07ce46a575c5d004bbc96032c2 + depends: + - __glibc >=2.17,<3.0.a0 + - keyutils >=1.6.3,<2.0a0 + - libedit >=3.1.20250104,<3.2.0a0 + - libedit >=3.1.20250104,<4.0a0 + - libgcc >=14 + - libstdcxx >=14 + - openssl >=3.5.5,<4.0a0 + license: MIT + license_family: MIT + size: 1386730 + timestamp: 1769769569681 +- conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.45.1-default_hbd61a6d_101.conda + sha256: 565941ac1f8b0d2f2e8f02827cbca648f4d18cd461afc31f15604cd291b5c5f3 + md5: 12bd9a3f089ee6c9266a37dab82afabd + depends: + - __glibc >=2.17,<3.0.a0 + - zstd >=1.5.7,<1.6.0a0 + constrains: + - binutils_impl_linux-64 2.45.1 + license: GPL-3.0-only + license_family: GPL + size: 725507 + timestamp: 1770267139900 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libattr-2.5.2-hb03c661_1.conda + sha256: 0cef37eb013dc7091f17161c357afbdef9a9bc79ef6462508face6db3f37db77 + md5: 7e7f0a692eb62b95d3010563e7f963b6 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: LGPL-2.1-or-later + license_family: LGPL + size: 53316 + timestamp: 1773595896163 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-5_h4a7cf45_openblas.conda + build_number: 5 + sha256: 18c72545080b86739352482ba14ba2c4815e19e26a7417ca21a95b76ec8da24c + md5: c160954f7418d7b6e87eaf05a8913fa9 + depends: + - libopenblas >=0.3.30,<0.3.31.0a0 + - libopenblas >=0.3.30,<1.0a0 + constrains: + - mkl <2026 + - liblapack 3.11.0 5*_openblas + - libcblas 3.11.0 5*_openblas + - blas 2.305 openblas + - liblapacke 3.11.0 5*_openblas + license: BSD-3-Clause + license_family: BSD + size: 18213 + timestamp: 1765818813880 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libcap-2.77-h3ff7636_0.conda + sha256: 9517cce5193144af0fcbf19b7bd67db0a329c2cc2618f28ffecaa921a1cbe9d3 + md5: 09c264d40c67b82b49a3f3b89037bd2e + depends: + - __glibc >=2.17,<3.0.a0 + - attr >=2.5.2,<2.6.0a0 + - libgcc >=14 + license: BSD-3-Clause + license_family: BSD + size: 121429 + timestamp: 1762349484074 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-5_h0358290_openblas.conda + build_number: 5 + sha256: 0cbdcc67901e02dc17f1d19e1f9170610bd828100dc207de4d5b6b8ad1ae7ad8 + md5: 6636a2b6f1a87572df2970d3ebc87cc0 + depends: + - libblas 3.11.0 5_h4a7cf45_openblas + constrains: + - liblapacke 3.11.0 5*_openblas + - blas 2.305 openblas + - liblapack 3.11.0 5*_openblas + license: BSD-3-Clause + license_family: BSD + size: 18194 + timestamp: 1765818837135 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libcurl-8.19.0-hcf29cc6_0.conda + sha256: a0390fd0536ebcd2244e243f5f00ab8e76ab62ed9aa214cd54470fe7496620f4 + md5: d50608c443a30c341c24277d28290f76 + depends: + - __glibc >=2.17,<3.0.a0 + - krb5 >=1.22.2,<1.23.0a0 + - libgcc >=14 + - libnghttp2 >=1.67.0,<2.0a0 + - libssh2 >=1.11.1,<2.0a0 + - libzlib >=1.3.1,<2.0a0 + - openssl >=3.5.5,<4.0a0 + - zstd >=1.5.7,<1.6.0a0 + license: curl + license_family: MIT + size: 466704 + timestamp: 1773218522665 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20250104-pl5321h7949ede_0.conda + sha256: d789471216e7aba3c184cd054ed61ce3f6dac6f87a50ec69291b9297f8c18724 + md5: c277e0a4d549b03ac1e9d6cbbe3d017b + depends: + - ncurses + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + - ncurses >=6.5,<7.0a0 + license: BSD-2-Clause + license_family: BSD + size: 134676 + timestamp: 1738479519902 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libev-4.33-hd590300_2.conda + sha256: 1cd6048169fa0395af74ed5d8f1716e22c19a81a8a36f934c110ca3ad4dd27b4 + md5: 172bf1cd1ff8629f2b1179945ed45055 + depends: + - libgcc-ng >=12 + license: BSD-2-Clause + license_family: BSD + size: 112766 + timestamp: 1702146165126 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda + sha256: 2e14399d81fb348e9d231a82ca4d816bf855206923759b69ad006ba482764131 + md5: a1cfcc585f0c42bf8d5546bb1dfb668d + depends: + - libgcc-ng >=12 + - openssl >=3.1.1,<4.0a0 + license: BSD-3-Clause + license_family: BSD + size: 427426 + timestamp: 1685725977222 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.7.4-hecca717_0.conda + sha256: d78f1d3bea8c031d2f032b760f36676d87929b18146351c4464c66b0869df3f5 + md5: e7f7ce06ec24cfcfb9e36d28cf82ba57 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + constrains: + - expat 2.7.4.* + license: MIT + license_family: MIT + size: 76798 + timestamp: 1771259418166 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libfabric-2.4.0-ha770c72_1.conda + sha256: c5298c27fe1be477b17cd989566eb6c1a1bb50222f2f90389143b6f06ba95398 + md5: 647939791f2cc2de3b4ecac28d216279 + depends: + - libfabric1 2.4.0 h8f87c3e_1 + license: BSD-2-Clause OR GPL-2.0-only + license_family: BSD + size: 14406 + timestamp: 1769190335747 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libfabric1-2.4.0-h8f87c3e_1.conda + sha256: 3110ee1b3debb97638897bb0d7074ee257ff33519520327064c36a35391dec50 + md5: c5fc7dbc3dbabcae1eec5d6c62251df8 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libnl >=3.11.0,<4.0a0 + - rdma-core >=61.0 + license: BSD-2-Clause OR GPL-2.0-only + license_family: BSD + size: 699849 + timestamp: 1769190335048 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda + sha256: 31f19b6a88ce40ebc0d5a992c131f57d919f73c0b92cd1617a5bec83f6e961e6 + md5: a360c33a5abe61c07959e449fa1453eb + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: MIT + license_family: MIT + size: 58592 + timestamp: 1769456073053 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-15.2.0-he0feb66_18.conda + sha256: faf7d2017b4d718951e3a59d081eb09759152f93038479b768e3d612688f83f5 + md5: 0aa00f03f9e39fb9876085dee11a85d4 + depends: + - __glibc >=2.17,<3.0.a0 + - _openmp_mutex >=4.5 + constrains: + - libgcc-ng ==15.2.0=*_18 + - libgomp 15.2.0 he0feb66_18 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 1041788 + timestamp: 1771378212382 +- conda: https://conda.anaconda.org/conda-forge/noarch/libgcc-devel_linux-64-14.3.0-hf649bbc_118.conda + sha256: 1abc6a81ee66e8ac9ac09a26e2d6ad7bba23f0a0cc3a6118654f036f9c0e1854 + md5: 06901733131833f5edd68cf3d9679798 + depends: + - __unix + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 3084533 + timestamp: 1771377786730 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-15.2.0-h69a702a_18.conda + sha256: e318a711400f536c81123e753d4c797a821021fb38970cebfb3f454126016893 + md5: d5e96b1ed75ca01906b3d2469b4ce493 + depends: + - libgcc 15.2.0 he0feb66_18 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 27526 + timestamp: 1771378224552 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-15.2.0-h69a702a_18.conda + sha256: d2c9fad338fd85e4487424865da8e74006ab2e2475bd788f624d7a39b2a72aee + md5: 9063115da5bc35fdc3e1002e69b9ef6e + depends: + - libgfortran5 15.2.0 h68bc16d_18 + constrains: + - libgfortran-ng ==15.2.0=*_18 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 27523 + timestamp: 1771378269450 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-15.2.0-h68bc16d_18.conda + sha256: 539b57cf50ec85509a94ba9949b7e30717839e4d694bc94f30d41c9d34de2d12 + md5: 646855f357199a12f02a87382d429b75 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=15.2.0 + constrains: + - libgfortran 15.2.0 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 2482475 + timestamp: 1771378241063 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libgomp-15.2.0-he0feb66_18.conda + sha256: 21337ab58e5e0649d869ab168d4e609b033509de22521de1bfed0c031bfc5110 + md5: 239c5e9546c38a1e884d69effcf4c882 + depends: + - __glibc >=2.17,<3.0.a0 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 603262 + timestamp: 1771378117851 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.13.0-default_he001693_1000.conda + sha256: 5041d295813dfb84652557839825880aae296222ab725972285c5abe3b6e4288 + md5: c197985b58bc813d26b42881f0021c82 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libstdcxx >=14 + - libxml2 + - libxml2-16 >=2.14.6 + license: BSD-3-Clause + license_family: BSD + size: 2436378 + timestamp: 1770953868164 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.18-h3b78370_2.conda + sha256: c467851a7312765447155e071752d7bf9bf44d610a5687e32706f480aad2833f + md5: 915f5995e94f60e9a4826e0b0920ee88 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: LGPL-2.1-only + size: 790176 + timestamp: 1754908768807 +- conda: https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.11.0-5_h47877c9_openblas.conda + build_number: 5 + sha256: c723b6599fcd4c6c75dee728359ef418307280fa3e2ee376e14e85e5bbdda053 + md5: b38076eb5c8e40d0106beda6f95d7609 + depends: + - libblas 3.11.0 5_h4a7cf45_openblas + constrains: + - blas 2.305 openblas + - liblapacke 3.11.0 5*_openblas + - libcblas 3.11.0 5*_openblas + license: BSD-3-Clause + license_family: BSD + size: 18200 + timestamp: 1765818857876 +- conda: https://conda.anaconda.org/conda-forge/linux-64/liblzma-5.8.2-hb03c661_0.conda + sha256: 755c55ebab181d678c12e49cced893598f2bab22d582fbbf4d8b83c18be207eb + md5: c7c83eecbb72d88b940c249af56c8b17 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + constrains: + - xz 5.8.2.* + license: 0BSD + size: 113207 + timestamp: 1768752626120 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda + sha256: fe171ed5cf5959993d43ff72de7596e8ac2853e9021dec0344e583734f1e0843 + md5: 2c21e66f50753a083cbe6b80f38268fa + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: BSD-2-Clause + license_family: BSD + size: 92400 + timestamp: 1769482286018 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.68.1-h877daf1_0.conda + sha256: 663444d77a42f2265f54fb8b48c5450bfff4388d9c0f8253dd7855f0d993153f + md5: 2a45e7f8af083626f009645a6481f12d + depends: + - __glibc >=2.17,<3.0.a0 + - c-ares >=1.34.6,<2.0a0 + - libev >=4.33,<4.34.0a0 + - libev >=4.33,<5.0a0 + - libgcc >=14 + - libstdcxx >=14 + - libzlib >=1.3.1,<2.0a0 + - openssl >=3.5.5,<4.0a0 + license: MIT + license_family: MIT + size: 663344 + timestamp: 1773854035739 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libnl-3.11.0-hb9d3cd8_0.conda + sha256: ba7c5d294e3d80f08ac5a39564217702d1a752e352e486210faff794ac5001b4 + md5: db63358239cbe1ff86242406d440e44a + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: LGPL-2.1-or-later + license_family: LGPL + size: 741323 + timestamp: 1731846827427 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.30-pthreads_h94d23a6_4.conda + sha256: 199d79c237afb0d4780ccd2fbf829cea80743df60df4705202558675e07dd2c5 + md5: be43915efc66345cccb3c310b6ed0374 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libgfortran + - libgfortran5 >=14.3.0 + constrains: + - openblas >=0.3.30,<0.3.31.0a0 + license: BSD-3-Clause + license_family: BSD + size: 5927939 + timestamp: 1763114673331 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libpmix-5.0.8-h31fc519_4.conda + sha256: 2837759fb832ab8587622602531f8e4b5d5e5ec12ea9b3936be06f384b933800 + md5: bd15ae3916a0cbe005c683bbc33811b7 + depends: + - __glibc >=2.17,<3.0.a0 + - libevent >=2.1.12,<2.1.13.0a0 + - libgcc >=14 + - libhwloc >=2.13.0,<2.13.1.0a0 + - libzlib >=1.3.1,<2.0a0 + license: BSD-3-Clause + license_family: BSD + size: 731098 + timestamp: 1770962978877 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libsanitizer-14.3.0-h8f1669f_18.conda + sha256: e03ed186eefb46d7800224ad34bad1268c9d19ecb8f621380a50601c6221a4a7 + md5: ad3a0e2dc4cce549b2860e2ef0e6d75b + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14.3.0 + - libstdcxx >=14.3.0 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 7949259 + timestamp: 1771377982207 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.52.0-hf4e2dac_0.conda + sha256: d716847b7deca293d2e49ed1c8ab9e4b9e04b9d780aea49a97c26925b28a7993 + md5: fd893f6a3002a635b5e50ceb9dd2c0f4 + depends: + - __glibc >=2.17,<3.0.a0 + - icu >=78.2,<79.0a0 + - libgcc >=14 + - libzlib >=1.3.1,<2.0a0 + license: blessing + size: 951405 + timestamp: 1772818874251 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.1-hcf80075_0.conda + sha256: fa39bfd69228a13e553bd24601332b7cfeb30ca11a3ca50bb028108fe90a7661 + md5: eecce068c7e4eddeb169591baac20ac4 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + - libzlib >=1.3.1,<2.0a0 + - openssl >=3.5.0,<4.0a0 + license: BSD-3-Clause + license_family: BSD + size: 304790 + timestamp: 1745608545575 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_18.conda + sha256: 78668020064fdaa27e9ab65cd2997e2c837b564ab26ce3bf0e58a2ce1a525c6e + md5: 1b08cd684f34175e4514474793d44bcb + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc 15.2.0 he0feb66_18 + constrains: + - libstdcxx-ng ==15.2.0=*_18 + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 5852330 + timestamp: 1771378262446 +- conda: https://conda.anaconda.org/conda-forge/noarch/libstdcxx-devel_linux-64-14.3.0-h9f08a49_118.conda + sha256: b1c3824769b92a1486bf3e2cc5f13304d83ae613ea061b7bc47bb6080d6dfdba + md5: 865a399bce236119301ebd1532fced8d + depends: + - __unix + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 20171098 + timestamp: 1771377827750 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-257.13-hd0affe5_0.conda + sha256: c5008b602cb5c819f7b52d418b3ed17e1818cbbf6705b189e7ab36bb70cce3d8 + md5: 8ee3cb7f64be0e8c4787f3a4dbe024e6 + depends: + - __glibc >=2.17,<3.0.a0 + - libcap >=2.77,<2.78.0a0 + - libgcc >=14 + license: LGPL-2.1-or-later + size: 492799 + timestamp: 1773797095649 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libudev1-257.13-hd0affe5_0.conda + sha256: 1a1e367c04d66030aa93b4d33905f7f6fbb59cfc292e816fe3e9c1e8b3f4d1e2 + md5: 2c2270f93d6f9073cbf72d821dfc7d72 + depends: + - __glibc >=2.17,<3.0.a0 + - libcap >=2.77,<2.78.0a0 + - libgcc >=14 + license: LGPL-2.1-or-later + size: 145087 + timestamp: 1773797108513 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.41.3-h5347b49_0.conda + sha256: 1a7539cfa7df00714e8943e18de0b06cceef6778e420a5ee3a2a145773758aee + md5: db409b7c1720428638e7c0d509d3e1b5 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: BSD-3-Clause + license_family: BSD + size: 40311 + timestamp: 1766271528534 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libuv-1.51.0-hb03c661_1.conda + sha256: c180f4124a889ac343fc59d15558e93667d894a966ec6fdb61da1604481be26b + md5: 0f03292cc56bf91a077a134ea8747118 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + license: MIT + license_family: MIT + size: 895108 + timestamp: 1753948278280 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda + sha256: 6ae68e0b86423ef188196fff6207ed0c8195dd84273cb5623b85aa08033a410c + md5: 5aa797f8787fe7a17d1b0821485b5adc + depends: + - libgcc-ng >=12 + license: LGPL-2.1-or-later + size: 100393 + timestamp: 1702724383534 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.15.2-he237659_0.conda + sha256: 275c324f87bda1a3b67d2f4fcc3555eeff9e228a37655aa001284a7ceb6b0392 + md5: e49238a1609f9a4a844b09d9926f2c3d + depends: + - __glibc >=2.17,<3.0.a0 + - icu >=78.2,<79.0a0 + - libgcc >=14 + - libiconv >=1.18,<2.0a0 + - liblzma >=5.8.2,<6.0a0 + - libxml2-16 2.15.2 hca6bf5a_0 + - libzlib >=1.3.1,<2.0a0 + license: MIT + license_family: MIT + size: 45968 + timestamp: 1772704614539 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-16-2.15.2-hca6bf5a_0.conda + sha256: 08d2b34b49bec9613784f868209bb7c3bb8840d6cf835ff692e036b09745188c + md5: f3bc152cb4f86babe30f3a4bf0dbef69 + depends: + - __glibc >=2.17,<3.0.a0 + - icu >=78.2,<79.0a0 + - libgcc >=14 + - libiconv >=1.18,<2.0a0 + - liblzma >=5.8.2,<6.0a0 + - libzlib >=1.3.1,<2.0a0 + constrains: + - libxml2 2.15.2 + license: MIT + license_family: MIT + size: 557492 + timestamp: 1772704601644 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda + sha256: d4bfe88d7cb447768e31650f06257995601f89076080e76df55e3112d4e47dc4 + md5: edb0dca6bc32e4f4789199455a1dbeb8 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + constrains: + - zlib 1.3.1 *_2 + license: Zlib + license_family: Other + size: 60963 + timestamp: 1727963148474 +- conda: https://conda.anaconda.org/conda-forge/linux-64/make-4.4.1-hb9d3cd8_2.conda + sha256: d652c7bd4d3b6f82b0f6d063b0d8df6f54cc47531092d7ff008e780f3261bdda + md5: 33405d2a66b1411db9f7242c8b97c9e7 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: GPL-3.0-or-later + license_family: GPL + size: 513088 + timestamp: 1727801714848 +- conda: https://conda.anaconda.org/conda-forge/noarch/mpi-1.0.1-openmpi.conda + sha256: e1698675ec83a2139c0b02165f47eaf0701bcab043443d9008fc0f8867b07798 + md5: 78b827d2852c67c68cd5b2c55f31e376 + license: BSD-3-Clause + license_family: BSD + size: 6571 + timestamp: 1727683130230 +- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda + sha256: 3fde293232fa3fca98635e1167de6b7c7fda83caf24b9d6c91ec9eefb4f4d586 + md5: 47e340acb35de30501a76c7c799c41d7 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: X11 AND BSD-3-Clause + size: 891641 + timestamp: 1738195959188 +- conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.3-py314h2b28147_0.conda + sha256: f2ba8cb0d86a6461a6bcf0d315c80c7076083f72c6733c9290086640723f79ec + md5: 36f5b7eb328bdc204954a2225cf908e2 + depends: + - python + - libstdcxx >=14 + - libgcc >=14 + - __glibc >=2.17,<3.0.a0 + - python_abi 3.14.* *_cp314 + - libcblas >=3.9.0,<4.0a0 + - liblapack >=3.9.0,<4.0a0 + - libblas >=3.9.0,<4.0a0 + constrains: + - numpy-base <0a0 + license: BSD-3-Clause + license_family: BSD + size: 8927860 + timestamp: 1773839233468 +- conda: https://conda.anaconda.org/conda-forge/linux-64/openmpi-5.0.10-h67ed482_1.conda + sha256: 9d13b09c95f5e9429f295dc89a896dae41a4ca4f77118139b1ff02001ae25127 + md5: afa5d72e0e68fdf2b51b1c80a3d2086b + depends: + - mpi 1.0.* openmpi + - libgcc >=14 + - libgfortran5 >=14.3.0 + - libgfortran + - __glibc >=2.17,<3.0.a0 + - libstdcxx >=14 + - libevent >=2.1.12,<2.1.13.0a0 + - libhwloc >=2.13.0,<2.13.1.0a0 + - ucc >=1.6.0,<2.0a0 + - libzlib >=1.3.1,<2.0a0 + - ucx >=1.20.0,<1.21.0a0 + - libfabric + - libfabric1 >=1.14.0 + - libpmix >=5.0.8,<6.0a0 + - libnl >=3.11.0,<4.0a0 + constrains: + - __cuda >=12.0 + - cuda-version >=12.0 + - libprrte ==0.0.0 + license: BSD-3-Clause + license_family: BSD + size: 3940272 + timestamp: 1772089559619 +- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.1-h35e630c_1.conda + sha256: 44c877f8af015332a5d12f5ff0fb20ca32f896526a7d0cdb30c769df1144fb5c + md5: f61eb8cd60ff9057122a3d338b99c00f + depends: + - __glibc >=2.17,<3.0.a0 + - ca-certificates + - libgcc >=14 + license: Apache-2.0 + license_family: Apache + size: 3164551 + timestamp: 1769555830639 +- conda: https://conda.anaconda.org/conda-forge/noarch/packaging-26.0-pyhcf101f3_0.conda + sha256: c1fc0f953048f743385d31c468b4a678b3ad20caffdeaa94bed85ba63049fd58 + md5: b76541e68fea4d511b1ac46a28dcd2c6 + depends: + - python >=3.8 + - python + license: Apache-2.0 + license_family: APACHE + size: 72010 + timestamp: 1769093650580 +- conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda + sha256: 5e6f7d161356fefd981948bea5139c5aa0436767751a6930cb1ca801ebb113ff + md5: 7a3bff861a6583f1889021facefc08b1 + depends: + - __glibc >=2.17,<3.0.a0 + - bzip2 >=1.0.8,<2.0a0 + - libgcc >=14 + - libzlib >=1.3.1,<2.0a0 + license: BSD-3-Clause + license_family: BSD + size: 1222481 + timestamp: 1763655398280 +- conda: https://conda.anaconda.org/conda-forge/linux-64/perl-5.32.1-7_hd590300_perl5.conda + build_number: 7 + sha256: 9ec32b6936b0e37bcb0ed34f22ec3116e75b3c0964f9f50ecea5f58734ed6ce9 + md5: f2cfec9406850991f4e3d960cc9e3321 + depends: + - libgcc-ng >=12 + - libxcrypt >=4.4.36 + license: GPL-1.0-or-later OR Artistic-1.0-Perl + size: 13344463 + timestamp: 1703310653947 +- conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda + sha256: e14aafa63efa0528ca99ba568eaf506eb55a0371d12e6250aaaa61718d2eb62e + md5: d7585b6550ad04c8c5e21097ada2888e + depends: + - python >=3.9 + - python + license: MIT + license_family: MIT + size: 25877 + timestamp: 1764896838868 +- conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda + sha256: 5577623b9f6685ece2697c6eb7511b4c9ac5fb607c9babc2646c811b428fd46a + md5: 6b6ece66ebcae2d5f326c77ef2c5a066 + depends: + - python >=3.9 + license: BSD-2-Clause + license_family: BSD + size: 889287 + timestamp: 1750615908735 +- conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda + sha256: 9e749fb465a8bedf0184d8b8996992a38de351f7c64e967031944978de03a520 + md5: 2b694bad8a50dc2f712f5368de866480 + depends: + - pygments >=2.7.2 + - python >=3.10 + - iniconfig >=1.0.1 + - packaging >=22 + - pluggy >=1.5,<2 + - tomli >=1 + - colorama >=0.4 + - exceptiongroup >=1 + - python + constrains: + - pytest-faulthandler >=2 + license: MIT + license_family: MIT + size: 299581 + timestamp: 1765062031645 +- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.14.3-h32b2ec7_101_cp314.conda + build_number: 101 + sha256: cb0628c5f1732f889f53a877484da98f5a0e0f47326622671396fb4f2b0cd6bd + md5: c014ad06e60441661737121d3eae8a60 + depends: + - __glibc >=2.17,<3.0.a0 + - bzip2 >=1.0.8,<2.0a0 + - ld_impl_linux-64 >=2.36.1 + - libexpat >=2.7.3,<3.0a0 + - libffi >=3.5.2,<3.6.0a0 + - libgcc >=14 + - liblzma >=5.8.2,<6.0a0 + - libmpdec >=4.0.0,<5.0a0 + - libsqlite >=3.51.2,<4.0a0 + - libuuid >=2.41.3,<3.0a0 + - libzlib >=1.3.1,<2.0a0 + - ncurses >=6.5,<7.0a0 + - openssl >=3.5.5,<4.0a0 + - python_abi 3.14.* *_cp314 + - readline >=8.3,<9.0a0 + - tk >=8.6.13,<8.7.0a0 + - tzdata + - zstd >=1.5.7,<1.6.0a0 + license: Python-2.0 + size: 36702440 + timestamp: 1770675584356 + python_site_packages_path: lib/python3.14/site-packages +- conda: https://conda.anaconda.org/conda-forge/noarch/python_abi-3.14-8_cp314.conda + build_number: 8 + sha256: ad6d2e9ac39751cc0529dd1566a26751a0bf2542adb0c232533d32e176e21db5 + md5: 0539938c55b6b1a59b560e843ad864a4 + constrains: + - python 3.14.* *_cp314 + license: BSD-3-Clause + license_family: BSD + size: 6989 + timestamp: 1752805904792 +- conda: https://conda.anaconda.org/conda-forge/linux-64/rdma-core-61.0-h192683f_0.conda + sha256: 8e0b7962cf8bec9a016cd91a6c6dc1f9ebc8e7e316b1d572f7b9047d0de54717 + md5: d487d93d170e332ab39803e05912a762 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libnl >=3.11.0,<4.0a0 + - libstdcxx >=14 + - libsystemd0 >=257.10 + - libudev1 >=257.10 + license: Linux-OpenIB + license_family: BSD + size: 1268666 + timestamp: 1769154883613 +- conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda + sha256: 12ffde5a6f958e285aa22c191ca01bbd3d6e710aa852e00618fa6ddc59149002 + md5: d7d95fc8287ea7bf33e0e7116d2b95ec + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - ncurses >=6.5,<7.0a0 + license: GPL-3.0-only + license_family: GPL + size: 345073 + timestamp: 1765813471974 +- conda: https://conda.anaconda.org/conda-forge/linux-64/rhash-1.4.6-hb9d3cd8_1.conda + sha256: d5c73079c1dd2c2a313c3bfd81c73dbd066b7eb08d213778c8bff520091ae894 + md5: c1c9b02933fdb2cfb791d936c20e887e + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: MIT + license_family: MIT + size: 193775 + timestamp: 1748644872902 +- conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.17.1-py314hf07bd8e_0.conda + sha256: 1ae427836d7979779c9005388a05993a3addabcc66c4422694639a4272d7d972 + md5: d0510124f87c75403090e220db1e9d41 + depends: + - __glibc >=2.17,<3.0.a0 + - libblas >=3.9.0,<4.0a0 + - libcblas >=3.9.0,<4.0a0 + - libgcc >=14 + - libgfortran + - libgfortran5 >=14.3.0 + - liblapack >=3.9.0,<4.0a0 + - libstdcxx >=14 + - numpy <2.7 + - numpy >=1.23,<3 + - numpy >=1.25.2 + - python >=3.14,<3.15.0a0 + - python_abi 3.14.* *_cp314 + license: BSD-3-Clause + license_family: BSD + size: 17225275 + timestamp: 1771880751368 +- conda: https://conda.anaconda.org/conda-forge/noarch/sysroot_linux-64-2.28-h4ee821c_9.conda + sha256: c47299fe37aebb0fcf674b3be588e67e4afb86225be4b0d452c7eb75c086b851 + md5: 13dc3adbc692664cd3beabd216434749 + depends: + - __glibc >=2.28 + - kernel-headers_linux-64 4.18.0 he073ed8_9 + - tzdata + license: LGPL-2.0-or-later AND LGPL-2.0-or-later WITH exceptions AND GPL-2.0-or-later + license_family: GPL + size: 24008591 + timestamp: 1765578833462 +- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda + sha256: cafeec44494f842ffeca27e9c8b0c27ed714f93ac77ddadc6aaf726b5554ebac + md5: cffd3bdd58090148f4cfcd831f4b26ab + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libzlib >=1.3.1,<2.0a0 + constrains: + - xorg-libx11 >=1.8.12,<2.0a0 + license: TCL + license_family: BSD + size: 3301196 + timestamp: 1769460227866 +- conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.4.0-pyhcf101f3_0.conda + sha256: 62940c563de45790ba0f076b9f2085a842a65662268b02dd136a8e9b1eaf47a8 + md5: 72e780e9aa2d0a3295f59b1874e3768b + depends: + - python >=3.10 + - python + license: MIT + license_family: MIT + size: 21453 + timestamp: 1768146676791 +- conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.15.0-pyhcf101f3_0.conda + sha256: 032271135bca55aeb156cee361c81350c6f3fb203f57d024d7e5a1fc9ef18731 + md5: 0caa1af407ecff61170c9437a808404d + depends: + - python >=3.10 + - python + license: PSF-2.0 + license_family: PSF + size: 51692 + timestamp: 1756220668932 +- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda + sha256: 1d30098909076af33a35017eed6f2953af1c769e273a0626a04722ac4acaba3c + md5: ad659d0a2b3e47e38d829aa8cad2d610 + license: LicenseRef-Public-Domain + size: 119135 + timestamp: 1767016325805 +- conda: https://conda.anaconda.org/conda-forge/linux-64/ucc-1.7.0-hcedbda0_0.conda + sha256: e50076036f16f558cc980c79e81be90cfd99e165481cdc199b59299916cdac8c + md5: 9dcf3dd1c01009915ef1e12732e5cfcb + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libstdcxx >=14 + - ucx >=1.20.0,<1.20.1.0a0 + constrains: + - cuda-version >=12,<13.0a0 + - nccl >=2.29.3.1,<3.0a0 + - cuda-cudart + license: BSD-3-Clause + license_family: BSD + size: 8830517 + timestamp: 1773083374330 +- conda: https://conda.anaconda.org/conda-forge/linux-64/ucx-1.20.0-hf72d326_1.conda + sha256: 350c5179e1bda17434acf99eb8247f5b6d9b7f991dfd19c582abf538ed41733a + md5: d878a39ba2fc02440785a6a5c4657b09 + depends: + - __glibc >=2.28,<3.0.a0 + - _openmp_mutex >=4.5 + - libgcc >=14 + - libstdcxx >=14 + - rdma-core >=61.0 + constrains: + - cuda-cudart + - cuda-version >=13,<14.0a0 + license: BSD-3-Clause + license_family: BSD + size: 7801740 + timestamp: 1769197798676 +- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda + sha256: 68f0206ca6e98fea941e5717cec780ed2873ffabc0e1ed34428c061e2c6268c7 + md5: 4a13eeac0b5c8e5b8ab496e6c4ddd829 + depends: + - __glibc >=2.17,<3.0.a0 + - libzlib >=1.3.1,<2.0a0 + license: BSD-3-Clause + license_family: BSD + size: 601375 + timestamp: 1764777111296 diff --git a/pixi.toml b/pixi.toml new file mode 100644 index 0000000..93834ef --- /dev/null +++ b/pixi.toml @@ -0,0 +1,24 @@ +[workspace] +name = "ale" +version = "0.1.0" +description = "ALE: Amalgamated Likelihood Estimation for reconciled gene trees" +channels = ["conda-forge"] +platforms = ["linux-64"] + +[dependencies] +cmake = ">=3.24" +openmpi = "*" +gxx = "*" +make = "*" +git = "*" +python = ">=3.10" +scipy = "*" +numpy = "*" +pytest = "*" + +[tasks] +configure = "cmake -S . -B build" +build = { cmd = "cmake --build build -j4", depends-on = ["configure"] } +python_ale = "python -m ALE_python" +test = { cmd = "pytest ALE_python/tests -v", depends-on = ["build"] } +test-fast = { cmd = "pytest ALE_python/tests -v -m 'not slow'", depends-on = ["build"] } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7d0c185 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.pytest.ini_options] +markers = [ + "slow: marks tests that run the full optimizer (deselect with '-m \"not slow\"')", +]