1313# limitations under the License.
1414
1515import dataclasses
16- from functools import partial
16+ from functools import partial , wraps
1717from typing import Any , Callable
1818import math
1919from concurrent .futures import ThreadPoolExecutor , Future
4343
4444TIME_AXIS = 2
4545USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now
46- #USE_PREFIX_CACHE = False
46+ # USE_PREFIX_CACHE = False
4747is_type = lambda x , cls : (type (x ).__name__ == cls .__name__ ) and (type (x ).__module__ == cls .__module__ )
4848
4949########################################################################################################################
@@ -98,6 +98,7 @@ def _ensure_all_args_on_mesh(*args, mesh: Mesh):
9898# kv cache buffer management ###########################################################################################
9999########################################################################################################################
100100
101+
101102@partial (jax .jit , static_argnames = ("axis" , "chunk_size" , "ns" ))
102103def _split (val : jax .Array | list [jax .Array ], axis : int , chunk_size : int , ns : int ) -> list [jax .Array ]:
103104 def _fn (val ):
@@ -133,7 +134,7 @@ def _get_unique_buffer_ids(self, n: int):
133134 def offload_buffers (self , how_many : int ):
134135 if how_many == 0 :
135136 return
136- candidates = sorted (self ._store .keys (), key = lambda i : self .usecount [i ] if self .ondevice [i ] else 2 ** 60 )
137+ candidates = sorted (self ._store .keys (), key = lambda i : self .usecount [i ] if self .ondevice [i ] else 2 ** 60 )
137138 for i in candidates [:how_many ]:
138139 if self .ondevice [i ]:
139140 shrd = jax .tree .map (lambda x : x .sharding .with_memory_kind ("pinned_host" ), self ._store [i ])
@@ -168,6 +169,7 @@ def mark_visited(self, id: int):
168169 return [self .mark_visited (i ) for i in id ]
169170 self .usecount [id ] += 1
170171
172+
171173BUFFER_STORE = KVBufferStore ()
172174
173175########################################################################################################################
@@ -176,6 +178,7 @@ def mark_visited(self, id: int):
176178
177179EMPTY , HASH_BITWIDTH = - 1 , 1
178180
181+
179182@dataclasses .dataclass
180183class ChildKeys :
181184 keys : np .ndarray
@@ -194,7 +197,12 @@ def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int
194197
195198
196199def _prefilter_on_hash (
197- w : np .ndarray , keys : np .ndarray , vh : np .ndarray , vm : np .ndarray , hash_bitwidth : int = HASH_BITWIDTH , pad_idx : int = EMPTY
200+ w : np .ndarray ,
201+ keys : np .ndarray ,
202+ vh : np .ndarray ,
203+ vm : np .ndarray ,
204+ hash_bitwidth : int = HASH_BITWIDTH ,
205+ pad_idx : int = EMPTY ,
198206):
199207 wh , wm = _hash_encode (w , hash_bitwidth = hash_bitwidth , pad_idx = pad_idx )
200208 inv_match = (wh ^ vh ) & vm & wm
@@ -203,6 +211,7 @@ def _prefilter_on_hash(
203211 max_match_len = max (np .max (match_len ), 1 )
204212 return np .where (match_len == max_match_len )[0 ]
205213
214+
206215def _fast_pad (x , size , axis , pad_val = 0 ):
207216 new_buf = pad_val * np .ones ([size - s if i == axis else s for i , s in enumerate (x .shape )], dtype = x .dtype )
208217 return np .concat ([x , new_buf ], axis )
@@ -344,6 +353,9 @@ def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pa
344353next_power_of_2 = lambda x : 2 ** round (math .ceil (math .log2 (x )))
345354like_spec = lambda z : jax .tree .map (lambda x : jax .typeof (x ).sharding .spec , z )
346355like_shard = lambda z , mesh : jax .tree .map (lambda x : NamedSharding (mesh , jax .typeof (x ).sharding .spec ), z )
356+ _make_empty = lambda x , mesh : jax .make_array_from_single_device_arrays (
357+ x .shape , NamedSharding (mesh , x .sharding .spec ), [], dtype = x .dtype
358+ )
347359
348360
349361@dataclasses .dataclass
@@ -446,7 +458,15 @@ def body(carry, _):
446458 (curr_tokens , cache ), output_tokens = jax .lax .scan (body , (curr_tokens , cache ), length = steps )
447459 return (curr_tokens , cache ), output_tokens [..., 0 ].T
448460
449- return multistep_decode_fn
461+ @wraps (multistep_decode_fn )
462+ def wrapped (curr_tokens , decode_weights , cache , cfg , steps : int = 32 , * , participate : bool = True ):
463+ if participate :
464+ return multistep_decode_fn (curr_tokens , decode_weights , cache , cfg , steps = steps )
465+ else :
466+ _make_empty_ , fn = partial (_make_empty , mesh = cfg .mesh ), multistep_decode_fn
467+ return jax .tree .map (_make_empty_ , jax .eval_shape (fn , curr_tokens , decode_weights , cache , cfg , steps = steps ))
468+
469+ return wrapped
450470
451471
452472def _make_stacked_prefill (prefill_fn ):
@@ -461,7 +481,15 @@ def stacked_prefill(inputs, weights, cfg):
461481 stacked_kv = jax .tree .map (lambda * xs : jnp .stack (xs , axis = 0 ), * kv_list )
462482 return next_tokens , logits , stacked_kv
463483
464- return lambda inputs , weights , cfg : stacked_prefill (_numpy_pad_tokens (inputs ), weights , cfg )
484+ @wraps (stacked_prefill )
485+ def wrapped (inputs , weights , cfg , * , participate : bool = True ):
486+ if participate :
487+ return stacked_prefill (_numpy_pad_tokens (inputs ), weights , cfg )
488+ else :
489+ _make_empty_ = partial (_make_empty , mesh = cfg .mesh )
490+ return jax .tree .map (_make_empty_ , jax .eval_shape (stacked_prefill , _numpy_pad_tokens (inputs ), weights , cfg ))
491+
492+ return wrapped
465493
466494
467495class ServingLoop :
@@ -476,8 +504,8 @@ def __init__(
476504 decode_cache : KVCache ,
477505 is_server : bool = False ,
478506 ):
479- if not ( SyncServer .broadcast ("welcome" , 0 , True , is_server ) if jax . process_count () > 1 else is_server ):
480- raise ValueError ("No processes registered as the main server, at least one process must." )
507+ if not SyncServer .broadcast ("welcome" , 0 , is_server , is_server ):
508+ raise ValueError ("Neither this proccess nor any other processe is the main server, at least one must." )
481509 self .serve_cfg , self .cfg = serve_cfg , cfg
482510
483511 # setup decode
@@ -506,7 +534,7 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens,
506534 self .decode_output = (None , None )
507535
508536 # setup prefill
509- self .prefill_fn = staticmethod ( _make_stacked_prefill (prefill_fn ) )
537+ self .prefill_fn = _make_stacked_prefill (prefill_fn )
510538 self .prefill_weights = prefill_weights
511539 self .prefill_mesh = [x for x in jax .tree .leaves (prefill_weights ) if hasattr (x , "sharding" )][0 ].sharding .mesh
512540 self .prefill_work = PrefillWork ([], [], [])
@@ -572,19 +600,16 @@ def decode_step(self):
572600
573601 # 2. run N decode steps
574602 output_tokens , output_mapping = [], []
575- if "decode" in self .roles : # TODO(rdyro): revisit: don't issue the decode call on non-participating machines
576- with use_mesh (self .decode_mesh ):
577- config = dict (cfg = self .cfg , steps = self .serve_cfg .decode_steps )
578- (self .decode_work .curr_tokens , self .decode_work .cache ), output_tokens = self .multistep_decode_fn (
579- self .decode_work .curr_tokens , self .decode_weights , self .decode_work .cache , ** config
580- )
581- output_mapping = [
582- [getattr (result , "id" , - 1 ) for result in self .decode_work .active_results ]
583- ] * self .serve_cfg .decode_steps
584- output_mapping = np .array (output_mapping ).T
585- print (
586- f"Decoding with fill rate of { np .mean ([result is not None for result in self .decode_work .active_results ])} "
603+ with use_mesh (self .decode_mesh ):
604+ config = dict (cfg = self .cfg , steps = self .serve_cfg .decode_steps , participate = "decode" in self .roles )
605+ (self .decode_work .curr_tokens , self .decode_work .cache ), output_tokens = self .multistep_decode_fn (
606+ self .decode_work .curr_tokens , self .decode_weights , self .decode_work .cache , ** config
587607 )
608+ output_mapping = [
609+ [getattr (result , "id" , - 1 ) for result in self .decode_work .active_results ]
610+ ] * self .serve_cfg .decode_steps
611+ output_mapping = np .array (output_mapping ).T
612+ print (f"Decoding with fill rate: { np .mean ([result is not None for result in self .decode_work .active_results ])} " )
588613
589614 # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection)
590615 self .decode_output , (output_tokens , output_mapping ) = (output_tokens , output_mapping ), self .decode_output
@@ -607,8 +632,7 @@ def decode_step(self):
607632 (output_tokens_flat , output_mapping_flat , done ),
608633 is_source = "decode_coordinator" in self .roles ,
609634 )
610- #if "server" in self.roles or "decode_coordinator" in self.roles:
611- for token , id in zip (output_tokens .reshape (- 1 ).tolist (), output_mapping .reshape (- 1 ).tolist ()):
635+ for token , id in zip (output_tokens_flat , output_mapping_flat ):
612636 if id > 0 :
613637 self .results [id ].token_list .append (token )
614638 self .results [id ].tokens_decoded += 1
@@ -664,7 +688,9 @@ def prefill_step(self):
664688 self .prefill_work .to_decode .append (new_decode )
665689 print (f"Found a full match" )
666690 else :
667- print (f"Need to prefill the request, only found a match for length { total_match / (len (request .text ) - 1 )} " )
691+ print (
692+ f"Need to prefill the request, only found a match for length { total_match / (len (request .text ) - 1 )} "
693+ )
668694 self .prefill_work .to_prefill .append (request )
669695
670696 if self .prefill_work .pending_prefill is not None : # a current prefill is still running, skip scheduling another
@@ -684,7 +710,9 @@ def _prefill_job():
684710 inputs = np .pad (inputs , ((0 , row_pad ), (0 , col_pad )), mode = "constant" , constant_values = self .pad_id )
685711 cfg = dataclasses .replace (self .cfg , mesh = self .prefill_mesh )
686712 with use_mesh (self .prefill_mesh ):
687- _ , _ , prefill_results = self .prefill_fn (inputs , self .prefill_weights , cfg )
713+ _ , _ , prefill_results = self .prefill_fn (
714+ inputs , self .prefill_weights , cfg , participate = "prefill" in self .roles
715+ )
688716 prefill_results = jax .block_until_ready (prefill_results )
689717 return prefill_input , prefill_results
690718
@@ -719,10 +747,10 @@ def serving_step(self):
719747 if "server" in self .roles :
720748 with self .state_lock :
721749 self .pending_requests , requests = [], list (self .pending_requests )
750+ serve_cfg , requests = SyncServer .broadcast (
751+ "requests" , self ._it , (dataclasses .asdict (self .serve_cfg ), requests ), is_source = "server" in self .roles
752+ )
722753 with self .state_lock :
723- serve_cfg , requests = SyncServer .broadcast (
724- "requests" , self ._it , (dataclasses .asdict (self .serve_cfg ), requests ), is_source = "server" in self .roles
725- )
726754 self .serve_cfg = dataclasses .replace (self .serve_cfg , ** serve_cfg )
727755 for request in requests :
728756 self .total_requests += 1
0 commit comments