Skip to content

Commit 3dc3d49

Browse files
committed
Partial prefix match support via chunked prefill
1 parent bfc0665 commit 3dc3d49

File tree

5 files changed

+233
-164
lines changed

5 files changed

+233
-164
lines changed

deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
import jax
2121
from jax.sharding import PartitionSpec as P
22+
from argparse import ArgumentParser
2223

23-
from deepseek_r1_jax.model import ShardingRules, Config
24-
from deepseek_r1_jax import chkpt_utils as utils
2524

26-
def main():
27-
root_path = Path("/mnt/storage/DeepSeek-R1")
28-
dest_path = Path("/mnt/storage/deepseek-r1-jax-chkpt")
25+
def main(root_path, dest_path):
26+
from deepseek_r1_jax.model import ShardingRules, Config
27+
from deepseek_r1_jax import chkpt_utils as utils
28+
29+
root_path, dest_path = Path(root_path), Path(dest_path)
30+
dest_path.mkdir(exist_ok=True, parents=True)
2931

3032
cfg = Config()
3133
cfg.quantize_mlp = False
@@ -39,4 +41,17 @@ def main():
3941
utils.convert_hf_checkpoint(params_map, root_path, dest_path, cfg)
4042

4143
if __name__ == "__main__":
42-
main()
44+
parser = ArgumentParser()
45+
parser.add_argument(
46+
"--source-path", default="/mnt/storage/DeepSeek-R1-weights-only", required=True, help="HF model directory path"
47+
)
48+
parser.add_argument(
49+
"--dest-path",
50+
default="~/deepseek_r1_jax",
51+
required=True,
52+
help="JAX model model directory (to be created).",
53+
)
54+
args = parser.parse_args()
55+
main(args.source_path, args.dest_path)
56+
57+
main(args)

llama3/llama3_jax/model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import math
2222
from functools import partial
2323
from typing import Callable, Any, TypeVar
24-
from types import ModuleType
2524
from inspect import signature
25+
from collections import OrderedDict as odict
2626

2727
import jax
2828
import jax.numpy as jnp
@@ -31,7 +31,8 @@
3131
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
3232
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
3333
from jax.experimental.shard_map import shard_map
34-
from jax.sharding import PartitionSpec as P, use_mesh
34+
from jax.sharding import PartitionSpec as P
35+
from jax.experimental.array_serialization import pytree_serialization as ser
3536
try:
3637
from jax.experimental.shard import auto_axes as _auto_axes, reshard
3738
except ModuleNotFoundError:
@@ -213,6 +214,7 @@ class ArrayInfo:
213214
# module reload friendly isinstance check
214215
is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__)
215216
is_param = lambda x: is_type(x, ArrayInfo)
217+
which_platform = lambda cfg: cfg.mesh.devices.reshape(-1)[0].platform
216218
_count_left_padding = lambda ids, pad_id=0: auto_axes(
217219
lambda ids: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1), out_sharding=P(None)
218220
)(ids)
@@ -404,15 +406,18 @@ def abstract(cls, cfg: Config):
404406
)
405407

406408

407-
@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "time_axis"))
409+
@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "time_axis", "insert_sequences"))
408410
class KVCache(_Init):
409411
k: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim)
410412
v: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim)
411413
iter: jax.Array # [] # sequences are right-aligned for slice update performance
412414
starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices
413-
batch_size: int = 0
415+
batch_size: int = 1
414416
size: int = 2 ** 30
415417
time_axis: int = 2
418+
#update_slice: Callable = None
419+
insert_sequences: Callable = None
420+
#get_sequence: Callable = None
416421

417422
@classmethod
418423
def abstract(cls, cfg: Config, batch_size: int):
@@ -798,6 +803,8 @@ def _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale
798803

799804

800805
def paged_attention_kernel(q, k, v, block_tables, lengths, cfg: Config):
806+
if which_platform(cfg) not in ("gpu", "cuda"):
807+
raise ValueError("Paged attention is only supported on GPU.")
801808
k, k_scale = (k.quant, k.scale) if is_type(k, QuantArray) else (k, None)
802809
v, v_scale = (v.quant, v.scale) if is_type(v, QuantArray) else (v, None)
803810

@@ -1030,6 +1037,17 @@ def prepare_chunk(chunk, pad_to: int, pad_id: int):
10301037
return chunk, segment_ids
10311038

10321039

