diff --git a/demo.py b/demo.py index dfcc3fd0..d48adbcb 100644 --- a/demo.py +++ b/demo.py @@ -92,6 +92,9 @@ def setup_arg_parser(): action="store_true", help="Enable printed prompt processing progress callback", ) + parser.add_argument( + "--max-img-size", type=int, help="Downscale images to this side length (px)" + ) return parser @@ -203,22 +206,22 @@ def prompt_progress_callback(percent): tf_tokenizer = AutoProcessor.from_pretrained(model_path) images_base64 = [image_to_base64(img_path) for img_path in args.images] conversation = [ - DEFAULT_SYSTEM_PROMPT, + # DEFAULT_SYSTEM_PROMPT, { "role": "user", "content": [ - {"type": "text", "text": prompt}, *[ {"type": "image", "base64": image_b64} for image_b64 in images_base64 ], + {"type": "text", "text": prompt}, ], }, ] else: tf_tokenizer = AutoTokenizer.from_pretrained(model_path) conversation = [ - DEFAULT_SYSTEM_PROMPT, + # DEFAULT_SYSTEM_PROMPT, {"role": "user", "content": prompt}, ] prompt = tf_tokenizer.apply_chat_template( @@ -232,12 +235,15 @@ def prompt_progress_callback(percent): # Initialize generation stats collector stats_collector = GenerationStatsCollector() + # Clamp image size + max_img_size = (args.max_img_size, args.max_img_size) if args.max_img_size else None + # Generate the response generator = create_generator( model_kit, prompt_tokens, images_b64=images_base64, - max_image_size=(1024, 1024), + max_image_size=max_img_size, stop_strings=args.stop_strings, max_tokens=1024, top_logprobs=args.top_logprobs, diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 96673954..e41219d7 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -9,6 +9,7 @@ import mlx.core as mx import mlx.nn as nn import sys +import hashlib PROMPT_PROCESSING_CHUNK_SIZE = 512 @@ -59,6 +60,10 @@ def __init__( ) self.chunk_size = chunk_size + # Vision prompt caching state + self.prev_images_hash: Optional[str] = None + self.prev_expanded_input_ids: Optional[mx.array] = None + def _get_num_tokens_in_cache(self) -> int | None: """ Get the number of tokens in the cache. @@ -136,27 +141,21 @@ def _get_unprocessed_tokens( logger.warning( "Could not determine the number of tokens in the cache, clearing the cache." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - return self.tokens + return self._reset_cache(prompt_tokens) num_tokens_to_trim = num_tokens_in_cache - common_prefix if num_tokens_to_trim > 0: if not can_trim_prompt_cache(self.cache): logger.warning( f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: Cache is not trimmable. Clearing the cache instead." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - return self.tokens + return self._reset_cache(prompt_tokens) tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim) if tokens_trimmed != num_tokens_to_trim: # If we trimmed fewer tokens than expected, the cache is invalid logger.error( f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected ({num_tokens_to_trim}). Clearing the cache." ) - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = prompt_tokens - return self.tokens + return self._reset_cache(prompt_tokens) logger.info(f"Trimmed {num_tokens_to_trim} tokens from the prompt cache") # Keep track of the prompt tokens @@ -221,8 +220,7 @@ def _prefill( ) num_tokens_in_cache = None if num_tokens_in_cache is None: - self.cache = make_prompt_cache(self.model, self.max_kv_size) - self.tokens = None + self._reset_cache() else: # Remember which tokens were processed so far, so that we can continue processing at a later point self.tokens = self.tokens[:num_tokens_in_cache] @@ -254,8 +252,7 @@ def set_draft_model(self, draft_model: nn.Module): # clear the current cache, append draft model cache to the end of the main model cache as per # https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382 logger.info("Clearing current prompt cache and adding draft model to the cache") - self.tokens = None - self.cache: List[Any] = make_prompt_cache(self.model) + self._reset_cache(use_max_kv_size=False) if draft_model is not None: self.cache += make_prompt_cache(draft_model) self.draft_model = draft_model @@ -332,3 +329,99 @@ def record_generated_token(self, token): Add the generated token to the token list, so that we can map the token to the KV cache. """ self.tokens = mx.concat([self.tokens, mx.array([token])]) + + def _compute_images_hash(self, images_b64: List[str]) -> str: + """Compute hash of images for cache validation.""" + combined = "".join(images_b64) + return hashlib.sha256(combined.encode()).hexdigest() + + def can_reuse_vision_cache( + self, images_b64: List[str], expanded_input_ids: mx.array + ) -> bool: + """ + Check if we can skip expensive vision processing and reuse cached KV states. + + Supports both extending (longer prompt) and rewinding (shorter prompt) + as long as one prompt is a prefix of the other AND cache operations are supported. + + Args: + images_b64: Current request's base64-encoded images + expanded_input_ids: Current request's expanded input_ids (with image tokens expanded to pad tokens) + + Returns: + bool: True if we can skip vision processing, False otherwise + """ + if self.prev_images_hash is None or self.prev_expanded_input_ids is None: + return False + + # Check if images are identical + current_images_hash = self._compute_images_hash(images_b64) + if current_images_hash != self.prev_images_hash: + return False + + # Check if cache supports required operations + # For non-trimmable caches (like SWA), we can only reuse vision cache when extending + # (adding new tokens). If we're re-prompting with same/shorter prompt, the cache + # will have generated tokens that need trimming, which non-trimmable caches can't do. + if not can_trim_prompt_cache(self.cache): + num_tokens_in_cache = self._get_num_tokens_in_cache() + if num_tokens_in_cache is None: + return False # Can't determine cache state + + # Only allow reuse for non-trimmable caches when extending the prompt + # If not extending (same or rewinding), cache will need trimming + if len(expanded_input_ids) <= len(self.prev_expanded_input_ids): + # Not extending - cache will need trimming (has generated tokens) + if num_tokens_in_cache > 0: + return False + + return True + + def record_vision_state(self, images_b64: List[str], expanded_input_ids: mx.array): + """ + Record vision processing state for future cache validation. + + Args: + images_b64: Base64-encoded images that were processed + expanded_input_ids: Expanded input_ids (with image tokens expanded to pad tokens) that were used + """ + self.prev_images_hash = self._compute_images_hash(images_b64) + self.prev_expanded_input_ids = expanded_input_ids + + def set_vision_tokens(self, input_ids: mx.array): + """ + Set the token state after vision processing. + + This is used after vision tower processing where embeddings have been computed + and the full expanded input_ids should be tracked in the cache. + + Args: + input_ids: The full expanded input_ids (with image tokens expanded to pad tokens) + """ + self.tokens = input_ids + + def _reset_cache( + self, tokens: Optional[mx.array] = None, use_max_kv_size: bool = True + ) -> Optional[mx.array]: + """ + Reset the cache to a fresh state. + + Args: + tokens: The tokens to set (or None to clear) + use_max_kv_size: Whether to pass max_kv_size to make_prompt_cache + + Returns: + The tokens that were set + """ + if use_max_kv_size: + self.cache = make_prompt_cache(self.model, self.max_kv_size) + else: + self.cache = make_prompt_cache(self.model) + self.tokens = tokens + self.clear_vision_cache() + return self.tokens + + def clear_vision_cache(self): + """Clear vision-specific cache state.""" + self.prev_images_hash = None + self.prev_expanded_input_ids = None diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index e96b62e3..61aa57ce 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -208,6 +208,7 @@ def create_generator( ValueError: If top_logprobs exceeds MAX_TOP_LOGPROBS or if any parameters are invalid """ set_seed(seed) + images_b64 = [] if images_b64 is None else images_b64 generate_args = {} # For each call to create_generator, wrap all prompt progress calls with a ratchet that diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 08dfa877..131e1e9a 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -16,6 +16,9 @@ from mlx_engine.model_kit.vision_add_ons.qwen2_vl import Qwen2_VLVisionAddOn from mlx_engine.utils.kv_cache_quantization import get_kv_cache_quantization_params from mlx_engine.utils.prompt_processing import process_prompt_text_only +from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, +) logger = logging.getLogger(__name__) @@ -51,7 +54,6 @@ class ModelKit: tokenizer: TokenizerWrapper = None detokenizer: StreamingDetokenizer = None cache_wrapper: Optional[CacheWrapper] = None - _cross_prompt_cache_active: bool = False max_kv_size: Optional[int] = None kv_bits: Optional[int] = None kv_group_size: Optional[int] = None @@ -103,10 +105,7 @@ def _full_model_init( self.kv_group_size = kv_group_size self.quantized_kv_start = quantized_kv_start vision_add_on_class = self.VISION_ADD_ON_MAP.get(self.model_type) - should_load_vision_add_on = ( - vision_add_on_class is not None and "vision_config" in config_json - ) - if should_load_vision_add_on: + if vision_add_on_class and "vision_config" in config_json: self.vision_add_on = vision_add_on_class(model_path) logger.info("Model loaded successfully") @@ -136,19 +135,45 @@ def tokenize(self, prompt: str) -> List[int]: return [ids] return ids + def _get_input_ids_via_prepare_inputs( + self, + prompt_tokens: mx.array, + images_b64: list[str], + max_image_size: tuple[int, int] | None, + ) -> mx.array: + """ + Get input_ids with image tokens inserted (cheap operation). + Calls common_process_prompt_with_images but skips expensive vision processing. + """ + # Determine should_pad based on model type + # Qwen models need should_pad=False, others use True (default) + should_pad = self.model_type not in ["qwen2_vl", "qwen2_5_vl"] + + processed = common_process_prompt_with_images( + prompt_tokens=prompt_tokens, + images_b64=images_b64, + processor=self.vision_add_on.processor, + config=self.vision_add_on.config, + max_size=max_image_size, + should_pad=should_pad, + ) + + # Return input_ids, squeeze batch dimension if present + input_ids = processed.input_ids + return input_ids.squeeze(0) if input_ids.ndim > 1 else input_ids + def process_prompt( self, prompt_tokens, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_progress_callback: Optional[Callable[[float], bool]], generate_args: dict, max_image_size: tuple[int, int] | None, speculative_decoding_toggle: Optional[bool] = None, ) -> Tuple[mx.array, Optional[mx.array]]: ### TEXT-ONLY PROCESS_PROMPT ### - is_text_only_processing = images_b64 is None or len(images_b64) == 0 + is_text_only_processing = len(images_b64) == 0 if is_text_only_processing: - self._cross_prompt_cache_active = True if len(prompt_tokens) == 0: logger.warning( "Received empty prompt. Generation quality will likely be poor" @@ -163,23 +188,62 @@ def process_prompt( speculative_decoding_toggle, prompt_progress_callback, ), None - ### WITH IMAGES PROMPT PROCESSING ###s + ### WITH IMAGES PROMPT PROCESSING ### if self.vision_add_on is None: raise ValueError( "Vision add-on is not loaded, but images were provided for processing" ) - self._cross_prompt_cache_active = False + + # Get expanded input_ids, which add image pad tokens to the prompt + input_ids = self._get_input_ids_via_prepare_inputs( + prompt_tokens, images_b64, max_image_size + ) + + # Check if we can skip expensive vision processing + can_skip_vision_processing = self.cache_wrapper.can_reuse_vision_cache( + images_b64, input_ids + ) + + if can_skip_vision_processing: + # Skip vision tower, reuse cached KV states + unprocessed_tokens = process_prompt_text_only( + input_ids, + self.cache_wrapper, + generate_args, + draft_model=None, # Vision models don't support draft models + speculative_decoding_toggle=None, + prompt_progress_callback=prompt_progress_callback, + ) + + # Update vision state for next request + self.cache_wrapper.record_vision_state(images_b64, input_ids) + + return unprocessed_tokens, None + + # Full vision processing input_ids, embeddings = self.vision_add_on.compute_embeddings( self.model, prompt_tokens, images_b64, max_size=max_image_size ) + + # Record vision state for future requests + self.cache_wrapper.record_vision_state(images_b64, input_ids) + + # Set the tokens to the full expanded input_ids + self.cache_wrapper.set_vision_tokens(input_ids) + + generate_args["prompt_cache"] = self.cache_wrapper.cache + return input_ids, embeddings def is_cross_prompt_cache_active(self) -> bool: """ Check if cross-prompt caching is currently enabled. Can be overridden by subclasses for custom behavior. + + ModelKit always supports cross-prompt caching. + VisionModelKit overrides this to return False. """ - return self._cross_prompt_cache_active + return True def record_token_to_cache(self, token: int) -> None: self.cache_wrapper.record_generated_token(token) diff --git a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py index 0dde9705..2736d756 100644 --- a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py +++ b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py @@ -22,6 +22,7 @@ def common_process_prompt_with_images( processor: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], config, # expected to be a ModelConfig object as defined by mlx-vlm. Can vary by model max_size: tuple[int, int] | None, + should_pad: bool = True, ) -> ProcessedImagePrompt: """ Common prompt processing used by mlx-vlm vision add-ons. @@ -33,6 +34,7 @@ def common_process_prompt_with_images( processor: Tokenizer/processor for the model config: Model configuration object max_size: Maximum image size as (width, height) tuple. If None, no resizing. + should_pad: Whether to pad images to uniform size. Defaults to True. """ if len(images_b64) == 0: raise ValueError("Images must be non-empty") @@ -45,7 +47,7 @@ def common_process_prompt_with_images( logger.info(f"Prompt dump: {prompt}\n") images = convert_to_pil(images_b64) - images = custom_resize(images, max_size=max_size) + images = custom_resize(images, max_size=max_size, should_pad=should_pad) if hasattr(config, "image_token_index"): image_token_index = config.image_token_index diff --git a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py index 124892b0..ad03945d 100644 --- a/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py +++ b/mlx_engine/model_kit/vision_add_ons/qwen2_vl.py @@ -7,7 +7,9 @@ from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn from mlx_engine.model_kit.vision_add_ons.load_utils import load_vision_addon -from mlx_engine.model_kit.vision_add_ons.qwen_vl_utils import compute_qwen_vl_embeddings +from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, +) from mlx_vlm.models.qwen2_5_vl import ( VisionModel as Qwen25VLVisionTower, @@ -80,12 +82,47 @@ def compute_embeddings( """ Compute input_ids and embeddings for text with images. """ - - return compute_qwen_vl_embeddings( - addon=self, - text_model=text_model, + # Process prompt with images (cheap operation - tokenization and prepare_inputs) + # Note: Qwen models require should_pad=False for images + processed = common_process_prompt_with_images( prompt_tokens=prompt_tokens, images_b64=images_b64, - qwen_vl_version=2, + processor=self.processor, + config=self.config, max_size=max_size, + should_pad=False, + ) + + input_ids = processed.input_ids + pixel_values = processed.pixel_values + + # Get image_grid_thw from other_inputs if present + grid_thw = processed.other_inputs.get("image_grid_thw") + + # Get text embeddings + input_embeddings = text_model.language_model.model.embed_tokens(input_ids) + + # If no images, return input_ids and input_embeddings + if pixel_values is None: + return input_ids.squeeze(0), input_embeddings.squeeze(0) + + # Ensure pixel values are in the right format for vision tower + if pixel_values.dtype != input_embeddings.dtype: + pixel_values = pixel_values.astype(input_embeddings.dtype) + + # Process image through vision tower (expensive operation) + hidden_states = self.vision_tower( + pixel_values, grid_thw, output_hidden_states=False ) + + # Merge image features with text embeddings (expensive operation) + final_inputs_embeds = self.model_cls.merge_input_ids_with_image_features( + self.config.image_token_id, + self.config.video_token_id, + hidden_states, + input_embeddings, + input_ids, + ) + + # Remove batch dimension + return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) diff --git a/mlx_engine/vision_model_kit/vision_model_kit.py b/mlx_engine/vision_model_kit/vision_model_kit.py index ef8d4e3b..ae6e93dc 100644 --- a/mlx_engine/vision_model_kit/vision_model_kit.py +++ b/mlx_engine/vision_model_kit/vision_model_kit.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List, Tuple +from typing import Union, Optional, Tuple from mlx_engine.model_kit.model_kit import ModelKit import logging @@ -98,7 +98,7 @@ def _reset_for_prediction(self): def process_prompt( self, prompt_tokens, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_progress_callback, generate_args, max_image_size: tuple[int, int] | None, @@ -122,7 +122,7 @@ def process_prompt( # The VLM input_ids shape is important, but mlx_lm expects a flattened array # Send back a fake shape and input_ids, and save the real shape in `self.model.input_ids` - if images_b64 is None or len(images_b64) == 0: + if len(images_b64) == 0: # For text-only, enable mlx-lm prompt processing return self.model.input_ids.reshape(-1), None # Disable mlx-lm prompt processing by returning a fake input diff --git a/mlx_engine/vision_model_kit/vision_model_wrapper.py b/mlx_engine/vision_model_kit/vision_model_wrapper.py index e71b41c9..5c688dbd 100644 --- a/mlx_engine/vision_model_kit/vision_model_wrapper.py +++ b/mlx_engine/vision_model_kit/vision_model_wrapper.py @@ -2,7 +2,6 @@ import logging from mlx_vlm.models.cache import KVCache, SimpleKVCache -from typing import List, Optional from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( common_process_prompt_with_images, ) @@ -153,7 +152,7 @@ def record_sampled_token(self, token: int) -> None: def process_prompt_with_images( self, - images_b64: Optional[List[str]], + images_b64: list[str], prompt_tokens: mx.array, processor, detokenizer, @@ -163,9 +162,6 @@ def process_prompt_with_images( This method generates the input_ids, pixel_values, and mask for the vision model Call this before starting evaluation """ - if images_b64 is None: - images_b64 = [] - # Handle the case with no images if len(images_b64) == 0: detokenizer.reset() diff --git a/tests/test_vision_cache.py b/tests/test_vision_cache.py new file mode 100644 index 00000000..44d49ea3 --- /dev/null +++ b/tests/test_vision_cache.py @@ -0,0 +1,198 @@ +import unittest +import base64 +from pathlib import Path +from mlx_engine.generate import ( + load_model, + tokenize, + create_generator, +) +from tests.shared import model_getter +import pytest + +MAX_IMAGE_SIZE = (1024, 1024) + + +class TestVisionCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up test resources that will be shared across all test methods""" + # Read and encode test images + cls.toucan_path = Path(__file__).parent.parent / "demo-data" / "toucan.jpeg" + with open(cls.toucan_path, "rb") as image_file: + cls.toucan_image_b64 = base64.b64encode(image_file.read()).decode("utf-8") + cls.chameleon_image_path = ( + Path(__file__).parent.parent / "demo-data" / "chameleon.webp" + ) + with open(cls.chameleon_image_path, "rb") as chameleon_image_file: + cls.chameleon_image_b64 = base64.b64encode( + chameleon_image_file.read() + ).decode("utf-8") + + @pytest.mark.heavy + def test_nonswa_model(self): + """ + Test that image caching works for models without SWA (sliding window attn) layers + """ + prompt = "[INST][IMG][IMG]In one word each, describe the animal in the images[/INST]\n" + model_name = "lmstudio-community/Mistral-Small-3.2-24B-Instruct-2506-MLX-4bit" + model_path = model_getter(model_name=model_name) + images_b64 = [self.toucan_image_b64, self.chameleon_image_b64] + + # Load the model + model_kit = load_model( + model_path=model_path, max_kv_size=4096, trust_remote_code=True + ) + + callback_history = [] + + def prompt_callback(x): + nonlocal callback_history + callback_history.append(x) + return True + + def generate_text(prompt): + nonlocal callback_history + callback_history = [] + + # Tokenize the prompt + prompt_tokens = tokenize(model_kit, prompt) + + # Generate description + generated_text = "" + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + prompt_progress_callback=prompt_callback, + images_b64=images_b64, + max_image_size=MAX_IMAGE_SIZE, + seed=0, + max_tokens=100, + temp=0.0, + repetition_penalty=1.01, # enable the logits processor code path + ): + generated_text += result.text + print(result.text, end="", flush=True) + if result.stop_condition: + break + print() + return generated_text + + generated_text = generate_text(prompt) + self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm + + # ask a followup question + print("--") + prompt2 = ( + prompt + generated_text + "[INST]what color is each animal's eyes?[/INST]" + ) + generated_text2 = generate_text(prompt2) + + # prompt processing by cache_wrapper. less work is done since the images are cached + self.assertEqual(len(callback_history), 2) + self.assertRegex(generated_text2, "toucan.*dark") + self.assertRegex(generated_text2, "chameleon.*orange") + + # rewind the cache but swap the images. full preprocessing happens + images_b64.reverse() + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # rewind the cache and re-prompt; images are not processed + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 2) + + # add an image in the followup; all three images are re-processed + images_b64.append(self.chameleon_image_b64) + prompt3 = ( + prompt + + generated_text + + "[INST][IMG]do you see two toucans or chameleons?[/INST]" + ) + generated_text3 = generate_text(prompt3) + self.assertEqual(len(callback_history), 5) + self.assertRegex(generated_text3, "chameleon") + + def test_swa_model(self): + """ + Test that image caching works for models with SWA layers + """ + prompt = "user\nIn one word each, describe the images\nmodel\n" + model_name = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit" + model_path = model_getter(model_name=model_name) + images_b64 = [self.toucan_image_b64, self.chameleon_image_b64] + + # Load the model + model_kit = load_model( + model_path=model_path, max_kv_size=4096, trust_remote_code=True + ) + + callback_history = [] + + def prompt_callback(x): + nonlocal callback_history + callback_history.append(x) + return True + + def generate_text(prompt): + nonlocal callback_history + callback_history = [] + + # Tokenize the prompt + prompt_tokens = tokenize(model_kit, prompt) + + # Generate description + generated_text = "" + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + prompt_progress_callback=prompt_callback, + images_b64=images_b64, + max_image_size=MAX_IMAGE_SIZE, + seed=0, + max_tokens=100, + temp=0.0, + repetition_penalty=1.01, # enable the logits processor code path + ): + generated_text += result.text + print(result.text, end="", flush=True) + if result.stop_condition: + break + print() + return generated_text + + generated_text = generate_text(prompt) + self.assertEqual(len(callback_history), 4) # prompt processing by mlx-lm + + # ask a followup question + prompt2 = ( + prompt + + generated_text + + "\nuser\nwhich direction is each animal facing?\n" + + "model\n" + ) + generated_text2 = generate_text(prompt2) + + # prompt processing by cache_wrapper. less work is done since the images are cached + self.assertEqual(len(callback_history), 2) + self.assertRegex(generated_text2, "toucan.*left") + self.assertRegex(generated_text2, "chameleon.*right") + + # swap the images; full preprocessing happens + images_b64.reverse() + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # attempt cache rewind; images fully processed since the cache can't be trimmed + _ = generate_text(prompt) + self.assertEqual(len(callback_history), 4) + + # add an image in the followup; all three images are re-processed + images_b64.append(self.chameleon_image_b64) + prompt3 = ( + prompt + + generated_text + + "user\nhow many animals do you count?\nmodel" + ) + generated_text3 = generate_text(prompt3) + self.assertEqual(len(callback_history), 4) + self.assertRegex(generated_text3, "chameleon")