@@ -2086,7 +2086,8 @@ class Matcher:
20862086 def __init__ (
20872087 self ,
20882088 variant_data ,
2089- inference_site_position ,
2089+ combined_position ,
2090+ terminal_position ,
20902091 num_threads = 1 ,
20912092 path_compression = True ,
20922093 recombination_rate = None ,
@@ -2104,30 +2105,33 @@ def __init__(
21042105 self .num_threads = num_threads
21052106 self .path_compression = path_compression
21062107 self .num_samples = self .variant_data .num_samples
2107- self .num_sites = len (inference_site_position )
2108- if self .num_sites == 0 :
2109- logger .warning ("No sites used for inference" )
2110- num_intervals = max (self .num_sites - 1 , 0 )
21112108 self .progress_monitor = _get_progress_monitor (progress_monitor )
21122109 self .match_progress = None # Allocated by subclass
21132110 self .extended_checks = extended_checks
21142111
2112+ assert np .isin (terminal_position , combined_position ).all ()
2113+ inference_position = np .setdiff1d (
2114+ combined_position , terminal_position , assume_unique = True
2115+ )
2116+ self .num_sites = len (inference_position )
2117+ if self .num_sites == 0 :
2118+ logger .warning ("No sites used for inference" )
2119+ num_intervals = max (self .num_sites - 1 , 0 )
2120+
21152121 all_sites = self .variant_data .sites_position [:]
2116- index = np .searchsorted (all_sites , inference_site_position )
2122+ index = np .searchsorted (all_sites , inference_position )
21172123 num_alleles = variant_data .num_alleles ()[index ]
21182124 self .num_alleles = num_alleles
2119- if not np .all (all_sites [index ] == inference_site_position ):
2125+ if not np .all (all_sites [index ] == inference_position ):
21202126 raise ValueError (
21212127 "Site positions for inference must be a subset of those in "
21222128 "the sample data file."
21232129 )
21242130 self .inference_site_id = index
21252131
2126- # Map of site index to tree sequence position. Bracketing
2127- # values of 0 and L are used for simplicity.
2128- self .position_map = np .hstack (
2129- [inference_site_position , [variant_data .sequence_length ]]
2130- )
2132+ # Map of site index to tree sequence position. Terminal site position
2133+ # is included is no longer set to sequence_length.
2134+ self .position_map = combined_position .copy ()
21312135 self .position_map [0 ] = 0
21322136 self .recombination = np .zeros (self .num_sites ) # TODO: reduce len by 1
21332137 self .mismatch = np .zeros (self .num_sites )
@@ -2163,7 +2167,7 @@ def __init__(
21632167 )
21642168 else :
21652169 genetic_dists = self .recombination_rate_to_dist (
2166- recombination_rate , inference_site_position
2170+ recombination_rate , inference_position
21672171 )
21682172 recombination = self .recombination_dist_to_prob (genetic_dists )
21692173 if mismatch_ratio is None :
@@ -2356,6 +2360,12 @@ def convert_inference_mutations(self, tables):
23562360 progress .update ()
23572361 progress .close ()
23582362
2363+ site_id = tables .sites .add_row (
2364+ self .terminal_position [0 ],
2365+ ancestral_state = "N" ,
2366+ metadata = b"" ,
2367+ )
2368+
23592369 def restore_tree_sequence_builder (self ):
23602370 tables = self .ancestors_ts_tables
23612371 if self .variant_data .sequence_length != tables .sequence_length :
@@ -2421,8 +2431,14 @@ class AncestorMatcher(Matcher):
24212431 def __init__ (
24222432 self , variant_data , ancestor_data , ancestors_ts = None , time_units = None , ** kwargs
24232433 ):
2424- super ().__init__ (variant_data , ancestor_data .sites_position [:], ** kwargs )
2434+ super ().__init__ (
2435+ variant_data ,
2436+ combined_position = ancestor_data .sites_position [:],
2437+ terminal_position = ancestor_data .terminal_position [:],
2438+ ** kwargs ,
2439+ )
24252440 self .ancestor_data = ancestor_data
2441+ self .terminal_position = ancestor_data .terminal_position
24262442 if time_units is None :
24272443 time_units = tskit .TIME_UNITS_UNCALIBRATED
24282444 self .time_units = time_units
@@ -2688,8 +2704,18 @@ def store_output(self):
26882704class SampleMatcher (Matcher ):
26892705 def __init__ (self , variant_data , ancestors_ts , ** kwargs ):
26902706 self .ancestors_ts_tables = ancestors_ts .dump_tables ()
2707+
2708+ ancestral_state_vals = ancestors_ts .tables .sites .ancestral_state
2709+ ancestral_state = np .char .decode (ancestral_state_vals .view ("S1" ), "ascii" )
2710+ terminal_sites = np .where (ancestral_state == "N" )[0 ]
2711+ terminal_position = ancestors_ts .sites_position [terminal_sites ]
2712+ self .terminal_position = terminal_position
2713+
26912714 super ().__init__ (
2692- variant_data , self .ancestors_ts_tables .sites .position , ** kwargs
2715+ variant_data ,
2716+ combined_position = self .ancestors_ts_tables .sites .position ,
2717+ terminal_position = terminal_position ,
2718+ ** kwargs ,
26932719 )
26942720 self .restore_tree_sequence_builder ()
26952721 # Map from input sample indexes (IDs in the SampleData file) to the
0 commit comments