From b013c0cd5a5979a9ff52806cb2641621d2ca1ab8 Mon Sep 17 00:00:00 2001 From: i3hz Date: Fri, 28 Nov 2025 08:29:28 +0000 Subject: [PATCH 1/5] fixed static-cache crashes --- src/transformers/cache_utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 28f40952f2cd..a4efc4a352f9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -331,16 +331,25 @@ def update( cache_position = ( cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) ) - + k_out = self.keys + v_out = self.values + batch_size = key_states.shape[0] + if k_out.shape[0] != batch_size: + k_out = k_out[:batch_size] + v_out = v_out[:batch_size] + if key_states.dtype != k_out.dtype: + key_states = key_states.to(k_out.dtype) + if value_states.dtype != v_out.dtype: + value_states = value_states.to(v_out.dtype) # Update the cache try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: # Fallback for devices like MPS where index_copy_ might not be supported. - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states - return self.keys, self.values + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + return k_out, v_out def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the attention mask""" From d89e30bb455d58403597f90430241a6b91e8aeec Mon Sep 17 00:00:00 2001 From: i3hz Date: Sat, 29 Nov 2025 10:04:11 +0000 Subject: [PATCH 2/5] removed redundant code --- src/transformers/cache_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a4efc4a352f9..ea3aa61af114 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -337,10 +337,6 @@ def update( if k_out.shape[0] != batch_size: k_out = k_out[:batch_size] v_out = v_out[:batch_size] - if key_states.dtype != k_out.dtype: - key_states = key_states.to(k_out.dtype) - if value_states.dtype != v_out.dtype: - value_states = value_states.to(v_out.dtype) # Update the cache try: k_out.index_copy_(2, cache_position, key_states) From 1a6e2717f77c2ce7542b178517db200f83d59175 Mon Sep 17 00:00:00 2001 From: i3hz Date: Wed, 3 Dec 2025 12:52:36 +0000 Subject: [PATCH 3/5] static cache bug --- src/transformers/cache_utils.py | 76 ++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ea3aa61af114..42530652d4a2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -258,14 +258,17 @@ class StaticLayer(CacheLayerMixin): Args: max_cache_len (`int`): Maximum number of tokens that can be stored, used for tensor preallocation. + max_batch_size(`int`, *optional*): + Maximum batch size that can be stored """ is_compileable = True is_sliding = False - def __init__(self, max_cache_len: int): + def __init__(self, max_cache_len: int, max_batch_size: int | None = None): super().__init__() self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size def lazy_initialization(self, key_states: torch.Tensor): """ @@ -281,26 +284,30 @@ def lazy_initialization(self, key_states: torch.Tensor): i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should not be compiled anyway for performances! """ - self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape + if self.max_batch_size is None: + self.max_batch_size = key_states.shape[0] + _, self.num_heads, _, self.head_dim = key_states.shape self.dtype, self.device = key_states.dtype, key_states.device - self.keys = torch.zeros( + self.keys_ = torch.zeros( (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), dtype=self.dtype, device=self.device, ) - self.values = torch.zeros( + self.values_ = torch.zeros( (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), dtype=self.dtype, device=self.device, ) + self.keys = self.keys_ + self.values = self.values_ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case. # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile # prefill explicitly, but this should be avoided!) if not is_torchdynamo_compiling(): - torch._dynamo.mark_static_address(self.keys) - torch._dynamo.mark_static_address(self.values) + torch._dynamo.mark_static_address(self.keys_) + torch._dynamo.mark_static_address(self.values_) self.is_initialized = True @@ -331,21 +338,25 @@ def update( cache_position = ( cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) ) - k_out = self.keys - v_out = self.values batch_size = key_states.shape[0] - if k_out.shape[0] != batch_size: - k_out = k_out[:batch_size] - v_out = v_out[:batch_size] - # Update the cache + # 3. Dynamic Slicing: Update the view to match current batch + self.keys = self.keys_[:batch_size] + self.values = self.values_[:batch_size] try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) except NotImplementedError: - # Fallback for devices like MPS where index_copy_ might not be supported. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - return k_out, v_out + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + + return self.keys, self.values + + def reset(self): + if self.is_initialized: + self.keys_.zero_() + self.values_.zero_() + self.keys = self.keys_ + self.values = self.values_ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the attention mask""" @@ -1024,6 +1035,8 @@ class StaticCache(Cache): offload_only_non_sliding (`bool`, *optional*, defaults to `True`): If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster). + max_batch_size (`int`, *optional*): + The maximum batch size that will be used with this Cache . Example: @@ -1052,6 +1065,7 @@ def __init__( max_cache_len: int, offloading: bool = False, offload_only_non_sliding: bool = True, + max_batch_size: int | None = None, **kwargs, ): config = config.get_text_config(decoder=True) @@ -1071,19 +1085,39 @@ def __init__( layers = [] for layer_type in layer_types: if layer_type == "sliding_attention": - layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) + layer = StaticSlidingWindowLayer( + max_cache_len=max_cache_len, sliding_window=config.sliding_window, max_batch_size=max_batch_size + ) elif layer_type == "chunked_attention": # From a cache point of view, both sliding and chunked are the same in how they should behave and how many # states they should return - only the mask changes to make them different at the end! layer = StaticSlidingWindowLayer( - max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size + max_cache_len=max_cache_len, + sliding_window=config.attention_chunk_size, + max_batch_size=max_batch_size, ) else: - layer = StaticLayer(max_cache_len=max_cache_len) + layer = StaticLayer(max_cache_len=max_cache_len, max_batch_size=max_batch_size) layers.append(layer) super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding) + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + def reset(self): + for layer in self.layers: + layer.reset() + + def __len__(self): + return len(self.layers) + class QuantizedCache(Cache): """ From 1944a2a6f5c4359b4ebc53ed2e7ee572199cd593 Mon Sep 17 00:00:00 2001 From: i3hz Date: Fri, 5 Dec 2025 12:33:20 +0000 Subject: [PATCH 4/5] minor code cleanup --- src/transformers/cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 42530652d4a2..c42ab66732ea 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -339,7 +339,6 @@ def update( cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) ) batch_size = key_states.shape[0] - # 3. Dynamic Slicing: Update the view to match current batch self.keys = self.keys_[:batch_size] self.values = self.values_[:batch_size] try: From acabeed84068a7e32a862f63612dc00df639a2c0 Mon Sep 17 00:00:00 2001 From: i3hz Date: Fri, 5 Dec 2025 12:59:57 +0000 Subject: [PATCH 5/5] fixing staticslidingwindow layer --- src/transformers/cache_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c42ab66732ea..29cbc7712a59 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -385,13 +385,15 @@ class StaticSlidingWindowLayer(StaticLayer): Maximum number of tokens that can be stored, used for tensor preallocation. sliding_window (`int`): The size of the sliding window. + max_batch_size(`int`, *optional*): + Maximum batch size that can be stored """ is_sliding = True - def __init__(self, max_cache_len: int, sliding_window: int): + def __init__(self, max_cache_len: int, sliding_window: int, max_batch_size: int | None = None): effective_max_cache_len = min(sliding_window, max_cache_len) - super().__init__(max_cache_len=effective_max_cache_len) + super().__init__(max_cache_len=effective_max_cache_len, max_batch_size=max_batch_size) self.cumulative_length = 0 def update( @@ -422,6 +424,10 @@ def update( cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) ) + batch_size = key_states.shape[0] + self.keys = self.keys_[:batch_size] + self.values = self.values_[:batch_size] + cumulative_length = self.cumulative_length is_full = cumulative_length >= self.max_cache_len # Update it now that we saved the value above