Skip to content

Commit bfc0665

Browse files
committed
Skip non-participating hosts in computation
1 parent 3652849 commit bfc0665

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

serving/serving_jax/__init__.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16-
from functools import partial
16+
from functools import partial, wraps
1717
from typing import Any, Callable
1818
import math
1919
from concurrent.futures import ThreadPoolExecutor, Future
@@ -43,7 +43,7 @@
4343

4444
TIME_AXIS = 2
4545
USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now
46-
#USE_PREFIX_CACHE = False
46+
# USE_PREFIX_CACHE = False
4747
is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__)
4848

4949
########################################################################################################################
@@ -98,6 +98,7 @@ def _ensure_all_args_on_mesh(*args, mesh: Mesh):
9898
# kv cache buffer management ###########################################################################################
9999
########################################################################################################################
100100

101+
101102
@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns"))
102103
def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]:
103104
def _fn(val):
@@ -133,7 +134,7 @@ def _get_unique_buffer_ids(self, n: int):
133134
def offload_buffers(self, how_many: int):
134135
if how_many == 0:
135136
return
136-
candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2 ** 60)
137+
candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2**60)
137138
for i in candidates[:how_many]:
138139
if self.ondevice[i]:
139140
shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i])
@@ -168,6 +169,7 @@ def mark_visited(self, id: int):
168169
return [self.mark_visited(i) for i in id]
169170
self.usecount[id] += 1
170171

172+
171173
BUFFER_STORE = KVBufferStore()
172174

173175
########################################################################################################################
@@ -176,6 +178,7 @@ def mark_visited(self, id: int):
176178

177179
EMPTY, HASH_BITWIDTH = -1, 1
178180

181+
179182
@dataclasses.dataclass
180183
class ChildKeys:
181184
keys: np.ndarray
@@ -194,7 +197,12 @@ def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int
194197

195198

196199
def _prefilter_on_hash(
197-
w: np.ndarray, keys: np.ndarray, vh: np.ndarray, vm: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY
200+
w: np.ndarray,
201+
keys: np.ndarray,
202+
vh: np.ndarray,
203+
vm: np.ndarray,
204+
hash_bitwidth: int = HASH_BITWIDTH,
205+
pad_idx: int = EMPTY,
198206
):
199207
wh, wm = _hash_encode(w, hash_bitwidth=hash_bitwidth, pad_idx=pad_idx)
200208
inv_match = (wh ^ vh) & vm & wm
@@ -203,6 +211,7 @@ def _prefilter_on_hash(
203211
max_match_len = max(np.max(match_len), 1)
204212
return np.where(match_len == max_match_len)[0]
205213

214+
206215
def _fast_pad(x, size, axis, pad_val=0):
207216
new_buf = pad_val * np.ones([size - s if i == axis else s for i, s in enumerate(x.shape)], dtype=x.dtype)
208217
return np.concat([x, new_buf], axis)
@@ -344,6 +353,9 @@ def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pa
344353
next_power_of_2 = lambda x: 2 ** round(math.ceil(math.log2(x)))
345354
like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z)
346355
like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z)
356+
_make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays(
357+
x.shape, NamedSharding(mesh, x.sharding.spec), [], dtype=x.dtype
358+
)
347359

348360

349361
@dataclasses.dataclass
@@ -446,7 +458,15 @@ def body(carry, _):
446458
(curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps)
447459
return (curr_tokens, cache), output_tokens[..., 0].T
448460

449-
return multistep_decode_fn
461+
@wraps(multistep_decode_fn)
462+
def wrapped(curr_tokens, decode_weights, cache, cfg, steps: int = 32, *, participate: bool = True):
463+
if participate:
464+
return multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps=steps)
465+
else:
466+
_make_empty_, fn = partial(_make_empty, mesh=cfg.mesh), multistep_decode_fn
467+
return jax.tree.map(_make_empty_, jax.eval_shape(fn, curr_tokens, decode_weights, cache, cfg, steps=steps))
468+
469+
return wrapped
450470

451471

452472
def _make_stacked_prefill(prefill_fn):
@@ -461,7 +481,15 @@ def stacked_prefill(inputs, weights, cfg):
461481
stacked_kv = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *kv_list)
462482
return next_tokens, logits, stacked_kv
463483

464-
return lambda inputs, weights, cfg: stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg)
484+
@wraps(stacked_prefill)
485+
def wrapped(inputs, weights, cfg, *, participate: bool = True):
486+
if participate:
487+
return stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg)
488+
else:
489+
_make_empty_ = partial(_make_empty, mesh=cfg.mesh)
490+
return jax.tree.map(_make_empty_, jax.eval_shape(stacked_prefill, _numpy_pad_tokens(inputs), weights, cfg))
491+
492+
return wrapped
465493

