-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Add KV Cache Memory Estimator Example Script #29736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add KV Cache Memory Estimator Example Script #29736
Conversation
Signed-off-by: Keerthana Senthilnathan <keethu@mac.lan>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
Documentation preview: https://vllm--29736.org.readthedocs.build/en/29736/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces a useful example script for estimating KV cache memory. However, the current implementation has a significant flaw in its calculation that leads to overestimation for models with Grouped-Query Attention (GQA) or Multi-Query Attention (MQA). I've provided a comment with a suggested fix to use the correct formula based on num_key_value_heads and head_size for a more accurate estimation.
| @dataclass | ||
| class KVEstimate: | ||
| model_name: str | ||
| num_layers: int | ||
| hidden_size: int | ||
| seq_len: int | ||
| batch_size: int | ||
| dtype: str | ||
|
|
||
| def total_elements(self) -> int: | ||
| # KV per token per layer = 2 * hidden_size | ||
| return self.batch_size * self.seq_len * self.num_layers * (2 * self.hidden_size) | ||
|
|
||
| def total_bytes(self) -> int: | ||
| return self.total_elements() * DTYPE_BYTES[self.dtype] | ||
|
|
||
| def total_gb(self) -> float: | ||
| return self.total_bytes() / (1024 ** 3) | ||
|
|
||
| def pretty(self) -> str: | ||
| return ( | ||
| f"Model: {self.model_name}\n" | ||
| f"Layers: {self.num_layers}\n" | ||
| f"Hidden size: {self.hidden_size}\n" | ||
| f"Batch size: {self.batch_size}\n" | ||
| f"Seq length: {self.seq_len}\n" | ||
| f"Dtype: {self.dtype}\n" | ||
| f"-------------------------------\n" | ||
| f"Approx KV cache memory: {self.total_gb():.2f} GB\n" | ||
| ) | ||
|
|
||
|
|
||
| def load_model_config(model_name: str): | ||
| cfg = AutoConfig.from_pretrained(model_name) | ||
|
|
||
| num_layers = getattr(cfg, "num_hidden_layers", getattr(cfg, "n_layer", None)) | ||
| hidden_size = getattr(cfg, "hidden_size", getattr(cfg, "n_embd", None)) | ||
|
|
||
| if num_layers is None or hidden_size is None: | ||
| raise ValueError( | ||
| f"Could not extract num_layers/hidden_size from config for {model_name}." | ||
| ) | ||
|
|
||
| return num_layers, hidden_size | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Estimate KV cache memory usage.") | ||
| parser.add_argument("--model", type=str, required=True) | ||
| parser.add_argument("--seq-len", type=int, required=True) | ||
| parser.add_argument("--batch-size", type=int, default=1) | ||
| parser.add_argument("--dtype", type=str, default="fp16", choices=DTYPE_BYTES.keys()) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
| num_layers, hidden_size = load_model_config(args.model) | ||
|
|
||
| est = KVEstimate( | ||
| model_name=args.model, | ||
| num_layers=num_layers, | ||
| hidden_size=hidden_size, | ||
| seq_len=args.seq_len, | ||
| batch_size=args.batch_size, | ||
| dtype=args.dtype, | ||
| ) | ||
|
|
||
| print(est.pretty()) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current formula for total_elements overestimates the KV cache size for models that use Grouped-Query Attention (GQA) or Multi-Query Attention (MQA). The formula 2 * hidden_size is equivalent to 2 * num_attention_heads * head_size, but the KV cache size depends on num_key_value_heads.
The correct formula for the number of elements per token per layer is 2 * num_key_value_heads * head_size.
For models like Llama-3-8B, where num_attention_heads=32 and num_key_value_heads=8, this leads to a 4x overestimation of the KV cache memory.
To fix this, the script should be updated to extract num_key_value_heads and head_size from the model config and use them in the calculation. Here is a suggested refactoring of the KVEstimate class, load_model_config function, and main function to implement this correction.
@dataclass
class KVEstimate:
model_name: str
num_layers: int
num_kv_heads: int
head_size: int
seq_len: int
batch_size: int
dtype: str
def total_elements(self) -> int:
# Each token has a key and a value vector for each layer.
# For each layer, the size of the key/value cache is
# num_kv_heads * head_size.
# So, for each token, the total size is 2 * num_kv_heads * head_size.
return self.batch_size * self.seq_len * self.num_layers * (2 * self.num_kv_heads * self.head_size)
def total_bytes(self) -> int:
return self.total_elements() * DTYPE_BYTES[self.dtype]
def total_gb(self) -> float:
return self.total_bytes() / (1024 ** 3)
def pretty(self) -> str:
return (
f"Model: {self.model_name}\n"
f"Layers: {self.num_layers}\n"
f"KV Heads: {self.num_kv_heads}\n"
f"Head Size: {self.head_size}\n"
f"Batch size: {self.batch_size}\n"
f"Seq length: {self.seq_len}\n"
f"Dtype: {self.dtype}\n"
f"-------------------------------\n"
f"Approx KV cache memory: {self.total_gb():.2f} GB\n"
)
def load_model_config(model_name: str):
cfg = AutoConfig.from_pretrained(model_name)
num_layers = getattr(cfg, "num_hidden_layers", getattr(cfg, "n_layer", None))
hidden_size = getattr(cfg, "hidden_size", getattr(cfg, "n_embd", None))
num_attention_heads = getattr(cfg, "num_attention_heads", getattr(cfg, "n_head", None))
if num_layers is None or hidden_size is None or num_attention_heads is None:
raise ValueError(
f"Could not extract num_layers/hidden_size/num_attention_heads from config for {model_name}."
)
num_key_value_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
head_size = hidden_size // num_attention_heads
return num_layers, num_key_value_heads, head_size
def parse_args():
parser = argparse.ArgumentParser(description="Estimate KV cache memory usage.")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--seq-len", type=int, required=True)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--dtype", type=str, default="fp16", choices=DTYPE_BYTES.keys())
return parser.parse_args()
def main():
args = parse_args()
num_layers, num_kv_heads, head_size = load_model_config(args.model)
est = KVEstimate(
model_name=args.model,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
seq_len=args.seq_len,
batch_size=args.batch_size,
dtype=args.dtype,
)
print(est.pretty())
Purpose
This PR introduces a new example script: examples/kv_cache_memory_estimator.py
The purpose of this example is to provide users with a simple, approximate way to estimate the memory footprint of the attention key/value (KV) cache for any Hugging Face causal LM model. Many users configuring vLLM for long-context or large-batch inference have questions such as:
This example fills a gap by offering a small, self-contained tool that computes KV cache memory usage directly from the model configuration (num_hidden_layers, hidden_size) using the standard KV sizing formula. It is intended as a lightweight planning and educational aid — not as a precise profiler.
Test Plan
I tested the script manually using multiple Hugging Face causal LM models to ensure:
AutoConfig.from_pretrained.num_hidden_layers/n_layer,hidden_size/n_embd).Commands used:
pip install transformers python examples/kv_cache_memory_estimator.py \ --model gpt2 \ --seq-len 1024 \ --batch-size 1 \ --dtype fp16 Also tested with a larger model: python examples/kv_cache_memory_estimator.py \ --model meta-llama/Llama-3-8B-Instruct \ --seq-len 4096 \ --batch-size 1 \ --dtype fp16 ## Test Result Example output for GPT-2 test run: Model: gpt2 Layers: 12 Hidden size: 768 Batch size: 1 Seq length: 1024 Dtype: fp16 ------------------------------- Approx KV cache memory: 0.04 GBThis matches expected order-of-magnitude KV memory for GPT-2 at 1024 context length.
Example output for LLaMA-3-8B test run produced a larger KV estimate consistent with its hidden size and depth.
The script runs without errors, produces stable results, and works as intended as an approximate estimator.
Essential Elements of an Effective PR Description Checklist
Purpose of the PR is clearly explained.
Test plan is described with commands used for verification.
Test results included with real sample outputs.
Optional documentation update (not required for this PR).
Optional release notes update (not user-facing core change).