@@ -50,62 +50,59 @@ def __init__(
5050 include_non_sample_nodes ,
5151 ):
5252 self .tree_sequence = tree_sequence
53- self .contig_id = contig_id
54- self .isolated_as_missing = isolated_as_missing
5553
5654 vcf_model = tree_sequence .map_to_vcf_model (
5755 individuals = individuals ,
5856 ploidy = ploidy ,
5957 individual_names = individual_names ,
6058 include_non_sample_nodes = include_non_sample_nodes ,
6159 position_transform = position_transform ,
60+ contig_id = contig_id ,
61+ isolated_as_missing = isolated_as_missing ,
6262 )
63+
64+ # We now make some tweaks to the VCF model required for
65+ # writing the VCF in text format
66+
6367 # Remove individuals with zero ploidy as these cannot be
6468 # represented in VCF.
65- individuals_nodes = vcf_model .individuals_nodes
66- to_keep = (individuals_nodes != - 1 ).any (axis = 1 )
67- individuals_nodes = individuals_nodes [to_keep ]
68- self .individual_names = vcf_model .individuals_name [to_keep ]
69+ to_keep = (vcf_model .individuals_nodes != - 1 ).any (axis = 1 )
70+ vcf_model .individuals_nodes = vcf_model .individuals_nodes [to_keep ]
71+ vcf_model .individual_names = vcf_model .individuals_name [to_keep ]
6972 self .individual_ploidies = [
70- len (nodes [nodes >= 0 ]) for nodes in individuals_nodes
73+ len (nodes [nodes >= 0 ]) for nodes in vcf_model . individuals_nodes
7174 ]
72- self .num_individuals = len (self .individual_names )
75+ self .num_individuals = len (vcf_model .individual_names )
7376
74- if len (individuals_nodes ) == 0 :
77+ if len (vcf_model . individuals_nodes ) == 0 :
7578 raise ValueError ("No samples in resulting VCF model" )
7679
80+ if len (vcf_model .transformed_positions ) > 0 :
81+ # Arguably this should be last_pos + 1, but if we hit this
82+ # condition the coordinate systems are all muddled up anyway
83+ # so it's simpler to stay with this rule that was inherited
84+ # from the legacy VCF output code.
85+ vcf_model .contig_length = max (
86+ vcf_model .transformed_positions [- 1 ], vcf_model .contig_length
87+ )
88+
7789 # Flatten the array of node IDs, filtering out the -1 padding values
7890 self .samples = []
79- for row in individuals_nodes :
91+ for row in vcf_model . individuals_nodes :
8092 for node_id in row :
8193 if node_id != - 1 :
8294 self .samples .append (node_id )
8395
84- self .transformed_positions = vcf_model .transformed_positions
85- self .contig_length = vcf_model .contig_length
86- if len (self .transformed_positions ) > 0 :
87- # Arguably this should be last_pos + 1, but if we hit this
88- # condition the coordinate systems are all muddled up anyway
89- # so it's simpler to stay with this rule that was inherited
90- # from the legacy VCF output code.
91- self .contig_length = max (self .transformed_positions [- 1 ], self .contig_length )
92-
9396 if site_mask is None :
9497 site_mask = np .zeros (tree_sequence .num_sites , dtype = bool )
9598 self .site_mask = np .array (site_mask , dtype = bool )
9699 if self .site_mask .shape != (tree_sequence .num_sites ,):
97100 raise ValueError ("Site mask must be 1D a boolean array of length num_sites" )
98101
99- self .sample_mask = sample_mask
100- if sample_mask is not None :
101- if not callable (sample_mask ):
102- sample_mask = np .array (sample_mask , dtype = bool )
103- self .sample_mask = lambda _ : sample_mask
104-
105102 # The VCF spec does not allow for positions to be 0, so we error if one of the
106103 # transformed positions is 0 and allow_position_zero is False.
107104 if not allow_position_zero and np .any (
108- self .transformed_positions [~ site_mask ] == 0
105+ vcf_model .transformed_positions [~ site_mask ] == 0
109106 ):
110107 raise ValueError (
111108 "A variant position of 0 was found in the VCF output, this is not "
@@ -116,17 +113,26 @@ def __init__(
116113 '"position_transform = lambda x: np.fmax(1, x)"'
117114 )
118115
116+ self .sample_mask = sample_mask
117+ if sample_mask is not None :
118+ if not callable (sample_mask ):
119+ sample_mask = np .array (sample_mask , dtype = bool )
120+ self .sample_mask = lambda _ : sample_mask
121+
122+ self .vcf_model = vcf_model
123+
119124 def __write_header (self , output ):
120125 print ("##fileformat=VCFv4.2" , file = output )
121126 print (f"##source=tskit { provenance .__version__ } " , file = output )
122127 print ('##FILTER=<ID=PASS,Description="All filters passed">' , file = output )
123128 print (
124- f"##contig=<ID={ self .contig_id } ,length={ self .contig_length } >" , file = output
129+ f"##contig=<ID={ self .vcf_model .contig_id } ,length={ self .vcf_model .contig_length } >" ,
130+ file = output ,
125131 )
126132 print (
127133 '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">' , file = output
128134 )
129- vcf_samples = "\t " .join (self .individual_names )
135+ vcf_samples = "\t " .join (self .vcf_model . individual_names )
130136 print (
131137 "#CHROM" ,
132138 "POS" ,
@@ -163,7 +169,7 @@ def write(self, output):
163169 indexes = np .array (indexes , dtype = int )
164170
165171 for variant in self .tree_sequence .variants (
166- samples = self .samples , isolated_as_missing = self .isolated_as_missing
172+ samples = self .samples , isolated_as_missing = self .vcf_model . isolated_as_missing
167173 ):
168174 site_id = variant .site .id
169175 # We check the mask before we do any checks so we can use this as a
@@ -176,13 +182,13 @@ def write(self, output):
176182 "More than 9 alleles not currently supported. Please open an issue "
177183 "on GitHub if this limitation affects you."
178184 )
179- pos = self .transformed_positions [variant .index ]
185+ pos = self .vcf_model . transformed_positions [variant .index ]
180186 ref = variant .alleles [0 ]
181187 alt = "."
182188 if variant .num_alleles > 1 :
183189 alt = "," .join (variant .alleles [1 : variant .num_alleles ])
184190 print (
185- self .contig_id ,
191+ self .vcf_model . contig_id ,
186192 pos ,
187193 site_id ,
188194 ref ,
0 commit comments