-
Notifications
You must be signed in to change notification settings - Fork 41
Fix numerical issue on hybrid kv cache allocation #1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These PR doesn't have any tests. Please add the following tests:
- e2e Correctness test: output with and without hybrid allocation is the same
- e2e performance test: performance with hybrid allocator is higher than without hybrid allocator
- unit tests for the changed python files and the runner. We need to keep coverage above 70% and we need our PRs to come with enough tests
py4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also work for JAX path? if no, can we also make JAX path work?
It should be backend agnostic, but to enable in Jax, we need to modify the individual jax model. Previously all jax models don't need hybrid kv cache, so it's not enabled. The numerical issue is also reported using vLLM model instead of flax nnx. |
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
Signed-off-by: Chenyaaang <chenyangli@google.com>
8f5b161 to
a1d07b7
Compare
Signed-off-by: Chenyaaang <chenyangli@google.com>
|
with this PR, Ion gpt-oss, 've verified that numeric issue has been solved & also a performance issue that stemmed from numeric issues has been resolved. |
Signed-off-by: Chenyaaang <chenyangli@google.com>
Signed-off-by: Chenyaaang <chenyangli@google.com>
Description
Fix numerical issue on hybrid kv cache allocation. When we enable hybrid kv cache, at each kv cache allocation round, the block_id is different between each kv cache group, which means different layers are writing to different block_ids, so we need to create individual attention metadata for each layer, instead of using the same attention metadata for every layer.
Tests
python examples/offline_inference.py --model google/gemma-3-27b-it --tensor-parallel-size 8Checklist
Before submitting this PR, please make sure: