Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/flash_head/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

"""FlashHead package version. Bump this to trigger a PyPI release."""

__version__ = "0.1.9"
__version__ = "0.1.10"
5 changes: 4 additions & 1 deletion src/flash_head/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
3. RejectionSampler.forward - greedy comparison for speculative decoding
4. EagleProposer._greedy_sample - handle FlashHead in draft proposals
5. GPUModelRunner._dummy_sampler_run - skip warmup for token IDs
6. LLMEngine.from_engine_args - load FlashHead metadata before engine init
6. LLMEngine.from_engine_args - load FlashHead metadata (Python LLM API)
7. AsyncLLM.__init__ - load FlashHead metadata (vllm serve / OpenAI API)
"""

import logging
Expand All @@ -28,12 +29,14 @@ def apply_all():
from flash_head.patches.eagle import patch_eagle
from flash_head.patches.gpu_model_runner import patch_gpu_model_runner
from flash_head.patches.llm import patch_llm
from flash_head.patches.async_llm import patch_async_llm

patch_logits_processor()
patch_sampler()
patch_rejection_sampler()
patch_eagle()
patch_gpu_model_runner()
patch_llm()
patch_async_llm()

logger.info("[FlashHead] All patches applied")
59 changes: 59 additions & 0 deletions src/flash_head/patches/async_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (C) 2026 Embedl AB

"""Patch AsyncLLM to load FlashHead metadata.

The existing `patch_llm()` only covers `vllm.v1.engine.llm_engine.LLMEngine.
from_engine_args`, which is the entry point used by the offline Python
`LLM(...)` API. The `vllm serve` CLI in vLLM >= 0.19 goes through
`vllm.entrypoints.openai.api_server.build_async_engine_client_from_engine_args`
→ `AsyncLLM.from_vllm_config` → `AsyncLLM.__init__`, which never calls
`LLMEngine.from_engine_args`. Without this patch, metadata never gets
written for `vllm serve` and FlashHead silently falls back to the dense
lm_head on every decode step.

We patch `AsyncLLM.__init__` so the metadata-load happens regardless of
which constructor is used (`from_engine_args`, `from_vllm_config`, or the
raw `AsyncLLM(vllm_config=...)` form).
"""

import logging

logger = logging.getLogger(__name__)


def _model_id_from_vllm_config(vllm_config) -> str | None:
"""Pull the HF model id / local path out of a VllmConfig."""
try:
return vllm_config.model_config.model
except AttributeError:
return None


def patch_async_llm():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a guard for idempotence (similar to what we do in logits_processor.py) [if _flash_head is None:...]

While AsyncLLm is run only once per engine construction (not per decode / request) there may be other parts of vllm that call it. We could add a _FLASH_HEAD_NOT_LOADED.

try:
from vllm.v1.engine.async_llm import AsyncLLM
except Exception as e:
logger.debug("[FlashHead] AsyncLLM not available, skipping patch: %s", e)
return

_original_init = AsyncLLM.__init__

def _patched_init(self, vllm_config, *args, **kwargs):
model = _model_id_from_vllm_config(vllm_config)
if model is not None:
try:
from flash_head.loading import (
load_flash_head_from_checkpoint,
set_flash_head,
)
flash_head_meta = load_flash_head_from_checkpoint(model)
set_flash_head(flash_head_meta)
if flash_head_meta:
logger.info("[FlashHead] Metadata saved for model: %s", model)
except Exception as e:
logger.warning("[FlashHead] Metadata loading skipped: %s", e)

return _original_init(self, vllm_config, *args, **kwargs)

AsyncLLM.__init__ = _patched_init
logger.info("[FlashHead] Patched AsyncLLM.__init__")
13 changes: 9 additions & 4 deletions src/flash_head/patches/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

logger = logging.getLogger(__name__)

# Sentinel for lazy loading
_FLASH_HEAD_NOT_LOADED = object()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed since get_flash_head() may be None (e.g., when running non-FlashHead models).

_flash_head = _FLASH_HEAD_NOT_LOADED
_flash_head = None


def _get_flash_head():
"""Return the FlashHead module, lazy-loading on first successful call.

Negative results are NOT cached: if metadata is not yet available we keep
returning None and recheck on the next call. This makes it safe to write
`/tmp/flashhead_metadata.pt` after server startup (e.g. when the engine
was constructed through a vLLM entry point we don't patch directly).
"""
global _flash_head
if _flash_head is _FLASH_HEAD_NOT_LOADED:
if _flash_head is None:
from flash_head.loading import get_flash_head
_flash_head = get_flash_head()
return _flash_head
Expand Down