diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ef8dbf7bf..d1c3b87a1 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -471,6 +471,8 @@ def speculative_generate_step( model: nn.Module, draft_model: nn.Module, *, + tokenizer: Optional[Any] = None, + draft_tokenizer: Optional[Any] = None, num_draft_tokens: int = 2, max_tokens: int = 256, sampler: Optional[Callable[[mx.array], mx.array]] = None, @@ -480,6 +482,7 @@ def speculative_generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + translation_prefix_tokens: int = 0, ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -488,6 +491,11 @@ def speculative_generate_step( prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. draft_model (nn.Module): The draft model for speculative decoding. + tokenizer: The verifier model's tokenizer. Required when + ``draft_tokenizer`` is provided for cross-tokenizer translation. + draft_tokenizer: The draft model's tokenizer. When provided (and + differs from ``tokenizer``), draft tokens are decoded to text and + re-encoded with the verifier tokenizer before verification. num_draft_tokens (int, optional): The number of draft tokens for speculative decoding. Default: ``2``. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite @@ -505,12 +513,20 @@ def speculative_generate_step( kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. when ``kv_bits`` is non-None. Default: ``0``. + translation_prefix_tokens (int): Number of previously generated verifier + tokens to prepend as context when re-encoding draft text into verifier + tokens during cross-tokenizer translation. Providing context avoids + boundary mis-tokenisation (e.g. a leading ``w`` being tokenised as the + beginning-of-word ``_w`` instead of the mid-word ``w``). Only active + when ``draft_tokenizer`` is set. Default: ``0`` (no prefix context). Yields: Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and a bool indicating if the token was generated by the draft model """ + cross_tokenizer = draft_tokenizer is not None + y = prompt.astype(mx.uint32) prev_tokens = None @@ -540,13 +556,13 @@ def _process_and_sample(tokens, logits): y = sampler(logprobs) return y, logprobs - def _step(model, cache, y, n_predict=1): + def _step(model, cache, y, n_predict=1, apply_logits_processors=True): with mx.stream(generation_stream): logits = model(y[None], cache=cache) logits = logits[:, -n_predict:, :] quantize_cache_fn(cache) - if logits_processors: + if apply_logits_processors and logits_processors: nonlocal prev_tokens out_y, out_logprobs = [], [] if n_predict > 1: @@ -579,68 +595,247 @@ def _rewind_cache(num_draft, num_accept): cache.trim_prompt_cache(model_cache, num_draft - num_accept) cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + def _translate_to_verifier(draft_token_list, context_ids=None): + draft_text = draft_tokenizer.decode(draft_token_list) + if context_ids: + prefix_text = tokenizer.decode(context_ids) + full_ids = tokenizer.encode(prefix_text + draft_text, add_special_tokens=False) + prefix_len = len(tokenizer.encode(prefix_text, add_special_tokens=False)) + return draft_text, full_ids[prefix_len:] + return draft_text, tokenizer.encode(draft_text, add_special_tokens=False) + + def _translate_to_draft(verifier_token_ids): + text = tokenizer.decode(verifier_token_ids) + return text, draft_tokenizer.encode(text, add_special_tokens=False) + def _draft_generate(y, num_draft): if num_draft == 0: return mx.array([], mx.uint32) ys = [] for _ in range(num_draft): - y, _ = _step(draft_model, draft_cache, y) + y, _ = _step( + draft_model, draft_cache, y, + apply_logits_processors=not cross_tokenizer, + ) mx.async_eval(y) ys.append(y) return mx.concatenate(ys) - with mx.stream(generation_stream): - draft_y = _prefill(draft_model, draft_cache, y) - y = _prefill(model, model_cache, y) + if cross_tokenizer: + # Re-encode prompt for the draft model's tokenizer + prompt_text = tokenizer.decode(prompt.tolist()) + draft_prompt = mx.array( + draft_tokenizer.encode(prompt_text, add_special_tokens=False), + mx.uint32, + ) + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, draft_prompt) + y = _prefill(model, model_cache, y) + + ntoks = 0 + num_draft = 0 + n_verifier = 0 + n_accept_v = 0 + needs_cleanup = False + verifier_context: List[int] = [] + + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + + # Draft n tokens in draft token space + draft_tokens = _draft_generate(draft_y, num_draft) + mx.eval(draft_tokens) + draft_token_list = draft_tokens.tolist() + + # Decode draft tokens to text, re-encode with verifier tokenizer. + # Pass recent verifier tokens as context so the tokenizer sees the + # correct word-boundary for the first subword of the draft text. + context_ids = ( + verifier_context[-translation_prefix_tokens:] + if translation_prefix_tokens > 0 + else None + ) + draft_text, verifier_token_ids = _translate_to_verifier( + draft_token_list, context_ids + ) + n_verifier = len(verifier_token_ids) + + if n_verifier == 0: + # Draft text produced no verifier tokens; do a plain verifier step + tokens_v, logprobs_v = _step(model, model_cache, y, 1) + mx.eval(tokens_v) + ntoks += 1 + fallback_tok = tokens_v.item() + verifier_context.append(fallback_tok) + if translation_prefix_tokens > 0: + verifier_context = verifier_context[-translation_prefix_tokens:] + yield fallback_tok, logprobs_v.squeeze(0), False + if ntoks == max_tokens: + break + y = mx.array([fallback_tok], mx.uint32) + _, draft_corr = _translate_to_draft([fallback_tok]) + cache.trim_prompt_cache(draft_cache, num_draft - 1) + num_draft = 0 + if draft_corr: + if len(draft_corr) > 1: + with mx.stream(generation_stream): + mx.async_eval( + draft_model( + mx.array(draft_corr[:-1], mx.uint32)[None], + cache=draft_cache, + ) + ) + draft_y = mx.array(draft_corr[-1:], mx.uint32) + else: + fallback = draft_tokenizer.encode( + " ", add_special_tokens=False + ) + draft_y = mx.array(fallback[-1:], mx.uint32) + continue + + # Verify re-encoded tokens with the verifier model + if prev_tokens is not None: + prev_tokens = prev_tokens[ + : prev_tokens.size - y.size - n_verifier + 1 + ] + y_verify = mx.concatenate( + [y, mx.array(verifier_token_ids, mx.uint32)] + ) + tokens_v, logprobs_v = _step( + model, model_cache, y_verify, n_verifier + 1 + ) + needs_cleanup = True + mx.eval(tokens_v) + tokens_list = tokens_v.tolist() + + # Find first disagreement between verifier samples and + # the re-encoded draft tokens + n_accept_v = 0 + while n_accept_v < n_verifier: + if tokens_list[n_accept_v] != verifier_token_ids[n_accept_v]: + break + n_accept_v += 1 + + # Yield accepted verifier tokens + for i in range(n_accept_v): + ntoks += 1 + yield verifier_token_ids[i], logprobs_v[i], True + if ntoks == max_tokens: + break + + # Yield the verifier's correction token + if ntoks < max_tokens: + ntoks += 1 + yield tokens_list[n_accept_v], logprobs_v[n_accept_v], False - ntoks = 0 - # Set these so the finally block doesn't raise - num_draft = 0 - n = 0 - try: - while True: - num_draft = min(max_tokens - ntoks, num_draft_tokens) - draft_tokens = _draft_generate(draft_y, num_draft) - if prev_tokens is not None: - prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1] - y = mx.concatenate([y, draft_tokens]) - tokens, logprobs = _step(model, model_cache, y, num_draft + 1) - mx.eval(tokens, draft_tokens) - draft_tokens = draft_tokens.tolist() - tokens = tokens.tolist() - n = 0 - while n < num_draft: - tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] - if tn != dtn: - break - n += 1 - ntoks += 1 - yield tn, lpn, True if ntoks == max_tokens: break - if ntoks < max_tokens: - ntoks += 1 - yield tokens[n], logprobs[n], False - if ntoks == max_tokens: - break + correction_token = tokens_list[n_accept_v] - y = mx.array([tokens[n]], mx.uint32) - draft_y = y + # Keep a sliding window of recent verifier tokens for prefix context + if translation_prefix_tokens > 0: + verifier_context.extend( + verifier_token_ids[:n_accept_v] + [correction_token] + ) + verifier_context = verifier_context[-translation_prefix_tokens:] + + # Rewind verifier cache: keep y + accepted positions + cache.trim_prompt_cache(model_cache, n_verifier - n_accept_v) + + # Rewind draft cache: keep only the draft_y_old entry from + # this round so we can replay the accepted text + cache.trim_prompt_cache(draft_cache, num_draft - 1) + needs_cleanup = False + + # Translate accepted + correction text back to draft tokens + # and feed them to the draft model so it stays in sync + all_new_verifier = verifier_token_ids[:n_accept_v] + [correction_token] + _, new_draft_tokens = _translate_to_draft(all_new_verifier) + + if new_draft_tokens: + if len(new_draft_tokens) > 1: + with mx.stream(generation_stream): + mx.async_eval( + draft_model( + mx.array( + new_draft_tokens[:-1], mx.uint32 + )[None], + cache=draft_cache, + ) + ) + draft_y = mx.array(new_draft_tokens[-1:], mx.uint32) + else: + fallback = draft_tokenizer.encode( + " ", add_special_tokens=False + ) + draft_y = mx.array(fallback[-1:], mx.uint32) - # If we accepted all the draft tokens, include the last - # draft token in the next draft step since it hasn't been - # processed yet by the draft model - if n == num_draft: - draft_y = mx.concatenate( - [mx.array(draft_tokens[-1:], mx.uint32), draft_y] - ) + y = mx.array([correction_token], mx.uint32) + + if prev_tokens is not None: + prev_tokens = prev_tokens[: -max(n_verifier - n_accept_v, 1)] + + finally: + if needs_cleanup: + cache.trim_prompt_cache(model_cache, n_verifier - n_accept_v) + cache.trim_prompt_cache(draft_cache, num_draft - 1) - if prev_tokens is not None: - prev_tokens = prev_tokens[: -max(num_draft - n, 1)] + else: + # Original same-tokenizer speculative decoding + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + if prev_tokens is not None: + prev_tokens = prev_tokens[ + : prev_tokens.size - y.size - num_draft + 1 + ] + y = mx.concatenate([y, draft_tokens]) + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn, True + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n], False + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + if prev_tokens is not None: + prev_tokens = prev_tokens[: -max(num_draft - n, 1)] + _rewind_cache(num_draft, n) + finally: _rewind_cache(num_draft, n) - finally: - _rewind_cache(num_draft, n) def stream_generate( @@ -649,6 +844,8 @@ def stream_generate( prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + draft_tokenizer: Optional[Union[PreTrainedTokenizer, TokenizerWrapper]] = None, + translation_prefix_tokens: int = 0, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -662,8 +859,14 @@ def stream_generate( max_tokens (int): The maximum number of tokens to generate. Default: ``256``. draft_model (Optional[nn.Module]): An optional draft model. If provided - then speculative decoding is used. The draft model must use the same - tokenizer as the main model. Default: ``None``. + then speculative decoding is used. Default: ``None``. + draft_tokenizer: The draft model's tokenizer. When provided, enables + cross-tokenizer speculative decoding where draft tokens are decoded + to text and re-encoded with the verifier tokenizer before + verification. Default: ``None``. + translation_prefix_tokens (int): Number of previously generated verifier + tokens used as context when re-encoding draft text into verifier tokens. + Only active when ``draft_tokenizer`` is set. Default: ``0``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -698,7 +901,13 @@ def stream_generate( kwargs.pop("max_kv_size", None) kwargs.pop("prompt_progress_callback", None) token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs + prompt, + model, + draft_model, + tokenizer=tokenizer, + draft_tokenizer=draft_tokenizer, + translation_prefix_tokens=translation_prefix_tokens, + **kwargs, ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() @@ -1500,10 +1709,9 @@ def main(): if args.draft_model is not None: draft_model, draft_tokenizer = load(args.draft_model) - if draft_tokenizer.vocab_size != tokenizer.vocab_size: - raise ValueError("Draft model tokenizer does not match model tokenizer.") else: draft_model = None + draft_tokenizer = None sampler = make_sampler( args.temp, args.top_p, @@ -1527,6 +1735,7 @@ def main(): kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, + draft_tokenizer=draft_tokenizer, num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: