Summary
When serving datalab-to/chandra-ocr-2 on vLLM with --enable-prefix-caching, the prefix cache hit rate stays at ~0% even across thousands of OCR requests that share an identical ~500-token ocr_layout prompt. This leaves measurable prefill compute on the table.
Observation
Steady-state vLLM engine log on an RTX 3090 running 8 concurrent requests:
Engine 000: Avg prompt throughput: 673 tokens/s, Avg generation throughput: 410 tokens/s,
Running: 8, Waiting: 1, GPU KV cache usage: 33%,
Prefix cache hit rate: 0.0%, MM cache hit rate: 14.9%
Root cause
In chandra/model/vllm.py (generate_vllm), the message content is built as [image_url, text]. The chat_template.jinja shipped with Chandra 2 iterates content in list order, so every request renders as:
<|im_start|>user\n<|vision_start|>[per-page image]<|vision_end|>{ocr_layout prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n
vLLM's block-level prefix cache hashes KV blocks and, for any block touching MM tokens, mixes in the per-image mm_hash as an extra key (vllm/v1/core/kv_cache_utils.py::_gen_mm_extra_hash_keys). Because image tokens come first, every KV block after the tiny <|im_start|>user\n preamble inherits a unique per-page hash — the identical ocr_layout prompt that follows can never form a reusable prefix.
If content were [text, image] instead, the prompt tokens would form a shared prefix across all pages and prefix caching would hit.
Estimated impact
The ocr_layout prompt is ~500 tokens; a typical request is ~7,200 tokens (image dominates at ~6,700). Enabling prefix reuse on the prompt would save ~7% of prefill compute per page — modest but free, and it compounds at scale (extracting hundreds of thousands of scanned pages).
Question / tradeoff
Before opening a PR, I'd like to confirm:
- Was Chandra 2 trained with
[image, text] or [text, image] content ordering? A naive swap at inference could cause distribution shift on a model trained image-first, even though the tokens cover the same content.
- If training was
[text, image], is image-first at inference an oversight that could be fixed directly in generate_vllm?
- If training was
[image, text], would you accept a PR adding an opt-in flag (e.g. CHANDRA_TEXT_FIRST=1) so users can trade a small quality risk for ~7% throughput on vLLM, or is this something the model would need to be re-trained to support?
Happy to submit a PR once the direction is clear.
Env
vllm/vllm-openai:v0.17.0
chandra-ocr 0.2.0 (library)
- Model:
datalab-to/chandra-ocr-2
- Flags:
--enable-prefix-caching --max-model-len 30720 --max-num-seqs 8 --max-num-batched-tokens 2048 --mm-processor-kwargs '{"min_pixels": 3136, "max_pixels": 6291456}'
Summary
When serving
datalab-to/chandra-ocr-2on vLLM with--enable-prefix-caching, the prefix cache hit rate stays at ~0% even across thousands of OCR requests that share an identical ~500-tokenocr_layoutprompt. This leaves measurable prefill compute on the table.Observation
Steady-state vLLM engine log on an RTX 3090 running 8 concurrent requests:
Root cause
In
chandra/model/vllm.py(generate_vllm), the messagecontentis built as[image_url, text]. Thechat_template.jinjashipped with Chandra 2 iterates content in list order, so every request renders as:vLLM's block-level prefix cache hashes KV blocks and, for any block touching MM tokens, mixes in the per-image
mm_hashas an extra key (vllm/v1/core/kv_cache_utils.py::_gen_mm_extra_hash_keys). Because image tokens come first, every KV block after the tiny<|im_start|>user\npreamble inherits a unique per-page hash — the identicalocr_layoutprompt that follows can never form a reusable prefix.If
contentwere[text, image]instead, the prompt tokens would form a shared prefix across all pages and prefix caching would hit.Estimated impact
The
ocr_layoutprompt is ~500 tokens; a typical request is ~7,200 tokens (image dominates at ~6,700). Enabling prefix reuse on the prompt would save ~7% of prefill compute per page — modest but free, and it compounds at scale (extracting hundreds of thousands of scanned pages).Question / tradeoff
Before opening a PR, I'd like to confirm:
[image, text]or[text, image]content ordering? A naive swap at inference could cause distribution shift on a model trained image-first, even though the tokens cover the same content.[text, image], is image-first at inference an oversight that could be fixed directly ingenerate_vllm?[image, text], would you accept a PR adding an opt-in flag (e.g.CHANDRA_TEXT_FIRST=1) so users can trade a small quality risk for ~7% throughput on vLLM, or is this something the model would need to be re-trained to support?Happy to submit a PR once the direction is clear.
Env
vllm/vllm-openai:v0.17.0chandra-ocr0.2.0 (library)datalab-to/chandra-ocr-2--enable-prefix-caching --max-model-len 30720 --max-num-seqs 8 --max-num-batched-tokens 2048 --mm-processor-kwargs '{"min_pixels": 3136, "max_pixels": 6291456}'