Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def setup_arg_parser():
action="store_true",
help="Enable printed prompt processing progress callback",
)
parser.add_argument(
"--max-img-size", type=int, help="Downscale images to this side length (px)"
)
return parser


Expand Down Expand Up @@ -203,22 +206,22 @@ def prompt_progress_callback(percent):
tf_tokenizer = AutoProcessor.from_pretrained(model_path)
images_base64 = [image_to_base64(img_path) for img_path in args.images]
conversation = [
DEFAULT_SYSTEM_PROMPT,
# DEFAULT_SYSTEM_PROMPT,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
*[
{"type": "image", "base64": image_b64}
for image_b64 in images_base64
],
{"type": "text", "text": prompt},
],
},
]
else:
tf_tokenizer = AutoTokenizer.from_pretrained(model_path)
conversation = [
DEFAULT_SYSTEM_PROMPT,
# DEFAULT_SYSTEM_PROMPT,
{"role": "user", "content": prompt},
]
prompt = tf_tokenizer.apply_chat_template(
Expand All @@ -232,12 +235,15 @@ def prompt_progress_callback(percent):
# Initialize generation stats collector
stats_collector = GenerationStatsCollector()

# Clamp image size
max_img_size = (args.max_img_size, args.max_img_size) if args.max_img_size else None

# Generate the response
generator = create_generator(
model_kit,
prompt_tokens,
images_b64=images_base64,
max_image_size=(1024, 1024),
max_image_size=max_img_size,
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
Expand Down
119 changes: 106 additions & 13 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import mlx.core as mx
import mlx.nn as nn
import sys
import hashlib


PROMPT_PROCESSING_CHUNK_SIZE = 512
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(
)
self.chunk_size = chunk_size

# Vision prompt caching state
self.prev_images_hash: Optional[str] = None
self.prev_expanded_input_ids: Optional[mx.array] = None

def _get_num_tokens_in_cache(self) -> int | None:
"""
Get the number of tokens in the cache.
Expand Down Expand Up @@ -136,27 +141,21 @@ def _get_unprocessed_tokens(
logger.warning(
"Could not determine the number of tokens in the cache, clearing the cache."
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.tokens = prompt_tokens
return self.tokens
return self._reset_cache(prompt_tokens)
num_tokens_to_trim = num_tokens_in_cache - common_prefix
if num_tokens_to_trim > 0:
if not can_trim_prompt_cache(self.cache):
logger.warning(
f"Tried to trim '{num_tokens_to_trim}' tokens from the prompt cache, but could not: Cache is not trimmable. Clearing the cache instead."
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.tokens = prompt_tokens
return self.tokens
return self._reset_cache(prompt_tokens)
tokens_trimmed = trim_prompt_cache(self.cache, num_tokens_to_trim)
if tokens_trimmed != num_tokens_to_trim:
# If we trimmed fewer tokens than expected, the cache is invalid
logger.error(
f"Tokens trimmed from cache ({tokens_trimmed}) is less than expected ({num_tokens_to_trim}). Clearing the cache."
)
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.tokens = prompt_tokens
return self.tokens
return self._reset_cache(prompt_tokens)
logger.info(f"Trimmed {num_tokens_to_trim} tokens from the prompt cache")

# Keep track of the prompt tokens
Expand Down Expand Up @@ -221,8 +220,7 @@ def _prefill(
)
num_tokens_in_cache = None
if num_tokens_in_cache is None:
self.cache = make_prompt_cache(self.model, self.max_kv_size)
self.tokens = None
self._reset_cache()
else:
# Remember which tokens were processed so far, so that we can continue processing at a later point
self.tokens = self.tokens[:num_tokens_in_cache]
Expand Down Expand Up @@ -254,8 +252,7 @@ def set_draft_model(self, draft_model: nn.Module):
# clear the current cache, append draft model cache to the end of the main model cache as per
# https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382
logger.info("Clearing current prompt cache and adding draft model to the cache")
self.tokens = None
self.cache: List[Any] = make_prompt_cache(self.model)
self._reset_cache(use_max_kv_size=False)
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.draft_model = draft_model
Expand Down Expand Up @@ -332,3 +329,99 @@ def record_generated_token(self, token):
Add the generated token to the token list, so that we can map the token to the KV cache.
"""
self.tokens = mx.concat([self.tokens, mx.array([token])])

def _compute_images_hash(self, images_b64: List[str]) -> str:
"""Compute hash of images for cache validation."""
combined = "".join(images_b64)
return hashlib.sha256(combined.encode()).hexdigest()

def can_reuse_vision_cache(
self, images_b64: List[str], expanded_input_ids: mx.array
) -> bool:
"""
Check if we can skip expensive vision processing and reuse cached KV states.

Supports both extending (longer prompt) and rewinding (shorter prompt)
as long as one prompt is a prefix of the other AND cache operations are supported.

Args:
images_b64: Current request's base64-encoded images
expanded_input_ids: Current request's expanded input_ids (with image tokens expanded to pad tokens)

Returns:
bool: True if we can skip vision processing, False otherwise
"""
if self.prev_images_hash is None or self.prev_expanded_input_ids is None:
return False

# Check if images are identical
current_images_hash = self._compute_images_hash(images_b64)
if current_images_hash != self.prev_images_hash:
return False

# Check if cache supports required operations
# For non-trimmable caches (like SWA), we can only reuse vision cache when extending
# (adding new tokens). If we're re-prompting with same/shorter prompt, the cache
# will have generated tokens that need trimming, which non-trimmable caches can't do.
if not can_trim_prompt_cache(self.cache):
num_tokens_in_cache = self._get_num_tokens_in_cache()
if num_tokens_in_cache is None:
return False # Can't determine cache state

# Only allow reuse for non-trimmable caches when extending the prompt
# If not extending (same or rewinding), cache will need trimming
if len(expanded_input_ids) <= len(self.prev_expanded_input_ids):
# Not extending - cache will need trimming (has generated tokens)
if num_tokens_in_cache > 0:
return False

return True

def record_vision_state(self, images_b64: List[str], expanded_input_ids: mx.array):
"""
Record vision processing state for future cache validation.

Args:
images_b64: Base64-encoded images that were processed
expanded_input_ids: Expanded input_ids (with image tokens expanded to pad tokens) that were used
"""
self.prev_images_hash = self._compute_images_hash(images_b64)
self.prev_expanded_input_ids = expanded_input_ids

def set_vision_tokens(self, input_ids: mx.array):
"""
Set the token state after vision processing.

This is used after vision tower processing where embeddings have been computed
and the full expanded input_ids should be tracked in the cache.

Args:
input_ids: The full expanded input_ids (with image tokens expanded to pad tokens)
"""
self.tokens = input_ids

def _reset_cache(
self, tokens: Optional[mx.array] = None, use_max_kv_size: bool = True
) -> Optional[mx.array]:
"""
Reset the cache to a fresh state.

Args:
tokens: The tokens to set (or None to clear)
use_max_kv_size: Whether to pass max_kv_size to make_prompt_cache

Returns:
The tokens that were set
"""
if use_max_kv_size:
self.cache = make_prompt_cache(self.model, self.max_kv_size)
else:
self.cache = make_prompt_cache(self.model)
self.tokens = tokens
self.clear_vision_cache()
return self.tokens

def clear_vision_cache(self):
"""Clear vision-specific cache state."""
self.prev_images_hash = None
self.prev_expanded_input_ids = None
1 change: 1 addition & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def create_generator(
ValueError: If top_logprobs exceeds MAX_TOP_LOGPROBS or if any parameters are invalid
"""
set_seed(seed)
images_b64 = [] if images_b64 is None else images_b64

generate_args = {}
# For each call to create_generator, wrap all prompt progress calls with a ratchet that
Expand Down
86 changes: 75 additions & 11 deletions mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from mlx_engine.model_kit.vision_add_ons.qwen2_vl import Qwen2_VLVisionAddOn
from mlx_engine.utils.kv_cache_quantization import get_kv_cache_quantization_params
from mlx_engine.utils.prompt_processing import process_prompt_text_only
from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import (
common_process_prompt_with_images,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +54,6 @@ class ModelKit:
tokenizer: TokenizerWrapper = None
detokenizer: StreamingDetokenizer = None
cache_wrapper: Optional[CacheWrapper] = None
_cross_prompt_cache_active: bool = False
max_kv_size: Optional[int] = None
kv_bits: Optional[int] = None
kv_group_size: Optional[int] = None
Expand Down Expand Up @@ -103,10 +105,7 @@ def _full_model_init(
self.kv_group_size = kv_group_size
self.quantized_kv_start = quantized_kv_start
vision_add_on_class = self.VISION_ADD_ON_MAP.get(self.model_type)
should_load_vision_add_on = (
vision_add_on_class is not None and "vision_config" in config_json
)
if should_load_vision_add_on:
if vision_add_on_class and "vision_config" in config_json:
self.vision_add_on = vision_add_on_class(model_path)
logger.info("Model loaded successfully")

Expand Down Expand Up @@ -136,19 +135,45 @@ def tokenize(self, prompt: str) -> List[int]:
return [ids]
return ids

def _get_input_ids_via_prepare_inputs(
self,
prompt_tokens: mx.array,
images_b64: list[str],
max_image_size: tuple[int, int] | None,
) -> mx.array:
"""
Get input_ids with image tokens inserted (cheap operation).
Calls common_process_prompt_with_images but skips expensive vision processing.
"""
# Determine should_pad based on model type
# Qwen models need should_pad=False, others use True (default)
should_pad = self.model_type not in ["qwen2_vl", "qwen2_5_vl"]

processed = common_process_prompt_with_images(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
processor=self.vision_add_on.processor,
config=self.vision_add_on.config,
max_size=max_image_size,
should_pad=should_pad,
)

# Return input_ids, squeeze batch dimension if present
input_ids = processed.input_ids
return input_ids.squeeze(0) if input_ids.ndim > 1 else input_ids

def process_prompt(
self,
prompt_tokens,
images_b64: Optional[List[str]],
images_b64: list[str],
prompt_progress_callback: Optional[Callable[[float], bool]],
generate_args: dict,
max_image_size: tuple[int, int] | None,
speculative_decoding_toggle: Optional[bool] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
### TEXT-ONLY PROCESS_PROMPT ###
is_text_only_processing = images_b64 is None or len(images_b64) == 0
is_text_only_processing = len(images_b64) == 0
if is_text_only_processing:
self._cross_prompt_cache_active = True
if len(prompt_tokens) == 0:
logger.warning(
"Received empty prompt. Generation quality will likely be poor"
Expand All @@ -163,23 +188,62 @@ def process_prompt(
speculative_decoding_toggle,
prompt_progress_callback,
), None
### WITH IMAGES PROMPT PROCESSING ###s
### WITH IMAGES PROMPT PROCESSING ###
if self.vision_add_on is None:
raise ValueError(
"Vision add-on is not loaded, but images were provided for processing"
)
self._cross_prompt_cache_active = False

# Get expanded input_ids, which add image pad tokens to the prompt
input_ids = self._get_input_ids_via_prepare_inputs(
prompt_tokens, images_b64, max_image_size
)

# Check if we can skip expensive vision processing
can_skip_vision_processing = self.cache_wrapper.can_reuse_vision_cache(
images_b64, input_ids
)

if can_skip_vision_processing:
# Skip vision tower, reuse cached KV states
unprocessed_tokens = process_prompt_text_only(
input_ids,
self.cache_wrapper,
generate_args,
draft_model=None, # Vision models don't support draft models
speculative_decoding_toggle=None,
prompt_progress_callback=prompt_progress_callback,
)

# Update vision state for next request
self.cache_wrapper.record_vision_state(images_b64, input_ids)

return unprocessed_tokens, None

# Full vision processing
input_ids, embeddings = self.vision_add_on.compute_embeddings(
self.model, prompt_tokens, images_b64, max_size=max_image_size
)

# Record vision state for future requests
self.cache_wrapper.record_vision_state(images_b64, input_ids)

# Set the tokens to the full expanded input_ids
self.cache_wrapper.set_vision_tokens(input_ids)

generate_args["prompt_cache"] = self.cache_wrapper.cache

return input_ids, embeddings

def is_cross_prompt_cache_active(self) -> bool:
"""
Check if cross-prompt caching is currently enabled.
Can be overridden by subclasses for custom behavior.

ModelKit always supports cross-prompt caching.
VisionModelKit overrides this to return False.
"""
return self._cross_prompt_cache_active
return True

def record_token_to_cache(self, token: int) -> None:
self.cache_wrapper.record_generated_token(token)
Expand Down
Loading