Skip to content

Conversation

@ksenthilnathan02
Copy link

@ksenthilnathan02 ksenthilnathan02 commented Nov 29, 2025

Purpose

This PR introduces a new example script: examples/kv_cache_memory_estimator.py

The purpose of this example is to provide users with a simple, approximate way to estimate the memory footprint of the attention key/value (KV) cache for any Hugging Face causal LM model. Many users configuring vLLM for long-context or large-batch inference have questions such as:

  • “How much GPU memory will the KV cache use for my model?”
  • “How does memory scale with sequence length or batch size?”
  • “Is my GPU large enough for this context length?”

This example fills a gap by offering a small, self-contained tool that computes KV cache memory usage directly from the model configuration (num_hidden_layers, hidden_size) using the standard KV sizing formula. It is intended as a lightweight planning and educational aid — not as a precise profiler.


Test Plan

I tested the script manually using multiple Hugging Face causal LM models to ensure:

  1. The script loads the model configuration correctly via AutoConfig.from_pretrained.
  2. It extracts the expected attributes (num_hidden_layers / n_layer, hidden_size / n_embd).
  3. The approximate memory estimation formula runs without errors.
  4. The output is formatted cleanly and matches expected scaling behavior.

Commands used:

pip install transformers

python examples/kv_cache_memory_estimator.py \
    --model gpt2 \
    --seq-len 1024 \
    --batch-size 1 \
    --dtype fp16

Also tested with a larger model:

python examples/kv_cache_memory_estimator.py \
    --model meta-llama/Llama-3-8B-Instruct \
    --seq-len 4096 \
    --batch-size 1 \
    --dtype fp16

## Test Result
Example output for GPT-2 test run:

Model:         gpt2
Layers:        12
Hidden size:   768
Batch size:    1
Seq length:    1024
Dtype:         fp16
-------------------------------
Approx KV cache memory: 0.04 GB 

This matches expected order-of-magnitude KV memory for GPT-2 at 1024 context length.

Example output for LLaMA-3-8B test run produced a larger KV estimate consistent with its hidden size and depth.

The script runs without errors, produces stable results, and works as intended as an approximate estimator.

Essential Elements of an Effective PR Description Checklist

Purpose of the PR is clearly explained.

Test plan is described with commands used for verification.

Test results included with real sample outputs.

Optional documentation update (not required for this PR).

Optional release notes update (not user-facing core change).

Signed-off-by: Keerthana Senthilnathan <keethu@mac.lan>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link

mergify bot commented Nov 29, 2025

Documentation preview: https://vllm--29736.org.readthedocs.build/en/29736/

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a useful example script for estimating KV cache memory. However, the current implementation has a significant flaw in its calculation that leads to overestimation for models with Grouped-Query Attention (GQA) or Multi-Query Attention (MQA). I've provided a comment with a suggested fix to use the correct formula based on num_key_value_heads and head_size for a more accurate estimation.

Comment on lines +47 to +116
@dataclass
class KVEstimate:
model_name: str
num_layers: int
hidden_size: int
seq_len: int
batch_size: int
dtype: str

def total_elements(self) -> int:
# KV per token per layer = 2 * hidden_size
return self.batch_size * self.seq_len * self.num_layers * (2 * self.hidden_size)

def total_bytes(self) -> int:
return self.total_elements() * DTYPE_BYTES[self.dtype]

def total_gb(self) -> float:
return self.total_bytes() / (1024 ** 3)

def pretty(self) -> str:
return (
f"Model: {self.model_name}\n"
f"Layers: {self.num_layers}\n"
f"Hidden size: {self.hidden_size}\n"
f"Batch size: {self.batch_size}\n"
f"Seq length: {self.seq_len}\n"
f"Dtype: {self.dtype}\n"
f"-------------------------------\n"
f"Approx KV cache memory: {self.total_gb():.2f} GB\n"
)


def load_model_config(model_name: str):
cfg = AutoConfig.from_pretrained(model_name)

num_layers = getattr(cfg, "num_hidden_layers", getattr(cfg, "n_layer", None))
hidden_size = getattr(cfg, "hidden_size", getattr(cfg, "n_embd", None))

