From db145daf763f55ac57ff05e883c6373beae5db7e Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Fri, 24 Oct 2025 14:55:07 -0400 Subject: [PATCH 01/13] refactor qwen2_vl vision add on --- mlx_engine/model_kit/model_kit.py | 2 +- .../process_prompt_with_images.py | 4 +- .../model_kit/vision_add_ons/qwen2_vl.py | 49 ++++++++++++++++--- .../vision_model_kit/vision_model_wrapper.py | 5 ++ 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 08dfa877..c94c0f52 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -163,7 +163,7 @@ def process_prompt( speculative_decoding_toggle, prompt_progress_callback, ), None - ### WITH IMAGES PROMPT PROCESSING ###s + ### WITH IMAGES PROMPT PROCESSING ### if self.vision_add_on is None: raise ValueError( "Vision add-on is not loaded, but images were provided for processing" diff --git a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py index 0dde9705..2736d756 100644 --- a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py +++ b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py @@ -22,6 +22,7 @@ def common_process_prompt_with_images( processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model max_size: tuple[int, int] | None, + should_pad: bool = True, ) -> ProcessedImagePrompt: """ Common prompt processing used by mlx-vlm vision add-ons. @@ -33,6 +34,7 @@ def common_process_prompt_with_images( processor: Tokenizer/processor for the model config: Model configuration object max_size: Maximum image size as (width, height) tuple. If None, no resizing. + should_pad: Whether to pad images to uniform size. Defaults to True. """ if len(images_b64) == 0: raise ValueError("Images must be non-empty") @@ -45,7 +47,7 @@ def common_process_prompt_with_images( logger.info(f"Prompt dump: {prompt}\n") images = convert_to_pil(images_b64) - images = custom_resize(images, max_size=max_size) + images = custom_resize(images, max_size=max_size, should_pad=should_pad) if hasattr(config, "image_token_index"): image_token_index = config.image_token_index diff --git a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py index 124892b0..ad03945d 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py @@ -7,7 +7,9 @@ from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon -from mlx_engine.model_kit.vision_add_ons.qwen_vl_utils import compute_qwen_vl_embeddings +from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, +) from mlx_vlm.models.qwen2_5_vl import ( VisionModel as Qwen25VLVisionTower, @@ -80,12 +82,47 @@ def compute_embeddings( """ Compute input_ids and embeddings for text with images. """ - - return compute_qwen_vl_embeddings( - addon=self, - text_model=text_model, + # Process prompt with images (cheap operation - tokenization and prepare_inputs) + # Note: Qwen models require should_pad=False for images + processed = common_process_prompt_with_images( prompt_tokens=prompt_tokens, images_b64=images_b64, - qwen_vl_version=2, + processor=self.processor, + config=self.config, max_size=max_size, + should_pad=False, + ) + + input_ids = processed.input_ids + pixel_values = processed.pixel_values + + # Get image_grid_thw from other_inputs if present + grid_thw = processed.other_inputs.get("image_grid_thw") + + # Get text embeddings + input_embeddings = text_model.language_model.model.embed_tokens(input_ids) + + # If no images, return input_ids and input_embeddings + if pixel_values is None: + return input_ids.squeeze(0), input_embeddings.squeeze(0) + + # Ensure pixel values are in the right format for vision tower + if pixel_values.dtype != input_embeddings.dtype: + pixel_values = pixel_values.astype(input_embeddings.dtype) + + # Process image through vision tower (expensive operation) + hidden_states = self.vision_tower( + pixel_values, grid_thw, output_hidden_states=False ) + + # Merge image features with text embeddings (expensive operation) + final_inputs_embeds = self.model_cls.merge_input_ids_with_image_features( + self.config.image_token_id, + self.config.video_token_id, + hidden_states, + input_embeddings, + input_ids, + ) + + # Remove batch dimension + return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) diff --git a/mlx_engine/vision_model_kit/vision_model_wrapper.py b/mlx_engine/vision_model_kit/vision_model_wrapper.py index e71b41c9..4f273b62 100644 --- a/mlx_engine/vision_model_kit/vision_model_wrapper.py +++ b/mlx_engine/vision_model_kit/vision_model_wrapper.py @@ -111,6 +111,11 @@ def __call__(self, *args, input_embeddings=None, **kwargs): "decoder_input_ids": self.decoder_input_ids, "encoder_outputs": outputs.encoder_outputs, } + # elif self.vision_model.config.model_type == "qwen3_vl": + # self.language_model_kwargs = { + # "visual_pos_masks": outputs.visual_pos_masks, + # "deepstack_visual_embeds": outputs.deepstack_visual_embeds, + # } # Add the cache we created here to the language model kwargs self.language_model_kwargs["cache"] = cache From 2ce6b9d77f861ce893e2618df1776e35989ee54a Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Fri, 24 Oct 2025 17:42:37 -0400 Subject: [PATCH 02/13] working p1 --- mlx_engine/cache_wrapper.py | 66 ++++++++++++++++++++++ mlx_engine/model_kit/model_kit.py | 91 +++++++++++++++++++++++++++++-- tests/test_vision_cache.py | 79 +++++++++++++++++++++++++++ 3 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 tests/test_vision_cache.py diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 96673954..a0ece7e9 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -59,6 +59,10 @@ def __init__( ) self.chunk_size = chunk_size + # Vision prompt caching state + self.prev_images_hash: Optional[str] = None + self.prev_raw_prompt_tokens: Optional[List[int]] = None + def _get_num_tokens_in_cache(self) -> int | None: """ Get the number of tokens in the cache. @@ -138,6 +142,7 @@ def _get_unprocessed_tokens( ) self.cache = make_prompt_cache(self.model, self.max_kv_size) self.tokens = prompt_tokens + self.clear_vision_cache() return self.tokens num_tokens_to_trim = num_tokens_in_cache - common_prefix if num_tokens_to_trim > 0: @@ -147,6 +152,7 @@ def _get_unprocessed_tokens( ) self.cache = make_prompt_cache(self.model, self.max_kv_size) self.tokens = prompt_tokens + self.clear_vision_cache() return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) if tokens_trimmed != num_tokens_to_trim: @@ -156,6 +162,7 @@ def _get_unprocessed_tokens( ) self.cache = make_prompt_cache(self.model, self.max_kv_size) self.tokens = prompt_tokens + self.clear_vision_cache() return self.tokens logger.info(f"Trimmed {num_tokens_to_trim} tokens from the prompt cache") @@ -223,6 +230,7 @@ def _prefill( if num_tokens_in_cache is None: self.cache = make_prompt_cache(self.model, self.max_kv_size) self.tokens = None + self.clear_vision_cache() else: # Remember which tokens were processed so far, so that we can continue processing at a later point self.tokens = self.tokens[:num_tokens_in_cache] @@ -255,6 +263,7 @@ def set_draft_model(self, draft_model: nn.Module): # https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382 logger.info("Clearing current prompt cache and adding draft model to the cache") self.tokens = None + self.clear_vision_cache() self.cache: List[Any] = make_prompt_cache(self.model) if draft_model is not None: self.cache += make_prompt_cache(draft_model) @@ -332,3 +341,60 @@ def record_generated_token(self, token): Add the generated token to the token list, so that we can map the token to the KV cache. """ self.tokens = mx.concat([self.tokens, mx.array([token])]) + + def _compute_images_hash(self, images_b64: List[str]) -> str: + """Compute hash of images for cache validation.""" + import hashlib + + combined = "".join(images_b64) + return hashlib.sha256(combined.encode()).hexdigest() + + def can_reuse_vision_cache( + self, images_b64: List[str], raw_prompt_tokens: List[int] + ) -> bool: + """ + Check if we can skip expensive vision processing and reuse cached KV states. + + Args: + images_b64: Current request's base64-encoded images + raw_prompt_tokens: Current request's raw prompt tokens (before vision processing) + + Returns: + bool: True if we can skip vision processing, False otherwise + """ + if self.prev_images_hash is None or self.prev_raw_prompt_tokens is None: + return False + + # Check if images are identical + current_images_hash = self._compute_images_hash(images_b64) + if current_images_hash != self.prev_images_hash: + return False + + # Check if current prompt extends previous prompt + if len(raw_prompt_tokens) <= len(self.prev_raw_prompt_tokens): + return False + + # Check if prefix matches exactly + if ( + raw_prompt_tokens[: len(self.prev_raw_prompt_tokens)] + != self.prev_raw_prompt_tokens + ): + return False + + return True + + def record_vision_state(self, images_b64: List[str], raw_prompt_tokens: List[int]): + """ + Record vision processing state for future cache validation. + + Args: + images_b64: Base64-encoded images that were processed + raw_prompt_tokens: Raw prompt tokens (before vision processing) that were used + """ + self.prev_images_hash = self._compute_images_hash(images_b64) + self.prev_raw_prompt_tokens = raw_prompt_tokens + + def clear_vision_cache(self): + """Clear vision-specific cache state.""" + self.prev_images_hash = None + self.prev_raw_prompt_tokens = None diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index c94c0f52..12127836 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -136,6 +136,37 @@ def tokenize(self, prompt: str) -> List[int]: return [ids] return ids + def _get_input_ids_via_prepare_inputs( + self, + prompt_tokens: mx.array, + images_b64: List[str], + max_image_size: tuple[int, int] | None, + ) -> mx.array: + """ + Get input_ids with image tokens inserted (cheap operation). + Calls common_process_prompt_with_images but skips expensive vision processing. + """ + from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, + ) + + # Determine should_pad based on model type + # Qwen models need should_pad=False, others use True (default) + should_pad = self.model_type not in ["qwen2_vl", "qwen2_5_vl"] + + processed = common_process_prompt_with_images( + prompt_tokens=prompt_tokens, + images_b64=images_b64, + processor=self.vision_add_on.processor, + config=self.vision_add_on.config, + max_size=max_image_size, + should_pad=should_pad, + ) + + # Return input_ids, squeeze batch dimension if present + input_ids = processed.input_ids + return input_ids.squeeze(0) if input_ids.ndim > 1 else input_ids + def process_prompt( self, prompt_tokens, @@ -168,11 +199,63 @@ def process_prompt( raise ValueError( "Vision add-on is not loaded, but images were provided for processing" ) - self._cross_prompt_cache_active = False - input_ids, embeddings = self.vision_add_on.compute_embeddings( - self.model, prompt_tokens, images_b64, max_size=max_image_size + + # Convert prompt_tokens to list for cache validation + prompt_tokens_list = ( + prompt_tokens if isinstance(prompt_tokens, list) else prompt_tokens.tolist() + ) + + # Check if we can skip expensive vision processing + can_skip_vision_processing = self.cache_wrapper.can_reuse_vision_cache( + images_b64, prompt_tokens_list ) - return input_ids, embeddings + + if can_skip_vision_processing: + # CHEAP PATH: Skip vision tower, reuse cached KV states + logger.info("Reusing cached vision features from previous request") + + # Get input_ids with image tokens (cheap - just tokenization + prepare_inputs) + input_ids = self._get_input_ids_via_prepare_inputs( + prompt_tokens, images_b64, max_image_size + ) + + # Enable caching so generated tokens are recorded + self._cross_prompt_cache_active = True + + # Use cache wrapper to find common prefix and return unprocessed tokens + unprocessed_tokens = self.cache_wrapper.update_cache( + input_ids, + prompt_progress_callback, + ) + generate_args["prompt_cache"] = self.cache_wrapper.cache + + # Update vision state for next request + self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + + # Return tokens only, no embeddings (model will use text embeddings for new tokens) + return unprocessed_tokens, None + else: + # EXPENSIVE PATH: Full vision processing (first request or images changed) + logger.info("Performing full vision processing with images") + + input_ids, embeddings = self.vision_add_on.compute_embeddings( + self.model, prompt_tokens, images_b64, max_size=max_image_size + ) + + # Enable caching - we want generated tokens recorded for future requests + self._cross_prompt_cache_active = True + + # Record vision state for future requests + self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + + # Initialize cache tracking with the processed input_ids + # This is critical - tells cache_wrapper what tokens are being processed + self.cache_wrapper.tokens = input_ids + + # Set prompt_cache for generation (fixes missing cache usage in vision path!) + generate_args["prompt_cache"] = self.cache_wrapper.cache + + return input_ids, embeddings def is_cross_prompt_cache_active(self) -> bool: """ diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py new file mode 100644 index 00000000..a920c54a --- /dev/null +++ b/tests/test_vision_cache.py @@ -0,0 +1,79 @@ +import unittest +import base64 +from pathlib import Path +from mlx_engine.generate import ( + load_model, + tokenize, + create_generator, +) +from tests.shared import model_getter + + +MAX_IMAGE_SIZE = (1024, 1024) + + +class TestVisionCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + cls.description_prompt = "What is this" + cls.text_only_prompt = "What is a toucan?" + cls.test_data_dir = Path(__file__).parent / "data" + cls.demo_data_dir = Path(__file__).parent.parent / "demo-data" + + # Read and encode test images + cls.toucan_path = Path(__file__).parent.parent / "demo-data" / "toucan.jpeg" + with open(cls.toucan_path, "rb") as image_file: + cls.toucan_image_b64 = base64.b64encode(image_file.read()).decode("utf-8") + cls.chameleon_image_path = ( + Path(__file__).parent.parent / "demo-data" / "chameleon.webp" + ) + with open(cls.chameleon_image_path, "rb") as chameleon_image_file: + cls.chameleon_image_b64 = base64.b64encode( + chameleon_image_file.read() + ).decode("utf-8") + + ### MODEL-SPECIFIC TESTS ### + def test_gemma3n(self): + """Test LFM2-VL 450M model""" + prompt = "user\nIn one word each, describe the images\nmodel\n" + model_name = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" + print(f"Testing model {model_name}") + model_path = model_getter(model_name=model_name) + + # Load the model + model_kit = load_model( + model_path=model_path, max_kv_size=2048, trust_remote_code=True + ) + + def generate_text(prompt): + # Tokenize the prompt + prompt_tokens = tokenize(model_kit, prompt) + + # Generate description + generated_text = "" + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + images_b64=[self.toucan_image_b64, self.chameleon_image_b64], + max_image_size=MAX_IMAGE_SIZE, + seed=0, + max_tokens=100, + temp=0.0, + repetition_penalty=1.01, # enable the logits processor code path + ): + generated_text += result.text + print(result.text, end="", flush=True) + if result.stop_condition: + break + print() + return generated_text + + generated_text = generate_text(prompt) + print("--") + prompt = ( + prompt + + generated_text + + "\nuser\nwhich direction is each animal facing?\nmodel\n" + ) + generated_text = generate_text(prompt) From 6a04c0e7e54e8d550e8da995a48ba8d1f6cc6707 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Mon, 27 Oct 2025 12:03:06 -0400 Subject: [PATCH 03/13] use text-only hook --- demo.py | 14 ++++++++++---- mlx_engine/model_kit/model_kit.py | 11 +++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/demo.py b/demo.py index dfcc3fd0..d48adbcb 100644 --- a/demo.py +++ b/demo.py @@ -92,6 +92,9 @@ def setup_arg_parser(): action="store_true", help="Enable printed prompt processing progress callback", ) + parser.add_argument( + "--max-img-size", type=int, help="Downscale images to this side length (px)" + ) return parser @@ -203,22 +206,22 @@ def prompt_progress_callback(percent): tf_tokenizer = AutoProcessor.from_pretrained(model_path) images_base64 = [image_to_base64(img_path) for img_path in args.images] conversation = [ - DEFAULT_SYSTEM_PROMPT, + # DEFAULT_SYSTEM_PROMPT, { "role": "user", "content": [ - {"type": "text", "text": prompt}, *[ {"type": "image", "base64": image_b64} for image_b64 in images_base64 ], + {"type": "text", "text": prompt}, ], }, ] else: tf_tokenizer = AutoTokenizer.from_pretrained(model_path) conversation = [ - DEFAULT_SYSTEM_PROMPT, + # DEFAULT_SYSTEM_PROMPT, {"role": "user", "content": prompt}, ] prompt = tf_tokenizer.apply_chat_template( @@ -232,12 +235,15 @@ def prompt_progress_callback(percent): # Initialize generation stats collector stats_collector = GenerationStatsCollector() + # Clamp image size + max_img_size = (args.max_img_size, args.max_img_size) if args.max_img_size else None + # Generate the response generator = create_generator( model_kit, prompt_tokens, images_b64=images_base64, - max_image_size=(1024, 1024), + max_image_size=max_img_size, stop_strings=args.stop_strings, max_tokens=1024, top_logprobs=args.top_logprobs, diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 12127836..fcfe34d4 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -222,12 +222,15 @@ def process_prompt( # Enable caching so generated tokens are recorded self._cross_prompt_cache_active = True - # Use cache wrapper to find common prefix and return unprocessed tokens - unprocessed_tokens = self.cache_wrapper.update_cache( + # Process like text-only: use cache wrapper to preprocess new tokens + unprocessed_tokens = process_prompt_text_only( input_ids, - prompt_progress_callback, + self.cache_wrapper, + generate_args, + draft_model=None, # Vision models don't support draft models + speculative_decoding_toggle=None, + prompt_progress_callback=prompt_progress_callback, ) - generate_args["prompt_cache"] = self.cache_wrapper.cache # Update vision state for next request self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) From c9777fe146a641b261d8b0a9e4b7025853b6338d Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 14:45:52 -0400 Subject: [PATCH 04/13] working test --- mlx_engine/model_kit/model_kit.py | 7 +++---- tests/test_vision_cache.py | 31 +++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index fcfe34d4..fdf3deac 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -16,6 +16,9 @@ from mlx_engine.model_kit.vision_add_ons.qwen2_vl import Qwen2_VLVisionAddOn from mlx_engine.utils.kv_cache_quantization import get_kv_cache_quantization_params from mlx_engine.utils.prompt_processing import process_prompt_text_only +from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, +) logger = logging.getLogger(__name__) @@ -146,10 +149,6 @@ def _get_input_ids_via_prepare_inputs( Get input_ids with image tokens inserted (cheap operation). Calls common_process_prompt_with_images but skips expensive vision processing. """ - from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( - common_process_prompt_with_images, - ) - # Determine should_pad based on model type # Qwen models need should_pad=False, others use True (default) should_pad = self.model_type not in ["qwen2_vl", "qwen2_5_vl"] diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py index a920c54a..ba4a103a 100644 --- a/tests/test_vision_cache.py +++ b/tests/test_vision_cache.py @@ -16,11 +16,6 @@ class TestVisionCache(unittest.TestCase): @classmethod def setUpClass(cls): """Set up test resources that will be shared across all test methods""" - cls.description_prompt = "What is this" - cls.text_only_prompt = "What is a toucan?" - cls.test_data_dir = Path(__file__).parent / "data" - cls.demo_data_dir = Path(__file__).parent.parent / "demo-data" - # Read and encode test images cls.toucan_path = Path(__file__).parent.parent / "demo-data" / "toucan.jpeg" with open(cls.toucan_path, "rb") as image_file: @@ -33,7 +28,6 @@ def setUpClass(cls): chameleon_image_file.read() ).decode("utf-8") - ### MODEL-SPECIFIC TESTS ### def test_gemma3n(self): """Test LFM2-VL 450M model""" prompt = "user\nIn one word each, describe the images\nmodel\n" @@ -46,7 +40,14 @@ def test_gemma3n(self): model_path=model_path, max_kv_size=2048, trust_remote_code=True ) - def generate_text(prompt): + callback_history = [] + + def prompt_callback(x): + nonlocal callback_history + callback_history.append(x) + return True + + def generate_text(prompt, images_b64=None): # Tokenize the prompt prompt_tokens = tokenize(model_kit, prompt) @@ -55,6 +56,7 @@ def generate_text(prompt): for result in create_generator( model_kit=model_kit, prompt_tokens=prompt_tokens, + prompt_progress_callback=prompt_callback, images_b64=[self.toucan_image_b64, self.chameleon_image_b64], max_image_size=MAX_IMAGE_SIZE, seed=0, @@ -70,10 +72,23 @@ def generate_text(prompt): return generated_text generated_text = generate_text(prompt) + self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm + callback_history = [] + + # ask a followup question print("--") prompt = ( prompt + generated_text - + "\nuser\nwhich direction is each animal facing?\nmodel\n" + + "\nuser\nwhich direction is each animal facing?\n" + + "also, remember this information: " + + ", ".join([str(x) for x in range(200)]) + + "\n" + + "model\n" ) generated_text = generate_text(prompt) + + # prompt processing by cache_wrapper. less work is done since the images are cached + self.assertEqual(len(callback_history), 3) + self.assertRegex(generated_text, "toucan.*left") + self.assertRegex(generated_text, "chameleon.*right") From cb2dba7b4227c6c3a71871168422035a7bbd8c94 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 15:04:29 -0400 Subject: [PATCH 05/13] simplify --- mlx_engine/model_kit/model_kit.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index fdf3deac..7b7f8a05 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -54,7 +54,6 @@ class ModelKit: tokenizer: TokenizerWrapper = None detokenizer: StreamingDetokenizer = None cache_wrapper: Optional[CacheWrapper] = None - _cross_prompt_cache_active: bool = False max_kv_size: Optional[int] = None kv_bits: Optional[int] = None kv_group_size: Optional[int] = None @@ -178,7 +177,6 @@ def process_prompt( ### TEXT-ONLY PROCESS_PROMPT ### is_text_only_processing = images_b64 is None or len(images_b64) == 0 if is_text_only_processing: - self._cross_prompt_cache_active = True if len(prompt_tokens) == 0: logger.warning( "Received empty prompt. Generation quality will likely be poor" @@ -218,9 +216,6 @@ def process_prompt( prompt_tokens, images_b64, max_image_size ) - # Enable caching so generated tokens are recorded - self._cross_prompt_cache_active = True - # Process like text-only: use cache wrapper to preprocess new tokens unprocessed_tokens = process_prompt_text_only( input_ids, @@ -244,9 +239,6 @@ def process_prompt( self.model, prompt_tokens, images_b64, max_size=max_image_size ) - # Enable caching - we want generated tokens recorded for future requests - self._cross_prompt_cache_active = True - # Record vision state for future requests self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) @@ -263,8 +255,11 @@ def is_cross_prompt_cache_active(self) -> bool: """ Check if cross-prompt caching is currently enabled. Can be overridden by subclasses for custom behavior. + + ModelKit always supports cross-prompt caching. + VisionModelKit overrides this to return False. """ - return self._cross_prompt_cache_active + return True def record_token_to_cache(self, token: int) -> None: self.cache_wrapper.record_generated_token(token) From 1ae47321059103ad9db9263cccf70b7cf98580e3 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 16:12:57 -0400 Subject: [PATCH 06/13] test non-swa model --- mlx_engine/cache_wrapper.py | 26 ++++++++++------ tests/test_vision_cache.py | 61 ++++++++++++++++++++++++------------- 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index a0ece7e9..9cf288bc 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -355,6 +355,9 @@ def can_reuse_vision_cache( """ Check if we can skip expensive vision processing and reuse cached KV states. + Supports both extending (longer prompt) and rewinding (shorter prompt) + as long as one prompt is a prefix of the other. + Args: images_b64: Current request's base64-encoded images raw_prompt_tokens: Current request's raw prompt tokens (before vision processing) @@ -370,18 +373,21 @@ def can_reuse_vision_cache( if current_images_hash != self.prev_images_hash: return False - # Check if current prompt extends previous prompt - if len(raw_prompt_tokens) <= len(self.prev_raw_prompt_tokens): - return False + # Use existing _find_common_prefix to check if one prompt is a prefix of the other + current_tokens = mx.array(raw_prompt_tokens) + prev_tokens = mx.array(self.prev_raw_prompt_tokens) - # Check if prefix matches exactly - if ( - raw_prompt_tokens[: len(self.prev_raw_prompt_tokens)] - != self.prev_raw_prompt_tokens - ): - return False + # Find common prefix length (num_tokens_to_exclude=0 since we don't need that constraint) + common_length = self._find_common_prefix( + current_tokens=prev_tokens, + prompt_tokens=current_tokens, + num_tokens_to_exclude=0, + ) - return True + # Can reuse if one prompt is a complete prefix of the other + # (common_length equals the length of the shorter prompt) + min_length = min(len(raw_prompt_tokens), len(self.prev_raw_prompt_tokens)) + return common_length == min_length def record_vision_state(self, images_b64: List[str], raw_prompt_tokens: List[int]): """ diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py index ba4a103a..9eb162bd 100644 --- a/tests/test_vision_cache.py +++ b/tests/test_vision_cache.py @@ -7,7 +7,7 @@ create_generator, ) from tests.shared import model_getter - +import pytest MAX_IMAGE_SIZE = (1024, 1024) @@ -28,12 +28,15 @@ def setUpClass(cls): chameleon_image_file.read() ).decode("utf-8") - def test_gemma3n(self): - """Test LFM2-VL 450M model""" - prompt = "user\nIn one word each, describe the images\nmodel\n" - model_name = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" - print(f"Testing model {model_name}") + @pytest.mark.heavy + def test_nonswa_model(self): + """ + Test that image caching works for models without a SWA cache + """ + prompt = "[INST][IMG][IMG]In one word each, describe the animal in the images[/INST]\n" + model_name = "lmstudio-community/Mistral-Small-3.2-24B-Instruct-2506-MLX-4bit" model_path = model_getter(model_name=model_name) + images_b64 = [self.toucan_image_b64, self.chameleon_image_b64] # Load the model model_kit = load_model( @@ -47,7 +50,10 @@ def prompt_callback(x): callback_history.append(x) return True - def generate_text(prompt, images_b64=None): + def generate_text(prompt): + nonlocal callback_history + callback_history = [] + # Tokenize the prompt prompt_tokens = tokenize(model_kit, prompt) @@ -57,7 +63,7 @@ def generate_text(prompt, images_b64=None): model_kit=model_kit, prompt_tokens=prompt_tokens, prompt_progress_callback=prompt_callback, - images_b64=[self.toucan_image_b64, self.chameleon_image_b64], + images_b64=images_b64, max_image_size=MAX_IMAGE_SIZE, seed=0, max_tokens=100, @@ -73,22 +79,35 @@ def generate_text(prompt, images_b64=None): generated_text = generate_text(prompt) self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm - callback_history = [] # ask a followup question print("--") - prompt = ( - prompt - + generated_text - + "\nuser\nwhich direction is each animal facing?\n" - + "also, remember this information: " - + ", ".join([str(x) for x in range(200)]) - + "\n" - + "model\n" + prompt2 = ( + prompt + generated_text + "[INST]what color is each animal's eyes?[/INST]" ) - generated_text = generate_text(prompt) + generated_text2 = generate_text(prompt2) # prompt processing by cache_wrapper. less work is done since the images are cached - self.assertEqual(len(callback_history), 3) - self.assertRegex(generated_text, "toucan.*left") - self.assertRegex(generated_text, "chameleon.*right") + self.assertEqual(len(callback_history), 2) + self.assertRegex(generated_text2, "toucan.*dark") + self.assertRegex(generated_text2, "chameleon.*orange") + + # rewind the cache but swap the images. full preprocessing happens + images_b64.reverse() + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # rewind the cache and re-prompt; images are not processed + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 2) + + # add an image in the followup; all three images are re-processed + images_b64.append(self.chameleon_image_b64) + prompt3 = ( + prompt + + generated_text + + "[INST][IMG]do you see two toucans or chameleons?[/INST]" + ) + generated_text3 = generate_text(prompt3) + self.assertEqual(len(callback_history), 5) + self.assertRegex(generated_text3, "chameleon") From 47e1d8a52d181f310a1799acc7082eb598ed0d11 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 16:51:43 -0400 Subject: [PATCH 07/13] checkpoint --- mlx_engine/cache_wrapper.py | 39 +++++++++++--- mlx_engine/model_kit/model_kit.py | 72 +++++++++++++------------ tests/test_vision_cache.py | 88 ++++++++++++++++++++++++++++++- 3 files changed, 158 insertions(+), 41 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 9cf288bc..6f30ebed 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -22,6 +22,13 @@ class StopPromptProcessing(Exception): """ +class CacheNotTrimmableError(Exception): + """ + Exception to signal that cache trimming is required but the cache is not trimmable. + Used in vision mode to signal that full reprocessing with vision add-on is needed. + """ + + class CacheWrapper: """ Wrapper class for the MLX LM cache to maintain an in-memory cache @@ -112,6 +119,30 @@ def _find_common_prefix( common_length = max(common_length - length_adjustment, 0) return common_length + def _handle_nontrimmable_cache( + self, num_tokens_to_trim: int, prompt_tokens: mx.array + ): + # Check if we've cached images + if self.prev_images_hash is not None: + logger.warning( + "Cache is not trimmable and vision processing is active. " + "Signaling need for full reprocessing." + ) + self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.tokens = None + # Don't clear vision cache - let caller handle full reprocessing + raise CacheNotTrimmableError() + else: + # Non-vision mode + logger.warning( + f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, " + f"but could not: Cache is not trimmable. Clearing the cache instead." + ) + self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.tokens = prompt_tokens + self.clear_vision_cache() + return self.tokens + def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int ): @@ -147,13 +178,9 @@ def _get_unprocessed_tokens( num_tokens_to_trim = num_tokens_in_cache - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): - logger.warning( - f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: Cache is not trimmable. Clearing the cache instead." + return self._handle_nontrimmable_cache( + num_tokens_to_trim, prompt_tokens ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - self.clear_vision_cache() - return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) if tokens_trimmed != num_tokens_to_trim: # If we trimmed fewer tokens than expected, the cache is invalid diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 7b7f8a05..7d948f1b 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, List, Tuple import mlx_lm from mlx_lm.tokenizer_utils import TokenizerWrapper, StreamingDetokenizer -from mlx_engine.cache_wrapper import CacheWrapper +from mlx_engine.cache_wrapper import CacheWrapper, CacheNotTrimmableError from pathlib import Path import mlx.nn as nn import mlx.core as mx @@ -208,48 +208,52 @@ def process_prompt( ) if can_skip_vision_processing: - # CHEAP PATH: Skip vision tower, reuse cached KV states - logger.info("Reusing cached vision features from previous request") + try: + # CHEAP PATH: Skip vision tower, reuse cached KV states - # Get input_ids with image tokens (cheap - just tokenization + prepare_inputs) - input_ids = self._get_input_ids_via_prepare_inputs( - prompt_tokens, images_b64, max_image_size - ) + # Get input_ids with image tokens + input_ids = self._get_input_ids_via_prepare_inputs( + prompt_tokens, images_b64, max_image_size + ) - # Process like text-only: use cache wrapper to preprocess new tokens - unprocessed_tokens = process_prompt_text_only( - input_ids, - self.cache_wrapper, - generate_args, - draft_model=None, # Vision models don't support draft models - speculative_decoding_toggle=None, - prompt_progress_callback=prompt_progress_callback, - ) + # Process like text-only: use cache wrapper to preprocess new tokens + unprocessed_tokens = process_prompt_text_only( + input_ids, + self.cache_wrapper, + generate_args, + draft_model=None, # Vision models don't support draft models + speculative_decoding_toggle=None, + prompt_progress_callback=prompt_progress_callback, + ) - # Update vision state for next request - self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + # Update vision state for next request + self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) - # Return tokens only, no embeddings (model will use text embeddings for new tokens) - return unprocessed_tokens, None - else: - # EXPENSIVE PATH: Full vision processing (first request or images changed) - logger.info("Performing full vision processing with images") + # Return tokens only, no embeddings (model will use text embeddings for new tokens) + return unprocessed_tokens, None - input_ids, embeddings = self.vision_add_on.compute_embeddings( - self.model, prompt_tokens, images_b64, max_size=max_image_size - ) + except CacheNotTrimmableError: + pass + # Fall through to expensive path below + + # EXPENSIVE PATH: Full vision processing (first request or images changed) + logger.info("Performing full vision processing with images") + + input_ids, embeddings = self.vision_add_on.compute_embeddings( + self.model, prompt_tokens, images_b64, max_size=max_image_size + ) - # Record vision state for future requests - self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + # Record vision state for future requests + self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) - # Initialize cache tracking with the processed input_ids - # This is critical - tells cache_wrapper what tokens are being processed - self.cache_wrapper.tokens = input_ids + # Initialize cache tracking with the processed input_ids + # This is critical - tells cache_wrapper what tokens are being processed + self.cache_wrapper.tokens = input_ids - # Set prompt_cache for generation (fixes missing cache usage in vision path!) - generate_args["prompt_cache"] = self.cache_wrapper.cache + # Set prompt_cache for generation (fixes missing cache usage in vision path!) + generate_args["prompt_cache"] = self.cache_wrapper.cache - return input_ids, embeddings + return input_ids, embeddings def is_cross_prompt_cache_active(self) -> bool: """ diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py index 9eb162bd..1d2a7ab3 100644 --- a/tests/test_vision_cache.py +++ b/tests/test_vision_cache.py @@ -40,7 +40,7 @@ def test_nonswa_model(self): # Load the model model_kit = load_model( - model_path=model_path, max_kv_size=2048, trust_remote_code=True + model_path=model_path, max_kv_size=4096, trust_remote_code=True ) callback_history = [] @@ -111,3 +111,89 @@ def generate_text(prompt): generated_text3 = generate_text(prompt3) self.assertEqual(len(callback_history), 5) self.assertRegex(generated_text3, "chameleon") + + def test_swa_model(self): + """ + Test that image caching works for models with a SWA cache + """ + prompt = "user\nIn one word each, describe the images\nmodel\n" + model_name = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" + model_path = model_getter(model_name=model_name) + images_b64 = [self.toucan_image_b64, self.chameleon_image_b64] + + # Load the model + model_kit = load_model( + model_path=model_path, max_kv_size=4096, trust_remote_code=True + ) + + callback_history = [] + + def prompt_callback(x): + nonlocal callback_history + callback_history.append(x) + return True + + def generate_text(prompt): + nonlocal callback_history + callback_history = [] + + # Tokenize the prompt + prompt_tokens = tokenize(model_kit, prompt) + + # Generate description + generated_text = "" + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + prompt_progress_callback=prompt_callback, + images_b64=images_b64, + max_image_size=MAX_IMAGE_SIZE, + seed=0, + max_tokens=100, + temp=0.0, + repetition_penalty=1.01, # enable the logits processor code path + ): + generated_text += result.text + print(result.text, end="", flush=True) + if result.stop_condition: + break + print() + return generated_text + + generated_text = generate_text(prompt) + self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm + + # ask a followup question + print("--") + prompt2 = ( + prompt + + generated_text + + "\nuser\nwhich direction is each animal facing?\n" + + "model\n" + ) + generated_text2 = generate_text(prompt2) + + # prompt processing by cache_wrapper. less work is done since the images are cached + self.assertEqual(len(callback_history), 2) + self.assertRegex(generated_text2, "toucan.*left") + self.assertRegex(generated_text2, "chameleon.*right") + + # swap the images; full preprocessing happens + images_b64.reverse() + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # attempt cache rewind; images fully processed since the cache can't be trimmed + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # add an image in the followup; all three images are re-processed + images_b64.append(self.chameleon_image_b64) + prompt3 = ( + prompt + + generated_text + + "user\nhow many animals do you count?\nmodel" + ) + generated_text3 = generate_text(prompt3) + self.assertEqual(len(callback_history), 4) + self.assertRegex(generated_text3, "chameleon") From c05c2416bba386ff1261e1627e492b8a4efc6fc9 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 17:17:08 -0400 Subject: [PATCH 08/13] checkpoint --- mlx_engine/cache_wrapper.py | 68 ++++++++++++++----------------- mlx_engine/model_kit/model_kit.py | 47 +++++++++------------ 2 files changed, 51 insertions(+), 64 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 6f30ebed..0d933f07 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -22,13 +22,6 @@ class StopPromptProcessing(Exception): """ -class CacheNotTrimmableError(Exception): - """ - Exception to signal that cache trimming is required but the cache is not trimmable. - Used in vision mode to signal that full reprocessing with vision add-on is needed. - """ - - class CacheWrapper: """ Wrapper class for the MLX LM cache to maintain an in-memory cache @@ -119,30 +112,6 @@ def _find_common_prefix( common_length = max(common_length - length_adjustment, 0) return common_length - def _handle_nontrimmable_cache( - self, num_tokens_to_trim: int, prompt_tokens: mx.array - ): - # Check if we've cached images - if self.prev_images_hash is not None: - logger.warning( - "Cache is not trimmable and vision processing is active. " - "Signaling need for full reprocessing." - ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = None - # Don't clear vision cache - let caller handle full reprocessing - raise CacheNotTrimmableError() - else: - # Non-vision mode - logger.warning( - f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, " - f"but could not: Cache is not trimmable. Clearing the cache instead." - ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - self.clear_vision_cache() - return self.tokens - def _get_unprocessed_tokens( self, prompt_tokens: mx.array, num_tokens_to_exclude: int ): @@ -178,9 +147,13 @@ def _get_unprocessed_tokens( num_tokens_to_trim = num_tokens_in_cache - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): - return self._handle_nontrimmable_cache( - num_tokens_to_trim, prompt_tokens + logger.warning( + f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: Cache is not trimmable. Clearing the cache instead." ) + self.cache = make_prompt_cache(self.model, self.max_kv_size) + self.tokens = prompt_tokens + self.clear_vision_cache() + return self.tokens tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) if tokens_trimmed != num_tokens_to_trim: # If we trimmed fewer tokens than expected, the cache is invalid @@ -383,7 +356,7 @@ def can_reuse_vision_cache( Check if we can skip expensive vision processing and reuse cached KV states. Supports both extending (longer prompt) and rewinding (shorter prompt) - as long as one prompt is a prefix of the other. + as long as one prompt is a prefix of the other AND cache operations are supported. Args: images_b64: Current request's base64-encoded images @@ -411,10 +384,31 @@ def can_reuse_vision_cache( num_tokens_to_exclude=0, ) - # Can reuse if one prompt is a complete prefix of the other - # (common_length equals the length of the shorter prompt) + # Check if one prompt is a complete prefix of the other min_length = min(len(raw_prompt_tokens), len(self.prev_raw_prompt_tokens)) - return common_length == min_length + if common_length != min_length: + return False # Not a prefix relationship + + # Check if cache supports required operations + # For non-trimmable caches (like SWA), we can only reuse vision cache when extending + # (adding new tokens). If we're re-prompting with same/shorter prompt, the cache + # will have generated tokens that need trimming, which non-trimmable caches can't do. + if not can_trim_prompt_cache(self.cache): + num_tokens_in_cache = self._get_num_tokens_in_cache() + if num_tokens_in_cache is None: + return False # Can't determine cache state + + # Only allow reuse for non-trimmable caches when extending the prompt + # If not extending (same or rewinding), cache will need trimming + if len(raw_prompt_tokens) <= len(self.prev_raw_prompt_tokens): + # Not extending - cache will need trimming (has generated tokens) + if num_tokens_in_cache > 0: + logger.info( + "Cannot reuse vision cache: cache is not trimmable and would require trimming" + ) + return False + + return True def record_vision_state(self, images_b64: List[str], raw_prompt_tokens: List[int]): """ diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 7d948f1b..6bb1cdda 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, List, Tuple import mlx_lm from mlx_lm.tokenizer_utils import TokenizerWrapper, StreamingDetokenizer -from mlx_engine.cache_wrapper import CacheWrapper, CacheNotTrimmableError +from mlx_engine.cache_wrapper import CacheWrapper from pathlib import Path import mlx.nn as nn import mlx.core as mx @@ -208,36 +208,29 @@ def process_prompt( ) if can_skip_vision_processing: - try: - # CHEAP PATH: Skip vision tower, reuse cached KV states + # CHEAP PATH: Skip vision tower, reuse cached KV states + logger.info("Reusing cached vision features from previous request") - # Get input_ids with image tokens - input_ids = self._get_input_ids_via_prepare_inputs( - prompt_tokens, images_b64, max_image_size - ) - - # Process like text-only: use cache wrapper to preprocess new tokens - unprocessed_tokens = process_prompt_text_only( - input_ids, - self.cache_wrapper, - generate_args, - draft_model=None, # Vision models don't support draft models - speculative_decoding_toggle=None, - prompt_progress_callback=prompt_progress_callback, - ) - - # Update vision state for next request - self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + # Get input_ids with image tokens (cheap - just tokenization + prepare_inputs) + input_ids = self._get_input_ids_via_prepare_inputs( + prompt_tokens, images_b64, max_image_size + ) - # Return tokens only, no embeddings (model will use text embeddings for new tokens) - return unprocessed_tokens, None + # Process like text-only: use cache wrapper to preprocess new tokens + unprocessed_tokens = process_prompt_text_only( + input_ids, + self.cache_wrapper, + generate_args, + draft_model=None, # Vision models don't support draft models + speculative_decoding_toggle=None, + prompt_progress_callback=prompt_progress_callback, + ) - except CacheNotTrimmableError: - pass - # Fall through to expensive path below + # Update vision state for next request + self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) - # EXPENSIVE PATH: Full vision processing (first request or images changed) - logger.info("Performing full vision processing with images") + # Return tokens only, no embeddings (model will use text embeddings for new tokens) + return unprocessed_tokens, None input_ids, embeddings = self.vision_add_on.compute_embeddings( self.model, prompt_tokens, images_b64, max_size=max_image_size From 973907bdf0ca1cd8f6ab54524e0ce11860d92e9d Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 28 Oct 2025 17:55:32 -0400 Subject: [PATCH 09/13] checkpoint --- mlx_engine/cache_wrapper.py | 34 ++++++++----------------------- mlx_engine/model_kit/model_kit.py | 29 ++++++++++++++++---------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 0d933f07..a852410a 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -61,7 +61,7 @@ def __init__( # Vision prompt caching state self.prev_images_hash: Optional[str] = None - self.prev_raw_prompt_tokens: Optional[List[int]] = None + self.prev_expanded_input_ids: Optional[List[int]] = None def _get_num_tokens_in_cache(self) -> int | None: """ @@ -350,7 +350,7 @@ def _compute_images_hash(self, images_b64: List[str]) -> str: return hashlib.sha256(combined.encode()).hexdigest() def can_reuse_vision_cache( - self, images_b64: List[str], raw_prompt_tokens: List[int] + self, images_b64: List[str], expanded_input_ids: List[int] ) -> bool: """ Check if we can skip expensive vision processing and reuse cached KV states. @@ -360,12 +360,12 @@ def can_reuse_vision_cache( Args: images_b64: Current request's base64-encoded images - raw_prompt_tokens: Current request's raw prompt tokens (before vision processing) + expanded_input_ids: Current request's expanded input_ids (with image tokens expanded to pad tokens) Returns: bool: True if we can skip vision processing, False otherwise """ - if self.prev_images_hash is None or self.prev_raw_prompt_tokens is None: + if self.prev_images_hash is None or self.prev_expanded_input_ids is None: return False # Check if images are identical @@ -373,22 +373,6 @@ def can_reuse_vision_cache( if current_images_hash != self.prev_images_hash: return False - # Use existing _find_common_prefix to check if one prompt is a prefix of the other - current_tokens = mx.array(raw_prompt_tokens) - prev_tokens = mx.array(self.prev_raw_prompt_tokens) - - # Find common prefix length (num_tokens_to_exclude=0 since we don't need that constraint) - common_length = self._find_common_prefix( - current_tokens=prev_tokens, - prompt_tokens=current_tokens, - num_tokens_to_exclude=0, - ) - - # Check if one prompt is a complete prefix of the other - min_length = min(len(raw_prompt_tokens), len(self.prev_raw_prompt_tokens)) - if common_length != min_length: - return False # Not a prefix relationship - # Check if cache supports required operations # For non-trimmable caches (like SWA), we can only reuse vision cache when extending # (adding new tokens). If we're re-prompting with same/shorter prompt, the cache @@ -400,7 +384,7 @@ def can_reuse_vision_cache( # Only allow reuse for non-trimmable caches when extending the prompt # If not extending (same or rewinding), cache will need trimming - if len(raw_prompt_tokens) <= len(self.prev_raw_prompt_tokens): + if len(expanded_input_ids) <= len(self.prev_expanded_input_ids): # Not extending - cache will need trimming (has generated tokens) if num_tokens_in_cache > 0: logger.info( @@ -410,18 +394,18 @@ def can_reuse_vision_cache( return True - def record_vision_state(self, images_b64: List[str], raw_prompt_tokens: List[int]): + def record_vision_state(self, images_b64: List[str], expanded_input_ids: List[int]): """ Record vision processing state for future cache validation. Args: images_b64: Base64-encoded images that were processed - raw_prompt_tokens: Raw prompt tokens (before vision processing) that were used + expanded_input_ids: Expanded input_ids (with image tokens expanded to pad tokens) that were used """ self.prev_images_hash = self._compute_images_hash(images_b64) - self.prev_raw_prompt_tokens = raw_prompt_tokens + self.prev_expanded_input_ids = expanded_input_ids def clear_vision_cache(self): """Clear vision-specific cache state.""" self.prev_images_hash = None - self.prev_raw_prompt_tokens = None + self.prev_expanded_input_ids = None diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 6bb1cdda..7e5e6ac0 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -197,25 +197,26 @@ def process_prompt( "Vision add-on is not loaded, but images were provided for processing" ) - # Convert prompt_tokens to list for cache validation - prompt_tokens_list = ( - prompt_tokens if isinstance(prompt_tokens, list) else prompt_tokens.tolist() + # Get expanded input_ids (cheap operation - just tokenization + prepare_inputs, no vision tower) + # We need this BEFORE the cache check to compare expanded tokens with what's in the cache + input_ids = self._get_input_ids_via_prepare_inputs( + prompt_tokens, images_b64, max_image_size + ) + + # Convert input_ids to list for cache validation + input_ids_list = ( + input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) ) # Check if we can skip expensive vision processing can_skip_vision_processing = self.cache_wrapper.can_reuse_vision_cache( - images_b64, prompt_tokens_list + images_b64, input_ids_list ) if can_skip_vision_processing: # CHEAP PATH: Skip vision tower, reuse cached KV states logger.info("Reusing cached vision features from previous request") - # Get input_ids with image tokens (cheap - just tokenization + prepare_inputs) - input_ids = self._get_input_ids_via_prepare_inputs( - prompt_tokens, images_b64, max_image_size - ) - # Process like text-only: use cache wrapper to preprocess new tokens unprocessed_tokens = process_prompt_text_only( input_ids, @@ -227,17 +228,23 @@ def process_prompt( ) # Update vision state for next request - self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + self.cache_wrapper.record_vision_state(images_b64, input_ids_list) # Return tokens only, no embeddings (model will use text embeddings for new tokens) return unprocessed_tokens, None + # EXPENSIVE PATH: Full vision processing with vision tower input_ids, embeddings = self.vision_add_on.compute_embeddings( self.model, prompt_tokens, images_b64, max_size=max_image_size ) + # Update input_ids_list in case compute_embeddings returns different input_ids + input_ids_list = ( + input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) + ) + # Record vision state for future requests - self.cache_wrapper.record_vision_state(images_b64, prompt_tokens_list) + self.cache_wrapper.record_vision_state(images_b64, input_ids_list) # Initialize cache tracking with the processed input_ids # This is critical - tells cache_wrapper what tokens are being processed From 4e94de98231f42ad01c3942d25e90c173d214b29 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 29 Oct 2025 09:50:37 -0400 Subject: [PATCH 10/13] remove list --- mlx_engine/cache_wrapper.py | 6 +++--- mlx_engine/model_kit/model_kit.py | 16 +++------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index a852410a..a26e1b23 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -61,7 +61,7 @@ def __init__( # Vision prompt caching state self.prev_images_hash: Optional[str] = None - self.prev_expanded_input_ids: Optional[List[int]] = None + self.prev_expanded_input_ids: Optional[mx.array] = None def _get_num_tokens_in_cache(self) -> int | None: """ @@ -350,7 +350,7 @@ def _compute_images_hash(self, images_b64: List[str]) -> str: return hashlib.sha256(combined.encode()).hexdigest() def can_reuse_vision_cache( - self, images_b64: List[str], expanded_input_ids: List[int] + self, images_b64: List[str], expanded_input_ids: mx.array ) -> bool: """ Check if we can skip expensive vision processing and reuse cached KV states. @@ -394,7 +394,7 @@ def can_reuse_vision_cache( return True - def record_vision_state(self, images_b64: List[str], expanded_input_ids: List[int]): + def record_vision_state(self, images_b64: List[str], expanded_input_ids: mx.array): """ Record vision processing state for future cache validation. diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 7e5e6ac0..8c7ba03e 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -203,14 +203,9 @@ def process_prompt( prompt_tokens, images_b64, max_image_size ) - # Convert input_ids to list for cache validation - input_ids_list = ( - input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) - ) - # Check if we can skip expensive vision processing can_skip_vision_processing = self.cache_wrapper.can_reuse_vision_cache( - images_b64, input_ids_list + images_b64, input_ids ) if can_skip_vision_processing: @@ -228,7 +223,7 @@ def process_prompt( ) # Update vision state for next request - self.cache_wrapper.record_vision_state(images_b64, input_ids_list) + self.cache_wrapper.record_vision_state(images_b64, input_ids) # Return tokens only, no embeddings (model will use text embeddings for new tokens) return unprocessed_tokens, None @@ -238,13 +233,8 @@ def process_prompt( self.model, prompt_tokens, images_b64, max_size=max_image_size ) - # Update input_ids_list in case compute_embeddings returns different input_ids - input_ids_list = ( - input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) - ) - # Record vision state for future requests - self.cache_wrapper.record_vision_state(images_b64, input_ids_list) + self.cache_wrapper.record_vision_state(images_b64, input_ids) # Initialize cache tracking with the processed input_ids # This is critical - tells cache_wrapper what tokens are being processed From 36c360a5ee9352439f954b43c182f570197d58f6 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 29 Oct 2025 10:18:15 -0400 Subject: [PATCH 11/13] cleanup --- mlx_engine/cache_wrapper.py | 59 ++++++++++++------- mlx_engine/generate.py | 1 + mlx_engine/model_kit/model_kit.py | 23 +++----- .../vision_model_kit/vision_model_kit.py | 6 +- .../vision_model_kit/vision_model_wrapper.py | 6 +- tests/test_vision_cache.py | 5 +- 6 files changed, 54 insertions(+), 46 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index a26e1b23..2a7d9b5f 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -9,6 +9,7 @@ import mlx.core as mx import mlx.nn as nn import sys +import hashlib PROMPT_PROCESSING_CHUNK_SIZE = 512 @@ -140,30 +141,21 @@ def _get_unprocessed_tokens( logger.warning( "Could not determine the number of tokens in the cache, clearing the cache." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - self.clear_vision_cache() - return self.tokens + return self._reset_cache(prompt_tokens) num_tokens_to_trim = num_tokens_in_cache - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): logger.warning( f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: Cache is not trimmable. Clearing the cache instead." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - self.clear_vision_cache() - return self.tokens + return self._reset_cache(prompt_tokens) tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) if tokens_trimmed != num_tokens_to_trim: # If we trimmed fewer tokens than expected, the cache is invalid logger.error( f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected ({num_tokens_to_trim}). Clearing the cache." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - self.clear_vision_cache() - return self.tokens + return self._reset_cache(prompt_tokens) logger.info(f"Trimmed {num_tokens_to_trim} tokens from the prompt cache") # Keep track of the prompt tokens @@ -228,9 +220,7 @@ def _prefill( ) num_tokens_in_cache = None if num_tokens_in_cache is None: - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = None - self.clear_vision_cache() + self._reset_cache() else: # Remember which tokens were processed so far, so that we can continue processing at a later point self.tokens = self.tokens[:num_tokens_in_cache] @@ -262,9 +252,7 @@ def set_draft_model(self, draft_model: nn.Module): # clear the current cache, append draft model cache to the end of the main model cache as per # https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382 logger.info("Clearing current prompt cache and adding draft model to the cache") - self.tokens = None - self.clear_vision_cache() - self.cache: List[Any] = make_prompt_cache(self.model) + self._reset_cache(use_max_kv_size=False) if draft_model is not None: self.cache += make_prompt_cache(draft_model) self.draft_model = draft_model @@ -344,8 +332,6 @@ def record_generated_token(self, token): def _compute_images_hash(self, images_b64: List[str]) -> str: """Compute hash of images for cache validation.""" - import hashlib - combined = "".join(images_b64) return hashlib.sha256(combined.encode()).hexdigest() @@ -405,6 +391,39 @@ def record_vision_state(self, images_b64: List[str], expanded_input_ids: mx.arra self.prev_images_hash = self._compute_images_hash(images_b64) self.prev_expanded_input_ids = expanded_input_ids + def set_vision_tokens(self, input_ids: mx.array): + """ + Set the token state after vision processing. + + This is used after vision tower processing where embeddings have been computed + and the full expanded input_ids should be tracked in the cache. + + Args: + input_ids: The full expanded input_ids (with image tokens expanded to pad tokens) + """ + self.tokens = input_ids + + def _reset_cache( + self, tokens: Optional[mx.array] = None, use_max_kv_size: bool = True + ) -> Optional[mx.array]: + """ + Reset the cache to a fresh state. + + Args: + tokens: The tokens to set (or None to clear) + use_max_kv_size: Whether to pass max_kv_size to make_prompt_cache + + Returns: + The tokens that were set + """ + if use_max_kv_size: + self.cache = make_prompt_cache(self.model, self.max_kv_size) + else: + self.cache = make_prompt_cache(self.model) + self.tokens = tokens + self.clear_vision_cache() + return self.tokens + def clear_vision_cache(self): """Clear vision-specific cache state.""" self.prev_images_hash = None diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index e96b62e3..61aa57ce 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -208,6 +208,7 @@ def create_generator( ValueError: If top_logprobs exceeds MAX_TOP_LOGPROBS or if any parameters are invalid """ set_seed(seed) + images_b64 = [] if images_b64 is None else images_b64 generate_args = {} # For each call to create_generator, wrap all prompt progress calls with a ratchet that diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 8c7ba03e..81589f2a 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -141,7 +141,7 @@ def tokenize(self, prompt: str) -> List[int]: def _get_input_ids_via_prepare_inputs( self, prompt_tokens: mx.array, - images_b64: List[str], + images_b64: list[str], max_image_size: tuple[int, int] | None, ) -> mx.array: """ @@ -168,14 +168,14 @@ def _get_input_ids_via_prepare_inputs( def process_prompt( self, prompt_tokens, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_progress_callback: Optional[Callable[[float], bool]], generate_args: dict, max_image_size: tuple[int, int] | None, speculative_decoding_toggle: Optional[bool] = None, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### - is_text_only_processing = images_b64 is None or len(images_b64) == 0 + is_text_only_processing = len(images_b64) == 0 if is_text_only_processing: if len(prompt_tokens) == 0: logger.warning( @@ -197,8 +197,7 @@ def process_prompt( "Vision add-on is not loaded, but images were provided for processing" ) - # Get expanded input_ids (cheap operation - just tokenization + prepare_inputs, no vision tower) - # We need this BEFORE the cache check to compare expanded tokens with what's in the cache + # Get expanded input_ids, which add image pad tokens to the prompt input_ids = self._get_input_ids_via_prepare_inputs( prompt_tokens, images_b64, max_image_size ) @@ -209,10 +208,7 @@ def process_prompt( ) if can_skip_vision_processing: - # CHEAP PATH: Skip vision tower, reuse cached KV states - logger.info("Reusing cached vision features from previous request") - - # Process like text-only: use cache wrapper to preprocess new tokens + # Skip vision tower, reuse cached KV states unprocessed_tokens = process_prompt_text_only( input_ids, self.cache_wrapper, @@ -225,10 +221,9 @@ def process_prompt( # Update vision state for next request self.cache_wrapper.record_vision_state(images_b64, input_ids) - # Return tokens only, no embeddings (model will use text embeddings for new tokens) return unprocessed_tokens, None - # EXPENSIVE PATH: Full vision processing with vision tower + # Full vision processing input_ids, embeddings = self.vision_add_on.compute_embeddings( self.model, prompt_tokens, images_b64, max_size=max_image_size ) @@ -236,11 +231,9 @@ def process_prompt( # Record vision state for future requests self.cache_wrapper.record_vision_state(images_b64, input_ids) - # Initialize cache tracking with the processed input_ids - # This is critical - tells cache_wrapper what tokens are being processed - self.cache_wrapper.tokens = input_ids + # Set the tokens to the full expanded input_ids + self.cache_wrapper.set_vision_tokens(input_ids) - # Set prompt_cache for generation (fixes missing cache usage in vision path!) generate_args["prompt_cache"] = self.cache_wrapper.cache return input_ids, embeddings diff --git a/mlx_engine/vision_model_kit/vision_model_kit.py b/mlx_engine/vision_model_kit/vision_model_kit.py index ef8d4e3b..ae6e93dc 100644 --- a/mlx_engine/vision_model_kit/vision_model_kit.py +++ b/mlx_engine/vision_model_kit/vision_model_kit.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List, Tuple +from typing import Union, Optional, Tuple from mlx_engine.model_kit.model_kit import ModelKit import logging @@ -98,7 +98,7 @@ def _reset_for_prediction(self): def process_prompt( self, prompt_tokens, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_progress_callback, generate_args, max_image_size: tuple[int, int] | None, @@ -122,7 +122,7 @@ def process_prompt( # The VLM input_ids shape is important, but mlx_lm expects a flattened array # Send back a fake shape and input_ids, and save the real shape in `self.model.input_ids` - if images_b64 is None or len(images_b64) == 0: + if len(images_b64) == 0: # For text-only, enable mlx-lm prompt processing return self.model.input_ids.reshape(-1), None # Disable mlx-lm prompt processing by returning a fake input diff --git a/mlx_engine/vision_model_kit/vision_model_wrapper.py b/mlx_engine/vision_model_kit/vision_model_wrapper.py index 4f273b62..bc2634f9 100644 --- a/mlx_engine/vision_model_kit/vision_model_wrapper.py +++ b/mlx_engine/vision_model_kit/vision_model_wrapper.py @@ -2,7 +2,6 @@ import logging from mlx_vlm.models.cache import KVCache, SimpleKVCache -from typing import List, Optional from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( common_process_prompt_with_images, ) @@ -158,7 +157,7 @@ def record_sampled_token(self, token: int) -> None: def process_prompt_with_images( self, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_tokens: mx.array, processor, detokenizer, @@ -168,9 +167,6 @@ def process_prompt_with_images( This method generates the input_ids, pixel_values, and mask for the vision model Call this before starting evaluation """ - if images_b64 is None: - images_b64 = [] - # Handle the case with no images if len(images_b64) == 0: detokenizer.reset() diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py index 1d2a7ab3..44d49ea3 100644 --- a/tests/test_vision_cache.py +++ b/tests/test_vision_cache.py @@ -31,7 +31,7 @@ def setUpClass(cls): @pytest.mark.heavy def test_nonswa_model(self): """ - Test that image caching works for models without a SWA cache + Test that image caching works for models without SWA (sliding window attn) layers """ prompt = "[INST][IMG][IMG]In one word each, describe the animal in the images[/INST]\n" model_name = "lmstudio-community/Mistral-Small-3.2-24B-Instruct-2506-MLX-4bit" @@ -114,7 +114,7 @@ def generate_text(prompt): def test_swa_model(self): """ - Test that image caching works for models with a SWA cache + Test that image caching works for models with SWA layers """ prompt = "user\nIn one word each, describe the images\nmodel\n" model_name = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" @@ -164,7 +164,6 @@ def generate_text(prompt): self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm # ask a followup question - print("--") prompt2 = ( prompt + generated_text From fd38585257b6c24ddd60bc4d7d43d89487cb9fca Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 29 Oct 2025 10:20:59 -0400 Subject: [PATCH 12/13] cleanup --- mlx_engine/model_kit/model_kit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 81589f2a..131e1e9a 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -105,10 +105,7 @@ def _full_model_init( self.kv_group_size = kv_group_size self.quantized_kv_start = quantized_kv_start vision_add_on_class = self.VISION_ADD_ON_MAP.get(self.model_type) - should_load_vision_add_on = ( - vision_add_on_class is not None and "vision_config" in config_json - ) - if should_load_vision_add_on: + if vision_add_on_class and "vision_config" in config_json: self.vision_add_on = vision_add_on_class(model_path) logger.info("Model loaded successfully") From 7dfe3cdeeedf289b9b8bed72f67fcabe201765c2 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 29 Oct 2025 10:36:19 -0400 Subject: [PATCH 13/13] cleanup --- mlx_engine/cache_wrapper.py | 3 --- mlx_engine/vision_model_kit/vision_model_wrapper.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 2a7d9b5f..e41219d7 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -373,9 +373,6 @@ def can_reuse_vision_cache( if len(expanded_input_ids) <= len(self.prev_expanded_input_ids): # Not extending - cache will need trimming (has generated tokens) if num_tokens_in_cache > 0: - logger.info( - "Cannot reuse vision cache: cache is not trimmable and would require trimming" - ) return False return True diff --git a/mlx_engine/vision_model_kit/vision_model_wrapper.py b/mlx_engine/vision_model_kit/vision_model_wrapper.py index bc2634f9..5c688dbd 100644 --- a/mlx_engine/vision_model_kit/vision_model_wrapper.py +++ b/mlx_engine/vision_model_kit/vision_model_wrapper.py @@ -110,11 +110,6 @@ def __call__(self, *args, input_embeddings=None, **kwargs): "decoder_input_ids": self.decoder_input_ids, "encoder_outputs": outputs.encoder_outputs, } - # elif self.vision_model.config.model_type == "qwen3_vl": - # self.language_model_kwargs = { - # "visual_pos_masks": outputs.visual_pos_masks, - # "deepstack_visual_embeds": outputs.deepstack_visual_embeds, - # } # Add the cache we created here to the language model kwargs self.language_model_kwargs["cache"] = cache