88import time
99from typing import AsyncGenerator
1010from contextlib import asynccontextmanager
11- import os
1211from argparse import ArgumentParser
12+ from typing import Any
1313
1414import jax
1515from jax import random
2424import serving_jax as serving
2525from serving_jax import attention_cache_utils
2626
27+ Config = Any
2728
2829TOKENIZER , SERVE_LOOP , SERVING_THREAD , ARGS = None , None , None , None
2930
3031jax .config .update ("jax_explain_cache_misses" , True )
31- jax .config .update ("jax_compilation_cache_dir" , str (Path ("~/.cache/jax" ).expanduser ()))
32- jax .config .update ("jax_enable_empty_arrays" , True )
32+ #jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser()))
3333
3434try : # newer JAX only
35- assert False
36- my_id = int (socket .gethostname ().split ("-" )[- 1 ]) - 1
35+ my_id = int (socket .gethostname ().split ("-" )[- 1 ])
3736 my_ip = socket .getaddrinfo (socket .gethostname (), 80 )[0 ][- 1 ][0 ]
3837 jax .config .update ("jax_cross_host_transfer_socket_address" , f"{ my_ip } :{ 17007 + my_id } " )
3938 jax .config .update ("jax_cross_host_transport_addresses" , "," .join ([f"{ my_ip } :0" ] * 8 ))
@@ -60,39 +59,60 @@ def load_model():
6059
6160 #process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...)
6261 #jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx)
62+ jax .distributed .initialize ()
6363 print (jax .devices ())
6464 print ("-" * 80 )
6565 print (jax .local_devices ())
6666
67- model_name = "Llama-3.1-8B-Instruct"
68- ckpt_path = Path (f"~/{ model_name } " ).expanduser ()
67+ #model_name = "Llama-3.1-8B-Instruct"
68+ #ckpt_path = Path(f"~/{model_name}").expanduser()
69+ #model_name = "Llama-3.1-8B-Instruct-quant"
70+ model_name = "Llama-3.1-70B-Instruct-quant"
71+ ckpt_path = Path (f"~/bucket/llama3_jax_old/{ model_name } " ).expanduser ()
6972 cfg = l3jax .load_config (ckpt_path / "config.json" )
7073 TOKENIZER = l3jax .load_tokenizer (ckpt_path / "tokenizer.json" , ckpt_path / "tokenizer_config.json" )
7174 assert ckpt_path .is_dir ()
7275 print ("---> Model config loaded" )
7376
7477 # two hosts, different device and host meshes
75- local_mesh = jax .make_mesh ((1 , 8 , 1 ), P ("x" , "y" , "z" ), devices = jax .local_devices (), axis_types = (AxisType .Explicit ,) * 3 )
76- decode_mesh , prefill_mesh = local_mesh , local_mesh
78+ #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3)
79+ #local_mesh = jax.make_mesh((1, 1, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3)
80+ #decode_mesh, prefill_mesh = local_mesh, local_mesh
81+ decode_mesh = jax .make_mesh ((1 , 8 , 1 ), P ("x" , "y" , "z" ), devices = jax .devices ()[:8 ], axis_types = (AxisType .Explicit ,) * 3 )
82+ prefill_mesh = jax .make_mesh ((1 , 8 , 1 ), P ("x" , "y" , "z" ), devices = jax .devices ()[8 :], axis_types = (AxisType .Explicit ,) * 3 )
83+ #decode_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3)
84+ #prefill_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3)
7785 cfg = dataclasses .replace (cfg , mesh = decode_mesh , quant_layer = True , quant_cache = True )
78- cfg = dataclasses .replace (cfg , use_prefill_attn_kernel = False , use_decode_attn_kernel = False , max_seq_len = 8192 )
79- cfg = dataclasses .replace (cfg , quant_layer = False , quant_cache = False )
80- cfg .quant_cache = True
86+ cfg = dataclasses .replace (cfg , use_prefill_attn_kernel = False , use_decode_attn_kernel = False , max_seq_len = 2048 )
87+ cfg .quant_cache = False
8188
8289 decode_weights = l3jax .load_pytree (ckpt_path , l3jax .Weights .shardings (dataclasses .replace (cfg , mesh = decode_mesh )))
8390 prefill_weights = l3jax .load_pytree (ckpt_path , l3jax .Weights .shardings (dataclasses .replace (cfg , mesh = prefill_mesh )))
8491
8592 print ("---> Weights loaded" )
8693
87- serve_cfg = serving .ServingConfig (decode_steps = 32 , max_decode_length = 64 )
88- #decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size)
89- #decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry
90- #decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache
91- decode_cache = l3jax .PagedKVCache .init (random .key (0 ), cfg , serve_cfg .decode_batch_size , 2048 , 32 )
92- decode_cache .get_sequence = attention_cache_utils .batch_paged_get_entry
93- decode_cache .insert_sequences = attention_cache_utils .batch_paged_update_sequences
94+ serve_cfg = serving .ServingConfig (decode_steps = 32 , max_decode_length = 64 , prefix_chunk_size = 64 )
95+ decode_cache = l3jax .KVCache .init (random .key (0 ), cfg , serve_cfg .decode_batch_size )
96+ decode_cache .get_sequence = attention_cache_utils .kvcache_get_entry
97+ decode_cache .insert_sequences = attention_cache_utils .kvcache_update_cache
98+ #decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32)
99+ #decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry
100+ #decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences
101+
102+ def init_cache (cfg : Config , batch_size : int , actual_len : int ):
103+ cache = l3jax .KVCache .init (random .key (0 ), cfg , batch_size )
104+ cache .get_sequence = attention_cache_utils .kvcache_get_entry
105+ cache .insert_sequences = attention_cache_utils .kvcache_update_cache
106+ cache .iter = actual_len
107+ return cache
108+
109+ with jax .sharding .set_mesh (prefill_mesh ):
110+ prefill_cache = init_cache (dataclasses .replace (cfg , mesh = prefill_mesh ), serve_cfg .prefill_batch_size , 8192 )
111+
112+ forward_fn = l3jax .decode_step # TODO: the model file needs to call it forward explcitly
94113 SERVE_LOOP = serving .ServingLoop (
95- serve_cfg , cfg , l3jax .prefill , prefill_weights , l3jax .decode_step , decode_weights , decode_cache , ARGS .server
114+ #serve_cfg, cfg, init_cache, l3jax.decode_step, prefill_weights, decode_weights, decode_cache, ARGS.server
115+ serve_cfg , cfg , forward_fn , prefill_weights , prefill_cache , decode_weights , decode_cache , ARGS .server
96116 )
97117 print ("---> Created the serving loop" )
98118
0 commit comments