diff --git a/src/flash_head/_version.py b/src/flash_head/_version.py index e58734c..497880d 100644 --- a/src/flash_head/_version.py +++ b/src/flash_head/_version.py @@ -2,4 +2,4 @@ """FlashHead package version. Bump this to trigger a PyPI release.""" -__version__ = "0.1.9" +__version__ = "0.1.10" diff --git a/src/flash_head/patches/__init__.py b/src/flash_head/patches/__init__.py index e76cb39..fbac341 100644 --- a/src/flash_head/patches/__init__.py +++ b/src/flash_head/patches/__init__.py @@ -12,7 +12,8 @@ 3. RejectionSampler.forward - greedy comparison for speculative decoding 4. EagleProposer._greedy_sample - handle FlashHead in draft proposals 5. GPUModelRunner._dummy_sampler_run - skip warmup for token IDs -6. LLMEngine.from_engine_args - load FlashHead metadata before engine init +6. LLMEngine.from_engine_args - load FlashHead metadata (Python LLM API) +7. AsyncLLM.__init__ - load FlashHead metadata (vllm serve / OpenAI API) """ import logging @@ -28,6 +29,7 @@ def apply_all(): from flash_head.patches.eagle import patch_eagle from flash_head.patches.gpu_model_runner import patch_gpu_model_runner from flash_head.patches.llm import patch_llm + from flash_head.patches.async_llm import patch_async_llm patch_logits_processor() patch_sampler() @@ -35,5 +37,6 @@ def apply_all(): patch_eagle() patch_gpu_model_runner() patch_llm() + patch_async_llm() logger.info("[FlashHead] All patches applied") diff --git a/src/flash_head/patches/async_llm.py b/src/flash_head/patches/async_llm.py new file mode 100644 index 0000000..0e7469a --- /dev/null +++ b/src/flash_head/patches/async_llm.py @@ -0,0 +1,59 @@ +# Copyright (C) 2026 Embedl AB + +"""Patch AsyncLLM to load FlashHead metadata. + +The existing `patch_llm()` only covers `vllm.v1.engine.llm_engine.LLMEngine. +from_engine_args`, which is the entry point used by the offline Python +`LLM(...)` API. The `vllm serve` CLI in vLLM >= 0.19 goes through +`vllm.entrypoints.openai.api_server.build_async_engine_client_from_engine_args` +→ `AsyncLLM.from_vllm_config` → `AsyncLLM.__init__`, which never calls +`LLMEngine.from_engine_args`. Without this patch, metadata never gets +written for `vllm serve` and FlashHead silently falls back to the dense +lm_head on every decode step. + +We patch `AsyncLLM.__init__` so the metadata-load happens regardless of +which constructor is used (`from_engine_args`, `from_vllm_config`, or the +raw `AsyncLLM(vllm_config=...)` form). +""" + +import logging + +logger = logging.getLogger(__name__) + + +def _model_id_from_vllm_config(vllm_config) -> str | None: + """Pull the HF model id / local path out of a VllmConfig.""" + try: + return vllm_config.model_config.model + except AttributeError: + return None + + +def patch_async_llm(): + try: + from vllm.v1.engine.async_llm import AsyncLLM + except Exception as e: + logger.debug("[FlashHead] AsyncLLM not available, skipping patch: %s", e) + return + + _original_init = AsyncLLM.__init__ + + def _patched_init(self, vllm_config, *args, **kwargs): + model = _model_id_from_vllm_config(vllm_config) + if model is not None: + try: + from flash_head.loading import ( + load_flash_head_from_checkpoint, + set_flash_head, + ) + flash_head_meta = load_flash_head_from_checkpoint(model) + set_flash_head(flash_head_meta) + if flash_head_meta: + logger.info("[FlashHead] Metadata saved for model: %s", model) + except Exception as e: + logger.warning("[FlashHead] Metadata loading skipped: %s", e) + + return _original_init(self, vllm_config, *args, **kwargs) + + AsyncLLM.__init__ = _patched_init + logger.info("[FlashHead] Patched AsyncLLM.__init__") diff --git a/src/flash_head/patches/logits_processor.py b/src/flash_head/patches/logits_processor.py index 811fd0e..adfd979 100644 --- a/src/flash_head/patches/logits_processor.py +++ b/src/flash_head/patches/logits_processor.py @@ -8,14 +8,19 @@ logger = logging.getLogger(__name__) -# Sentinel for lazy loading -_FLASH_HEAD_NOT_LOADED = object() -_flash_head = _FLASH_HEAD_NOT_LOADED +_flash_head = None def _get_flash_head(): + """Return the FlashHead module, lazy-loading on first successful call. + + Negative results are NOT cached: if metadata is not yet available we keep + returning None and recheck on the next call. This makes it safe to write + `/tmp/flashhead_metadata.pt` after server startup (e.g. when the engine + was constructed through a vLLM entry point we don't patch directly). + """ global _flash_head - if _flash_head is _FLASH_HEAD_NOT_LOADED: + if _flash_head is None: from flash_head.loading import get_flash_head _flash_head = get_flash_head() return _flash_head