466494

467495
class ServingLoop:
@@ -476,8 +504,8 @@ def __init__(
476504
decode_cache: KVCache,
477505
is_server: bool = False,
478506
):
479-
if not (SyncServer.broadcast("welcome", 0, True, is_server) if jax.process_count() > 1 else is_server):
480-
raise ValueError("No processes registered as the main server, at least one process must.")
507+
if not SyncServer.broadcast("welcome", 0, is_server, is_server):
508+
raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.")
481509
self.serve_cfg, self.cfg = serve_cfg, cfg
482510

483511
# setup decode
@@ -506,7 +534,7 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens,
506534
self.decode_output = (None, None)
507535

508536
# setup prefill
509-
self.prefill_fn = staticmethod(_make_stacked_prefill(prefill_fn))
537+
self.prefill_fn = _make_stacked_prefill(prefill_fn)
510538
self.prefill_weights = prefill_weights
511539
self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh
512540
self.prefill_work = PrefillWork([], [], [])
@@ -572,19 +600,16 @@ def decode_step(self):
572600

573601
# 2. run N decode steps
574602
output_tokens, output_mapping = [], []
575-
if "decode" in self.roles: # TODO(rdyro): revisit: don't issue the decode call on non-participating machines
576-
with use_mesh(self.decode_mesh):
577-
config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps)
578-
(self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn(
579-
self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config
580-
)
581-
output_mapping = [
582-
[getattr(result, "id", -1) for result in self.decode_work.active_results]
583-
] * self.serve_cfg.decode_steps
584-
output_mapping = np.array(output_mapping).T
585-
print(
586-
f"Decoding with fill rate of {np.mean([result is not None for result in self.decode_work.active_results])}"
603+
with use_mesh(self.decode_mesh):
604+
config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps, participate="decode" in self.roles)
605+
(self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn(
606+
self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config
587607
)
608+
output_mapping = [
609+
[getattr(result, "id", -1) for result in self.decode_work.active_results]
610+
] * self.serve_cfg.decode_steps
611+
output_mapping = np.array(output_mapping).T
612+
print(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}")
588613

589614
# 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection)
590615
self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output
@@ -607,8 +632,7 @@ def decode_step(self):
607632
(output_tokens_flat, output_mapping_flat, done),
608633
is_source="decode_coordinator" in self.roles,
609634
)
610-
#if "server" in self.roles or "decode_coordinator" in self.roles:
611-
for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()):
635+
for token, id in zip(output_tokens_flat, output_mapping_flat):
612636
if id > 0:
613637
self.results[id].token_list.append(token)
614638
self.results[id].tokens_decoded += 1
@@ -664,7 +688,9 @@ def prefill_step(self):
664688
self.prefill_work.to_decode.append(new_decode)
665689
print(f"Found a full match")
666690
else:
667-
print(f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}")
691+
print(
692+
f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}"
693+
)
668694
self.prefill_work.to_prefill.append(request)
669695

670696
if self.prefill_work.pending_prefill is not None: # a current prefill is still running, skip scheduling another
@@ -684,7 +710,9 @@ def _prefill_job():
684710
inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id)
685711
cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh)
686712
with use_mesh(self.prefill_mesh):
687-
_, _, prefill_results = self.prefill_fn(inputs, self.prefill_weights, cfg)
713+
_, _, prefill_results = self.prefill_fn(
714+
inputs, self.prefill_weights, cfg, participate="prefill" in self.roles
715+
)
688716
prefill_results = jax.block_until_ready(prefill_results)
689717
return prefill_input, prefill_results
690718

@@ -719,10 +747,10 @@ def serving_step(self):
719747
if "server" in self.roles:
720748
with self.state_lock:
721749
self.pending_requests, requests = [], list(self.pending_requests)
750+
serve_cfg, requests = SyncServer.broadcast(
751+
"requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles
752+
)
722753
with self.state_lock:
723-
serve_cfg, requests = SyncServer.broadcast(
724-
"requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles
725-
)
726754
self.serve_cfg = dataclasses.replace(self.serve_cfg, **serve_cfg)
727755
for request in requests:
728756
self.total_requests += 1

serving/serving_jax/cross_host.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
99
import numpy as np
1010

11-
jax.config.update("jax_enable_empty_arrays", True)
11+
#jax.config.update("jax_enable_empty_arrays", True)
1212
PyTree = Any
1313

1414

0 commit comments

Comments
 (0)