-
Notifications
You must be signed in to change notification settings - Fork 1
Activate FlashHead under vllm serve #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
WilhelmTr
wants to merge
1
commit into
master
Choose a base branch
from
fix/vllm-serve-activation
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+73
−6
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Copyright (C) 2026 Embedl AB | ||
|
|
||
| """Patch AsyncLLM to load FlashHead metadata. | ||
|
|
||
| The existing `patch_llm()` only covers `vllm.v1.engine.llm_engine.LLMEngine. | ||
| from_engine_args`, which is the entry point used by the offline Python | ||
| `LLM(...)` API. The `vllm serve` CLI in vLLM >= 0.19 goes through | ||
| `vllm.entrypoints.openai.api_server.build_async_engine_client_from_engine_args` | ||
| → `AsyncLLM.from_vllm_config` → `AsyncLLM.__init__`, which never calls | ||
| `LLMEngine.from_engine_args`. Without this patch, metadata never gets | ||
| written for `vllm serve` and FlashHead silently falls back to the dense | ||
| lm_head on every decode step. | ||
|
|
||
| We patch `AsyncLLM.__init__` so the metadata-load happens regardless of | ||
| which constructor is used (`from_engine_args`, `from_vllm_config`, or the | ||
| raw `AsyncLLM(vllm_config=...)` form). | ||
| """ | ||
|
|
||
| import logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _model_id_from_vllm_config(vllm_config) -> str | None: | ||
| """Pull the HF model id / local path out of a VllmConfig.""" | ||
| try: | ||
| return vllm_config.model_config.model | ||
| except AttributeError: | ||
| return None | ||
|
|
||
|
|
||
| def patch_async_llm(): | ||
| try: | ||
| from vllm.v1.engine.async_llm import AsyncLLM | ||
| except Exception as e: | ||
| logger.debug("[FlashHead] AsyncLLM not available, skipping patch: %s", e) | ||
| return | ||
|
|
||
| _original_init = AsyncLLM.__init__ | ||
|
|
||
| def _patched_init(self, vllm_config, *args, **kwargs): | ||
| model = _model_id_from_vllm_config(vllm_config) | ||
| if model is not None: | ||
| try: | ||
| from flash_head.loading import ( | ||
| load_flash_head_from_checkpoint, | ||
| set_flash_head, | ||
| ) | ||
| flash_head_meta = load_flash_head_from_checkpoint(model) | ||
| set_flash_head(flash_head_meta) | ||
| if flash_head_meta: | ||
| logger.info("[FlashHead] Metadata saved for model: %s", model) | ||
| except Exception as e: | ||
| logger.warning("[FlashHead] Metadata loading skipped: %s", e) | ||
|
|
||
| return _original_init(self, vllm_config, *args, **kwargs) | ||
|
|
||
| AsyncLLM.__init__ = _patched_init | ||
| logger.info("[FlashHead] Patched AsyncLLM.__init__") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,14 +8,19 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Sentinel for lazy loading | ||
| _FLASH_HEAD_NOT_LOADED = object() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed since get_flash_head() may be None (e.g., when running non-FlashHead models). |
||
| _flash_head = _FLASH_HEAD_NOT_LOADED | ||
| _flash_head = None | ||
|
|
||
|
|
||
| def _get_flash_head(): | ||
| """Return the FlashHead module, lazy-loading on first successful call. | ||
|
|
||
| Negative results are NOT cached: if metadata is not yet available we keep | ||
| returning None and recheck on the next call. This makes it safe to write | ||
| `/tmp/flashhead_metadata.pt` after server startup (e.g. when the engine | ||
| was constructed through a vLLM entry point we don't patch directly). | ||
| """ | ||
| global _flash_head | ||
| if _flash_head is _FLASH_HEAD_NOT_LOADED: | ||
| if _flash_head is None: | ||
| from flash_head.loading import get_flash_head | ||
| _flash_head = get_flash_head() | ||
| return _flash_head | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a guard for idempotence (similar to what we do in logits_processor.py) [if _flash_head is None:...]
While AsyncLLm is run only once per engine construction (not per decode / request) there may be other parts of vllm that call it. We could add a _FLASH_HEAD_NOT_LOADED.