From 580c3e699e67a2abb3ff116cfa939969985d2a56 Mon Sep 17 00:00:00 2001 From: Duncan Mbuli-Robertson Date: Thu, 16 Oct 2025 16:10:12 +0100 Subject: [PATCH 1/3] Add support for terminal site to generate_ancestors with PY engine --- tsinfer/algorithm.py | 25 +++++++++++++++---------- tsinfer/formats.py | 36 +++++++++++++++++++++++++++++------- tsinfer/inference.py | 26 ++++++++++++++++++++------ 3 files changed, 64 insertions(+), 23 deletions(-) diff --git a/tsinfer/algorithm.py b/tsinfer/algorithm.py index d0d0c6ef..979948aa 100644 --- a/tsinfer/algorithm.py +++ b/tsinfer/algorithm.py @@ -58,6 +58,7 @@ class Site: id = attr.ib() time = attr.ib() derived_count = attr.ib() + terminal = attr.ib() class AncestorBuilder: @@ -137,21 +138,23 @@ def store_site_genotypes(self, site_id, genotypes): stop = start + self.encoded_genotypes_size self.genotype_store[start:stop] = genotypes - def add_site(self, time, genotypes): + def add_site(self, time, genotypes, terminal): """ Adds a new site at the specified ID to the builder. """ site_id = len(self.sites) derived_count = np.sum(genotypes == 1) - self.store_site_genotypes(site_id, genotypes) - self.sites.append(Site(site_id, time, derived_count)) - sites_at_fixed_timepoint = self.time_map[time] - # Sites with an identical variant distribution (i.e. with the same - # genotypes.tobytes() value) and at the same time, are put into the same ancestor - # to which we allocate a unique ID (just use the genotypes value) - ancestor_uid = tuple(genotypes) - # Add each site to the list for this ancestor_uid at this timepoint - sites_at_fixed_timepoint[ancestor_uid].append(site_id) + self.sites.append(Site(site_id, time, derived_count, terminal)) + if not terminal: + self.store_site_genotypes(site_id, genotypes) + sites_at_fixed_timepoint = self.time_map[time] + # Sites with an identical variant distribution (i.e. with the same + # genotypes.tobytes() value) and at the same time, are put into the + # same ancestor to which we allocate a unique ID (just use the genotypes + # value) + ancestor_uid = tuple(genotypes) + # Add each site to the list for this ancestor_uid at thigs timepoint + sites_at_fixed_timepoint[ancestor_uid].append(site_id) def print_state(self): print("Ancestor builder") @@ -221,6 +224,8 @@ def compute_ancestral_states(self, a, focal_site, sites): disagree = np.zeros(self.num_samples, dtype=bool) for site_index in sites: + if self.sites[site_index].terminal: + break a[site_index] = 0 last_site = site_index g_l = self.get_site_genotypes(site_index) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 9728f725..cf78d990 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -3093,7 +3093,14 @@ class AncestorData(DataContainer): FORMAT_NAME = "tsinfer-ancestor-data" FORMAT_VERSION = (3, 0) - def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs): + def __init__( + self, + inference_position, + terminal_position, + sequence_length, + chunk_size_sites=None, + **kwargs, + ): super().__init__(**kwargs) self._last_time = 0 self.inference_sites_set = False @@ -3111,15 +3118,22 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs): self.create_dataset("sample_end", dtype=np.int32) self.create_dataset("sample_time", dtype=np.float64) self.create_dataset("sample_focal_sites", dtype="array:i4") - + variant_position = np.concatenate([inference_position, terminal_position]) self.create_dataset( "variant_position", - data=position, - shape=position.shape, + data=variant_position, + shape=variant_position.shape, chunks=self._chunk_size_sites, dtype=np.float64, dimensions=["variants"], ) + self.create_dataset( + "terminal_position", + data=terminal_position, + shape=terminal_position.shape, + dtype=np.float64, + dimensions=["terminal_sites"], + ) # We have to include a ploidy dimension sgkit compatibility a = self.create_dataset( @@ -3277,10 +3291,17 @@ def num_sites(self): @property def sites_position(self): """ - The positions of the inference sites used to generate the ancestors + The positions of the inference and terminal sites used to generate the ancestors """ return self.data["variant_position"] + @property + def terminal_position(self): + """ + The positions of the terminal sites used to generate the ancestors + """ + return self.data["terminal_position"] + @property def ancestors_start(self): return self.data["sample_start"] @@ -3314,10 +3335,10 @@ def ancestors_length(self): """ # Ancestor start and end are half-closed. The last site is assumed # to cover the region up to sequence length. - pos = np.hstack([self.sites_position[:], [self.sequence_length]]) + start = self.ancestors_start[:] end = self.ancestors_end[:] - return pos[end] - pos[start] + return self.sites_position[end] - self.sites_position[start] def insert_proxy_samples( self, @@ -3683,6 +3704,7 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype): if start < 0: raise ValueError("Start must be >= 0") if end > self.num_sites: + print(f"[INFO] {end}, {self.num_sites}") raise ValueError("end must be <= num_sites") if start >= end: raise ValueError("start must be < end") diff --git a/tsinfer/inference.py b/tsinfer/inference.py index b11d3782..6a92cd3b 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1807,6 +1807,8 @@ def __init__( self.num_samples = variant_data.num_samples self.num_threads = num_threads self.mmap_temp_file = None + self.sites_position = None + self.terminal_position = None mmap_fd = -1 genotype_matrix_size = self.max_sites * self.num_samples @@ -1865,6 +1867,8 @@ def add_sites(self, exclude_positions=None): logger.info(f"Starting addition of {self.max_sites} sites") progress = self.progress_monitor.get("ga_add_sites", self.max_sites) inference_site_id = [] + last_position = 0 + for variant in self.variant_data.variants(recode_ancestral=True): # If there's missing data the last allele is None num_alleles = len(variant.alleles) - int(variant.alleles[-1] is None) @@ -1879,6 +1883,7 @@ def add_sites(self, exclude_positions=None): and site.ancestral_state is not None ): use_site = True + last_position = site.position time = site.time if tskit.is_unknown_time(time): # Non-variable sites have no obvious freq-as-time values @@ -1888,12 +1893,18 @@ def add_sites(self, exclude_positions=None): if np.isnan(time): use_site = False # Site with meaningless time value: skip inference if use_site: - self.ancestor_builder.add_site(time, variant.genotypes) + self.ancestor_builder.add_site(time, variant.genotypes, terminal=False) inference_site_id.append(site.id) self.num_sites += 1 progress.update() progress.close() self.inference_site_ids = inference_site_id + # Add terminal site at end of sequence + zeros = np.zeros(self.num_samples, dtype=np.int8) + self.ancestor_builder.add_site(tskit.UNKNOWN_TIME, zeros, terminal=True) + self.num_sites += 1 + self.terminal_position = np.array([last_position + 1], dtype=np.float64) + logger.info("Finished adding sites") def _run_synchronous(self, progress): @@ -2000,15 +2011,18 @@ def run(self): if t not in self.timepoint_to_epoch: self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1 self.ancestor_data = formats.AncestorData( - self.variant_data.sites_position[:][self.inference_site_ids], - self.variant_data.sequence_length, + inference_position=self.variant_data.sites_position[:][ + self.inference_site_ids + ], + terminal_position=self.terminal_position, + sequence_length=self.variant_data.sequence_length, path=self.ancestor_data_path, **self.ancestor_data_kwargs, ) if self.num_ancestors > 0: logger.info(f"Starting build for {self.num_ancestors} ancestors") progress = self.progress_monitor.get("ga_generate", self.num_ancestors) - a = np.zeros(self.num_sites, dtype=np.int8) + a = np.zeros(self.num_sites - 1, dtype=np.int8) root_time = max(self.timepoint_to_epoch.keys()) av_timestep = root_time / len(self.timepoint_to_epoch) root_time += av_timestep # Add a root a bit older than the oldest ancestor @@ -2017,7 +2031,7 @@ def run(self): # line up. It's normally removed when processing the final tree sequence. self.ancestor_data.add_ancestor( start=0, - end=self.num_sites, + end=self.num_sites - 1, time=root_time + av_timestep, focal_sites=np.array([], dtype=np.int32), haplotype=a, @@ -2025,7 +2039,7 @@ def run(self): # This is the the "ultimate ancestor" of all zeros self.ancestor_data.add_ancestor( start=0, - end=self.num_sites, + end=self.num_sites - 1, time=root_time, focal_sites=np.array([], dtype=np.int32), haplotype=a, From 15dee6452913306fdad3de42e5cd5af2d2505586 Mon Sep 17 00:00:00 2001 From: Duncan Mbuli-Robertson Date: Thu, 23 Oct 2025 10:33:01 +0100 Subject: [PATCH 2/3] Add support for terminal site to ancestor and sample matching (PY) --- tsinfer/inference.py | 56 ++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 6a92cd3b..1f200476 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -2086,7 +2086,8 @@ class Matcher: def __init__( self, variant_data, - inference_site_position, + combined_position, + terminal_position, num_threads=1, path_compression=True, recombination_rate=None, @@ -2104,30 +2105,33 @@ def __init__( self.num_threads = num_threads self.path_compression = path_compression self.num_samples = self.variant_data.num_samples - self.num_sites = len(inference_site_position) - if self.num_sites == 0: - logger.warning("No sites used for inference") - num_intervals = max(self.num_sites - 1, 0) self.progress_monitor = _get_progress_monitor(progress_monitor) self.match_progress = None # Allocated by subclass self.extended_checks = extended_checks + assert np.isin(terminal_position, combined_position).all() + inference_position = np.setdiff1d( + combined_position, terminal_position, assume_unique=True + ) + self.num_sites = len(inference_position) + if self.num_sites == 0: + logger.warning("No sites used for inference") + num_intervals = max(self.num_sites - 1, 0) + all_sites = self.variant_data.sites_position[:] - index = np.searchsorted(all_sites, inference_site_position) + index = np.searchsorted(all_sites, inference_position) num_alleles = variant_data.num_alleles()[index] self.num_alleles = num_alleles - if not np.all(all_sites[index] == inference_site_position): + if not np.all(all_sites[index] == inference_position): raise ValueError( "Site positions for inference must be a subset of those in " "the sample data file." ) self.inference_site_id = index - # Map of site index to tree sequence position. Bracketing - # values of 0 and L are used for simplicity. - self.position_map = np.hstack( - [inference_site_position, [variant_data.sequence_length]] - ) + # Map of site index to tree sequence position. Terminal site position + # is included is no longer set to sequence_length. + self.position_map = combined_position.copy() self.position_map[0] = 0 self.recombination = np.zeros(self.num_sites) # TODO: reduce len by 1 self.mismatch = np.zeros(self.num_sites) @@ -2163,7 +2167,7 @@ def __init__( ) else: genetic_dists = self.recombination_rate_to_dist( - recombination_rate, inference_site_position + recombination_rate, inference_position ) recombination = self.recombination_dist_to_prob(genetic_dists) if mismatch_ratio is None: @@ -2356,6 +2360,12 @@ def convert_inference_mutations(self, tables): progress.update() progress.close() + site_id = tables.sites.add_row( + self.terminal_position[0], + ancestral_state="N", + metadata=b"", + ) + def restore_tree_sequence_builder(self): tables = self.ancestors_ts_tables if self.variant_data.sequence_length != tables.sequence_length: @@ -2421,8 +2431,14 @@ class AncestorMatcher(Matcher): def __init__( self, variant_data, ancestor_data, ancestors_ts=None, time_units=None, **kwargs ): - super().__init__(variant_data, ancestor_data.sites_position[:], **kwargs) + super().__init__( + variant_data, + combined_position=ancestor_data.sites_position[:], + terminal_position=ancestor_data.terminal_position[:], + **kwargs, + ) self.ancestor_data = ancestor_data + self.terminal_position = ancestor_data.terminal_position if time_units is None: time_units = tskit.TIME_UNITS_UNCALIBRATED self.time_units = time_units @@ -2688,8 +2704,18 @@ def store_output(self): class SampleMatcher(Matcher): def __init__(self, variant_data, ancestors_ts, **kwargs): self.ancestors_ts_tables = ancestors_ts.dump_tables() + + ancestral_state_vals = ancestors_ts.tables.sites.ancestral_state + ancestral_state = np.char.decode(ancestral_state_vals.view("S1"), "ascii") + terminal_sites = np.where(ancestral_state == "N")[0] + terminal_position = ancestors_ts.sites_position[terminal_sites] + self.terminal_position = terminal_position + super().__init__( - variant_data, self.ancestors_ts_tables.sites.position, **kwargs + variant_data, + combined_position=self.ancestors_ts_tables.sites.position, + terminal_position=terminal_position, + **kwargs, ) self.restore_tree_sequence_builder() # Map from input sample indexes (IDs in the SampleData file) to the From e6935434178580629b80f47ea9baf243107fd02a Mon Sep 17 00:00:00 2001 From: Duncan Mbuli-Robertson Date: Thu, 23 Oct 2025 13:56:19 +0100 Subject: [PATCH 3/3] Fix edge case where site is near sequence_length --- tsinfer/inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 1f200476..d4c632a4 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1903,7 +1903,11 @@ def add_sites(self, exclude_positions=None): zeros = np.zeros(self.num_samples, dtype=np.int8) self.ancestor_builder.add_site(tskit.UNKNOWN_TIME, zeros, terminal=True) self.num_sites += 1 - self.terminal_position = np.array([last_position + 1], dtype=np.float64) + + terminal_position = last_position + 1 + if terminal_position == self.variant_data.sequence_length: + terminal_position -= 0.5 + self.terminal_position = np.array([terminal_position], dtype=np.float64) logger.info("Finished adding sites")