Skip to content

Conversation

@i3hz
Copy link
Contributor

@i3hz i3hz commented Nov 28, 2025

What does this PR do?

Fixes #42454

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@zucchini-nlp @Rocketknight1 @mobicham

Benchmarking script -

import torch
import time
import numpy as np
from transformers import WhisperForConditionalGeneration
from transformers.cache_utils import StaticCache, EncoderDecoderCache


MODEL_ID = "openai/whisper-tiny"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MAX_BATCH_SIZE = 64
TEST_BATCHES = [1, 8, 32, 64]

SEQ_LEN = 128
WARMUP = 10
REPEATS = 50


def load_model():
    model = WhisperForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        attn_implementation="sdpa",
    ).to(DEVICE)
    model.eval()
    return model


def run_benchmark(model, batch_size, cache_cap, tag):
    decoder = model.model.decoder
    cache_len = SEQ_LEN + 10
    enc_len = SEQ_LEN

    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=cache_cap,
        max_cache_len=cache_len,
        device=DEVICE,
        dtype=DTYPE,
    )
    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=cache_cap,
        max_cache_len=enc_len,
        device=DEVICE,
        dtype=DTYPE,
    )
    kv_cache = EncoderDecoderCache(self_cache, cross_cache)

    input_ids = torch.randint(0, 1000, (batch_size, SEQ_LEN), device=DEVICE)
    encoder_states = torch.randn(batch_size, enc_len, model.config.d_model, device=DEVICE, dtype=DTYPE)
    cache_pos = torch.arange(SEQ_LEN, device=DEVICE)

    for _ in range(WARMUP):
        kv_cache.reset()
        with torch.no_grad():
            decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_states,
                past_key_values=kv_cache,
                cache_position=cache_pos,
                use_cache=True,
            )

    # actual runs
    start = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]
    end = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]

    torch.cuda.synchronize()

    for i in range(REPEATS):
        kv_cache.reset()
        start[i].record()

        with torch.no_grad():
            decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_states,
                past_key_values=kv_cache,
                cache_position=cache_pos,
                use_cache=True,
            )

        end[i].record()

    torch.cuda.synchronize()
    times = [s.elapsed_time(e) for s, e in zip(start, end)]
    avg = float(np.mean(times))

    print(f"[{tag}]  batch={batch_size:2d}  cache_cap={cache_cap:2d}  latency={avg:.2f} ms")
    return avg



model = load_model()

results = []

for bs in TEST_BATCHES:
    base = run_benchmark(model, batch_size=bs, cache_cap=bs, tag="BASELINE")
    sliced = run_benchmark(model, batch_size=bs, cache_cap=MAX_BATCH_SIZE, tag="SLICED")
    diff = (sliced - base) / base * 100
    results.append((bs, base, sliced, diff))

print("Summary:")
print(f"{'Batch':<8} | {'Baseline (ms)':<15} | {'Sliced (ms)':<15} | Diff (%)")
for bs, base, sliced, diff in results:
    print(f"{bs:<8} | {base:<15.2f} | {sliced:<15.2f} | {diff:+.2f}%")

@i3hz
Copy link
Contributor Author

i3hz commented Nov 28, 2025

@zucchini-nlp This doesn't have the max_batch_size as you mentioned . If it's something that I should add , please lmk .

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a quick fix @i3hz !

I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo

After that, we can verify that there are no unwanted graph breaks and run the bench

LMK if this makes sense and you need guidance

Comment on lines 340 to 343
if key_states.dtype != k_out.dtype:
key_states = key_states.to(k_out.dtype)
if value_states.dtype != v_out.dtype:
value_states = value_states.to(v_out.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be handled by lazy init already, so not really needed imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I'll remove that

Comment on lines +336 to +339
batch_size = key_states.shape[0]
if k_out.shape[0] != batch_size:
k_out = k_out[:batch_size]
v_out = v_out[:batch_size]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you run benchmarks and check if it creates excessive cudagraph breaks, as per the last comment from mobicham. In any case, a small benchmark run will be needed before merging the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified the fix with GPT2 using torch.compile(mode='reduce-overhead', fullgraph=True) . For some reason it keeps failing with whisper models and I can't really figure out why .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the bench script to PR description pls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I'll add it

@i3hz
Copy link
Contributor Author

i3hz commented Nov 28, 2025

I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo

So basically I should add max_batch_size to the __init__ method of StaticCache and then in StaticLayer modify the lazy_initialization to use max_batch_size .

And also change line 1837 from src/transformers/generation/utils.py to as seen in #37394

or cache_to_check.max_batch_size < batch_size

@zucchini-nlp
Copy link
Member

Yep, and a small test as well

@i3hz
Copy link
Contributor Author

i3hz commented Nov 29, 2025

hi @zucchini-nlp I'm still stuck on this. I’ve been testing with torch.compile and it works fine with GPT-2, but does not work with whisper small ,I’m not sure what I’m missing tbh .
If you have any pointers on what I should check or tweak, I’d really appreciate it.
Thanks a lot and sorry for the trouble

k_out = self.keys
v_out = self.values
batch_size = key_states.shape[0]
if k_out.shape[0] != batch_size:
Copy link
Contributor

@mobicham mobicham Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] >= batch_size is better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When debugging the torch.compile stuff, can you check this:

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
assert v_out.data_ptr() == v_out[:batch_size].data_ptr() , "invalid v_out data copy()!"

If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] <= batch_size is better

Wait should it be < or > considering k_out will be larger than the batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

Builtin `operator.*` comparison with constant `self` failed
  Explanation: Failed to compare DataPtrVariable() with DataPtrVariable(), because DataPtrVariable() is not a Python constant or its mutation check fails.

About the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with

Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(s72, 6, 1, 64), dtype=torch.float16,
           grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>)), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 1.0, 'is_causal': False}): got RuntimeError('Attempting to broadcast a dimension of length 8 at -2! Mismatching argument at index 1 had [8, 6]; but expected shape should be broadcastable to [s72, 6]')

from user code:
   File "/home/vedth/stuhdy/z.py", line 21, in decoder_forward
    out = model.model.decoder(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 865, in forward
    layer_outputs = decoder_layer(
  File "/home/vedth/stuhdy/transformers/src/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 501, in forward
    hidden_states, cross_attn_weights = self.encoder_attn(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 347, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/vedth/stuhdy/transformers/src/transformers/integrations/sdpa_attention.py", line 92, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait should it be < or > considering k_out will be larger than the batch_size

Oh sorry you're right, I meant current_batch_size <= max_batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

I meant run it without torch.compile, just to see if it performs any copy

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace

I see, I will try to debug next week too 👍

@mobicham
Copy link
Contributor

@i3hz I will also do some debugging next week

@i3hz
Copy link
Contributor Author

i3hz commented Nov 29, 2025

@i3hz I will also do some debugging next week

Thanks a lot
The main issue still lies within torch.compile as without it the model is working

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

StaticCache crashes when the batch-size changes

3 participants