1040+
## serialization
1041+
#def save_pytree(data, path):
1042+
# flat_data = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(data)[0])
1043+
# ser.save(flat_data, path) # save a flatten with path to avoid custom
1044+
#
1045+
#
1046+
#def load_pytree(path, sharding=None):
1047+
# flat_sharding = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(sharding)[0])
1048+
# return jax.tree.unflatten(jax.tree.structure(sharding), jax.tree.leaves(ser.load(path, flat_sharding)))
1049+
1050+
10331051
def prefill(tokens: jax.Array, weights: Weights, cache: KVCache | None, cfg: Config, pad_id: int = 0):
10341052
"""Samples from a prompt."""
10351053
# Calculate the next power of 2 for padding, up to cfg.max_seq.

serving/main_serving.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import time
99
from typing import AsyncGenerator
1010
from contextlib import asynccontextmanager
11-
import os
1211
from argparse import ArgumentParser
12+
from typing import Any
1313

1414
import jax
1515
from jax import random
@@ -24,16 +24,15 @@
2424
import serving_jax as serving
2525
from serving_jax import attention_cache_utils
2626

27+
Config = Any
2728

2829
TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None
2930

3031
jax.config.update("jax_explain_cache_misses", True)
31-
jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser()))
32-
jax.config.update("jax_enable_empty_arrays", True)
32+
#jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser()))
3333

3434
try: # newer JAX only
35-
assert False
36-
my_id = int(socket.gethostname().split("-")[-1]) - 1
35+
my_id = int(socket.gethostname().split("-")[-1])
3736
my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0]
3837
jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}")
3938
jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8))
@@ -60,39 +59,60 @@ def load_model():
6059

6160
#process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...)
6261
#jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx)
62+
jax.distributed.initialize()
6363
print(jax.devices())
6464
print("-" * 80)
6565
print(jax.local_devices())
6666

67-
model_name = "Llama-3.1-8B-Instruct"
68-
ckpt_path = Path(f"~/{model_name}").expanduser()
67+
#model_name = "Llama-3.1-8B-Instruct"
68+
#ckpt_path = Path(f"~/{model_name}").expanduser()
69+
#model_name = "Llama-3.1-8B-Instruct-quant"
70+
model_name = "Llama-3.1-70B-Instruct-quant"
71+
ckpt_path = Path(f"~/bucket/llama3_jax_old/{model_name}").expanduser()
6972
cfg = l3jax.load_config(ckpt_path / "config.json")
7073
TOKENIZER = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json")
7174
assert ckpt_path.is_dir()
7275
print("---> Model config loaded")
7376

7477
# two hosts, different device and host meshes
75-
local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3)
76-
decode_mesh, prefill_mesh = local_mesh, local_mesh
78+
#local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3)
79+
#local_mesh = jax.make_mesh((1, 1, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3)
80+
#decode_mesh, prefill_mesh = local_mesh, local_mesh
81+
decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3)
82+
prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3)
83+
#decode_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3)
84+
#prefill_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3)
7785
cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True)
78-
cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=8192)
79-
cfg = dataclasses.replace(cfg, quant_layer=False, quant_cache=False)
80-
cfg.quant_cache = True
86+
cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=2048)
87+
cfg.quant_cache = False
8188

8289
decode_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh)))
8390
prefill_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh)))
8491

8592
print("---> Weights loaded")
8693

87-
serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64)
88-
#decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size)
89-
#decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry
90-
#decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache
91-
decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32)
92-
decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry
93-
decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences
94+
serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, prefix_chunk_size=64)
95+
decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size)
96+
decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry
97+
decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache
98+
#decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32)
99+
#decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry
100+
#decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences
101+
102+
def init_cache(cfg: Config, batch_size: int, actual_len: int):
103+
cache = l3jax.KVCache.init(random.key(0), cfg, batch_size)
104+
cache.get_sequence = attention_cache_utils.kvcache_get_entry
105+
cache.insert_sequences = attention_cache_utils.kvcache_update_cache
106+
cache.iter = actual_len
107+
return cache
108+
109+
with jax.sharding.set_mesh(prefill_mesh):
110+
prefill_cache = init_cache(dataclasses.replace(cfg, mesh=prefill_mesh), serve_cfg.prefill_batch_size, 8192)
111+
112+
forward_fn = l3jax.decode_step # TODO: the model file needs to call it forward explcitly
94113
SERVE_LOOP = serving.ServingLoop(
95-
serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache, ARGS.server
114+
#serve_cfg, cfg, init_cache, l3jax.decode_step, prefill_weights, decode_weights, decode_cache, ARGS.server
115+
serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server
96116
)
97117
print("---> Created the serving loop")
98118

0 commit comments

Comments
 (0)