if num_layers is None or hidden_size is None:
raise ValueError(
f"Could not extract num_layers/hidden_size from config for {model_name}."
)

return num_layers, hidden_size


def parse_args():
parser = argparse.ArgumentParser(description="Estimate KV cache memory usage.")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--seq-len", type=int, required=True)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--dtype", type=str, default="fp16", choices=DTYPE_BYTES.keys())
return parser.parse_args()


def main():
args = parse_args()
num_layers, hidden_size = load_model_config(args.model)

est = KVEstimate(
model_name=args.model,
num_layers=num_layers,
hidden_size=hidden_size,
seq_len=args.seq_len,
batch_size=args.batch_size,
dtype=args.dtype,
)

print(est.pretty())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current formula for total_elements overestimates the KV cache size for models that use Grouped-Query Attention (GQA) or Multi-Query Attention (MQA). The formula 2 * hidden_size is equivalent to 2 * num_attention_heads * head_size, but the KV cache size depends on num_key_value_heads.

The correct formula for the number of elements per token per layer is 2 * num_key_value_heads * head_size.

For models like Llama-3-8B, where num_attention_heads=32 and num_key_value_heads=8, this leads to a 4x overestimation of the KV cache memory.

To fix this, the script should be updated to extract num_key_value_heads and head_size from the model config and use them in the calculation. Here is a suggested refactoring of the KVEstimate class, load_model_config function, and main function to implement this correction.

@dataclass
class KVEstimate:
    model_name: str
    num_layers: int
    num_kv_heads: int
    head_size: int
    seq_len: int
    batch_size: int
    dtype: str

    def total_elements(self) -> int:
        # Each token has a key and a value vector for each layer.
        # For each layer, the size of the key/value cache is
        # num_kv_heads * head_size.
        # So, for each token, the total size is 2 * num_kv_heads * head_size.
        return self.batch_size * self.seq_len * self.num_layers * (2 * self.num_kv_heads * self.head_size)

    def total_bytes(self) -> int:
        return self.total_elements() * DTYPE_BYTES[self.dtype]

    def total_gb(self) -> float:
        return self.total_bytes() / (1024 ** 3)

    def pretty(self) -> str:
        return (
            f"Model:         {self.model_name}\n"
            f"Layers:        {self.num_layers}\n"
            f"KV Heads:      {self.num_kv_heads}\n"
            f"Head Size:     {self.head_size}\n"
            f"Batch size:    {self.batch_size}\n"
            f"Seq length:    {self.seq_len}\n"
            f"Dtype:         {self.dtype}\n"
            f"-------------------------------\n"
            f"Approx KV cache memory: {self.total_gb():.2f} GB\n"
        )


def load_model_config(model_name: str):
    cfg = AutoConfig.from_pretrained(model_name)

    num_layers = getattr(cfg, "num_hidden_layers", getattr(cfg, "n_layer", None))
    hidden_size = getattr(cfg, "hidden_size", getattr(cfg, "n_embd", None))
    num_attention_heads = getattr(cfg, "num_attention_heads", getattr(cfg, "n_head", None))

    if num_layers is None or hidden_size is None or num_attention_heads is None:
        raise ValueError(
            f"Could not extract num_layers/hidden_size/num_attention_heads from config for {model_name}."
        )

    num_key_value_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
    head_size = hidden_size // num_attention_heads

    return num_layers, num_key_value_heads, head_size


def parse_args():
    parser = argparse.ArgumentParser(description="Estimate KV cache memory usage.")
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--seq-len", type=int, required=True)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--dtype", type=str, default="fp16", choices=DTYPE_BYTES.keys())
    return parser.parse_args()


def main():
    args = parse_args()
    num_layers, num_kv_heads, head_size = load_model_config(args.model)

    est = KVEstimate(
        model_name=args.model,
        num_layers=num_layers,
        num_kv_heads=num_kv_heads,
        head_size=head_size,
        seq_len=args.seq_len,
        batch_size=args.batch_size,
        dtype=args.dtype,
    )

    print(est.pretty())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant