@@ -65,9 +65,8 @@ def __init__(
6565
6666 # Precompute indices for logits corresponding to tree-internal
6767 # tokens across batches.
68- num_tree_internal_tokens = self .cu_tokens_per_level [- 2 ]
69- self .tree_internal_index_offsets = torch .arange (
70- num_tree_internal_tokens , device = device )
68+ self .tree_internal_size = self .cu_tokens_per_level [- 2 ]
69+ self .tree_index_offsets = torch .arange (self .tree_size , device = device )
7170
7271 def forward (
7372 self ,
@@ -113,101 +112,23 @@ def forward(
113112 # 3 4 5 6 level 2
114113
115114 device = target_logits .device
115+ draft_tree_size = self .tree_size - 1
116+ tree_internal_index_offsets = (
117+ self .tree_index_offsets [:self .tree_internal_size ])
118+ draft_index_offsets = self .tree_index_offsets [:draft_tree_size ]
119+
116120 num_reqs = len (metadata .num_draft_tokens )
117- # [8 , 8, 0 , 0, 0, 0, 0, 0]
121+ # [1 , 8, 8 , 0, 0, 0, 0, 0]
118122 num_draft_tokens = torch .tensor (metadata .num_draft_tokens ,
119123 device = device )
120- draft_tree_size = self .tree_size - 1
121- # [1, 1, 0, 0, 0, 0, 0, 0]
122- tree_decode_mask = num_draft_tokens == draft_tree_size
124+ # [0, 1, 1, 0, 0, 0, 0, 0]
125+ is_tree_decode = num_draft_tokens == draft_tree_size
123126 # [0, 1, 2, 3, 4, 5, 6, 7]
124127 start_indices = torch .arange (num_reqs , device = device )
125- # [0, 9, 18, 19, 20, 21, 22, 23]
128+ # [0, 2, 11, 20, 21, 22, 23, 24 ]
126129 start_indices [1 :] += metadata .cu_num_draft_tokens [:- 1 ]
127130
128- # Compute target probabilities for all logits corresponding to internal
129- # nodes in the tree.
130- vocab_size = target_logits .shape [- 1 ]
131- # [0, 9]
132- tree_decode_start_indices = start_indices [tree_decode_mask ]
133- # [[0, 1, 2],
134- # [9, 10, 11]]
135- tree_internal_indices = tree_decode_start_indices .unsqueeze (
136- 1 ) + self .tree_internal_index_offsets
137- num_tree_decodes , num_logits_per_batch = tree_internal_indices .shape
138- tree_internal_logits = target_logits [tree_internal_indices .flatten ()]
139- target_probs = self .compute_probs (
140- tree_internal_logits ,
141- num_logits_per_batch ,
142- sampling_metadata ,
143- ).view (num_tree_decodes , - 1 , vocab_size )
144-
145- # Sample tokens from the target probabilities.
146- # TODO(TheEpicDolphin): Add support for probabilistic-style rejection
147- # sampling, as used in EAGLE.
148- target_token_ids = target_probs .argmax (dim = - 1 ).cpu ()
149-
150- # Reshape the draft token ids to [num_tree_decodes, draft_tree_size].
151- draft_token_ids = metadata .draft_token_ids .view (num_tree_decodes , - 1 )
152-
153- # Move sampled target and draft token tensors to CPU.
154- # [[311, 6435, 96618],
155- # [279, 11, 15861]]
156- target_token_ids_cpu = target_token_ids .cpu ()
157- # [[311, 8844, 2349, 387, 4732, 96618, 311, 334],
158- # [3634, 279, 323, 11, 438, 15861, 3634, 7016]]
159- draft_token_ids_cpu = draft_token_ids .cpu ()
160-
161- # For each batch, find longest path from the root node.
162- path_lengths = torch .zeros (
163- # +1 for the root token.
164- (num_tree_decodes , draft_tree_size + 1 ),
165- dtype = torch .int32 )
166- path_lengths [:, 0 ] = 1
167- for level in range (1 , self .tree_depth ):
168- # level 2:
169- # (3, 9)
170- start , end = self .draft_slices [level ]
171- # [1, 1, 1, 2, 2, 2]
172- parent_indices = self .parent_indices [level ]
173- # [[0, 0, 0, 0, 0, 0],
174- # [0, 1, 0, 1, 0, 0]]
175- sample_match = draft_token_ids_cpu [:, start - 1 :end -
176- 1 ] == target_token_ids_cpu [:,
177- parent_indices ]
178- nonzero_length = path_lengths [:, parent_indices ] > 0
179- # [[1, 2, 0, 0, 0, 0, 0, 0, 0], -> [[1, 2, 0, 0, 0, 0, 0, 0, 0],
180- # [1, 0, 2, 0, 0, 0, 0, 0, 0]] [1, 0, 2, 0, 0, 0, 3, 0, 0]]
181- path_lengths [:,
182- start :end ].masked_fill_ (sample_match & nonzero_length ,
183- level + 1 )
184- # [1, 6]
185- accepted_token_index_offsets = path_lengths .argmax (dim = - 1 ).to (device )
186-
187- # Get boolean masks for the paths to the accepted tokens.
188- # [0, 1]
189- tree_batch_indices = self .batch_indices [:num_tree_decodes ]
190- # [[[1, 0, 0, 0, 0, 0, 0], <- batch 0
191- # [1, 1, 0, 0, 0, 0, 0],
192- # [1, 0, 1, 0, 0, 0, 0],
193- # [1, 1, 0, 1, 0, 0, 0],
194- # [1, 1, 0, 0, 1, 0, 0],
195- # [1, 0, 1, 0, 0, 1, 0],
196- # [1, 0, 1, 0, 0, 0, 1]],
197- # [[1, 0, 0, 0, 0, 0, 0], <- batch 1
198- # [1, 1, 0, 0, 0, 0, 0],
199- # [1, 0, 1, 0, 0, 0, 0],
200- # [1, 1, 0, 1, 0, 0, 0],
201- # [1, 1, 0, 0, 1, 0, 0],
202- # [1, 0, 1, 0, 0, 1, 0],
203- # [1, 0, 1, 0, 0, 0, 1]]]
204- tree_mask = self .expanded_tree_mask [:num_tree_decodes ]
205- # [1, 6] => [[1, 0, 0, 0, 0, 0], <- batch 0
206- # [0, 1, 0, 0, 0, 1]] <- batch 1
207- path_masks = tree_mask [tree_batch_indices ,
208- accepted_token_index_offsets ]
209-
210- # Create output buffer.
131+ # Create output token ids buffer.
211132 output_token_ids = torch .empty (
212133 # +1 for the bonus token.
213134 (num_reqs , draft_tree_size + 1 ),
@@ -217,15 +138,116 @@ def forward(
217138 )
218139 output_token_ids .fill_ (PLACEHOLDER_TOKEN_ID )
219140
220- # Set accepted draft tokens.
221- accepted_draft_tokens = draft_token_ids [path_masks ]
222- scatter_mask = torch .zeros_like (output_token_ids , dtype = torch .bool )
223- scatter_mask [tree_decode_mask , :- 1 ] = path_masks
224- output_token_ids .masked_scatter_ (scatter_mask , accepted_draft_tokens )
141+ # [0, 0, 0, 0, 0, 0, 0, 0]
142+ accepted_index_offsets = torch .zeros_like (is_tree_decode ,
143+ dtype = torch .int32 )
144+
145+ num_tree_decodes = is_tree_decode .sum ()
146+ if num_tree_decodes > 0 :
147+ # Compute target probabilities for all logits corresponding to
148+ # internal nodes in the tree.
149+ vocab_size = target_logits .shape [- 1 ]
150+ # [0, 9]
151+ tree_decode_start_indices = start_indices [is_tree_decode ]
152+ # [[0, 1, 2],
153+ # [9, 10, 11]]
154+ tree_internal_indices = (tree_decode_start_indices .unsqueeze (1 ) +
155+ tree_internal_index_offsets )
156+ tree_internal_logits = target_logits [
157+ tree_internal_indices .flatten ()]
158+ target_probs = self .compute_tree_target_probs (
159+ tree_internal_logits ,
160+ is_tree_decode ,
161+ num_tree_decodes ,
162+ sampling_metadata ,
163+ ).view (num_tree_decodes , - 1 , vocab_size )
164+
165+ # Sample tokens from the target probabilities.
166+ # TODO(TheEpicDolphin): Add support for probabilistic-style
167+ # rejection sampling, as used in EAGLE.
168+ target_token_ids = target_probs .argmax (dim = - 1 )
169+
170+ # Get the draft token ids for batches with full draft trees.
171+ # [0, 0]
172+ draft_start_indices = torch .zeros (num_tree_decodes ,
173+ device = device ,
174+ dtype = torch .int32 )
175+ # [0, 8]
176+ draft_start_indices [1 :] = (
177+ metadata .cu_num_draft_tokens [is_tree_decode ][:- 1 ])
178+ # [[0, 1, 2, ... , 7]
179+ # [8, 9, 10, ... , 15]]
180+ tree_draft_indices = (draft_start_indices .unsqueeze (1 ) +
181+ draft_index_offsets )
182+ draft_token_ids = metadata .draft_token_ids [tree_draft_indices ]
183+
184+ # Move sampled target and draft token tensors to CPU.
185+ # [[311, 6435, 96618],
186+ # [279, 11, 15861]]
187+ target_token_ids_cpu = target_token_ids .cpu ()
188+ # [[311, 8844, 2349, 387, 4732, 96618, 311, 334],
189+ # [3634, 279, 323, 11, 438, 15861, 3634, 7016]]
190+ draft_token_ids_cpu = draft_token_ids .cpu ()
191+
192+ # For each tree decode batch, find longest path from the root node.
193+ path_lengths_cpu = torch .zeros (
194+ # +1 for the root token.
195+ (num_tree_decodes , draft_tree_size + 1 ),
196+ dtype = torch .int32 ,
197+ device = "cpu" )
198+ path_lengths_cpu [:, 0 ] = 1
199+ for level in range (1 , self .tree_depth ):
200+ # level 2:
201+ # (3, 9)
202+ start , end = self .draft_slices [level ]
203+ # [1, 1, 1, 2, 2, 2]
204+ parent_indices = self .parent_indices [level ]
205+ # [[0, 0, 0, 0, 0, 0],
206+ # [0, 1, 0, 1, 0, 0]]
207+ sample_match = (draft_token_ids_cpu [:, start - 1 :end - 1 ] ==
208+ target_token_ids_cpu [:, parent_indices ])
209+ nonzero_length = path_lengths_cpu [:, parent_indices ] > 0
210+ # [[1, 2, 0, 0, 0, 0, 0, 0, 0],-> [[1, 2, 0, 0, 0, 0, 0, 0, 0],
211+ # [1, 0, 2, 0, 0, 0, 0, 0, 0]] [1, 0, 2, 0, 0, 0, 3, 0, 0]]
212+ path_lengths_cpu [:, start :end ].masked_fill_ (
213+ sample_match & nonzero_length , level + 1 )
214+ # [1, 6, 0, 0, 0, 0, 0, 0]
215+ path_lengths = path_lengths_cpu .argmax (dim = - 1 ).to (
216+ device , dtype = torch .int32 )
217+ accepted_index_offsets [is_tree_decode ] = path_lengths
218+
219+ # Get boolean masks for the paths to the accepted tokens.
220+ # [0, 1]
221+ tree_batch_indices = self .batch_indices [:num_tree_decodes ]
222+ # [[[1, 0, 0, 0, 0, 0, 0], <- batch 0
223+ # [1, 1, 0, 0, 0, 0, 0],
224+ # [1, 0, 1, 0, 0, 0, 0],
225+ # [1, 1, 0, 1, 0, 0, 0],
226+ # [1, 1, 0, 0, 1, 0, 0],
227+ # [1, 0, 1, 0, 0, 1, 0],
228+ # [1, 0, 1, 0, 0, 0, 1]],
229+ # [[1, 0, 0, 0, 0, 0, 0], <- batch 1
230+ # [1, 1, 0, 0, 0, 0, 0],
231+ # [1, 0, 1, 0, 0, 0, 0],
232+ # [1, 1, 0, 1, 0, 0, 0],
233+ # [1, 1, 0, 0, 1, 0, 0],
234+ # [1, 0, 1, 0, 0, 1, 0],
235+ # [1, 0, 1, 0, 0, 0, 1]]]
236+ tree_mask = self .expanded_tree_mask [:num_tree_decodes ]
237+ # [1, 6] => [[1, 0, 0, 0, 0, 0], <- batch 0
238+ # [0, 1, 0, 0, 0, 1]] <- batch 1
239+ path_masks = tree_mask [tree_batch_indices , path_lengths ]
240+
241+ # Set accepted draft tokens.
242+ accepted_draft_tokens = draft_token_ids [path_masks ]
243+ scatter_mask = torch .zeros_like (output_token_ids , dtype = torch .bool )
244+ scatter_mask [is_tree_decode , :- 1 ] = path_masks
245+ output_token_ids .masked_scatter_ (scatter_mask ,
246+ accepted_draft_tokens )
225247
226248 # Sample and add a bonus token to the accepted paths.
227- bonus_token_indices = start_indices
228- bonus_token_indices [ tree_decode_mask ] += accepted_token_index_offsets
249+ # [0, 2 + 1, 11 + 6, 20, 21, 22, 23, 24]
250+ bonus_token_indices = start_indices + accepted_index_offsets
229251 bonus_sampler_output = self .main_sampler (
230252 logits = target_logits [bonus_token_indices ],
231253 sampling_metadata = sampling_metadata ,
@@ -234,22 +256,30 @@ def forward(
234256 - 1 ] = bonus_sampler_output .sampled_token_ids .view (- 1 )
235257 return output_token_ids
236258
237- def compute_probs (self , logits : torch .Tensor , logits_per_batch : int ,
238- sampling_metadata : SamplingMetadata ):
259+ def compute_tree_target_probs (self , logits : torch .Tensor ,
260+ is_tree_decode : torch .Tensor ,
261+ num_tree_decodes : int ,
262+ sampling_metadata : SamplingMetadata ):
239263 if sampling_metadata .all_greedy :
240264 return logits
241265
266+ # How many times to repeat the temperature, top-k, and top-p
267+ # for each tree-decode batch.
268+ num_repeats = logits .shape [0 ] // num_tree_decodes
269+
242270 assert sampling_metadata .temperature is not None
243- temperature = sampling_metadata .temperature . repeat_interleave (
244- logits_per_batch )
271+ temperature = sampling_metadata .temperature [ is_tree_decode ]
272+ temperature = temperature . repeat_interleave ( num_repeats )
245273 logits .div_ (temperature .view (- 1 , 1 ))
246274
247275 top_k = None
248276 if sampling_metadata .top_k is not None :
249- top_k = sampling_metadata .top_k .repeat_interleave (logits_per_batch )
277+ top_k = sampling_metadata .top_k [is_tree_decode ]
278+ top_k = top_k .repeat_interleave (num_repeats )
250279 top_p = None
251280 if sampling_metadata .top_p is not None :
252- top_p = sampling_metadata .top_p .repeat_interleave (logits_per_batch )
281+ top_p = sampling_metadata .top_p [is_tree_decode ]
282+ top_p = top_p .repeat_interleave (num_repeats )
253283 logits = apply_top_k_top_p (logits , top_k , top_p )
254284 output_probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
255285 return output_probs
0 commit comments