Skip to content

Commit 15dee64

Browse files
committed
Add support for terminal site to ancestor and sample matching (PY)
1 parent 580c3e6 commit 15dee64

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

tsinfer/inference.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,7 +2086,8 @@ class Matcher:
20862086
def __init__(
20872087
self,
20882088
variant_data,
2089-
inference_site_position,
2089+
combined_position,
2090+
terminal_position,
20902091
num_threads=1,
20912092
path_compression=True,
20922093
recombination_rate=None,
@@ -2104,30 +2105,33 @@ def __init__(
21042105
self.num_threads = num_threads
21052106
self.path_compression = path_compression
21062107
self.num_samples = self.variant_data.num_samples
2107-
self.num_sites = len(inference_site_position)
2108-
if self.num_sites == 0:
2109-
logger.warning("No sites used for inference")
2110-
num_intervals = max(self.num_sites - 1, 0)
21112108
self.progress_monitor = _get_progress_monitor(progress_monitor)
21122109
self.match_progress = None # Allocated by subclass
21132110
self.extended_checks = extended_checks
21142111

2112+
assert np.isin(terminal_position, combined_position).all()
2113+
inference_position = np.setdiff1d(
2114+
combined_position, terminal_position, assume_unique=True
2115+
)
2116+
self.num_sites = len(inference_position)
2117+
if self.num_sites == 0:
2118+
logger.warning("No sites used for inference")
2119+
num_intervals = max(self.num_sites - 1, 0)
2120+
21152121
all_sites = self.variant_data.sites_position[:]
2116-
index = np.searchsorted(all_sites, inference_site_position)
2122+
index = np.searchsorted(all_sites, inference_position)
21172123
num_alleles = variant_data.num_alleles()[index]
21182124
self.num_alleles = num_alleles
2119-
if not np.all(all_sites[index] == inference_site_position):
2125+
if not np.all(all_sites[index] == inference_position):
21202126
raise ValueError(
21212127
"Site positions for inference must be a subset of those in "
21222128
"the sample data file."
21232129
)
21242130
self.inference_site_id = index
21252131

2126-
# Map of site index to tree sequence position. Bracketing
2127-
# values of 0 and L are used for simplicity.
2128-
self.position_map = np.hstack(
2129-
[inference_site_position, [variant_data.sequence_length]]
2130-
)
2132+
# Map of site index to tree sequence position. Terminal site position
2133+
# is included is no longer set to sequence_length.
2134+
self.position_map = combined_position.copy()
21312135
self.position_map[0] = 0
21322136
self.recombination = np.zeros(self.num_sites) # TODO: reduce len by 1
21332137
self.mismatch = np.zeros(self.num_sites)
@@ -2163,7 +2167,7 @@ def __init__(
21632167
)
21642168
else:
21652169
genetic_dists = self.recombination_rate_to_dist(
2166-
recombination_rate, inference_site_position
2170+
recombination_rate, inference_position
21672171
)
21682172
recombination = self.recombination_dist_to_prob(genetic_dists)
21692173
if mismatch_ratio is None:
@@ -2356,6 +2360,12 @@ def convert_inference_mutations(self, tables):
23562360
progress.update()
23572361
progress.close()
23582362

2363+
site_id = tables.sites.add_row(
2364+
self.terminal_position[0],
2365+
ancestral_state="N",
2366+
metadata=b"",
2367+
)
2368+
23592369
def restore_tree_sequence_builder(self):
23602370
tables = self.ancestors_ts_tables
23612371
if self.variant_data.sequence_length != tables.sequence_length:
@@ -2421,8 +2431,14 @@ class AncestorMatcher(Matcher):
24212431
def __init__(
24222432
self, variant_data, ancestor_data, ancestors_ts=None, time_units=None, **kwargs
24232433
):
2424-
super().__init__(variant_data, ancestor_data.sites_position[:], **kwargs)
2434+
super().__init__(
2435+
variant_data,
2436+
combined_position=ancestor_data.sites_position[:],
2437+
terminal_position=ancestor_data.terminal_position[:],
2438+
**kwargs,
2439+
)
24252440
self.ancestor_data = ancestor_data
2441+
self.terminal_position = ancestor_data.terminal_position
24262442
if time_units is None:
24272443
time_units = tskit.TIME_UNITS_UNCALIBRATED
24282444
self.time_units = time_units
@@ -2688,8 +2704,18 @@ def store_output(self):
26882704
class SampleMatcher(Matcher):
26892705
def __init__(self, variant_data, ancestors_ts, **kwargs):
26902706
self.ancestors_ts_tables = ancestors_ts.dump_tables()
2707+
2708+
ancestral_state_vals = ancestors_ts.tables.sites.ancestral_state
2709+
ancestral_state = np.char.decode(ancestral_state_vals.view("S1"), "ascii")
2710+
terminal_sites = np.where(ancestral_state == "N")[0]
2711+
terminal_position = ancestors_ts.sites_position[terminal_sites]
2712+
self.terminal_position = terminal_position
2713+
26912714
super().__init__(
2692-
variant_data, self.ancestors_ts_tables.sites.position, **kwargs
2715+
variant_data,
2716+
combined_position=self.ancestors_ts_tables.sites.position,
2717+
terminal_position=terminal_position,
2718+
**kwargs,
26932719
)
26942720
self.restore_tree_sequence_builder()
26952721
# Map from input sample indexes (IDs in the SampleData file) to the

0 commit comments

Comments
 (0)