1818_pad_after = lambda x , l , axis : jnp .pad (x , [(0 , 0 ) if i != axis else (0 , l - x .shape [i ]) for i in range (x .ndim )])
1919
2020
21+ def safe_zip (* args ):
22+ if len (args ) == 0 :
23+ return []
24+ assert all (len (arg ) == len (args [0 ]) for arg in args )
25+ return zip (* args )
26+
27+
2128def _transpose_attention_tree (kv_list : list [PyTree ], time_axis : int ):
2229 "From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list."
2330
@@ -28,7 +35,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
2835 for i , c in enumerate (kv_list [0 ]):
2936 els = [[_split (z ) for z in jax .tree .leaves (kv [i ])] for kv in kv_list ] # [B, R_flat, L]
3037 els = jax .tree .map (lambda * xs : jnp .concatenate (xs , axis = 0 ), * els ) # [R_flat, L]
31- leaves_list = list (zip (* els )) # [L, R_flat]
38+ leaves_list = list (safe_zip (* els )) # [L, R_flat]
3239 out [i ] = [jax .tree .unflatten (jax .tree .structure (c ), leaves ) for leaves in leaves_list ] # [L, R]
3340 return tuple (out ), max_seq_len
3441
@@ -41,7 +48,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
4148@partial (jax .jit , donate_argnames = ("cache" ,))
4249def _kvcache_update_cache (
4350 cache : KVCache ,
44- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
51+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
4552 batch_idxs : list [jax .Array ],
4653 actual_lens : list [jax .Array ],
4754 update_mask : list [bool ] | None = None ,
@@ -62,15 +69,17 @@ def _update_element(x, u):
6269 # update_permute = [batch_dim, time_dim] + update_permute
6370 return x .at [batch_idxs [:, None ], :, time_indices , ...].set (u .transpose (update_permute ), mode = "drop" )
6471
65- cache_k , cache_v = jax .tree .map (_update_element , ( cache .k , cache . v ) , kvs )
72+ cache_kvs = jax .tree .map (_update_element , cache .buffers , kvs )
6673 cache_starts = cache .starts .at [batch_idxs ].set (start_time , mode = "drop" )
6774 cache_iter = jnp .where (uninitialized_cache , jnp .max (actual_lens ), cache .iter )
68- return dataclasses .replace (cache , k = cache_k , v = cache_v , iter = cache_iter , starts = cache_starts )
75+
76+ buffer_names = [field .name for field in dataclasses .fields (cache )][:len (cache_kvs )]
77+ return dataclasses .replace (cache , ** dict (safe_zip (buffer_names , cache_kvs )), iter = cache_iter , starts = cache_starts )
6978
7079
7180def kvcache_update_cache (
7281 cache : KVCache ,
73- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
82+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
7483 batch_idxs : list [jax .Array ],
7584 actual_lens : list [jax .Array ],
7685):
@@ -85,7 +94,7 @@ def kvcache_update_cache(
8594def kvcache_get_entry (cache : KVCache , batch_idx : jax .Array ):
8695 shift = - cache .starts [batch_idx ]
8796 assert cache .time_axis > 0
88- kvs = jax .tree .map (lambda x : jnp .roll (x [batch_idx , ...], shift = shift , axis = cache .time_axis - 1 ), ( cache .k , cache . v ) )
97+ kvs = jax .tree .map (lambda x : jnp .roll (x [batch_idx , ...], shift = shift , axis = cache .time_axis - 1 ), cache .buffers )
8998 kvs = (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), kvs [0 ]), jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), kvs [1 ]))
9099 true_len = cache .fill_len ()[batch_idx ]
91100 return kvs , true_len
@@ -109,13 +118,13 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array |
109118 return jax .lax .top_k (free_pages , k )[1 ]
110119
111120
112- def _paged_update_slice (cache : PagedKVCache , k : jax .Array | QuantArray , v : jax . Array | QuantArray , * , layer_idx : int ):
113- key_heads = cache .k [layer_idx ].shape [0 ]
114- assert v .shape [:- 1 ] == k .shape [:- 1 ] == (cache .batch_size , key_heads , 1 )
121+ def _paged_update_slice (cache : PagedKVCache , kv : tuple [ jax .Array | QuantArray , ...] , * , layer_idx : int ):
122+ # key_heads = cache.buffers[0] [layer_idx].shape[0]
123+ # assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically
115124 needs_next_page = (cache .lengths % cache .page_size ) == 0
116125 page_table_idx = cache .lengths // cache .page_size
117126 current_page_cursor = jnp .take_along_axis (cache .block_tables , page_table_idx [:, None ], axis = - 1 )[..., 0 ]
118- avg_pages_per_batch_entry = round (cache .k [layer_idx ].shape [0 ] / cache .batch_size )
127+ avg_pages_per_batch_entry = round (cache .buffers [ 0 ] [layer_idx ].shape [0 ] / cache .batch_size )
119128 even_batch_spread = jnp .arange (cache .batch_size ) * avg_pages_per_batch_entry
120129 proposal_pages = jnp .where (cache .lengths == 0 , even_batch_spread , current_page_cursor + 1 )
121130 free_pages = _find_empty_pages (cache .free_pages , cache .batch_size , proposal_pages = proposal_pages )
@@ -127,27 +136,28 @@ def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.A
127136 # for batch index update the target slice is (heads, i, j, head_dim)
128137 # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim)
129138 _update = lambda dest , src : dest .at [:, page_cursor , inpage_cursor , ...].set (src .squeeze (2 ).swapaxes (0 , 1 ))
130- cache .k [layer_idx ], cache .v [layer_idx ] = jax .tree .map (_update , (cache .k [layer_idx ], cache .v [layer_idx ]), (k , v ))
139+ for buffer , new_buffer in safe_zip (cache .buffers , kv ):
140+ buffer [layer_idx ] = jax .tree .map (_update , buffer [layer_idx ], new_buffer )
131141
132142 batch_idx = jnp .arange (cache .batch_size )
133143 new_block_tables = cache .block_tables .at [batch_idx , new_lengths // cache .page_size ].set (page_cursor )
134144
135145 new_free_pages = cache .free_pages .at [page_cursor ].set (False , mode = "drop" )
136146 new_state = dict (lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages )
137- return cache . k [layer_idx ], cache .v [ layer_idx ] , new_state
147+ return tuple ( buffer [layer_idx ] for buffer in cache .buffers ) , new_state
138148
139149
140- def paged_update_slice (cache : PagedKVCache , k : jax .Array | QuantArray , v : jax . Array | QuantArray , * , layer_idx : int ):
150+ def paged_update_slice (cache : PagedKVCache , kv : tuple [ jax .Array | QuantArray , ...] , * , layer_idx : int ):
141151 repl_sharding = jax .typeof (cache .lengths ).sharding
142- kv_sharding = jax .tree .map (lambda x : jax .typeof (x ).sharding , ( cache . k [layer_idx ], cache .v [ layer_idx ] ))
143- sharding = (* kv_sharding , dict (lengths = repl_sharding , block_tables = repl_sharding , free_pages = repl_sharding ))
144- return auto_axes (partial (_paged_update_slice , layer_idx = layer_idx ), out_sharding = sharding )(cache , k , v )
152+ kv_sharding = jax .tree .map (lambda x : jax .typeof (x ).sharding , tuple ( buffer [layer_idx ] for buffer in cache .buffers ))
153+ sharding = (kv_sharding , dict (lengths = repl_sharding , block_tables = repl_sharding , free_pages = repl_sharding ))
154+ return auto_axes (partial (_paged_update_slice , layer_idx = layer_idx ), out_sharding = sharding )(cache , kv )
145155
146156
147157@partial (jax .jit , donate_argnames = ("cache" ,))
148158def _batch_paged_update_sequences (
149159 cache : PagedKVCache ,
150- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
160+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
151161 batch_idxs : list [jax .Array ],
152162 actual_lens : list [jax .Array ],
153163 update_mask : list [bool ] | None = None ,
@@ -156,9 +166,7 @@ def _batch_paged_update_sequences(
156166 batch_idxs = jnp .where (update_mask , jnp .array (batch_idxs ), 2 ** 30 ) # send masked to nowhere
157167 actual_lens = jnp .minimum (jnp .array (actual_lens ), jnp .array ([jax .tree .leaves (kv )[0 ].shape [2 ] for kv in kvs ]))
158168
159- kvs , max_seq_len = _transpose_attention_tree (
160- kvs , time_axis = 2
161- ) # undo stacking along the layer dimension for transit
169+ kvs , max_seq_len = _transpose_attention_tree (kvs , time_axis = 2 ) # undo stack along layer dimension in transit
162170
163171 # clear existing pages
164172 actual_page_num = jnp .rint (jnp .ceil (cache .lengths [batch_idxs ] / cache .page_size )).astype (jnp .int32 )
@@ -186,21 +194,23 @@ def _update_element(x, u):
186194 update_permute = [1 , 0 , 2 ] + [i for i in range (u .ndim ) if i not in (0 , 1 , 2 )]
187195 return x .at [:, pages_idx , ...].set (u .transpose (update_permute ), mode = "drop" )
188196
189- cache_k , cache_v = jax .tree .map (_update_element , ( cache .k , cache . v ) , kvs )
197+ new_buffers = jax .tree .map (_update_element , cache .buffers , kvs )
190198 block_tables_idx = jnp .where (
191199 update_mask [:, None ] & (pages_arange [None , :] < actual_page_num [:, None ]), pages_arange [None , :], 2 ** 30
192200 )
193201 new_block_tables = cache .block_tables .at [batch_idxs [:, None ], block_tables_idx ].set (pages_idx , mode = "drop" )
194202 new_free_pages = new_free_pages .at [pages_idx .reshape (- 1 )].set (False , mode = "drop" )
195203 new_lengths = cache .lengths .at [batch_idxs ].set (actual_lens , mode = "drop" )
204+
205+ named_buffers = dict (zip ([field .name for field in dataclasses .fields (cache )][:len (new_buffers )], new_buffers ))
196206 return dataclasses .replace (
197- cache , k = cache_k , v = cache_v , lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages
207+ cache , ** named_buffers , lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages
198208 )
199209
200210
201211def batch_paged_update_sequences (
202212 cache : KVCache ,
203- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
213+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
204214 batch_idxs : list [jax .Array ],
205215 actual_lens : list [jax .Array ],
206216):
@@ -222,5 +232,5 @@ def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len
222232 _get = lambda x : jnp .where (mask [None , :, * ([None ] * (x .ndim - 3 ))], _reshape_out (x [:, page_indices , ...]), 0 )
223233
224234 # stack along layer dimensions for transit
225- kvs = tuple (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), * z ) for z in jax .tree .map (_get , ( cache .k , cache . v ) ))
235+ kvs = tuple (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), * z ) for z in jax .tree .map (_get , cache .buffers ))
226236 return kvs , true_len
0 commit comments